[LLVMGPUVectorDistribute] Add support for inter-subgroup multi_reduction (#19596)
This commit adds support for distribute multi_reductions where the
reduction dimension(s) is/are distributed across subgroups.
We perform the existing reduction distribution, however, we are left
with partial reductions accross subgroups.
Thereafter, we insert tranfer_write / transfer_read to shared memory to
achieve a layout change where
we re-distribute reduction subgroup tiles into element tile. Finally, we
do another multi_reduction to complete the reduction.
closes: https://github.com/iree-org/iree/issues/19578
---------
Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 0b4d812..49981f1 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -414,8 +414,9 @@
/// by doing a butterfly shuffle.
/// 3. Accumulator Reduce: Each thread reduces it's intermediate reduced
/// results with the accumulator it holds.
-/// Currently, reduction across warps is not supported, but it would just add
-/// another step, Warp Reduce, where threads do an atomic addition on a buffer.
+/// 4. Subgroup reduce : each subgroup will store the partial reductions
+/// to shared memory and will be reloaded into a layout where partial
+/// reductions will be placed inside threads.
struct DistributeMultiReduction final
: OpDistributionPattern<vector::MultiDimReductionOp> {
using OpDistributionPattern::OpDistributionPattern;
@@ -460,7 +461,6 @@
}
Location loc = multiReduceOp.getLoc();
-
SmallVector<bool> reducedDims = multiReduceOp.getReductionMask();
int64_t rank = srcVector.getType().getRank();
@@ -492,25 +492,32 @@
assert(locallyReduced && "result should have been a vector");
// Flatten the locally reduced value.
+ VectorValue threadReduced = locallyReduced;
VectorType shaped = locallyReduced.getType();
- int64_t numElements = shaped.getNumElements();
- SmallVector<int64_t> flatShape(1, numElements);
- VectorType flatVecType = VectorType::get(flatShape, elemTy);
- VectorValue flat =
- rewriter.create<vector::ShapeCastOp>(loc, flatVecType, locallyReduced);
+ bool hasThreadReductions =
+ llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
+ return srcLayout.getThreadTile()[rDim] > 1;
+ });
+ if (hasThreadReductions) {
+ int64_t numElements = shaped.getNumElements();
+ SmallVector<int64_t> flatShape(1, numElements);
+ VectorType flatVecType = VectorType::get(flatShape, elemTy);
+ VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
+ locallyReduced);
- // Do inter-thread/warp reduce.
- FailureOr<VectorValue> threadReduced = doThreadReduction(
- rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
- if (failed(threadReduced)) {
- return failure();
+ // Do inter-thread/warp reduce.
+ FailureOr<VectorValue> threadReducedFlat = doThreadReduction(
+ rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
+ if (failed(threadReducedFlat)) {
+ return failure();
+ }
+
+ // Do reduction against accumulator, which needs to be done after thread
+ // reduction.
+ threadReduced = rewriter.create<vector::ShapeCastOp>(
+ loc, shaped, threadReducedFlat.value());
}
- // Do reduction against accumulator, which needs to be done after thread
- // reduction.
- VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
- loc, shaped, threadReduced.value());
-
if (!accVector) {
// Broadcast the scalar (e.g., f32) to a vector type (e.g., vector<f32>)
// because the following implementation requires the operand to be a
@@ -518,21 +525,32 @@
disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
}
- Value accReduction = vector::makeArithReduction(
- rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc);
- auto accReduced = dyn_cast<VectorValue>(accReduction);
- if (!accReduced) {
- return failure();
+ bool hasSubgroupReductions =
+ llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
+ return srcLayout.getSubgroupTile()[rDim] > 1;
+ });
+ // We can exit here if its just a subgroup reduction.
+ if (!hasSubgroupReductions) {
+ Value accReduction = vector::makeArithReduction(
+ rewriter, loc, multiReduceOp.getKind(), threadReduced, disAcc);
+ auto accReduced = dyn_cast<VectorValue>(accReduction);
+ if (!accReduced) {
+ return failure();
+ }
+ if (resVector) {
+ replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
+ } else {
+ Value accReducedVal = rewriter.create<vector::ExtractOp>(
+ loc, accReduction, ArrayRef{int64_t(0)});
+ replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
+ }
+ return success();
}
-
- if (resVector) {
- replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
- } else {
- Value accReducedVal = rewriter.create<vector::ExtractOp>(
- loc, accReduction, ArrayRef{int64_t(0)});
- replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
- }
-
+ // do inter-subgroup reductions
+ Value subgroupReduced = doSubgroupReduction(
+ rewriter, loc, srcVector, srcLayout, multiReduceOp.getReductionDims(),
+ threadReduced, multiReduceOp.getKind(), acc, signature[resVector]);
+ rewriter.replaceOp(multiReduceOp, subgroupReduced);
return success();
}
@@ -569,10 +587,185 @@
res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
}
-
return res;
}
+ // The reductions across subgroups are performed
+ // as follows:
+ // 1) Re-cover the subgroup-local result as the same rank as the
+ // input vector
+ // 2) Write the subgroup-local reduced vector to shared memory
+ // 3) Read the subgroup-local reduced vector where partially reduced
+ // subgroup tile is read as the element tile.
+ // 4) Perform a second reduction to complete the reduction.
+ Value doSubgroupReduction(PatternRewriter &rewriter, Location loc,
+ VectorValue srcVector, NestedLayoutAttr srcLayout,
+ ArrayRef<int64_t> reductionDims,
+ VectorValue threadReduced,
+ vector::CombiningKind kind, Value acc,
+ VectorLayoutInterface resLayout) const {
+ // Subgroup-local / thread-local vector.multi_reduce operations
+ // will remove the reduction dimensions by definition.
+ // e.g.:
+ // p1 x p2 x p3 x r2 x r1 --> p1 x p2 x p3
+ // However, the reduction is not complete until inter-subgroup results
+ // are combined. Therefore, we need to maintain the rank to get them back to
+ // the SIMD domain to re-layout the vector.
+ // Thus, we re-insert the reduction dimensions in
+ // their original positions as :
+ // p1 x p2 x p3 -> p1 x p2 x p3 x 1 x 1
+ int64_t rank = srcLayout.getRank();
+ SmallVector<int64_t> partialReducedDistributedShape =
+ srcLayout.getDistributedShape();
+ for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
+ int64_t tileGroupOffset = tileGroupIdx * rank;
+ for (int64_t rDim : reductionDims) {
+ partialReducedDistributedShape[tileGroupOffset + rDim] = 1;
+ }
+ }
+ VectorType partialReducedDistributedType = VectorType::get(
+ partialReducedDistributedShape, srcVector.getType().getElementType());
+ Value isoRankThreadReduced = rewriter.create<vector::ShapeCastOp>(
+ loc, partialReducedDistributedType, threadReduced);
+
+ SmallVector<int64_t> preDistrShape =
+ srcLayout.getUndistributedPackedShape();
+ SmallVector<int64_t> partialReductionShape =
+ llvm::to_vector(srcVector.getType().getShape());
+ for (int64_t rDim : reductionDims) {
+ // The first #rank elements will form the subgroup tile
+ // Here we replace the input shape with subgroup tile
+ // because every other tile is reduced except the subgroup
+ // tile.
+ partialReductionShape[rDim] = preDistrShape[rDim];
+ }
+ auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
+ rewriter.getContext(), gpu::AddressSpace::Workgroup));
+ MemRefType allocType = MemRefType::get(
+ partialReductionShape, srcVector.getType().getElementType(),
+ AffineMap(), workgroupMemoryAddressSpace);
+ auto alloc = rewriter.create<memref::AllocOp>(loc, allocType);
+ VectorType unDistributedType = VectorType::get(
+ partialReductionShape, srcVector.getType().getElementType());
+ Value undistrWrite = rewriter.create<IREE::VectorExt::ToSIMDOp>(
+ loc, unDistributedType, isoRankThreadReduced);
+ Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> indices(unDistributedType.getRank(), c0);
+ SmallVector<bool> inBounds(unDistributedType.getRank(), true);
+ // Insert gpu.barrier to make sure previuos iteration
+ // of batch loop has fully read the subgroup partial
+ // reductions.
+ rewriter.create<gpu::BarrierOp>(loc);
+ auto write = rewriter.create<vector::TransferWriteOp>(
+ loc, undistrWrite, alloc, indices, inBounds);
+ // Set layouts signature for write.
+ // We need to set the layout on the srcVector/first operand.
+ auto unitAttr = UnitAttr::get(rewriter.getContext());
+ {
+ SmallVector<int64_t> subgroupTileLens =
+ llvm::to_vector(srcLayout.getSubgroupTile());
+ SmallVector<int64_t> batchTileLens =
+ llvm::to_vector(srcLayout.getBatchTile());
+ SmallVector<int64_t> outerTileLens =
+ llvm::to_vector(srcLayout.getOuterTile());
+ SmallVector<int64_t> threadTileLens =
+ llvm::to_vector(srcLayout.getThreadTile());
+ SmallVector<int64_t> elementTileLens =
+ llvm::to_vector(srcLayout.getElementTile());
+ SmallVector<int64_t> subgroupStrides =
+ llvm::to_vector(srcLayout.getSubgroupStrides());
+ SmallVector<int64_t> threadStrides =
+ llvm::to_vector(srcLayout.getThreadStrides());
+ // Replace the reduced tiles with unit dimension.
+ for (int64_t rDim : reductionDims) {
+ batchTileLens[rDim] = 1;
+ outerTileLens[rDim] = 1;
+ threadTileLens[rDim] = 1;
+ elementTileLens[rDim] = 1;
+ threadStrides[rDim] = 0;
+ }
+ auto interSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
+ rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
+ threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+ auto writeAttrs =
+ SmallVector<Attribute>(write->getNumOperands(), unitAttr);
+ writeAttrs[0] = interSubGroupLayout;
+ ArrayAttr writeOperandsAttr =
+ ArrayAttr::get(rewriter.getContext(), writeAttrs);
+ ArrayAttr writeResultsAttr = ArrayAttr::get(rewriter.getContext(), {});
+ setSignatureForRedistribution(rewriter, write.getOperation(),
+ writeOperandsAttr, writeResultsAttr);
+ }
+ // Insert gpu.barrier
+ rewriter.create<gpu::BarrierOp>(write.getLoc());
+ auto read = rewriter.create<vector::TransferReadOp>(loc, unDistributedType,
+ alloc, indices);
+ // Create new layout where subgroup dims are squashed to
+ // element tile
+ IREE::VectorExt::NestedLayoutAttr intraSubGroupLayout;
+ {
+ // We intentionally make the subgroup tile to be 1
+ SmallVector<int64_t> subgroupTileLens =
+ llvm::to_vector(srcLayout.getSubgroupTile());
+ SmallVector<int64_t> batchTileLens =
+ llvm::to_vector(srcLayout.getBatchTile());
+ SmallVector<int64_t> outerTileLens =
+ llvm::to_vector(srcLayout.getOuterTile());
+ SmallVector<int64_t> threadTileLens =
+ llvm::to_vector(srcLayout.getThreadTile());
+ SmallVector<int64_t> elementTileLens =
+ llvm::to_vector(srcLayout.getElementTile());
+ SmallVector<int64_t> subgroupStrides =
+ llvm::to_vector(srcLayout.getSubgroupStrides());
+ SmallVector<int64_t> threadStrides =
+ llvm::to_vector(srcLayout.getThreadStrides());
+ for (int64_t rDim : reductionDims) {
+ subgroupTileLens[rDim] = 1;
+ batchTileLens[rDim] = 1;
+ outerTileLens[rDim] = 1;
+ threadTileLens[rDim] = 1;
+ // the partial reductions that was across subgroups will
+ // will be loaded as element tile. We can revisit if this
+ // need to be something else such as thread tile.
+ elementTileLens[rDim] = srcLayout.getSubgroupTile()[rDim];
+ subgroupStrides[rDim] = 0;
+ threadStrides[rDim] = 0;
+ }
+ intraSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
+ rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
+ threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+ auto readAttrs = SmallVector<Attribute>(read->getNumOperands(), unitAttr);
+ ArrayAttr readOperandsAttr =
+ ArrayAttr::get(rewriter.getContext(), readAttrs);
+ ArrayAttr readResultsAttr =
+ ArrayAttr::get(rewriter.getContext(), {intraSubGroupLayout});
+ setSignatureForRedistribution(rewriter, read.getOperation(),
+ readOperandsAttr, readResultsAttr);
+ }
+
+ // A newly created reduction to complete the reduction
+ // that reduces the data that was otherwise was on
+ // different subgroups.
+ auto secondReduction = rewriter.create<vector::MultiDimReductionOp>(
+ loc, kind, read, acc, reductionDims);
+ {
+ auto reduceAttrs =
+ SmallVector<Attribute>(secondReduction->getNumOperands(), unitAttr);
+ reduceAttrs[0] = intraSubGroupLayout;
+ ArrayAttr reduceResultsAttr =
+ ArrayAttr::get(rewriter.getContext(), {unitAttr});
+ if (auto dstLayout = dyn_cast_or_null<NestedLayoutAttr>(resLayout)) {
+ reduceAttrs[1] = dstLayout;
+ reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
+ }
+ ArrayAttr reduceOperandsAttr =
+ ArrayAttr::get(rewriter.getContext(), reduceAttrs);
+ setSignatureForRedistribution(rewriter, secondReduction.getOperation(),
+ reduceOperandsAttr, reduceResultsAttr);
+ }
+ return secondReduction.getResult();
+ }
+
int64_t subgroupSize;
int64_t maxBitsPerShuffle;
};
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
index 2d9bfef..f36a06f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp
@@ -217,6 +217,7 @@
void notifyOperationModified(Operation *op) override {
if (op->hasAttr(kVectorLayoutRedistributeAttrName) &&
op->hasAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName)) {
+ op->removeAttr(kVectorLayoutRedistributeAttrName);
toBeDistributed.push_back(op);
}
}
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index 473ca95..df2e5fe 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -35,6 +35,7 @@
"gpu_lower_to_ukernels.mlir",
"gpu_nested_layout_contract_amdgpu.mlir",
"gpu_nested_layout_vector_distribution.mlir",
+ "gpu_nested_layout_vector_distribution_multi_reduce.mlir",
"gpu_nested_layout_vector_distribution_step.mlir",
"gpu_pad_operands.mlir",
"gpu_pipeline.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index cf1cf03..2fa5ba2 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -30,6 +30,7 @@
"gpu_lower_to_ukernels.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
+ "gpu_nested_layout_vector_distribution_multi_reduce.mlir"
"gpu_nested_layout_vector_distribution_step.mlir"
"gpu_pack_to_instrinsics.mlir"
"gpu_pad_operands.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
index eb2a853..3e9f155 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir
@@ -1035,124 +1035,6 @@
// -----
-#nested = #iree_vector_ext.nested_layout<
- subgroup_tile = [1, 1],
- // We are reducing along dim=1, so each thread will reduce
- // 2 batches x 4 elements = 8 elements.
- batch_tile = [2, 2],
- outer_tile = [1, 1],
- // We are reducing on dim=1, which is distributed over 4 threads. Based
- // on the subgroup basis and thread order, the shuffle offset is 16.
- thread_tile = [16, 4],
- element_tile = [1, 4],
-
- subgroup_strides = [1, 1],
- thread_strides = [1, 16]
->
-
-func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
- %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
- %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
- return %0 : vector<32xf32>
-}
-
-builtin.module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
- %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
- transform.yield
- }
-}
-
-// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
-// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
-// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
-// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
-// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
-// Global reduction
-// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
-// Accumulator reduction
-// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
-// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
-
-// -----
-
-#nested = #iree_vector_ext.nested_layout<
- subgroup_tile = [1, 1],
- // We are reducing along dim=1, so each thread will reduce
- // 4 batches x 4 elements = 16 elements.
- batch_tile = [1, 4],
- outer_tile = [1, 1],
- // We are reducing on dim=1, which is distributed over 2 threads. Based
- // on the subgroup basis and thread order, the shuffle offset is 32.
- thread_tile = [32, 2],
- element_tile = [1, 4],
-
- subgroup_strides = [1, 1],
- thread_strides = [1, 32]
->
-
-func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
- %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
- %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
- return %0 : vector<32xf32>
-}
-
-builtin.module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
- %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
- transform.yield
- }
-}
-
-// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
-// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
-// Global reduction
-// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
-// Accumulator reduction
-// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
-
-// -----
-
-#nested = #iree_vector_ext.nested_layout<
- subgroup_tile = [1, 1],
- batch_tile = [2, 2],
- outer_tile = [1, 1],
- thread_tile = [16, 4],
- element_tile = [1, 4],
-
- subgroup_strides = [1, 1],
- thread_strides = [1, 16]
->
-
-func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 {
- %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
- %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32
- return %0 : f32
-}
-
-builtin.module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
- %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
- transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
- transform.yield
- }
-}
-
-// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims
-// Local reduction
-// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
-// Global reduction
-// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32
-// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
-// Accumulator reduction
-// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32>
-
-// -----
-
#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [2, 2],
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir
new file mode 100644
index 0000000..2559f10
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir
@@ -0,0 +1,180 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize -mlir-print-local-scope --cse %s | FileCheck %s
+
+#nested = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ // We are reducing along dim=1, so each thread will reduce
+ // 2 batches x 4 elements = 8 elements.
+ batch_tile = [2, 2],
+ outer_tile = [1, 1],
+ // We are reducing on dim=1, which is distributed over 4 threads. Based
+ // on the subgroup basis and thread order, the shuffle offset is 16.
+ thread_tile = [16, 4],
+ element_tile = [1, 4],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [1, 16]
+>
+
+func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+ %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+ %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
+ return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1
+// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32>
+// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32>
+// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32>
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32>
+// Global reduction
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// Accumulator reduction
+// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32>
+// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ // We are reducing along dim=1, so each thread will reduce
+ // 4 batches x 4 elements = 16 elements.
+ batch_tile = [1, 4],
+ outer_tile = [1, 1],
+ // We are reducing on dim=1, which is distributed over 2 threads. Based
+ // on the subgroup basis and thread order, the shuffle offset is 32.
+ thread_tile = [32, 2],
+ element_tile = [1, 4],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [1, 32]
+>
+
+func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+ %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+ %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
+ return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32>
+// Global reduction
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32
+// Accumulator reduction
+// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [2, 2],
+ outer_tile = [1, 1],
+ thread_tile = [16, 4],
+ element_tile = [1, 4],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [1, 16]
+>
+
+func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 {
+ %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+ %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32
+ return %0 : f32
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32
+// Global reduction
+// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32
+// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// Accumulator reduction
+// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32>
+
+// -----
+
+#nested = #iree_vector_ext.nested_layout<
+ // There will two partial reductions across
+ // two subgroups.
+ subgroup_tile = [1, 2],
+ // We are reducing along dim=1, so each thread will reduce
+ // 1 batches x 4 elements = 4 elements.
+ batch_tile = [2, 1],
+ outer_tile = [1, 1],
+ // We are reducing on dim=1, which is distributed over 4 threads. Based
+ // on the subgroup basis and thread order, the shuffle offset is 16.
+ thread_tile = [16, 4],
+ element_tile = [1, 4],
+
+ subgroup_strides = [2, 1],
+ thread_strides = [1, 16]
+>
+
+func.func @inter_subgroup_reduction(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> {
+ %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32>
+ %0 = vector.multi_reduction <maximumf>, %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32>
+ return %0 : vector<32xf32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func @inter_subgroup_reduction
+// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x1x1x1x1x2xf32>
+// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// Local reduction
+// CHECK: vector.multi_reduction <maximumf>, %{{.*}}, %{{.*}} [1, 3, 5] : vector<2x1x1x1x1x4xf32> to vector<2x1x1xf32>
+// Thread reduction
+// CHECK: %[[THREAD_RED0:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// CHECK: %[[THREAD_RED1:.+]] = vector.insert %[[THREAD_RED0]], %cst_1 [0] : f32 into vector<2xf32>
+// CHECK: %[[THREAD_RED2:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32
+// CHECK: %[[THREAD_RED3:.+]] = vector.insert %[[THREAD_RED2]], %[[THREAD_RED1]] [1] : f32 into vector<2xf32>
+// CHECK: %[[THREAD_RED4:.+]] = vector.shape_cast %[[THREAD_RED3]] : vector<2xf32> to vector<2x1x1xf32>
+// Subgroup reduction
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x2xf32, #gpu.address_space<workgroup>>
+// CHECK: gpu.barrier
+// CHECK-DAG: %[[TIDX0:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x]
+// CHECK-DAG: %[[TIDX1:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16 + 16)>()[%thread_id_x]
+// CHECK-DAG: %[[SGIDX:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 64) mod 2)>()[%thread_id_x]
+// CHECK-DAG: %[[EXTRACT0:.+]] = vector.extract %[[THREAD_RED4]][0] : vector<1x1xf32> from vector<2x1x1xf32>
+// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[THREAD_RED4]][1] : vector<1x1xf32> from vector<2x1x1xf32>
+// CHECK-DAG: vector.transfer_write %[[EXTRACT0]], %[[ALLOC]][%[[TIDX0]], %[[SGIDX]]]
+// CHECK-DAG: vector.transfer_write %[[EXTRACT1]], %[[ALLOC]][%[[TIDX1]], %[[SGIDX]]]
+// CHECK: gpu.barrier
+// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %alloc[%[[TIDX0]], %c0], {{.*}} {in_bounds = [false, true]} : memref<32x2xf32, #gpu.address_space<workgroup>>, vector<1x2xf32>
+// CHECK-DAG: %[[GATHER0:.+]] = vector.insert_strided_slice %[[READ0]], %[[CST]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x2xf32> into vector<2x1x1x1x1x2xf32>
+// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %alloc[%[[TIDX1]], %c0], %cst_0 {in_bounds = [false, true]} : memref<32x2xf32, #gpu.address_space<workgroup>>, vector<1x2xf32>
+// CHECK-DAG: %[[GATHER1:.+]] = vector.insert_strided_slice %[[READ1]], %[[GATHER0]] {offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x2xf32> into vector<2x1x1x1x1x2xf32>
+// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_simt %arg1 : vector<32xf32> -> vector<2x1x1xf32>
+// CHECK-DAG: %[[SGRED:.+]] = vector.multi_reduction <maximumf>, %[[GATHER1]], {{.*}} [1, 3, 5] : vector<2x1x1x1x1x2xf32> to vector<2x1x1xf32>
+// CHECK-DAG: arith.maximumf %[[SGRED]], %[[ACC]] : vector<2x1x1xf32>