[Flow] Fix FoldSplatReshapeIntoSplat pattern (#18818)
Apparently `replaceOpWithNewOp` does not insert the new operation at the
same location as the replaced op, but rather at the current insertion
point. So, set the insertion point to the consumer reshape op. This
ensures that the `flow.tensor.reshape`'s operands will be defined at
that point and will be valid when used to construct the
`flow.tensor.splat`. This is OK because we know there is only 1 use of
the `flow.tensor.splat` op.
fixes https://github.com/iree-org/iree/issues/18815
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 6930906..855df65 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -961,6 +961,8 @@
if (!reshapeOp)
return failure();
+ PatternRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(reshapeOp);
rewriter.replaceOpWithNewOp<TensorSplatOp>(
reshapeOp, reshapeOp.getResult().getType(), splatOp.getValue(),
reshapeOp.getResultDims());
diff --git a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
index 1559c3c..ada2b37 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/IR/test/tensor_folding.mlir
@@ -752,6 +752,22 @@
// -----
+// CHECK-LABEL: @foldSplatReshapeIntoSplatAfterDefs
+util.func public @foldSplatReshapeIntoSplatAfterDefs(%arg0 : f32) -> tensor<?x?xf32> {
+ // CHECK: %[[RES:.+]] = flow.tensor.splat %arg0 : tensor<?x?xf32>
+ // CHECK-NEXT: util.return %[[RES]] : tensor<?x?xf32>
+ %cst10 = arith.constant 10 : index
+ %cst4 = arith.constant 4 : index
+ %nofold0 = util.optimization_barrier %cst10 : index
+ %0 = flow.tensor.splat %arg0 : tensor<?x4xf32>{%nofold0}
+ %nofold1 = util.optimization_barrier %cst10 : index
+ %nofold2 = util.optimization_barrier %cst4 : index
+ %1 = flow.tensor.reshape %0 : tensor<?x4xf32>{%nofold0} -> tensor<?x?xf32>{%nofold1, %nofold2}
+ util.return %1 : tensor<?x?xf32>
+}
+
+// -----
+
util.func public @innermost_unit_dim(%4: !flow.dispatch.tensor<readonly:tensor<3x1x16x257x88xf16>>,
%arg0: index, %arg2 : index, %10 : index, %9 : index) -> tensor<?x?x?xf16> {
%c16 = arith.constant 16 : index