[LLVMGPU] Use LLVMGPUDistribute for small input scatters (#19670)
If the scattered slice is small then we will end up only distributing
the batch dimensions to workgroups. This causes bufferization to fail in
LLVMGPUTileAndFuse because the workgroup level `extract_slice` will fold
away.
Fixes #19639
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 61c325c..a119351 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -787,6 +787,19 @@
}
}
+ int64_t numBatch = scatter.getBatchRank();
+ // Currently bufferization will fail if the only dimension distributed to
+ // workgroups is the batch dims because the workgroup level slice will fold
+ // away and cause a mismatch.
+ // TODO(qedawkins): Support this case.
+ if (llvm::all_of_zip(llvm::drop_begin(workgroupTileSizes, numBatch),
+ llvm::drop_begin(loopBounds, numBatch),
+ [](int64_t tileSize, int64_t bound) {
+ return tileSize == bound || tileSize == 0;
+ })) {
+ return failure();
+ }
+
// Attach the MMA schedule as an attribute to the entry point export function
// for later access in the pipeline.
MLIRContext *context = scatter.getContext();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index 3d137e5..6f94069 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -371,11 +371,10 @@
}
// CHECK-LABEL: func.func @only_scattered_dim
-// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUDistribute workgroup_size = [128, 1, 1] subgroup_size = 64
-// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
-// CHECK-SAME: thread = [1]
-// CHECK-SAME: workgroup = [48]
+// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_codegen.lowering_config
+// CHECK-SAME: tile_sizes = {{\[}}[128]]
// -----