[DispatchCreation] Disable batch mmt4d fusion as its not supported by backends (#18611)
As described in #16025 we cant support mmt4d + elementwise fusion yet so
it needs to be disabled. Similarly batch mmt4d also needs to be
disabled.
Fixes https://github.com/iree-org/iree/issues/18589
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
index 222f854..3aba7fe 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
@@ -543,7 +543,7 @@
// TODO(#16025): Enable mmt4d fusion. It is disabled because the backends
// can not set multi lowering_config properly. See the issue for more details.
- if (isa<linalg::Mmt4DOp>(producer)) {
+ if (isa<linalg::Mmt4DOp, linalg::BatchMmt4DOp>(producer)) {
return false;
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index bfceeda..62bca48 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -751,3 +751,47 @@
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC2]], %[[GENERIC3]]
// CHECK: flow.return %[[GENERIC4]]
// CHECK: util.return %[[RESULT]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+util.func public @no_batch_mmt4d_fusion(%arg0: tensor<1x1x64x1x1xf32>,
+ %arg1: tensor<1x32x64x4x1xf32>, %arg2: tensor<1x1x32x1x4xf32>)
+ -> tensor<1x1x32x1x4xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x1x32x1x4xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x32x1x4xf32>) -> tensor<1x1x32x1x4xf32>
+ %2 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<1x1x64x1x1xf32>, tensor<1x32x64x4x1xf32>)
+ outs(%1 : tensor<1x1x32x1x4xf32>) -> tensor<1x1x32x1x4xf32>
+ %3 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg2, %2 : tensor<1x1x32x1x4xf32>, tensor<1x1x32x1x4xf32>)
+ outs(%0 : tensor<1x1x32x1x4xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.addf %in, %in_0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<1x1x32x1x4xf32>
+ util.return %3 : tensor<1x1x32x1x4xf32>
+}
+
+// CHECK-LABEL: util.func public @no_batch_mmt4d_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x1x64x1x1xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x32x64x4x1xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<1x1x32x1x4xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<1x1x32x1x4xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32)
+// CHECK-SAME: outs(%[[INIT0]] : tensor<1x1x32x1x4xf32>)
+// CHECK: %[[DISP0:.+]] = flow.dispatch.region -> (tensor<1x1x32x1x4xf32>)
+// CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x1x64x1x1xf32>, tensor<1x32x64x4x1xf32>)
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x1x32x1x4xf32>)
+// CHECK: flow.return %[[MMT4D]] : tensor<1x1x32x1x4xf32>
+// CHECK: %[[DISP1:.+]] = flow.dispatch.region -> (tensor<1x1x32x1x4xf32>)
+// CHECK: %[[GEN:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG2]], %[[DISP0]] : tensor<1x1x32x1x4xf32>, tensor<1x1x32x1x4xf32>)
+// CHECK-SAME: outs(%[[INIT0]] : tensor<1x1x32x1x4xf32>)
+// CHECK: flow.return %[[GEN]] : tensor<1x1x32x1x4xf32>
+// CHECK: util.return %[[DISP1]] : tensor<1x1x32x1x4xf32>