[Codegen][GPU] Add DistributeArgCompare pattern (#23793)
This PR adds the `DistributeArgCompare `pattern to distribute
`iree_vector_ext.arg_compare` operations across GPU threads and
subgroups.
For supported comparators, we use a ballot-based approach that leverages
`gpu.subgroup_reduce` + `gpu.ballot` for reduction.
Supported comparators include:
1. Direct comparison on values (e.g., arith.cmpf ogt for argmax)
2. Same unary op applied to both arguments before comparison (e.g.,
math.absf for argmax of absolute values)
Unsupported comparators fall back to the portable butterfly shuffle
approach. Currently, this is mainly used for argmax/argmin operations,
but we can extend support for additional comparators as needed.
Issue: #23005
Assisted-by: [Claude Code](https://claude.ai/code)
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
index 686b074..30ab59e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
@@ -158,6 +158,7 @@
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:LoopLikeInterface",
+ "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:MemRefUtils",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
index 851acee..3883159 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
@@ -122,6 +122,7 @@
MLIRLinalgTransforms
MLIRLinalgUtils
MLIRLoopLikeInterface
+ MLIRMathDialect
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRMemRefUtils
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index a24daa3..ccd3c5e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -21,7 +21,9 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -936,16 +938,311 @@
return layout.getThreadTile()[dim];
}
-/// The lowering for multi_reduction is done in two steps:
+/// Computes the distributed shape after local (per-thread) reduction, with
+/// reduction dimensions set to 1 across all tile groups (batch, outer,
+/// element).
+static SmallVector<int64_t>
+getLocalReducedDistributedShape(NestedLayoutAttr srcLayout,
+ ArrayRef<int64_t> reductionDims) {
+ int64_t rank = srcLayout.getRank();
+ SmallVector<int64_t> shape = srcLayout.getDistributedShape();
+ // Iterate over 3 tile groups: batch, outer, element.
+ for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
+ int64_t tileGroupOffset = tileGroupIdx * rank;
+ for (int64_t rDim : reductionDims) {
+ shape[tileGroupOffset + rDim] = 1;
+ }
+ }
+ return shape;
+}
+
+/// Computes the undistributed shape after subgroup-level reduction, where
+/// reduction dimensions retain only the subgroup tile size.
+static SmallVector<int64_t>
+getSubgroupReducedShape(NestedLayoutAttr srcLayout, ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> reductionDims) {
+ SmallVector<int64_t> preDistrShape = srcLayout.getUndistributedPackedShape();
+ SmallVector<int64_t> shape = llvm::to_vector(srcShape);
+ for (int64_t rDim : reductionDims) {
+ shape[rDim] = preDistrShape[rDim];
+ }
+ return shape;
+}
+
+/// Reshapes a flat 1-d vector back to a target type. For 0-d vectors,
+/// uses extract+broadcast since shape_cast to 0-d is not supported.
+static VectorValue reshapeFlatToTarget(RewriterBase &rewriter, Location loc,
+ VectorValue flat,
+ VectorType targetType) {
+ if (targetType.getRank() == 0) {
+ Value scalar =
+ vector::ExtractOp::create(rewriter, loc, flat, ArrayRef<int64_t>{0});
+ return vector::BroadcastOp::create(rewriter, loc, targetType, scalar);
+ }
+ return vector::ShapeCastOp::create(rewriter, loc, targetType, flat);
+}
+
+static LogicalResult checkBitwidthForShuffle(Operation *op, Type type,
+ int64_t maxBitsPerShuffle,
+ StringRef typeName,
+ PatternRewriter &rewriter) {
+ unsigned bitwidth = type.getIntOrFloatBitWidth();
+ if (bitwidth > maxBitsPerShuffle) {
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("{0} bitwidth {1} greater than maxBitsPerShuffle {2}",
+ typeName, bitwidth, maxBitsPerShuffle));
+ }
+ return success();
+}
+
+/// Creates an equality comparison operation for the given values.
+/// Returns arith.cmpf for floating-point types and arith.cmpi for integers.
+static Value createEqualityComparison(RewriterBase &rewriter, Location loc,
+ Value lhs, Value rhs) {
+ if (isa<FloatType>(lhs.getType())) {
+ return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ, lhs,
+ rhs);
+ }
+ return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, lhs,
+ rhs);
+}
+
+/// Result of analyzing a comparator region for ballot-based reduction.
+struct ComparatorAnalysis {
+ gpu::AllReduceOperation reduceOp;
+ /// The unary transformation operation applied to both comparison operands,
+ /// or nullptr if comparing block arguments directly.
+ Operation *transformOp = nullptr;
+};
+
+/// Maps a floating-point comparison predicate to the corresponding
+/// gpu::AllReduceOperation for reduction.
+static std::optional<gpu::AllReduceOperation>
+mapFCmpPredicateToReduceOp(arith::CmpFPredicate pred) {
+ switch (pred) {
+ case arith::CmpFPredicate::OGT:
+ case arith::CmpFPredicate::OGE:
+ case arith::CmpFPredicate::UGT:
+ case arith::CmpFPredicate::UGE:
+ return gpu::AllReduceOperation::MAXNUMF;
+ case arith::CmpFPredicate::OLT:
+ case arith::CmpFPredicate::OLE:
+ case arith::CmpFPredicate::ULT:
+ case arith::CmpFPredicate::ULE:
+ return gpu::AllReduceOperation::MINNUMF;
+ default:
+ return std::nullopt;
+ }
+}
+
+/// Maps an integer comparison predicate to the corresponding
+/// gpu::AllReduceOperation for reduction.
+static std::optional<gpu::AllReduceOperation>
+mapICmpPredicateToReduceOp(arith::CmpIPredicate pred) {
+ switch (pred) {
+ case arith::CmpIPredicate::sgt:
+ case arith::CmpIPredicate::sge:
+ return gpu::AllReduceOperation::MAXSI;
+ case arith::CmpIPredicate::slt:
+ case arith::CmpIPredicate::sle:
+ return gpu::AllReduceOperation::MINSI;
+ case arith::CmpIPredicate::ugt:
+ case arith::CmpIPredicate::uge:
+ return gpu::AllReduceOperation::MAXUI;
+ case arith::CmpIPredicate::ult:
+ case arith::CmpIPredicate::ule:
+ return gpu::AllReduceOperation::MINUI;
+ default:
+ return std::nullopt;
+ }
+}
+
+/// Analyzes the comparator region of an arg_compare operation to determine
+/// if it can use the efficient ballot-based reduction approach.
+///
+/// This function detects two patterns:
+/// 1. Simple comparison: direct comparison on block arguments
+/// Example: arith.cmpf ogt, %lhs, %rhs
+///
+/// 2. Transformed comparison: same unary transformation applied to both
+/// arguments before comparison
+/// Example: arith.cmpf ogt, (math.absf %lhs), (math.absf %rhs)
+///
+/// Returns the ComparatorAnalysis if the pattern is detected, or std::nullopt
+/// for comparators that require the shuffle-based fallback.
+static std::optional<ComparatorAnalysis>
+analyzeComparatorForThreadReduction(Region &comparatorRegion) {
+ Block &block = comparatorRegion.front();
+
+ // Get the yield operation. The verifier guarantees:
+ // - Terminator is YieldOp (SingleBlockImplicitTerminator trait)
+ // - YieldOp has exactly 1 operand of type i1
+ auto yieldOp = cast<IREE::VectorExt::YieldOp>(block.getTerminator());
+ Value yieldedValue = yieldOp.getValues()[0];
+
+ // The yielded value must have a defining op (not a block argument).
+ Operation *cmpOp = yieldedValue.getDefiningOp();
+ if (!cmpOp) {
+ return std::nullopt;
+ }
+
+ Value blockArg0 = block.getArgument(0);
+ Value blockArg1 = block.getArgument(1);
+
+ // Helper lambda to analyze comparison operands and build ComparatorAnalysis.
+ // Works for both floating-point and integer comparisons.
+ auto analyzeComparisonOperands =
+ [&](Value lhs, Value rhs, std::optional<gpu::AllReduceOperation> reduceOp)
+ -> std::optional<ComparatorAnalysis> {
+ if (!reduceOp) {
+ return std::nullopt;
+ }
+
+ // Case 1: Direct comparison on block arguments.
+ if (lhs == blockArg0 && rhs == blockArg1) {
+ return ComparatorAnalysis{*reduceOp, /*transformOp=*/nullptr};
+ }
+
+ // Case 2: Same unary transformation applied to both arguments.
+ // Check if both operands come from the same type of unary op applied
+ // to the block arguments.
+ Operation *lhsOp = lhs.getDefiningOp();
+ Operation *rhsOp = rhs.getDefiningOp();
+ if (lhsOp && rhsOp && lhsOp->getName() == rhsOp->getName() &&
+ lhsOp->getNumOperands() == 1 && lhsOp->getNumResults() == 1 &&
+ rhsOp->getNumOperands() == 1 && rhsOp->getNumResults() == 1 &&
+ lhsOp->getOperand(0) == blockArg0 &&
+ rhsOp->getOperand(0) == blockArg1) {
+ return ComparatorAnalysis{*reduceOp, /*transformOp=*/lhsOp};
+ }
+
+ return std::nullopt;
+ };
+
+ // Check for floating-point comparison.
+ if (auto cmpfOp = dyn_cast<arith::CmpFOp>(cmpOp)) {
+ auto reduceOp = mapFCmpPredicateToReduceOp(cmpfOp.getPredicate());
+ return analyzeComparisonOperands(cmpfOp.getLhs(), cmpfOp.getRhs(),
+ reduceOp);
+ }
+
+ // Check for integer comparison.
+ if (auto cmpiOp = dyn_cast<arith::CmpIOp>(cmpOp)) {
+ auto reduceOp = mapICmpPredicateToReduceOp(cmpiOp.getPredicate());
+ return analyzeComparisonOperands(cmpiOp.getLhs(), cmpiOp.getRhs(),
+ reduceOp);
+ }
+
+ return std::nullopt;
+}
+
+/// Clones the body of a comparator region, mapping block arguments to the
+/// given lhs/rhs values. Returns the yielded i1 comparison result.
+static Value cloneComparatorRegion(RewriterBase &rewriter, Region ®ion,
+ Value lhs, Value rhs) {
+ Block &block = region.front();
+ IRMapping mapper;
+ mapper.map(block.getArgument(0), lhs);
+ mapper.map(block.getArgument(1), rhs);
+ for (Operation &op : block.without_terminator()) {
+ Operation *clonedOp = rewriter.clone(op, mapper);
+ for (const auto &[origResult, clonedResult] :
+ llvm::zip_equal(op.getResults(), clonedOp->getResults())) {
+ mapper.map(origResult, clonedResult);
+ }
+ }
+ auto yieldOp = cast<IREE::VectorExt::YieldOp>(block.getTerminator());
+ return mapper.lookup(yieldOp.getValues()[0]);
+}
+
+/// Computes the layout for reading reduction results from shared memory.
+/// Shared by DistributeMultiReduction and DistributeArgCompare.
+static NestedLayoutAttr
+computeLayoutForReductionFromBuffer(NestedLayoutAttr srcLayout,
+ ArrayRef<int64_t> reductionDims) {
+ auto subgroupTileLens = llvm::to_vector(srcLayout.getSubgroupTile());
+ auto batchTileLens = llvm::to_vector(srcLayout.getBatchTile());
+ auto outerTileLens = llvm::to_vector(srcLayout.getOuterTile());
+ auto threadTileLens = llvm::to_vector(srcLayout.getThreadTile());
+ auto elementTileLens = llvm::to_vector(srcLayout.getElementTile());
+ auto subgroupStrides = llvm::to_vector(srcLayout.getSubgroupStrides());
+ auto threadStrides = llvm::to_vector(srcLayout.getThreadStrides());
+
+ int64_t threadsRequired = 1;
+ for (int64_t rDim : reductionDims) {
+ threadsRequired *= llvm::PowerOf2Ceil(subgroupTileLens[rDim]);
+ }
+
+ std::optional<int64_t> availableThreads;
+ int64_t threadStride = 0;
+ for (int64_t rDim : reductionDims) {
+ if (threadTileLens[rDim] >= threadsRequired) {
+ availableThreads = threadTileLens[rDim];
+ threadStride = threadStrides[rDim];
+ break;
+ }
+ }
+
+ for (int64_t rDim : reductionDims) {
+ batchTileLens[rDim] = 1;
+ outerTileLens[rDim] = 1;
+ elementTileLens[rDim] = 1;
+ if (availableThreads.has_value()) {
+ int64_t used = llvm::PowerOf2Ceil(subgroupTileLens[rDim]);
+ threadStrides[rDim] = threadStride;
+ threadTileLens[rDim] = used;
+ availableThreads.value() /= used;
+ threadStride *= used;
+ } else {
+ threadStrides[rDim] = 0;
+ threadTileLens[rDim] = 1;
+ }
+ subgroupTileLens[rDim] = 1;
+ subgroupStrides[rDim] = 0;
+ }
+
+ return IREE::VectorExt::NestedLayoutAttr::get(
+ srcLayout.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
+ threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+}
+
+/// Computes the inter-subgroup write layout by replacing reduced tiles with
+/// unit dimensions. Shared by DistributeMultiReduction and
+/// DistributeArgCompare.
+static NestedLayoutAttr
+computeInterSubgroupWriteLayout(NestedLayoutAttr srcLayout,
+ ArrayRef<int64_t> reductionDims) {
+ auto subgroupTileLens = llvm::to_vector(srcLayout.getSubgroupTile());
+ auto batchTileLens = llvm::to_vector(srcLayout.getBatchTile());
+ auto outerTileLens = llvm::to_vector(srcLayout.getOuterTile());
+ auto threadTileLens = llvm::to_vector(srcLayout.getThreadTile());
+ auto elementTileLens = llvm::to_vector(srcLayout.getElementTile());
+ auto subgroupStrides = llvm::to_vector(srcLayout.getSubgroupStrides());
+ auto threadStrides = llvm::to_vector(srcLayout.getThreadStrides());
+
+ for (int64_t rDim : reductionDims) {
+ batchTileLens[rDim] = 1;
+ outerTileLens[rDim] = 1;
+ threadTileLens[rDim] = 1;
+ elementTileLens[rDim] = 1;
+ threadStrides[rDim] = 0;
+ }
+
+ return IREE::VectorExt::NestedLayoutAttr::get(
+ srcLayout.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
+ threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+}
+
+/// The lowering for multi_reduction is done in four steps:
/// 1. Local Reduce: Each thread reduces all elements carried by it along
/// the reduction dimensions. This is the batch, outer and element dims.
/// 2. Thread Reduce: Each thread reduces result of step 1 across threads
/// by doing a butterfly shuffle.
-/// 3. Accumulator Reduce: Each thread reduces it's intermediate reduced
+/// 3. Accumulator Reduce: Each thread reduces its intermediate reduced
/// results with the accumulator it holds.
-/// 4. Subgroup reduce : each subgroup will store the partial reductions
-/// to shared memory and will be reloaded into a layout where partial
-/// reductions will be placed inside threads.
+/// 4. Subgroup Reduce: Each subgroup stores the partial reductions
+/// to shared memory and reloads them into a layout where partial
+/// reductions are placed inside threads.
struct DistributeMultiReduction final
: MaskedOpDistributionPattern<vector::MultiDimReductionOp> {
using MaskedOpDistributionPattern::MaskedOpDistributionPattern;
@@ -975,12 +1272,9 @@
}
Type elemTy = srcVector.getType().getElementType();
- unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
- if (elemBitwidth > maxBitsPerShuffle) {
- return rewriter.notifyMatchFailure(
- multiReduceOp,
- llvm::formatv("element bitwidth greater than maxBitsPerShuffle",
- elemBitwidth, maxBitsPerShuffle));
+ if (failed(checkBitwidthForShuffle(multiReduceOp, elemTy, maxBitsPerShuffle,
+ "element", rewriter))) {
+ return failure();
}
VectorValue disSrc =
@@ -1159,65 +1453,7 @@
NestedLayoutAttr
getLayoutForReductionFromBuffer(NestedLayoutAttr srcLayout,
ArrayRef<int64_t> reductionDims) const {
- // Create new layout where the elements of a subgroup are
- // distributed to every threads.
- IREE::VectorExt::NestedLayoutAttr bufferReduceLayout;
- auto subgroupTileLens =
- llvm::to_vector_of<int64_t>(srcLayout.getSubgroupTile());
- auto batchTileLens = llvm::to_vector_of<int64_t>(srcLayout.getBatchTile());
- auto outerTileLens = llvm::to_vector_of<int64_t>(srcLayout.getOuterTile());
- auto threadTileLens =
- llvm::to_vector_of<int64_t>(srcLayout.getThreadTile());
- auto elementTileLens =
- llvm::to_vector_of<int64_t>(srcLayout.getElementTile());
- auto subgroupStrides =
- llvm::to_vector_of<int64_t>(srcLayout.getSubgroupStrides());
- auto threadStrides =
- llvm::to_vector_of<int64_t>(srcLayout.getThreadStrides());
-
- // Check if we had enough threads on one of the reduction dimensions
- // to use for a subgroup reduction. If not, do a serialized reduction.
- // This usually works, because we would be distributing the reduction
- // dimension on atleast more threads than number of subgroups, and if we
- // aren't, it's probably best to do a serialized reduction anyway.
- int64_t threadsRequired = 1;
- for (int64_t rDim : reductionDims) {
- // The size or #lanes needs to be a power of 2.
- threadsRequired *= llvm::PowerOf2Ceil(subgroupTileLens[rDim]);
- }
- std::optional<int64_t> availableThreads;
- int64_t threadStride = 0;
- for (int64_t rDim : reductionDims) {
- // TODO: We could merge two different dimension threads into one, but they
- // can be disjoint.
- if (threadTileLens[rDim] >= threadsRequired) {
- availableThreads = threadTileLens[rDim];
- threadStride = threadStrides[rDim];
- break;
- }
- }
-
- for (int64_t rDim : reductionDims) {
- batchTileLens[rDim] = 1;
- outerTileLens[rDim] = 1;
- elementTileLens[rDim] = 1;
- if (availableThreads.has_value()) {
- int64_t used = llvm::PowerOf2Ceil(subgroupTileLens[rDim]);
- threadStrides[rDim] = threadStride;
- threadTileLens[rDim] = used;
- availableThreads.value() /= used;
- threadStride *= used;
- } else {
- threadStrides[rDim] = 0;
- threadTileLens[rDim] = 1;
- }
- subgroupTileLens[rDim] = 1;
- subgroupStrides[rDim] = 0;
- }
- bufferReduceLayout = IREE::VectorExt::NestedLayoutAttr::get(
- srcLayout.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
- threadTileLens, elementTileLens, subgroupStrides, threadStrides);
- return bufferReduceLayout;
+ return computeLayoutForReductionFromBuffer(srcLayout, reductionDims);
}
void writePartialResultToBuffer(RewriterBase &rewriter, Location loc,
@@ -1230,31 +1466,8 @@
SmallVector<bool> inBounds(unDistributedType.getRank(), true);
auto write = vector::TransferWriteOp::create(rewriter, loc, valueToWrite,
buffer, indices, inBounds);
- // Set layouts signature for write.
- // We need to set the layout on the srcVector/first operand.
- auto subgroupTileLens =
- llvm::to_vector_of<int64_t>(srcLayout.getSubgroupTile());
- auto batchTileLens = llvm::to_vector_of<int64_t>(srcLayout.getBatchTile());
- auto outerTileLens = llvm::to_vector_of<int64_t>(srcLayout.getOuterTile());
- auto threadTileLens =
- llvm::to_vector_of<int64_t>(srcLayout.getThreadTile());
- auto elementTileLens =
- llvm::to_vector_of<int64_t>(srcLayout.getElementTile());
- auto subgroupStrides =
- llvm::to_vector_of<int64_t>(srcLayout.getSubgroupStrides());
- auto threadStrides =
- llvm::to_vector_of<int64_t>(srcLayout.getThreadStrides());
- // Replace the reduced tiles with unit dimension.
- for (int64_t rDim : reductionDims) {
- batchTileLens[rDim] = 1;
- outerTileLens[rDim] = 1;
- threadTileLens[rDim] = 1;
- elementTileLens[rDim] = 1;
- threadStrides[rDim] = 0;
- }
- auto interSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
- rewriter.getContext(), subgroupTileLens, batchTileLens, outerTileLens,
- threadTileLens, elementTileLens, subgroupStrides, threadStrides);
+ auto interSubGroupLayout =
+ computeInterSubgroupWriteLayout(srcLayout, reductionDims);
setSignatureForRedistribution(rewriter, write, {interSubGroupLayout}, {});
}
@@ -1328,41 +1541,25 @@
// Thus, we re-insert the reduction dimensions in
// their original positions as :
// p1 x p2 x p3 -> p1 x p2 x p3 x 1 x 1
- int64_t rank = srcLayout.getRank();
- SmallVector<int64_t> partialReducedDistributedShape =
- srcLayout.getDistributedShape();
- for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
- int64_t tileGroupOffset = tileGroupIdx * rank;
- for (int64_t rDim : reductionDims) {
- partialReducedDistributedShape[tileGroupOffset + rDim] = 1;
- }
- }
- VectorType partialReducedDistributedType = VectorType::get(
- partialReducedDistributedShape, srcVector.getType().getElementType());
+ SmallVector<int64_t> localReducedDistributedShape =
+ getLocalReducedDistributedShape(srcLayout, reductionDims);
+ VectorType localReducedDistributedType = VectorType::get(
+ localReducedDistributedShape, srcVector.getType().getElementType());
Value isoRankThreadReduced = vector::ShapeCastOp::create(
- rewriter, loc, partialReducedDistributedType, threadReduced);
+ rewriter, loc, localReducedDistributedType, threadReduced);
- SmallVector<int64_t> preDistrShape =
- srcLayout.getUndistributedPackedShape();
- SmallVector<int64_t> partialReductionShape =
- llvm::to_vector(srcVector.getType().getShape());
- for (int64_t rDim : reductionDims) {
- // The first #rank elements will form the subgroup tile
- // Here we replace the input shape with subgroup tile
- // because every other tile is reduced except the subgroup
- // tile.
- partialReductionShape[rDim] = preDistrShape[rDim];
- }
+ SmallVector<int64_t> subgroupReducedShape = getSubgroupReducedShape(
+ srcLayout, srcVector.getType().getShape(), reductionDims);
auto unDistributedType = VectorType::get(
- partialReductionShape, srcVector.getType().getElementType());
+ subgroupReducedShape, srcVector.getType().getElementType());
VectorValue valueToWrite = IREE::VectorExt::ToSIMDOp::create(
rewriter, loc, unDistributedType, isoRankThreadReduced);
auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
rewriter.getContext(), gpu::AddressSpace::Workgroup));
MemRefType allocType = MemRefType::get(
- partialReductionShape, srcVector.getType().getElementType(),
- AffineMap(), workgroupMemoryAddressSpace);
+ subgroupReducedShape, srcVector.getType().getElementType(), AffineMap(),
+ workgroupMemoryAddressSpace);
auto alloc =
getBufferForSubgroupReduction(rewriter, allocType, valueToWrite);
writePartialResultToBuffer(rewriter, loc, valueToWrite, alloc, srcLayout,
@@ -1377,6 +1574,672 @@
int64_t maxBitsPerShuffle;
};
+/// Distributes `iree_vector_ext.arg_compare` ops with nested layouts.
+/// Follows the same local -> thread -> subgroup reduction approach as
+/// DistributeMultiReduction, but tracks both values and indices.
+struct DistributeArgCompare final
+ : MaskedOpDistributionPattern<IREE::VectorExt::ArgCompareOp> {
+
+ DistributeArgCompare(MLIRContext *context, int64_t subgroupSize,
+ int64_t maxBitsPerShuffle, int64_t benefit = 1)
+ : MaskedOpDistributionPattern(context, benefit),
+ subgroupSize(subgroupSize), maxBitsPerShuffle(maxBitsPerShuffle) {}
+
+ LogicalResult
+ matchAndRewrite(IREE::VectorExt::ArgCompareOp argCompareOp,
+ DistributionSignature &signature, vector::MaskOp maskOp,
+ std::optional<DistributionSignature> &maskSignature,
+ PatternRewriter &rewriter) const override {
+ Location loc = argCompareOp.getLoc();
+ VectorValue inputValue = argCompareOp.getInputValue();
+ Value inputIndex = argCompareOp.getInputIndex();
+ VectorValue initValue = argCompareOp.getInitValue();
+ VectorValue initIndex = argCompareOp.getInitIndex();
+ int64_t reductionDim = argCompareOp.getDimension();
+ int64_t rank = inputValue.getType().getRank();
+
+ // TODO(Bangtian): Implement masked arg_compare distribution.
+ if (maskOp) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp, "masked arg_compare distribution not yet implemented");
+ }
+
+ auto valueLayout =
+ dyn_cast_if_present<NestedLayoutAttr>(signature[inputValue]);
+ if (!valueLayout) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp, "expected nested layout attr for input value");
+ }
+
+ auto initValueLayout =
+ dyn_cast_if_present<NestedLayoutAttr>(signature[initValue]);
+ if (!initValueLayout) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp, "expected nested layout attr for init value");
+ }
+ auto initIndexLayout =
+ dyn_cast_if_present<NestedLayoutAttr>(signature[initIndex]);
+ if (!initIndexLayout) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp, "expected nested layout attr for init index");
+ }
+
+ Type elemTy = inputValue.getType().getElementType();
+ if (failed(checkBitwidthForShuffle(argCompareOp, elemTy, maxBitsPerShuffle,
+ "element", rewriter))) {
+ return failure();
+ }
+
+ // No bitwidth check on the index type: the index is only forwarded from
+ // the winning lane via `gpu.shuffle idx`, which handles wider types (i64).
+ // TODO(Bangtian): On AMD, ROCDL decomposes 64-bit shuffles into 32-bit
+ // pairs. Consider dropping the value bitwidth check above too.
+
+ // Only explicit index mode; iota indices are materialized earlier.
+ if (!inputIndex) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp, "expected explicit index mode (indices should be "
+ "materialized by earlier passes)");
+ }
+
+ auto inputIndexVec = cast<VectorValue>(inputIndex);
+ auto indexLayout =
+ dyn_cast_if_present<NestedLayoutAttr>(signature[inputIndexVec]);
+ if (!indexLayout) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp, "expected nested layout attr for input index");
+ }
+
+ VectorValue disValue =
+ getDistributed(rewriter, inputValue, signature[inputValue]);
+ VectorValue disIndex = cast<VectorValue>(
+ getDistributed(rewriter, inputIndexVec, signature[inputIndexVec]));
+
+ // Handle 0-d init distribution. Three cases:
+ // 1. Non-zero rank: distribute normally.
+ // 2. 0-d wrapped by ToSIMDOp: unwrap the identity op.
+ // 3. 0-d with identity layout: use as-is (no-op).
+ // The else-if is intentional — splitting into independent ifs would let
+ // case 1 values defined by ToSIMDOp hit both branches.
+ VectorValue disInitValue = initValue;
+ if (isNonZeroRank(initValue)) {
+ disInitValue = getDistributed(rewriter, initValue, initValueLayout);
+ } else if (auto toSIMD =
+ initValue.getDefiningOp<IREE::VectorExt::ToSIMDOp>()) {
+ disInitValue = cast<VectorValue>(toSIMD.getOperand());
+ }
+ VectorValue disInitIndex = initIndex;
+ if (isNonZeroRank(initIndex)) {
+ disInitIndex = getDistributed(rewriter, initIndex, initIndexLayout);
+ } else if (auto toSIMD =
+ initIndex.getDefiningOp<IREE::VectorExt::ToSIMDOp>()) {
+ disInitIndex = cast<VectorValue>(toSIMD.getOperand());
+ }
+
+ FailureOr<std::pair<VectorValue, VectorValue>> localReduced =
+ doLocalArgCompareReduction(
+ rewriter, loc, disValue, disIndex, disInitValue, disInitIndex,
+ argCompareOp.getRegion(), reductionDim, rank);
+ if (failed(localReduced)) {
+ return rewriter.notifyMatchFailure(
+ argCompareOp,
+ "failed to perform local per-thread reduction for arg_compare");
+ }
+ auto [localValueResult, localIndexResult] = *localReduced;
+
+ bool hasThreadReductions = valueLayout.getThreadTile()[reductionDim] > 1;
+ bool hasSubgroupReductions =
+ valueLayout.getSubgroupTile()[reductionDim] > 1;
+
+ if (!hasThreadReductions && !hasSubgroupReductions) {
+ replaceOpWithDistributedValues(rewriter, argCompareOp,
+ {localValueResult, localIndexResult});
+ return success();
+ }
+
+ std::pair<VectorValue, VectorValue> threadReduced = {localValueResult,
+ localIndexResult};
+ if (hasThreadReductions) {
+ std::optional<ComparatorAnalysis> analysis =
+ analyzeComparatorForThreadReduction(argCompareOp.getRegion());
+
+ FailureOr<std::pair<VectorValue, VectorValue>> result = doThreadReduction(
+ rewriter, loc, valueLayout, localValueResult, localIndexResult,
+ argCompareOp.getRegion(), reductionDim, analysis);
+ if (failed(result)) {
+ return failure();
+ }
+ threadReduced = result.value();
+ }
+
+ if (!hasSubgroupReductions) {
+ replaceOpWithDistributedValues(
+ rewriter, argCompareOp, {threadReduced.first, threadReduced.second});
+ return success();
+ }
+
+ SmallVector<bool> resultReductionMask(rank, false);
+ resultReductionMask[reductionDim] = true;
+ VectorLayoutInterface resultLayout =
+ valueLayout.project(resultReductionMask);
+
+ std::pair<Value, Value> subgroupReduced = doSubgroupReduction(
+ rewriter, loc, inputValue, valueLayout, reductionDim,
+ threadReduced.first, threadReduced.second, argCompareOp.getRegion(),
+ isNonZeroRank(initValue) ? initValue : disInitValue,
+ isNonZeroRank(initIndex) ? initIndex : disInitIndex, resultLayout);
+
+ rewriter.replaceOp(argCompareOp,
+ {subgroupReduced.first, subgroupReduced.second});
+ return success();
+ }
+
+private:
+ /// Per-thread reduction over batch/outer/element tiles in the reduction dim.
+ FailureOr<std::pair<VectorValue, VectorValue>> doLocalArgCompareReduction(
+ RewriterBase &rewriter, Location loc, VectorValue inputVal,
+ VectorValue inputIdx, VectorValue initVal, VectorValue initIdx,
+ Region &comparatorRegion, int64_t reductionDim, int64_t rank) const {
+ VectorType valType = inputVal.getType();
+ int64_t distRank = valType.getRank();
+ assert(
+ distRank == 3 * rank && inputIdx.getType().getRank() == distRank &&
+ "distributed rank must be 3 * original rank for batch/outer/element");
+
+ VectorType outValType = initVal.getType();
+ VectorType outIdxType = initIdx.getType();
+ assert(outValType.getNumElements() == outIdxType.getNumElements() &&
+ "init value and index must have the same number of elements");
+
+ int64_t initRank = rank - 1;
+ assert(outValType.getRank() == 3 * initRank &&
+ "init rank must be 3 * (rank - 1) after dropping reduction dim");
+
+ int64_t batchDimInDist = reductionDim;
+ int64_t outerDimInDist = rank + reductionDim;
+ int64_t elementDimInDist = 2 * rank + reductionDim;
+
+ int64_t batchSize = valType.getShape()[batchDimInDist];
+ int64_t outerSize = valType.getShape()[outerDimInDist];
+ int64_t elementSize = valType.getShape()[elementDimInDist];
+ assert(batchSize > 0 && outerSize > 0 && elementSize > 0 &&
+ "tile sizes along reduction dim must be positive");
+
+ auto applyComparator = [&](Value lhs, Value rhs) -> Value {
+ return cloneComparatorRegion(rewriter, comparatorRegion, lhs, rhs);
+ };
+
+ // Placeholder containers; every element is overwritten by InsertOp below.
+ Value outValVec = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(outValType))
+ .getResult();
+ Value outIdxVec = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(outIdxType))
+ .getResult();
+
+ SmallVector<int64_t> outShape(outValType.getShape());
+ SmallVector<int64_t> outIndices(outValType.getRank(), 0);
+ int64_t outNumElements = outValType.getNumElements();
+
+ for (int64_t linearIdx = 0; linearIdx < outNumElements; ++linearIdx) {
+ int64_t tmp = linearIdx;
+ for (int64_t i = static_cast<int64_t>(outIndices.size()) - 1; i >= 0;
+ --i) {
+ int64_t extent = outShape[i];
+ outIndices[i] = tmp % extent;
+ tmp /= extent;
+ }
+
+ Value accVal =
+ vector::ExtractOp::create(rewriter, loc, initVal, outIndices);
+ Value accIdx =
+ vector::ExtractOp::create(rewriter, loc, initIdx, outIndices);
+
+ SmallVector<int64_t> inputIndices(distRank, 0);
+ for (int64_t tileGroup = 0; tileGroup < 3; ++tileGroup) {
+ for (int64_t dimIdx = 0; dimIdx < rank; ++dimIdx) {
+ if (dimIdx == reductionDim) {
+ continue;
+ }
+ int64_t initDimInOrig = dimIdx < reductionDim ? dimIdx : dimIdx - 1;
+ int64_t outPos = tileGroup * initRank + initDimInOrig;
+ inputIndices[tileGroup * rank + dimIdx] = outIndices[outPos];
+ }
+ }
+
+ int64_t totalReductionIters = batchSize * outerSize * elementSize;
+ for (int64_t ri = 0; ri < totalReductionIters; ++ri) {
+ int64_t e = ri % elementSize;
+ int64_t o = (ri / elementSize) % outerSize;
+ int64_t b = ri / (elementSize * outerSize);
+ inputIndices[batchDimInDist] = b;
+ inputIndices[outerDimInDist] = o;
+ inputIndices[elementDimInDist] = e;
+
+ Value elemVal =
+ vector::ExtractOp::create(rewriter, loc, inputVal, inputIndices);
+ Value elemIdx =
+ vector::ExtractOp::create(rewriter, loc, inputIdx, inputIndices);
+
+ Value cmpResult = applyComparator(elemVal, accVal);
+ accVal =
+ arith::SelectOp::create(rewriter, loc, cmpResult, elemVal, accVal);
+ accIdx =
+ arith::SelectOp::create(rewriter, loc, cmpResult, elemIdx, accIdx);
+ }
+
+ outValVec =
+ vector::InsertOp::create(rewriter, loc, accVal, outValVec, outIndices)
+ .getResult();
+ outIdxVec =
+ vector::InsertOp::create(rewriter, loc, accIdx, outIdxVec, outIndices)
+ .getResult();
+ }
+
+ return std::make_pair(cast<VectorValue>(outValVec),
+ cast<VectorValue>(outIdxVec));
+ }
+
+ /// Flattened reduction state shared between ballot and shuffle impls.
+ struct FlattenedReductionState {
+ VectorValue flatValue;
+ VectorValue flatIndex;
+ VectorValue valueRes;
+ VectorValue indexRes;
+ VectorType valueType;
+ VectorType indexType;
+ int64_t numElements;
+ int64_t threadOffset;
+ int64_t width;
+ };
+
+ /// Flatten inputs and create result vectors for thread reduction.
+ FlattenedReductionState
+ setupThreadReduction(RewriterBase &rewriter, Location loc,
+ NestedLayoutAttr layout, VectorValue value,
+ VectorValue index, int64_t reductionDim) const {
+ FlattenedReductionState state;
+ state.valueType = value.getType();
+ state.indexType = index.getType();
+ assert(state.valueType.getNumElements() ==
+ state.indexType.getNumElements() &&
+ "value and index must have matching shapes");
+ Type elemTy = state.valueType.getElementType();
+ Type indexElemTy = state.indexType.getElementType();
+ state.numElements = state.valueType.getNumElements();
+
+ SmallVector<int64_t> flatShape{state.numElements};
+ VectorType flatValueType = VectorType::get(flatShape, elemTy);
+ VectorType flatIndexType = VectorType::get(flatShape, indexElemTy);
+ state.flatValue =
+ vector::ShapeCastOp::create(rewriter, loc, flatValueType, value);
+ state.flatIndex =
+ vector::ShapeCastOp::create(rewriter, loc, flatIndexType, index);
+
+ auto valueZeroAttr = rewriter.getZeroAttr(flatValueType);
+ auto indexZeroAttr = rewriter.getZeroAttr(flatIndexType);
+ state.valueRes = cast<VectorValue>(
+ arith::ConstantOp::create(rewriter, loc, valueZeroAttr).getResult());
+ state.indexRes = cast<VectorValue>(
+ arith::ConstantOp::create(rewriter, loc, indexZeroAttr).getResult());
+
+ state.threadOffset = getShuffleOffset(layout, reductionDim);
+ state.width = getShuffleWidth(layout, reductionDim);
+ return state;
+ }
+
+ /// Main entry point for thread reduction. Dispatches to ballot-based
+ /// or shuffle-based implementation based on the comparator analysis.
+ /// TODO(Bangtian): Share with DistributeMultiReduction via a common helper.
+ FailureOr<std::pair<VectorValue, VectorValue>> doThreadReduction(
+ RewriterBase &rewriter, Location loc, NestedLayoutAttr layout,
+ VectorValue value, VectorValue index, Region &comparatorRegion,
+ int64_t reductionDim,
+ std::optional<ComparatorAnalysis> analysis = std::nullopt) const {
+ FlattenedReductionState state =
+ setupThreadReduction(rewriter, loc, layout, value, index, reductionDim);
+
+ if (analysis.has_value()) {
+ // Ballot-based O(1) path for standard comparators.
+ return doThreadReductionWithBallot(rewriter, loc, state, *analysis);
+ }
+ // Butterfly shuffle fallback for custom comparators.
+ return doThreadReductionWithShuffles(rewriter, loc, state,
+ comparatorRegion);
+ }
+
+ /// Butterfly shuffle fallback: reduces (value, index) pairs for custom
+ /// comparators.
+ FailureOr<std::pair<VectorValue, VectorValue>>
+ doThreadReductionWithShuffles(RewriterBase &rewriter, Location loc,
+ const FlattenedReductionState &state,
+ Region &comparatorRegion) const {
+ VectorValue valueRes = state.valueRes;
+ VectorValue indexRes = state.indexRes;
+
+ Value subgroupSizeVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(subgroupSize));
+
+ for (int64_t elemIdx = 0; elemIdx < state.numElements; ++elemIdx) {
+ Value currentValue =
+ vector::ExtractOp::create(rewriter, loc, state.flatValue, elemIdx);
+ Value currentIndex =
+ vector::ExtractOp::create(rewriter, loc, state.flatIndex, elemIdx);
+
+ for (int64_t stride = state.width / 2; stride > 0; stride /= 2) {
+ int64_t shuffleOffset = stride * state.threadOffset;
+ Value shuffleOffsetVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(shuffleOffset));
+
+ Value shuffledValue = gpu::ShuffleOp::create(
+ rewriter, loc, currentValue, shuffleOffsetVal,
+ subgroupSizeVal, gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+
+ Value cmpResult = cloneComparatorRegion(rewriter, comparatorRegion,
+ currentValue, shuffledValue);
+
+ currentValue = arith::SelectOp::create(rewriter, loc, cmpResult,
+ currentValue, shuffledValue)
+ .getResult();
+
+ Value shuffledIndex = gpu::ShuffleOp::create(
+ rewriter, loc, currentIndex, shuffleOffsetVal,
+ subgroupSizeVal, gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ currentIndex = arith::SelectOp::create(rewriter, loc, cmpResult,
+ currentIndex, shuffledIndex)
+ .getResult();
+ }
+
+ valueRes = vector::InsertOp::create(rewriter, loc, currentValue, valueRes,
+ elemIdx);
+ indexRes = vector::InsertOp::create(rewriter, loc, currentIndex, indexRes,
+ elemIdx);
+ }
+
+ VectorValue reshapedValue =
+ reshapeFlatToTarget(rewriter, loc, valueRes, state.valueType);
+ VectorValue reshapedIndex =
+ reshapeFlatToTarget(rewriter, loc, indexRes, state.indexType);
+ return std::pair{reshapedValue, reshapedIndex};
+ }
+
+ /// Ballot-based thread reduction for analyzable comparators. Elects the
+ /// first winning lane and forwards its index via gpu.shuffle idx.
+ FailureOr<std::pair<VectorValue, VectorValue>>
+ doThreadReductionWithBallot(RewriterBase &rewriter, Location loc,
+ const FlattenedReductionState &state,
+ const ComparatorAnalysis &analysis) const {
+ VectorValue valueRes = state.valueRes;
+ VectorValue indexRes = state.indexRes;
+
+ int64_t threadStride = state.threadOffset;
+ int64_t width = state.width;
+
+ // Ballot type: i32 for subgroup_size <= 32, i64 otherwise.
+ Type ballotType =
+ subgroupSize <= 32 ? rewriter.getI32Type() : rewriter.getI64Type();
+
+ Value subgroupSizeVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getI32IntegerAttr(subgroupSize));
+
+ for (int64_t elemIdx = 0; elemIdx < state.numElements; ++elemIdx) {
+ Value localValue =
+ vector::ExtractOp::create(rewriter, loc, state.flatValue, elemIdx);
+ Value localIndex =
+ vector::ExtractOp::create(rewriter, loc, state.flatIndex, elemIdx);
+
+ // Apply transformation if present (e.g., abs for argmax of abs).
+ Value valueToReduce = localValue;
+ if (analysis.transformOp) {
+ IRMapping mapper;
+ mapper.map(analysis.transformOp->getOperand(0), localValue);
+ Operation *clonedOp = rewriter.clone(*analysis.transformOp, mapper);
+ valueToReduce = clonedOp->getResult(0);
+ }
+
+ Value reducedValue = gpu::SubgroupReduceOp::create(
+ rewriter, loc, valueToReduce, analysis.reduceOp,
+ /*uniform=*/false,
+ /*cluster_size=*/
+ std::optional<uint32_t>(static_cast<uint32_t>(width)),
+ /*cluster_stride=*/static_cast<uint32_t>(threadStride));
+
+ Value isWinner =
+ createEqualityComparison(rewriter, loc, valueToReduce, reducedValue);
+
+ Value ballotMask =
+ gpu::BallotOp::create(rewriter, loc, ballotType, isWinner);
+
+ // Mask ballot to only consider threads in our cluster for non-unit
+ // stride.
+ if (threadStride > 1) {
+ uint64_t clusterMask = 0;
+ for (int64_t i = 0; i < width; ++i) {
+ clusterMask |= (1ULL << (i * threadStride));
+ }
+ Value clusterMaskVal = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(ballotType, clusterMask));
+ ballotMask =
+ arith::AndIOp::create(rewriter, loc, ballotMask, clusterMaskVal);
+ }
+
+ Value winningLane =
+ math::CountTrailingZerosOp::create(rewriter, loc, ballotMask);
+
+ if (ballotType != rewriter.getI32Type()) {
+ winningLane = arith::TruncIOp::create(
+ rewriter, loc, rewriter.getI32Type(), winningLane);
+ }
+
+ // For transformed comparators, broadcast the original value from the
+ // winning lane. For simple comparators, use reducedValue directly.
+ Value resultValue = reducedValue;
+ if (analysis.transformOp) {
+ resultValue =
+ gpu::ShuffleOp::create(rewriter, loc, localValue, winningLane,
+ subgroupSizeVal, gpu::ShuffleMode::IDX)
+ .getShuffleResult();
+ }
+
+ Value resultIndex =
+ gpu::ShuffleOp::create(rewriter, loc, localIndex, winningLane,
+ subgroupSizeVal, gpu::ShuffleMode::IDX)
+ .getShuffleResult();
+
+ valueRes = vector::InsertOp::create(rewriter, loc, resultValue, valueRes,
+ elemIdx);
+ indexRes = vector::InsertOp::create(rewriter, loc, resultIndex, indexRes,
+ elemIdx);
+ }
+
+ VectorValue reshapedValue =
+ reshapeFlatToTarget(rewriter, loc, valueRes, state.valueType);
+ VectorValue reshapedIndex =
+ reshapeFlatToTarget(rewriter, loc, indexRes, state.indexType);
+ return std::pair{reshapedValue, reshapedIndex};
+ }
+
+ /// Perform inter-subgroup reduction via shared memory.
+ std::pair<Value, Value>
+ doSubgroupReduction(RewriterBase &rewriter, Location loc,
+ VectorValue srcVector, NestedLayoutAttr srcLayout,
+ int64_t reductionDim, VectorValue threadReducedValue,
+ VectorValue threadReducedIndex, Region &comparatorRegion,
+ VectorValue initValue, VectorValue initIndex,
+ VectorLayoutInterface resLayout) const {
+ SmallVector<int64_t> localReducedDistributedShape =
+ getLocalReducedDistributedShape(srcLayout, {reductionDim});
+ VectorType localReducedDistributedType = VectorType::get(
+ localReducedDistributedShape, srcVector.getType().getElementType());
+ VectorType localReducedIndexType = VectorType::get(
+ localReducedDistributedShape, initIndex.getType().getElementType());
+
+ Value isoRankThreadReducedValue = vector::ShapeCastOp::create(
+ rewriter, loc, localReducedDistributedType, threadReducedValue);
+ Value isoRankThreadReducedIndex = vector::ShapeCastOp::create(
+ rewriter, loc, localReducedIndexType, threadReducedIndex);
+
+ SmallVector<int64_t> subgroupReducedShape = getSubgroupReducedShape(
+ srcLayout, srcVector.getType().getShape(), {reductionDim});
+ auto unDistributedValueType = VectorType::get(
+ subgroupReducedShape, srcVector.getType().getElementType());
+ auto unDistributedIndexType = VectorType::get(
+ subgroupReducedShape, initIndex.getType().getElementType());
+
+ VectorValue valueToWrite = IREE::VectorExt::ToSIMDOp::create(
+ rewriter, loc, unDistributedValueType, isoRankThreadReducedValue);
+ VectorValue indexToWrite = IREE::VectorExt::ToSIMDOp::create(
+ rewriter, loc, unDistributedIndexType, isoRankThreadReducedIndex);
+
+ auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
+ rewriter.getContext(), gpu::AddressSpace::Workgroup));
+
+ MemRefType valueAllocType = MemRefType::get(
+ subgroupReducedShape, srcVector.getType().getElementType(), AffineMap(),
+ workgroupMemoryAddressSpace);
+ MemRefType indexAllocType = MemRefType::get(
+ subgroupReducedShape, initIndex.getType().getElementType(), AffineMap(),
+ workgroupMemoryAddressSpace);
+
+ auto valueAlloc = memref::AllocOp::create(rewriter, loc, valueAllocType);
+ auto indexAlloc = memref::AllocOp::create(rewriter, loc, indexAllocType);
+
+ gpu::BarrierOp::create(rewriter, loc, valueAlloc);
+ writePartialArgCompareResultToBuffer(rewriter, loc, valueToWrite,
+ indexToWrite, valueAlloc, indexAlloc,
+ srcLayout, reductionDim);
+ gpu::BarrierOp::create(rewriter, loc, valueAlloc);
+
+ return doSubgroupReductionFromBuffer(
+ rewriter, loc, valueAlloc, indexAlloc, srcLayout, resLayout,
+ reductionDim, comparatorRegion, initValue, initIndex);
+ }
+
+ /// Write partial arg_compare results to shared memory with redistribution
+ /// signatures.
+ void writePartialArgCompareResultToBuffer(
+ RewriterBase &rewriter, Location loc, VectorValue valueToWrite,
+ VectorValue indexToWrite, Value valueBuffer, Value indexBuffer,
+ NestedLayoutAttr srcLayout, int64_t reductionDim) const {
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ VectorType valueType = valueToWrite.getType();
+ SmallVector<Value> indices(valueType.getRank(), c0);
+ SmallVector<bool> inBounds(valueType.getRank(), true);
+
+ auto valueWrite = vector::TransferWriteOp::create(
+ rewriter, loc, valueToWrite, valueBuffer, indices, inBounds);
+ auto indexWrite = vector::TransferWriteOp::create(
+ rewriter, loc, indexToWrite, indexBuffer, indices, inBounds);
+
+ auto interSubGroupLayout =
+ computeInterSubgroupWriteLayout(srcLayout, {reductionDim});
+ setSignatureForRedistribution(rewriter, valueWrite, {interSubGroupLayout},
+ {});
+ setSignatureForRedistribution(rewriter, indexWrite, {interSubGroupLayout},
+ {});
+ }
+
+ /// Read from shared memory and complete subgroup reduction.
+ std::pair<Value, Value> doSubgroupReductionFromBuffer(
+ RewriterBase &rewriter, Location loc, Value valueBuffer,
+ Value indexBuffer, NestedLayoutAttr srcLayout,
+ VectorLayoutInterface resLayout, int64_t reductionDim,
+ Region &comparatorRegion, VectorValue initValue,
+ VectorValue initIndex) const {
+ NestedLayoutAttr readLayout =
+ getLayoutForReductionFromBuffer(srcLayout, {reductionDim});
+
+ Type valueElemType = getElementTypeOrSelf(valueBuffer);
+ Type indexElemType = getElementTypeOrSelf(indexBuffer);
+
+ auto valueReadTy =
+ VectorType::get(readLayout.getUndistributedShape(), valueElemType);
+ auto indexReadTy =
+ VectorType::get(readLayout.getUndistributedShape(), indexElemType);
+
+ auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ auto inBounds = rewriter.getBoolArrayAttr(
+ SmallVector<bool>(readLayout.getRank(), true));
+
+ // No mask needed: inBounds is all-true, so the mask would be all-true
+ // and redundant. The padding value handles any out-of-bounds lanes
+ // after distribution.
+ //
+ // Pad with init values so out-of-bounds lanes don't affect the
+ // subsequent ArgCompareOp (e.g., -inf for argmax, +inf for argmin).
+ // Unlike MultiDimReductionOp which has a CombiningKind to derive the
+ // identity via getCombiningIdentityValue, ArgCompareOp uses an opaque
+ // comparator region, so we rely on the init values instead.
+ SmallVector<int64_t> zeroIdx(initValue.getType().getRank(), 0);
+ Value valuePad =
+ vector::ExtractOp::create(rewriter, loc, initValue, zeroIdx);
+ Value indexPad =
+ vector::ExtractOp::create(rewriter, loc, initIndex, zeroIdx);
+
+ auto valueRead = vector::TransferReadOp::create(
+ rewriter, loc, valueReadTy, valueBuffer,
+ SmallVector<Value>(readLayout.getRank(), zero),
+ rewriter.getMultiDimIdentityMap(readLayout.getRank()), valuePad,
+ /*mask=*/Value(), inBounds);
+
+ auto indexRead = vector::TransferReadOp::create(
+ rewriter, loc, indexReadTy, indexBuffer,
+ SmallVector<Value>(readLayout.getRank(), zero),
+ rewriter.getMultiDimIdentityMap(readLayout.getRank()), indexPad,
+ /*mask=*/Value(), inBounds);
+
+ setSignatureForRedistribution(rewriter, valueRead, {}, {readLayout});
+ setSignatureForRedistribution(rewriter, indexRead, {}, {readLayout});
+
+ VectorType initValueType = initValue.getType();
+ VectorType initIndexType = initIndex.getType();
+
+ // This new ArgCompareOp will be distributed by a subsequent application
+ // of the same DistributeArgCompare pattern.
+ auto secondArgCompare = IREE::VectorExt::ArgCompareOp::create(
+ rewriter, loc, initValueType, initIndexType, valueRead, indexRead,
+ initValue, initIndex, Value(), reductionDim);
+
+ IRMapping mapper;
+ comparatorRegion.cloneInto(&secondArgCompare.getRegion(), mapper);
+
+ SmallVector<VectorLayoutInterface> inputLayouts = {readLayout, readLayout};
+ SmallVector<VectorLayoutInterface> resultLayouts;
+ if (isNonZeroRank(initValue)) {
+ inputLayouts.push_back(resLayout);
+ inputLayouts.push_back(resLayout);
+ resultLayouts.push_back(resLayout);
+ resultLayouts.push_back(resLayout);
+ } else {
+ ArrayRef<int64_t> empty = {};
+ auto emptyLayout =
+ NestedLayoutAttr::get(rewriter.getContext(), empty, empty, empty,
+ empty, empty, empty, empty);
+ inputLayouts.push_back(emptyLayout);
+ inputLayouts.push_back(emptyLayout);
+ resultLayouts.push_back(emptyLayout);
+ resultLayouts.push_back(emptyLayout);
+ }
+ setSignatureForRedistribution(rewriter, secondArgCompare.getOperation(),
+ inputLayouts, resultLayouts);
+
+ return {cast<VectorValue>(secondArgCompare.getResultValue()),
+ cast<VectorValue>(secondArgCompare.getResultIndex())};
+ }
+
+ /// Get layout for reading reduction results from shared memory.
+ NestedLayoutAttr
+ getLayoutForReductionFromBuffer(NestedLayoutAttr srcLayout,
+ ArrayRef<int64_t> reductionDims) const {
+ return computeLayoutForReductionFromBuffer(srcLayout, reductionDims);
+ }
+
+ int64_t subgroupSize;
+ int64_t maxBitsPerShuffle;
+};
+
/// The distribution of contract is performed by doing a local contraction where
/// each thread performs operations on its locally distributed elements. Then,
/// the resulting vector is interpreted in undistributed domain. The said
@@ -2220,6 +3083,8 @@
patterns.getContext());
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
+ patterns.add<DistributeArgCompare>(patterns.getContext(), subgroupSize,
+ maxBitsPerShuffle);
patterns.add<DistributeContract>(patterns.getContext());
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
patterns.add<DistributeInnerTiled>(patterns.getContext());
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
index b22d4dc..216b5bb 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel
@@ -46,6 +46,7 @@
"gpu_infer_memory_space.mlir",
"gpu_lower_coalesced_dma_to_global_loads.mlir",
"gpu_nested_layout_vector_distribution.mlir",
+ "gpu_nested_layout_vector_distribution_argcompare.mlir",
"gpu_nested_layout_vector_distribution_inner_tiled.mlir",
"gpu_nested_layout_vector_distribution_mask.mlir",
"gpu_nested_layout_vector_distribution_multi_reduce.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
index 7eab23e..094701e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt
@@ -41,6 +41,7 @@
"gpu_infer_memory_space.mlir"
"gpu_lower_coalesced_dma_to_global_loads.mlir"
"gpu_nested_layout_vector_distribution.mlir"
+ "gpu_nested_layout_vector_distribution_argcompare.mlir"
"gpu_nested_layout_vector_distribution_inner_tiled.mlir"
"gpu_nested_layout_vector_distribution_mask.mlir"
"gpu_nested_layout_vector_distribution_multi_reduce.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_argcompare.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_argcompare.mlir
new file mode 100644
index 0000000..f9e860f
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_argcompare.mlir
@@ -0,0 +1,566 @@
+// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize --cse %s | FileCheck %s
+
+// Tests for DistributeArgCompare pattern with explicit index mode.
+
+#layout_2d_element_only = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [1, 1],
+ element_tile = [16, 8],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [0, 0]
+>
+
+#layout_1d_element_only = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1],
+ batch_tile = [1],
+ outer_tile = [1],
+ thread_tile = [1],
+ element_tile = [16],
+
+ subgroup_strides = [0],
+ thread_strides = [0]
+>
+
+// Test: Element-tile-only argmax with explicit index. No thread shuffles needed
+// since thread_tile = [1, 1].
+// CHECK-LABEL: func @argmax_element_only
+// CHECK-SAME: %[[INPUT:.*]]: vector<16x8xf16>
+// CHECK-SAME: %[[INPUT_IDX:.*]]: vector<16x8xi32>
+// CHECK-SAME: %[[INIT_VAL:.*]]: vector<16xf16>
+// CHECK-SAME: %[[INIT_IDX:.*]]: vector<16xi32>
+func.func @argmax_element_only(
+ %input: vector<16x8xf16>,
+ %input_idx: vector<16x8xi32>,
+ %init_val: vector<16xf16>,
+ %init_idx: vector<16xi32>) -> (vector<16xf16>, vector<16xi32>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_2d_element_only) : vector<16x8xf16>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_2d_element_only) : vector<16x8xi32>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_1d_element_only) : vector<16xf16>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_1d_element_only) : vector<16xi32>
+
+ // Local inline reduction within element tile using scalar extract/cmpf/select.
+ // No shuffles needed since thread_tile[1] = 1.
+ // CHECK-DAG: %[[DIS_INPUT:.*]] = iree_vector_ext.to_simt %[[INPUT]] : vector<16x8xf16> -> vector<1x1x1x1x16x8xf16>
+ // CHECK: %[[ELEM0:.*]] = vector.extract %[[DIS_INPUT]][0, 0, 0, 0, 0, 0]
+ // CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[ELEM0]], %{{.*}} : f16
+ // CHECK: arith.select %[[CMP]],
+ // CHECK-NOT: gpu.shuffle
+ %result:2 = iree_vector_ext.arg_compare dimension(1)
+ ins(%input_layout, %input_idx_layout : vector<16x8xf16>, vector<16x8xi32>)
+ inits(%init_val_layout, %init_idx_layout : vector<16xf16>, vector<16xi32>) {
+ ^bb0(%lhs: f16, %rhs: f16):
+ %cmp = arith.cmpf ogt, %lhs, %rhs : f16
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<16xf16>, vector<16xi32>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_1d_element_only) : vector<16xf16>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_1d_element_only) : vector<16xi32>
+
+ func.return %result_val_layout, %result_idx_layout : vector<16xf16>, vector<16xi32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#layout_2d_thread_reduce = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [2, 2],
+ outer_tile = [1, 1],
+ thread_tile = [16, 4],
+ element_tile = [1, 4],
+
+ subgroup_strides = [1, 1],
+ thread_strides = [1, 16]
+>
+
+#layout_1d_thread_reduce = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1],
+ batch_tile = [2],
+ outer_tile = [1],
+ thread_tile = [16],
+ element_tile = [1],
+
+ subgroup_strides = [1],
+ thread_strides = [1]
+>
+
+// Test: Ballot-based thread reduction with i64 indices.
+// CHECK-LABEL: func @argmax_i64_index
+func.func @argmax_i64_index(
+ %input: vector<32x32xf32>,
+ %input_idx: vector<32x32xi64>,
+ %init_val: vector<32xf32>,
+ %init_idx: vector<32xi64>) -> (vector<32xf32>, vector<32xi64>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_2d_thread_reduce) : vector<32x32xf32>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_2d_thread_reduce) : vector<32x32xi64>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_1d_thread_reduce) : vector<32xf32>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_1d_thread_reduce) : vector<32xi64>
+
+ // Local reduction.
+ // CHECK: %[[ELEM0:.*]] = vector.extract %{{.*}}[0, 0, 0, 0, 0, 0]
+ // CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[ELEM0]], %{{.*}}
+ // CHECK: arith.select %[[CMP]],
+ // Ballot-based thread reduction (4 threads on dim 1).
+ // CHECK: %[[REDUCED:.*]] = gpu.subgroup_reduce maxnumf %[[LOCAL_VAL:[a-z0-9]+]]
+ // CHECK: %[[IS_WINNER:.*]] = arith.cmpf oeq, %[[LOCAL_VAL]], %[[REDUCED]] : f32
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[IS_WINNER]]
+ // CHECK: %[[MASKED:.*]] = arith.andi %[[BALLOT]],
+ // CHECK: %[[WINNER_LANE_I64:.*]] = math.cttz %[[MASKED]]
+ // CHECK: %[[WINNER_LANE:.*]] = arith.trunci %[[WINNER_LANE_I64]]
+ // CHECK: gpu.shuffle idx {{.*}}, %[[WINNER_LANE]], {{.*}} : i64
+ %result:2 = iree_vector_ext.arg_compare dimension(1)
+ ins(%input_layout, %input_idx_layout : vector<32x32xf32>, vector<32x32xi64>)
+ inits(%init_val_layout, %init_idx_layout : vector<32xf32>, vector<32xi64>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf ogt, %lhs, %rhs : f32
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<32xf32>, vector<32xi64>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_1d_thread_reduce) : vector<32xf32>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_1d_thread_reduce) : vector<32xi64>
+
+ func.return %result_val_layout, %result_idx_layout : vector<32xf32>, vector<32xi64>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// Test: Custom comparator falls back to butterfly shuffles.
+// CHECK-LABEL: func @argmax_custom_comparator_i64_index
+func.func @argmax_custom_comparator_i64_index(
+ %input: vector<32x32xf32>,
+ %input_idx: vector<32x32xi64>,
+ %init_val: vector<32xf32>,
+ %init_idx: vector<32xi64>) -> (vector<32xf32>, vector<32xi64>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_2d_thread_reduce) : vector<32x32xf32>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_2d_thread_reduce) : vector<32x32xi64>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_1d_thread_reduce) : vector<32xf32>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_1d_thread_reduce) : vector<32xi64>
+
+ // CHECK: gpu.shuffle xor {{.*}} : f32
+ // CHECK: gpu.shuffle xor {{.*}} : i64
+ %result:2 = iree_vector_ext.arg_compare dimension(1)
+ ins(%input_layout, %input_idx_layout : vector<32x32xf32>, vector<32x32xi64>)
+ inits(%init_val_layout, %init_idx_layout : vector<32xf32>, vector<32xi64>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %lhs2 = arith.mulf %lhs, %lhs : f32
+ %rhs2 = arith.mulf %rhs, %rhs : f32
+ %cmp = arith.cmpf ogt, %lhs2, %rhs2 : f32
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<32xf32>, vector<32xi64>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_1d_thread_reduce) : vector<32xf32>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_1d_thread_reduce) : vector<32xi64>
+
+ func.return %result_val_layout, %result_idx_layout : vector<32xf32>, vector<32xi64>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// Test: Integer argmax with ballot-based thread reduction (arith.cmpi + maxsi).
+// CHECK-LABEL: func @argmax_integer_thread_reduction
+func.func @argmax_integer_thread_reduction(
+ %input: vector<32x32xi32>,
+ %input_idx: vector<32x32xi32>,
+ %init_val: vector<32xi32>,
+ %init_idx: vector<32xi32>) -> (vector<32xi32>, vector<32xi32>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_2d_thread_reduce) : vector<32x32xi32>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_2d_thread_reduce) : vector<32x32xi32>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_1d_thread_reduce) : vector<32xi32>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_1d_thread_reduce) : vector<32xi32>
+
+ // CHECK: %[[ELEM0:.*]] = vector.extract %{{.*}}[0, 0, 0, 0, 0, 0]
+ // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ELEM0]], %{{.*}}
+ // CHECK: arith.select %[[CMP]],
+ // CHECK: vector.extract %{{.*}}[0, 1, 0, 0, 0, 0]
+ // CHECK: vector.extract %{{.*}}[1, 0, 0, 0, 0, 0]
+ // CHECK: %[[REDUCED:.*]] = gpu.subgroup_reduce maxsi %[[LOCAL_VAL:[a-z0-9]+]]
+ // CHECK: %[[IS_WINNER:.*]] = arith.cmpi eq, %[[LOCAL_VAL]], %[[REDUCED]] : i32
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[IS_WINNER]]
+ // CHECK: %[[MASKED:.*]] = arith.andi %[[BALLOT]],
+ // CHECK: %[[WINNER_LANE_I64:.*]] = math.cttz %[[MASKED]]
+ // CHECK: %[[WINNER_LANE:.*]] = arith.trunci %[[WINNER_LANE_I64]]
+ // CHECK: gpu.shuffle idx {{.*}}, %[[WINNER_LANE]],
+ %result:2 = iree_vector_ext.arg_compare dimension(1)
+ ins(%input_layout, %input_idx_layout : vector<32x32xi32>, vector<32x32xi32>)
+ inits(%init_val_layout, %init_idx_layout : vector<32xi32>, vector<32xi32>) {
+ ^bb0(%lhs: i32, %rhs: i32):
+ %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<32xi32>, vector<32xi32>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_1d_thread_reduce) : vector<32xi32>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_1d_thread_reduce) : vector<32xi32>
+
+ func.return %result_val_layout, %result_idx_layout : vector<32xi32>, vector<32xi32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#layout_2d_explicit = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 1],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [4, 4],
+ element_tile = [1, 8],
+
+ subgroup_strides = [0, 0],
+ thread_strides = [4, 1]
+>
+
+#layout_1d_explicit = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1],
+ batch_tile = [1],
+ outer_tile = [1],
+ thread_tile = [4],
+ element_tile = [1],
+
+ subgroup_strides = [0],
+ thread_strides = [4]
+>
+
+// Test: Argmin with explicit index mode using ballot-based thread reduction.
+// CHECK-LABEL: func @argmin_explicit_index
+// CHECK-SAME: %[[INPUT_VAL:.*]]: vector<4x32xf32>
+// CHECK-SAME: %[[INPUT_IDX:.*]]: vector<4x32xi32>
+// CHECK-SAME: %[[INIT_VAL:.*]]: vector<4xf32>
+// CHECK-SAME: %[[INIT_IDX:.*]]: vector<4xi32>
+func.func @argmin_explicit_index(
+ %input_val: vector<4x32xf32>,
+ %input_idx: vector<4x32xi32>,
+ %init_val: vector<4xf32>,
+ %init_idx: vector<4xi32>) -> (vector<4xf32>, vector<4xi32>) {
+
+ %input_val_layout = iree_vector_ext.to_layout %input_val to layout(#layout_2d_explicit) : vector<4x32xf32>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_2d_explicit) : vector<4x32xi32>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_1d_explicit) : vector<4xf32>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_1d_explicit) : vector<4xi32>
+
+ // Inline local element reduction with explicit index input.
+ // CHECK: %[[ELEM0:.*]] = vector.extract %{{.*}}[0, 0, 0, 0, 0, 0]
+ // CHECK: %[[CMP:.*]] = arith.cmpf olt, %[[ELEM0]], %{{.*}}
+ // CHECK: arith.select %[[CMP]],
+ // Thread reduction with ballot-based approach.
+ // CHECK: %[[REDUCED:.*]] = gpu.subgroup_reduce minnumf %[[LOCAL_VAL:[a-z0-9]+]]
+ // CHECK: %[[IS_WINNER:.*]] = arith.cmpf oeq, %[[LOCAL_VAL]], %[[REDUCED]] : f32
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[IS_WINNER]]
+ // CHECK: %[[WINNER_LANE_I64:.*]] = math.cttz %[[BALLOT]]
+ // CHECK: %[[WINNER_LANE:.*]] = arith.trunci %[[WINNER_LANE_I64]]
+ // CHECK: gpu.shuffle idx {{.*}}, %[[WINNER_LANE]],
+ %result:2 = iree_vector_ext.arg_compare dimension(1)
+ ins(%input_val_layout, %input_idx_layout : vector<4x32xf32>, vector<4x32xi32>)
+ inits(%init_val_layout, %init_idx_layout : vector<4xf32>, vector<4xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf olt, %lhs, %rhs : f32
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<4xf32>, vector<4xi32>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_1d_explicit) : vector<4xf32>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_1d_explicit) : vector<4xi32>
+
+ func.return %result_val_layout, %result_idx_layout : vector<4xf32>, vector<4xi32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#layout_2d_subgroup = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1, 2],
+ batch_tile = [1, 1],
+ outer_tile = [1, 1],
+ thread_tile = [8, 4],
+ element_tile = [1, 4],
+
+ subgroup_strides = [0, 1],
+ thread_strides = [1, 8]
+>
+
+#layout_1d_subgroup = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1],
+ batch_tile = [1],
+ outer_tile = [1],
+ thread_tile = [8],
+ element_tile = [1],
+
+ subgroup_strides = [0],
+ thread_strides = [1]
+>
+
+// Test: Subgroup reduction via shared memory with explicit index.
+// 2 subgroups participate in reduction on dim 1.
+// CHECK-LABEL: func @argmax_subgroup_reduction
+// CHECK-SAME: %[[INPUT:.*]]: vector<8x32xf32>
+// CHECK-SAME: %[[INPUT_IDX:.*]]: vector<8x32xi32>
+// CHECK-SAME: %[[INIT_VAL:.*]]: vector<8xf32>
+// CHECK-SAME: %[[INIT_IDX:.*]]: vector<8xi32>
+func.func @argmax_subgroup_reduction(
+ %input: vector<8x32xf32>,
+ %input_idx: vector<8x32xi32>,
+ %init_val: vector<8xf32>,
+ %init_idx: vector<8xi32>) -> (vector<8xf32>, vector<8xi32>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_2d_subgroup) : vector<8x32xf32>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_2d_subgroup) : vector<8x32xi32>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_1d_subgroup) : vector<8xf32>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_1d_subgroup) : vector<8xi32>
+
+ // Inline local element reduction + thread reduction (ballot-based) + subgroup reduction via shared memory.
+ // CHECK: %[[ELEM0:.*]] = vector.extract %{{.*}}[0, 0, 0, 0, 0, 0]
+ // CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[ELEM0]], %{{.*}}
+ // CHECK: arith.select %[[CMP]],
+ // Thread reduction with ballot-based approach (4 threads on dim 1).
+ // CHECK: %[[REDUCED:.*]] = gpu.subgroup_reduce maxnumf %[[LOCAL_VAL:[a-z0-9]+]]
+ // CHECK: %[[IS_WINNER:.*]] = arith.cmpf oeq, %[[LOCAL_VAL]], %[[REDUCED]] : f32
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[IS_WINNER]]
+ // CHECK: %[[MASKED:.*]] = arith.andi %[[BALLOT]],
+ // CHECK: %[[WINNER_LANE_I64:.*]] = math.cttz %[[MASKED]]
+ // CHECK: %[[WINNER_LANE:.*]] = arith.trunci %[[WINNER_LANE_I64]]
+ // CHECK: gpu.shuffle idx {{.*}}, %[[WINNER_LANE]],
+ // Shared memory operations for subgroup reduction.
+ // CHECK: %[[VAL_ALLOC:.*]] = memref.alloc
+ // CHECK: %[[IDX_ALLOC:.*]] = memref.alloc
+ // CHECK: gpu.barrier
+ // CHECK: vector.transfer_write {{.*}}, %[[VAL_ALLOC]]
+ // CHECK: vector.transfer_write {{.*}}, %[[IDX_ALLOC]]
+ // CHECK: gpu.barrier
+ // Padding uses init values so out-of-bounds lanes don't affect the result.
+ // CHECK: %[[VAL_PAD:.*]] = vector.extract %[[INIT_VAL]][0]
+ // CHECK: %[[IDX_PAD:.*]] = vector.extract %[[INIT_IDX]][0]
+ // CHECK: vector.transfer_read %[[VAL_ALLOC]]{{.*}}, %[[VAL_PAD]]
+ // CHECK: vector.transfer_read %[[IDX_ALLOC]]{{.*}}, %[[IDX_PAD]]
+ // Second reduction across subgroups uses ballot-based approach.
+ // CHECK: gpu.subgroup_reduce maxnumf
+ // CHECK: gpu.shuffle idx
+ %result:2 = iree_vector_ext.arg_compare dimension(1)
+ ins(%input_layout, %input_idx_layout : vector<8x32xf32>, vector<8x32xi32>)
+ inits(%init_val_layout, %init_idx_layout : vector<8xf32>, vector<8xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf ogt, %lhs, %rhs : f32
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<8xf32>, vector<8xi32>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_1d_subgroup) : vector<8xf32>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_1d_subgroup) : vector<8xi32>
+
+ func.return %result_val_layout, %result_idx_layout : vector<8xf32>, vector<8xi32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test: Full reduction to 0-d (scalar) result with thread reduction.
+// This tests the special handling for 0-d vectors where:
+// 1. Init values are unwrapped from ToSIMD (identity op for 0-d)
+// 2. reshapeFlatToTarget uses extract+broadcast instead of shape_cast
+
+#layout_1d_full_reduce = #iree_vector_ext.nested_layout<
+ subgroup_tile = [1],
+ batch_tile = [1],
+ outer_tile = [1],
+ thread_tile = [4],
+ element_tile = [8],
+
+ subgroup_strides = [0],
+ thread_strides = [1]
+>
+
+#layout_0d = #iree_vector_ext.nested_layout<
+ subgroup_tile = [],
+ batch_tile = [],
+ outer_tile = [],
+ thread_tile = [],
+ element_tile = [],
+
+ subgroup_strides = [],
+ thread_strides = []
+>
+
+// CHECK-LABEL: func @argmax_full_reduce_to_scalar
+// CHECK-SAME: %[[INPUT:.*]]: vector<32xf32>
+// CHECK-SAME: %[[INPUT_IDX:.*]]: vector<32xi32>
+// CHECK-SAME: %[[INIT_VAL:.*]]: vector<f32>
+// CHECK-SAME: %[[INIT_IDX:.*]]: vector<i32>
+func.func @argmax_full_reduce_to_scalar(
+ %input: vector<32xf32>,
+ %input_idx: vector<32xi32>,
+ %init_val: vector<f32>,
+ %init_idx: vector<i32>) -> (vector<f32>, vector<i32>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_1d_full_reduce) : vector<32xf32>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_1d_full_reduce) : vector<32xi32>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_0d) : vector<f32>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_0d) : vector<i32>
+
+ // Full reduction to scalar. Init values are 0-d vectors (scalars).
+ // The pattern should:
+ // 1. Unwrap ToSIMD for 0-d init values (identity op)
+ // 2. Use extract+broadcast instead of shape_cast for 0-d reshape
+ // Distributed shape: vector<1x1x8xf32> (batch=1, outer=1, element=8)
+ // CHECK: %[[ELEM0:.*]] = vector.extract %{{.*}}[0, 0, 0] : f32 from vector<1x1x8xf32>
+ // CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[ELEM0]], %{{.*}}
+ // CHECK: arith.select %[[CMP]],
+ // Thread reduction with ballot-based approach (4 threads).
+ // CHECK: %[[REDUCED:.*]] = gpu.subgroup_reduce maxnumf %[[LOCAL_VAL:[a-z0-9]+]]
+ // CHECK: %[[IS_WINNER:.*]] = arith.cmpf oeq, %[[LOCAL_VAL]], %[[REDUCED]] : f32
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[IS_WINNER]]
+ // CHECK: %[[WINNER_LANE_I64:.*]] = math.cttz %[[BALLOT]]
+ // CHECK: %[[WINNER_LANE:.*]] = arith.trunci %[[WINNER_LANE_I64]]
+ // CHECK: gpu.shuffle idx {{.*}}, %[[WINNER_LANE]],
+ // For 0-d result, use extract+broadcast instead of shape_cast.
+ // CHECK: vector.broadcast {{.*}} : f32 to vector<f32>
+ %result:2 = iree_vector_ext.arg_compare dimension(0)
+ ins(%input_layout, %input_idx_layout : vector<32xf32>, vector<32xi32>)
+ inits(%init_val_layout, %init_idx_layout : vector<f32>, vector<i32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf ogt, %lhs, %rhs : f32
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<f32>, vector<i32>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_0d) : vector<f32>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_0d) : vector<i32>
+
+ func.return %result_val_layout, %result_idx_layout : vector<f32>, vector<i32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// Test: Full reduction to 0-d with subgroup reduction.
+// Exercises the combination of 0-d result + subgroup_tile > 1, which requires
+// shared memory exchange between subgroups before producing a scalar output.
+
+#layout_1d_full_reduce_subgroup = #iree_vector_ext.nested_layout<
+ subgroup_tile = [2],
+ batch_tile = [1],
+ outer_tile = [1],
+ thread_tile = [4],
+ element_tile = [8],
+
+ subgroup_strides = [1],
+ thread_strides = [1]
+>
+
+#layout_0d_subgroup = #iree_vector_ext.nested_layout<
+ subgroup_tile = [],
+ batch_tile = [],
+ outer_tile = [],
+ thread_tile = [],
+ element_tile = [],
+
+ subgroup_strides = [],
+ thread_strides = []
+>
+
+// CHECK-LABEL: func @argmax_full_reduce_with_subgroup
+// CHECK-SAME: %[[INPUT:.*]]: vector<64xf16>
+// CHECK-SAME: %[[INPUT_IDX:.*]]: vector<64xi32>
+// CHECK-SAME: %[[INIT_VAL:.*]]: vector<f16>
+// CHECK-SAME: %[[INIT_IDX:.*]]: vector<i32>
+func.func @argmax_full_reduce_with_subgroup(
+ %input: vector<64xf16>,
+ %input_idx: vector<64xi32>,
+ %init_val: vector<f16>,
+ %init_idx: vector<i32>) -> (vector<f16>, vector<i32>) {
+
+ %input_layout = iree_vector_ext.to_layout %input to layout(#layout_1d_full_reduce_subgroup) : vector<64xf16>
+ %input_idx_layout = iree_vector_ext.to_layout %input_idx to layout(#layout_1d_full_reduce_subgroup) : vector<64xi32>
+ %init_val_layout = iree_vector_ext.to_layout %init_val to layout(#layout_0d_subgroup) : vector<f16>
+ %init_idx_layout = iree_vector_ext.to_layout %init_idx to layout(#layout_0d_subgroup) : vector<i32>
+
+ // Init value scalars are extracted once and reused for both local reduction
+ // and as padding for the shared memory transfer_read.
+ // CHECK: %[[VAL_PAD:.*]] = vector.extract %[[INIT_VAL]][]
+ // CHECK: %[[IDX_PAD:.*]] = vector.extract %[[INIT_IDX]][]
+ // CHECK: %[[ELEM0:.*]] = vector.extract %{{.*}}[0, 0, 0] : f16 from vector<1x1x8xf16>
+ // CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[ELEM0]], %{{.*}}
+ // CHECK: arith.select %[[CMP]],
+ // CHECK: %[[REDUCED:.*]] = gpu.subgroup_reduce maxnumf %[[LOCAL_VAL:[a-z0-9]+]]
+ // CHECK: %[[IS_WINNER:.*]] = arith.cmpf oeq, %[[LOCAL_VAL]], %[[REDUCED]] : f16
+ // CHECK: %[[BALLOT:.*]] = gpu.ballot %[[IS_WINNER]]
+ // CHECK: gpu.shuffle idx {{.*}} : i32
+ // CHECK: %[[VAL_ALLOC:.*]] = memref.alloc
+ // CHECK: %[[IDX_ALLOC:.*]] = memref.alloc
+ // CHECK: gpu.barrier
+ // CHECK: vector.transfer_write {{.*}}, %[[VAL_ALLOC]]
+ // CHECK: vector.transfer_write {{.*}}, %[[IDX_ALLOC]]
+ // CHECK: gpu.barrier
+ // CHECK: vector.transfer_read %[[VAL_ALLOC]]{{.*}}, %[[VAL_PAD]]
+ // CHECK: vector.transfer_read %[[IDX_ALLOC]]{{.*}}, %[[IDX_PAD]]
+ // CHECK: gpu.subgroup_reduce maxnumf
+ // CHECK: gpu.shuffle idx
+ // CHECK: vector.broadcast {{.*}} : f16 to vector<f16>
+ %result:2 = iree_vector_ext.arg_compare dimension(0)
+ ins(%input_layout, %input_idx_layout : vector<64xf16>, vector<64xi32>)
+ inits(%init_val_layout, %init_idx_layout : vector<f16>, vector<i32>) {
+ ^bb0(%lhs: f16, %rhs: f16):
+ %cmp = arith.cmpf ogt, %lhs, %rhs : f16
+ iree_vector_ext.yield %cmp : i1
+ } -> vector<f16>, vector<i32>
+
+ %result_val_layout = iree_vector_ext.to_layout %result#0 to layout(#layout_0d_subgroup) : vector<f16>
+ %result_idx_layout = iree_vector_ext.to_layout %result#1 to layout(#layout_0d_subgroup) : vector<i32>
+
+ func.return %result_val_layout, %result_idx_layout : vector<f16>, vector<i32>
+}
+
+builtin.module attributes { transform.with_named_sequence } {
+ transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
+ %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
+ transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
+ transform.yield
+ }
+}