Handle fusion of any named ops with outs of root ops. (#7133)
Current code only handles fusion of linalg.fill/linalg.generic
with outs of root operations, but this could be any op. These ops
would be simultaneoulsy added to the fusion group of consumer root
operation, but also marked as root. Fix that and also ensure an op
within a dispatch region is not added to another "nested" dispatch
region.
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index a5e6f48..30c7932 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -717,6 +717,9 @@
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure();
if (!hasRootOpAttribute(op)) return failure();
+ if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ return failure();
+ }
// TODO(ravishankarm): It is getting strange to track when to apply this
// pattern and when not to. Need to revisit this, with dynamic shape cases
@@ -798,6 +801,9 @@
PatternRewriter &rewriter) const override {
if (!hasRootOpAttribute(tilableOp)) return failure();
if (hasOnlyDimUses(tilableOp)) return failure();
+ if (tilableOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
+ return failure();
+ }
SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
SmallVector<Range> loopRanges = tilableOp.getLoopBounds(rewriter);
@@ -1079,7 +1085,7 @@
// order here.
for (Operation &op : llvm::reverse(block)) {
// Start with a root operation and fuse its producers.
- if (!isRootOp(&op)) continue;
+ if (hasFusionGroupsAttribute(&op) || !isRootOp(&op)) continue;
unsigned newGroup = numRootOps++;
setRootAttribute(context, &op, newGroup);
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index 0dc7904..2d35b8a 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -1073,3 +1073,26 @@
// CHECK: scf.for %[[X:.+]] =
// CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum
// CHECK: flow.dispatch.tensor.store %[[POOL]], %[[OUTPUT]], offsets = [0, %[[Z]], %[[Y]], %[[X]]], sizes = [1, %{{.+}}, %{{.+}}, %{{.+}}]
+
+// -----
+
+func @named_op_outs_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %cst1 = constant -1.0 : f64
+ %cstm1 = constant 1.0 : f64
+ %c12345 = constant 12345 : i32
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %fill = linalg.fill_rng_2d ins(%cst1, %cstm1, %c12345 : f64, f64, i32)
+ outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %matmul = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %matmul : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @named_op_outs_fusion
+// CHECK: flow.dispatch.workgroups
+// CHECK: %[[FILL:.+]] = linalg.fill_rng_2d
+// CHECK: linalg.matmul
+// CHECK-SAME: outs(%[[FILL]] : tensor<?x?xf32>)