[Flow] Allow element-wise fusion of multi-reduction ops (#16503)

This enables element-wise fusion of ops with multiple reduction dimensions in FusionOfTensorOps.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
index ea4f55a..060439e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp
@@ -19,6 +19,7 @@
 #include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
@@ -61,7 +62,8 @@
 }
 
 /// Check if the producer generic op is fusable with the consumer generic op.
-static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand) {
+static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand,
+                          bool fuseMultiReduction) {
   Operation *producerOp = fusedOperand->get().getDefiningOp();
   Operation *consumerOp = fusedOperand->getOwner();
   if (!producerOp)
@@ -115,11 +117,17 @@
         linalgConsumerOp.getNumLoops()) {
       return true;
     }
-    if (linalgConsumerOp.getNumReductionLoops() != 1 ||
-        !linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
+    if (!linalgConsumerOp.getMatchingIndexingMap(fusedOperand)
              .isPermutation()) {
       return false;
     }
+    if (!fuseMultiReduction && linalgConsumerOp.getNumReductionLoops() != 1) {
+      return false;
+    }
+    if (linalg::isaContractionOpInterface(linalgConsumerOp) ||
+        linalg::isaConvolutionOpInterface(linalgConsumerOp)) {
+      return false;
+    }
     return true;
   }
 
@@ -304,13 +312,15 @@
     registry.insert<affine::AffineDialect, arith::ArithDialect,
                     linalg::LinalgDialect, math::MathDialect>();
   }
-  FusionOfTensorOpsPass(bool fuseMultiUse, unsigned multiUseFusionIteration) {
+  FusionOfTensorOpsPass(bool fuseMultiUse, bool fuseMultiReduction,
+                        unsigned multiUseFusionIteration) {
     this->fuseMultiUse = fuseMultiUse;
+    this->fuseMultiReduction = fuseMultiReduction;
     this->multiUseFusionIteration = multiUseFusionIteration;
   }
   FusionOfTensorOpsPass(const FusionOfTensorOpsPass &pass)
-      : FusionOfTensorOpsPass(pass.fuseMultiUse, pass.multiUseFusionIteration) {
-  }
+      : FusionOfTensorOpsPass(pass.fuseMultiUse, pass.fuseMultiReduction,
+                              pass.multiUseFusionIteration) {}
 
   void runOnOperation() override {
     Operation *funcOp = getOperation();
@@ -345,7 +355,7 @@
             if (operands.size() >= kIreeMaxOperandCount)
               return false;
 
-            return areFusableOps(context, fusedOperand);
+            return areFusableOps(context, fusedOperand, fuseMultiReduction);
           };
       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
                                                    fuseElementwiseOpsControlFn);
@@ -501,10 +511,10 @@
 } // namespace
 
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
-createFusionOfTensorOpsPass(bool fuseMultiUse,
+createFusionOfTensorOpsPass(bool fuseMultiUse, bool fuseMultiReduction,
                             unsigned multiUseFusionIteration) {
-  return std::make_unique<FusionOfTensorOpsPass>(fuseMultiUse,
-                                                 multiUseFusionIteration);
+  return std::make_unique<FusionOfTensorOpsPass>(
+      fuseMultiUse, fuseMultiReduction, multiUseFusionIteration);
 }
 
 } // namespace mlir::iree_compiler::IREE::Flow
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 2cdc605..56edc9e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -80,6 +80,11 @@
                          llvm::cl::desc("Fuse multi-use ops."),
                          llvm::cl::init(false));
 
+static llvm::cl::opt<bool> clEnableElementWiseFuseMultiReduction(
+    "iree-flow-element-wise-fuse-multi-reduction",
+    llvm::cl::desc("Enable element-wise fusion of multi-reduction loop ops."),
+    llvm::cl::init(true));
+
 static llvm::cl::opt<bool> clDispatchGenerateWorkloadRegion(
     "iree-flow-dispatch-generate-workload-region",
     llvm::cl::desc("Generate the workload region."), llvm::cl::init(true));
@@ -135,8 +140,10 @@
       .addPass(mlir::createCanonicalizerPass)
       .addPass(mlir::createCSEPass)
       // Elementwise fusion.
-      .addPass(
-          []() { return createFusionOfTensorOpsPass(clEnableFuseMultiUse); })
+      .addPass([]() {
+        return createFusionOfTensorOpsPass(
+            clEnableFuseMultiUse, clEnableElementWiseFuseMultiReduction);
+      })
       .addPredicatedPass(clDetensoring,
                          [&]() { return mlir::createLinalgDetensorizePass(); })
       .addPass(mlir::createCanonicalizerPass)
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
index 2bfc4bc..60cf678 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.h
@@ -65,6 +65,7 @@
 // Creates a pass to fuse Linalg operations on tensors.
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createFusionOfTensorOpsPass(bool fuseMultiUse = false,
+                            bool fuseMultiReduction = true,
                             unsigned multiUseFusionIteration = 2);
 
 // Create a pass to initialize all empty tensors after dispatch formation to
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
index 55041e1..743b19e 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
@@ -168,6 +168,8 @@
   let options = [
     Option<"fuseMultiUse", "fuse-multi-use", "bool",
            /*default=*/"false", "Fuse ops with multiuse">,
+    Option<"fuseMultiReduction", "fuse-multi-reduction", "bool",
+           /*default=*/"true", "Fuse ops that have multiple reduction iterators">,
     Option<"multiUseFusionIteration", "multi-use-fusion-iteration", "unsigned",
            /*default=*/"2", "Number of iterations to fuse multiuse ops">
   ];
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir
index f3741b6..da68ca1 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/fusion_of_tensor_ops.mlir
@@ -366,10 +366,9 @@
 
 // -----
 
-util.func public @nofuse_by_expand_dequant(%arg0 : tensor<11008x4096xi4>, %arg1 : tensor<11008x32x1xf16>, %arg2 : tensor<11008x32x1xf16>) -> (tensor<11008xf16>) {
+util.func public @nofuse_by_expand_dequant(%arg0 : tensor<11008x4096xi4>, %arg1 : tensor<11008x32x1xf16>, %arg2 : tensor<11008x32x1xf16>, %arg3 : tensor<1x1x32x128xf16>) -> (tensor<11008xf16>) {
   %cst_1 = arith.constant 0.000000e+00 : f16
   %0 = tensor.empty() : tensor<11008x32x128xf16>
-  %1 = arith.constant dense<0.000000e+00> : tensor<1x1x32x128xf16>
   %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] : tensor<11008x4096xi4> into tensor<11008x32x128xi4>
   %collapsed = tensor.collapse_shape %arg2 [[0], [1, 2]] : tensor<11008x32x1xf16> into tensor<11008x32xf16>
   %collapsed_2 = tensor.collapse_shape %arg1 [[0], [1, 2]] : tensor<11008x32x1xf16> into tensor<11008x32xf16>
@@ -381,7 +380,7 @@
     %15 = arith.mulf %14, %in_8 : f16
     linalg.yield %15 : f16
   } -> tensor<11008x32x128xf16>
-  %collapsed_3 = tensor.collapse_shape %1 [[0, 1, 2], [3]] : tensor<1x1x32x128xf16> into tensor<32x128xf16>
+  %collapsed_3 = tensor.collapse_shape %arg3 [[0, 1, 2], [3]] : tensor<1x1x32x128xf16> into tensor<32x128xf16>
   %3 = tensor.empty() : tensor<11008xf16>
   %4 = linalg.fill ins(%cst_1 : f16) outs(%3 : tensor<11008xf16>) -> tensor<11008xf16>
   %5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%collapsed_3, %2 : tensor<32x128xf16>, tensor<11008x32x128xf16>) outs(%4 : tensor<11008xf16>) {
@@ -399,7 +398,7 @@
 //       CHECK-NOT:   tensor.collapse_shape %[[DEQUANT]]
 //           CHECK:     %[[MATVEC:.+]] = linalg.generic
 //      CHECK-SAME:         iterator_types = ["parallel", "reduction", "reduction"]
-//      CHECK-SAME:         ins(%[[DEQUANT]] : tensor<11008x32x128xf16>)
+//      CHECK-SAME:         ins({{.+}}%[[DEQUANT]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
index 5485f59..290d2d7 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/pipeline_tests.mlir
@@ -44,16 +44,11 @@
 //       CHECK:   flow.executable private @[[EXECUTABLE0:[a-zA-Z0-9_]+]]
 //       CHECK:     func.func @[[FUNC0:[a-zA-Z0-9_x]+]]
 //       CHECK:       linalg.generic
-//  CHECK-SAME:         ["parallel", "parallel"]
-//       CHECK:   flow.executable private @[[EXECUTABLE1:[a-zA-Z0-9_]+]]
-//       CHECK:     func.func @[[FUNC1:[a-zA-Z0-9_x]+]]
-//       CHECK:       linalg.generic
-//  CHECK-SAME:         ["reduction"]
+//  CHECK-SAME:         ["reduction", "reduction"]
+//   CHECK-NOT:       linalg.generic
 //       CHECK:   util.func public @main(
 //       CHECK:     %[[T0:.+]] = flow.dispatch @[[EXECUTABLE0]]::@[[FUNC0]]
-//       CHECK:     %[[T1:.+]] = flow.tensor.reshape %[[T0]] : tensor<833x833xf32> -> tensor<693889xf32>
-//       CHECK:     %[[T2:.+]] = flow.dispatch @[[EXECUTABLE1]]::@[[FUNC1]](%[[T1]])
-//       CHECK:     util.return %[[T2]]
+//       CHECK:     util.return %[[T0]]
 
 // -----