[LLVMGPUVectorDistribute] Add support for inter-subgroup multi_reduction (#19596)

This commit adds support for distribute multi_reductions where the
reduction dimension(s) is/are distributed across subgroups.

We perform the existing reduction distribution, however, we are left
with partial reductions accross subgroups.

Thereafter, we insert tranfer_write / transfer_read to shared memory to
achieve a layout change where
we re-distribute reduction subgroup tiles into element tile. Finally, we
do another multi_reduction to complete the reduction.

closes: https://github.com/iree-org/iree/issues/19578

---------

Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 0b4d812..49981f1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -414,8 +414,9 @@
 ///      by doing a butterfly shuffle.
 ///   3. Accumulator Reduce: Each thread reduces it's intermediate reduced
 ///      results with the accumulator it holds.
-/// Currently, reduction across warps is not supported, but it would just add
-/// another step, Warp Reduce, where threads do an atomic addition on a buffer.
+///   4. Subgroup reduce : each subgroup will store the partial reductions
+///      to shared memory and will be reloaded into a layout where partial
+///      reductions will be placed inside threads.
 struct DistributeMultiReduction final
     : OpDistributionPattern<vector::MultiDimReductionOp> {
   using OpDistributionPattern::OpDistributionPattern;
@@ -460,7 +461,6 @@
     }
 
     Location loc = multiReduceOp.getLoc();
-
     SmallVector<bool> reducedDims = multiReduceOp.getReductionMask();
     int64_t rank = srcVector.getType().getRank();
 
@@ -492,25 +492,32 @@
     assert(locallyReduced && "result should have been a vector");
 
     // Flatten the locally reduced value.
+    VectorValue threadReduced = locallyReduced;
     VectorType shaped = locallyReduced.getType();
-    int64_t numElements = shaped.getNumElements();
-    SmallVector<int64_t> flatShape(1, numElements);
-    VectorType flatVecType = VectorType::get(flatShape, elemTy);
-    VectorValue flat =
-        rewriter.create<vector::ShapeCastOp>(loc, flatVecType, locallyReduced);
+    bool hasThreadReductions =
+        llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
+          return srcLayout.getThreadTile()[rDim] > 1;
+        });
+    if (hasThreadReductions) {
+      int64_t numElements = shaped.getNumElements();
+      SmallVector<int64_t> flatShape(1, numElements);
+      VectorType flatVecType = VectorType::get(flatShape, elemTy);
+      VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
+                                                              locallyReduced);
 
-    // Do inter-thread/warp reduce.
-    FailureOr<VectorValue> threadReduced = doThreadReduction(
-        rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
-    if (failed(threadReduced)) {
-      return failure();
+      // Do inter-thread/warp reduce.
+      FailureOr<VectorValue> threadReducedFlat = doThreadReduction(
+          rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
+      if (failed(threadReducedFlat)) {
+        return failure();
+      }
+
+      // Do reduction against accumulator, which needs to be done after thread
+      // reduction.
+      threadReduced = rewriter.create<vector::ShapeCastOp>(
+          loc, shaped, threadReducedFlat.value());
     }
 
-    // Do reduction against accumulator, which needs to be done after thread
-    // reduction.
-    VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
-        loc, shaped, threadReduced.value());
-
     if (!accVector) {
       // Broadcast the scalar (e.g., f32) to a vector type (e.g., vector<f32>)
       // because the following implementation requires the operand to be a
@@ -518,21 +525,32 @@
       disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
     }
 
-    Value accReduction = vector::makeArithReduction(
-        rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc);
-    auto accReduced = dyn_cast<VectorValue>(accReduction);
-    if (!accReduced) {
-      return failure();
+    bool hasSubgroupReductions =
+        llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
+          return srcLayout.getSubgroupTile()[rDim] > 1;
+        });
+    // We can exit here if its just a subgroup reduction.
+    if (!hasSubgroupReductions) {
+      Value accReduction = vector::makeArithReduction(
+          rewriter, loc, multiReduceOp.getKind(), threadReduced, disAcc);
+      auto accReduced = dyn_cast<VectorValue>(accReduction);
+      if (!accReduced) {
+        return failure();
+      }
+      if (resVector) {
+        replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
+      } else {
+        Value accReducedVal = rewriter.create<vector::ExtractOp>(
+            loc, accReduction, ArrayRef{int64_t(0)});
+        replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
+      }
+      return success();
     }
-
-    if (resVector) {
-      replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
-    } else {
-      Value accReducedVal = rewriter.create<vector::ExtractOp>(
-          loc, accReduction, ArrayRef{int64_t(0)});
-      replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
-    }
-
+    // do inter-subgroup reductions
+    Value subgroupReduced = doSubgroupReduction(
+        rewriter, loc, srcVector, srcLayout, multiReduceOp.getReductionDims(),
+        threadReduced, multiReduceOp.getKind(), acc, signature[resVector]);
+    rewriter.replaceOp(multiReduceOp, subgroupReduced);
     return success();
   }
 
@@ -569,10 +587,185 @@
 
       res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
     }
-
     return res;
   }
 
+  // The reductions across subgroups are performed
+  // as follows:
+  // 1) Re-cover the subgroup-local result as the same rank as the
+  //    input vector
+  // 2) Write the subgroup-local reduced vector to shared memory
+  // 3) Read the subgroup-local reduced vector where partially reduced
+  //    subgroup tile is read as the element tile.
+  // 4) Perform a second reduction to complete the reduction.
+  Value doSubgroupReduction(PatternRewriter &rewriter, Location loc,
+                            VectorValue srcVector, NestedLayoutAttr srcLayout,
+                            ArrayRef<int64_t> reductionDims,
+                            VectorValue threadReduced,
+                            vector::CombiningKind kind, Value acc,
+                            VectorLayoutInterface resLayout) const {
+    // Subgroup-local / thread-local vector.multi_reduce operations
+    // will remove the reduction dimensions by definition.
+    // e.g.:
+    // p1 x p2 x p3 x r2 x r1 --> p1 x p2 x p3
+    // However, the reduction is not complete until inter-subgroup results
+    // are combined. Therefore, we need to maintain the rank to get them back to
+    // the SIMD domain to re-layout the vector.
+    // Thus, we re-insert the reduction dimensions in
+    // their original positions as :
+    // p1 x p2 x p3 -> p1 x p2 x p3 x 1 x 1
+    int64_t rank = srcLayout.getRank();
+    SmallVector<int64_t> partialReducedDistributedShape =
+        srcLayout.getDistributedShape();
+    for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
+      int64_t tileGroupOffset = tileGroupIdx * rank;
+      for (int64_t rDim : reductionDims) {
+        partialReducedDistributedShape[tileGroupOffset + rDim] = 1;
+      }
+    }
+    VectorType partialReducedDistributedType = VectorType::get(
+        partialReducedDistributedShape, srcVector.getType().getElementType());
+    Value isoRankThreadReduced = rewriter.create<vector::ShapeCastOp>(
+        loc, partialReducedDistributedType, threadReduced);
+
+    SmallVector<int64_t> preDistrShape =
+        srcLayout.getUndistributedPackedShape();
+    SmallVector<int64_t> partialReductionShape =
+        llvm::to_vector(srcVector.getType().getShape());
+    for (int64_t rDim : reductionDims) {
+      // The first #rank elements will form the subgroup tile
+      // Here we replace the input shape with subgroup tile
+      // because every other tile is reduced except the subgroup
+      // tile.
+      partialReductionShape[rDim] = preDistrShape[rDim];
+    }
+    auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
+        rewriter.getContext(), gpu::AddressSpace::Workgroup));
+    MemRefType allocType = MemRefType::get(
+        partialReductionShape, srcVector.getType().getElementType(),
+        AffineMap(), workgroupMemoryAddressSpace);
+    auto alloc = rewriter.create<memref::AllocOp>(loc, allocType);
+    VectorType unDistributedType = VectorType::get(
+        partialReductionShape, srcVector.getType().getElementType());
+    Value undistrWrite = rewriter.create<IREE::VectorExt::ToSIMDOp>(
+        loc, unDistributedType, isoRankThreadReduced);
+    Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> indices(unDistributedType.getRank(), c0);
+    SmallVector<bool> inBounds(unDistributedType.getRank(), true);
+    // Insert gpu.barrier to make sure previuos iteration
+    // of batch loop has fully read the subgroup partial
+    // reductions.
+    rewriter.create<gpu::BarrierOp>(loc);
+    auto write = rewriter.create<vector::TransferWriteOp>(
+        loc, undistrWrite, alloc, indices, inBounds);
+    // Set layouts signature for write.
+    // We need to set the layout on the srcVector/first operand.
+    auto unitAttr = UnitAttr::get(rewriter.getContext());
+    {
+      SmallVector<int64_t> subgroupTileLens =
+          llvm::to_vector(srcLayout.getSubgroupTile());
+      SmallVector<int64_t> batchTileLens =
+          llvm::to_vector(srcLayout.getBatchTile());
+      SmallVector<int64_t> outerTileLens =
+          llvm::to_vector(srcLayout.getOuterTile());
+      SmallVector<int64_t> threadTileLens =
+          llvm::to_vector(srcLayout.getThreadTile());
+      SmallVector<int64_t> elementTileLens =
+          llvm::to_vector(srcLayout.getElementTile());
+      SmallVector<int64_t> subgroupStrides =
+          llvm::to_vector(srcLayout.getSubgroupStrides());
+      SmallVector<int64_t> threadStrides =
+          llvm::to_vector(srcLayout.getThreadStrides());
+      // Replace the reduced tiles with unit dimension.
+      for (int64_t rDim : reductionDims) {
+        batchTileLens[rDim] = 1;
+        outerTileLens[rDim] = 1;
+        threadTileLens[rDim] = 1;
+        elementTileLens[rDim] = 1;
+        threadStrides[rDim] = 0;
+      }
+      auto interSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
+          rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
+          threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+      auto writeAttrs =
+          SmallVector<Attribute>(write->getNumOperands(), unitAttr);
+      writeAttrs[0] = interSubGroupLayout;
+      ArrayAttr writeOperandsAttr =
+          ArrayAttr::get(rewriter.getContext(), writeAttrs);
+      ArrayAttr writeResultsAttr = ArrayAttr::get(rewriter.getContext(), {});
+      setSignatureForRedistribution(rewriter, write.getOperation(),
+                                    writeOperandsAttr, writeResultsAttr);
+    }
+    // Insert gpu.barrier
+    rewriter.create<gpu::BarrierOp>(write.getLoc());
+    auto read = rewriter.create<vector::TransferReadOp>(loc, unDistributedType,
+                                                        alloc, indices);
+    // Create new layout where subgroup dims are squashed to
+    // element tile
+    IREE::VectorExt::NestedLayoutAttr intraSubGroupLayout;
+    {
+      // We intentionally make the subgroup tile to be 1
+      SmallVector<int64_t> subgroupTileLens =
+          llvm::to_vector(srcLayout.getSubgroupTile());
+      SmallVector<int64_t> batchTileLens =
+          llvm::to_vector(srcLayout.getBatchTile());
+      SmallVector<int64_t> outerTileLens =
+          llvm::to_vector(srcLayout.getOuterTile());
+      SmallVector<int64_t> threadTileLens =
+          llvm::to_vector(srcLayout.getThreadTile());
+      SmallVector<int64_t> elementTileLens =
+          llvm::to_vector(srcLayout.getElementTile());
+      SmallVector<int64_t> subgroupStrides =
+          llvm::to_vector(srcLayout.getSubgroupStrides());
+      SmallVector<int64_t> threadStrides =
+          llvm::to_vector(srcLayout.getThreadStrides());
+      for (int64_t rDim : reductionDims) {
+        subgroupTileLens[rDim] = 1;
+        batchTileLens[rDim] = 1;
+        outerTileLens[rDim] = 1;
+        threadTileLens[rDim] = 1;
+        // the partial reductions that was across subgroups will
+        // will be loaded as element tile. We can revisit if this
+        // need to be something else such as thread tile.
+        elementTileLens[rDim] = srcLayout.getSubgroupTile()[rDim];
+        subgroupStrides[rDim] = 0;
+        threadStrides[rDim] = 0;
+      }
+      intraSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
+          rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
+          threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+      auto readAttrs = SmallVector<Attribute>(read->getNumOperands(), unitAttr);
+      ArrayAttr readOperandsAttr =
+          ArrayAttr::get(rewriter.getContext(), readAttrs);
+      ArrayAttr readResultsAttr =
+          ArrayAttr::get(rewriter.getContext(), {intraSubGroupLayout});
+      setSignatureForRedistribution(rewriter, read.getOperation(),
+                                    readOperandsAttr, readResultsAttr);
+    }
+
+    // A newly created reduction to complete the reduction
+    // that reduces the data that was otherwise was on
+    // different subgroups.
+    auto secondReduction = rewriter.create<vector::MultiDimReductionOp>(
+        loc, kind, read, acc, reductionDims);
+    {
+      auto reduceAttrs =
+          SmallVector<Attribute>(secondReduction->getNumOperands(), unitAttr);
+      reduceAttrs[0] = intraSubGroupLayout;
+      ArrayAttr reduceResultsAttr =
+          ArrayAttr::get(rewriter.getContext(), {unitAttr});
+      if (auto dstLayout = dyn_cast_or_null<NestedLayoutAttr>(resLayout)) {
+        reduceAttrs[1] = dstLayout;
+        reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
+      }
+      ArrayAttr reduceOperandsAttr =
+          ArrayAttr::get(rewriter.getContext(), reduceAttrs);
+      setSignatureForRedistribution(rewriter, secondReduction.getOperation(),
+                                    reduceOperandsAttr, reduceResultsAttr);
+    }
+    return secondReduction.getResult();
+  }
+
   int64_t subgroupSize;
   int64_t maxBitsPerShuffle;
 };
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 2d9bfef..f36a06f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -217,6 +217,7 @@
   void notifyOperationModified(Operation *op) override {
     if (op->hasAttr(kVectorLayoutRedistributeAttrName) &&
         op->hasAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName)) {
+      op->removeAttr(kVectorLayoutRedistributeAttrName);
       toBeDistributed.push_back(op);
     }
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 473ca95..df2e5fe 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -35,6 +35,7 @@
             "gpu_lower_to_ukernels.mlir",
             "gpu_nested_layout_contract_amdgpu.mlir",
             "gpu_nested_layout_vector_distribution.mlir",
+            "gpu_nested_layout_vector_distribution_multi_reduce.mlir",
             "gpu_nested_layout_vector_distribution_step.mlir",
             "gpu_pad_operands.mlir",
             "gpu_pipeline.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index cf1cf03..2fa5ba2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -30,6 +30,7 @@
     "gpu_lower_to_ukernels.mlir"
     "gpu_nested_layout_contract_amdgpu.mlir"
     "gpu_nested_layout_vector_distribution.mlir"
+    "gpu_nested_layout_vector_distribution_multi_reduce.mlir"
     "gpu_nested_layout_vector_distribution_step.mlir"
     "gpu_pack_to_instrinsics.mlir"
     "gpu_pad_operands.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index eb2a853..3e9f155 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -1035,124 +1035,6 @@
 
 // -----
 
-#nested = #iree_vector_ext.nested_layout<
-  subgroup_tile = [1, 1],
-  // We are reducing along dim=1, so each thread will reduce
-  // 2 batches x 4 elements = 8 elements.
-  batch_tile = [2, 2],
-  outer_tile = [1, 1],
-  // We are reducing on dim=1, which is distributed over 4 threads. Based
-  // on the subgroup basis and thread order, the shuffle offset is 16.
-  thread_tile = [16, 4],
-  element_tile = [1, 4],
-
-  subgroup_strides = [1, 1],
-  thread_strides = [1, 16]
->
-
-func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
-  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
-  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
-  return %0 : vector<32xf32>
-}
-
-builtin.module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
-    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
-    transform.yield
-  }
-}
-
-// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
-// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
-// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
-// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
-// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
-// Global reduction
-// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
-// Accumulator reduction
-// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
-// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
-
-// -----
-
-#nested = #iree_vector_ext.nested_layout<
-  subgroup_tile = [1, 1],
-  // We are reducing along dim=1, so each thread will reduce
-  // 4 batches x 4 elements = 16 elements.
-  batch_tile    = [1, 4],
-  outer_tile        = [1, 1],
-  // We are reducing on dim=1, which is distributed over 2 threads. Based
-  // on the subgroup basis and thread order, the shuffle offset is 32.
-  thread_tile       = [32, 2],
-  element_tile     = [1, 4],
-
-  subgroup_strides        = [1, 1],
-  thread_strides          = [1, 32]
->
-
-func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
-  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
-  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
-  return %0 : vector<32xf32>
-}
-
-builtin.module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
-    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
-    transform.yield
-  }
-}
-
-// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
-// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
-// Global reduction
-// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
-// Accumulator reduction
-// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
-
-// -----
-
-#nested = #iree_vector_ext.nested_layout<
-  subgroup_tile = [1, 1],
-  batch_tile = [2, 2],
-  outer_tile = [1, 1],
-  thread_tile = [16, 4],
-  element_tile = [1, 4],
-
-  subgroup_strides = [1, 1],
-  thread_strides = [1, 16]
->
-
-func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 {
-  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
-  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32
-  return %0 : f32
-}
-
-builtin.module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
-    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
-    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
-    transform.yield
-  }
-}
-
-// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims
-// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
-// Global reduction
-// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32
-// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
-// Accumulator reduction
-// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32>
-
-// -----
-
 #layout = #iree_vector_ext.nested_layout<
   subgroup_tile = [1, 1],
   batch_tile = [2, 2],
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir
new file mode 100644
index 0000000..2559f10
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir
@@ -0,0 +1,180 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize -mlir-print-local-scope --cse %s | FileCheck %s
+
+#nested = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  // We are reducing along dim=1, so each thread will reduce
+  // 2 batches x 4 elements = 8 elements.
+  batch_tile = [2, 2],
+  outer_tile = [1, 1],
+  // We are reducing on dim=1, which is distributed over 4 threads. Based
+  // on the subgroup basis and thread order, the shuffle offset is 16.
+  thread_tile = [16, 4],
+  element_tile = [1, 4],
+
+  subgroup_strides = [1, 1],
+  thread_strides = [1, 16]
+>
+
+func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
+  return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
+// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
+// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
+// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
+// Global reduction
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// Accumulator reduction
+// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
+// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  // We are reducing along dim=1, so each thread will reduce
+  // 4 batches x 4 elements = 16 elements.
+  batch_tile    = [1, 4],
+  outer_tile        = [1, 1],
+  // We are reducing on dim=1, which is distributed over 2 threads. Based
+  // on the subgroup basis and thread order, the shuffle offset is 32.
+  thread_tile       = [32, 2],
+  element_tile     = [1, 4],
+
+  subgroup_strides        = [1, 1],
+  thread_strides          = [1, 32]
+>
+
+func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
+  return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
+// Global reduction
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
+// Accumulator reduction
+// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+  subgroup_tile = [1, 1],
+  batch_tile = [2, 2],
+  outer_tile = [1, 1],
+  thread_tile = [16, 4],
+  element_tile = [1, 4],
+
+  subgroup_strides = [1, 1],
+  thread_strides = [1, 16]
+>
+
+func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 {
+  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32
+  return %0 : f32
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
+// Global reduction
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32
+// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// Accumulator reduction
+// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+  // There will two partial reductions across
+  // two subgroups.
+  subgroup_tile = [1, 2],
+  // We are reducing along dim=1, so each thread will reduce
+  // 1 batches x 4 elements = 4 elements.
+  batch_tile = [2, 1],
+  outer_tile = [1, 1],
+  // We are reducing on dim=1, which is distributed over 4 threads. Based
+  // on the subgroup basis and thread order, the shuffle offset is 16.
+  thread_tile = [16, 4],
+  element_tile = [1, 4],
+
+  subgroup_strides = [2, 1],
+  thread_strides = [1, 16]
+>
+
+func.func @inter_subgroup_reduction(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+  %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+  %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
+  return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+  transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+    %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+    transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func @inter_subgroup_reduction
+// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x1x1x1x1x2xf32>
+// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<2x1x1x1x1x4xf32> to vector<2x1x1xf32>
+// Thread reduction
+// CHECK: %[[THREAD_RED0:.+]] = gpu.subgroup_reduce  maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// CHECK: %[[THREAD_RED1:.+]] = vector.insert %[[THREAD_RED0]], %cst_1 [0] : f32 into vector<2xf32>
+// CHECK: %[[THREAD_RED2:.+]] = gpu.subgroup_reduce  maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// CHECK: %[[THREAD_RED3:.+]] = vector.insert %[[THREAD_RED2]], %[[THREAD_RED1]] [1] : f32 into vector<2xf32>
+// CHECK: %[[THREAD_RED4:.+]] = vector.shape_cast %[[THREAD_RED3]] : vector<2xf32> to vector<2x1x1xf32>
+// Subgroup reduction
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x2xf32, #gpu.address_space<workgroup>>
+// CHECK: gpu.barrier
+// CHECK-DAG: %[[TIDX0:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
+// CHECK-DAG: %[[TIDX1:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16 + 16)>()[%thread_id_x]
+// CHECK-DAG: %[[SGIDX:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 64) mod 2)>()[%thread_id_x]
+// CHECK-DAG: %[[EXTRACT0:.+]] = vector.extract %[[THREAD_RED4]][0] : vector<1x1xf32> from vector<2x1x1xf32>
+// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[THREAD_RED4]][1] : vector<1x1xf32> from vector<2x1x1xf32>
+// CHECK-DAG: vector.transfer_write %[[EXTRACT0]], %[[ALLOC]][%[[TIDX0]], %[[SGIDX]]]
+// CHECK-DAG: vector.transfer_write %[[EXTRACT1]], %[[ALLOC]][%[[TIDX1]], %[[SGIDX]]]
+// CHECK: gpu.barrier
+// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %alloc[%[[TIDX0]], %c0], {{.*}} {in_bounds = [false, true]} : memref<32x2xf32, #gpu.address_space<workgroup>>, vector<1x2xf32>
+// CHECK-DAG: %[[GATHER0:.+]] = vector.insert_strided_slice %[[READ0]], %[[CST]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x2xf32> into vector<2x1x1x1x1x2xf32>
+// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %alloc[%[[TIDX1]], %c0], %cst_0 {in_bounds = [false, true]} : memref<32x2xf32, #gpu.address_space<workgroup>>, vector<1x2xf32>
+// CHECK-DAG: %[[GATHER1:.+]] = vector.insert_strided_slice %[[READ1]], %[[GATHER0]] {offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x2xf32> into vector<2x1x1x1x1x2xf32>
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_simt %arg1 : vector<32xf32> -> vector<2x1x1xf32>
+// CHECK-DAG: %[[SGRED:.+]] = vector.multi_reduction <maximumf>, %[[GATHER1]], {{.*}} [1, 3, 5] : vector<2x1x1x1x1x2xf32> to vector<2x1x1xf32>
+// CHECK-DAG: arith.maximumf %[[SGRED]], %[[ACC]] : vector<2x1x1xf32>