[GlobalOpt] Add quantized matmul reassociation support for f16 types (#15964)

This adds support for reassociating f16 typed quantized matmuls, fixing
a bug reported in https://github.com/openxla/iree/issues/15661.
diff --git a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
index 7ec2388..aa3473b 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/FuseDequantizationMatmul.cpp
@@ -169,7 +169,7 @@
 struct QuantizedMatmulRewriter {
   QuantizedMatmulRewriter(RewriterBase &rewriter, linalg::GenericOp dequant,
                           linalg::GenericOp matmul, int quantizedBitWidth);
-  std::optional<SmallVector<OpOperand *>> getDequantMatmulInputs_f32();
+  std::optional<SmallVector<OpOperand *>> getDequantMatmulInputs();
   std::pair<SmallVector<AffineMap>, SmallVector<utils::IteratorType>>
   getGroupReductionMapsAndIterators(OpOperand *inputOperand);
   Value getGroupReductionInit(Value input);
@@ -219,37 +219,28 @@
 // TODO(#) Have stricter matching on inputs. There may be cases where
 // the current matching fails
 std::optional<SmallVector<OpOperand *>>
-QuantizedMatmulRewriter::getDequantMatmulInputs_f32() {
+QuantizedMatmulRewriter::getDequantMatmulInputs() {
   assert(!failed(isContractionWithTwoReductions(matmul)) &&
          "expected `matmul` to be a contraction with two reduction dimensions");
   assert(!failed(isGroupedDequantizationOp(dequant)) &&
          "expected `dequant` to be a grouped dequantization");
   OpOperand *scales, *zps, *quantMat, *unquantMat, *dequantMat;
-  for (int operandIdx = 0; operandIdx < dequant.getNumDpsInputs();
-       operandIdx++) {
-    OpOperand *operand = dequant.getDpsInputOperand(operandIdx);
-    Value input = operand->get();
-    RankedTensorType inputType =
-        llvm::dyn_cast<RankedTensorType>(input.getType());
-    if (!inputType) {
-      continue;
-    }
-    if (inputType.getElementTypeBitWidth() != 32) {
-      quantMat = operand;
-      continue;
-    }
-    for (Operation &bodyOp : dequant.getBlock()->getOperations()) {
-      if (isa<arith::MulFOp>(bodyOp)) {
-        if (bodyOp.getOperand(1) ==
-            dequant.getBlock()->getArgument(operandIdx)) {
-          scales = operand;
-          break;
-        }
-      } else if (isa<arith::SubFOp>(bodyOp)) {
-        if (bodyOp.getOperand(1) ==
-            dequant.getBlock()->getArgument(operandIdx)) {
-          zps = operand;
-          break;
+  auto maps = dequant.getIndexingMapsArray();
+  for (auto [idx, map] : enumerate(ArrayRef<AffineMap>(maps).drop_back())) {
+    if (map.isIdentity()) {
+      quantMat = dequant.getDpsInputOperand(idx);
+    } else if (map.isProjectedPermutation(true)) {
+      for (Operation &bodyOp : dequant.getBlock()->getOperations()) {
+        if (isa<arith::MulFOp>(bodyOp)) {
+          if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) {
+            scales = dequant.getDpsInputOperand(idx);
+            break;
+          }
+        } else if (isa<arith::SubFOp>(bodyOp)) {
+          if (bodyOp.getOperand(1) == dequant.getBlock()->getArgument(idx)) {
+            zps = dequant.getDpsInputOperand(idx);
+            break;
+          }
         }
       }
     }
@@ -278,7 +269,7 @@
   accType = rewriter.getI32Type();
   mulType = rewriter.getI32Type();
   quantType = rewriter.getIntegerType(quantizedBitWidth);
-  std::optional<SmallVector<OpOperand *>> inputs = getDequantMatmulInputs_f32();
+  std::optional<SmallVector<OpOperand *>> inputs = getDequantMatmulInputs();
   if (inputs) {
     ins = *inputs;
   }
@@ -319,11 +310,13 @@
     return rewriter.notifyMatchFailure(
         matmul, "inner shape of input expected to be reduced in matmul");
   }
-  if (!unquantizedInputType.getElementType().isa<FloatType>()) {
-    return rewriter.notifyMatchFailure(matmul, "expected float type");
-  }
   Value scales = ins[2]->get();
   Value zps = ins[3]->get();
+  if (!unquantizedInputType.getElementType().isa<FloatType>() ||
+      !getElementTypeOrSelf(scales).isa<FloatType>() ||
+      !getElementTypeOrSelf(zps).isa<FloatType>()) {
+    return rewriter.notifyMatchFailure(matmul, "expected float type");
+  }
   OpOperand *matmulDequantizedOperand = ins[4];
   auto matmulDequantizedInputExprs =
       matmul.getMatchingIndexingMap(matmulDequantizedOperand).getResults();
@@ -629,8 +622,7 @@
                                   outputExprs.front().getContext());
   maps.push_back(outputMap);
 
-  Type i32Type = rewriter.getI32Type();
-  Type f32Type = rewriter.getF32Type();
+  Type floatType = getElementTypeOrSelf(scales);
   Value output = matmulOutputOperand->get();
   auto reassociatedDequantizationOp = rewriter.create<linalg::GenericOp>(
       loc, output.getType(),
@@ -638,12 +630,7 @@
       output, maps, iterators,
       [&](OpBuilder &b, Location loc, ValueRange args) {
         Value dq;
-        if (accType == i32Type) {
-          dq = b.create<arith::SIToFPOp>(loc, f32Type, args[0]);
-        } else {
-          Value ext = b.create<arith::ExtSIOp>(loc, i32Type, args[0]);
-          dq = b.create<arith::SIToFPOp>(loc, f32Type, ext);
-        }
+        dq = b.create<arith::SIToFPOp>(loc, floatType, args[0]);
         Value scaledRes0 = b.create<arith::MulFOp>(loc, dq, args[1]);
         Value scaledRes1 = b.create<arith::MulFOp>(loc, scaledRes0, args[3]);
         Value scaledZp0 = b.create<arith::MulFOp>(loc, args[4], args[3]);
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
index 52f6e09..2b59689 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/fuse_dequantization_matmul.mlir
@@ -123,3 +123,130 @@
 //       CHECK:   %[[READDF:.+]] = arith.addf %[[RESUBF]], %[[REOUT0]] : f32
 //       CHECK:   linalg.yield %[[READDF]] : f32
 //       CHECK:   return %[[GENREASSOCIATE]]
+
+// -----
+
+module {
+  func.func @grouped_quantized_matmul_reassociate_f16(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<32x128xf16>, %arg2: tensor<11008x32xf16>, %arg3: tensor<11008x32xf16>) -> tensor<11008xf16> {
+    %cst = arith.constant 0.000000e+00 : f16
+    %0 = tensor.empty() : tensor<11008xf16>
+    %1 = tensor.empty() : tensor<11008x32x128xf16>
+    %2 = linalg.fill ins(%cst : f16) outs(%0 : tensor<11008xf16>) -> tensor<11008xf16>
+    %3 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+        iterator_types = ["parallel", "parallel", "parallel"]}
+        ins(%arg0, %arg2, %arg3 : tensor<11008x32x128xi4>, tensor<11008x32xf16>, tensor<11008x32xf16>) outs(%1 : tensor<11008x32x128xf16>) {
+    ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
+      %5 = arith.extui %in : i4 to i32
+      %6 = arith.uitofp %5 : i32 to f16
+      %7 = arith.subf %6, %in_1 : f16
+      %8 = arith.mulf %7, %in_0 : f16
+      linalg.yield %8 : f16
+    } -> tensor<11008x32x128xf16>
+    %4 = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
+                         affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                         affine_map<(d0, d1, d2) -> (d0)>],
+        iterator_types = ["parallel", "reduction", "reduction"]}
+        ins(%arg1, %3 : tensor<32x128xf16>, tensor<11008x32x128xf16>) outs(%2 : tensor<11008xf16>) {
+    ^bb0(%in: f16, %in_0: f16, %out: f16):
+      %5 = arith.mulf %in, %in_0 : f16
+      %6 = arith.addf %5, %out : f16
+      linalg.yield %6 : f16
+    } -> tensor<11008xf16>
+    return %4 : tensor<11008xf16>
+  }
+}
+
+//   CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0)>
+//   CHECK-DAG: #[[MAP2:[a-zA-Z0-9]+]] = affine_map<(d0) -> (d0)>
+//   CHECK-DAG: #[[MAP3:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+//   CHECK-DAG: #[[MAP4:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP5:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP6:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d1)>
+//       CHECK: func.func @grouped_quantized_matmul_reassociate_f16(
+//  CHECK-SAME:   %[[QUANT:[a-zA-Z0-9_]+]]: tensor<11008x32x128xi4>
+//  CHECK-SAME:   %[[UNQUANT:[a-zA-Z0-9_]+]]: tensor<32x128xf16>
+//  CHECK-SAME:   %[[SCALES:[a-zA-Z0-9_]+]]: tensor<11008x32xf16>
+//  CHECK-SAME:   %[[ZPS:[a-zA-Z0-9_]+]]: tensor<11008x32xf16>
+//       CHECK:   %[[C0I32:.+]] = arith.constant 0 : i32
+//       CHECK:   %[[RANGE:.+]] = arith.constant 3.276800e+04 : f16
+//       CHECK:   %[[C0f16:.+]] = arith.constant 0.000000e+00 : f16
+//       CHECK:   %[[INITOUT:.+]] = tensor.empty() : tensor<11008xf16>
+//       CHECK:   %[[FILLOUT:.+]] = linalg.fill ins(%[[C0f16]]
+//  CHECK-SAME:       outs(%[[INITOUT]] :
+//       CHECK:   %[[INITMAX:.+]] = tensor.empty() : tensor<32xf16>
+//       CHECK:   %[[FILLMAX:.+]] = linalg.fill ins(%[[C0f16]]
+//  CHECK-SAME:       outs(%[[INITMAX]] :
+//       CHECK:   %[[GENMAX:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "reduction"]
+//  CHECK-SAME:       ins(%[[UNQUANT]] :
+//  CHECK-SAME:       outs(%[[FILLMAX]] :
+//       CHECK:   ^bb0(%[[MAXIN0:.+]]: f16, %[[MAXOUT0:.+]]: f16):
+//       CHECK:   %[[MAXABSF:.+]] = math.absf %[[MAXIN0]] : f16
+//       CHECK:   %[[MAXMAXIMUMF:.+]] = arith.maximumf %[[MAXABSF]], %[[MAXOUT0]] : f16
+//       CHECK:   linalg.yield %[[MAXMAXIMUMF]] : f16
+//       CHECK:   %[[INITSCALES:.+]] = tensor.empty() : tensor<32xf16>
+//       CHECK:   %[[GENSCALES:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP2]], #[[MAP2]]]
+//  CHECK-SAME:       iterator_types = ["parallel"]
+//  CHECK-SAME:       ins(%[[GENMAX]] :
+//  CHECK-SAME:       outs(%[[INITSCALES]] :
+//       CHECK:   ^bb0(%[[SCALESIN0:.+]]: f16, %[[SCALESOUT0:.+]]: f16):
+//       CHECK:   %[[SCALESDIVF:.+]] = arith.divf %[[SCALESIN0]], %[[RANGE]] : f16
+//       CHECK:   linalg.yield %[[SCALESDIVF]] : f16
+//       CHECK:   %[[INITSUM:.+]] = tensor.empty() : tensor<32xf16>
+//       CHECK:   %[[FILLSUM:.+]] = linalg.fill ins(%[[C0f16]]
+//  CHECK-SAME:       outs(%[[INITSUM]] :
+//       CHECK:   %[[GENSUM:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "reduction"]
+//  CHECK-SAME:       ins(%[[UNQUANT]] :
+//  CHECK-SAME:       outs(%[[FILLSUM]] :
+//       CHECK:   ^bb0(%[[SUMIN0:.+]]: f16, %[[SUMOUT0:.+]]: f16):
+//       CHECK:   %[[SUMADDF:.+]] = arith.addf %[[SUMIN0]], %[[SUMOUT0]] : f16
+//       CHECK:   linalg.yield %[[SUMADDF]] : f16
+//       CHECK:   %[[INITQUANT:.+]] = tensor.empty() : tensor<32x128xi16>
+//       CHECK:   %[[GENQUANT:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[UNQUANT]], %[[GENSCALES]] :
+//  CHECK-SAME:       outs(%[[INITQUANT]] :
+//       CHECK:   ^bb0(%[[QUANTIN0:.+]]: f16, %[[QUANTIN1:.+]]: f16, %[[QUANTOUT0:.+]]: i16):
+//       CHECK:   %[[QUANTDIVF:.+]] = arith.divf %[[QUANTIN0]], %[[QUANTIN1]] : f16
+//       CHECK:   %[[QUANTFPTOSI:.+]] = arith.fptosi %[[QUANTDIVF]] : f16 to i16
+//       CHECK:   linalg.yield %[[QUANTFPTOSI]] : i16
+//       CHECK:   %[[INITMATMUL:.+]] = tensor.empty() : tensor<11008x32xi32>
+//       CHECK:   %[[FILLMATMUL:.+]] = linalg.fill ins(%[[C0I32]]
+//  CHECK-SAME:       outs(%[[INITMATMUL]] :
+//       CHECK:   %[[GENMATMUL:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP3]], #[[MAP4]], #[[MAP5]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "reduction"]
+//  CHECK-SAME:       ins(%[[GENQUANT]], %[[QUANT]] :
+//  CHECK-SAME:       outs(%[[FILLMATMUL]] :
+//       CHECK:   ^bb0(%[[MATMULIN0:.+]]: i16, %[[MATMULIN1:.+]]: i4, %[[MATMULOUT0:.+]]: i32):
+//   CHECK-DAG:   %[[MATMULEXTSI:.+]] = arith.extsi %[[MATMULIN0]] : i16 to i32
+//   CHECK-DAG:   %[[MATMULEXTUI:.+]] = arith.extui %[[MATMULIN1]] : i4 to i32
+//       CHECK:   %[[MATMULMULI:.+]] = arith.muli %[[MATMULEXTSI]], %[[MATMULEXTUI]] : i32
+//       CHECK:   %[[MATMULADDI:.+]] = arith.addi %[[MATMULMULI]], %[[MATMULOUT0]] : i32
+//       CHECK:   linalg.yield %[[MATMULADDI]] : i32
+//       CHECK:   %[[GENREASSOCIATE:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP6]], #[[MAP6]], #[[MAP0]], #[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:       iterator_types = ["parallel", "reduction"]
+//  CHECK-SAME:       ins(%[[GENMATMUL]], %[[GENSCALES]], %[[GENSUM]], %[[SCALES]], %[[ZPS]] :
+//  CHECK-SAME:       outs(%[[FILLOUT]] :
+//       CHECK:   ^bb0(%[[REIN0:.+]]: i32, %[[REIN1:.+]]: f16, %[[REIN2:.+]]: f16, %[[REIN3:.+]]: f16, %[[REIN4:.+]]: f16, %[[REOUT0:.+]]: f16):
+//   CHECK-DAG:   %[[RESITOFP:.+]] = arith.sitofp %[[REIN0]] : i32 to f16
+//   CHECK-DAG:   %[[REMULF0:.+]] = arith.mulf %[[RESITOFP]], %[[REIN1]] : f16
+//   CHECK-DAG:   %[[REMULF1:.+]] = arith.mulf %[[REMULF0]], %[[REIN3]] : f16
+//   CHECK-DAG:   %[[REMULF2:.+]] = arith.mulf %[[REIN4]], %[[REIN3]] : f16
+//   CHECK-DAG:   %[[REMULF3:.+]] = arith.mulf %[[REMULF2]], %[[REIN2]] : f16
+//       CHECK:   %[[RESUBF:.+]] = arith.subf %[[REMULF1]], %[[REMULF3]] : f16
+//       CHECK:   %[[READDF:.+]] = arith.addf %[[RESUBF]], %[[REOUT0]] : f16
+//       CHECK:   linalg.yield %[[READDF]] : f16
+//       CHECK:   return %[[GENREASSOCIATE]]