Add SCF tuple flattening to StableHLO preprocessing (#14548)
Previously we removed tuples from control flow by first collapsing to
SCF. This makes detensoring passes much more difficult and is easier if
it works at the SCF level.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
index 4f6c2a3..048f845 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
@@ -45,12 +45,10 @@
passManager.addNestedPass<func::FuncOp>(
stablehlo::createLegalizeControlFlow());
- // Currently we don't handle SCF ops well and have to convert them all to CFG.
- // In the future it would be nice if we could have all of flow be both scf
- // and cfg compatible.
- passManager.addNestedPass<func::FuncOp>(createTopLevelSCFToCFGPass());
- if (detuple)
+ passManager.addPass(createFlattenTuplesInSCF());
+ if (detuple) {
passManager.addPass(createFlattenTuplesInCFG());
+ }
passManager.addPass(createStableHLOToStableHLOPreprocessing());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel
index 165ce95..9e2f5ae 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/BUILD.bazel
@@ -64,6 +64,7 @@
"DotGeneralToDot.cpp",
"EinsumToDotGeneral.cpp",
"FlattenTuplesInCFG.cpp",
+ "FlattenTuplesInSCF.cpp",
"GatherToTorchIndexSelect.cpp",
"LowerComplex.cpp",
"Passes.cpp",
@@ -85,6 +86,7 @@
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Support",
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt
index 8fccaef..5ef3819 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/CMakeLists.txt
@@ -57,6 +57,7 @@
"DotGeneralToDot.cpp"
"EinsumToDotGeneral.cpp"
"FlattenTuplesInCFG.cpp"
+ "FlattenTuplesInSCF.cpp"
"GatherToTorchIndexSelect.cpp"
"LowerComplex.cpp"
"Passes.cpp"
@@ -73,6 +74,7 @@
MLIRIR
MLIRMathDialect
MLIRPass
+ MLIRSCFDialect
MLIRShapeDialect
MLIRSparseTensorDialect
MLIRSupport
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp
index 2412ac3..0d13e5a 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInCFG.cpp
@@ -144,7 +144,6 @@
auto newResults = newOp.getResults();
for (auto oldResult : oldOp.getResults()) {
- llvm::SmallVector<Value, 10> subValues;
auto newResult = recursiveRetuple(oldResult.getType(), &newResults, builder,
oldOp->getLoc());
mapping.map(oldResult, newResult);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInSCF.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInSCF.cpp
new file mode 100644
index 0000000..a0432e0
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/FlattenTuplesInSCF.cpp
@@ -0,0 +1,266 @@
+// Copyright 2019 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+// Implements IREE-specific preprocessing for XLA inputs.
+
+#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h"
+#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "stablehlo/dialect/StablehloOps.h"
+
+namespace mlir::iree_compiler::stablehlo {
+
+#define GEN_PASS_DEF_FLATTENTUPLESINSCF
+#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h.inc"
+
+namespace {
+// Given a set of types, unpack to a list of a types, removing all tuples.
+void untupleTypes(TypeRange types, llvm::SmallVectorImpl<Type> &newTypes) {
+ for (Type type : types) {
+ if (auto tupleTy = dyn_cast<TupleType>(type)) {
+ untupleTypes(tupleTy.getTypes(), newTypes);
+ } else {
+ newTypes.push_back(type);
+ }
+ }
+}
+
+void recursiveUntuple(Value value, ImplicitLocOpBuilder b, IRMapping &mapping,
+ llvm::SmallVectorImpl<Value> &newValues) {
+ auto tupleType = dyn_cast<TupleType>(value.getType());
+ if (!tupleType) {
+ // We can return the value as is.
+ newValues.push_back(value);
+ return;
+ }
+
+ for (auto [idx, subType] : llvm::enumerate(tupleType.getTypes())) {
+ auto elementOp = b.create<mlir::stablehlo::GetTupleElementOp>(
+ subType, value, b.getI32IntegerAttr(idx));
+ recursiveUntuple(elementOp.getResult(), b, mapping, newValues);
+ }
+}
+
+Value recursiveRetuple(Type oldType, ArrayRef<Value> *values,
+ ImplicitLocOpBuilder &b) {
+ auto tupleType = dyn_cast<TupleType>(oldType);
+ if (!tupleType) {
+ Value returnValue = *values->begin();
+ *values = {values->begin() + 1, values->end()};
+ return returnValue;
+ }
+
+ llvm::SmallVector<Value> subValues;
+ for (Type subType : tupleType.getTypes()) {
+ subValues.push_back(recursiveRetuple(subType, values, b));
+ }
+
+ return b.create<mlir::stablehlo::TupleOp>(tupleType, subValues).getResult();
+}
+
+void DetupleRegion(Region &srcRegion, Region &destRegion, ArrayRef<Type> types,
+ IRMapping &mapping, ImplicitLocOpBuilder &b) {
+ auto beforeB = b.createBlock(&destRegion);
+ beforeB->addArguments(types, SmallVector<Location>(types.size(), b.getLoc()));
+ b.setInsertionPointToStart(beforeB);
+
+ llvm::SmallVector<Value> beforeResultVals;
+ for (auto argument : beforeB->getArguments()) {
+ beforeResultVals.push_back(argument);
+ }
+
+ llvm::ArrayRef<Value> beforeResults(beforeResultVals);
+ for (auto oldResult : srcRegion.front().getArguments()) {
+ auto newResult = recursiveRetuple(oldResult.getType(), &beforeResults, b);
+ mapping.map(oldResult, newResult);
+ }
+
+ b.setInsertionPointToEnd(&destRegion.front());
+ for (auto &srcop : srcRegion.front()) {
+ auto cloned = b.clone(srcop, mapping);
+ for (int i = 0; i < cloned->getNumResults(); i++) {
+ mapping.map(srcop.getResult(i), cloned->getResult(i));
+ }
+ }
+}
+
+class DetupleYieldOp : public OpRewritePattern<scf::YieldOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::scf::YieldOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ bool hasTuples = false;
+ IRMapping mapping;
+
+ llvm::SmallVector<Value> operands;
+ for (auto operand : op.getOperands()) {
+ hasTuples |= isa<TupleType>(operand.getType());
+ recursiveUntuple(operand, b, mapping, operands);
+ }
+
+ if (!hasTuples)
+ return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+ rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(op, operands);
+ return success();
+ }
+};
+
+class DetupleConditionOp : public OpRewritePattern<scf::ConditionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::scf::ConditionOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ bool hasTuples = false;
+ IRMapping mapping;
+
+ llvm::SmallVector<Value> operands;
+ for (auto operand : op.getArgs()) {
+ hasTuples |= isa<TupleType>(operand.getType());
+ recursiveUntuple(operand, b, mapping, operands);
+ }
+
+ if (!hasTuples)
+ return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+ rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(op, op.getCondition(),
+ operands);
+ return success();
+ }
+};
+
+class DetupleIfOp : public OpRewritePattern<scf::IfOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::scf::IfOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ bool hasTuples = false;
+ IRMapping mapping;
+
+ for (auto type : op.getResultTypes()) {
+ hasTuples |= isa<TupleType>(type);
+ }
+
+ if (!hasTuples)
+ return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+ llvm::SmallVector<Type> types;
+ untupleTypes(op.getResultTypes(), types);
+
+ auto newOp = b.create<mlir::scf::IfOp>(types, op.getOperand());
+
+ DetupleRegion(op.getThenRegion(), newOp.getThenRegion(), {}, mapping, b);
+ DetupleRegion(op.getElseRegion(), newOp.getElseRegion(), {}, mapping, b);
+
+ b.setInsertionPoint(op);
+ llvm::SmallVector<Value> newResultVals;
+ for (auto result : newOp.getResults()) {
+ newResultVals.push_back(result);
+ }
+
+ llvm::ArrayRef<Value> newResults(newResultVals);
+ llvm::SmallVector<Value, 10> retupledValues;
+ for (auto oldResult : op.getResults()) {
+ auto newResult = recursiveRetuple(oldResult.getType(), &newResults, b);
+ retupledValues.push_back(newResult);
+ mapping.map(oldResult, newResult);
+ }
+
+ rewriter.replaceOp(op, retupledValues);
+ return success();
+ }
+};
+
+class DetupleWhileOp : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::scf::WhileOp op,
+ PatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ bool hasTuples = false;
+ IRMapping mapping;
+
+ llvm::SmallVector<Value> operands;
+ for (auto operand : op.getOperands()) {
+ hasTuples |= isa<TupleType>(operand.getType());
+ recursiveUntuple(operand, b, mapping, operands);
+ }
+
+ if (!hasTuples)
+ return rewriter.notifyMatchFailure(op, "no tupled arguments");
+
+ llvm::SmallVector<Type> types;
+ untupleTypes(op.getResultTypes(), types);
+
+ auto newOp = b.create<mlir::scf::WhileOp>(types, operands);
+
+ DetupleRegion(op.getBefore(), newOp.getBefore(), types, mapping, b);
+ DetupleRegion(op.getAfter(), newOp.getAfter(), types, mapping, b);
+
+ b.setInsertionPoint(op);
+ llvm::SmallVector<Value> newResultVals;
+ for (auto result : newOp.getResults()) {
+ newResultVals.push_back(result);
+ }
+
+ llvm::ArrayRef<Value> newResults(newResultVals);
+ llvm::SmallVector<Value, 10> retupledValues;
+ for (auto oldResult : op.getResults()) {
+ auto newResult = recursiveRetuple(oldResult.getType(), &newResults, b);
+ retupledValues.push_back(newResult);
+ mapping.map(oldResult, newResult);
+ }
+
+ rewriter.replaceOp(op, retupledValues);
+ return success();
+ }
+};
+
+struct FlattenTuplesInSCF final
+ : impl::FlattenTuplesInSCFBase<FlattenTuplesInSCF> {
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mlir::scf::SCFDialect>();
+ registry.insert<mlir::stablehlo::StablehloDialect>();
+ }
+
+ void runOnOperation() override {
+ ModuleOp module = getOperation();
+ MLIRContext *ctx = module.getContext();
+ Builder b(ctx);
+
+ // Run canonicalization patterns to cancel out remaining tuple ops. We need
+ // to run these manually here because StableHLO does not define
+ // folds/canonicalization patterns for its ops.
+ RewritePatternSet patterns(ctx);
+ populateCanonicalizationPatterns(ctx, &patterns);
+ patterns
+ .add<DetupleYieldOp, DetupleConditionOp, DetupleIfOp, DetupleWhileOp>(
+ ctx);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+} // namespace mlir::iree_compiler::stablehlo
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
index b52cb4b..1b4a020 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
@@ -34,10 +34,15 @@
}
def FlattenTuplesInCFG :
- Pass<"iree-stablehlo-preprocessing-flatten-tuples", "ModuleOp"> {
+ Pass<"iree-stablehlo-preprocessing-flatten-cfg-tuples", "ModuleOp"> {
let summary = "Flattens tuples in the CFG form of StableHLO";
}
+def FlattenTuplesInSCF :
+ Pass<"iree-stablehlo-preprocessing-flatten-scf-tuples", "ModuleOp"> {
+ let summary = "Flattens tuples in the SCF form of StableHLO";
+}
+
def GatherToTorchIndexSelect :
Pass<"iree-stablehlo-preprocessing-gather-to-torch-index-select", "func::FuncOp"> {
let summary = "Legalizes gathers to a torch index select";
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel
index 358e11b..b8fe70c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/BUILD.bazel
@@ -24,6 +24,7 @@
"dot_general_to_dot.mlir",
"einsum_to_dot_general.mlir",
"flatten_tuples_in_cfg.mlir",
+ "flatten_tuples_in_scf.mlir",
"gather_to_torch_index_select.mlir",
"stablehlo_to_stablehlo.mlir",
"unfuse_batch_norm.mlir",
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt
index 79e9082..da9a641 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/CMakeLists.txt
@@ -20,6 +20,7 @@
"dot_general_to_dot.mlir"
"einsum_to_dot_general.mlir"
"flatten_tuples_in_cfg.mlir"
+ "flatten_tuples_in_scf.mlir"
"gather_to_torch_index_select.mlir"
"stablehlo_to_stablehlo.mlir"
"unfuse_batch_norm.mlir"
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir
index db0dcef..dc4f204 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_cfg.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --iree-stablehlo-preprocessing-flatten-tuples %s | FileCheck %s
+// RUN: iree-opt --iree-stablehlo-preprocessing-flatten-cfg-tuples %s | FileCheck %s
// CHECK-LABEL: @flatten_func
module @flatten_func {
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_scf.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_scf.mlir
new file mode 100644
index 0000000..c57e602
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/flatten_tuples_in_scf.mlir
@@ -0,0 +1,66 @@
+// RUN: iree-opt --iree-stablehlo-preprocessing-flatten-scf-tuples %s | FileCheck %s
+
+func.func @testWhile(%arg0 : i32, %arg1 : tuple<tensor<4xf32>, tensor<4xf32>>) -> (i32, tuple<tensor<4xf32>, tensor<4xf32>>) {
+ %0:2 = scf.while (%arg2 = %arg0, %arg3 = %arg1) : (i32, tuple<tensor<4xf32>, tensor<4xf32>>) -> (i32, tuple<tensor<4xf32>, tensor<4xf32>>) {
+ %c10 = arith.constant 10 : i32
+ %1 = arith.cmpi slt, %arg2, %c10 : i32
+ scf.condition(%1) %arg2, %arg3 : i32, tuple<tensor<4xf32>, tensor<4xf32>>
+ } do {
+ ^bb0(%arg2: i32, %arg3: tuple<tensor<4xf32>, tensor<4xf32>>):
+ %c1 = arith.constant 1 : i32
+ %add = arith.addi %arg2, %c1 : i32
+ %2 = stablehlo.get_tuple_element %arg3[0] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+ %3 = stablehlo.get_tuple_element %arg3[1] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+ %4 = stablehlo.add %2, %3 : tensor<4xf32>
+ %5 = stablehlo.tuple %3, %4 : tuple<tensor<4xf32>, tensor<4xf32>>
+ scf.yield %add, %5 : i32, tuple<tensor<4xf32>, tensor<4xf32>>
+ }
+ return %0#0, %0#1 : i32, tuple<tensor<4xf32>, tensor<4xf32>>
+}
+
+// CHECK-LABEL: @testWhile
+// CHECK-SAME: %[[ARG0:.+]]: i32, %[[ARG1:.+]]: tuple<tensor<4xf32>, tensor<4xf32>>
+// CHECK: %[[C1:.+]] = arith.constant 1
+// CHECK: %[[C10:.+]] = arith.constant 10
+// CHECK: %[[L:.+]] = stablehlo.get_tuple_element %[[ARG1]][0]
+// CHECK: %[[R:.+]] = stablehlo.get_tuple_element %[[ARG1]][1]
+
+// CHECK: %[[WHILE:.+]]:3 = scf.while (%[[ARG2:.+]] = %[[ARG0]], %[[ARG3:.+]] = %[[L]], %[[ARG4:.+]] = %[[R]])
+// CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[ARG2]], %[[C10]]
+// CHECK: scf.condition(%[[CMP]]) %[[ARG2]], %[[ARG3]], %[[ARG4]]
+
+// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: tensor<4xf32>, %[[ARG4:.+]]: tensor<4xf32>):
+// CHECK: %[[ADD:.+]] = arith.addi %[[ARG2]], %[[C1]] : i32
+// CHECK: %[[SADD:.+]] = stablehlo.add %[[ARG3]], %[[ARG4]] : tensor<4xf32>
+// CHECK: scf.yield %[[ADD]], %[[ARG4]], %[[SADD]]
+
+// CHECK: %[[TUPLE:.+]] = stablehlo.tuple %[[WHILE]]#1, %[[WHILE]]#2 : tuple<tensor<4xf32>, tensor<4xf32>>
+// CHECK: return %[[WHILE]]#0, %[[TUPLE]]
+
+// -----
+
+func.func @testIf(%cond: i1, %a : tuple<tensor<4xf32>, tensor<4xf32>>) -> (tuple<tensor<4xf32>, tensor<4xf32>>) {
+ %r = scf.if %cond -> (tuple<tensor<4xf32>, tensor<4xf32>>) {
+ %2 = stablehlo.get_tuple_element %a[0] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+ %3 = stablehlo.get_tuple_element %a[1] : (tuple<tensor<4xf32>, tensor<4xf32>>) -> tensor<4xf32>
+ %5 = stablehlo.tuple %3, %2 : tuple<tensor<4xf32>, tensor<4xf32>>
+ scf.yield %5 : tuple<tensor<4xf32>, tensor<4xf32>>
+ } else {
+ scf.yield %a : tuple<tensor<4xf32>, tensor<4xf32>>
+ }
+ return %r : tuple<tensor<4xf32>, tensor<4xf32>>
+}
+
+// CHECK-LABEL: @testIf
+// CHECK-SAME: %[[ARG0:.+]]: i1, %[[ARG1:.+]]: tuple<tensor<4xf32>, tensor<4xf32>>
+
+// CHECK: %[[IF:.+]]:2 = scf.if %[[ARG0]] -> (tensor<4xf32>, tensor<4xf32>) {
+// CHECK: %[[L:.+]] = stablehlo.get_tuple_element %[[ARG1]][0]
+// CHECK: %[[R:.+]] = stablehlo.get_tuple_element %[[ARG1]][1]
+// CHECK: scf.yield %[[R]], %[[L]]
+
+// CHECK: %[[L:.+]] = stablehlo.get_tuple_element %[[ARG1]][0]
+// CHECK: %[[R:.+]] = stablehlo.get_tuple_element %[[ARG1]][1]
+// CHECK: scf.yield %[[L]], %[[R]] : tensor<4xf32>, tensor<4xf32>
+// CHECK: %[[TUPLE:.+]] = stablehlo.tuple %[[IF]]#0, %[[IF]]#1 : tuple<tensor<4xf32>, tensor<4xf32>>
+// CHECK: return %[[TUPLE]]