[LLVMGPU] Fall back to scalar lowering for tiny attention shapes (#24239)
The attention `VectorDistribute` configs (both the MMA-intrinsic path
and the memory-bound reduction path) assume head dimensions and K2 reach
a certain size. For shapes below that threshold (e.g. Q=K=V=[2,2,2,2]
f16), the reduction path still succeeds at emitting a `VectorDistribute`
config, but the tile sizes it picks produce vector ops whose shapes the
layout engine cannot support, causing the failure in
https://github.com/iree-org/iree/issues/24221
Add early bailouts for the shapes that cannot be tiled cleanly.
---------
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: Lukas Sommer <lsommer@amd.com>
Co-authored-by: GPT-5 <noreply@openai.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 876a86f..a42400e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -1104,6 +1104,33 @@
op.getQueryMap(), op.getKeyMap(), op.getValueMap(), op.getOutputMap())
.value();
+ // Avoid the known misaligned K1/N tail cases in this reduction path.
+ // TODO: Remove the K1/N alignment checks once masked tails are verified.
+ int64_t k1Size = bounds[opInfo.getK1Dims().back()];
+ int64_t nSize = bounds[opInfo.getNDims().back()];
+ int64_t nWorkgroupTile = seeds.numValueVectors * seeds.valueVectorSize;
+ if (!ShapedType::isDynamic(k1Size) && k1Size % seeds.keyVectorSize != 0) {
+ LDBG() << "Bailing out: K1 not a multiple of key vector size ("
+ << seeds.keyVectorSize << "): " << k1Size;
+ return failure();
+ }
+ if (!ShapedType::isDynamic(nSize) && nSize % nWorkgroupTile != 0) {
+ LDBG() << "Bailing out: N not a multiple of value workgroup tile ("
+ << nWorkgroupTile << "): " << nSize;
+ return failure();
+ }
+
+ // Bail out on very skinny K2 shapes; smaller K2 collapses the reduction below
+ // what this path is tuned for.
+ // K2 may be empty; treat it as size 1, which trips the skinny-K2 check
+ // below.
+ int64_t k2Size =
+ opInfo.getK2Dims().empty() ? 1 : bounds[opInfo.getK2Dims().back()];
+ if (!ShapedType::isDynamic(k2Size) && k2Size <= kVerySkinnyDimThreshold) {
+ LDBG() << "Bailing out due to very skinny K2 dimension: " << k2Size;
+ return failure();
+ }
+
// Distribute the 'available' resource to the basis on the given dimensions.
// `currDim` tracks number of dims on which resources have already been
// distributed (to keep track of order of dimension distribution).
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 718d8ac..6831461 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -867,7 +867,13 @@
}
void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
- tileAndBufferize(funcPassManager);
+ tileAndDistributeToWorkgroup(funcPassManager);
+ // OnlineAttentionOp has no bufferization path of its own. Decompose only on
+ // the generic Distribute fallback path after workgroup tiling.
+ funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
+ funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
+ funcPassManager.addPass(createCSEPass());
+ addBufferizePasses(funcPassManager);
// Distribute linalg onto threads within the workgroup.
funcPassManager.addPass(createLLVMGPUTileAndDistributePass(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index 090dc89..8554670 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -38,6 +38,7 @@
"configure_buffer_instructions.mlir",
"pipeline_argcompare_vector_distribute.mlir",
"pipeline_direct_conv_tile_and_fuse.mlir",
+ "pipeline_distribute.mlir",
"pipeline_elementwise_f8fnuz.mlir",
"pipeline_elementwise_f8ocp.mlir",
"pipeline_full_smoketests.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 3d46941..4e008a3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -33,6 +33,7 @@
"configure_buffer_instructions.mlir"
"pipeline_argcompare_vector_distribute.mlir"
"pipeline_direct_conv_tile_and_fuse.mlir"
+ "pipeline_distribute.mlir"
"pipeline_elementwise_f8fnuz.mlir"
"pipeline_elementwise_f8ocp.mlir"
"pipeline_full_smoketests.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
index ba36719..5f37ae4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
@@ -603,3 +603,48 @@
// CHECK-SAME: serial = [128, 0],
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1], [0, 1]{{\]}},
// CHECK-SAME: thread = [2, 0]
+
+// -----
+
+// Shapes that are too small or misaligned for the VectorDistribute attention
+// configs must fall back to the generic Distribute pipeline.
+
+// CHECK: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
+// CHECK-LABEL: func.func @attention_skinny_K2
+// CHECK: iree_linalg_ext.online_attention
+// CHECK-SAME: lowering_config = #config
+
+#pipeline_layout_skinny_k2 = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @attention_skinny_K2() {
+ %cst = arith.constant 1.0 : f16
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+ %3 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+ %q = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>> -> tensor<2x64x64xf16>
+ %k = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+ %v = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+ %empty = tensor.empty() : tensor<2x64x64xf32>
+ %empty_red = tensor.empty() : tensor<2x64xf32>
+ %att:3 = iree_linalg_ext.online_attention
+ {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>]}
+ ins(%q, %k, %v, %cst : tensor<2x64x64xf16>, tensor<2x2x64xf16>, tensor<2x2x64xf16>, f16)
+ outs(%empty, %empty_red, %empty_red : tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score : f32
+ } -> tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>
+ iree_tensor_ext.dispatch.tensor.store %att#0, %3, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : tensor<2x64x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+ return
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir
new file mode 100644
index 0000000..579f0bb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir
@@ -0,0 +1,58 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
+// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN: %s | FileCheck %s
+
+#distribute_attention_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 0, 0, 128]]>
+#distribute_attention_translation = #iree_codegen.translation_info<
+ pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
+#distribute_attention_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+
+hal.executable private @attention_skinny_k2_distribute {
+ hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.export public @attention_skinny_k2_distribute ordinal(0) layout(#distribute_attention_layout) count(%arg0: !hal.device) -> (index, index, index) {
+ %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+ hal.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @attention_skinny_k2_distribute() attributes {translation_info = #distribute_attention_translation} {
+ %cst = arith.constant 1.0 : f16
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>>
+ %1 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+ %2 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+ %3 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+ %q = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>> -> tensor<2x64x64xf16>
+ %k = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+ %v = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+ %empty = tensor.empty() : tensor<2x64x64xf32>
+ %empty_red = tensor.empty() : tensor<2x64xf32>
+ %att:3 = iree_linalg_ext.online_attention
+ {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> ()>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
+ lowering_config = #distribute_attention_config}
+ ins(%q, %k, %v, %cst : tensor<2x64x64xf16>, tensor<2x2x64xf16>, tensor<2x2x64xf16>, f16)
+ outs(%empty, %empty_red, %empty_red : tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>) {
+ ^bb0(%score: f32):
+ iree_linalg_ext.yield %score : f32
+ } -> tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>
+ iree_tensor_ext.dispatch.tensor.store %att#0, %3, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : tensor<2x64x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+ return
+ }
+ }
+ }
+}
+
+// CHECK: #translation = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
+// CHECK-LABEL: func.func @attention_skinny_k2_distribute
+// CHECK-NOT: iree_linalg_ext.online_attention
+// CHECK: linalg.generic