[VectorDistribute] Consider all compute ops for thread tile size (#23394)

So far, selection of the thread tile size (`threadLoads`) in reduction
VectorDistribute lowering strategy selection was purely based on the
root operation. If the selected size wasn't suitable for other compute
operations in the same dispatch, selection of the VectorDistribute
strategy would fail, falling back to other strategies.

This change addresses an existing TODO to consider the constraints of
the other compute ops in the same dispatch when choosing the thread tile
size. This may lead to smaller tile sizes (worst case one element per
thread) if the other compute ops have additional constraints, but allows
to use the VectorDistribute pipeline for such reductions.

This fixes https://github.com/iree-org/iree/issues/23340.

Assisted-by: Claude Code

ci-extra: test_torch

---------

Signed-off-by: Lukas Sommer <lukas.sommer@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp
index c9866b9..79d21ac 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp
@@ -130,6 +130,21 @@
   return largestInput.elementBitwidth;
 }
 
+/// Returns the dimension size that threadLoads must divide for the given op.
+static int64_t getThreadLoadsConstraint(linalg::LinalgOp op) {
+  SmallVector<unsigned> parallelDims, reductionDims;
+  op.getParallelDims(parallelDims);
+  op.getReductionDims(reductionDims);
+  SmallVector<int64_t> bounds = op.getStaticLoopRanges();
+
+  if (reductionDims.empty()) {
+    // Parallel-only op: threadLoads must divide last parallel dim
+    return bounds[parallelDims.back()];
+  }
+  // Reduction op: threadLoads must divide last reduction dim
+  return bounds[reductionDims.back()];
+}
+
 /// Check if the reduction op has a single combiner operation.
 static LogicalResult checkSingleCombiner(linalg::LinalgOp op) {
   bool foundSingleReductionOutput = false;
@@ -661,7 +676,7 @@
   // reductions in scaled matmul with the last dimension being the block size
   // (32 for gfx950).
   int64_t reductionSize = bounds[reductionDims.back()];
-  if (!ShapedType::isDynamic(reductionSize) &&
+  if (ShapedType::isStatic(reductionSize) &&
       reductionSize % target.getPreferredSubgroupSize() != 0) {
     // Consider the entire reduction dimension.
     reductionSize = 1;
@@ -717,6 +732,20 @@
     }
   }
 
+  // Adjust threadLoads to satisfy constraints from all compute ops in the
+  // dispatch.
+  for (linalg::LinalgOp linalgOp : *computeOps) {
+    int64_t constraint = getThreadLoadsConstraint(linalgOp);
+    if (ShapedType::isStatic(constraint)) {
+      while (threadLoads > 1 && constraint % threadLoads != 0) {
+        threadLoads /= 2;
+      }
+    }
+    if (threadLoads <= 1) {
+      break;
+    }
+  }
+
   std::optional<int64_t> parallelSize = 1;
   for (int64_t dim : parallelDims) {
     if (ShapedType::isDynamic(bounds[dim])) {
@@ -767,9 +796,6 @@
     *parallelSize /= 2;
   }
 
-  // TODO(pashu123): Currently, the threadLoads is done on the basis of
-  // the root operation and ignores other operation within a dispatch.
-  // Extend it to use per operation within a dispatch.
   if (failed(populateConfigInfo(*computeOps, target, workgroupSize,
                                 subgroupSize, threadLoads))) {
     return failure();
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
index 10f7982..731bf0d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir
@@ -179,78 +179,78 @@
 
 // -----
 
-// The same IR as 'test_multiple_reduction' but with an unsupported operation,
-// preventing this from going down vector distribute. Previously lowering configs
-// would be attached to the supported operations even though the full dispatch
-// is unsupported.
-func.func @test_negative_multiple_reduction() {
-  %cst = arith.constant 0.000000e+00 : f32
-  %cst_0 = arith.constant 1.638400e+05 : f32
+// Test derived from NHWC layernorm and the bug reported in
+// https://github.com/iree-org/iree/issues/23340.
+// At the time, the VectorDistribute pipeline wasn't chosen for this input,
+// because the `threadLoads` determined purely based on the root reduction
+// operation doesn't divide the number of channels (3). Now `threadLoads`
+// is determined across all compute ops in the dispatch and VectorDistribute
+// is used.
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+  #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+  #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+  #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
+  #hal.pipeline.binding<storage_buffer, Indirect>
+], flags = Indirect>
+func.func @nhwc_layernorm_small_channel() {
+  %cst = arith.constant 0.000000e+00 : bf16
+  %cst_0 = arith.constant 4.915200e+04 : bf16
   %cst_1 = arith.constant 9.99999974E-6 : f32
   %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x32x10x16384xf16>>
-  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x32x10x16384x1x1xf16>>
-  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x32x10x16384xf32>>
-  %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 32, 10, 16384], strides = [1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x32x10x16384xf16>> -> tensor<2x32x10x16384xf16>
-  %unitdims_4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0, 0, 0, 0], sizes = [2, 32, 10, 16384, 1, 1], strides = [1, 1, 1, 1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x32x10x16384x1x1xf16>> -> tensor<2x32x10x16384x1x1xf16>
-
-  // Operation unsupported by reduction vector distribution.
-  %outs_4 = tensor.empty() : tensor<2x32x10x16384xf16>
-  %4 = linalg.reduce ins(%unitdims_4 : tensor<2x32x10x16384x1x1xf16>) outs(%outs_4 : tensor<2x32x10x16384xf16>) dimensions = [4, 5]
-    (%in: f16, %init: f16) {
-      %20 = arith.addf %in, %init : f16
-      linalg.yield %20 : f16
-    }
-
-  %5 = tensor.empty() : tensor<2x32x10x16384xf32>
-  %6 = tensor.empty() : tensor<2x32xf32>
-  %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3 : tensor<2x32x10x16384xf16>) outs(%5 : tensor<2x32x10x16384xf32>) {
-  ^bb0(%in: f16, %out: f32):
-    %13 = arith.extf %in : f16 to f32
-    linalg.yield %13 : f32
-  } -> tensor<2x32x10x16384xf32>
-  %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<2x32xf32>) -> tensor<2x32xf32>
-  %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%7 : tensor<2x32x10x16384xf32>) outs(%8 : tensor<2x32xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %13 = arith.addf %in, %out : f32
-    linalg.yield %13 : f32
-  } -> tensor<2x32xf32>
-  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9 : tensor<2x32xf32>) outs(%6 : tensor<2x32xf32>) {
-  ^bb0(%in: f32, %out: f32):
-    %13 = arith.divf %in, %cst_0 : f32
-    linalg.yield %13 : f32
-  } -> tensor<2x32xf32>
-  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%7, %10 : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) outs(%8 : tensor<2x32xf32>) {
-  ^bb0(%in: f32, %in_2: f32, %out: f32):
-    %13 = arith.subf %in, %in_2 : f32
-    %14 = arith.mulf %13, %13 : f32
-    %15 = arith.addf %14, %out : f32
-    linalg.yield %15 : f32
-  } -> tensor<2x32xf32>
-  %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4, %10, %11 : tensor<2x32x10x16384xf16>, tensor<2x32xf32>, tensor<2x32xf32>) outs(%5 : tensor<2x32x10x16384xf32>) {
-  ^bb0(%in: f16, %in_2: f32, %in_3: f32, %out: f32):
-    %13 = arith.divf %in_3, %cst_0 : f32
-    %14 = arith.addf %13, %cst_1 : f32
-    %15 = math.rsqrt %14 : f32
-    %16 = arith.extf %in : f16 to f32
-    %17 = arith.subf %16, %in_2 : f32
-    %18 = arith.mulf %17, %15 : f32
-    linalg.yield %18 : f32
-  } -> tensor<2x32x10x16384xf32>
-  iree_tensor_ext.dispatch.tensor.store %12, %2, offsets = [0, 0, 0, 0], sizes = [2, 32, 10, 16384], strides = [1, 1, 1, 1] : tensor<2x32x10x16384xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x32x10x16384xf32>>
+  %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x16384x3xbf16>>
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16384x3xbf16>>
+  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16384x3xbf16>>
+  %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags(Indirect) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x16384x3xbf16>>
+  %4 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [2, 16384, 3], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<2x16384x3xbf16>> -> tensor<2x16384x3xbf16>
+  %5 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [16384, 3], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16384x3xbf16>> -> tensor<16384x3xbf16>
+  %6 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0], sizes = [16384, 3], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16384x3xbf16>> -> tensor<16384x3xbf16>
+  %7 = tensor.empty() : tensor<2xbf16>
+  %8 = linalg.fill ins(%cst : bf16) outs(%7 : tensor<2xbf16>) -> tensor<2xbf16>
+  %9 = tensor.empty() : tensor<2x3x16384xbf16>
+  %10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4 : tensor<2x16384x3xbf16>) outs(%9 : tensor<2x3x16384xbf16>) {
+  ^bb0(%in: bf16, %out: bf16):
+    linalg.yield %in : bf16
+  } -> tensor<2x3x16384xbf16>
+  %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%10 : tensor<2x3x16384xbf16>) outs(%8 : tensor<2xbf16>) {
+  ^bb0(%in: bf16, %out: bf16):
+    %16 = arith.addf %in, %out : bf16
+    linalg.yield %16 : bf16
+  } -> tensor<2xbf16>
+  %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%10, %11 : tensor<2x3x16384xbf16>, tensor<2xbf16>) outs(%8 : tensor<2xbf16>) {
+  ^bb0(%in: bf16, %mean: bf16, %out: bf16):
+    %16 = arith.divf %mean, %cst_0 : bf16
+    %17 = arith.subf %in, %16 : bf16
+    %18 = arith.mulf %17, %17 : bf16
+    %19 = arith.addf %18, %out : bf16
+    linalg.yield %19 : bf16
+  } -> tensor<2xbf16>
+  %13 = tensor.empty() : tensor<2x16384x3xbf16>
+  %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2, d1)>, affine_map<(d0, d1, d2) -> (d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %12, %5, %6 : tensor<2x3x16384xbf16>, tensor<2xbf16>, tensor<16384x3xbf16>, tensor<16384x3xbf16>) outs(%13 : tensor<2x16384x3xbf16>) {
+  ^bb0(%in: bf16, %var: bf16, %scale: bf16, %bias: bf16, %out: bf16):
+    %16 = arith.divf %var, %cst_0 : bf16
+    %17 = arith.truncf %cst_1 : f32 to bf16
+    %18 = arith.addf %16, %17 : bf16
+    %19 = math.rsqrt %18 : bf16
+    %20 = arith.mulf %in, %19 : bf16
+    %21 = arith.mulf %20, %scale : bf16
+    %22 = arith.addf %21, %bias : bf16
+    linalg.yield %22 : bf16
+  } -> tensor<2x16384x3xbf16>
+  iree_tensor_ext.dispatch.tensor.store %14, %3, offsets = [0, 0, 0], sizes = [2, 16384, 3], strides = [1, 1, 1] : tensor<2x16384x3xbf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<2x16384x3xbf16>>
   return
 }
 
-// Shouldn't go down vector distribute.
-// CHECK:       #iree_codegen.translation_info<pipeline =
-// CHECK-NOT:   LLVMGPUVectorDistribute
+// Verify VectorDistribute is selected despite small channel dimension.
+// The thread tile sizes should be 1 (not 8) because threadLoads was reduced
+// to satisfy the constraint from the parallel operations' last parallel
+// dim (3).
 
-// CHECK-LABEL: func.func @test_negative_multiple_reduction
+// CHECK:       #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
 
-// Only one lowering config should be present.
-// CHECK:       #iree_gpu.lowering_config
-// CHECK-NOT:   #iree_gpu.lowering_config
-// CHECK:       return
+// CHECK-LABEL: func.func @nhwc_layernorm_small_channel
+// CHECK:       linalg.generic {{.*}} iterator_types = ["parallel", "reduction", "reduction"]{{.*}} thread = [0, 1, 1]
+// CHECK:       linalg.generic {{.*}} iterator_types = ["parallel", "reduction", "reduction"]{{.*}} thread = [0, 1, 1]
+// CHECK:       linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel"]{{.*}} thread = [0, 0, 1]
 
 // -----