Revert "[LLVMGPU] Fall back to scalar lowering for tiny attention shapes (#24239)" (#24356)
This breaks a test in MI355 CI see example
https://github.com/iree-org/iree/actions/runs/25333663639/job/74274843822#step:12:661,
reverting while we make a fix.
This reverts commit 81f4decfba8e2b8d43e9f55084802638ef7e55bb.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index a42400e..876a86f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -1104,33 +1104,6 @@
op.getQueryMap(), op.getKeyMap(), op.getValueMap(), op.getOutputMap())
.value();
- // Avoid the known misaligned K1/N tail cases in this reduction path.
- // TODO: Remove the K1/N alignment checks once masked tails are verified.
- int64_t k1Size = bounds[opInfo.getK1Dims().back()];
- int64_t nSize = bounds[opInfo.getNDims().back()];
- int64_t nWorkgroupTile = seeds.numValueVectors * seeds.valueVectorSize;
- if (!ShapedType::isDynamic(k1Size) && k1Size % seeds.keyVectorSize != 0) {
- LDBG() << "Bailing out: K1 not a multiple of key vector size ("
- << seeds.keyVectorSize << "): " << k1Size;
- return failure();
- }
- if (!ShapedType::isDynamic(nSize) && nSize % nWorkgroupTile != 0) {
- LDBG() << "Bailing out: N not a multiple of value workgroup tile ("
- << nWorkgroupTile << "): " << nSize;
- return failure();
- }
-
- // Bail out on very skinny K2 shapes; smaller K2 collapses the reduction below
- // what this path is tuned for.
- // K2 may be empty; treat it as size 1, which trips the skinny-K2 check
- // below.
- int64_t k2Size =
- opInfo.getK2Dims().empty() ? 1 : bounds[opInfo.getK2Dims().back()];
- if (!ShapedType::isDynamic(k2Size) && k2Size <= kVerySkinnyDimThreshold) {
- LDBG() << "Bailing out due to very skinny K2 dimension: " << k2Size;
- return failure();
- }
-
// Distribute the 'available' resource to the basis on the given dimensions.
// `currDim` tracks number of dims on which resources have already been
// distributed (to keep track of order of dimension distribution).
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 6831461..718d8ac 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -867,13 +867,7 @@
}
void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
- tileAndDistributeToWorkgroup(funcPassManager);
- // OnlineAttentionOp has no bufferization path of its own. Decompose only on
- // the generic Distribute fallback path after workgroup tiling.
- funcPassManager.addPass(IREE::LinalgExt::createDecomposeAttentionPass());
- funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
- funcPassManager.addPass(createCSEPass());
- addBufferizePasses(funcPassManager);
+ tileAndBufferize(funcPassManager);
// Distribute linalg onto threads within the workgroup.
funcPassManager.addPass(createLLVMGPUTileAndDistributePass(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
index 8554670..090dc89 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel
@@ -38,7 +38,6 @@
"configure_buffer_instructions.mlir",
"pipeline_argcompare_vector_distribute.mlir",
"pipeline_direct_conv_tile_and_fuse.mlir",
- "pipeline_distribute.mlir",
"pipeline_elementwise_f8fnuz.mlir",
"pipeline_elementwise_f8ocp.mlir",
"pipeline_full_smoketests.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
index 4e008a3..3d46941 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt
@@ -33,7 +33,6 @@
"configure_buffer_instructions.mlir"
"pipeline_argcompare_vector_distribute.mlir"
"pipeline_direct_conv_tile_and_fuse.mlir"
- "pipeline_distribute.mlir"
"pipeline_elementwise_f8fnuz.mlir"
"pipeline_elementwise_f8ocp.mlir"
"pipeline_full_smoketests.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
index 5f37ae4..ba36719 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
@@ -603,48 +603,3 @@
// CHECK-SAME: serial = [128, 0],
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1], [0, 1]{{\]}},
// CHECK-SAME: thread = [2, 0]
-
-// -----
-
-// Shapes that are too small or misaligned for the VectorDistribute attention
-// configs must fall back to the generic Distribute pipeline.
-
-// CHECK: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
-// CHECK-LABEL: func.func @attention_skinny_K2
-// CHECK: iree_linalg_ext.online_attention
-// CHECK-SAME: lowering_config = #config
-
-#pipeline_layout_skinny_k2 = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-func.func @attention_skinny_K2() {
- %cst = arith.constant 1.0 : f16
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>>
- %1 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
- %2 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
- %3 = hal.interface.binding.subspan layout(#pipeline_layout_skinny_k2) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
- %q = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>> -> tensor<2x64x64xf16>
- %k = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
- %v = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
- %empty = tensor.empty() : tensor<2x64x64xf32>
- %empty_red = tensor.empty() : tensor<2x64xf32>
- %att:3 = iree_linalg_ext.online_attention
- {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>]}
- ins(%q, %k, %v, %cst : tensor<2x64x64xf16>, tensor<2x2x64xf16>, tensor<2x2x64xf16>, f16)
- outs(%empty, %empty_red, %empty_red : tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>) {
- ^bb0(%score: f32):
- iree_linalg_ext.yield %score : f32
- } -> tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>
- iree_tensor_ext.dispatch.tensor.store %att#0, %3, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : tensor<2x64x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
- return
-}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir
deleted file mode 100644
index 579f0bb..0000000
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_distribute.mlir
+++ /dev/null
@@ -1,58 +0,0 @@
-// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
-// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-llvmgpu-lower-executable-target)))))" \
-// RUN: %s | FileCheck %s
-
-#distribute_attention_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 0, 0, 128]]>
-#distribute_attention_translation = #iree_codegen.translation_info<
- pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
-#distribute_attention_layout = #hal.pipeline.layout<bindings = [
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>,
- #hal.pipeline.binding<storage_buffer>
-]>
-
-hal.executable private @attention_skinny_k2_distribute {
- hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
- hal.executable.export public @attention_skinny_k2_distribute ordinal(0) layout(#distribute_attention_layout) count(%arg0: !hal.device) -> (index, index, index) {
- %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice()
- hal.return %x, %y, %z : index, index, index
- }
- builtin.module {
- func.func @attention_skinny_k2_distribute() attributes {translation_info = #distribute_attention_translation} {
- %cst = arith.constant 1.0 : f16
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>>
- %1 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
- %2 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>>
- %3 = hal.interface.binding.subspan layout(#distribute_attention_layout) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
- %q = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x64x64xf16>> -> tensor<2x64x64xf16>
- %k = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
- %v = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [2, 2, 64], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x2x64xf16>> -> tensor<2x2x64xf16>
- %empty = tensor.empty() : tensor<2x64x64xf32>
- %empty_red = tensor.empty() : tensor<2x64xf32>
- %att:3 = iree_linalg_ext.online_attention
- {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> ()>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
- affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>],
- lowering_config = #distribute_attention_config}
- ins(%q, %k, %v, %cst : tensor<2x64x64xf16>, tensor<2x2x64xf16>, tensor<2x2x64xf16>, f16)
- outs(%empty, %empty_red, %empty_red : tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>) {
- ^bb0(%score: f32):
- iree_linalg_ext.yield %score : f32
- } -> tensor<2x64x64xf32>, tensor<2x64xf32>, tensor<2x64xf32>
- iree_tensor_ext.dispatch.tensor.store %att#0, %3, offsets = [0, 0, 0], sizes = [2, 64, 64], strides = [1, 1, 1] : tensor<2x64x64xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x64x64xf32>>
- return
- }
- }
- }
-}
-
-// CHECK: #translation = #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<Distribute> workgroup_size = [128, 1, 1] subgroup_size = 64>
-// CHECK-LABEL: func.func @attention_skinny_k2_distribute
-// CHECK-NOT: iree_linalg_ext.online_attention
-// CHECK: linalg.generic