[GPU] Add lowering configuration logic for scatter (#19624)
This replaces LLVMGPUDistribute with LLVMGPUTileAndFuse for scatters on
their own. This gives vectorization opportunities where possible and is
a step towards deprecating LLVMGPUDistribute.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 4f4d778..3c514c5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -32,6 +32,7 @@
namespace mlir::iree_compiler::IREE::GPU {
constexpr int64_t kCacheLineSizeBits = 128 * 8;
+constexpr int64_t kPreferredCopyNumBits = 128;
LogicalResult setDataTiledMultiMmaLoweringConfig(
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
@@ -732,6 +733,90 @@
{flatWorkgroupSize, 1, 1}, subgroupSize, DictionaryAttr());
}
+LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target,
+ mlir::FunctionOpInterface entryPoint,
+ Operation *op) {
+ auto scatter = dyn_cast<IREE::LinalgExt::ScatterOp>(op);
+ if (!scatter) {
+ return failure();
+ }
+
+ // TODO: Support non-unique indices.
+ if (!scatter.getUniqueIndices()) {
+ return failure();
+ }
+
+ // Various problem parameters.
+ int64_t loopDepth = scatter.getLoopIteratorTypes().size();
+ int64_t elemBits = scatter.getOriginalType().getElementTypeBitWidth();
+ SmallVector<int64_t> loopBounds = scatter.getStaticLoopRanges().value_or(
+ SmallVector<int64_t>(loopDepth, ShapedType::kDynamic));
+
+ // Configurations we need to decide.
+ int64_t flatWorkgroupSize = target.getPreferredSubgroupSize();
+ SmallVector<int64_t> workgroupTileSizes(loopDepth, 1);
+ SmallVector<int64_t> threadTileSizes(loopDepth, 1);
+ int64_t vectorSize = kPreferredCopyNumBits / elemBits;
+
+ bool innerDynamic = ShapedType::isDynamic(loopBounds.back());
+
+ // Do not bother trying to vectorize if there are no vectorizable dims.
+ if (loopDepth == 1) {
+ vectorSize = 1;
+ } else if (!innerDynamic) {
+ // Use the largest power of 2 that divides the inner most non-scattered dim.
+ vectorSize = std::gcd(vectorSize, loopBounds.back());
+ }
+
+ threadTileSizes.back() = vectorSize;
+ int64_t residualInnerSize =
+ innerDynamic ? loopBounds.back() : loopBounds.back() / vectorSize;
+
+ // If the inner most dim is dynamic or exceeds the expected number of threads,
+ // Only distribute threads along the inner most dimension.
+ if (ShapedType::isDynamic(residualInnerSize) ||
+ residualInnerSize >= flatWorkgroupSize) {
+ workgroupTileSizes.back() = vectorSize * flatWorkgroupSize;
+ } else { // residualInnerSize < flatWorkgroupSize
+ // Floordiv to overestimate the required number of threads.
+ int64_t residualThreads = flatWorkgroupSize / residualInnerSize;
+ workgroupTileSizes.back() = residualInnerSize * vectorSize;
+ for (int64_t i = loopDepth - 2, e = 0; i >= e; --i) {
+ if (residualThreads <= 1) {
+ break;
+ }
+
+ bool dynamicDim = ShapedType::isDynamic(loopBounds[i]);
+ workgroupTileSizes[i] = dynamicDim
+ ? residualThreads
+ : std::min(residualThreads, loopBounds[i]);
+ residualThreads = dynamicDim ? 1 : residualThreads / loopBounds[i];
+ }
+ }
+
+ // Attach the MMA schedule as an attribute to the entry point export function
+ // for later access in the pipeline.
+ MLIRContext *context = scatter.getContext();
+ SmallVector<NamedAttribute, 1> attrs;
+ Builder b(context);
+ attrs.emplace_back(StringAttr::get(context, "workgroup"),
+ b.getI64ArrayAttr(workgroupTileSizes));
+
+ attrs.emplace_back(StringAttr::get(context, "thread"),
+ b.getI64ArrayAttr(threadTileSizes));
+
+ auto configDict = DictionaryAttr::get(context, attrs);
+ auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
+
+ LDBG("Selected tile and fuse lowering config: " << loweringConfig << "\n");
+
+ // TODO(qedawkins): Use a shared pipeline identifier here.
+ return setOpConfigAndEntryPointFnTranslation(
+ entryPoint, scatter, loweringConfig,
+ IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse,
+ {flatWorkgroupSize, 1, 1}, flatWorkgroupSize, DictionaryAttr());
+}
+
//===----------------------------------------------------------------------===//
// Lowering Config Attributes
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h
index 636ffe5..ad9aa42 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h
@@ -42,6 +42,11 @@
mlir::FunctionOpInterface entryPoint,
Operation *op);
+// Helper for setting tile sizes for scatter.
+LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target,
+ mlir::FunctionOpInterface entryPoint,
+ Operation *op);
+
//===----------------------------------------------------------------------===//
// Pass Pipeline Options
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 2c96a56..e09a72b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -2425,6 +2425,14 @@
return setDefaultCustomOpLoweringConfig(entryPointFn, customOp,
initGPULaunchConfig);
})
+ .Case<IREE::LinalgExt::ScatterOp>([&](auto scatterOp) {
+ LDBG("ScatterOp Config");
+ if (failed(IREE::GPU::setScatterLoweringConfig(target, entryPointFn,
+ scatterOp))) {
+ return setRootDefaultConfig(target, entryPointFn, computeOp);
+ }
+ return success();
+ })
.Default([&](auto op) {
LDBG("Default Config");
if (!clLLVMGPUVectorizePipeline) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index f31ac9a..125ad8b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -310,6 +310,90 @@
// -----
+func.func @large_scatter(%arg0: tensor<3x2048x2048xf32>,
+ %arg1: tensor<3x1xi32>) -> tensor<3x2048x2048xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<3x2048x2048xf32>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%arg0, %arg1 : tensor<3x2048x2048xf32>, tensor<3x1xi32>) outs(%0 : tensor<3x2048x2048xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ iree_linalg_ext.yield %arg2 : f32
+ } -> tensor<3x2048x2048xf32>
+ return %1 : tensor<3x2048x2048xf32>
+}
+
+// CHECK-LABEL: func.func @large_scatter
+// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: thread = [1, 1, 4]
+// CHECK-SAME: workgroup = [1, 1, 256]
+
+// -----
+
+func.func @small_scatter(%arg0: tensor<3x32x16xf32>,
+ %arg1: tensor<3x1xi32>) -> tensor<3x32x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<3x32x16xf32>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%arg0, %arg1 : tensor<3x32x16xf32>, tensor<3x1xi32>) outs(%0 : tensor<3x32x16xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ iree_linalg_ext.yield %arg2 : f32
+ } -> tensor<3x32x16xf32>
+ return %1 : tensor<3x32x16xf32>
+}
+
+// CHECK-LABEL: func.func @small_scatter
+// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: thread = [1, 1, 4]
+// CHECK-SAME: workgroup = [1, 16, 16]
+
+// -----
+
+func.func @only_scattered_dim(%arg0: tensor<48xf32>,
+ %arg1: tensor<48x2xi32>) -> tensor<100x100xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<100x100xf32>
+ %1 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
+ ins(%arg0, %arg1 : tensor<48xf32>, tensor<48x2xi32>) outs(%0 : tensor<100x100xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32):
+ iree_linalg_ext.yield %arg2 : f32
+ } -> tensor<100x100xf32>
+ return %1 : tensor<100x100xf32>
+}
+
+// CHECK-LABEL: func.func @only_scattered_dim
+// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: thread = [1]
+// CHECK-SAME: workgroup = [48]
+
+// -----
+
+func.func @dynamic_scatter(%arg0: tensor<3x32x?xf32>,
+ %arg1: tensor<3x1xi32>,
+ %arg2: tensor<3x32x?xf32>) -> tensor<3x32x?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%arg0, %arg1 : tensor<3x32x?xf32>, tensor<3x1xi32>) outs(%arg2 : tensor<3x32x?xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ iree_linalg_ext.yield %arg3 : f32
+ } -> tensor<3x32x?xf32>
+ return %1 : tensor<3x32x?xf32>
+}
+
+// CHECK-LABEL: func.func @dynamic_scatter
+// CHECK-SAME: #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+// CHECK: linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+// CHECK-SAME: thread = [1, 1, 4]
+// CHECK-SAME: workgroup = [1, 1, 256]
+
+// -----
+
func.func @elementwise_scatter(%arg0: tensor<3x2048x2048xf32>,
%arg1: tensor<3x2048x2048xf32>,
%arg2: tensor<3x1xi32>) -> tensor<3x2048x2048xf32> {