[LLVMGPU] Fall back to scalar lowering for tiny attention shapes (#24239)

The attention `VectorDistribute` configs (both the MMA-intrinsic path
and the memory-bound reduction path) assume head dimensions and K2 reach
a certain size. For shapes below that threshold (e.g. Q=K=V=[2,2,2,2]
f16), the reduction path still succeeds at emitting a `VectorDistribute`
config, but the tile sizes it picks produce vector ops whose shapes the
layout engine cannot support, causing the failure in
https://github.com/iree-org/iree/issues/24221

Add early bailouts for the shapes that cannot be tiled cleanly.

---------

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Co-authored-by: Lukas Sommer <lsommer@amd.com>
Co-authored-by: GPT-5 <noreply@openai.com>
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 876a86f..a42400e 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -1104,6 +1104,33 @@
           op.getQueryMap(), op.getKeyMap(), op.getValueMap(), op.getOutputMap())
           .value();
 
+  // Avoid the known misaligned K1/N tail cases in this reduction path.
+  // TODO: Remove the K1/N alignment checks once masked tails are verified.
+  int64_t k1Size = bounds[opInfo.getK1Dims().back()];
+  int64_t nSize = bounds[opInfo.getNDims().back()];
+  int64_t nWorkgroupTile = seeds.numValueVectors * seeds.valueVectorSize;
+  if (!ShapedType::isDynamic(k1Size) && k1Size % seeds.keyVectorSize != 0) {
+    LDBG() << "Bailing out: K1 not a multiple of key vector size ("
+           << seeds.keyVectorSize << "): " << k1Size;
+    return failure();
+  }
+  if (!ShapedType::isDynamic(nSize) && nSize % nWorkgroupTile != 0) {
+    LDBG() << "Bailing out: N not a multiple of value workgroup tile ("
+           << nWorkgroupTile << "): " << nSize;
+    return failure();
+  }
+
+  // Bail out on very skinny K2 shapes; smaller K2 collapses the reduction below
+  // what this path is tuned for.
+  // K2 may be empty; treat it as size 1, which trips the skinny-K2 check
+  // below.
+  int64_t k2Size =
+      opInfo.getK2Dims().empty() ? 1 : bounds[opInfo.getK2Dims().back()];
+  if (!ShapedType::isDynamic(k2Size) && k2Size <= kVerySkinnyDimThreshold) {
+    LDBG() << "Bailing out due to very skinny K2 dimension: " << k2Size;
+    return failure();
+  }
+
   // Distribute the 'available' resource to the basis on the given dimensions.
   // `currDim` tracks number of dims on which resources have already been
   // distributed (to keep track of order of dimension distribution).
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 718d8ac..6831461 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -867,7 +867,13 @@
 }
 
 void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
-  tileAndBufferize(funcPassManager);
+  tileAndDistributeToWorkgroup(funcPassManager);
+  // OnlineAttentionOp has no bufferization path of its own. Decompose only on
+  // the generic Distribute fallback path after workgroup tiling.
+  funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
+  funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
+  funcPassManager.addPass(createCSEPass());
+  addBufferizePasses(funcPassManager);
 
   // Distribute linalg onto threads within the workgroup.
   funcPassManager.addPass(createLLVMGPUTileAndDistributePass(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index 090dc89..8554670 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -38,6 +38,7 @@
             "configure_buffer_instructions.mlir",
             "pipeline_argcompare_vector_distribute.mlir",
             "pipeline_direct_conv_tile_and_fuse.mlir",
+            "pipeline_distribute.mlir",
             "pipeline_elementwise_f8fnuz.mlir",
             "pipeline_elementwise_f8ocp.mlir",
             "pipeline_full_smoketests.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 3d46941..4e008a3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -33,6 +33,7 @@
     "configure_buffer_instructions.mlir"
     "pipeline_argcompare_vector_distribute.mlir"
     "pipeline_direct_conv_tile_and_fuse.mlir"
+    "pipeline_distribute.mlir"
     "pipeline_elementwise_f8fnuz.mlir"
     "pipeline_elementwise_f8ocp.mlir"
     "pipeline_full_smoketests.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
index ba36719..5f37ae4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
@@ -603,3 +603,48 @@
 // CHECK-SAME:               serial = [128, 0],
 // CHECK-SAME:               subgroup_basis = {{\[}}[1, 1], [0, 1]{{\]}},
 // CHECK-SAME:               thread = [2, 0]
+
+// -----
+
+// Shapes that are too small or misaligned for the VectorDistribute attention
+// configs must fall back to the generic Distribute pipeline.
+
+// CHECK:       #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
+// CHECK-LABEL: func.func @attention_skinny_K2
+// CHECK:       iree_linalg_ext.online_attention
+// CHECK-SAME:    lowering_config = #config
+
+#pipeline_layout_skinny_k2 = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+func.func @attention_skinny_K2() {
+  %cst = arith.constant 1.0 : f16
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>>
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+  %2 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+  %3 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+  %q = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>> -> tensor<2x64x64xf16>
+  %k = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+  %v = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+  %empty = tensor.empty() : tensor<2x64x64xf32>
+  %empty_red = tensor.empty() : tensor<2x64xf32>
+  %att:3 = iree_linalg_ext.online_attention
+      {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                        affine_map<(d0, d1, d2, d3, d4) -> ()>,
+                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
+                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+                        affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>]}
+      ins(%q, %k, %v, %cst : tensor<2x64x64xf16>, tensor<2x2x64xf16>, tensor<2x2x64xf16>, f16)
+      outs(%empty, %empty_red, %empty_red : tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>) {
+    ^bb0(%score: f32):
+      iree_linalg_ext.yield %score : f32
+  } -> tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>
+  iree_tensor_ext.dispatch.tensor.store %att#0, %3, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : tensor<2x64x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+  return
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir
new file mode 100644
index 0000000..579f0bb
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir
@@ -0,0 +1,58 @@
+// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
+// RUN:   --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
+// RUN:   %s | FileCheck %s
+
+#distribute_attention_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 0, 0, 128]]>
+#distribute_attention_translation = #iree_codegen.translation_info<
+  pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
+#distribute_attention_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>,
+  #hal.pipeline.binding<storage_buffer>
+]>
+
+hal.executable private @attention_skinny_k2_distribute {
+  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
+    hal.executable.export public @attention_skinny_k2_distribute ordinal(0) layout(#distribute_attention_layout) count(%arg0: !hal.device) -> (index, index, index) {
+      %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
+      hal.return %x, %y, %z : index, index, index
+    }
+    builtin.module {
+      func.func @attention_skinny_k2_distribute() attributes {translation_info = #distribute_attention_translation} {
+        %cst = arith.constant 1.0 : f16
+        %c0 = arith.constant 0 : index
+        %0 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>>
+        %1 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+        %2 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
+        %3 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+        %q = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>> -> tensor<2x64x64xf16>
+        %k = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+        %v = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
+        %empty = tensor.empty() : tensor<2x64x64xf32>
+        %empty_red = tensor.empty() : tensor<2x64xf32>
+        %att:3 = iree_linalg_ext.online_attention
+            {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+                              affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
+                              affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
+                              affine_map<(d0, d1, d2, d3, d4) -> ()>,
+                              affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
+                              affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+                              affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
+             lowering_config = #distribute_attention_config}
+            ins(%q, %k, %v, %cst : tensor<2x64x64xf16>, tensor<2x2x64xf16>, tensor<2x2x64xf16>, f16)
+            outs(%empty, %empty_red, %empty_red : tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>) {
+          ^bb0(%score: f32):
+            iree_linalg_ext.yield %score : f32
+        } -> tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>
+        iree_tensor_ext.dispatch.tensor.store %att#0, %3, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : tensor<2x64x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
+        return
+      }
+    }
+  }
+}
+
+// CHECK:       #translation = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
+// CHECK-LABEL: func.func @attention_skinny_k2_distribute
+// CHECK-NOT:   iree_linalg_ext.online_attention
+// CHECK:       linalg.generic