[Flow] Fix FuseDequantMatmul pass for multiple uses of dequant (#15045)
The fusion currently assumes dequant has only one use, which is not the
case if the result of dequant is used by two different matmuls. This
patch fixes this by clone the dequant if it has multiple uses.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp
index 8be4f2f..f251be1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FuseDequantizationMatmul.cpp
@@ -43,8 +43,16 @@
regionOp = maybeRegionOp.value();
}
- FailureOr<DispatchRegionOp> maybeFusedRegionOp =
- movePrecedingOpsIntoDispatchRegion(rewriter, dequant, regionOp);
+ FailureOr<DispatchRegionOp> maybeFusedRegionOp;
+ if (dequant->hasOneUse()) {
+ maybeFusedRegionOp =
+ movePrecedingOpsIntoDispatchRegion(rewriter, dequant, regionOp);
+ } else {
+ // Clone the dequant operation if there are multiple uses of dequant.
+ maybeFusedRegionOp =
+ clonePrecedingOpIntoDispatchRegion(rewriter, dequant, regionOp);
+ }
+
if (failed(maybeFusedRegionOp))
return failure();
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir
index 441c6f6..689263d 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fuse_dequantization_matmul.mlir
@@ -297,3 +297,66 @@
// CHECK-NOT: flow.dispatch.region
// CHECK-NOT: flow.return
// CHECK: return %[[GEN1]]
+
+// -----
+
+module {
+ func.func @clone_grouped_quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x1x4096xf32>
+ %1 = tensor.empty() : tensor<4096x32x128xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) {
+ ^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
+ %5 = arith.extui %in : i8 to i32
+ %6 = arith.uitofp %5 : i32 to f32
+ %7 = arith.subf %6, %in_1 : f32
+ %8 = arith.mulf %7, %in_0 : f32
+ linalg.yield %8 : f32
+ } -> tensor<4096x32x128xf32>
+ %barrier = util.optimization_barrier %3 : tensor<4096x32x128xf32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg1, %3 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%2 : tensor<1x1x4096xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %5 = arith.mulf %in, %in_0 : f32
+ %6 = arith.addf %5, %out : f32
+ linalg.yield %6 : f32
+ } -> tensor<1x1x4096xf32>
+ return %4 : tensor<1x1x4096xf32>
+ }
+}
+// CHECK: func.func @clone_grouped_quantized_matmul
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4096x32x128xi8>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x1x32x128xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[INIT1:.+]] = tensor.empty() : tensor<1x1x4096xf32>
+// CHECK: %[[INIT0:.+]] = tensor.empty() : tensor<4096x32x128xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]]
+// CHECK-SAME: outs(%[[INIT1]] :
+// CHECK: %[[GEN0:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG2]], %[[ARG3]] :
+// CHECK-SAME: outs(%[[INIT0]] :
+// CHECK: %[[DISP:.+]] = flow.dispatch.region -> (tensor<1x1x4096xf32>)
+// CHECK: %[[CLONE:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG2]], %[[ARG3]] :
+// CHECK-SAME: outs(%[[INIT0]] :
+// CHECK: %[[GEN1:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+// CHECK-SAME: ins(%[[ARG1]], %[[CLONE]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: flow.return %[[GEN1]] :
+// CHECK: return %[[DISP]]