Port reduction-v3 to C++ (#11539)

Using `TileReductionUsingForeachThreadOp` in
`createReductionStrategyThreadDistribution` allows a
much more general mapping of reductions on GPU.
    
This PR additionally performs a few refactorings to generalize the
control of the transformation
to gracefully degrade from vector<4> to vector<2> to scalar.
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
index 5c72d68..1675374 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategies.h
@@ -85,6 +85,7 @@
                                 ArrayRef<int64_t> blockSize);
 
 static constexpr unsigned kCudaWarpSize = 32;
+static constexpr unsigned kCudaMaxNumThreads = 1024;
 
 /// Post-bufferization vector distribution with rank-reduction.
 /// Takes a handle to a func.func and returns an updated handle to a
diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
index c6b4b52..7cc5134 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TransformDialectStrategiesGPU.cpp
@@ -8,6 +8,7 @@
 
 #include <numeric>
 #include <type_traits>
+#include <utility>
 
 #include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
 #include "iree-dialects/Transforms/TransformMatchers.h"
@@ -27,6 +28,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -97,19 +99,116 @@
   return std::tuple_cat(a);
 }
 
+/// Matches a C++ callback previously registered under `callbackName` and
+/// taking arguments `args`.
+/// Unpacks a number of handles `N` (asserts there are exactly `N` matched ops
+/// but this could be relaxed if needed).
+/// Returns the tuple of handles.
+template <int N, typename... MatchingArgs>
+auto unpackRegisteredMatchCallback(ImplicitLocOpBuilder &b,
+                                   StringRef callbackName,
+                                   MatchingArgs... args) {
+  SmallVector<Type> matchedTypes(N, pdl::OperationType::get(b.getContext()));
+  auto matchOp = b.create<MatchCallbackOp>(
+      matchedTypes, callbackName, std::forward<decltype(args)>(args)...);
+  assert(matchOp->getNumResults() == N && "Unexpected number of results");
+  std::array<Value, N> a;
+  for (int64_t i = 0; i < N; ++i) a[i] = matchOp->getResult(i);
+  return std::tuple_cat(a);
+}
+
 //===----------------------------------------------------------------------===//
 // Higher-level problem-specific strategy creation APIs, these should favor
 // user-friendliness.
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+/// Compute good tile and vector sizes for the reduction dimension of a 1-D
+/// reduction dimension for a TileReductionUsingForeachThreadOp strategy.
+///
+/// Dynamic case: use as many threads as allowed along threadIdx.x with vector
+/// size of 1 (i.e. coalesced accesses).
+/// This can be further refined with splitting or vector masking when
+/// available.
+///
+/// Static case: perfectly tile by:
+///   - 128 to obtain 32*k threads working on vector<4xf32> with k as high as
+///   possible within the limits of maxNumThreadsToUse, when possible;
+///   - 64 to obtain 32*k threads working on vector<2xf32> with k as high as
+///   possible within the limits of maxNumThreadsToUse, when possible;
+///   - reductionDimensionSize within the limits of maxNumThreadsToUse,
+///   otherwise.
+// TODO: refine even further based on mod 2 and mod 4 only + min
+// canonicalizations.
+// TODO: refine sizes based on the bitwidth of the elemental type.
+class ReductionStrategyThreadDistributionSizes {
+ public:
+  ReductionStrategyThreadDistributionSizes(
+      int64_t reductionDimensionSize = 0,
+      int64_t maxNumThreadsToUse = iree_compiler::kCudaMaxNumThreads)
+      : reductionDimensionSize(reductionDimensionSize),
+        maxNumThreadsToUse(maxNumThreadsToUse) {
+    computeStrategy();
+  }
+  ReductionStrategyThreadDistributionSizes(
+      const ReductionStrategyThreadDistributionSizes &) = default;
+
+  ReductionStrategyThreadDistributionSizes &operator=(
+      const ReductionStrategyThreadDistributionSizes &) = default;
+
+  int64_t reductionTileSize;
+  int64_t vectorTileSize;
+
+ private:
+  void computeStrategy();
+
+  int64_t reductionDimensionSize;
+  // TODO: Characterize shared memory consumption of this strategy and limit
+  // accordingly for good occupancy.
+  int64_t maxNumThreadsToUse;
+};
+
+void ReductionStrategyThreadDistributionSizes::computeStrategy() {
+  vectorTileSize = 1;
+  reductionTileSize = maxNumThreadsToUse;
+  if (reductionDimensionSize <= 0) return;
+
+  // TODO: refine even further based on mod 2 and mod 4 only + min
+  // canonicalizations.
+  int64_t warpVector4Size = 4 * iree_compiler::kCudaWarpSize;
+  int64_t warpVector2Size = 2 * iree_compiler::kCudaWarpSize;
+  if (reductionDimensionSize % warpVector4Size == 0) {
+    int64_t f1 = reductionDimensionSize / warpVector4Size;
+    int64_t f2 = maxNumThreadsToUse / warpVector4Size;
+    reductionTileSize = std::min(f1, f2) * iree_compiler::kCudaWarpSize;
+    vectorTileSize = 4;
+  } else if (reductionDimensionSize % warpVector2Size == 0) {
+    int64_t f1 = reductionDimensionSize / warpVector2Size;
+    int64_t f2 = maxNumThreadsToUse / warpVector2Size;
+    reductionTileSize = std::min(f1, f2) * iree_compiler::kCudaWarpSize;
+    vectorTileSize = 2;
+  } else {
+    reductionTileSize = std::min(maxNumThreadsToUse, reductionDimensionSize);
+    vectorTileSize = 1;
+  }
+}
+
 /// Structure to hold the parameters related to GPU reduction strategy.
 struct GPUReductionStrategyInfos {
+  explicit GPUReductionStrategyInfos(int64_t reductionDimensionSize)
+      : reductionDimensionSize(reductionDimensionSize),
+        threadDistributionSizes(
+            ReductionStrategyThreadDistributionSizes(reductionDimensionSize)) {}
+  int64_t reductionDimensionSize;
+  ReductionStrategyThreadDistributionSizes threadDistributionSizes;
+
   std::array<int64_t, 3> workgroupSize;
   SmallVector<int64_t> workgroupTileSizes;
   SmallVector<int64_t> fillSecondTileSizes;
   SmallVector<int64_t> genericSecondTileSizes;
-  int64_t reductionDimensionSize;
 };
+}  // namespace
 
 static std::pair<Value, Value> createReductionStrategyBlockDistribution(
     ImplicitLocOpBuilder &b, Value maybeLeadingH, Value fillH, Value reductionH,
@@ -140,22 +239,7 @@
 
 static void createReductionStrategyThreadDistribution(
     ImplicitLocOpBuilder &b, Value gridReductionH, Value maybeTiledTrailingH,
-    int64_t reductionDimensionSize) {
-  // Select tile sizes. Perfectly tile by:
-  //   - 128 to obtain 32 threads working on vector<4xf32> when possible;
-  //   - 64 to obtain 32 threads working on vector<2xf32> when possible;
-  //   - 32 otherwise.
-  // TODO: refine sizes based on the bitwidth of the elemental type.
-  int64_t firstReductionSize = iree_compiler::kCudaWarpSize;
-  int64_t vectorTileSize = 1;
-  if (reductionDimensionSize % (4 * iree_compiler::kCudaWarpSize) == 0) {
-    firstReductionSize = 4 * iree_compiler::kCudaWarpSize;
-    vectorTileSize = 4;
-  } else if (reductionDimensionSize % (2 * iree_compiler::kCudaWarpSize) == 0) {
-    firstReductionSize = 2 * iree_compiler::kCudaWarpSize;
-    vectorTileSize = 2;
-  }
-
+    const ReductionStrategyThreadDistributionSizes &sizes) {
   auto pdlOperation = pdl::OperationType::get(b.getContext());
   auto threadX = mlir::gpu::GPUThreadMappingAttr::get(b.getContext(),
                                                       mlir::gpu::Threads::DimX);
@@ -164,26 +248,22 @@
 
   // Split the reduction into a parallel and combiner part, then tile the
   // parallel part and map it to a full warp so it works on vectors.
-  auto tileReduction = b.create<transform::TileReductionUsingScfOp>(
-      gridReductionH, ArrayRef<int64_t>({0, firstReductionSize}));
+  auto tileReduction = b.create<transform::TileReductionUsingForeachThreadOp>(
+      /*target=*/gridReductionH,
+      /*numThreads=*/ArrayRef<int64_t>{0, sizes.reductionTileSize},
+      /*tileSizes=*/ArrayRef<int64_t>{0, sizes.vectorTileSize},
+      /*threadDimMapping=*/b.getArrayAttr(threadX));
+  Value blockParallelForeachThreadOp = tileReduction.getForeachThreadOp();
   Value blockParallelFillH = tileReduction.getFillOp();
-  Value blockParallelOpH = tileReduction.getSplitLinalgOp();
   Value blockCombinerOpH = tileReduction.getCombiningLinalgOp();
-  iree_compiler::buildTileFuseDistToForeachThreadWithNumThreads(
-      b, blockParallelOpH, {},
-      getAsOpFoldResult(b.getI64ArrayAttr({0, iree_compiler::kCudaWarpSize})),
-      b.getArrayAttr(threadX));
 
-  // Tile the fill so it maps to vectors.
-  // TODO: fuse once the support is available
-  // (https://reviews.llvm.org/D139844).
-  iree_compiler::buildTileFuseDistToForeachThreadWithTileSizes(
-      b, blockParallelFillH, {},
-      getAsOpFoldResult(b.getI64ArrayAttr({0, vectorTileSize})),
-      b.getArrayAttr(threadX));
+  // Fuse the fill and pointwise to privatize them.
+  blockParallelFillH = b.create<FuseIntoContainingOp>(
+      blockParallelFillH, blockParallelForeachThreadOp);
 
   // Map the combiner reduction to one thread along y so it can be mapped
-  // further via predication. Fuse it into the trailing elementwise if present.
+  // further via predication. Fuse it into the trailing elementwise if
+  // present.
   auto selector = b.create<TakeFirstOp>(
       pdlOperation, pdlOperation,
       ArrayRef<Value>({maybeTiledTrailingH, blockCombinerOpH}));
@@ -195,33 +275,30 @@
 }
 
 /// Builds the transform IR tiling reductions for CUDA targets. Supports
-/// reductions in the last dimension with static shape divisible by 32 (CUDA
-/// warp size), with optional leading and trailing elementwise operations.
+/// reductions in the last dimension, with optional leading and trailing
+/// elementwise operations.
 static void createReductionCudaStrategy(
     ImplicitLocOpBuilder &b, Value variantH,
     const GPUReductionStrategyInfos &infos) {
   // Step 1. Call the matcher. Note that this is the same matcher as used to
   // trigger this compilation path, so it must always apply.
   b.create<RegisterMatchCallbacksOp>();
-  SmallVector<Type> matchedTypes(4, pdl::OperationType::get(b.getContext()));
-  auto match = b.create<MatchCallbackOp>(
-      matchedTypes, "reduction", transform::FailurePropagationMode::Propagate,
-      variantH);
-  Value maybeLeadingH = match.getResult(0);
-  Value fillH = match.getResult(1);
-  Value reductionH = match.getResult(2);
-  Value maybeTrailingH = match.getResult(3);
+  auto [maybeLeadingH, fillH, reductionH, maybeTrailingH] =
+      unpackRegisteredMatchCallback<4>(
+          b, "reduction", transform::FailurePropagationMode::Propagate,
+          variantH);
 
-  // Step 2. Use tiling to introduce a single-iteration loop mapped to a single
-  // block/workgroup. Keep everything fused.
+  // Step 2. Use tiling to introduce a single-iteration loop mapped to a
+  // single block/workgroup. Keep everything fused.
   auto [gridReductionH, maybeTiledTrailingH] =
       createReductionStrategyBlockDistribution(b, maybeLeadingH, fillH,
                                                reductionH, maybeTrailingH);
 
   // Step 3. Split the reduction and tile the pieces to ensure vector
   // load/stores and mapping to a single warp with shuffles.
-  createReductionStrategyThreadDistribution(
-      b, gridReductionH, maybeTiledTrailingH, infos.reductionDimensionSize);
+  ReductionStrategyThreadDistributionSizes sizes(infos.reductionDimensionSize);
+  createReductionStrategyThreadDistribution(b, gridReductionH,
+                                            maybeTiledTrailingH, sizes);
 
   // Step 4. Bufferize and drop HAL decriptor from memref ops.
   Value funcH = b.create<MatchOp>(variantH, func::FuncOp::getOperationName());
@@ -239,14 +316,13 @@
   iree_compiler::buildDistributeVectors(b, variantH, funcH);
 }
 
-// TODO: consider passing a problem-specific struct to control information.
-static bool matchGPUReduction(linalg::LinalgOp op,
-                              GPUReductionStrategyInfos &info) {
-  // TODO: match the sequence the strategy supports.
+static FailureOr<GPUReductionStrategyInfos> matchGPUReduction(
+    linalg::LinalgOp op) {
   StructuredOpMatcher reduction, fill, leading, trailing;
+  int64_t reductionDimensionSize;
   makeReductionMatcher(reduction, fill, leading, trailing,
-                       info.reductionDimensionSize);
-  if (!matchPattern(op, reduction)) return false;
+                       reductionDimensionSize);
+  if (!matchPattern(op, reduction)) return failure();
 
   //
   // !!We must match exactly all payload ops when the dispatch is pre-formed!!
@@ -262,29 +338,30 @@
       DBGS() << "Failed to match " << mustMatchNumPayloadOps
              << " payload ops, matched " << numMatchedOps << " instead\n";
     });
-    return false;
+    return failure();
   }
 
-  // Hardcoded workgroup size, this could be deduced from the reduction dim.
-  info.workgroupSize = {32, 1, 1};
+  GPUReductionStrategyInfos info(reductionDimensionSize);
+  info.workgroupSize = {info.threadDistributionSizes.reductionTileSize, 1, 1};
   SmallVector<unsigned> partitionedLoops =
       cast<iree_compiler::PartitionableLoopsInterface>(op.getOperation())
           .getPartitionableLoops(iree_compiler::kNumMaxParallelDims);
   size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
+
   // Tile all the parallel dimension to 1.
   info.workgroupTileSizes.append(numLoops, 1);
   info.fillSecondTileSizes = {1, 0, 0};
   info.genericSecondTileSizes = {1, 1, 0};
-  return true;
+  return info;
 }
 
 LogicalResult iree_compiler::matchAndSetGPUReductionTransformStrategy(
     func::FuncOp entryPoint, linalg::LinalgOp op) {
   // 1. Match
-  GPUReductionStrategyInfos infos;
-  if (!matchGPUReduction(op, infos)) return failure();
+  FailureOr<GPUReductionStrategyInfos> maybeInfos = matchGPUReduction(op);
+  if (failed(maybeInfos)) return failure();
   auto strategyBuilder = [&](ImplicitLocOpBuilder &b, Value variant) {
-    return createReductionCudaStrategy(b, variant, infos);
+    return createReductionCudaStrategy(b, variant, *maybeInfos);
   };
   // 2. Add the strategy.
   createTransformRegion(entryPoint, strategyBuilder);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
index dad48b6..01d5f99 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform.mlir
@@ -31,22 +31,25 @@
 //   CHECK-LABEL: func.func @group_reduction
 //     CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
+//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
 //     CHECK-DAG:   %[[TIDX:.]] = gpu.thread_id  x
-//         CHECK:   %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-//         CHECK:   %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
-//         CHECK:   gpu.barrier
 
-// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
-//         CHECK:   vector.transfer_read
-//         CHECK:   vector.transfer_read
-//         CHECK:   arith.addf %{{.*}} : vector<2xf32>
-//         CHECK:   vector.transfer_write
-//         CHECK:   gpu.barrier
+// Fusion occurred, no barrier before the loop
+//     CHECK-NOT: gpu.barrier
+//     CHECK:   vector.transfer_read {{.*}} vector<f32>
+// Local per-thread scf.for-based reduction.
+//         CHECK: scf.for
+//         CHECK:   vector.transfer_read {{.*}} vector<2xf32>
+//         CHECK:   vector.reduction <add>{{.*}} : vector<2xf32> into f32
+//         CHECK:   vector.broadcast {{.*}} : f32 to vector<f32>
+// No barrier within the loop
+//     CHECK-NOT:   gpu.barrier
+//         CHECK:   scf.yield {{.*}} : vector<f32>
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
+//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
 //         CHECK:   %[[RES:.*]] = arith.addf %{{.*}}
@@ -95,29 +98,31 @@
 //   CHECK-LABEL: func.func @group_reduction_elementwise
 //     CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
+//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
 //     CHECK-DAG:   %[[TIDX:.]] = gpu.thread_id  x
-//         CHECK:   %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-//         CHECK:   %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
-//         CHECK:   gpu.barrier
 
-// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
-//         CHECK:   vector.transfer_read
-//         CHECK:   vector.transfer_read
-//         CHECK:   arith.addf %{{.*}} : vector<2xf32>
-//         CHECK:   vector.transfer_write
-//         CHECK:   gpu.barrier
+// Fusion occurred, no barrier before the loop
+//     CHECK-NOT: gpu.barrier
+//     CHECK:   vector.transfer_read {{.*}} vector<f32>
+// Local per-thread scf.for-based reduction.
+//         CHECK: scf.for
+//         CHECK:   vector.transfer_read {{.*}} vector<2xf32>
+//         CHECK:   vector.reduction <add>{{.*}} : vector<2xf32> into f32
+//         CHECK:   vector.broadcast {{.*}} : f32 to vector<f32>
+// No barrier within the loop
+//     CHECK-NOT:   gpu.barrier
+//         CHECK:   scf.yield {{.*}} : vector<f32>
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
+//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
 //         CHECK:   %[[PARTIAL:.*]] = arith.addf %{{.*}}
-//         CHECK:   %[[PARTIAL_VEC:.*]] = vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
-//         CHECK:   %[[ELEM:.*]] = vector.extractelement %[[PARTIAL_VEC]][]
-//         CHECK:   %[[RES:.*]] = math.sqrt %[[ELEM]]
-//         CHECK:   %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
+//         CHECK:   vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
+//         CHECK:   math.sqrt 
+//         CHECK:   %[[RES_VEC:.*]] = vector.broadcast %{{.*}}: f32 to vector<f32>
 //         CHECK:   %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
 //         CHECK:   scf.if %[[CONDXIS0]]
 //         CHECK:     vector.transfer_write %[[RES_VEC]]
@@ -159,27 +164,29 @@
 //   CHECK-LABEL: func.func @group_elementwise_reduction
 //     CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
+//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
 //     CHECK-DAG:   %[[TIDX:.]] = gpu.thread_id  x
-//         CHECK:   %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-//         CHECK:   %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
-//         CHECK:   gpu.barrier
 
-// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
-//         CHECK:   vector.transfer_read
-//         CHECK:   vector.transfer_read
-//         CHECK:   %[[PARTIAL_1:.*]] = arith.addf %[[ARG:.*]], %[[ARG]]
-//         CHECK:   %[[PARTIAL_2:.*]] = arith.addf %[[PARTIAL_1]], %[[PARTIAL_1]]
-//         CHECK:   arith.addf %[[PARTIAL_2]], %{{.*}} : vector<2xf32>
-//         CHECK:   vector.transfer_write
-//         CHECK:   gpu.barrier
+// Fusion occurred, no barrier before the loop
+//     CHECK-NOT: gpu.barrier
+//     CHECK:   vector.transfer_read {{.*}} vector<f32>
+// Local per-thread scf.for-based reduction.
+//         CHECK: scf.for
+//         CHECK:   vector.transfer_read {{.*}} vector<2xf32>
+//         CHECK:   arith.addf{{.*}} : vector<2xf32>
+//         CHECK:   arith.addf{{.*}} : vector<2xf32>
+//         CHECK:   vector.reduction <add>{{.*}} : vector<2xf32> into f32
+//         CHECK:   vector.broadcast {{.*}} : f32 to vector<f32>
+// No barrier within the loop
+//     CHECK-NOT:   gpu.barrier
+//         CHECK:   scf.yield {{.*}} : vector<f32>
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
+//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
-
 //         CHECK:   %[[RES:.*]] = arith.addf %{{.*}}
 //         CHECK:   %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
 //         CHECK:   %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
@@ -228,31 +235,33 @@
 //   CHECK-LABEL: func.func @group_elementwise_reduction_elementwise
 //     CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
+//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32xf32, 3>
 //     CHECK-DAG:   %[[TIDX:.]] = gpu.thread_id  x
-//         CHECK:   %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-//         CHECK:   %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<2xf32, strided<[1], offset: ?>, 3>
-//         CHECK:   gpu.barrier
 
-// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
-//         CHECK:   vector.transfer_read
-//         CHECK:   vector.transfer_read
-//         CHECK:   %[[PARTIAL_1:.*]] = arith.addf %[[ARG:.*]], %[[ARG]]
-//         CHECK:   %[[PARTIAL_2:.*]] = arith.addf %[[PARTIAL_1]], %[[PARTIAL_1]]
-//         CHECK:   arith.addf %[[PARTIAL_2]], %{{.*}} : vector<2xf32>
-//         CHECK:   vector.transfer_write
-//         CHECK:   gpu.barrier
+// Fusion occurred, no barrier before the loop
+//     CHECK-NOT: gpu.barrier
+//     CHECK:   vector.transfer_read {{.*}} vector<f32>
+// Local per-thread scf.for-based reduction.
+//         CHECK: scf.for
+//         CHECK:   vector.transfer_read {{.*}} vector<2xf32>
+//         CHECK:   arith.addf{{.*}} : vector<2xf32>
+//         CHECK:   arith.addf{{.*}} : vector<2xf32>
+//         CHECK:   vector.reduction <add>{{.*}} : vector<2xf32> into f32
+//         CHECK:   vector.broadcast {{.*}} : f32 to vector<f32>
+// No barrier within the loop
+//     CHECK-NOT:   gpu.barrier
+//         CHECK:   scf.yield {{.*}} : vector<f32>
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
-//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
+//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[TIDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
 //         CHECK:   %[[PARTIAL:.*]] = arith.addf %{{.*}}
-//         CHECK:   %[[PARTIAL_VEC:.*]] = vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
-//         CHECK:   %[[ELEM:.*]] = vector.extractelement %[[PARTIAL_VEC]][]
-//         CHECK:   %[[RES:.*]] = math.sqrt %[[ELEM]]
-//         CHECK:   %[[RES_VEC:.*]] = vector.broadcast %[[RES]] : f32 to vector<f32>
+//         CHECK:   vector.broadcast %[[PARTIAL]] : f32 to vector<f32>
+//         CHECK:   math.sqrt
+//         CHECK:   %[[RES_VEC:.*]] = vector.broadcast %{{.*}}: f32 to vector<f32>
 //         CHECK:   %[[CONDXIS0:.*]] = arith.cmpi eq, %[[TIDX]], %[[C0]] : index
 //         CHECK:   scf.if %[[CONDXIS0]]
 //         CHECK:     vector.transfer_write %[[RES_VEC]]
@@ -292,21 +301,25 @@
 //   CHECK-LABEL: func.func @group_reduction_larger
 //     CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //     CHECK-DAG:   %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
-//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x128xf32, 3>
+//     CHECK-DAG:   %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x64xf32, 3>
 //     CHECK-DAG:   %[[TIDX:.]] = gpu.thread_id  x
-//         CHECK:   %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
-//         CHECK:   %[[SHMEM_VIEW_EXPANDED:.*]] = memref.subview %[[SHMEM_ALLOC]][0, %[[IDX]]]{{.*}}to memref<4xf32, strided<[1], offset: ?>, 3>
-//         CHECK:   gpu.barrier
 
-// Local per-thread scf.for-based reduction, after the single-iteration scf.for was canonicalized.
+// Fusion occurred, no barrier before the loop
+//     CHECK-NOT: gpu.barrier
+//     CHECK:   vector.transfer_read {{.*}} vector<f32>
+// Local per-thread scf.for-based reduction.
+//         CHECK: scf.for
 //         CHECK:   vector.transfer_read
-//         CHECK:   vector.transfer_read
-//         CHECK:   arith.addf %{{.*}} : vector<4xf32>
-//         CHECK:   vector.transfer_write
-//         CHECK:   gpu.barrier
+//         CHECK:   vector.reduction <add>{{.*}} : vector<4xf32> into f32
+//         CHECK:   vector.broadcast {{.*}} : f32 to vector<f32>
+// No barrier within the loop
+//     CHECK-NOT:   gpu.barrier
+//         CHECK:   scf.yield {{.*}} : vector<f32>
 
 // Distributed reduction: everyone loads then 5 xor + addf expected.
 //         CHECK: %[[TIDY:.]] = gpu.thread_id  y
+//         CHECK: vector.transfer_read %{{.*}}[]
+//         CHECK: %[[IDX:.*]] = affine.apply{{.*}}%[[TIDX]]
 //         CHECK: vector.transfer_read %{{.*}}[%[[TIDY]], %[[IDX]]]
 // CHECK-COUNT-5: gpu.shuffle  xor{{.*}}{{[[:space:]].*}}{{.*}} arith.addf
 
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
index 0f3ab31..baf04a3 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/set_transform_strategy.mlir
@@ -35,15 +35,12 @@
 //         CHECK:   transform.iree.tile_to_foreach_thread_and_workgroup_count_region {{.*}} tile_sizes [1](mapping = [#gpu.block<x>])
 // CHECK-COUNT-3:   transform.structured.fuse_into_containing_op
 //         CHECK:   transform.iree.take_first
-//         CHECK:   transform.structured.tile_reduction_using_scf %{{.*}} by tile_sizes = [0, 64]
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [0, 32]
-//    CHECK-SAME:      (mapping = [#gpu.thread<x>])
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [0, 2](mapping = [#gpu.thread<x>])
+//         CHECK:   tile_reduction_using_foreach_thread {{.*}} by num_threads = [0, 32], tile_sizes = [0, 2], mapping = [#gpu.thread<x>]
+//         CHECK:   transform.structured.fuse_into_containing_op
 //         CHECK:   transform.iree.take_first
 //         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [1](mapping = [#gpu.thread<y>])
 //         CHECK:   transform.structured.fuse_into_containing_op
 //         CHECK:   transform.structured.match ops{["func.func"]} in %arg0
-//         CHECK:   transform.iree.apply_patterns %{{.*}} {rank_reducing}
 //         CHECK:   transform.structured.vectorize
 //         CHECK:   transform.iree.bufferize {target_gpu}
 //         CHECK:   transform.structured.match ops{["func.func"]} in %{{.*}}
@@ -95,10 +92,8 @@
 
 //   CHECK-LABEL: func.func @group_reduction_128
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
-//         CHECK:   transform.structured.tile_reduction_using_scf %{{.*}} by tile_sizes = [0, 128]
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [0, 32]
-//    CHECK-SAME:      (mapping = [#gpu.thread<x>])
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [0, 4](mapping = [#gpu.thread<x>])
+//         CHECK:   transform.structured.tile_reduction_using_foreach_thread %{{.*}} by num_threads = [0, 32], tile_sizes = [0, 4], mapping = [#gpu.thread<x>]
+//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [32, 1, 1]}
 
 // -----
 
@@ -136,7 +131,5 @@
 
 //   CHECK-LABEL: func.func @group_reduction_32
 //         CHECK:   transform.structured.canonicalized_sequence failures(propagate)
-//         CHECK:   transform.structured.tile_reduction_using_scf %{{.*}} by tile_sizes = [0, 32]
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} num_threads [0, 32]
-//    CHECK-SAME:      (mapping = [#gpu.thread<x>])
-//         CHECK:   transform.structured.tile_to_foreach_thread_op %{{.*}} tile_sizes [0, 1](mapping = [#gpu.thread<x>])
+//         CHECK:   transform.structured.tile_reduction_using_foreach_thread %{{.*}} by num_threads = [0, 32], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
+//         CHECK:   transform.iree.map_nested_foreach_thread_to_gpu_threads %{{.*}} {workgroup_size = [32, 1, 1]}
diff --git a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
index 25b81c0..25ef5ee 100644
--- a/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Transforms/TransformMatchers.cpp
@@ -317,9 +317,10 @@
                  .output(NumEqualsTo(1));
   leading = trailing;
   reduction = m_StructuredOp()
-                  .dim(AllDims(), ShapeKind::Static)
+                  // The CUDA strategy now supports arbitray mixes of static and
+                  // dynamic sizes and adapts accordingly.
+                  // Consider separating control for other targets if needed.
                   .dim(-1, utils::IteratorType::reduction)
-                  .dim(-1, DivisibleBy(kCudaWarpSize))
                   .dim(-1, CaptureDim(reductionDimensionSize))
                   // Can be extended to projected permutation with broadcast.
                   .input(AllOperands(), IsPermutation())
diff --git a/tests/transform_dialect/cuda/reduction_v2.mlir b/tests/transform_dialect/cuda/reduction_v2.mlir
index da9a90c..6372a66 100644
--- a/tests/transform_dialect/cuda/reduction_v2.mlir
+++ b/tests/transform_dialect/cuda/reduction_v2.mlir
@@ -27,15 +27,6 @@
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
 // RUN: FileCheck %s --check-prefix=CHECK
 
-// RUN: iree-opt %s --iree-hal-target-backends=cuda \
-// RUN:     --iree-abi-transformation-pipeline \
-// RUN:     --iree-flow-transformation-pipeline  \
-// RUN:     --iree-stream-transformation-pipeline \
-// RUN:     --iree-hal-configuration-pipeline | \
-// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-lower-executable-target)))' \
-// RUN:      --iree-codegen-llvmgpu-enable-transform-dialect-jit | \
-// RUN: FileCheck %s --check-prefix=CHECK
-
 // RUN: iree-compile %s --iree-hal-target-backends=cuda \
 // RUN:     --iree-codegen-llvmgpu-use-transform-dialect=%p/reduction_v2_codegen_spec.mlir | \
 // RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="33x1024xf32=1" |\
diff --git a/tests/transform_dialect/cuda/reduction_v3.mlir b/tests/transform_dialect/cuda/reduction_v3.mlir
index 7070b79..e627ef5 100644
--- a/tests/transform_dialect/cuda/reduction_v3.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3.mlir
@@ -34,6 +34,11 @@
 // RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="123x4567xf32=1" |\
 // RUN: FileCheck %s --check-prefix=EXEC
 
+// RUN: iree-compile %s --iree-hal-target-backends=cuda \
+// RUN:     --iree-codegen-llvmgpu-enable-transform-dialect-jit | \
+// RUN: iree-run-module --entry_function=reduce --device=cuda --function_input="123x4567xf32=1" |\
+// RUN: FileCheck %s --check-prefix=EXEC
+
   //     CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
   //     CHECK-DAG: %[[workgroup_id_x:.*]] = hal.interface.workgroup.id[0] : index
   //     CHECK-DAG: %[[SHMEM_ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x1024xf32, 3>
diff --git a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
index 5b02ef6..1b188b9 100644
--- a/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
+++ b/tests/transform_dialect/cuda/reduction_v3_codegen_spec.mlir
@@ -19,7 +19,7 @@
      transform.structured.tile_reduction_using_foreach_thread %grid_reduction 
         by num_threads = [0, 1024], tile_sizes = [0, 1], mapping = [#gpu.thread<x>]
 
-  // Fuse the fill and pointwise to privatize them. 
+  // Fuse the fill and pointwise to privatize them.
   transform.structured.fuse_into_containing_op %block_more_parallel_fill_op_2
     into %foreach_thread