[DispatchCreation] Allow fusion of multi-result producers (#24169)

Enables consumer fusion for multi-result producers like
`iree_linalg_ext.online_attention` whose results flow into a single
consumer via operands with different ranks (e.g. acc and sum of the
normalization `linalg.generic`).

Supports https://github.com/iree-org/iree/pull/24068

---------

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp
index 4f0d116..1182f49 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/LoopMappingUtils.cpp
@@ -48,6 +48,46 @@
   return llvm::SmallBitVector{};
 }
 
+/// Combines two composed iteration-space maps result-wise. A constant `0`
+/// is a broadcast placeholder and yields to a concrete expression; two
+/// different concrete expressions at the same position are a conflict. A
+/// null `accumulated` acts as the identity.
+///
+/// Example (merging results from two operands of the same consumer):
+///   accumulated = (d0, d1) -> (d0, 0)   // first operand only constrained d0
+///   incoming    = (d0, d1) -> (0, d1)   // second operand only constrained d1
+///   result      = (d0, d1) -> (d0, d1)  // concrete exprs override zeros
+///
+/// Conflict example:
+///   accumulated = (d0, d1) -> (d0, d1)
+///   incoming    = (d0, d1) -> (d1, d0)
+///   result      = failure  // d0 vs d1 at position 0 is a real disagreement
+static FailureOr<AffineMap> mergeComposedMaps(AffineMap accumulated,
+                                              AffineMap incoming) {
+  if (!accumulated || accumulated == incoming) {
+    return incoming;
+  }
+  auto isZero = [](AffineExpr e) {
+    auto c = dyn_cast<AffineConstantExpr>(e);
+    return c && c.getValue() == 0;
+  };
+
+  SmallVector<AffineExpr> merged;
+  merged.reserve(accumulated.getNumResults());
+  for (auto [accExpr, inExpr] :
+       llvm::zip_equal(accumulated.getResults(), incoming.getResults())) {
+    if (accExpr == inExpr || isZero(inExpr)) {
+      merged.push_back(accExpr);
+    } else if (isZero(accExpr)) {
+      merged.push_back(inExpr);
+    } else {
+      return failure();
+    }
+  }
+  return AffineMap::get(accumulated.getNumDims(), accumulated.getNumSymbols(),
+                        merged, accumulated.getContext());
+}
+
 static FailureOr<AffineMap>
 computeIterationSpaceMapping(AffineMap producerResultMap,
                              AffineMap consumerOperandMap,
@@ -110,7 +150,7 @@
 
       FailureOr<AffineMap> composedMap = computeIterationSpaceMapping(
           producerMap, consumerMap, producerLoopMap);
-      if (failed(composedMap) || (resultMap && composedMap != resultMap)) {
+      if (failed(composedMap)) {
         return failure();
       }
       // In the composed map, a result of 0 means that candidate dim does not
@@ -122,7 +162,16 @@
           composedMap->getNumResults() == composedMap->getNumOfZeroResults()) {
         return failure();
       }
-      resultMap = *composedMap;
+      // Multi-result producers (e.g. OnlineAttentionOp) can feed a single
+      // consumer via operands with different ranks, so two composed maps may
+      // each be valid while differing on broadcast-only dimensions. Merge them
+      // position-wise, letting concrete expressions override zero broadcast
+      // placeholders, and fail only on genuine conflicts.
+      FailureOr<AffineMap> merged = mergeComposedMaps(resultMap, *composedMap);
+      if (failed(merged)) {
+        return failure();
+      }
+      resultMap = *merged;
     }
   } else {
     // Compute mapping by examining consumer uses.
@@ -141,10 +190,14 @@
       // consumers.
       FailureOr<AffineMap> composedMap = computeIterationSpaceMapping(
           consumerMap, producerMap, consumerLoopMap);
-      if (failed(composedMap) || (resultMap && composedMap != resultMap)) {
+      if (failed(composedMap)) {
         return failure();
       }
-      resultMap = *composedMap;
+      FailureOr<AffineMap> merged = mergeComposedMaps(resultMap, *composedMap);
+      if (failed(merged)) {
+        return failure();
+      }
+      resultMap = *merged;
 
       // Producers cannot be more parallel than consumers.
       if (compressUnusedDims(resultMap).getNumDims() !=
diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
index f5cab75..0b4f77f 100644
--- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp
@@ -393,10 +393,20 @@
 static SmallVector<OpOperand *>
 getFusableUses(MLIRContext *context, Operation *op,
                DominanceInfo const &dominanceInfo, bool aggressiveFusion) {
-  if (!aggressiveFusion && llvm::count_if(op->getUses(), [](OpOperand &use) {
-                             return !isa<tensor::DimOp>(use.getOwner());
-                           }) != 1) {
-    return {};
+  // In non-aggressive mode, restrict fusion to producers whose results flow
+  // to a single consumer. Count distinct consumers rather than operand uses
+  // so that a single consumer reading multiple results from a multi-result
+  // producer (e.g. OnlineAttentionOp) still qualifies.
+  if (!aggressiveFusion) {
+    llvm::SetVector<Operation *> consumers;
+    for (Operation *user : op->getUsers()) {
+      if (!isa<tensor::DimOp>(user)) {
+        consumers.insert(user);
+      }
+    }
+    if (consumers.size() != 1) {
+      return {};
+    }
   }
 
   // Collect all fusable user candidates.
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
index 088afac..f350e73 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir
@@ -1,5 +1,6 @@
 // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}))" --split-input-file %s | FileCheck %s
 // RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true fuse-multi-use-producers=false}))" --split-input-file %s | FileCheck %s --check-prefix=NO-MULTI-USE
+// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions))" --split-input-file %s | FileCheck %s --check-prefix=DEFAULT
 
 util.func public @pack_elementwise_fusion(%arg0 : tensor<?xf32>,
     %arg1 : tensor<?x?xf32>) -> tensor<?x?x8x32xf32> {
@@ -2798,11 +2799,206 @@
 
 // -----
 
+// A single consumer reading multiple results of a multi-result producer via
+// operands with different ranks (acc is 3D, sum is 2D) should fuse.
+
+util.func public @online_attention_normalize_fusion(
+    %Q: tensor<20x4096x16xf16>,
+    %K: tensor<20x1024x16xf16>,
+    %V: tensor<20x1024x64xf16>) -> tensor<20x4096x64xf16> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %cst_neg = arith.constant -3.40282347E+38 : f32
+  %cst_one = arith.constant 1.000000e+00 : f32
+  %scale = arith.constant 1.0 : f16
+  %acc_e = tensor.empty() : tensor<20x4096x64xf32>
+  %ms_e = tensor.empty() : tensor<20x4096xf32>
+  %out_e = tensor.empty() : tensor<20x4096x64xf16>
+  %acc = linalg.fill ins(%cst : f32) outs(%acc_e : tensor<20x4096x64xf32>) -> tensor<20x4096x64xf32>
+  %max = linalg.fill ins(%cst_neg : f32) outs(%ms_e : tensor<20x4096xf32>) -> tensor<20x4096xf32>
+  %sum = linalg.fill ins(%cst : f32) outs(%ms_e : tensor<20x4096xf32>) -> tensor<20x4096xf32>
+  %r:3 = iree_linalg_ext.online_attention {
+    indexing_maps = [
+      affine_map<(b, m, n, k1, k2) -> (b, m, k1)>,
+      affine_map<(b, m, n, k1, k2) -> (b, k2, k1)>,
+      affine_map<(b, m, n, k1, k2) -> (b, k2, n)>,
+      affine_map<(b, m, n, k1, k2) -> ()>,
+      affine_map<(b, m, n, k1, k2) -> (b, m, n)>,
+      affine_map<(b, m, n, k1, k2) -> (b, m)>,
+      affine_map<(b, m, n, k1, k2) -> (b, m)>
+    ]
+  } ins(%Q, %K, %V, %scale : tensor<20x4096x16xf16>, tensor<20x1024x16xf16>, tensor<20x1024x64xf16>, f16)
+    outs(%acc, %max, %sum : tensor<20x4096x64xf32>, tensor<20x4096xf32>, tensor<20x4096xf32>) {
+  ^bb0(%score: f32):
+    iree_linalg_ext.yield %score : f32
+  } -> tensor<20x4096x64xf32>, tensor<20x4096xf32>, tensor<20x4096xf32>
+  %norm = linalg.generic {
+    indexing_maps = [
+      affine_map<(b, m, n) -> (b, m, n)>,
+      affine_map<(b, m, n) -> (b, m)>,
+      affine_map<(b, m, n) -> (b, m, n)>
+    ],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%r#0, %r#2 : tensor<20x4096x64xf32>, tensor<20x4096xf32>)
+    outs(%out_e : tensor<20x4096x64xf16>) {
+  ^bb0(%a: f32, %s: f32, %_: f16):
+    %inv = arith.divf %cst_one, %s : f32
+    %v = arith.mulf %a, %inv : f32
+    %v16 = arith.truncf %v : f32 to f16
+    linalg.yield %v16 : f16
+  } -> tensor<20x4096x64xf16>
+  util.return %norm : tensor<20x4096x64xf16>
+}
+
+//      DEFAULT-LABEL: @online_attention_normalize_fusion
+//            DEFAULT:   %[[D:.+]] = flow.dispatch.region -> (tensor<20x4096x64xf16>)
+//            DEFAULT:     %[[R:.+]]:3 = iree_linalg_ext.online_attention
+//            DEFAULT:     %[[G:.+]] = linalg.generic
+//       DEFAULT-SAME:       ins(%[[R]]#0, %[[R]]#2
+//            DEFAULT:     flow.return %[[G]]
+//            DEFAULT:   util.return %[[D]]
+
+// -----
+
+// Multi-result producer whose outputs use differently-permuted indexing maps.
+// The two composed iteration-space maps land concrete (non-broadcast)
+// expressions at the same positions, so mergeComposedMaps must reject the
+// merge and producer fusion must not happen.
+
+util.func public @multi_result_conflict_no_fusion(%x: tensor<16x16xf32>) -> tensor<16x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %e0 = tensor.empty() : tensor<16x16xf32>
+  %e1 = tensor.empty() : tensor<16x16xf32>
+  %out_e = tensor.empty() : tensor<16x16xf32>
+  %out = linalg.fill ins(%cst : f32) outs(%out_e : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %p:2 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d1, d0)>
+    ],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%x : tensor<16x16xf32>)
+    outs(%e0, %e1 : tensor<16x16xf32>, tensor<16x16xf32>) {
+  ^bb0(%in: f32, %o0: f32, %o1: f32):
+    linalg.yield %in, %in : f32, f32
+  } -> (tensor<16x16xf32>, tensor<16x16xf32>)
+  %m = linalg.matmul ins(%p#0, %p#1 : tensor<16x16xf32>, tensor<16x16xf32>)
+                    outs(%out : tensor<16x16xf32>) -> tensor<16x16xf32>
+  util.return %m : tensor<16x16xf32>
+}
+
+//      DEFAULT-LABEL: @multi_result_conflict_no_fusion
+//            DEFAULT:   %[[P:.+]]:2 = flow.dispatch.region
+//            DEFAULT:     linalg.generic
+//            DEFAULT:   flow.dispatch.region
+//            DEFAULT:     linalg.matmul ins(%[[P]]#0, %[[P]]#1
+//            DEFAULT:     flow.return
+
+// -----
+
+// Single-result producer whose only result feeds a single consumer via two
+// operand slots. The old count_if(non-DimOp uses) != 1 check bailed because
+// there are two uses; counting distinct consumers instead lets this fuse.
+
+util.func public @single_result_two_uses_same_consumer_default(
+    %a : tensor<?x?xf32>, %b : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %b, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %b, %c1 : tensor<?x?xf32>
+  %empty = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %m = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%m, %m : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%empty : tensor<?x?xf32>) {
+  ^bb0(%b0 : f32, %b1 : f32, %out : f32):
+    %s = arith.addf %b0, %b1 : f32
+    linalg.yield %s : f32
+  } -> tensor<?x?xf32>
+  util.return %r : tensor<?x?xf32>
+}
+
+//      DEFAULT-LABEL: @single_result_two_uses_same_consumer_default
+//            DEFAULT:   %[[D:.+]] = flow.dispatch.region
+//            DEFAULT:     %[[M:.+]] = linalg.matmul
+//            DEFAULT:     %[[G:.+]] = linalg.generic
+//       DEFAULT-SAME:       ins(%[[M]], %[[M]]
+//            DEFAULT:     flow.return %[[G]]
+//            DEFAULT:   util.return %[[D]]
+
+// -----
+
+// Two distinct producers absorbed into the same fusion group feed a single
+// consumer; the two composed loop maps differ on broadcast-zero positions,
+// so mergeComposedMaps merges rather than bit-comparing.
+
+util.func public @multi_producer_consumer_fusion(
+    %x: tensor<16x16xf32>, %y: tensor<16x16xf32>) -> tensor<16x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %eA = tensor.empty() : tensor<16x16xf32>
+  %eB = tensor.empty() : tensor<16x16xf32>
+  %eOut = tensor.empty() : tensor<16x16xf32>
+  %fill = linalg.fill ins(%cst : f32) outs(%eOut : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %pA = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%x : tensor<16x16xf32>) outs(%eA : tensor<16x16xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %s = arith.addf %in, %cst : f32
+    linalg.yield %s : f32
+  } -> tensor<16x16xf32>
+  %pB = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%y : tensor<16x16xf32>) outs(%eB : tensor<16x16xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %s = arith.addf %in, %cst : f32
+    linalg.yield %s : f32
+  } -> tensor<16x16xf32>
+  %m = linalg.matmul ins(%pA, %pB : tensor<16x16xf32>, tensor<16x16xf32>)
+                    outs(%fill : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%m, %pA : tensor<16x16xf32>, tensor<16x16xf32>)
+    outs(%eOut : tensor<16x16xf32>) {
+  ^bb0(%a: f32, %b: f32, %out: f32):
+    %s = arith.addf %a, %b : f32
+    linalg.yield %s : f32
+  } -> tensor<16x16xf32>
+  util.return %r : tensor<16x16xf32>
+}
+
+// Aggressive-fusion absorbs %pA into the matmul's group, so the consumer
+// generic at the end sees two operands from two different in-group
+// producers (%m and %pA). mergeComposedMaps must succeed on the two
+// composed maps for the fusion to proceed.
+//      CHECK-LABEL: @multi_producer_consumer_fusion
+//            CHECK:   flow.dispatch.region
+//            CHECK:     %[[PA:.+]] = linalg.generic
+//            CHECK:     %[[M:.+]] = linalg.matmul
+//       CHECK-SAME:       ins(%[[PA]], %{{.+}}
+//            CHECK:     %[[R:.+]] = linalg.generic
+//       CHECK-SAME:       ins(%[[M]], %[[PA]]
+//            CHECK:     flow.return %[[R]]
 // Multi-result producer feeds a pack consumer via two operands. Producer
 // result 0 has a non-identity indexing map while result 1 has an identity
 // map, so the pack-identity per-operand check fails on the source operand
 // but passes on the dest operand. Fusion must be rejected.
 
+// -----
+
 util.func public @pack_per_operand_rejection(
     %a: tensor<1x1xf32>, %b: tensor<1x1xf32>) -> tensor<1x1x1x1xf32> {
   %cst = arith.constant 0.000000e+00 : f32