[LLVMCPU] Add option `onlyFuseProducerInputOperands` to tileRootFuseConsumerProducer Pass (#18114)

Previously, we only tilled the reduction tile sizes and did not fuse
them with the producers from the input operands. It led to transfer
read/write with large vector sizes since the dequant operation
materialised its own tensor and wasn't fused inside the reduction loop.
Adds a `onlyFuseProducerInputOperands` option to the
tile-root-and-fuse-consumer-producer-pass.
If the option is set to true, it tiles the reduction dimension and fuses
the operations arising from the input operand of the already tiled
operation. Issue link: https://github.com/iree-org/iree/issues/18005
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
index 3f6c727..78b3b6f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp
@@ -32,34 +32,99 @@
 
 namespace {
 
-/// Implementation of tile root and fuse producers and consumers greedily.
-static LogicalResult tileRootAndFuseProducerConsumerUsingSCF(
-    RewriterBase &rewriter, TilingInterface root,
-    const scf::SCFTileAndFuseOptions &options) {
+/// Starting from `op` walk all operands backwards to find all
+/// potentially fusable operations, i.e. operations that implement
+/// the `TilingInterface`.
+static void collectTiledAndFusedOps(Operation *rootOp,
+                                    llvm::SmallDenseSet<Operation *> &result) {
+  SmallVector<Operation *> worklist;
+  worklist.push_back(rootOp);
+  result.insert(rootOp);
+  while (!worklist.empty()) {
+    Operation *current = worklist.pop_back_val();
+    for (OpOperand &operand : current->getOpOperands()) {
+      Operation *producer = operand.get().getDefiningOp();
+      if (!producer || !isa<TilingInterface>(producer) ||
+          result.count(producer))
+        continue;
+      worklist.push_back(producer);
+      result.insert(producer);
+    }
+  }
+}
 
-  // This transformation is only valid for ops that return values (i.e. not
-  // valid to use with operations that have memref operands).
-  if (!root->getNumResults()) {
-    return rewriter.notifyMatchFailure(
-        root, "invalid pattern for op with no results");
+/// Tile the root operation and fuse the producers of the root operation.
+/// If `onlyFuseProducerInputOperands` is set, only fuse producer input
+/// operands. Returns the tiled operation to be used for fusing consumers.
+FailureOr<Operation *>
+tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
+                         int64_t tilingLevel,
+                         bool onlyFuseProducerInputOperands) {
+  mlir::DominanceInfo dominanceInfo(rootOp);
+  llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
+  collectTiledAndFusedOps(rootOp, tiledAndFusedOps);
+
+  llvm::DenseSet<Operation *> yieldReplacementsFor;
+  for (auto op : tiledAndFusedOps) {
+    if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+          return dominanceInfo.properlyDominates(rootOp, user);
+        })) {
+      yieldReplacementsFor.insert(op);
+    }
   }
 
-  // 1. Tile root op and Fuse Producers.
+  SmallVector<OpFoldResult> tileSizes =
+      getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
+                                                    rootOp);
+
+  // Pad the tile sizes with zero.
+  auto zero = rewriter.getIndexAttr(0);
+  int64_t numLoops = rootOp.getLoopIteratorTypes().size();
+  if (tileSizes.size() > numLoops) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "tile sizes size " << tileSizes.size()
+               << " exceeds the number of loops " << numLoops << "\n");
+    return failure();
+  }
+  tileSizes.resize(numLoops, zero);
+
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.setTileSizes(tileSizes);
+
+  scf::SCFTileAndFuseOptions tileAndFuseOptions;
+  tileAndFuseOptions.setTilingOptions(tilingOptions);
+
+  scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
+      [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+          bool isDestinationOperand) {
+        Operation *owner = originalProducer.getOwner();
+        bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+        // Do not fuse destination operands if onlyFuseProducerInputOperands is
+        // true.
+        bool shouldFuse =
+            !(onlyFuseProducerInputOperands && isDestinationOperand);
+        return std::make_tuple(shouldFuse, yieldProducerReplacement);
+      };
+  tileAndFuseOptions.setFusionControlFn(controlFn);
+
   FailureOr<scf::SCFTileAndFuseResult> tiledResults =
-      scf::tileConsumerAndFuseProducersUsingSCF(rewriter, root, options);
-
+      scf::tileConsumerAndFuseProducersUsingSCF(rewriter, rootOp,
+                                                tileAndFuseOptions);
   if (failed(tiledResults)) {
-    return rewriter.notifyMatchFailure(
-        root, "failed to tile root and fuse producers");
+    return failure();
   }
 
-  // 2. Replace the producers with the tiled verison.
-  SmallVector<Operation *> opsToReplace = {root};
+  // Perform the replacement of tiled and fused values.
+  SmallVector<Operation *> opsToReplace{rootOp};
   llvm::append_range(opsToReplace, tiledResults->fusedProducers);
   for (Operation *toReplace : opsToReplace) {
     for (OpResult res : toReplace->getResults())
       if (auto replacement = tiledResults->replacements.lookup(res)) {
-        rewriter.replaceAllUsesWith(res, replacement);
+        Operation *replacementOp = replacement.getDefiningOp();
+        rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
+          Operation *user = use.getOwner();
+          return dominanceInfo.properlyDominates(replacementOp, user);
+        });
       }
 
     if (toReplace->use_empty()) {
@@ -67,13 +132,18 @@
     }
   }
 
-  // 3. Typically, the consumers of the tiled operation are slices of the
-  //    results of the tiled operation. These are expressed in IR using
-  //    `tensor.insert_slice` operations, whose outputs are the operands of the
-  //    untiled operation. Create a worklist of these `tensor.insert_siices`
-  //    operations. If the consumers of the source of the `tensor.insert_slices`
-  //    can be tiled such that the tiled value is generated in-place, that
-  //    effectively tiles + fuses the operations.
+  return tiledResults->tiledAndFusedOps.front();
+}
+
+static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
+
+  //  Typically, the consumers of the tiled operation are slices of the
+  //  results of the tiled operation. These are expressed in IR using
+  //  `tensor.insert_slice` operations, whose outputs are the operands of the
+  //  untiled operation. Create a worklist of these `tensor.insert_siices`
+  //  operations. If the consumers of the source of the `tensor.insert_slices`
+  //  can be tiled such that the tiled value is generated in-place, that
+  //  effectively tiles + fuses the operations.
   auto addCandidateSlices = [](Operation *fusedOp,
                                std::queue<tensor::InsertSliceOp> &candidates) {
     for (auto *userOp : fusedOp->getResults().getUsers()) {
@@ -86,7 +156,7 @@
   // Collect the candidate slices which can be potential consumers that can be
   // fused.
   std::queue<tensor::InsertSliceOp> candidates;
-  addCandidateSlices(tiledResults->tiledAndFusedOps.front(), candidates);
+  addCandidateSlices(tiledOp, candidates);
 
   while (!candidates.empty()) {
 
@@ -112,42 +182,44 @@
     addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
                        candidates);
   }
-  return success();
 }
 
-static LogicalResult tileRootAndFuseProducerConsumer(IRRewriter &rewriter,
-                                                     TilingInterface rootOp,
-                                                     int64_t tilingLevel) {
+/// Implementation of tile root and fuse producers and consumers greedily.
+/// If `onlyFuseProducerInputOperands` is set, only fuse producer input operands
+/// and disable consumer fusion.
+static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
+                                     TilingInterface rootOp,
+                                     int64_t tilingLevel,
+                                     bool onlyFuseProducerInputOperands) {
 
-  SmallVector<OpFoldResult> tileSizes =
-      getLoweringConfig(rootOp).getTilingLevelSizes(rewriter, tilingLevel,
-                                                    rootOp);
-  int64_t numLoops = rootOp.getLoopIteratorTypes().size();
-  if (tileSizes.size() > numLoops)
+  FailureOr<Operation *> tiledOp = tileRootAndFuseProducers(
+      rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands);
+
+  if (failed(tiledOp))
     return failure();
 
-  scf::SCFTilingOptions tilingOptions;
-  tilingOptions.setTileSizes(tileSizes);
+  if (!onlyFuseProducerInputOperands)
+    fuseConsumers(rewriter, tiledOp.value());
 
-  scf::SCFTileAndFuseOptions tileAndFuseOptions;
-  tileAndFuseOptions.setTilingOptions(tilingOptions);
-
-  return tileRootAndFuseProducerConsumerUsingSCF(rewriter, rootOp,
-                                                 tileAndFuseOptions);
+  return success();
 }
 
 /// This pass starts with the first TilingInterface operation that has
 /// lowering_config attribute, tiles the op and fuses its  consumers and
-/// producers recursively. The `tilingLevel` must be specified. It picks the
-/// `tilingLevel`-th list as tiling sizes from lowering_config.
+/// producers recursively. If the `onlyFuseProducerInputOperands` is set, it
+/// only fuses producer input operands and disables consumer fusion. The
+/// `tilingLevel` must be specified. It picks the `tilingLevel`-th list as
+/// tiling sizes from lowering_config.
 struct LLVMCPUTileRootAndFuseProducerConsumer
     : impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
           LLVMCPUTileRootAndFuseProducerConsumer> {
   using impl::LLVMCPUTileRootAndFuseProducerConsumerPassBase<
       LLVMCPUTileRootAndFuseProducerConsumer>::
       LLVMCPUTileRootAndFuseProducerConsumerPassBase;
-  explicit LLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
+  explicit LLVMCPUTileRootAndFuseProducerConsumer(
+      int64_t tilingLevel, bool onlyFuseProducerInputOperands) {
     this->tilingLevel = tilingLevel;
+    this->onlyFuseProducerInputOperands = onlyFuseProducerInputOperands;
   }
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<arith::ArithDialect, affine::AffineDialect,
@@ -186,9 +258,9 @@
     return signalPassFailure();
   }
 
-  if (failed(tileRootAndFuseProducerConsumer(
+  if (failed(tileRootAndFuse(
           rewriter, dyn_cast<TilingInterface>(rootOp.value()),
-          tilingLevel.getValue()))) {
+          tilingLevel.getValue(), onlyFuseProducerInputOperands.getValue()))) {
     funcOp.emitError() << "tiling of level " << tilingLevel.getValue()
                        << " failed\n";
     return signalPassFailure();
@@ -212,6 +284,12 @@
 
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel) {
-  return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(tilingLevel);
+  return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
+      tilingLevel, /*onlyFuseProducerInputOperands=*/false);
+}
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel) {
+  return std::make_unique<LLVMCPUTileRootAndFuseProducerConsumer>(
+      tilingLevel, /*onlyFuseProducerInputOperands=*/true);
 }
 } // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index aeb9d7f..70cbe0d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -451,8 +451,8 @@
   funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
   funcPassManager.addPass(createConcretizePadResultShapePass());
 
-  funcPassManager.addPass(
-      createLLVMCPUTilePass(tilingConfig.getVectorReductionLevel()));
+  funcPassManager.addPass(createLLVMCPUTileRootAndFuseInputOperands(
+      tilingConfig.getVectorReductionLevel()));
   funcPassManager.addPass(
       createLLVMCPUTileAndFusePass(tilingConfig.getVectorInnerParallelLevel()));
   funcPassManager.addPass(createDecomposeConvolutionToLowerDimOpsPass());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index bac30de..a8cb91a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -42,6 +42,9 @@
 createLLVMCPUTileRootAndFuseProducerConsumer(int64_t tilingLevel);
 
 std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createLLVMCPUTileRootAndFuseInputOperands(int64_t tilingLevel);
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
 createLLVMCPUVerifyVectorSizeLegalityPass(
     int64_t maxAllowedNumberOfNativeVectors);
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
index c85666b..81dbdcf 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
@@ -140,13 +140,20 @@
   ];
 }
 
-def LLVMCPUTileRootAndFuseProducerConsumerPass :
-    InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer", "mlir::FunctionOpInterface"> {
-  let summary = "Pass to tile root op and fuse with producer and consumer TilingInterface ops.";
-  let options = [
-    Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
-      "Use default tiling level used to retrieve the configuration from lowering_config">
-  ];
+def LLVMCPUTileRootAndFuseProducerConsumerPass
+    : InterfacePass<"iree-llvmcpu-tile-root-and-fuse-producer-consumer",
+                    "mlir::FunctionOpInterface"> {
+  let summary = "Pass to tile root op and fuse with producer and consumer "
+                "TilingInterface ops.";
+  let options =
+      [Option<"tilingLevel", "tiling-level", "int64_t", /*default=*/"-1",
+              "Use default tiling level used to retrieve the configuration "
+              "from lowering_config">,
+       Option<"onlyFuseProducerInputOperands",
+              "only-fuse-producer-input-operands", "bool",
+              /*default=*/"false",
+              "Specifies if we only want to fuse producer's input operands. "
+              "This is helpful to tile&fuse in case of reduction dimensions.">];
 }
 
 def LLVMCPUVerifyVectorSizeLegalityPass :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
index 8710da6..4d8805d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile-root-fuse-consumer-producer.mlir
@@ -1,4 +1,6 @@
 // RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=0}), canonicalize)"  --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmcpu-tile-root-and-fuse-producer-consumer{tiling-level=2  only-fuse-producer-input-operands=true}), canonicalize)"  --split-input-file %s | FileCheck %s --check-prefix=CHECK-REDUCTION
+
 
 #config1 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -30,7 +32,8 @@
 //      CHECK:   }
 
 // -----
-#config = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
+
+#config2 = #iree_codegen.lowering_config<tile_sizes = [[1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 16, 16, 0], [0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0]]>
 #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
 func.func @quantized_matmul(%arg0: tensor<2x4x128x16x1xi8>, %arg1: tensor<2x4x16xf32>, %arg2: tensor<2x4x16xf32>, %arg3: tensor<2x688x128x16x1xi8>, %arg4: tensor<2x688x16xf32>, %arg5: tensor<2x688x16xf32>) -> tensor<2x11008x64xf32> {
@@ -61,7 +64,7 @@
   } -> tensor<2x688x128x16x1xf32>
   %4 = tensor.empty() : tensor<2x4x688x16x16xf32>
   %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
-  %6 = linalg.batch_mmt4d {lowering_config = #config} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
+  %6 = linalg.batch_mmt4d {lowering_config = #config2} ins(%1, %3 : tensor<2x4x128x16x1xf32>, tensor<2x688x128x16x1xf32>) outs(%5 : tensor<2x4x688x16x16xf32>) -> tensor<2x4x688x16x16xf32>
   %7 = tensor.empty() : tensor<2x11008x64xf32>
   %unpack = tensor.unpack %6 outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 16] into %7 : tensor<2x4x688x16x16xf32> -> tensor<2x11008x64xf32>
   return %unpack : tensor<2x11008x64xf32>
@@ -75,3 +78,42 @@
 //      CHECK:       linalg.batch_mmt4d
 //      CHECK:       tensor.unpack
 //      CHECK:   }
+
+
+// -----
+
+#config3 = #iree_codegen.lowering_config<tile_sizes = [[0, 32, 0, 0, 0, 0], [1, 16, 1, 1, 0, 0], [0, 0, 0, 0, 1, 5], [0, 0, 0, 0, 0, 0]]>
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @dequant_avgpool(%arg0: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf32> {
+  %cst = arith.constant 1.250000e-01 : f32
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %c5408000 = arith.constant 5408000 : index
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<1x320x1x1xf32>
+  %1 = tensor.empty() : tensor<65x65xf32>
+  %2 = linalg.fill ins(%cst_0 : f32) outs(%1 : tensor<65x65xf32>) -> tensor<65x65xf32>
+  %3 = tensor.empty() : tensor<1x320x65x65xf32>
+  %4 = tensor.empty() : tensor<1x320x1x1xf32>
+  %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x320x65x65xi8>) outs(%3 : tensor<1x320x65x65xf32>) {
+  ^bb0(%in: i8, %out: f32):
+    %7 = arith.extsi %in : i8 to i32
+    %8 = arith.sitofp %7 : i32 to f32
+    %9 = arith.mulf %8, %cst : f32
+    linalg.yield %9 : f32
+  } -> tensor<1x320x65x65xf32>
+  %6 = linalg.pooling_nchw_sum {lowering_config = #config3} ins(%5, %2 : tensor<1x320x65x65xf32>, tensor<65x65xf32>) outs(%4 : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32>
+  return %6 : tensor<1x320x1x1xf32>
+}
+
+// CHECK-REDUCTION-LABEL:   func.func @dequant_avgpool(
+// CHECK-REDUCTION-SAME:      {
+// CHECK-REDUCTION:           scf.for
+// CHECK-REDUCTION-SAME:        {
+// CHECK-REDUCTION:             scf.for
+// CHECK-REDUCTION-SAME:          {
+// CHECK-REDUCTION:                 linalg.generic
+// CHECK-REDUCTION:                 %[[POOL:.+]] = linalg.pooling_nchw_sum
+// CHECK-REDUCTION:                 scf.yield %[[POOL]]
+// CHECK-REDUCTION:               }
+// CHECK-REDUCTION:             }
+// CHECK-REDUCTION:           }