Add SCF tuple flattening to StableHLO preprocessing (#14548)

Previously we removed tuples from control flow by first collapsing to
SCF. This makes detensoring passes much more difficult and is easier if
it works at the SCF level.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
index 4f6c2a3..048f845 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
@@ -45,12 +45,10 @@
   passManager.addNestedPass<func::FuncOp>(
       stablehlo::createLegalizeControlFlow());
 
-  // Currently we don't handle SCF ops well and have to convert them all to CFG.
-  // In the future it would be nice if we could have all of flow be both scf
-  // and cfg compatible.
-  passManager.addNestedPass<func::FuncOp>(createTopLevelSCFToCFGPass());
-  if (detuple)
+  passManager.addPass(createFlattenTuplesInSCF());
+  if (detuple) {
     passManager.addPass(createFlattenTuplesInCFG());
+  }
 
   passManager.addPass(createStableHLOToStableHLOPreprocessing());
   passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel
index 165ce95..9e2f5ae 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel
@@ -64,6 +64,7 @@
         "DotGeneralToDot.cpp",
         "EinsumToDotGeneral.cpp",
         "FlattenTuplesInCFG.cpp",
+        "FlattenTuplesInSCF.cpp",
         "GatherToTorchIndexSelect.cpp",
         "LowerComplex.cpp",
         "Passes.cpp",
@@ -85,6 +86,7 @@
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:MathDialect",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:ShapeDialect",
         "@llvm-project//mlir:SparseTensorDialect",
         "@llvm-project//mlir:Support",
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt
index 8fccaef..5ef3819 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt
@@ -57,6 +57,7 @@
     "DotGeneralToDot.cpp"
     "EinsumToDotGeneral.cpp"
     "FlattenTuplesInCFG.cpp"
+    "FlattenTuplesInSCF.cpp"
     "GatherToTorchIndexSelect.cpp"
     "LowerComplex.cpp"
     "Passes.cpp"
@@ -73,6 +74,7 @@
     MLIRIR
     MLIRMathDialect
     MLIRPass
+    MLIRSCFDialect
     MLIRShapeDialect
     MLIRSparseTensorDialect
     MLIRSupport
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp
index 2412ac3..0d13e5a 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp
@@ -144,7 +144,6 @@
 
   auto newResults = newOp.getResults();
   for (auto oldResult : oldOp.getResults()) {
-    llvm::SmallVector<Value, 10> subValues;
     auto newResult = recursiveRetuple(oldResult.getType(), &newResults, builder,
                                       oldOp->getLoc());
     mapping.map(oldResult, newResult);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInSCF.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInSCF.cpp
new file mode 100644
index 0000000..a0432e0
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInSCF.cpp
@@ -0,0 +1,266 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Implements IREE-specific preprocessing for XLA inputs.
+
+#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h"
+#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "stablehlo/dialect/StablehloOps.h"
+
+namespace mlir::iree_compiler::stablehlo {
+
+#define GEN_PASS_DEF_FLATTENTUPLESINSCF
+#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h.inc"
+
+namespace {
+// Given a set of types, unpack to a list of a types, removing all tuples.
+void untupleTypes(TypeRange types, llvm::SmallVectorImpl<Type> &newTypes) {
+  for (Type type : types) {
+    if (auto tupleTy = dyn_cast<TupleType>(type)) {
+      untupleTypes(tupleTy.getTypes(), newTypes);
+    } else {
+      newTypes.push_back(type);
+    }
+  }
+}
+
+void recursiveUntuple(Value value, ImplicitLocOpBuilder b, IRMapping &mapping,
+                      llvm::SmallVectorImpl<Value> &newValues) {
+  auto tupleType = dyn_cast<TupleType>(value.getType());
+  if (!tupleType) {
+    // We can return the value as is.
+    newValues.push_back(value);
+    return;
+  }
+
+  for (auto [idx, subType] : llvm::enumerate(tupleType.getTypes())) {
+    auto elementOp = b.create<mlir::stablehlo::GetTupleElementOp>(
+        subType, value, b.getI32IntegerAttr(idx));
+    recursiveUntuple(elementOp.getResult(), b, mapping, newValues);
+  }
+}
+
+Value recursiveRetuple(Type oldType, ArrayRef<Value> *values,
+                       ImplicitLocOpBuilder &b) {
+  auto tupleType = dyn_cast<TupleType>(oldType);
+  if (!tupleType) {
+    Value returnValue = *values->begin();
+    *values = {values->begin() + 1, values->end()};
+    return returnValue;
+  }
+
+  llvm::SmallVector<Value> subValues;
+  for (Type subType : tupleType.getTypes()) {
+    subValues.push_back(recursiveRetuple(subType, values, b));
+  }
+
+  return b.create<mlir::stablehlo::TupleOp>(tupleType, subValues).getResult();
+}
+
+void DetupleRegion(Region &srcRegion, Region &destRegion, ArrayRef<Type> types,
+                   IRMapping &mapping, ImplicitLocOpBuilder &b) {
+  auto beforeB = b.createBlock(&destRegion);
+  beforeB->addArguments(types, SmallVector<Location>(types.size(), b.getLoc()));
+  b.setInsertionPointToStart(beforeB);
+
+  llvm::SmallVector<Value> beforeResultVals;
+  for (auto argument : beforeB->getArguments()) {
+    beforeResultVals.push_back(argument);
+  }
+
+  llvm::ArrayRef<Value> beforeResults(beforeResultVals);
+  for (auto oldResult : srcRegion.front().getArguments()) {
+    auto newResult = recursiveRetuple(oldResult.getType(), &beforeResults, b);
+    mapping.map(oldResult, newResult);
+  }
+
+  b.setInsertionPointToEnd(&destRegion.front());
+  for (auto &srcop : srcRegion.front()) {
+    auto cloned = b.clone(srcop, mapping);
+    for (int i = 0; i < cloned->getNumResults(); i++) {
+      mapping.map(srcop.getResult(i), cloned->getResult(i));
+    }
+  }
+}
+
+class DetupleYieldOp : public OpRewritePattern<scf::YieldOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::scf::YieldOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    bool hasTuples = false;
+    IRMapping mapping;
+
+    llvm::SmallVector<Value> operands;
+    for (auto operand : op.getOperands()) {
+      hasTuples |= isa<TupleType>(operand.getType());
+      recursiveUntuple(operand, b, mapping, operands);
+    }
+
+    if (!hasTuples)
+      return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+    rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, operands);
+    return success();
+  }
+};
+
+class DetupleConditionOp : public OpRewritePattern<scf::ConditionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::scf::ConditionOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    bool hasTuples = false;
+    IRMapping mapping;
+
+    llvm::SmallVector<Value> operands;
+    for (auto operand : op.getArgs()) {
+      hasTuples |= isa<TupleType>(operand.getType());
+      recursiveUntuple(operand, b, mapping, operands);
+    }
+
+    if (!hasTuples)
+      return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+    rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(op, op.getCondition(),
+                                                        operands);
+    return success();
+  }
+};
+
+class DetupleIfOp : public OpRewritePattern<scf::IfOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::scf::IfOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    bool hasTuples = false;
+    IRMapping mapping;
+
+    for (auto type : op.getResultTypes()) {
+      hasTuples |= isa<TupleType>(type);
+    }
+
+    if (!hasTuples)
+      return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+    llvm::SmallVector<Type> types;
+    untupleTypes(op.getResultTypes(), types);
+
+    auto newOp = b.create<mlir::scf::IfOp>(types, op.getOperand());
+
+    DetupleRegion(op.getThenRegion(), newOp.getThenRegion(), {}, mapping, b);
+    DetupleRegion(op.getElseRegion(), newOp.getElseRegion(), {}, mapping, b);
+
+    b.setInsertionPoint(op);
+    llvm::SmallVector<Value> newResultVals;
+    for (auto result : newOp.getResults()) {
+      newResultVals.push_back(result);
+    }
+
+    llvm::ArrayRef<Value> newResults(newResultVals);
+    llvm::SmallVector<Value, 10> retupledValues;
+    for (auto oldResult : op.getResults()) {
+      auto newResult = recursiveRetuple(oldResult.getType(), &newResults, b);
+      retupledValues.push_back(newResult);
+      mapping.map(oldResult, newResult);
+    }
+
+    rewriter.replaceOp(op, retupledValues);
+    return success();
+  }
+};
+
+class DetupleWhileOp : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::scf::WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    bool hasTuples = false;
+    IRMapping mapping;
+
+    llvm::SmallVector<Value> operands;
+    for (auto operand : op.getOperands()) {
+      hasTuples |= isa<TupleType>(operand.getType());
+      recursiveUntuple(operand, b, mapping, operands);
+    }
+
+    if (!hasTuples)
+      return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+    llvm::SmallVector<Type> types;
+    untupleTypes(op.getResultTypes(), types);
+
+    auto newOp = b.create<mlir::scf::WhileOp>(types, operands);
+
+    DetupleRegion(op.getBefore(), newOp.getBefore(), types, mapping, b);
+    DetupleRegion(op.getAfter(), newOp.getAfter(), types, mapping, b);
+
+    b.setInsertionPoint(op);
+    llvm::SmallVector<Value> newResultVals;
+    for (auto result : newOp.getResults()) {
+      newResultVals.push_back(result);
+    }
+
+    llvm::ArrayRef<Value> newResults(newResultVals);
+    llvm::SmallVector<Value, 10> retupledValues;
+    for (auto oldResult : op.getResults()) {
+      auto newResult = recursiveRetuple(oldResult.getType(), &newResults, b);
+      retupledValues.push_back(newResult);
+      mapping.map(oldResult, newResult);
+    }
+
+    rewriter.replaceOp(op, retupledValues);
+    return success();
+  }
+};
+
+struct FlattenTuplesInSCF final
+    : impl::FlattenTuplesInSCFBase<FlattenTuplesInSCF> {
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mlir::scf::SCFDialect>();
+    registry.insert<mlir::stablehlo::StablehloDialect>();
+  }
+
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+    MLIRContext *ctx = module.getContext();
+    Builder b(ctx);
+
+    // Run canonicalization patterns to cancel out remaining tuple ops. We need
+    // to run these manually here because StableHLO does not define
+    // folds/canonicalization patterns for its ops.
+    RewritePatternSet patterns(ctx);
+    populateCanonicalizationPatterns(ctx, &patterns);
+    patterns
+        .add<DetupleYieldOp, DetupleConditionOp, DetupleIfOp, DetupleWhileOp>(
+            ctx);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
index b52cb4b..1b4a020 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
@@ -34,10 +34,15 @@
 }
 
 def FlattenTuplesInCFG :
-    Pass<"iree-stablehlo-preprocessing-flatten-tuples", "ModuleOp"> {
+    Pass<"iree-stablehlo-preprocessing-flatten-cfg-tuples", "ModuleOp"> {
   let summary = "Flattens tuples in the CFG form of StableHLO";
 }
 
+def FlattenTuplesInSCF :
+    Pass<"iree-stablehlo-preprocessing-flatten-scf-tuples", "ModuleOp"> {
+  let summary = "Flattens tuples in the SCF form of StableHLO";
+}
+
 def GatherToTorchIndexSelect :
     Pass<"iree-stablehlo-preprocessing-gather-to-torch-index-select", "func::FuncOp"> {
   let summary = "Legalizes gathers to a torch index select";
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel
index 358e11b..b8fe70c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel
@@ -24,6 +24,7 @@
             "dot_general_to_dot.mlir",
             "einsum_to_dot_general.mlir",
             "flatten_tuples_in_cfg.mlir",
+            "flatten_tuples_in_scf.mlir",
             "gather_to_torch_index_select.mlir",
             "stablehlo_to_stablehlo.mlir",
             "unfuse_batch_norm.mlir",
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt
index 79e9082..da9a641 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt
@@ -20,6 +20,7 @@
     "dot_general_to_dot.mlir"
     "einsum_to_dot_general.mlir"
     "flatten_tuples_in_cfg.mlir"
+    "flatten_tuples_in_scf.mlir"
     "gather_to_torch_index_select.mlir"
     "stablehlo_to_stablehlo.mlir"
     "unfuse_batch_norm.mlir"
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir
index db0dcef..dc4f204 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --iree-stablehlo-preprocessing-flatten-tuples %s | FileCheck %s
+// RUN: iree-opt --iree-stablehlo-preprocessing-flatten-cfg-tuples %s | FileCheck %s
 
 // CHECK-LABEL: @flatten_func
 module @flatten_func {
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_scf.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_scf.mlir
new file mode 100644
index 0000000..c57e602
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_scf.mlir
@@ -0,0 +1,66 @@
+// RUN: iree-opt --iree-stablehlo-preprocessing-flatten-scf-tuples %s | FileCheck %s
+
+func.func @testWhile(%arg0 : i32, %arg1 : tuple<tensor<4xf32>, tensor<4xf32>>) -> (i32, tuple<tensor<4xf32>, tensor<4xf32>>) {
+  %0:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, tuple<tensor<4xf32>, tensor<4xf32>>) -> (i32, tuple<tensor<4xf32>, tensor<4xf32>>) {
+    %c10 = arith.constant 10 : i32
+    %1 = arith.cmpi slt, %arg2, %c10 : i32
+    scf.condition(%1) %arg2, %arg3 : i32, tuple<tensor<4xf32>, tensor<4xf32>>
+  } do {
+  ^bb0(%arg2: i32, %arg3: tuple<tensor<4xf32>, tensor<4xf32>>):
+    %c1 = arith.constant 1 : i32
+    %add = arith.addi %arg2, %c1 : i32
+    %2 = stablehlo.get_tuple_element %arg3[0] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+    %3 = stablehlo.get_tuple_element %arg3[1] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+    %4 = stablehlo.add %2, %3 : tensor<4xf32>
+    %5 = stablehlo.tuple %3, %4 : tuple<tensor<4xf32>, tensor<4xf32>>
+    scf.yield %add, %5 : i32, tuple<tensor<4xf32>, tensor<4xf32>>
+  }
+  return %0#0, %0#1 : i32, tuple<tensor<4xf32>, tensor<4xf32>>
+}
+
+// CHECK-LABEL: @testWhile
+// CHECK-SAME: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: tuple<tensor<4xf32>, tensor<4xf32>>
+// CHECK: %[[C1:.+]] = arith.constant 1
+// CHECK: %[[C10:.+]] = arith.constant 10
+// CHECK: %[[L:.+]] = stablehlo.get_tuple_element %[[ARG1]][0]
+// CHECK: %[[R:.+]] = stablehlo.get_tuple_element %[[ARG1]][1]
+
+// CHECK: %[[WHILE:.+]]:3 = scf.while (%[[ARG2:.+]] = %[[ARG0]], %[[ARG3:.+]] = %[[L]], %[[ARG4:.+]] = %[[R]])
+// CHECK:   %[[CMP:.+]] = arith.cmpi slt, %[[ARG2]], %[[C10]]
+// CHECK:   scf.condition(%[[CMP]]) %[[ARG2]], %[[ARG3]], %[[ARG4]]
+
+// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: tensor<4xf32>, %[[ARG4:.+]]: tensor<4xf32>):
+// CHECK:   %[[ADD:.+]] = arith.addi %[[ARG2]], %[[C1]] : i32
+// CHECK:   %[[SADD:.+]] = stablehlo.add %[[ARG3]], %[[ARG4]] : tensor<4xf32>
+// CHECK:   scf.yield %[[ADD]], %[[ARG4]], %[[SADD]]
+
+// CHECK: %[[TUPLE:.+]] = stablehlo.tuple %[[WHILE]]#1, %[[WHILE]]#2 : tuple<tensor<4xf32>, tensor<4xf32>>
+// CHECK: return %[[WHILE]]#0, %[[TUPLE]]
+
+// -----
+
+func.func @testIf(%cond: i1, %a : tuple<tensor<4xf32>, tensor<4xf32>>) -> (tuple<tensor<4xf32>, tensor<4xf32>>) {
+  %r = scf.if %cond -> (tuple<tensor<4xf32>, tensor<4xf32>>) {
+    %2 = stablehlo.get_tuple_element %a[0] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+    %3 = stablehlo.get_tuple_element %a[1] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+    %5 = stablehlo.tuple %3, %2 : tuple<tensor<4xf32>, tensor<4xf32>>
+    scf.yield %5 : tuple<tensor<4xf32>, tensor<4xf32>>
+  } else {
+    scf.yield %a : tuple<tensor<4xf32>, tensor<4xf32>>
+  }
+  return %r : tuple<tensor<4xf32>, tensor<4xf32>>
+}
+
+// CHECK-LABEL: @testIf
+// CHECK-SAME: %[[ARG0:.+]]: i1, %[[ARG1:.+]]: tuple<tensor<4xf32>, tensor<4xf32>>
+
+// CHECK:  %[[IF:.+]]:2 = scf.if %[[ARG0]] -> (tensor<4xf32>, tensor<4xf32>) {
+// CHECK:    %[[L:.+]] = stablehlo.get_tuple_element %[[ARG1]][0]
+// CHECK:    %[[R:.+]] = stablehlo.get_tuple_element %[[ARG1]][1]
+// CHECK:    scf.yield %[[R]], %[[L]]
+
+// CHECK:    %[[L:.+]] = stablehlo.get_tuple_element %[[ARG1]][0]
+// CHECK:    %[[R:.+]] = stablehlo.get_tuple_element %[[ARG1]][1]
+// CHECK:    scf.yield %[[L]], %[[R]] : tensor<4xf32>, tensor<4xf32>
+// CHECK:  %[[TUPLE:.+]] = stablehlo.tuple %[[IF]]#0, %[[IF]]#1 : tuple<tensor<4xf32>, tensor<4xf32>>
+// CHECK:  return %[[TUPLE]]