[LLVMGPUVectorDistribute] Refactor vector.contract distribute (#19631)
Currently, vector.contract distribution is implemented as a standalone
distribution closely following vector.multi_reduce. Therefore, we have
to duplicate code/effort when we improve either one.
This commit changes vector.contract just to distribute the "contract"
part of it. Then it creates a new vector.multi_reduce to be
re-distributed with partial reduction semantics. Thus, allowing the
improvements of vector.multi_reduce to be re-used by vector.contract
closes : https://github.com/iree-org/iree/issues/19620
---------
Signed-off-by: Manupa Karunaratne <manupa.karunaratne@amd.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index 49981f1..17b9a5b 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -442,13 +442,6 @@
}
Type elemTy = srcVector.getType().getElementType();
- unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
- if (elemBitwidth != maxBitsPerShuffle) {
- return rewriter.notifyMatchFailure(
- multiReduceOp, llvm::formatv("unimplemented: packed shuffle",
- elemBitwidth, maxBitsPerShuffle));
- }
-
VectorValue disSrc =
getDistributed(rewriter, srcVector, signature[srcVector]);
@@ -770,24 +763,18 @@
int64_t maxBitsPerShuffle;
};
-/// The lowering for Contract is performed in three steps (similar to above
-/// multi_reduction):
-/// 1. Local Contract: Each thread performs operations on its locally
-/// distributed elements.
-/// 2. Subgroup Reduction: Threads in each subgroup reduce the results from
-/// step 1 across threads using a subgroup reduction if distribution occurs
-/// along the reduction dimension.
-/// 3. Accumulator Reduction: Each thread combines its intermediate results
-/// with its held accumulator.
-///
-/// Currently, reduction across multiple warps is not supported.
+/// The distribution of contract is performed by doing a local contraction where
+/// each thread performs operations on its locally distributed elements. Then,
+/// the resulting vector is interpreted in undistributed domain. The said
+/// undistributed vector is a partial reduction when contraction has been
+/// performed only thread locally. Therefore, a to-be-distributed
+/// vector.multi_reduce
+////is added to complete the contraction.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;
- DistributeContract(MLIRContext *context, int64_t subgroupSize,
- int64_t maxBitsPerShuffle, int64_t benefit = 1)
- : OpDistributionPattern(context, benefit), subgroupSize(subgroupSize),
- maxBitsPerShuffle(maxBitsPerShuffle) {}
+ DistributeContract(MLIRContext *context, int64_t benefit = 1)
+ : OpDistributionPattern(context, benefit) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
DistributionSignature &signature,
@@ -817,6 +804,16 @@
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction rhs");
}
+ NestedLayoutAttr resLayout;
+ if (auto contractRes = dyn_cast<VectorValue>(contractOp.getResult())) {
+ resLayout = dyn_cast<NestedLayoutAttr>(signature[contractRes]);
+ } else {
+ // Create a zero-d layout because we
+ // are going to add reduction dims
+ // back to handle the partial reduction
+ resLayout = NestedLayoutAttr::get(
+ contractOp.getContext(), ArrayRef<int64_t>{}, {}, {}, {}, {}, {}, {});
+ }
Value disLhs = getDistributed(rewriter, contractOp.getLhs(), lhsLayout);
Value disRhs = getDistributed(rewriter, contractOp.getRhs(), rhsLayout);
@@ -838,21 +835,10 @@
Location loc = contractOp.getLoc();
// Step 1: local contraction
+ Value localInit = getCombiningIdentityValue(
+ loc, rewriter, contractOp.getKind(), disAcc.getType());
vector::ContractionOp localContractOp = doDistributedContraction(
- rewriter, loc, ctx, contractOp, disLhs, disRhs, disAcc);
-
- int64_t rank = lhsLayout.getRank();
- SmallVector<bool> reducedDims(rank, false);
-
- // Identify the reduction dimension and apply it for subgroup reduction.
- for (auto [index, iteratorType] :
- llvm::enumerate(contractOp.getIteratorTypes())) {
- if (vector::isReductionIterator(iteratorType)) {
- auto map = contractOp.getIndexingMapsArray()[0];
- int64_t redIdx = *(map.getResultPosition(getAffineDimExpr(index, ctx)));
- reducedDims[redIdx] = true;
- }
- }
+ rewriter, loc, ctx, contractOp, disLhs, disRhs, localInit);
VectorValue localContractValue;
if (accVector) {
@@ -865,46 +851,79 @@
assert(localContractValue && "result should have been a vector");
- // Flatten the locally result value.
- VectorType shaped = localContractValue.getType();
- int64_t numElements = shaped.getNumElements();
- SmallVector<int64_t> flatShape(1, numElements);
- VectorType flatVecType = VectorType::get(flatShape, accElemTy);
- VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
- localContractValue);
-
- // Step 2: Do subgroup reduction.
- FailureOr<VectorValue> threadReduced = doThreadReduction(
- rewriter, lhsLayout, flat, contractOp.getKind(), reducedDims);
- if (failed(threadReduced)) {
- return failure();
+ // Identify the reduction dimension and apply it for subgroup reduction.
+ auto lhsMap = contractOp.getIndexingMapsArray()[0];
+ SmallVector<int64_t> reductionSubGroupTile;
+ SmallVector<int64_t> reductionSubGroupStrides;
+ SmallVector<int64_t> reductionThreadTile;
+ SmallVector<int64_t> reductionThreadStrides;
+ SmallVector<int64_t> partialReductionDims;
+ for (auto [index, iteratorType] :
+ llvm::enumerate(contractOp.getIteratorTypes())) {
+ if (vector::isReductionIterator(iteratorType)) {
+ int64_t redLhsIdx =
+ *(lhsMap.getResultPosition(getAffineDimExpr(index, ctx)));
+ partialReductionDims.push_back(resLayout.getRank() +
+ reductionSubGroupTile.size());
+ reductionSubGroupTile.push_back(lhsLayout.getSubgroupTile()[redLhsIdx]);
+ reductionSubGroupStrides.push_back(
+ lhsLayout.getSubgroupStrides()[redLhsIdx]);
+ reductionThreadTile.push_back(lhsLayout.getThreadTile()[redLhsIdx]);
+ reductionThreadStrides.push_back(
+ lhsLayout.getThreadStrides()[redLhsIdx]);
+ }
}
+ SmallVector<int64_t> unitBroadcastTile(reductionThreadTile.size(), 1);
- // Do reduction against accumulator, which needs to be done after thread
- // reduction.
- VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
- loc, shaped, threadReduced.value());
+ // Manually infer the layout of partial reduction
+ // We do this by appending the reduction dims on
+ // subgroup and thread tiles to the layout of the
+ // result.
+ IREE::VectorExt::NestedLayoutAttr reductionLayout =
+ IREE::VectorExt::NestedLayoutAttr::get(
+ contractOp.getContext(),
+ /*source=*/resLayout,
+ /*appendSubGroupLens=*/reductionSubGroupTile,
+ /*appendBatchLens=*/unitBroadcastTile,
+ /*appendOuterLens=*/unitBroadcastTile,
+ /*appendThreadLens=*/reductionThreadTile,
+ /*appendElementLens=*/unitBroadcastTile,
+ /*appendSubgroupStrides=*/reductionSubGroupStrides,
+ /*appendThreadStrides=*/reductionThreadStrides);
- if (!accVector) {
- disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
+ VectorType partialReducedDistributedType =
+ VectorType::get(reductionLayout.getDistributedShape(),
+ localContractValue.getType().getElementType());
+ Value shapeCasted = rewriter.create<vector::ShapeCastOp>(
+ loc, partialReducedDistributedType, localContractValue);
+ VectorType unDistributedType =
+ VectorType::get(reductionLayout.getUndistributedShape(),
+ localContractValue.getType().getElementType());
+ Value undistrLocalReduced = rewriter.create<IREE::VectorExt::ToSIMDOp>(
+ loc, unDistributedType, shapeCasted);
+
+ // Create the partial reduction
+ auto partialReduction = rewriter.create<vector::MultiDimReductionOp>(
+ loc, contractOp.getKind(), undistrLocalReduced, acc,
+ partialReductionDims);
+ {
+ auto unitAttr = UnitAttr::get(rewriter.getContext());
+ auto reduceAttrs =
+ SmallVector<Attribute>(partialReduction->getNumOperands(), unitAttr);
+ reduceAttrs[0] = reductionLayout;
+ ArrayAttr reduceResultsAttr =
+ ArrayAttr::get(rewriter.getContext(), {unitAttr});
+ if (auto dstLayout =
+ dyn_cast_or_null<NestedLayoutAttr>(signature[resVector])) {
+ reduceAttrs[1] = dstLayout;
+ reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
+ }
+ ArrayAttr reduceOperandsAttr =
+ ArrayAttr::get(rewriter.getContext(), reduceAttrs);
+ setSignatureForRedistribution(rewriter, partialReduction.getOperation(),
+ reduceOperandsAttr, reduceResultsAttr);
}
-
- // Step 3: Accumulator Reduction
- Value accReduction = vector::makeArithReduction(
- rewriter, loc, contractOp.getKind(), unflattened, disAcc);
- auto accReduced = dyn_cast<VectorValue>(accReduction);
- if (!accReduced) {
- return failure();
- }
-
- if (resVector) {
- replaceOpWithDistributedValues(rewriter, contractOp, accReduced);
- } else {
- Value accReducedVal = rewriter.create<vector::ExtractOp>(
- loc, accReduction, SmallVector<int64_t>{0});
- replaceOpWithDistributedValues(rewriter, contractOp, accReducedVal);
- }
-
+ rewriter.replaceOp(contractOp, partialReduction);
return success();
}
@@ -954,46 +973,6 @@
return localContractOp;
}
-
- FailureOr<VectorValue> doThreadReduction(RewriterBase &rewriter,
- NestedLayoutAttr layout,
- VectorValue flat,
- vector::CombiningKind kind,
- ArrayRef<bool> reductionMask) const {
- VectorType flatVecType = flat.getType();
- int64_t numElements = flatVecType.getNumElements();
- Location loc = flat.getLoc();
-
- auto constOp = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(flatVecType));
- auto res = llvm::cast<VectorValue>(constOp.getResult());
-
- for (unsigned i = 0; i < numElements; ++i) {
- Value extracted = rewriter.create<vector::ExtractOp>(loc, flat, i);
-
- // Reduce across all reduction dimensions 1-by-1.
- for (unsigned i = 0, e = reductionMask.size(); i != e; ++i) {
- if (reductionMask[i]) {
- int64_t offset = getShuffleOffset(layout, i);
- int64_t width = getShuffleWidth(layout, i);
- assert(offset <= std::numeric_limits<uint32_t>::max() &&
- width <= std::numeric_limits<uint32_t>::max());
-
- extracted = rewriter.create<gpu::SubgroupReduceOp>(
- loc, extracted, combiningKindToAllReduce(kind),
- /*uniform=*/false, /*cluster_size=*/width,
- /*cluster_stride=*/offset);
- }
- }
-
- res = rewriter.create<vector::InsertOp>(loc, extracted, res, i);
- }
-
- return res;
- }
-
- int64_t subgroupSize;
- int64_t maxBitsPerShuffle;
};
struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
@@ -1344,8 +1323,7 @@
patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
maxBitsPerShuffle);
- patterns.add<DistributeContract>(patterns.getContext(), subgroupSize,
- maxBitsPerShuffle);
+ patterns.add<DistributeContract>(patterns.getContext());
patterns.add<DistributeBatchOuterToLayoutConversions>(patterns.getContext());
patterns.add<DistributeStep>(patterns.getContext(), threadId, subgroupSize);
}
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index d4062bb..e1bc3cc 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -112,6 +112,19 @@
return shape;
}
+SmallVector<int64_t> NestedLayoutAttr::getUndistributedShape() const {
+ int64_t rank = getRank();
+ SmallVector<int64_t> shape;
+ shape.reserve(rank);
+ for (int64_t i : llvm::seq<int64_t>(rank)) {
+ int64_t expectedDimLen = getSubgroupTile()[i] * getBatchTile()[i] *
+ getOuterTile()[i] * getThreadTile()[i] *
+ getElementTile()[i];
+ shape.push_back(expectedDimLen);
+ }
+ return shape;
+}
+
// Gets the rank of the undistributed vector for this layout.
int64_t NestedLayoutAttr::getRank() const {
// The layout requires that all size lists are the same length and match
@@ -198,6 +211,42 @@
normalizedThreadStrides);
}
+static SmallVector<int64_t> appendDims(ArrayRef<int64_t> tileLens,
+ ArrayRef<int64_t> appendLens) {
+ SmallVector<int64_t> tileLensResult = llvm::to_vector(tileLens);
+ tileLensResult.insert(tileLensResult.end(), appendLens.begin(),
+ appendLens.end());
+ return tileLensResult;
+}
+
+NestedLayoutAttr NestedLayoutAttr::get(MLIRContext *context,
+ NestedLayoutAttr source,
+ ArrayRef<int64_t> appendSubGroupLens,
+ ArrayRef<int64_t> appendBatchLens,
+ ArrayRef<int64_t> appendOuterLens,
+ ArrayRef<int64_t> appendThreadLens,
+ ArrayRef<int64_t> appendElementLens,
+ ArrayRef<int64_t> appendSubgroupStrides,
+ ArrayRef<int64_t> appendThreadStrides) {
+ SmallVector<int64_t> subgroupTile =
+ appendDims(source.getSubgroupTile(), appendSubGroupLens);
+ SmallVector<int64_t> batchTile =
+ appendDims(source.getBatchTile(), appendBatchLens);
+ SmallVector<int64_t> outerTile =
+ appendDims(source.getOuterTile(), appendOuterLens);
+ SmallVector<int64_t> threadTile =
+ appendDims(source.getThreadTile(), appendThreadLens);
+ SmallVector<int64_t> elementTile =
+ appendDims(source.getElementTile(), appendElementLens);
+ SmallVector<int64_t> subgroupStrides =
+ appendDims(source.getSubgroupStrides(), appendSubgroupStrides);
+ SmallVector<int64_t> threadStrides =
+ appendDims(source.getThreadStrides(), appendThreadStrides);
+ return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile,
+ threadTile, elementTile, subgroupStrides,
+ threadStrides);
+}
+
LogicalResult NestedLayoutAttr::verify(
llvm::function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile,
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
index 01cf309..16bad7f 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td
@@ -218,7 +218,15 @@
"ArrayRef<int64_t>":$threadTile,
"ArrayRef<int64_t>":$elementTile,
"ArrayRef<int64_t>":$subgroupStrides,
- "ArrayRef<int64_t>":$threadStrides)>
+ "ArrayRef<int64_t>":$threadStrides)>,
+ AttrBuilder<(ins "NestedLayoutAttr":$source,
+ "ArrayRef<int64_t>":$appendSubGroupLens,
+ "ArrayRef<int64_t>":$appendBatchLens,
+ "ArrayRef<int64_t>":$appendOuterLens,
+ "ArrayRef<int64_t>":$appendThreadLens,
+ "ArrayRef<int64_t>":$appendElementLens,
+ "ArrayRef<int64_t>":$appendSubgroupStrides,
+ "ArrayRef<int64_t>":$appendThreadStrides)>
];
let extraClassDeclaration = [{
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td
index fdca63f..b5181b8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td
@@ -36,6 +36,12 @@
/*args=*/(ins "::llvm::ArrayRef<bool>":$droppedDims)
>,
InterfaceMethod<
+ /*description=*/"Get the expected undistributed shape for the given vector type.",
+ /*retTy=*/"SmallVector<int64_t>",
+ /*methodName=*/"getUndistributedShape",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
/*description=*/"Get the distributed shape for the given vector type.",
/*retTy=*/"SmallVector<int64_t>",
/*methodName=*/"getDistributedShape",