[GlobalOpt] Add quantized matmul reassociation support for f16 types (#15964)
This adds support for reassociating f16 typed quantized matmuls, fixing
a bug reported in https://github.com/openxla/iree/issues/15661.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
index 7ec2388..aa3473b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
@@ -169,7 +169,7 @@
struct QuantizedMatmulRewriter {
QuantizedMatmulRewriter(RewriterBase &rewriter, linalg::GenericOp dequant,
linalg::GenericOp matmul, int quantizedBitWidth);
- std::optional<SmallVector<OpOperand *>> getDequantMatmulInputs_f32();
+ std::optional<SmallVector<OpOperand *>> getDequantMatmulInputs();
std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>>
getGroupReductionMapsAndIterators(OpOperand *inputOperand);
Value getGroupReductionInit(Value input);
@@ -219,37 +219,28 @@
// TODO(#) Have stricter matching on inputs. There may be cases where
// the current matching fails
std::optional<SmallVector<OpOperand *>>
-QuantizedMatmulRewriter::getDequantMatmulInputs_f32() {
+QuantizedMatmulRewriter::getDequantMatmulInputs() {
assert(!failed(isContractionWithTwoReductions(matmul)) &&
"expected `matmul` to be a contraction with two reduction dimensions");
assert(!failed(isGroupedDequantizationOp(dequant)) &&
"expected `dequant` to be a grouped dequantization");
OpOperand *scales, *zps, *quantMat, *unquantMat, *dequantMat;
- for (int operandIdx = 0; operandIdx < dequant.getNumDpsInputs();
- operandIdx++) {
- OpOperand *operand = dequant.getDpsInputOperand(operandIdx);
- Value input = operand->get();
- RankedTensorType inputType =
- llvm::dyn_cast<RankedTensorType>(input.getType());
- if (!inputType) {
- continue;
- }
- if (inputType.getElementTypeBitWidth() != 32) {
- quantMat = operand;
- continue;
- }
- for (Operation &bodyOp : dequant.getBlock()->getOperations()) {
- if (isa<arith::MulFOp>(bodyOp)) {
- if (bodyOp.getOperand(1) ==
- dequant.getBlock()->getArgument(operandIdx)) {
- scales = operand;
- break;
- }
- } else if (isa<arith::SubFOp>(bodyOp)) {
- if (bodyOp.getOperand(1) ==
- dequant.getBlock()->getArgument(operandIdx)) {
- zps = operand;
- break;
+ auto maps = dequant.getIndexingMapsArray();
+ for (auto [idx, map] : enumerate(ArrayRef<AffineMap>(maps).drop_back())) {
+ if (map.isIdentity()) {
+ quantMat = dequant.getDpsInputOperand(idx);
+ } else if (map.isProjectedPermutation(true)) {
+ for (Operation &bodyOp : dequant.getBlock()->getOperations()) {
+ if (isa<arith::MulFOp>(bodyOp)) {
+ if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) {
+ scales = dequant.getDpsInputOperand(idx);
+ break;
+ }
+ } else if (isa<arith::SubFOp>(bodyOp)) {
+ if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) {
+ zps = dequant.getDpsInputOperand(idx);
+ break;
+ }
}
}
}
@@ -278,7 +269,7 @@
accType = rewriter.getI32Type();
mulType = rewriter.getI32Type();
quantType = rewriter.getIntegerType(quantizedBitWidth);
- std::optional<SmallVector<OpOperand *>> inputs = getDequantMatmulInputs_f32();
+ std::optional<SmallVector<OpOperand *>> inputs = getDequantMatmulInputs();
if (inputs) {
ins = *inputs;
}
@@ -319,11 +310,13 @@
return rewriter.notifyMatchFailure(
matmul, "inner shape of input expected to be reduced in matmul");
}
- if (!unquantizedInputType.getElementType().isa<FloatType>()) {
- return rewriter.notifyMatchFailure(matmul, "expected float type");
- }
Value scales = ins[2]->get();
Value zps = ins[3]->get();
+ if (!unquantizedInputType.getElementType().isa<FloatType>() ||
+ !getElementTypeOrSelf(scales).isa<FloatType>() ||
+ !getElementTypeOrSelf(zps).isa<FloatType>()) {
+ return rewriter.notifyMatchFailure(matmul, "expected float type");
+ }
OpOperand *matmulDequantizedOperand = ins[4];
auto matmulDequantizedInputExprs =
matmul.getMatchingIndexingMap(matmulDequantizedOperand).getResults();
@@ -629,8 +622,7 @@
outputExprs.front().getContext());
maps.push_back(outputMap);
- Type i32Type = rewriter.getI32Type();
- Type f32Type = rewriter.getF32Type();
+ Type floatType = getElementTypeOrSelf(scales);
Value output = matmulOutputOperand->get();
auto reassociatedDequantizationOp = rewriter.create<linalg::GenericOp>(
loc, output.getType(),
@@ -638,12 +630,7 @@
output, maps, iterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dq;
- if (accType == i32Type) {
- dq = b.create<arith::SIToFPOp>(loc, f32Type, args[0]);
- } else {
- Value ext = b.create<arith::ExtSIOp>(loc, i32Type, args[0]);
- dq = b.create<arith::SIToFPOp>(loc, f32Type, ext);
- }
+ dq = b.create<arith::SIToFPOp>(loc, floatType, args[0]);
Value scaledRes0 = b.create<arith::MulFOp>(loc, dq, args[1]);
Value scaledRes1 = b.create<arith::MulFOp>(loc, scaledRes0, args[3]);
Value scaledZp0 = b.create<arith::MulFOp>(loc, args[4], args[3]);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
index 52f6e09..2b59689 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
@@ -123,3 +123,130 @@
// CHECK: %[[READDF:.+]] = arith.addf %[[RESUBF]], %[[REOUT0]] : f32
// CHECK: linalg.yield %[[READDF]] : f32
// CHECK: return %[[GENREASSOCIATE]]
+
+// -----
+
+module {
+ func.func @grouped_quantized_matmul_reassociate_f16(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<32x128xf16>, %arg2: tensor<11008x32xf16>, %arg3: tensor<11008x32xf16>) -> tensor<11008xf16> {
+ %cst = arith.constant 0.000000e+00 : f16
+ %0 = tensor.empty() : tensor<11008xf16>
+ %1 = tensor.empty() : tensor<11008x32x128xf16>
+ %2 = linalg.fill ins(%cst : f16) outs(%0 : tensor<11008xf16>) -> tensor<11008xf16>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg2, %arg3 : tensor<11008x32x128xi4>, tensor<11008x32xf16>, tensor<11008x32xf16>) outs(%1 : tensor<11008x32x128xf16>) {
+ ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
+ %5 = arith.extui %in : i4 to i32
+ %6 = arith.uitofp %5 : i32 to f16
+ %7 = arith.subf %6, %in_1 : f16
+ %8 = arith.mulf %7, %in_0 : f16
+ linalg.yield %8 : f16
+ } -> tensor<11008x32x128xf16>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0)>],
+ iterator_types = ["parallel", "reduction", "reduction"]}
+ ins(%arg1, %3 : tensor<32x128xf16>, tensor<11008x32x128xf16>) outs(%2 : tensor<11008xf16>) {
+ ^bb0(%in: f16, %in_0: f16, %out: f16):
+ %5 = arith.mulf %in, %in_0 : f16
+ %6 = arith.addf %5, %out : f16
+ linalg.yield %6 : f16
+ } -> tensor<11008xf16>
+ return %4 : tensor<11008xf16>
+ }
+}
+
+// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[MAP2:[a-zA-Z0-9]+]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[MAP3:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[MAP4:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP5:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[MAP6:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: func.func @grouped_quantized_matmul_reassociate_f16(
+// CHECK-SAME: %[[QUANT:[a-zA-Z0-9_]+]]: tensor<11008x32x128xi4>
+// CHECK-SAME: %[[UNQUANT:[a-zA-Z0-9_]+]]: tensor<32x128xf16>
+// CHECK-SAME: %[[SCALES:[a-zA-Z0-9_]+]]: tensor<11008x32xf16>
+// CHECK-SAME: %[[ZPS:[a-zA-Z0-9_]+]]: tensor<11008x32xf16>
+// CHECK: %[[C0I32:.+]] = arith.constant 0 : i32
+// CHECK: %[[RANGE:.+]] = arith.constant 3.276800e+04 : f16
+// CHECK: %[[C0f16:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK: %[[INITOUT:.+]] = tensor.empty() : tensor<11008xf16>
+// CHECK: %[[FILLOUT:.+]] = linalg.fill ins(%[[C0f16]]
+// CHECK-SAME: outs(%[[INITOUT]] :
+// CHECK: %[[INITMAX:.+]] = tensor.empty() : tensor<32xf16>
+// CHECK: %[[FILLMAX:.+]] = linalg.fill ins(%[[C0f16]]
+// CHECK-SAME: outs(%[[INITMAX]] :
+// CHECK: %[[GENMAX:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[UNQUANT]] :
+// CHECK-SAME: outs(%[[FILLMAX]] :
+// CHECK: ^bb0(%[[MAXIN0:.+]]: f16, %[[MAXOUT0:.+]]: f16):
+// CHECK: %[[MAXABSF:.+]] = math.absf %[[MAXIN0]] : f16
+// CHECK: %[[MAXMAXIMUMF:.+]] = arith.maximumf %[[MAXABSF]], %[[MAXOUT0]] : f16
+// CHECK: linalg.yield %[[MAXMAXIMUMF]] : f16
+// CHECK: %[[INITSCALES:.+]] = tensor.empty() : tensor<32xf16>
+// CHECK: %[[GENSCALES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: ins(%[[GENMAX]] :
+// CHECK-SAME: outs(%[[INITSCALES]] :
+// CHECK: ^bb0(%[[SCALESIN0:.+]]: f16, %[[SCALESOUT0:.+]]: f16):
+// CHECK: %[[SCALESDIVF:.+]] = arith.divf %[[SCALESIN0]], %[[RANGE]] : f16
+// CHECK: linalg.yield %[[SCALESDIVF]] : f16
+// CHECK: %[[INITSUM:.+]] = tensor.empty() : tensor<32xf16>
+// CHECK: %[[FILLSUM:.+]] = linalg.fill ins(%[[C0f16]]
+// CHECK-SAME: outs(%[[INITSUM]] :
+// CHECK: %[[GENSUM:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[UNQUANT]] :
+// CHECK-SAME: outs(%[[FILLSUM]] :
+// CHECK: ^bb0(%[[SUMIN0:.+]]: f16, %[[SUMOUT0:.+]]: f16):
+// CHECK: %[[SUMADDF:.+]] = arith.addf %[[SUMIN0]], %[[SUMOUT0]] : f16
+// CHECK: linalg.yield %[[SUMADDF]] : f16
+// CHECK: %[[INITQUANT:.+]] = tensor.empty() : tensor<32x128xi16>
+// CHECK: %[[GENQUANT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[UNQUANT]], %[[GENSCALES]] :
+// CHECK-SAME: outs(%[[INITQUANT]] :
+// CHECK: ^bb0(%[[QUANTIN0:.+]]: f16, %[[QUANTIN1:.+]]: f16, %[[QUANTOUT0:.+]]: i16):
+// CHECK: %[[QUANTDIVF:.+]] = arith.divf %[[QUANTIN0]], %[[QUANTIN1]] : f16
+// CHECK: %[[QUANTFPTOSI:.+]] = arith.fptosi %[[QUANTDIVF]] : f16 to i16
+// CHECK: linalg.yield %[[QUANTFPTOSI]] : i16
+// CHECK: %[[INITMATMUL:.+]] = tensor.empty() : tensor<11008x32xi32>
+// CHECK: %[[FILLMATMUL:.+]] = linalg.fill ins(%[[C0I32]]
+// CHECK-SAME: outs(%[[INITMATMUL]] :
+// CHECK: %[[GENMATMUL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[GENQUANT]], %[[QUANT]] :
+// CHECK-SAME: outs(%[[FILLMATMUL]] :
+// CHECK: ^bb0(%[[MATMULIN0:.+]]: i16, %[[MATMULIN1:.+]]: i4, %[[MATMULOUT0:.+]]: i32):
+// CHECK-DAG: %[[MATMULEXTSI:.+]] = arith.extsi %[[MATMULIN0]] : i16 to i32
+// CHECK-DAG: %[[MATMULEXTUI:.+]] = arith.extui %[[MATMULIN1]] : i4 to i32
+// CHECK: %[[MATMULMULI:.+]] = arith.muli %[[MATMULEXTSI]], %[[MATMULEXTUI]] : i32
+// CHECK: %[[MATMULADDI:.+]] = arith.addi %[[MATMULMULI]], %[[MATMULOUT0]] : i32
+// CHECK: linalg.yield %[[MATMULADDI]] : i32
+// CHECK: %[[GENREASSOCIATE:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP6]], #[[MAP6]], #[[MAP0]], #[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "reduction"]
+// CHECK-SAME: ins(%[[GENMATMUL]], %[[GENSCALES]], %[[GENSUM]], %[[SCALES]], %[[ZPS]] :
+// CHECK-SAME: outs(%[[FILLOUT]] :
+// CHECK: ^bb0(%[[REIN0:.+]]: i32, %[[REIN1:.+]]: f16, %[[REIN2:.+]]: f16, %[[REIN3:.+]]: f16, %[[REIN4:.+]]: f16, %[[REOUT0:.+]]: f16):
+// CHECK-DAG: %[[RESITOFP:.+]] = arith.sitofp %[[REIN0]] : i32 to f16
+// CHECK-DAG: %[[REMULF0:.+]] = arith.mulf %[[RESITOFP]], %[[REIN1]] : f16
+// CHECK-DAG: %[[REMULF1:.+]] = arith.mulf %[[REMULF0]], %[[REIN3]] : f16
+// CHECK-DAG: %[[REMULF2:.+]] = arith.mulf %[[REIN4]], %[[REIN3]] : f16
+// CHECK-DAG: %[[REMULF3:.+]] = arith.mulf %[[REMULF2]], %[[REIN2]] : f16
+// CHECK: %[[RESUBF:.+]] = arith.subf %[[REMULF1]], %[[REMULF3]] : f16
+// CHECK: %[[READDF:.+]] = arith.addf %[[RESUBF]], %[[REOUT0]] : f16
+// CHECK: linalg.yield %[[READDF]] : f16
+// CHECK: return %[[GENREASSOCIATE]]