Fix ExpandDestinationForallOp rejecting non-store users (#24073)
The pattern only needs to validate store users for expandability.
Non-store users (e.g. tensor.dim) are safe because the hoisted
collapse_shape preserves the original result type.
Previously, any non-store user caused the pattern to bail out. With
dynamic shapes on gfx1100, tensor.dim ops on the forall result blocked
the pattern, preventing the WMMAR3 accumulator reshape from folding into
the output buffer and causing a shared memory overflow (80KB > 65KB).
Signed-off-by: Jorn <jorn.tuyls@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
index cbfa98e..ba7558c 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/Transforms.cpp
@@ -1148,26 +1148,41 @@
return failure();
}
- // We only want this pattern if the forall op result is being written to a
- // full slice, or an expandable buffer. Otherwise the hoisted collapse op is
- // not foldable.
+ // This pattern hoists a collapse_shape out of the forall body by
+ // expanding the forall destination and wrapping the result in a
+ // collapse_shape: expand -> new_forall -> collapse. The collapse
+ // result has the exact same type as the original forall result, so
+ // all users of the original result (stores, tensor.dim, etc.) see
+ // an unchanged type. The only requirement is that at least one store
+ // user can absorb the expansion, so the collapse is foldable. Even
+ // if the collapse survives (e.g., non-store users keep it alive),
+ // it is valid IR and strictly better than failing to hoist.
+ bool hasExpandableStore = false;
for (Operation *foralluser : tiedResult.getUsers()) {
auto storeOp =
dyn_cast<IREE::TensorExt::DispatchTensorStoreOp>(foralluser);
- if (storeOp && isFullSlice(storeOp, storeOp.getTargetType(),
- storeOp.getTargetDims())) {
+ if (storeOp) {
+ if (!isFullSlice(storeOp, storeOp.getTargetType(),
+ storeOp.getTargetDims())) {
+ return failure();
+ }
+ hasExpandableStore = true;
continue;
}
auto storeToBufferOp =
dyn_cast<IREE::Codegen::StoreToBufferOp>(foralluser);
- if (!storeToBufferOp) {
- return failure();
+ if (storeToBufferOp) {
+ MemRefType bufferType = storeToBufferOp.getBuffer().getType();
+ if (failed(memref::ExpandShapeOp::computeExpandedType(
+ bufferType, expandedDestShape, reIndices))) {
+ return failure();
+ }
+ hasExpandableStore = true;
+ continue;
}
- MemRefType bufferType = storeToBufferOp.getBuffer().getType();
- if (failed(memref::ExpandShapeOp::computeExpandedType(
- bufferType, expandedDestShape, reIndices))) {
- return failure();
- }
+ }
+ if (!hasExpandableStore) {
+ return failure();
}
// This allows us to assume that the extract/inserts in the loop are
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
index 9bc367d..e2e8028 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
@@ -162,6 +162,86 @@
// -----
+// Verify that ExpandDestinationForallOp fires even when the forall result
+// has non-store users. The pattern's collapse_shape preserves the original
+// result type, so non-store users always see the same type. This is needed
+// for dynamic shapes where tensor.dim or other metadata ops appear as
+// users of the forall result.
+func.func @expand_dest_forall_with_non_store_user(
+ %buffer: memref<?x64x32xf32>, %index: index) {
+ %c0 = arith.constant 0 : index
+ %1 = tensor.empty(%index) : tensor<?x64x32xf32>
+ %2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16)
+ shared_outs(%arg2 = %1) -> (tensor<?x64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+ : tensor<?x64x32xf32> to tensor<1x16x16xf32>
+ %expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]]
+ output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32>
+ %expanded_barrier = util.optimization_barrier %expanded : tensor<1x16x2x4x2xf32>
+ %collapsed = tensor.collapse_shape %expanded_barrier [[0], [1], [2, 3, 4]]
+ : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %collapsed into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+ : tensor<1x16x16xf32> into tensor<?x64x32xf32>
+ }
+ } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+ // Non-store user: previously blocked ExpandDestinationForallOp.
+ %barrier = util.optimization_barrier %2 : tensor<?x64x32xf32>
+ iree_codegen.store_to_buffer %2, %buffer
+ : tensor<?x64x32xf32> into memref<?x64x32xf32>
+ return
+}
+
+// The forall output should be expanded despite the non-store user.
+// CHECK-LABEL: func @expand_dest_forall_with_non_store_user(
+// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]: memref<?x64x32xf32>
+// CHECK-SAME: %[[INDEX:[a-zA-Z0-9]+]]: index
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[INDEX]]) : tensor<?x64x4x4x2xf32>
+// CHECK: %[[FORALL:.+]] = scf.forall
+// CHECK-SAME: shared_outs(%{{.+}} = %[[EMPTY]]) -> (tensor<?x64x4x4x2xf32>)
+// CHECK: tensor.extract_slice
+// CHECK-SAME: tensor<?x64x4x4x2xf32> to tensor<1x16x2x4x2xf32>
+// CHECK: tensor.parallel_insert_slice
+// CHECK-SAME: tensor<1x16x2x4x2xf32> into tensor<?x64x4x4x2xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL]]
+// CHECK-SAME: tensor<?x64x4x4x2xf32> into tensor<?x64x32xf32>
+// CHECK: util.optimization_barrier %[[COLLAPSE]] : tensor<?x64x32xf32>
+// CHECK: %[[EXPAND_BUF:.+]] = memref.expand_shape %[[BUF]]
+// CHECK-SAME: memref<?x64x32xf32> into memref<?x64x4x4x2xf32>
+// CHECK: iree_codegen.store_to_buffer %[[FORALL]], %[[EXPAND_BUF]]
+// CHECK-SAME: tensor<?x64x4x4x2xf32> into memref<?x64x4x4x2xf32>
+
+// -----
+
+// Negative test: the pattern should NOT fire when there are only non-store
+// users (no expandable store to fold the collapse into).
+func.func @noexpand_dest_forall_no_store_user(%index: index) {
+ %c0 = arith.constant 0 : index
+ %1 = tensor.empty(%index) : tensor<?x64x32xf32>
+ %2 = scf.forall (%arg0, %arg1) = (0, 0) to (64, 32) step (16, 16)
+ shared_outs(%arg2 = %1) -> (tensor<?x64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+ : tensor<?x64x32xf32> to tensor<1x16x16xf32>
+ %expanded = tensor.expand_shape %extracted_slice [[0], [1], [2, 3, 4]]
+ output_shape [1, 16, 2, 4, 2] : tensor<1x16x16xf32> into tensor<1x16x2x4x2xf32>
+ %expanded_barrier = util.optimization_barrier %expanded : tensor<1x16x2x4x2xf32>
+ %collapsed = tensor.collapse_shape %expanded_barrier [[0], [1], [2, 3, 4]]
+ : tensor<1x16x2x4x2xf32> into tensor<1x16x16xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %collapsed into %arg2[%c0, %arg0, %arg1] [1, 16, 16] [1, 1, 1]
+ : tensor<1x16x16xf32> into tensor<?x64x32xf32>
+ }
+ } {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
+ %barrier = util.optimization_barrier %2 : tensor<?x64x32xf32>
+ return
+}
+
+// The forall output should NOT be expanded (no store to benefit from it).
+// CHECK-LABEL: func @noexpand_dest_forall_no_store_user
+// CHECK: scf.forall{{.*}}-> (tensor<?x64x32xf32>)
+
+// -----
+
#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
func.func @noexpand_dest_forall_dynamicpacked() {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index 107c81f..2028770 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -43,6 +43,7 @@
"pipeline_lower_to_llvmgpu.mlir",
"pipeline_scaled_truncation_gfx950.mlir",
"pipeline_tile_and_fuse.mlir",
+ "pipeline_tile_and_fuse_gfx1100.mlir",
"pipeline_tile_and_fuse_gfx950.mlir",
"pipeline_vector_distribute_dynamic_shapes_gfx942.mlir",
"pipeline_vector_distribute_gfx1100.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 20b682f..ccc003e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -38,6 +38,7 @@
"pipeline_lower_to_llvmgpu.mlir"
"pipeline_scaled_truncation_gfx950.mlir"
"pipeline_tile_and_fuse.mlir"
+ "pipeline_tile_and_fuse_gfx1100.mlir"
"pipeline_tile_and_fuse_gfx950.mlir"
"pipeline_vector_distribute_dynamic_shapes_gfx942.mlir"
"pipeline_vector_distribute_gfx1100.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse_gfx1100.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse_gfx1100.mlir
new file mode 100644
index 0000000..ab44b53
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse_gfx1100.mlir
@@ -0,0 +1,105 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 \
+// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm}, builtin.module(iree-llvmgpu-select-lowering-strategy, func.func(iree-llvmgpu-lower-executable-target, iree-codegen-gpu-check-resource-usage)))))" \
+// RUN: %s | FileCheck %s
+
+// Regression test for dynamic batch_matmul with WMMAR3 on gfx1100.
+//
+// WMMAR3 (RDNA3) has accumulator layout outer={8,1} which requires an
+// expand_shape on the output. With dynamic shapes, tensor.dim users on
+// the forall result previously blocked ExpandDestinationForallOp from
+// folding the expand into the output buffer. This caused a separate shared
+// memory allocation for the output accumulator, exceeding the 65536-byte
+// shared memory limit (LHS 16KB + RHS 32KB + output 32KB = 80KB > 65KB).
+
+#pipeline_layout = #hal.pipeline.layout<constants = 6, bindings = [
+ #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+ #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+ #hal.pipeline.binding<storage_buffer, Indirect>
+], flags = Indirect>
+hal.executable private @batch_matmul_dynamic_wmmar3 {
+ hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export public @batch_matmul_24xDxDx128_f16_f32 ordinal(0)
+ layout(#pipeline_layout)
+ count(%dev: !hal.device, %arg1: index, %arg2: index, %arg3: index)
+ -> (index, index, index) {
+ %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice(%arg1, %arg2, %arg3)
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @batch_matmul_24xDxDx128_f16_f32() {
+ %c32_i64 = arith.constant 32 : i64
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
+ %1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
+ %2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32
+ %3 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : i32
+ %4 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : i32
+ %5 = hal.interface.constant.load layout(#pipeline_layout) ordinal(5) : i32
+ %6 = arith.extui %0 : i32 to i64
+ %7 = arith.extui %1 : i32 to i64
+ %8 = arith.shli %7, %c32_i64 : i64
+ %9 = arith.ori %6, %8 : i64
+ %10 = arith.index_castui %9 : i64 to index
+ %11 = arith.extui %2 : i32 to i64
+ %12 = arith.extui %3 : i32 to i64
+ %13 = arith.shli %12, %c32_i64 : i64
+ %14 = arith.ori %11, %13 : i64
+ %15 = arith.index_castui %14 : i64 to index
+ %16 = arith.extui %4 : i32 to i64
+ %17 = arith.extui %5 : i32 to i64
+ %18 = arith.shli %17, %c32_i64 : i64
+ %19 = arith.ori %16, %18 : i64
+ %20 = arith.index_castui %19 : i64 to index
+ %21:3 = util.assume.int
+ %10<umin = 0, umax = 9007199254740991>,
+ %15<umin = 0, umax = 9007199254740991>,
+ %20<udiv = 128>
+ : index, index, index
+ %22 = iree_tensor_ext.dispatch.workload.ordinal %21#0, 0 : index
+ %23 = iree_tensor_ext.dispatch.workload.ordinal %21#1, 1 : index
+ %24 = iree_tensor_ext.dispatch.workload.ordinal %21#2, 2 : index
+ %25 = hal.interface.binding.subspan layout(#pipeline_layout)
+ binding(0) alignment(64) offset(%c0)
+ flags("ReadOnly|Indirect")
+ : !iree_tensor_ext.dispatch.tensor<readonly:tensor<24x?x128xf16>>{%22}
+ %26 = hal.interface.binding.subspan layout(#pipeline_layout)
+ binding(1) alignment(64) offset(%c0)
+ flags("ReadOnly|Indirect")
+ : !iree_tensor_ext.dispatch.tensor<readonly:tensor<24x?x128xf16>>{%23}
+ %27 = hal.interface.binding.subspan layout(#pipeline_layout)
+ binding(2) alignment(64) offset(%c0)
+ flags(Indirect)
+ : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<24x?x?xf32>>{%24, %24}
+ %28 = iree_tensor_ext.dispatch.tensor.load %25,
+ offsets = [0, 0, 0], sizes = [24, %22, 128], strides = [1, 1, 1]
+ : !iree_tensor_ext.dispatch.tensor<readonly:tensor<24x?x128xf16>>{%22}
+ -> tensor<24x?x128xf16>
+ %29 = iree_tensor_ext.dispatch.tensor.load %26,
+ offsets = [0, 0, 0], sizes = [24, %23, 128], strides = [1, 1, 1]
+ : !iree_tensor_ext.dispatch.tensor<readonly:tensor<24x?x128xf16>>{%23}
+ -> tensor<24x?x128xf16>
+ %30 = tensor.empty(%24, %24) : tensor<24x?x?xf32>
+ %31 = linalg.fill ins(%cst : f32) outs(%30 : tensor<24x?x?xf32>)
+ -> tensor<24x?x?xf32>
+ %32 = linalg.batch_matmul
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>]
+ ins(%28, %29 : tensor<24x?x128xf16>, tensor<24x?x128xf16>)
+ outs(%31 : tensor<24x?x?xf32>) -> tensor<24x?x?xf32>
+ iree_tensor_ext.dispatch.tensor.store %32, %27,
+ offsets = [0, 0, 0], sizes = [24, %24, %24], strides = [1, 1, 1]
+ : tensor<24x?x?xf32>
+ -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<24x?x?xf32>>{%24, %24}
+ return
+ }
+ }
+ }
+}
+
+// Verify the full codegen pipeline completes (no shared memory overflow)
+// and produces WMMA instructions.
+// CHECK-LABEL: hal.executable private @batch_matmul_dynamic_wmmar3
+// CHECK: func.func @batch_matmul_24xDxDx128_f16_f32
+// CHECK: amdgpu.wmma
diff --git a/tests/e2e/regression/BUILD.bazel b/tests/e2e/regression/BUILD.bazel
index 903d192..ca14281 100644
--- a/tests/e2e/regression/BUILD.bazel
+++ b/tests/e2e/regression/BUILD.bazel
@@ -144,6 +144,7 @@
# TODO(kuhar): Drop the timeout after we switch to testing on the actual gfx1250 chip.
timeout = "moderate",
srcs = [
+ "dynamic_batch_matmul_gfx1100.mlir",
"dynamic_gather_attention.mlir",
"linalg_ops_dynamic.mlir",
"split_reduction_using_tiling.mlir",
diff --git a/tests/e2e/regression/CMakeLists.txt b/tests/e2e/regression/CMakeLists.txt
index 0677fe5..d040e6f 100644
--- a/tests/e2e/regression/CMakeLists.txt
+++ b/tests/e2e/regression/CMakeLists.txt
@@ -161,6 +161,7 @@
NAME
check_regression_hip
SRCS
+ "dynamic_batch_matmul_gfx1100.mlir"
"dynamic_gather_attention.mlir"
"linalg_ops_dynamic.mlir"
"split_reduction_using_tiling.mlir"
diff --git a/tests/e2e/regression/dynamic_batch_matmul_gfx1100.mlir b/tests/e2e/regression/dynamic_batch_matmul_gfx1100.mlir
new file mode 100644
index 0000000..b7be737
--- /dev/null
+++ b/tests/e2e/regression/dynamic_batch_matmul_gfx1100.mlir
@@ -0,0 +1,30 @@
+// Regression test for dynamic batch_matmul with WMMAR3 on RDNA3 (gfx1100).
+//
+// WMMAR3 has accumulator layout outer={8,1} which requires an expand_shape
+// on the output. With dynamic shapes, tensor.dim users on the forall result
+// previously blocked the ExpandDestinationForallOp pattern, causing a
+// separate shared memory allocation for the output accumulator that exceeded
+// the 65536-byte limit.
+
+func.func @dynamic_batch_matmul_transposed_rhs() {
+ %lhs = flow.tensor.dynamic_constant dense<1.0> : tensor<2x128x128xf16> -> tensor<2x?x128xf16>
+ %rhs = flow.tensor.dynamic_constant dense<1.0> : tensor<2x128x128xf16> -> tensor<2x?x128xf16>
+
+ %cst = arith.constant 0.0 : f32
+ %c1 = arith.constant 1 : index
+ %m = tensor.dim %lhs, %c1 : tensor<2x?x128xf16>
+ %n = tensor.dim %rhs, %c1 : tensor<2x?x128xf16>
+ %init = tensor.empty(%m, %n) : tensor<2x?x?xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+ %observed = linalg.batch_matmul
+ indexing_maps = [affine_map<(b, m, n, k) -> (b, m, k)>,
+ affine_map<(b, m, n, k) -> (b, n, k)>,
+ affine_map<(b, m, n, k) -> (b, m, n)>]
+ ins(%lhs, %rhs : tensor<2x?x128xf16>, tensor<2x?x128xf16>)
+ outs(%fill : tensor<2x?x?xf32>) -> tensor<2x?x?xf32>
+
+ // Each output element = sum(1.0 * 1.0, K=128) = 128.0
+ %expected = flow.tensor.dynamic_constant dense<128.0> : tensor<2x128x128xf32> -> tensor<2x?x?xf32>
+ check.expect_almost_eq(%observed, %expected, atol 1.0e-01) : tensor<2x?x?xf32>
+ return
+}