[CPU] Block root op chain fusion in non-root anchored tiling. (#24098)
When tiling from a non-root anchor op at VectorInnerParallelTiles level,
there are two paths, because the intention is for tiling dimensions that
not captured by rootOp:
(a) Start from an anchor op before the root op.
(b) Start from an anchor op after the root op.
In both cases, we should not fuse rootOp and the consumers/producers for
(a)/(b).
The revision aims to fix the case that reduction dimension is not tiled
on the root op which fused the root op unintended; it results in numeric
issues as expected. This can happen when the reduction dimension is much
less than native vector sizes.
Signed-off-by: hanhanW <hanhan0912@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp
index 0d26f70..1295c8b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileAndFuseProducerConsumer.cpp
@@ -71,6 +71,52 @@
return nullptr;
}
+/// Returns the root op and all its transitive consumers in `computeOps`.
+static llvm::SmallDenseSet<Operation *>
+getRootAndTransitiveConsumers(ArrayRef<Operation *> computeOps,
+ Operation *rootOp) {
+ llvm::SmallDenseSet<Operation *> result;
+ if (!rootOp) {
+ return result;
+ }
+ result.insert(rootOp);
+ for (Operation *op : computeOps) {
+ if (result.contains(op)) {
+ continue;
+ }
+ for (Value operand : op->getOperands()) {
+ if (auto *def = operand.getDefiningOp(); def && result.contains(def)) {
+ result.insert(op);
+ break;
+ }
+ }
+ }
+ return result;
+}
+
+/// Returns the root op and all its transitive producers in `computeOps`.
+static llvm::SmallDenseSet<Operation *>
+getRootAndTransitiveProducers(ArrayRef<Operation *> computeOps,
+ Operation *rootOp) {
+ llvm::SmallDenseSet<Operation *> result;
+ if (!rootOp) {
+ return result;
+ }
+ result.insert(rootOp);
+ for (Operation *op : llvm::reverse(computeOps)) {
+ if (result.contains(op)) {
+ continue;
+ }
+ for (auto user : op->getUsers()) {
+ if (result.contains(user)) {
+ result.insert(op);
+ break;
+ }
+ }
+ }
+ return result;
+}
+
/// Returns the last operation that has `level` tiling level in lowering config
/// before the root op (or ukernel ops) in the compute sequence.
static Operation *getLastAnchorOpBeforeRootOp(ArrayRef<Operation *> computeOps,
@@ -99,11 +145,13 @@
/// the root operation and fuse the producers of the root operation then
/// consumers (finds any missing fusion opportunities, then apply producer
/// fusion). If `onlyFuseProducerInputOperands` is set, only fuse producer input
-/// operands.
-static FailureOr<Operation *>
-tileRootAndFuseProducerConsumer(IRRewriter &rewriter, TilingInterface rootOp,
- IREE::CPU::TilingLevel tilingLevel,
- bool onlyFuseProducerInputOperands) {
+/// operands. `unfusableOps` contains operations that must not be fused as
+/// consumers (e.g., root ops from other anchor chains whose reduction
+/// dimensions would be incorrectly tiled as parallel).
+static FailureOr<Operation *> tileRootAndFuseProducerConsumer(
+ IRRewriter &rewriter, TilingInterface rootOp,
+ IREE::CPU::TilingLevel tilingLevel, bool onlyFuseProducerInputOperands,
+ const llvm::SmallDenseSet<Operation *> &unfusableOps = {}) {
auto *context = rewriter.getContext();
mlir::DominanceInfo dominanceInfo(rootOp);
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
@@ -209,10 +257,12 @@
if (!onlyFuseProducerInputOperands) {
FailureOr<std::queue<Operation *>> newFusionOpportunities =
- fuseConsumersIntoForall(rewriter, *rootTiledOp, tilingLoops,
- [&tiledAndFusedOps](Operation *op) {
- return tiledAndFusedOps.contains(op);
- });
+ fuseConsumersIntoForall(
+ rewriter, *rootTiledOp, tilingLoops,
+ [&tiledAndFusedOps, &unfusableOps](Operation *op) {
+ return tiledAndFusedOps.contains(op) &&
+ !unfusableOps.contains(op);
+ });
if (failed(newFusionOpportunities)) {
LDBG() << "failed to fuse consumers, skip";
@@ -258,34 +308,48 @@
IRRewriter rewriter(funcOp);
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
- SmallVector<Operation *> anchorOps;
- if (anchorOnRootOp) {
- Operation *anchorOp = getRootOp(computeOps, tilingLevel);
- if (anchorOp) {
- anchorOps.push_back(anchorOp);
- }
+ Operation *rootOp =
+ getRootOp(computeOps, IREE::CPU::TilingLevel::DistributionTiles);
+
+ // Anchor op paired with the set of ops that must not be fused as consumers
+ // when tiling from that anchor. When anchoring before the root, the root
+ // and its transitive consumers are unfusable; when after, the root and its
+ // transitive producers are unfusable; when on the root itself, nothing is
+ // restricted.
+ struct AnchorInfo {
+ Operation *anchorOp;
+ llvm::SmallDenseSet<Operation *> unfusableOps;
+ };
+ SmallVector<AnchorInfo> anchors;
+
+ if (anchorOnRootOp) {
+ if (Operation *anchorOp = getRootOp(computeOps, tilingLevel)) {
+ anchors.push_back({anchorOp, {}});
+ }
} else {
if (Operation *anchorOp =
getLastAnchorOpAfterRootOp(computeOps, tilingLevel)) {
- anchorOps.push_back(anchorOp);
+ anchors.push_back(
+ {anchorOp, getRootAndTransitiveProducers(computeOps, rootOp)});
}
if (Operation *anchorOp =
getLastAnchorOpBeforeRootOp(computeOps, tilingLevel)) {
- anchorOps.push_back(anchorOp);
+ anchors.push_back(
+ {anchorOp, getRootAndTransitiveConsumers(computeOps, rootOp)});
}
}
- if (anchorOps.empty()) {
+ if (anchors.empty()) {
LDBG() << "unable to find an anchor operation that has "
<< IREE::CPU::getTilingLevelName(tilingLevel) << " config";
return;
}
- for (auto anchorOp : anchorOps) {
+ for (auto &[anchorOp, unfusable] : anchors) {
LDBG() << "anchorOp: " << *anchorOp;
if (failed(tileRootAndFuseProducerConsumer(
rewriter, cast<TilingInterface>(anchorOp), tilingLevel,
- onlyFuseProducerInputOperands))) {
+ onlyFuseProducerInputOperands, unfusable))) {
funcOp.emitError() << "tiling of level "
<< IREE::CPU::getTilingLevelName(tilingLevel)
<< " failed\n";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir
index 57fb529..409e529 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/tile_and_fuse_producer_consumer_anchoring_non_root_op.mlir
@@ -375,3 +375,81 @@
// INNER-PARALLEL: linalg.fill
// INNER-PARALLEL: linalg.mmt4d
// INNER-PARALLEL: iree_linalg_ext.map_store
+
+// -----
+
+// Verify that the root op chain (reduction and its consumers) is NOT fused
+// into the before-root anchor's parallel forall, and that the root's producers
+// are NOT fused into the after-root anchor's parallel forall.
+// Chain: elementwise -> reduction(root) -> broadcast -> pack
+// The elementwise is the before-root anchor (has vector_inner_parallel).
+// The broadcast is the after-root anchor (has vector_inner_parallel).
+#map_ew = affine_map<(d0, d1) -> (d0, d1)>
+#map_red = affine_map<(d0, d1) -> (d0)>
+#map_bcast_in = affine_map<(d0, d1) -> (d0)>
+#map_bcast_out = affine_map<(d0, d1) -> (d0, d1)>
+#config_ew = #iree_cpu.lowering_config<vector_common_parallel = [4, 0], vector_inner_parallel = [0, 4], vector_reduction = [0, 16]>
+#config_red = #iree_cpu.lowering_config<distribution = [4, 0], vector_common_parallel = [4, 0], vector_reduction = [0, 16]>
+#config_fill = #iree_cpu.lowering_config<vector_common_parallel = [4]>
+#config_bcast = #iree_cpu.lowering_config<vector_common_parallel = [4, 0], vector_inner_parallel = [0, 8]>
+#config_pack = #iree_cpu.lowering_config<vector_common_parallel = [1, 0], vector_inner_parallel = [0, 1]>
+func.func @no_fuse_reduction_into_parallel_forall(
+ %input: tensor<4x16xi8>,
+ %init_ew: tensor<4x16xi8>,
+ %init_red: tensor<4xi32>,
+ %init_bcast: tensor<4x8xi32>,
+ %init_pack: tensor<1x4x4x2xi32>) -> tensor<1x4x4x2xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %c-128_i8 = arith.constant -128 : i8
+ %ew = linalg.generic {
+ indexing_maps = [#map_ew, #map_ew],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%input : tensor<4x16xi8>) outs(%init_ew : tensor<4x16xi8>)
+ attrs = {lowering_config = #config_ew} {
+ ^bb0(%in: i8, %out: i8):
+ %0 = arith.addi %in, %c-128_i8 : i8
+ linalg.yield %0 : i8
+ } -> tensor<4x16xi8>
+ %fill = linalg.fill {lowering_config = #config_fill}
+ ins(%c0_i32 : i32) outs(%init_red : tensor<4xi32>) -> tensor<4xi32>
+ %red = linalg.generic {
+ indexing_maps = [#map_ew, #map_red],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%ew : tensor<4x16xi8>) outs(%fill : tensor<4xi32>)
+ attrs = {lowering_config = #config_red} {
+ ^bb0(%in: i8, %out: i32):
+ %0 = arith.extsi %in : i8 to i32
+ %1 = arith.addi %0, %out : i32
+ linalg.yield %1 : i32
+ } -> tensor<4xi32>
+ %bcast = linalg.generic {
+ indexing_maps = [#map_bcast_in, #map_bcast_out],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%red : tensor<4xi32>) outs(%init_bcast : tensor<4x8xi32>)
+ attrs = {lowering_config = #config_bcast} {
+ ^bb0(%in: i32, %out: i32):
+ linalg.yield %in : i32
+ } -> tensor<4x8xi32>
+ %pack = linalg.pack %bcast outer_dims_perm = [0, 1] inner_dims_pos = [0, 1]
+ inner_tiles = [4, 2] into %init_pack
+ {lowering_config = #config_pack}
+ : tensor<4x8xi32> -> tensor<1x4x4x2xi32>
+ return %pack : tensor<1x4x4x2xi32>
+}
+// Before-root anchor (elementwise): the root and its consumers must not be
+// fused into its forall.
+// After-root anchor (broadcast): the root and its producers must not be
+// fused into its forall.
+// INNER-PARALLEL-LABEL: func.func @no_fuse_reduction_into_parallel_forall
+// INNER-PARALLEL: scf.forall
+// INNER-PARALLEL: linalg.generic
+// INNER-PARALLEL: scf.forall.in_parallel
+// INNER-PARALLEL-NOT: scf.forall
+// INNER-PARALLEL: linalg.fill
+// INNER-PARALLEL: linalg.generic
+// INNER-PARALLEL-SAME: iterator_types = ["parallel", "reduction"]
+// INNER-PARALLEL: scf.forall
+// INNER-PARALLEL: linalg.generic
+// INNER-PARALLEL-SAME: iterator_types = ["parallel", "parallel"]
+// INNER-PARALLEL: linalg.pack
+// INNER-PARALLEL: scf.forall.in_parallel