[DispatchCreation] Allow fusion of multi-result producers (#24169)
Enables consumer fusion for multi-result producers like
`iree_linalg_ext.online_attention` whose results flow into a single
consumer via operands with different ranks (e.g. acc and sum of the
normalization `linalg.generic`).
Supports https://github.com/iree-org/iree/pull/24068
---------
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp
index 4f0d116..1182f49 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp
@@ -48,6 +48,46 @@
return llvm::SmallBitVector{};
}
+/// Combines two composed iteration-space maps result-wise. A constant `0`
+/// is a broadcast placeholder and yields to a concrete expression; two
+/// different concrete expressions at the same position are a conflict. A
+/// null `accumulated` acts as the identity.
+///
+/// Example (merging results from two operands of the same consumer):
+/// accumulated = (d0, d1) -> (d0, 0) // first operand only constrained d0
+/// incoming = (d0, d1) -> (0, d1) // second operand only constrained d1
+/// result = (d0, d1) -> (d0, d1) // concrete exprs override zeros
+///
+/// Conflict example:
+/// accumulated = (d0, d1) -> (d0, d1)
+/// incoming = (d0, d1) -> (d1, d0)
+/// result = failure // d0 vs d1 at position 0 is a real disagreement
+static FailureOr<AffineMap> mergeComposedMaps(AffineMap accumulated,
+ AffineMap incoming) {
+ if (!accumulated || accumulated == incoming) {
+ return incoming;
+ }
+ auto isZero = [](AffineExpr e) {
+ auto c = dyn_cast<AffineConstantExpr>(e);
+ return c && c.getValue() == 0;
+ };
+
+ SmallVector<AffineExpr> merged;
+ merged.reserve(accumulated.getNumResults());
+ for (auto [accExpr, inExpr] :
+ llvm::zip_equal(accumulated.getResults(), incoming.getResults())) {
+ if (accExpr == inExpr || isZero(inExpr)) {
+ merged.push_back(accExpr);
+ } else if (isZero(accExpr)) {
+ merged.push_back(inExpr);
+ } else {
+ return failure();
+ }
+ }
+ return AffineMap::get(accumulated.getNumDims(), accumulated.getNumSymbols(),
+ merged, accumulated.getContext());
+}
+
static FailureOr<AffineMap>
computeIterationSpaceMapping(AffineMap producerResultMap,
AffineMap consumerOperandMap,
@@ -110,7 +150,7 @@
FailureOr<AffineMap> composedMap = computeIterationSpaceMapping(
producerMap, consumerMap, producerLoopMap);
- if (failed(composedMap) || (resultMap && composedMap != resultMap)) {
+ if (failed(composedMap)) {
return failure();
}
// In the composed map, a result of 0 means that candidate dim does not
@@ -122,7 +162,16 @@
composedMap->getNumResults() == composedMap->getNumOfZeroResults()) {
return failure();
}
- resultMap = *composedMap;
+ // Multi-result producers (e.g. OnlineAttentionOp) can feed a single
+ // consumer via operands with different ranks, so two composed maps may
+ // each be valid while differing on broadcast-only dimensions. Merge them
+ // position-wise, letting concrete expressions override zero broadcast
+ // placeholders, and fail only on genuine conflicts.
+ FailureOr<AffineMap> merged = mergeComposedMaps(resultMap, *composedMap);
+ if (failed(merged)) {
+ return failure();
+ }
+ resultMap = *merged;
}
} else {
// Compute mapping by examining consumer uses.
@@ -141,10 +190,14 @@
// consumers.
FailureOr<AffineMap> composedMap = computeIterationSpaceMapping(
consumerMap, producerMap, consumerLoopMap);
- if (failed(composedMap) || (resultMap && composedMap != resultMap)) {
+ if (failed(composedMap)) {
return failure();
}
- resultMap = *composedMap;
+ FailureOr<AffineMap> merged = mergeComposedMaps(resultMap, *composedMap);
+ if (failed(merged)) {
+ return failure();
+ }
+ resultMap = *merged;
// Producers cannot be more parallel than consumers.
if (compressUnusedDims(resultMap).getNumDims() !=
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
index f5cab75..0b4f77f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
@@ -393,10 +393,20 @@
static SmallVector<OpOperand *>
getFusableUses(MLIRContext *context, Operation *op,
DominanceInfo const &dominanceInfo, bool aggressiveFusion) {
- if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) {
- return !isa<tensor::DimOp>(use.getOwner());
- }) != 1) {
- return {};
+ // In non-aggressive mode, restrict fusion to producers whose results flow
+ // to a single consumer. Count distinct consumers rather than operand uses
+ // so that a single consumer reading multiple results from a multi-result
+ // producer (e.g. OnlineAttentionOp) still qualifies.
+ if (!aggressiveFusion) {
+ llvm::SetVector<Operation *> consumers;
+ for (Operation *user : op->getUsers()) {
+ if (!isa<tensor::DimOp>(user)) {
+ consumers.insert(user);
+ }
+ }
+ if (consumers.size() != 1) {
+ return {};
+ }
}
// Collect all fusable user candidates.
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index 088afac..f350e73 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -1,5 +1,6 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true fuse-multi-use-producers=false}))" --split-input-file %s | FileCheck %s --check-prefix=NO-MULTI-USE
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions))" --split-input-file %s | FileCheck %s --check-prefix=DEFAULT
util.func public @pack_elementwise_fusion(%arg0 : tensor<?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?x8x32xf32> {
@@ -2798,11 +2799,206 @@
// -----
+// A single consumer reading multiple results of a multi-result producer via
+// operands with different ranks (acc is 3D, sum is 2D) should fuse.
+
+util.func public @online_attention_normalize_fusion(
+ %Q: tensor<20x4096x16xf16>,
+ %K: tensor<20x1024x16xf16>,
+ %V: tensor<20x1024x64xf16>) -> tensor<20x4096x64xf16> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %cst_neg = arith.constant -3.40282347E+38 : f32
+ %cst_one = arith.constant 1.000000e+00 : f32
+ %scale = arith.constant 1.0 : f16
+ %acc_e = tensor.empty() : tensor<20x4096x64xf32>
+ %ms_e = tensor.empty() : tensor<20x4096xf32>
+ %out_e = tensor.empty() : tensor<20x4096x64xf16>
+ %acc = linalg.fill ins(%cst : f32) outs(%acc_e : tensor<20x4096x64xf32>) -> tensor<20x4096x64xf32>
+ %max = linalg.fill ins(%cst_neg : f32) outs(%ms_e : tensor<20x4096xf32>) -> tensor<20x4096xf32>
+ %sum = linalg.fill ins(%cst : f32) outs(%ms_e : tensor<20x4096xf32>) -> tensor<20x4096xf32>
+ %r:3 = iree_linalg_ext.online_attention {
+ indexing_maps = [
+ affine_map<(b, m, n, k1, k2) -> (b, m, k1)>,
+ affine_map<(b, m, n, k1, k2) -> (b, k2, k1)>,
+ affine_map<(b, m, n, k1, k2) -> (b, k2, n)>,
+ affine_map<(b, m, n, k1, k2) -> ()>,
+ affine_map<(b, m, n, k1, k2) -> (b, m, n)>,
+ affine_map<(b, m, n, k1, k2) -> (b, m)>,
+ affine_map<(b, m, n, k1, k2) -> (b, m)>
+ ]
+ } ins(%Q, %K, %V, %scale : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
+ outs(%acc, %max, %sum : tensor<20x4096x64xf32>, tensor<20x4096xf32>, tensor<20x4096xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score : f32
+ } -> tensor<20x4096x64xf32>, tensor<20x4096xf32>, tensor<20x4096xf32>
+ %norm = linalg.generic {
+ indexing_maps = [
+ affine_map<(b, m, n) -> (b, m, n)>,
+ affine_map<(b, m, n) -> (b, m)>,
+ affine_map<(b, m, n) -> (b, m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } ins(%r#0, %r#2 : tensor<20x4096x64xf32>, tensor<20x4096xf32>)
+ outs(%out_e : tensor<20x4096x64xf16>) {
+ ^bb0(%a: f32, %s: f32, %_: f16):
+ %inv = arith.divf %cst_one, %s : f32
+ %v = arith.mulf %a, %inv : f32
+ %v16 = arith.truncf %v : f32 to f16
+ linalg.yield %v16 : f16
+ } -> tensor<20x4096x64xf16>
+ util.return %norm : tensor<20x4096x64xf16>
+}
+
+// DEFAULT-LABEL: @online_attention_normalize_fusion
+// DEFAULT: %[[D:.+]] = flow.dispatch.region -> (tensor<20x4096x64xf16>)
+// DEFAULT: %[[R:.+]]:3 = iree_linalg_ext.online_attention
+// DEFAULT: %[[G:.+]] = linalg.generic
+// DEFAULT-SAME: ins(%[[R]]#0, %[[R]]#2
+// DEFAULT: flow.return %[[G]]
+// DEFAULT: util.return %[[D]]
+
+// -----
+
+// Multi-result producer whose outputs use differently-permuted indexing maps.
+// The two composed iteration-space maps land concrete (non-broadcast)
+// expressions at the same positions, so mergeComposedMaps must reject the
+// merge and producer fusion must not happen.
+
+util.func public @multi_result_conflict_no_fusion(%x: tensor<16x16xf32>) -> tensor<16x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %e0 = tensor.empty() : tensor<16x16xf32>
+ %e1 = tensor.empty() : tensor<16x16xf32>
+ %out_e = tensor.empty() : tensor<16x16xf32>
+ %out = linalg.fill ins(%cst : f32) outs(%out_e : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %p:2 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1, d0)>
+ ],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%x : tensor<16x16xf32>)
+ outs(%e0, %e1 : tensor<16x16xf32>, tensor<16x16xf32>) {
+ ^bb0(%in: f32, %o0: f32, %o1: f32):
+ linalg.yield %in, %in : f32, f32
+ } -> (tensor<16x16xf32>, tensor<16x16xf32>)
+ %m = linalg.matmul ins(%p#0, %p#1 : tensor<16x16xf32>, tensor<16x16xf32>)
+ outs(%out : tensor<16x16xf32>) -> tensor<16x16xf32>
+ util.return %m : tensor<16x16xf32>
+}
+
+// DEFAULT-LABEL: @multi_result_conflict_no_fusion
+// DEFAULT: %[[P:.+]]:2 = flow.dispatch.region
+// DEFAULT: linalg.generic
+// DEFAULT: flow.dispatch.region
+// DEFAULT: linalg.matmul ins(%[[P]]#0, %[[P]]#1
+// DEFAULT: flow.return
+
+// -----
+
+// Single-result producer whose only result feeds a single consumer via two
+// operand slots. The old count_if(non-DimOp uses) != 1 check bailed because
+// there are two uses; counting distinct consumers instead lets this fuse.
+
+util.func public @single_result_two_uses_same_consumer_default(
+ %a : tensor<?x?xf32>, %b : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %b, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %b, %c1 : tensor<?x?xf32>
+ %empty = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %m = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %r = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%m, %m : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%empty : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32, %out : f32):
+ %s = arith.addf %b0, %b1 : f32
+ linalg.yield %s : f32
+ } -> tensor<?x?xf32>
+ util.return %r : tensor<?x?xf32>
+}
+
+// DEFAULT-LABEL: @single_result_two_uses_same_consumer_default
+// DEFAULT: %[[D:.+]] = flow.dispatch.region
+// DEFAULT: %[[M:.+]] = linalg.matmul
+// DEFAULT: %[[G:.+]] = linalg.generic
+// DEFAULT-SAME: ins(%[[M]], %[[M]]
+// DEFAULT: flow.return %[[G]]
+// DEFAULT: util.return %[[D]]
+
+// -----
+
+// Two distinct producers absorbed into the same fusion group feed a single
+// consumer; the two composed loop maps differ on broadcast-zero positions,
+// so mergeComposedMaps merges rather than bit-comparing.
+
+util.func public @multi_producer_consumer_fusion(
+ %x: tensor<16x16xf32>, %y: tensor<16x16xf32>) -> tensor<16x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %eA = tensor.empty() : tensor<16x16xf32>
+ %eB = tensor.empty() : tensor<16x16xf32>
+ %eOut = tensor.empty() : tensor<16x16xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%eOut : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %pA = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%x : tensor<16x16xf32>) outs(%eA : tensor<16x16xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %s = arith.addf %in, %cst : f32
+ linalg.yield %s : f32
+ } -> tensor<16x16xf32>
+ %pB = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%y : tensor<16x16xf32>) outs(%eB : tensor<16x16xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %s = arith.addf %in, %cst : f32
+ linalg.yield %s : f32
+ } -> tensor<16x16xf32>
+ %m = linalg.matmul ins(%pA, %pB : tensor<16x16xf32>, tensor<16x16xf32>)
+ outs(%fill : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %r = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%m, %pA : tensor<16x16xf32>, tensor<16x16xf32>)
+ outs(%eOut : tensor<16x16xf32>) {
+ ^bb0(%a: f32, %b: f32, %out: f32):
+ %s = arith.addf %a, %b : f32
+ linalg.yield %s : f32
+ } -> tensor<16x16xf32>
+ util.return %r : tensor<16x16xf32>
+}
+
+// Aggressive-fusion absorbs %pA into the matmul's group, so the consumer
+// generic at the end sees two operands from two different in-group
+// producers (%m and %pA). mergeComposedMaps must succeed on the two
+// composed maps for the fusion to proceed.
+// CHECK-LABEL: @multi_producer_consumer_fusion
+// CHECK: flow.dispatch.region
+// CHECK: %[[PA:.+]] = linalg.generic
+// CHECK: %[[M:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[PA]], %{{.+}}
+// CHECK: %[[R:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[M]], %[[PA]]
+// CHECK: flow.return %[[R]]
// Multi-result producer feeds a pack consumer via two operands. Producer
// result 0 has a non-identity indexing map while result 1 has an identity
// map, so the pack-identity per-operand check fails on the source operand
// but passes on the dest operand. Fusion must be rejected.
+// -----
+
util.func public @pack_per_operand_rejection(
%a: tensor<1x1xf32>, %b: tensor<1x1xf32>) -> tensor<1x1x1x1xf32> {
%cst = arith.constant 0.000000e+00 : f32