Fix SSA use-def violation created by inserting flow.clone ops. (#6384)
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 125528a..946ebe3 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -172,7 +172,9 @@
tiedOperand.replaceUsesWithIf(clonedOperand, [&](OpOperand &use) {
Operation *user = use.getOwner();
return !excludedOps.count(user) &&
- user->getBlock() == clonedOperand.getDefiningOp()->getBlock();
+ user->getBlock() ==
+ clonedOperand.getDefiningOp()->getBlock() &&
+ clonedOperand.getDefiningOp()->isBeforeInBlock(user);
});
didClone = true;
}
diff --git a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
index 1b19641..bac3cac 100644
--- a/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
+++ b/iree/compiler/Dialect/Flow/IR/test/stream_folding.mlir
@@ -108,6 +108,32 @@
// -----
+// CHECK-LABEL: @dagImmutability
+// CHECK: %{{.+}} = flow.ex.stream.fragment
+// CHECK: %[[SRC:.+]] = flow.dispatch @_run_dispatch_1::@_run_dispatch_1[%c1, %c1, %c1]() : () -> tensor<i32>
+// CHECK: %[[RET0:.+]] = flow.tensor.clone %[[SRC]] : tensor<i32>
+// CHECK: %[[RET1:.+]] = flow.tensor.reshape %[[SRC]] : tensor<i32> -> tensor<1xi32>
+// CHECK: %[[RET2:.+]] = flow.tensor.slice
+// CHECK: flow.return %[[RET0]], %[[RET1]], %[[RET2]]
+func @dagImmutability(%arg0: tensor<1xi32>) -> (tensor<i32>, tensor<1xi32>, tensor<3xi32>) {
+ %0:3 = flow.ex.stream.fragment(%arg0) : (tensor<1xi32>) -> (tensor<i32>, tensor<1xi32>, tensor<3xi32>) =
+ (%arg1: tensor<1xi32>) -> (tensor<i32>, tensor<1xi32>, tensor<3xi32>) {
+ %c9 = constant 9 : index
+ %c1 = constant 1 : index
+ %c18 = constant 18 : index
+ %c0 = constant 0 : index
+ %c3 = constant 3 : index
+ %1 = flow.dispatch @_run_dispatch_1::@_run_dispatch_1[%c1, %c1, %c1]() : () -> tensor<i32>
+ %2 = flow.dispatch @_run_dispatch_2::@_run_dispatch_2[%c9, %c1, %c1](%1) : (tensor<i32>) -> tensor<9xi32>
+ %3 = flow.tensor.reshape %1 : tensor<i32> -> tensor<1xi32>
+ %4 = flow.tensor.slice %2[%c0 for %c3] : tensor<9xi32> -> tensor<3xi32>
+ flow.return %1, %3, %4 : tensor<i32>, tensor<1xi32>, tensor<3xi32>
+ }
+ return %0#0, %0#1, %0#2 : tensor<i32>, tensor<1xi32>, tensor<3xi32>
+}
+
+// -----
+
// Testing inserted clones: a clone here is required as we cannot update %_large_const in-place.
// CHECK-LABEL: func @insertCloneForUpdatedConstant