[Flow] Loosen restrictions for dequantization fusion (#15663)
The conditions for cloning dequantization ops into a dispatch
region are too conservative, and miss fusions of some
dequantization-like ops. This loosens the restrictions for
cloning dequantization-like ops into dispatch regions.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index c89cd6f..1c66c7f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -810,7 +810,7 @@
// materializing large tensors between dispatches.
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
IREE::LinalgExt::SetEncodingOp>(op) ||
- isa<linalg::FillOp>(op) || isGroupedDequantizationOp(&op)) {
+ isa<linalg::FillOp>(op) || isDequantizationLikeOp(&op)) {
continue;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 27bb5c5..ea4f55a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -360,7 +360,7 @@
}
// Do not fuse by expand if consumer is dequant.
- if (isGroupedDequantizationOp(consumer)) {
+ if (isDequantizationLikeOp(consumer)) {
return false;
}
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 3a91af5..42fbca6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -504,120 +504,65 @@
return newRegionOp;
}
-/// Returns true if the operation is an generic op that represents dequant.
-/// This function checks that the genericOp:
-/// 1. Has a body like:
-/// arith.extui
-/// arith.uitofp
-/// arith.subf
-/// arith.mulf
-/// 2. scale, offset and result is f16 or f32, while weight is i4 or i8.
-/// 3. Has 3 parallel dims
-/// 4. Has 2 (weights, scales) or 3 (weights, scales, zero points)
-/// inputs and 1 output
-/// 5. Only weight has same shape as result.
-/// 6. Weight and result have identity indexing map.
-bool Flow::isGroupedDequantizationOp(Operation *op) {
+bool Flow::isDequantizationLikeOp(Operation *op) {
auto genericOp = dyn_cast<linalg::GenericOp>(op);
- if (!genericOp)
+ if (!genericOp) {
return false;
- if (genericOp.getNumDpsInits() != 1)
+ }
+ if (genericOp.getNumDpsInits() != 1) {
return false;
- if (genericOp.getNumDpsInputs() != 2 && genericOp.getNumDpsInputs() != 3)
- return false;
+ }
- // Check that the rank is at least 3 and all loops are parallel
+ // Check that the all loops are parallel
unsigned numLoops = genericOp.getNumLoops();
unsigned numParallelLoops = genericOp.getNumParallelLoops();
- if (numLoops < 3)
+ if (numLoops != numParallelLoops) {
return false;
- if (numLoops != numParallelLoops)
+ }
+
+ // Check that only one input has an identity map, and the rest are projected
+ // permutations and not full permutations
+ OpOperand *identityInput = nullptr;
+ for (OpOperand *input : genericOp.getDpsInputOperands()) {
+ auto inputMap = genericOp.getMatchingIndexingMap(input);
+ if (inputMap.isIdentity()) {
+ if (identityInput) {
+ return false;
+ }
+ identityInput = input;
+ } else if (!inputMap.isProjectedPermutation(true) ||
+ inputMap.isPermutation()) {
+ return false;
+ }
+ }
+
+ if (!identityInput) {
return false;
+ }
- auto inputs = genericOp.getInputs();
- auto weight = inputs[0];
- auto scales = inputs[1];
- auto init = genericOp.getDpsInits()[0];
- Type weightElType = getElementTypeOrSelf(weight.getType());
- Type scaleElType = getElementTypeOrSelf(scales.getType());
- Type initElType = getElementTypeOrSelf(init.getType());
-
- // Check that init and weight have parallel indexing maps.
auto indexingMaps = genericOp.getIndexingMapsArray();
- const int kWeightIndex = 0;
- const int kInitIndex = indexingMaps.size() - 1;
- if (!indexingMaps[kWeightIndex].isIdentity())
- return false;
- if (!indexingMaps[kInitIndex].isIdentity())
- return false;
-
- // Dequant weight and init need to be same type and shape.
- auto weightShape = weight.getType().dyn_cast<ShapedType>().getShape();
- auto initShape = init.getType().dyn_cast<ShapedType>().getShape();
- if (weightShape != initShape)
- return false;
-
- // Scale and init needs to be of same element type.
- if (scaleElType != initElType)
- return false;
-
- // Check weight is i4 or i8.
- if (!weightElType.isInteger(4) && !weightElType.isInteger(8)) {
+ if (!indexingMaps.back().isIdentity()) {
return false;
}
- // Check scales is f16 or f32.
- Type f32Type = Float32Type::get(op->getContext());
- Type f16Type = Float16Type::get(op->getContext());
- if (scaleElType != f32Type && scaleElType != f16Type) {
+ // Check that the identity input element bitwidth is smaller than the output
+ // element bitwidth.
+ Type inputElementType = getElementTypeOrSelf(identityInput->get().getType());
+ Type outputElementType = getElementTypeOrSelf(genericOp->getResultTypes()[0]);
+ if (!inputElementType.isIntOrFloat() || !outputElementType.isIntOrFloat()) {
+ return false;
+ }
+ if (inputElementType.getIntOrFloatBitWidth() >=
+ outputElementType.getIntOrFloatBitWidth()) {
return false;
}
- // Work back from linalg.yield and check body of genericOp.
- // The genericOp should yield the result of an arith.mulf,
- // preceded by an arith.subf, arith.uitofp, and arith.extui
- auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
- Value producerOutput;
- Operation *producer;
-
- // Producer of linalg.yield op is arith.mulf
- {
- producerOutput = yieldOp->getOperand(0);
- producer = producerOutput.getDefiningOp();
- if (!producer || producer->getNumOperands() == 0)
+ for (auto &bodyOp : genericOp.getBody()->getOperations()) {
+ if (!isa<arith::ExtUIOp, arith::ExtSIOp, arith::MulFOp, arith::MulIOp,
+ arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
+ arith::SIToFPOp, arith::UIToFPOp, linalg::YieldOp>(bodyOp)) {
return false;
- if (!matchPattern(producer, m_Op<arith::MulFOp>()))
- return false;
- }
-
- // Producer of arith.mulf op is arith.subf
- {
- producerOutput = producer->getOperand(0);
- producer = producerOutput.getDefiningOp();
- if (!producer || producer->getNumOperands() == 0)
- return false;
- if (!matchPattern(producer, m_Op<arith::SubFOp>()))
- return false;
- }
-
- // Producer of arith.subf op is arith.uitofp
- {
- producerOutput = producer->getOperand(0);
- producer = producerOutput.getDefiningOp();
- if (!producer || producer->getNumOperands() == 0)
- return false;
- if (!matchPattern(producer, m_Op<arith::UIToFPOp>()))
- return false;
- }
-
- // Producer of arith.uitofp op is arith.extui
- {
- producerOutput = producer->getOperand(0);
- producer = producerOutput.getDefiningOp();
- if (!producer)
- return false;
- if (!matchPattern(producer, m_Op<arith::ExtUIOp>()))
- return false;
+ }
}
return true;
@@ -638,7 +583,7 @@
tensor::ExtractSliceOp, complex::CreateOp>(op)) {
return true;
}
- if (isGroupedDequantizationOp(op)) {
+ if (isDequantizationLikeOp(op)) {
return true;
}
if (isa<arith::ConstantOp>(op) || isa<complex::ConstantOp>(op)) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index 30b4d9f..ca57d9e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -96,8 +96,15 @@
/// into a dispatch region.
bool isClonableIntoDispatchOp(Operation *op);
-/// Returns true if the operation is an generic op that represents dequant.
-bool isGroupedDequantizationOp(Operation *op);
+/// Returns true if the operation has dequantization-like properties.
+/// This function checks that the genericOp:
+/// 1. Has only one output, and the output has an identity indexing map
+/// 2. Has all parallel loops.
+/// 3. Has exactly one input with an identity indexing map.
+/// 4. All other inputs are projected permutations and not permutations.
+/// 5. The input with an identity indexing map has a smaller element
+/// bitwidth than the output
+bool isDequantizationLikeOp(Operation *op);
/// Collect all ops that should be cloned into the given dispatch region op.
SmallVector<Operation *> getCloneableOps(Flow::DispatchRegionOp regionOp);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
index dbd34de..ec6baec 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
@@ -186,3 +186,58 @@
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: flow.return %[[GEN1]] :
// CHECK: return %[[DISP]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+module {
+ func.func @clone_dequantization_like(%arg0: tensor<32x1x16x1x8xi16>, %arg1: tensor<32x344x16x32x8xi4>) -> tensor<32x1x344x1x32xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %0 = tensor.empty() : tensor<32x1x16x1x8xi32>
+ %1 = linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<32x1x16x1x8xi16>) outs(%0 : tensor<32x1x16x1x8xi32>) {
+ ^bb0(%in: i16, %out: i32):
+ %7 = arith.extsi %in : i16 to i32
+ linalg.yield %7 : i32
+ } -> tensor<32x1x16x1x8xi32>
+ %2 = tensor.empty() : tensor<32x344x16x32x8xi32>
+ %3 = linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg1 : tensor<32x344x16x32x8xi4>) outs(%2 : tensor<32x344x16x32x8xi32>) {
+ ^bb0(%in: i4, %out: i32):
+ %7 = arith.extui %in : i4 to i32
+ linalg.yield %7 : i32
+ } -> tensor<32x344x16x32x8xi32>
+ %4 = tensor.empty() : tensor<32x1x344x1x32xi32>
+ %5 = linalg.fill ins(%c0_i32 : i32) outs(%4 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+ %6 = flow.dispatch.region -> (tensor<32x1x344x1x32xi32>) {
+ %7 = linalg.batch_mmt4d ins(%1, %3 : tensor<32x1x16x1x8xi32>, tensor<32x344x16x32x8xi32>) outs(%5 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+ flow.return %7 : tensor<32x1x344x1x32xi32>
+ }
+ return %6 : tensor<32x1x344x1x32xi32>
+ }
+}
+// CHECK: func.func @clone_dequantization
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32x1x16x1x8xi16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x344x16x32x8xi4>
+// CHECK: %[[DISP:.+]] = flow.dispatch.region -> (tensor<32x1x344x1x32xi32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<32x1x16x1x8xi32>
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<32x1x344x1x32xi32>
+// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<32x344x16x32x8xi32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]]
+// CHECK-SAME: outs(%[[INIT1]] :
+// CHECK: %[[GEN0:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]] :
+// CHECK-SAME: outs(%[[INIT0]] :
+// CHECK: %[[GEN1:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG1]] :
+// CHECK-SAME: outs(%[[INIT2]] :
+// CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[GEN0]], %[[GEN1]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: flow.return %[[MMT4D]] :
+// CHECK: return %[[DISP]]
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
index 9a0b19c..7be213f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
@@ -554,3 +554,55 @@
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: flow.return %[[GEN1]] :
// CHECK: return %[[DISP]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+module {
+ func.func @no_dequantization_like_fusion(%arg0: tensor<32x1x16x1x8xi16>, %arg1: tensor<32x344x16x32x8xi4>) -> tensor<32x1x344x1x32xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %0 = tensor.empty() : tensor<32x1x16x1x8xi32>
+ %1 = linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<32x1x16x1x8xi16>) outs(%0 : tensor<32x1x16x1x8xi32>) {
+ ^bb0(%in: i16, %out: i32):
+ %7 = arith.extsi %in : i16 to i32
+ linalg.yield %7 : i32
+ } -> tensor<32x1x16x1x8xi32>
+ %2 = tensor.empty() : tensor<32x344x16x32x8xi32>
+ %3 = linalg.generic {indexing_maps = [#map, #map],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg1 : tensor<32x344x16x32x8xi4>) outs(%2 : tensor<32x344x16x32x8xi32>) {
+ ^bb0(%in: i4, %out: i32):
+ %7 = arith.extui %in : i4 to i32
+ linalg.yield %7 : i32
+ } -> tensor<32x344x16x32x8xi32>
+ %4 = tensor.empty() : tensor<32x1x344x1x32xi32>
+ %5 = linalg.fill ins(%c0_i32 : i32) outs(%4 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+ %7 = linalg.batch_mmt4d ins(%1, %3 : tensor<32x1x16x1x8xi32>, tensor<32x344x16x32x8xi32>) outs(%5 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+ return %7 : tensor<32x1x344x1x32xi32>
+ }
+}
+// CHECK: func.func @no_dequantization_like_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32x1x16x1x8xi16>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x344x16x32x8xi4>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<32x1x16x1x8xi32>
+// CHECK: %[[GEN0:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]] :
+// CHECK-SAME: outs(%[[INIT0]] :
+// CHECK-DAG: %[[INIT2:.+]] = tensor.empty() : tensor<32x344x16x32x8xi32>
+// CHECK: %[[GEN1:.+]] = linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG1]] :
+// CHECK-SAME: outs(%[[INIT2]] :
+// CHECK: %[[INIT1:.+]] = tensor.empty() : tensor<32x1x344x1x32xi32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]]
+// CHECK-SAME: outs(%[[INIT1]] :
+// CHECK: %[[DISP:.+]] = flow.dispatch.region -> (tensor<32x1x344x1x32xi32>)
+// CHECK: %[[MMT4D:.+]] = linalg.batch_mmt4d
+// CHECK-SAME: ins(%[[GEN0]], %[[GEN1]] :
+// CHECK-SAME: outs(%[[FILL]] :
+// CHECK: flow.return %[[MMT4D]] :
+// CHECK: return %[[DISP]]