| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include <numeric> |
| |
| #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h" |
| #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/PatternMatch.h" |
| |
| using namespace mlir; |
| |
| namespace mlir::iree_compiler::IREE::VectorExt { |
| |
| using VectorValue = TypedValue<VectorType>; |
| |
| // Project the nested layout. This take a mask on the dimensions of the vector |
| // associated with this layout and projects out those dimensions. This reduces |
| // the rank of the layout in the process. |
| VectorLayoutInterface |
| NestedLayoutAttr::project(ArrayRef<bool> droppedDims) const { |
| assert(droppedDims.size() == getRank() && |
| "droppedDims size must match layout rank"); |
| |
| // Projection for this layout simply means the sizes along the projected |
| // are dropped. |
| SmallVector<int64_t> subgroupCount; |
| SmallVector<int64_t> batchCount; |
| SmallVector<int64_t> outerCount; |
| SmallVector<int64_t> threadCount; |
| SmallVector<int64_t> elementCount; |
| SmallVector<int64_t> subgroupStrides; |
| SmallVector<int64_t> threadStrides; |
| int64_t count = 0; |
| // Map to track pre-projection -> post-projection indices. Used to update |
| // the dimension orders. |
| llvm::DenseMap<int64_t, int64_t> indexToRankReducedIndexMap; |
| for (auto [idx, isProjected] : llvm::enumerate(droppedDims)) { |
| if (!isProjected) { |
| subgroupCount.push_back(getSubgroupTile()[idx]); |
| batchCount.push_back(getBatchTile()[idx]); |
| outerCount.push_back(getOuterTile()[idx]); |
| threadCount.push_back(getThreadTile()[idx]); |
| elementCount.push_back(getElementTile()[idx]); |
| subgroupStrides.push_back(getSubgroupStrides()[idx]); |
| threadStrides.push_back(getThreadStrides()[idx]); |
| indexToRankReducedIndexMap[idx] = count++; |
| } |
| } |
| // This layout is invalid for rank-0 vectors. |
| assert(count >= 0 && "unimplemented rank-0 vector"); |
| |
| return NestedLayoutAttr::get(getContext(), subgroupCount, batchCount, |
| outerCount, threadCount, elementCount, |
| subgroupStrides, threadStrides); |
| } |
| |
| VectorLayoutInterface |
| NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const { |
| SmallVector<int64_t> invPerm = invertPermutationVector(permutation); |
| SmallVector<int64_t> subgroupCount = |
| applyPermutation(getSubgroupTile(), permutation); |
| SmallVector<int64_t> batchCount = |
| applyPermutation(getBatchTile(), permutation); |
| SmallVector<int64_t> outerCount = |
| applyPermutation(getOuterTile(), permutation); |
| SmallVector<int64_t> threadCount = |
| applyPermutation(getThreadTile(), permutation); |
| SmallVector<int64_t> elementCount = |
| applyPermutation(getElementTile(), permutation); |
| SmallVector<int64_t> subgroupStrides = |
| applyPermutation(getSubgroupStrides(), permutation); |
| SmallVector<int64_t> threadStrides = |
| applyPermutation(getThreadStrides(), permutation); |
| return NestedLayoutAttr::get(getContext(), subgroupCount, batchCount, |
| outerCount, threadCount, elementCount, |
| subgroupStrides, threadStrides); |
| } |
| |
| /// We distribute to: |
| /// <BATCH x OUTER x ELEMENT> |
| SmallVector<int64_t> NestedLayoutAttr::getDistributedShape() const { |
| SmallVector<int64_t> shape; |
| shape.append(getBatchTile().begin(), getBatchTile().end()); |
| shape.append(getOuterTile().begin(), getOuterTile().end()); |
| shape.append(getElementTile().begin(), getElementTile().end()); |
| return shape; |
| } |
| |
| /// Before we distribute, we would like to see this as: |
| /// <SUBGROUP x BATCH x OUTER x THREAD x ELEMENT> |
| SmallVector<int64_t> NestedLayoutAttr::getUndistributedPackedShape() const { |
| SmallVector<int64_t> shape; |
| int64_t rank = getRank(); |
| shape.reserve(rank * 5); |
| shape.append(getSubgroupTile().begin(), getSubgroupTile().end()); |
| shape.append(getBatchTile().begin(), getBatchTile().end()); |
| shape.append(getOuterTile().begin(), getOuterTile().end()); |
| shape.append(getThreadTile().begin(), getThreadTile().end()); |
| shape.append(getElementTile().begin(), getElementTile().end()); |
| return shape; |
| } |
| |
| SmallVector<int64_t> NestedLayoutAttr::getUndistributedShape() const { |
| int64_t rank = getRank(); |
| SmallVector<int64_t> shape; |
| shape.reserve(rank); |
| for (int64_t i : llvm::seq<int64_t>(rank)) { |
| int64_t expectedDimLen = getSubgroupTile()[i] * getBatchTile()[i] * |
| getOuterTile()[i] * getThreadTile()[i] * |
| getElementTile()[i]; |
| shape.push_back(expectedDimLen); |
| } |
| return shape; |
| } |
| |
| // Gets the rank of the undistributed vector for this layout. |
| int64_t NestedLayoutAttr::getRank() const { |
| // The layout requires that all size lists are the same length and match |
| // the rank of the undistributed vector, so just return the length of one |
| // of the fields. |
| return getBatchTile().size(); |
| } |
| |
| LogicalResult NestedLayoutAttr::isValidLayout(ShapedType shapeTy, |
| Location loc) const { |
| int64_t rank = getRank(); |
| ArrayRef<int64_t> shape = shapeTy.getShape(); |
| if (shape.size() != rank) { |
| return emitError(loc, "Rank of vector (") |
| << shape.size() << ") does not match rank of layout (" << rank |
| << ")."; |
| } |
| // Multiply all shapes in the layout. |
| for (int i = 0, e = rank; i < e; ++i) { |
| int64_t expectedShape = getSubgroupTile()[i] * getBatchTile()[i] * |
| getOuterTile()[i] * getThreadTile()[i] * |
| getElementTile()[i]; |
| if (expectedShape != shape[i]) { |
| std::string shapeStr; |
| llvm::raw_string_ostream shapeOs(shapeStr); |
| llvm::interleaveComma(shape, shapeOs); |
| std::string layoutStr; |
| llvm::raw_string_ostream layoutOs(layoutStr); |
| printStripped(layoutOs); |
| return emitError(loc, "Vector shape: [") |
| << shapeStr << "] does not match the layout (" |
| << layoutStr + ") at dim " << i |
| << ". Dimension expected by layout: " << expectedShape |
| << " actual: " << shape[i]; |
| } |
| } |
| return success(); |
| } |
| NestedLayoutAttr NestedLayoutAttr::getChecked( |
| llvm::function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, |
| ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile, |
| ArrayRef<int64_t> outerTile, ArrayRef<int64_t> threadTile, |
| ArrayRef<int64_t> elementTile, ArrayRef<int64_t> subgroupStrides, |
| ArrayRef<int64_t> threadStrides) { |
| if (failed(NestedLayoutAttr::verify(emitError, subgroupTile, batchTile, |
| outerTile, threadTile, elementTile, |
| subgroupStrides, threadStrides))) { |
| return NestedLayoutAttr(); |
| } |
| |
| return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile, |
| threadTile, elementTile, subgroupStrides, |
| threadStrides); |
| } |
| |
| NestedLayoutAttr NestedLayoutAttr::get( |
| MLIRContext *context, ArrayRef<int64_t> subgroupTile, |
| ArrayRef<int64_t> batchTile, ArrayRef<int64_t> outerTile, |
| ArrayRef<int64_t> threadTile, ArrayRef<int64_t> elementTile, |
| ArrayRef<int64_t> subgroupStrides, ArrayRef<int64_t> threadStrides) { |
| |
| SmallVector<int64_t> normalizedSubgroupStrides(subgroupStrides); |
| SmallVector<int64_t> normalizedThreadStrides(threadStrides); |
| |
| // Dimension of size 1 only have one element to distribute, so stride can be |
| // anything. We normalize the stride to be 0, to have consistency. |
| |
| for (auto [stride, size] : |
| llvm::zip_equal(normalizedSubgroupStrides, subgroupTile)) { |
| if (size == 1) { |
| stride = 0; |
| } |
| } |
| |
| for (auto [stride, size] : |
| llvm::zip_equal(normalizedThreadStrides, threadTile)) { |
| if (size == 1) { |
| stride = 0; |
| } |
| } |
| |
| return Base::get(context, subgroupTile, batchTile, outerTile, threadTile, |
| elementTile, normalizedSubgroupStrides, |
| normalizedThreadStrides); |
| } |
| |
| static SmallVector<int64_t> appendDims(ArrayRef<int64_t> tileLens, |
| ArrayRef<int64_t> appendLens) { |
| SmallVector<int64_t> tileLensResult = llvm::to_vector(tileLens); |
| tileLensResult.insert(tileLensResult.end(), appendLens.begin(), |
| appendLens.end()); |
| return tileLensResult; |
| } |
| |
| NestedLayoutAttr NestedLayoutAttr::get(MLIRContext *context, |
| NestedLayoutAttr source, |
| ArrayRef<int64_t> appendSubGroupLens, |
| ArrayRef<int64_t> appendBatchLens, |
| ArrayRef<int64_t> appendOuterLens, |
| ArrayRef<int64_t> appendThreadLens, |
| ArrayRef<int64_t> appendElementLens, |
| ArrayRef<int64_t> appendSubgroupStrides, |
| ArrayRef<int64_t> appendThreadStrides) { |
| SmallVector<int64_t> subgroupTile = |
| appendDims(source.getSubgroupTile(), appendSubGroupLens); |
| SmallVector<int64_t> batchTile = |
| appendDims(source.getBatchTile(), appendBatchLens); |
| SmallVector<int64_t> outerTile = |
| appendDims(source.getOuterTile(), appendOuterLens); |
| SmallVector<int64_t> threadTile = |
| appendDims(source.getThreadTile(), appendThreadLens); |
| SmallVector<int64_t> elementTile = |
| appendDims(source.getElementTile(), appendElementLens); |
| SmallVector<int64_t> subgroupStrides = |
| appendDims(source.getSubgroupStrides(), appendSubgroupStrides); |
| SmallVector<int64_t> threadStrides = |
| appendDims(source.getThreadStrides(), appendThreadStrides); |
| return NestedLayoutAttr::get(context, subgroupTile, batchTile, outerTile, |
| threadTile, elementTile, subgroupStrides, |
| threadStrides); |
| } |
| |
| LogicalResult NestedLayoutAttr::verify( |
| llvm::function_ref<InFlightDiagnostic()> emitError, |
| ArrayRef<int64_t> subgroupTile, ArrayRef<int64_t> batchTile, |
| ArrayRef<int64_t> outerTile, ArrayRef<int64_t> threadTile, |
| ArrayRef<int64_t> elementTile, ArrayRef<int64_t> subgroupStrides, |
| ArrayRef<int64_t> threadStrides) { |
| |
| size_t rank = subgroupTile.size(); |
| auto checkTile = [&](ArrayRef<int64_t> tileShape) { |
| if (tileShape.size() != rank) { |
| emitError() << "all fields must have the same rank as the layout"; |
| return failure(); |
| } |
| return success(); |
| }; |
| |
| if (failed(checkTile(subgroupTile)) || failed(checkTile(batchTile)) || |
| failed(checkTile(outerTile)) || failed(checkTile(threadTile)) || |
| failed(checkTile(elementTile)) || failed(checkTile(subgroupStrides)) || |
| failed(checkTile(threadStrides))) { |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Given a single flat thread ID, compute the indices of the distributed |
| /// dimensions (subgroup and thread ids). The only difference between subgroup |
| /// and thread dimensions is the order in which they are "divided out" of the |
| /// underlying vector (i.e. vector_shape /= subgroups -> batches -> outers -> |
| /// threads -> elements). There is no requirement that a subgroup id only |
| /// spans subgroups. |
| SmallVector<Value> |
| NestedLayoutAttr::computeThreadIds(Value threadId, int64_t subgroupSize, |
| RewriterBase &rewriter) const { |
| SmallVector<Value> virtualTids; |
| |
| Location loc = threadId.getLoc(); |
| |
| AffineExpr tidExpr, size, stride; |
| bindDims(rewriter.getContext(), tidExpr); |
| bindSymbols(rewriter.getContext(), size, stride); |
| |
| // (tid floordiv stride) mod size |
| AffineMap threadTidMap = |
| AffineMap::get(/*dims=*/1, /*syms=*/2, tidExpr.floorDiv(stride) % size); |
| |
| // (tid floordiv (stride * subgroup_size)) mod size |
| AffineMap subgroupTidMap = AffineMap::get( |
| /*dims=*/1, /*syms=*/2, tidExpr.floorDiv(stride * subgroupSize) % size); |
| |
| for (auto [dimSize, dimStride] : |
| llvm::zip_equal(getSubgroupTile(), getSubgroupStrides())) { |
| // Dimension is not distributed. |
| if (dimStride == 0) { |
| virtualTids.push_back(rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIndexAttr(dimStride))); |
| continue; |
| } |
| |
| auto sizeVal = |
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dimSize)); |
| auto strideVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIndexAttr(dimStride)); |
| virtualTids.push_back(rewriter.create<affine::AffineApplyOp>( |
| loc, subgroupTidMap, ValueRange{threadId, sizeVal, strideVal})); |
| } |
| |
| for (auto [dimSize, dimStride] : |
| llvm::zip_equal(getThreadTile(), getThreadStrides())) { |
| // Dimension is not distributed. |
| if (dimStride == 0) { |
| virtualTids.push_back(rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIndexAttr(dimStride))); |
| continue; |
| } |
| |
| auto sizeVal = |
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dimSize)); |
| auto strideVal = rewriter.create<arith::ConstantOp>( |
| loc, rewriter.getIndexAttr(dimStride)); |
| virtualTids.push_back(rewriter.create<affine::AffineApplyOp>( |
| loc, threadTidMap, ValueRange{threadId, sizeVal, strideVal})); |
| } |
| |
| return virtualTids; |
| } |
| |
| } // namespace mlir::iree_compiler::IREE::VectorExt |
| |
| using namespace mlir::iree_compiler::IREE::VectorExt; |
| |
| #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtEnums.cpp.inc" // IWYU pragma: keep |
| |
| #define GET_ATTRDEF_CLASSES |
| #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" // IWYU pragma: keep |
| |
| void IREEVectorExtDialect::registerAttributes() { |
| addAttributes< |
| #define GET_ATTRDEF_LIST |
| #include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp.inc" |
| >(); |
| } |