[GPU] Add lowering configuration logic for scatter (#19624)

This replaces LLVMGPUDistribute with LLVMGPUTileAndFuse for scatters on
their own. This gives vectorization opportunities where possible and is
a step towards deprecating LLVMGPUDistribute.
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
index 4f4d778..3c514c5 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
@@ -32,6 +32,7 @@
 namespace mlir::iree_compiler::IREE::GPU {
 
 constexpr int64_t kCacheLineSizeBits = 128 * 8;
+constexpr int64_t kPreferredCopyNumBits = 128;
 
 LogicalResult setDataTiledMultiMmaLoweringConfig(
     IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
@@ -732,6 +733,90 @@
       {flatWorkgroupSize, 1, 1}, subgroupSize, DictionaryAttr());
 }
 
+LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target,
+                                       mlir::FunctionOpInterface entryPoint,
+                                       Operation *op) {
+  auto scatter = dyn_cast<IREE::LinalgExt::ScatterOp>(op);
+  if (!scatter) {
+    return failure();
+  }
+
+  // TODO: Support non-unique indices.
+  if (!scatter.getUniqueIndices()) {
+    return failure();
+  }
+
+  // Various problem parameters.
+  int64_t loopDepth = scatter.getLoopIteratorTypes().size();
+  int64_t elemBits = scatter.getOriginalType().getElementTypeBitWidth();
+  SmallVector<int64_t> loopBounds = scatter.getStaticLoopRanges().value_or(
+      SmallVector<int64_t>(loopDepth, ShapedType::kDynamic));
+
+  // Configurations we need to decide.
+  int64_t flatWorkgroupSize = target.getPreferredSubgroupSize();
+  SmallVector<int64_t> workgroupTileSizes(loopDepth, 1);
+  SmallVector<int64_t> threadTileSizes(loopDepth, 1);
+  int64_t vectorSize = kPreferredCopyNumBits / elemBits;
+
+  bool innerDynamic = ShapedType::isDynamic(loopBounds.back());
+
+  // Do not bother trying to vectorize if there are no vectorizable dims.
+  if (loopDepth == 1) {
+    vectorSize = 1;
+  } else if (!innerDynamic) {
+    // Use the largest power of 2 that divides the inner most non-scattered dim.
+    vectorSize = std::gcd(vectorSize, loopBounds.back());
+  }
+
+  threadTileSizes.back() = vectorSize;
+  int64_t residualInnerSize =
+      innerDynamic ? loopBounds.back() : loopBounds.back() / vectorSize;
+
+  // If the inner most dim is dynamic or exceeds the expected number of threads,
+  // Only distribute threads along the inner most dimension.
+  if (ShapedType::isDynamic(residualInnerSize) ||
+      residualInnerSize >= flatWorkgroupSize) {
+    workgroupTileSizes.back() = vectorSize * flatWorkgroupSize;
+  } else { // residualInnerSize < flatWorkgroupSize
+    // Floordiv to overestimate the required number of threads.
+    int64_t residualThreads = flatWorkgroupSize / residualInnerSize;
+    workgroupTileSizes.back() = residualInnerSize * vectorSize;
+    for (int64_t i = loopDepth - 2, e = 0; i >= e; --i) {
+      if (residualThreads <= 1) {
+        break;
+      }
+
+      bool dynamicDim = ShapedType::isDynamic(loopBounds[i]);
+      workgroupTileSizes[i] = dynamicDim
+                                  ? residualThreads
+                                  : std::min(residualThreads, loopBounds[i]);
+      residualThreads = dynamicDim ? 1 : residualThreads / loopBounds[i];
+    }
+  }
+
+  // Attach the MMA schedule as an attribute to the entry point export function
+  // for later access in the pipeline.
+  MLIRContext *context = scatter.getContext();
+  SmallVector<NamedAttribute, 1> attrs;
+  Builder b(context);
+  attrs.emplace_back(StringAttr::get(context, "workgroup"),
+                     b.getI64ArrayAttr(workgroupTileSizes));
+
+  attrs.emplace_back(StringAttr::get(context, "thread"),
+                     b.getI64ArrayAttr(threadTileSizes));
+
+  auto configDict = DictionaryAttr::get(context, attrs);
+  auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
+
+  LDBG("Selected tile and fuse lowering config: " << loweringConfig << "\n");
+
+  // TODO(qedawkins): Use a shared pipeline identifier here.
+  return setOpConfigAndEntryPointFnTranslation(
+      entryPoint, scatter, loweringConfig,
+      IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse,
+      {flatWorkgroupSize, 1, 1}, flatWorkgroupSize, DictionaryAttr());
+}
+
 //===----------------------------------------------------------------------===//
 // Lowering Config Attributes
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h
index 636ffe5..ad9aa42 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h
@@ -42,6 +42,11 @@
                                            mlir::FunctionOpInterface entryPoint,
                                            Operation *op);
 
+// Helper for setting tile sizes for scatter.
+LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target,
+                                       mlir::FunctionOpInterface entryPoint,
+                                       Operation *op);
+
 //===----------------------------------------------------------------------===//
 // Pass Pipeline Options
 //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index 2c96a56..e09a72b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -2425,6 +2425,14 @@
         return setDefaultCustomOpLoweringConfig(entryPointFn, customOp,
                                                 initGPULaunchConfig);
       })
+      .Case<IREE::LinalgExt::ScatterOp>([&](auto scatterOp) {
+        LDBG("ScatterOp Config");
+        if (failed(IREE::GPU::setScatterLoweringConfig(target, entryPointFn,
+                                                       scatterOp))) {
+          return setRootDefaultConfig(target, entryPointFn, computeOp);
+        }
+        return success();
+      })
       .Default([&](auto op) {
         LDBG("Default Config");
         if (!clLLVMGPUVectorizePipeline) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
index f31ac9a..125ad8b 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir
@@ -310,6 +310,90 @@
 
 // -----
 
+func.func @large_scatter(%arg0: tensor<3x2048x2048xf32>,
+                   %arg1: tensor<3x1xi32>) -> tensor<3x2048x2048xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<3x2048x2048xf32>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+    ins(%arg0, %arg1 : tensor<3x2048x2048xf32>, tensor<3x1xi32>) outs(%0 : tensor<3x2048x2048xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):
+    iree_linalg_ext.yield %arg2 : f32
+  } -> tensor<3x2048x2048xf32>
+  return %1 : tensor<3x2048x2048xf32>
+}
+
+// CHECK-LABEL: func.func @large_scatter
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+//       CHECK:   linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     thread = [1, 1, 4]
+//  CHECK-SAME:     workgroup = [1, 1, 256]
+
+// -----
+
+func.func @small_scatter(%arg0: tensor<3x32x16xf32>,
+                         %arg1: tensor<3x1xi32>) -> tensor<3x32x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<3x32x16xf32>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+    ins(%arg0, %arg1 : tensor<3x32x16xf32>, tensor<3x1xi32>) outs(%0 : tensor<3x32x16xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):
+    iree_linalg_ext.yield %arg2 : f32
+  } -> tensor<3x32x16xf32>
+  return %1 : tensor<3x32x16xf32>
+}
+
+// CHECK-LABEL: func.func @small_scatter
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+//       CHECK:   linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     thread = [1, 1, 4]
+//  CHECK-SAME:     workgroup = [1, 16, 16]
+
+// -----
+
+func.func @only_scattered_dim(%arg0: tensor<48xf32>,
+                              %arg1: tensor<48x2xi32>) -> tensor<100x100xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<100x100xf32>
+  %1 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
+    ins(%arg0, %arg1 : tensor<48xf32>, tensor<48x2xi32>) outs(%0 : tensor<100x100xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):
+    iree_linalg_ext.yield %arg2 : f32
+  } -> tensor<100x100xf32>
+  return %1 : tensor<100x100xf32>
+}
+
+// CHECK-LABEL: func.func @only_scattered_dim
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+//       CHECK:   linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     thread = [1]
+//  CHECK-SAME:     workgroup = [48]
+
+// -----
+
+func.func @dynamic_scatter(%arg0: tensor<3x32x?xf32>,
+                           %arg1: tensor<3x1xi32>,
+                           %arg2: tensor<3x32x?xf32>) -> tensor<3x32x?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+    ins(%arg0, %arg1 : tensor<3x32x?xf32>, tensor<3x1xi32>) outs(%arg2 : tensor<3x32x?xf32>) {
+  ^bb0(%arg3: f32, %arg4: f32):
+    iree_linalg_ext.yield %arg3 : f32
+  } -> tensor<3x32x?xf32>
+  return %1 : tensor<3x32x?xf32>
+}
+
+// CHECK-LABEL: func.func @dynamic_scatter
+//  CHECK-SAME:   #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
+
+//       CHECK:   linalg_ext.scatter {{.*}}lowering_config = #iree_gpu.lowering_config
+//  CHECK-SAME:     thread = [1, 1, 4]
+//  CHECK-SAME:     workgroup = [1, 1, 256]
+
+// -----
+
 func.func @elementwise_scatter(%arg0: tensor<3x2048x2048xf32>,
                                %arg1: tensor<3x2048x2048xf32>,
                                %arg2: tensor<3x1xi32>) -> tensor<3x2048x2048xf32> {