NFC: Add VectorLayoutInterface method for getting the layout rank (#17071)
The rank of a layout, especially the NestedLayout, is commonly queried
when manipulating the layout. An interface method makes this slightly
easier and more uniform than just checking the size of one of the nested
layout fields.
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
index ced5cfe..c2a2233 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp
@@ -40,7 +40,7 @@
return rewriter.notifyMatchFailure(
contractOp, "missing nested layout for contraction result");
}
- int64_t rank = resultLayout.getBatchOrder().size();
+ int64_t rank = resultLayout.getRank();
NestedLayoutAttr lhsLayout =
dyn_cast<NestedLayoutAttr>(signature[contractOp.getLhs()]);
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
index fba911a..20a3d54 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp
@@ -79,7 +79,7 @@
return constExpr.getValue() == 0;
return false;
};
- int64_t rank = vectorLayout.getBatchOrder().size();
+ int64_t rank = vectorLayout.getRank();
// Permute the batch and outer vector offsets to match the order of
// the vector dimensions using the inverse of the batch/offset order.
SmallVector<int64_t> batchOffsets =
@@ -113,7 +113,7 @@
}
static SmallVector<int64_t> getLoopOrder(NestedLayoutAttr vectorLayout) {
- int64_t rank = vectorLayout.getBatchOrder().size();
+ int64_t rank = vectorLayout.getRank();
// Let the unroll order first unroll the batch dimensions, then the
// outer vector dimensions. We unroll in the order specified by the
// layout.
@@ -137,7 +137,7 @@
static SmallVector<int64_t>
getElementVectorTileShape(NestedLayoutAttr vectorLayout) {
- int64_t rank = vectorLayout.getBatchOrder().size();
+ int64_t rank = vectorLayout.getRank();
SmallVector<int64_t> tileShape = vectorLayout.getDistributedShape();
// We tile to a vector with BATCH, OUTER, and ELEMENT dimensions. So to access
// the subvector only containing elements, we need indices in all BATCH and
@@ -217,7 +217,7 @@
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
- int64_t rank = vectorLayout.getBatchOrder().size();
+ int64_t rank = vectorLayout.getRank();
Type elementType = readOp.getSource().getType().getElementType();
auto vectorType = VectorType::get(distShape, elementType);
@@ -304,7 +304,7 @@
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
SmallVector<int64_t> loopOrder = getLoopOrder(vectorLayout);
- int64_t rank = vectorLayout.getBatchOrder().size();
+ int64_t rank = vectorLayout.getRank();
SmallVector<Value> warpIndices, threadIndices;
populateWarpAndThreadIndices(rewriter, threadId, vectorLayout, warpIndices,
@@ -388,8 +388,8 @@
Value accumulator = rewriter.create<arith::ConstantOp>(
loc, vectorType, rewriter.getZeroAttr(vectorType));
- int64_t rank = vectorLayout.getBatchOrder().size();
- int64_t sourceRank = sourceLayout.getBatchOrder().size();
+ int64_t rank = vectorLayout.getRank();
+ int64_t sourceRank = sourceLayout.getRank();
// We unroll along both the batch and outer dimensions for a similar reason
// to the transfer ops. `vector.broadcast` can only broadcast along outer
// dims, so mixing broadcasted and un-broadcasted element/outer dims can't
diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
index 065cf4d..71d4481 100644
--- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
+++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/VectorExt/IR/VectorExtInterfaces.td
@@ -41,6 +41,12 @@
/*retTy=*/"SmallVector<int64_t>",
/*methodName=*/"getDistributedShape",
/*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*description=*/"Get the rank of the undistributed vector for this layout.",
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getRank",
+ /*args=*/(ins)
>
];
}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
index d2a6dfb..6e59a85 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/VectorExt/IR/VectorExtAttrs.cpp
@@ -54,11 +54,11 @@
// specified, then return an empty vector.
LogicalResult LayoutAttr::isValidLayout(VectorValue vector) const {
ArrayRef<int64_t> shape = vector.getType().getShape();
- if (shape.size() != getLayouts().size()) {
- return emitError(vector.getLoc(),
- "Rank of vector (" + std::to_string(shape.size()) +
- ") does not match rank of layout (" +
- std::to_string(getLayouts().size()) + ").");
+ if (shape.size() != getRank()) {
+ return emitError(vector.getLoc(), "Rank of vector (" +
+ std::to_string(shape.size()) +
+ ") does not match rank of layout (" +
+ std::to_string(getRank()) + ").");
}
for (auto [idx, layout] : llvm::enumerate(getLayouts())) {
ArrayRef<int64_t> layoutShape = layout.getShapes();
@@ -87,11 +87,10 @@
// Project out the layout for the specified dimensions
// resulting in the layout for a lower dimensional vector.
VectorLayoutInterface LayoutAttr::project(ArrayRef<bool> droppedDims) const {
- assert(droppedDims.size() == getLayouts().size() &&
+ assert(droppedDims.size() == getRank() &&
"droppedDims size must match layout size");
ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
- assert(droppedDims.size() == layouts.size());
SmallVector<PerDimLayoutAttr> newLayouts;
for (auto pair : llvm::zip(droppedDims, layouts)) {
if (!std::get<0>(pair))
@@ -103,14 +102,13 @@
// Permute the layout according to the provided permutation
// vector. The dimensionality of the layout remains the same.
VectorLayoutInterface LayoutAttr::permute(ArrayRef<int64_t> permutation) const {
- assert(permutation.size() == getLayouts().size() &&
- "permutation size must match layout size");
+ assert(permutation.size() == getRank() &&
+ "permutation size must match layout rank");
ArrayRef<PerDimLayoutAttr> layouts = getLayouts();
- assert(permutation.size() == layouts.size());
SmallVector<PerDimLayoutAttr> newLayouts;
for (unsigned index : permutation) {
- assert(index >= 0 && index < layouts.size());
+ assert(index >= 0 && index < getRank());
newLayouts.push_back(layouts[index]);
}
return LayoutAttr::get(getContext(), newLayouts);
@@ -146,12 +144,12 @@
}
PerDimLayoutAttr LayoutAttr::getDimLayout(int64_t dim) const {
- assert(dim >= 0 && dim < getLayouts().size());
+ assert(dim >= 0 && dim < getRank());
return getLayouts()[dim];
}
std::optional<int64_t> LayoutAttr::getBatchDim(int64_t dim) {
- assert(dim < getLayouts().size());
+ assert(dim < getRank());
PerDimLayoutAttr layout = getDimLayout(dim);
for (auto [name, shape] :
llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
@@ -162,7 +160,7 @@
}
std::optional<int64_t> LayoutAttr::getLaneDim(int64_t dim) {
- assert(dim < getLayouts().size());
+ assert(dim < getRank());
PerDimLayoutAttr layout = getDimLayout(dim);
for (auto [name, shape] :
llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
@@ -173,7 +171,7 @@
}
std::optional<LayoutDimension> LayoutAttr::getLane(int64_t dim) {
- assert(dim < getLayouts().size());
+ assert(dim < getRank());
PerDimLayoutAttr layout = getDimLayout(dim);
for (auto [name, shape] :
llvm::zip_equal(layout.getLabels(), layout.getShapes())) {
@@ -183,6 +181,8 @@
return std::nullopt;
}
+int64_t LayoutAttr::getRank() const { return getLayouts().size(); }
+
std::tuple<int64_t, int64_t, int64_t> LayoutAttr::getLaneGrid() {
int64_t laneX = 1;
int64_t laneY = 1;
@@ -245,7 +245,7 @@
// the rank of the layout in the process.
VectorLayoutInterface
NestedLayoutAttr::project(ArrayRef<bool> droppedDims) const {
- assert(droppedDims.size() == getBatchesPerSubgroup().size() &&
+ assert(droppedDims.size() == getRank() &&
"droppedDims size must match layout rank");
// Projection for this layout simply means the sizes along the projected
@@ -384,16 +384,25 @@
return shape;
}
+// Gets the rank of the undistributed vector for this layout.
+int64_t NestedLayoutAttr::getRank() const {
+ // The layout requires that all size lists are the same length and match
+ // the rank of the undistributed vector, so just return the length of one
+ // of the fields.
+ return getBatchesPerSubgroup().size();
+}
+
LogicalResult NestedLayoutAttr::isValidLayout(VectorValue vector) const {
+ int64_t rank = getRank();
ArrayRef<int64_t> shape = vector.getType().getShape();
- if (shape.size() != getBatchOrder().size()) {
- return emitError(vector.getLoc(),
- "Rank of vector (" + std::to_string(shape.size()) +
- ") does not match rank of layout (" +
- std::to_string(getBatchOrder().size()) + ").");
+ if (shape.size() != rank) {
+ return emitError(vector.getLoc(), "Rank of vector (" +
+ std::to_string(shape.size()) +
+ ") does not match rank of layout (" +
+ std::to_string(rank) + ").");
}
// Multiply all shapes in the layout.
- for (int i = 0, e = shape.size(); i < e; ++i) {
+ for (int i = 0, e = rank; i < e; ++i) {
int64_t expectedShape = getSubgroupsPerWorkgroup()[i] *
getBatchesPerSubgroup()[i] *
getOutersPerBatch()[i] * getThreadsPerOuter()[i] *