[Flow] Loosen restrictions for dequantization fusion (#15663)

The conditions for cloning dequantization ops into a dispatch
region are too conservative, and miss fusions of some
dequantization-like ops. This loosens the restrictions for
cloning dequantization-like ops into dispatch regions.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
index c89cd6f..1c66c7f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.cpp
@@ -810,7 +810,7 @@
       // materializing large tensors between dispatches.
       if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp,
                IREE::LinalgExt::SetEncodingOp>(op) ||
-          isa<linalg::FillOp>(op) || isGroupedDequantizationOp(&op)) {
+          isa<linalg::FillOp>(op) || isDequantizationLikeOp(&op)) {
         continue;
       }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index 27bb5c5..ea4f55a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -360,7 +360,7 @@
             }
 
             // Do not fuse by expand if consumer is dequant.
-            if (isGroupedDequantizationOp(consumer)) {
+            if (isDequantizationLikeOp(consumer)) {
               return false;
             }
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
index 3a91af5..42fbca6 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.cpp
@@ -504,120 +504,65 @@
   return newRegionOp;
 }
 
-/// Returns true if the operation is an generic op that represents dequant.
-/// This function checks that the genericOp:
-/// 1. Has a body like:
-///      arith.extui
-///      arith.uitofp
-///      arith.subf
-///      arith.mulf
-/// 2. scale, offset and result is f16 or f32, while weight is i4 or i8.
-/// 3. Has 3 parallel dims
-/// 4. Has 2 (weights, scales) or 3 (weights, scales, zero points)
-///    inputs and 1 output
-/// 5. Only weight has same shape as result.
-/// 6. Weight and result have identity indexing map.
-bool Flow::isGroupedDequantizationOp(Operation *op) {
+bool Flow::isDequantizationLikeOp(Operation *op) {
   auto genericOp = dyn_cast<linalg::GenericOp>(op);
-  if (!genericOp)
+  if (!genericOp) {
     return false;
-  if (genericOp.getNumDpsInits() != 1)
+  }
+  if (genericOp.getNumDpsInits() != 1) {
     return false;
-  if (genericOp.getNumDpsInputs() != 2 && genericOp.getNumDpsInputs() != 3)
-    return false;
+  }
 
-  // Check that the rank is at least 3 and all loops are parallel
+  // Check that the all loops are parallel
   unsigned numLoops = genericOp.getNumLoops();
   unsigned numParallelLoops = genericOp.getNumParallelLoops();
-  if (numLoops < 3)
+  if (numLoops != numParallelLoops) {
     return false;
-  if (numLoops != numParallelLoops)
+  }
+
+  // Check that only one input has an identity map, and the rest are projected
+  // permutations and not full permutations
+  OpOperand *identityInput = nullptr;
+  for (OpOperand *input : genericOp.getDpsInputOperands()) {
+    auto inputMap = genericOp.getMatchingIndexingMap(input);
+    if (inputMap.isIdentity()) {
+      if (identityInput) {
+        return false;
+      }
+      identityInput = input;
+    } else if (!inputMap.isProjectedPermutation(true) ||
+               inputMap.isPermutation()) {
+      return false;
+    }
+  }
+
+  if (!identityInput) {
     return false;
+  }
 
-  auto inputs = genericOp.getInputs();
-  auto weight = inputs[0];
-  auto scales = inputs[1];
-  auto init = genericOp.getDpsInits()[0];
-  Type weightElType = getElementTypeOrSelf(weight.getType());
-  Type scaleElType = getElementTypeOrSelf(scales.getType());
-  Type initElType = getElementTypeOrSelf(init.getType());
-
-  // Check that init and weight have parallel indexing maps.
   auto indexingMaps = genericOp.getIndexingMapsArray();
-  const int kWeightIndex = 0;
-  const int kInitIndex = indexingMaps.size() - 1;
-  if (!indexingMaps[kWeightIndex].isIdentity())
-    return false;
-  if (!indexingMaps[kInitIndex].isIdentity())
-    return false;
-
-  // Dequant weight and init need to be same type and shape.
-  auto weightShape = weight.getType().dyn_cast<ShapedType>().getShape();
-  auto initShape = init.getType().dyn_cast<ShapedType>().getShape();
-  if (weightShape != initShape)
-    return false;
-
-  // Scale and init needs to be of same element type.
-  if (scaleElType != initElType)
-    return false;
-
-  // Check weight is i4 or i8.
-  if (!weightElType.isInteger(4) && !weightElType.isInteger(8)) {
+  if (!indexingMaps.back().isIdentity()) {
     return false;
   }
 
-  // Check scales is f16 or f32.
-  Type f32Type = Float32Type::get(op->getContext());
-  Type f16Type = Float16Type::get(op->getContext());
-  if (scaleElType != f32Type && scaleElType != f16Type) {
+  // Check that the identity input element bitwidth is smaller than the output
+  // element bitwidth.
+  Type inputElementType = getElementTypeOrSelf(identityInput->get().getType());
+  Type outputElementType = getElementTypeOrSelf(genericOp->getResultTypes()[0]);
+  if (!inputElementType.isIntOrFloat() || !outputElementType.isIntOrFloat()) {
+    return false;
+  }
+  if (inputElementType.getIntOrFloatBitWidth() >=
+      outputElementType.getIntOrFloatBitWidth()) {
     return false;
   }
 
-  // Work back from linalg.yield and check body of genericOp.
-  // The genericOp should yield the result of an arith.mulf,
-  // preceded by an arith.subf, arith.uitofp, and arith.extui
-  auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
-  Value producerOutput;
-  Operation *producer;
-
-  // Producer of linalg.yield op is arith.mulf
-  {
-    producerOutput = yieldOp->getOperand(0);
-    producer = producerOutput.getDefiningOp();
-    if (!producer || producer->getNumOperands() == 0)
+  for (auto &bodyOp : genericOp.getBody()->getOperations()) {
+    if (!isa<arith::ExtUIOp, arith::ExtSIOp, arith::MulFOp, arith::MulIOp,
+             arith::AddFOp, arith::AddIOp, arith::SubFOp, arith::SubIOp,
+             arith::SIToFPOp, arith::UIToFPOp, linalg::YieldOp>(bodyOp)) {
       return false;
-    if (!matchPattern(producer, m_Op<arith::MulFOp>()))
-      return false;
-  }
-
-  // Producer of arith.mulf op is arith.subf
-  {
-    producerOutput = producer->getOperand(0);
-    producer = producerOutput.getDefiningOp();
-    if (!producer || producer->getNumOperands() == 0)
-      return false;
-    if (!matchPattern(producer, m_Op<arith::SubFOp>()))
-      return false;
-  }
-
-  // Producer of arith.subf op is arith.uitofp
-  {
-    producerOutput = producer->getOperand(0);
-    producer = producerOutput.getDefiningOp();
-    if (!producer || producer->getNumOperands() == 0)
-      return false;
-    if (!matchPattern(producer, m_Op<arith::UIToFPOp>()))
-      return false;
-  }
-
-  // Producer of arith.uitofp op is arith.extui
-  {
-    producerOutput = producer->getOperand(0);
-    producer = producerOutput.getDefiningOp();
-    if (!producer)
-      return false;
-    if (!matchPattern(producer, m_Op<arith::ExtUIOp>()))
-      return false;
+    }
   }
 
   return true;
@@ -638,7 +583,7 @@
           tensor::ExtractSliceOp, complex::CreateOp>(op)) {
     return true;
   }
-  if (isGroupedDequantizationOp(op)) {
+  if (isDequantizationLikeOp(op)) {
     return true;
   }
   if (isa<arith::ConstantOp>(op) || isa<complex::ConstantOp>(op)) {
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
index 30b4d9f..ca57d9e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h
@@ -96,8 +96,15 @@
 /// into a dispatch region.
 bool isClonableIntoDispatchOp(Operation *op);
 
-/// Returns true if the operation is an generic op that represents dequant.
-bool isGroupedDequantizationOp(Operation *op);
+/// Returns true if the operation has dequantization-like properties.
+/// This function checks that the genericOp:
+///     1. Has only one output, and the output has an identity indexing map
+///     2. Has all parallel loops.
+///     3. Has exactly one input with an identity indexing map.
+///     4. All other inputs are projected permutations and not permutations.
+///     5. The input with an identity indexing map has a smaller element
+///        bitwidth than the output
+bool isDequantizationLikeOp(Operation *op);
 
 /// Collect all ops that should be cloned into the given dispatch region op.
 SmallVector<Operation *> getCloneableOps(Flow::DispatchRegionOp regionOp);
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
index dbd34de..ec6baec 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/clone_producers_into_dispatch_regions.mlir
@@ -186,3 +186,58 @@
 //  CHECK-SAME:       outs(%[[FILL]] :
 //       CHECK:   flow.return %[[GEN1]] :
 //       CHECK:   return %[[DISP]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+module {
+  func.func @clone_dequantization_like(%arg0: tensor<32x1x16x1x8xi16>, %arg1: tensor<32x344x16x32x8xi4>) -> tensor<32x1x344x1x32xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %0 = tensor.empty() : tensor<32x1x16x1x8xi32>
+    %1 = linalg.generic {indexing_maps = [#map, #map], 
+                         iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} 
+                         ins(%arg0 : tensor<32x1x16x1x8xi16>) outs(%0 : tensor<32x1x16x1x8xi32>) {
+    ^bb0(%in: i16, %out: i32):
+      %7 = arith.extsi %in : i16 to i32
+      linalg.yield %7 : i32
+    } -> tensor<32x1x16x1x8xi32>
+    %2 = tensor.empty() : tensor<32x344x16x32x8xi32>
+    %3 = linalg.generic {indexing_maps = [#map, #map], 
+                         iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} 
+                         ins(%arg1 : tensor<32x344x16x32x8xi4>) outs(%2 : tensor<32x344x16x32x8xi32>) {
+    ^bb0(%in: i4, %out: i32):
+      %7 = arith.extui %in : i4 to i32
+      linalg.yield %7 : i32
+    } -> tensor<32x344x16x32x8xi32>
+    %4 = tensor.empty() : tensor<32x1x344x1x32xi32>
+    %5 = linalg.fill ins(%c0_i32 : i32) outs(%4 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+    %6 = flow.dispatch.region -> (tensor<32x1x344x1x32xi32>) {
+      %7 = linalg.batch_mmt4d ins(%1, %3 : tensor<32x1x16x1x8xi32>, tensor<32x344x16x32x8xi32>) outs(%5 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+      flow.return %7 : tensor<32x1x344x1x32xi32>
+    }
+    return %6 : tensor<32x1x344x1x32xi32>
+  }
+}
+//       CHECK: func.func @clone_dequantization
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32x1x16x1x8xi16>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x344x16x32x8xi4>
+//       CHECK:   %[[DISP:.+]] = flow.dispatch.region -> (tensor<32x1x344x1x32xi32>)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : i32
+//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<32x1x16x1x8xi32>
+//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<32x1x344x1x32xi32>
+//   CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() : tensor<32x344x16x32x8xi32>
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0]]
+//  CHECK-SAME:       outs(%[[INIT1]] :
+//       CHECK:   %[[GEN0:.+]] = linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[ARG0]] :
+//  CHECK-SAME:       outs(%[[INIT0]] :
+//       CHECK:   %[[GEN1:.+]] = linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[ARG1]] :
+//  CHECK-SAME:       outs(%[[INIT2]] :
+//       CHECK:   %[[MMT4D:.+]] = linalg.batch_mmt4d
+//  CHECK-SAME:       ins(%[[GEN0]], %[[GEN1]] :
+//  CHECK-SAME:       outs(%[[FILL]] :
+//       CHECK:   flow.return %[[MMT4D]] :
+//       CHECK:   return %[[DISP]]
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
index 9a0b19c..7be213f 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_dispatch_regions.mlir
@@ -554,3 +554,55 @@
 //  CHECK-SAME:       outs(%[[FILL]] :
 //       CHECK:   flow.return %[[GEN1]] :
 //       CHECK:   return %[[DISP]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+module {
+  func.func @no_dequantization_like_fusion(%arg0: tensor<32x1x16x1x8xi16>, %arg1: tensor<32x344x16x32x8xi4>) -> tensor<32x1x344x1x32xi32> {
+    %c0_i32 = arith.constant 0 : i32
+    %0 = tensor.empty() : tensor<32x1x16x1x8xi32>
+    %1 = linalg.generic {indexing_maps = [#map, #map], 
+                         iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} 
+                         ins(%arg0 : tensor<32x1x16x1x8xi16>) outs(%0 : tensor<32x1x16x1x8xi32>) {
+    ^bb0(%in: i16, %out: i32):
+      %7 = arith.extsi %in : i16 to i32
+      linalg.yield %7 : i32
+    } -> tensor<32x1x16x1x8xi32>
+    %2 = tensor.empty() : tensor<32x344x16x32x8xi32>
+    %3 = linalg.generic {indexing_maps = [#map, #map], 
+                         iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} 
+                         ins(%arg1 : tensor<32x344x16x32x8xi4>) outs(%2 : tensor<32x344x16x32x8xi32>) {
+    ^bb0(%in: i4, %out: i32):
+      %7 = arith.extui %in : i4 to i32
+      linalg.yield %7 : i32
+    } -> tensor<32x344x16x32x8xi32>
+    %4 = tensor.empty() : tensor<32x1x344x1x32xi32>
+    %5 = linalg.fill ins(%c0_i32 : i32) outs(%4 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+    %7 = linalg.batch_mmt4d ins(%1, %3 : tensor<32x1x16x1x8xi32>, tensor<32x344x16x32x8xi32>) outs(%5 : tensor<32x1x344x1x32xi32>) -> tensor<32x1x344x1x32xi32>
+    return %7 : tensor<32x1x344x1x32xi32>
+  }
+}
+//       CHECK: func.func @no_dequantization_like_fusion
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32x1x16x1x8xi16>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32x344x16x32x8xi4>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : i32
+//   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<32x1x16x1x8xi32>
+//       CHECK:   %[[GEN0:.+]] = linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[ARG0]] :
+//  CHECK-SAME:       outs(%[[INIT0]] :
+//   CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() : tensor<32x344x16x32x8xi32>
+//       CHECK:   %[[GEN1:.+]] = linalg.generic
+//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+//  CHECK-SAME:       ins(%[[ARG1]] :
+//  CHECK-SAME:       outs(%[[INIT2]] :
+//       CHECK:   %[[INIT1:.+]] = tensor.empty() : tensor<32x1x344x1x32xi32>
+//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[C0]]
+//  CHECK-SAME:       outs(%[[INIT1]] :
+//       CHECK:   %[[DISP:.+]] = flow.dispatch.region -> (tensor<32x1x344x1x32xi32>)
+//       CHECK:   %[[MMT4D:.+]] = linalg.batch_mmt4d
+//  CHECK-SAME:       ins(%[[GEN0]], %[[GEN1]] :
+//  CHECK-SAME:       outs(%[[FILL]] :
+//       CHECK:   flow.return %[[MMT4D]] :
+//       CHECK:   return %[[DISP]]