[CPU] Block root op chain fusion in non-root anchored tiling. (#24098)

When tiling from a non-root anchor op at VectorInnerParallelTiles level,
there are two paths, because the intention is for tiling dimensions that
not captured by rootOp:
(a) Start from an anchor op before the root op.
(b) Start from an anchor op after the root op.

In both cases, we should not fuse rootOp and the consumers/producers for
(a)/(b).

The revision aims to fix the case that reduction dimension is not tiled
on the root op which fused the root op unintended; it results in numeric
issues as expected. This can happen when the reduction dimension is much
less than native vector sizes.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp
index 0d26f70..1295c8b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp
@@ -71,6 +71,52 @@
   return nullptr;
 }
 
+/// Returns the root op and all its transitive consumers in `computeOps`.
+static llvm::SmallDenseSet<Operation *>
+getRootAndTransitiveConsumers(ArrayRef<Operation *> computeOps,
+                              Operation *rootOp) {
+  llvm::SmallDenseSet<Operation *> result;
+  if (!rootOp) {
+    return result;
+  }
+  result.insert(rootOp);
+  for (Operation *op : computeOps) {
+    if (result.contains(op)) {
+      continue;
+    }
+    for (Value operand : op->getOperands()) {
+      if (auto *def = operand.getDefiningOp(); def && result.contains(def)) {
+        result.insert(op);
+        break;
+      }
+    }
+  }
+  return result;
+}
+
+/// Returns the root op and all its transitive producers in `computeOps`.
+static llvm::SmallDenseSet<Operation *>
+getRootAndTransitiveProducers(ArrayRef<Operation *> computeOps,
+                              Operation *rootOp) {
+  llvm::SmallDenseSet<Operation *> result;
+  if (!rootOp) {
+    return result;
+  }
+  result.insert(rootOp);
+  for (Operation *op : llvm::reverse(computeOps)) {
+    if (result.contains(op)) {
+      continue;
+    }
+    for (auto user : op->getUsers()) {
+      if (result.contains(user)) {
+        result.insert(op);
+        break;
+      }
+    }
+  }
+  return result;
+}
+
 /// Returns the last operation that has `level` tiling level in lowering config
 /// before the root op (or ukernel ops) in the compute sequence.
 static Operation *getLastAnchorOpBeforeRootOp(ArrayRef<Operation *> computeOps,
@@ -99,11 +145,13 @@
 /// the root operation and fuse the producers of the root operation then
 /// consumers (finds any missing fusion opportunities, then apply producer
 /// fusion). If `onlyFuseProducerInputOperands` is set, only fuse producer input
-/// operands.
-static FailureOr<Operation *>
-tileRootAndFuseProducerConsumer(IRRewriter &rewriter, TilingInterface rootOp,
-                                IREE::CPU::TilingLevel tilingLevel,
-                                bool onlyFuseProducerInputOperands) {
+/// operands. `unfusableOps` contains operations that must not be fused as
+/// consumers (e.g., root ops from other anchor chains whose reduction
+/// dimensions would be incorrectly tiled as parallel).
+static FailureOr<Operation *> tileRootAndFuseProducerConsumer(
+    IRRewriter &rewriter, TilingInterface rootOp,
+    IREE::CPU::TilingLevel tilingLevel, bool onlyFuseProducerInputOperands,
+    const llvm::SmallDenseSet<Operation *> &unfusableOps = {}) {
   auto *context = rewriter.getContext();
   mlir::DominanceInfo dominanceInfo(rootOp);
   llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
@@ -209,10 +257,12 @@
 
   if (!onlyFuseProducerInputOperands) {
     FailureOr<std::queue<Operation *>> newFusionOpportunities =
-        fuseConsumersIntoForall(rewriter, *rootTiledOp, tilingLoops,
-                                [&tiledAndFusedOps](Operation *op) {
-                                  return tiledAndFusedOps.contains(op);
-                                });
+        fuseConsumersIntoForall(
+            rewriter, *rootTiledOp, tilingLoops,
+            [&tiledAndFusedOps, &unfusableOps](Operation *op) {
+              return tiledAndFusedOps.contains(op) &&
+                     !unfusableOps.contains(op);
+            });
 
     if (failed(newFusionOpportunities)) {
       LDBG() << "failed to fuse consumers, skip";
@@ -258,34 +308,48 @@
   IRRewriter rewriter(funcOp);
 
   SmallVector<Operation *> computeOps = getComputeOps(funcOp);
-  SmallVector<Operation *> anchorOps;
-  if (anchorOnRootOp) {
-    Operation *anchorOp = getRootOp(computeOps, tilingLevel);
-    if (anchorOp) {
-      anchorOps.push_back(anchorOp);
-    }
 
+  Operation *rootOp =
+      getRootOp(computeOps, IREE::CPU::TilingLevel::DistributionTiles);
+
+  // Anchor op paired with the set of ops that must not be fused as consumers
+  // when tiling from that anchor. When anchoring before the root, the root
+  // and its transitive consumers are unfusable; when after, the root and its
+  // transitive producers are unfusable; when on the root itself, nothing is
+  // restricted.
+  struct AnchorInfo {
+    Operation *anchorOp;
+    llvm::SmallDenseSet<Operation *> unfusableOps;
+  };
+  SmallVector<AnchorInfo> anchors;
+
+  if (anchorOnRootOp) {
+    if (Operation *anchorOp = getRootOp(computeOps, tilingLevel)) {
+      anchors.push_back({anchorOp, {}});
+    }
   } else {
     if (Operation *anchorOp =
             getLastAnchorOpAfterRootOp(computeOps, tilingLevel)) {
-      anchorOps.push_back(anchorOp);
+      anchors.push_back(
+          {anchorOp, getRootAndTransitiveProducers(computeOps, rootOp)});
     }
     if (Operation *anchorOp =
             getLastAnchorOpBeforeRootOp(computeOps, tilingLevel)) {
-      anchorOps.push_back(anchorOp);
+      anchors.push_back(
+          {anchorOp, getRootAndTransitiveConsumers(computeOps, rootOp)});
     }
   }
-  if (anchorOps.empty()) {
+  if (anchors.empty()) {
     LDBG() << "unable to find an anchor operation that has "
            << IREE::CPU::getTilingLevelName(tilingLevel) << " config";
     return;
   }
 
-  for (auto anchorOp : anchorOps) {
+  for (auto &[anchorOp, unfusable] : anchors) {
     LDBG() << "anchorOp: " << *anchorOp;
     if (failed(tileRootAndFuseProducerConsumer(
             rewriter, cast<TilingInterface>(anchorOp), tilingLevel,
-            onlyFuseProducerInputOperands))) {
+            onlyFuseProducerInputOperands, unfusable))) {
       funcOp.emitError() << "tiling of level "
                          << IREE::CPU::getTilingLevelName(tilingLevel)
                          << " failed\n";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir
index 57fb529..409e529 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir
@@ -375,3 +375,81 @@
 // INNER-PARALLEL:           linalg.fill
 // INNER-PARALLEL:           linalg.mmt4d
 // INNER-PARALLEL:           iree_linalg_ext.map_store
+
+// -----
+
+// Verify that the root op chain (reduction and its consumers) is NOT fused
+// into the before-root anchor's parallel forall, and that the root's producers
+// are NOT fused into the after-root anchor's parallel forall.
+// Chain: elementwise -> reduction(root) -> broadcast -> pack
+// The elementwise is the before-root anchor (has vector_inner_parallel).
+// The broadcast is the after-root anchor (has vector_inner_parallel).
+#map_ew = affine_map<(d0, d1) -> (d0, d1)>
+#map_red = affine_map<(d0, d1) -> (d0)>
+#map_bcast_in = affine_map<(d0, d1) -> (d0)>
+#map_bcast_out = affine_map<(d0, d1) -> (d0, d1)>
+#config_ew = #iree_cpu.lowering_config<vector_common_parallel = [4, 0], vector_inner_parallel = [0, 4], vector_reduction = [0, 16]>
+#config_red = #iree_cpu.lowering_config<distribution = [4, 0], vector_common_parallel = [4, 0], vector_reduction = [0, 16]>
+#config_fill = #iree_cpu.lowering_config<vector_common_parallel = [4]>
+#config_bcast = #iree_cpu.lowering_config<vector_common_parallel = [4, 0], vector_inner_parallel = [0, 8]>
+#config_pack = #iree_cpu.lowering_config<vector_common_parallel = [1, 0], vector_inner_parallel = [0, 1]>
+func.func @no_fuse_reduction_into_parallel_forall(
+    %input: tensor<4x16xi8>,
+    %init_ew: tensor<4x16xi8>,
+    %init_red: tensor<4xi32>,
+    %init_bcast: tensor<4x8xi32>,
+    %init_pack: tensor<1x4x4x2xi32>) -> tensor<1x4x4x2xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %c-128_i8 = arith.constant -128 : i8
+  %ew = linalg.generic {
+      indexing_maps = [#map_ew, #map_ew],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%input : tensor<4x16xi8>) outs(%init_ew : tensor<4x16xi8>)
+      attrs = {lowering_config = #config_ew} {
+    ^bb0(%in: i8, %out: i8):
+      %0 = arith.addi %in, %c-128_i8 : i8
+      linalg.yield %0 : i8
+  } -> tensor<4x16xi8>
+  %fill = linalg.fill {lowering_config = #config_fill}
+      ins(%c0_i32 : i32) outs(%init_red : tensor<4xi32>) -> tensor<4xi32>
+  %red = linalg.generic {
+      indexing_maps = [#map_ew, #map_red],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%ew : tensor<4x16xi8>) outs(%fill : tensor<4xi32>)
+      attrs = {lowering_config = #config_red} {
+    ^bb0(%in: i8, %out: i32):
+      %0 = arith.extsi %in : i8 to i32
+      %1 = arith.addi %0, %out : i32
+      linalg.yield %1 : i32
+  } -> tensor<4xi32>
+  %bcast = linalg.generic {
+      indexing_maps = [#map_bcast_in, #map_bcast_out],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%red : tensor<4xi32>) outs(%init_bcast : tensor<4x8xi32>)
+      attrs = {lowering_config = #config_bcast} {
+    ^bb0(%in: i32, %out: i32):
+      linalg.yield %in : i32
+  } -> tensor<4x8xi32>
+  %pack = linalg.pack %bcast outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+      inner_tiles = [4, 2] into %init_pack
+      {lowering_config = #config_pack}
+      : tensor<4x8xi32> -> tensor<1x4x4x2xi32>
+  return %pack : tensor<1x4x4x2xi32>
+}
+// Before-root anchor (elementwise): the root and its consumers must not be
+// fused into its forall.
+// After-root anchor (broadcast): the root and its producers must not be
+// fused into its forall.
+// INNER-PARALLEL-LABEL: func.func @no_fuse_reduction_into_parallel_forall
+//       INNER-PARALLEL:   scf.forall
+//       INNER-PARALLEL:     linalg.generic
+//       INNER-PARALLEL:     scf.forall.in_parallel
+//   INNER-PARALLEL-NOT:   scf.forall
+//       INNER-PARALLEL:   linalg.fill
+//       INNER-PARALLEL:   linalg.generic
+//  INNER-PARALLEL-SAME:     iterator_types = ["parallel", "reduction"]
+//       INNER-PARALLEL:   scf.forall
+//       INNER-PARALLEL:     linalg.generic
+//  INNER-PARALLEL-SAME:     iterator_types = ["parallel", "parallel"]
+//       INNER-PARALLEL:     linalg.pack
+//       INNER-PARALLEL:     scf.forall.in_parallel