| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| //===---- VectorToGPUPass.cpp - Pass for the final SPIR-V conversion ------===// |
| // |
| // This file implement a pass to convert vector dialect operations to GPU |
| // operations distributed across a subgroup. |
| // |
| //===----------------------------------------------------------------------===// |
| #include <memory> |
| |
| #include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h" |
| #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" |
| #include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h" |
| #include "iree/compiler/Conversion/LinalgToSPIRV/CooperativeMatrixAnalysis.h" |
| #include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/GPU/GPUDialect.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace { |
| // TODO(thomasraoux): Fetch this value from device properties. |
| static const int subgroupSize = 32; |
| |
| struct ConvertVectorToGPUPass |
| : public PassWrapper<ConvertVectorToGPUPass, OperationPass<FuncOp>> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<AffineDialect, gpu::GPUDialect, memref::MemRefDialect, |
| scf::SCFDialect, vector::VectorDialect>(); |
| } |
| |
| void runOnOperation() override; |
| |
| private: |
| void tileAndVectorizeLinalgCopy(FuncOp funcOp, MLIRContext *context); |
| void lowerVectorOps(FuncOp funcOp, MLIRContext *context); |
| }; |
| |
| // Common class for all vector to GPU patterns. |
| template <typename OpTy> |
| class VectorToGPUPattern : public OpConversionPattern<OpTy> { |
| public: |
| VectorToGPUPattern<OpTy>( |
| MLIRContext *context, |
| const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis) |
| : OpConversionPattern<OpTy>::OpConversionPattern(context), |
| cooperativeMatrixAnalysis(cooperativeMatrixAnalysis) {} |
| |
| protected: |
| const CooperativeMatrixAnalysis &cooperativeMatrixAnalysis; |
| }; |
| |
| /// Converts unary and binary standard operations using new type. |
| template <typename StdOp> |
| class UnaryAndBinaryOpPattern final : public VectorToGPUPattern<StdOp> { |
| public: |
| using VectorToGPUPattern<StdOp>::VectorToGPUPattern; |
| |
| LogicalResult matchAndRewrite( |
| StdOp operation, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (VectorToGPUPattern<StdOp>::cooperativeMatrixAnalysis |
| .usesCooperativeMatrixType(operation)) |
| return failure(); |
| Value newOp = |
| rewriter.create<StdOp>(operation.getLoc(), ValueRange(operands)); |
| rewriter.replaceOp(operation, ValueRange(newOp)); |
| return success(); |
| } |
| }; |
| |
| class VectorTransferReadConversion |
| : public VectorToGPUPattern<vector::TransferReadOp> { |
| public: |
| using VectorToGPUPattern<vector::TransferReadOp>::VectorToGPUPattern; |
| |
| LogicalResult matchAndRewrite( |
| vector::TransferReadOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (cooperativeMatrixAnalysis.usesCooperativeMatrixType(op)) |
| return failure(); |
| // Only support identity map for now. |
| if (!op.permutation_map().isIdentity()) return failure(); |
| if (op.getVectorType().getNumElements() != subgroupSize) return failure(); |
| // Only works for the case where one workgroups has only one subgroup. |
| auto wgSize = spirv::lookupLocalWorkGroupSize(op); |
| if (wgSize.getValue<int32_t>(0) != subgroupSize || |
| wgSize.getValue<int32_t>(1) != 1 || wgSize.getValue<int32_t>(2) != 1) |
| return failure(); |
| auto loc = op.getLoc(); |
| SmallVector<Value, 4> indices(op.indices()); |
| // Use threadId.x as the subgroupInvocationId. |
| // TODO(thomasraoux): Replace it once subgroup Ids are working. |
| auto threadIndex = rewriter.create<gpu::ThreadIdOp>( |
| loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); |
| Value index = rewriter.create<AddIOp>(loc, threadIndex, indices.back()); |
| indices.back() = index; |
| Value newOp = rewriter.create<memref::LoadOp>(loc, op.source(), indices); |
| rewriter.replaceOp(op, ValueRange(newOp)); |
| return success(); |
| } |
| }; |
| |
| class VectorTransferWriteConversion |
| : public VectorToGPUPattern<vector::TransferWriteOp> { |
| public: |
| using VectorToGPUPattern<vector::TransferWriteOp>::VectorToGPUPattern; |
| |
| LogicalResult matchAndRewrite( |
| vector::TransferWriteOp op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (cooperativeMatrixAnalysis.usesCooperativeMatrixType(op)) |
| return failure(); |
| if (!op.permutation_map().isIdentity()) return failure(); |
| if (op.getVectorType().getNumElements() != subgroupSize) return failure(); |
| auto loc = op.getLoc(); |
| SmallVector<Value, 4> indices(op.indices()); |
| auto ThreadIndex = rewriter.create<gpu::ThreadIdOp>( |
| loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); |
| Value index = rewriter.create<AddIOp>(loc, ThreadIndex, indices.back()); |
| indices.back() = index; |
| rewriter.replaceOpWithNewOp<memref::StoreOp>(op, operands[0], operands[1], |
| indices); |
| return success(); |
| } |
| }; |
| |
| class VectorToGPUConversionTarget : public ConversionTarget { |
| public: |
| using ConversionTarget::ConversionTarget; |
| |
| protected: |
| // Standard operation are legal if they operate on scalars. We need to |
| // legalize operations on vectors. |
| bool isDynamicallyLegal(Operation *op) const override { |
| auto isVectorType = [](Type t) { return t.isa<VectorType>(); }; |
| if (llvm::any_of(op->getResultTypes(), isVectorType) || |
| llvm::any_of(op->getOperandTypes(), isVectorType)) |
| return false; |
| return true; |
| } |
| }; |
| |
| void ConvertVectorToGPUPass::tileAndVectorizeLinalgCopy(FuncOp funcOp, |
| MLIRContext *context) { |
| // 1. Tile linalg and distribute it on invocations. |
| std::unique_ptr<ConversionTarget> target = |
| std::make_unique<ConversionTarget>(*context); |
| target->addDynamicallyLegalOp<linalg::CopyOp>([&](linalg::CopyOp copy) { |
| return !(hasMarker(copy, getCopyToWorkgroupMemoryMarker())); |
| }); |
| target->markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
| OwningRewritePatternList tileAndDistributePattern(&getContext()); |
| populateLinalgTileAndDistributePatterns(context, tileAndDistributePattern); |
| if (failed(applyPartialConversion(funcOp, *target, |
| std::move(tileAndDistributePattern)))) { |
| return signalPassFailure(); |
| } |
| |
| // 2. Canonicalize the IR generated by tiling. |
| OwningRewritePatternList canonicalizePatterns = |
| linalg::getLinalgTilingCanonicalizationPatterns(context); |
| canonicalizePatterns.insert<AffineMinCanonicalizationPattern, |
| linalg::AffineMinSCFCanonicalizationPattern>( |
| context); |
| (void)applyPatternsAndFoldGreedily(funcOp, std::move(canonicalizePatterns)); |
| |
| // 3. Vectorize the tiled linalg to be able to map it to load/store vector. |
| OwningRewritePatternList vectorizationPatterns(&getContext()); |
| linalg::insertVectorizationPatterns<linalg::CopyOp>( |
| vectorizationPatterns, linalg::LinalgVectorizationOptions(), |
| linalg::LinalgTransformationFilter( |
| Identifier::get(getVectorizeMarker(), context), {})); |
| (void)applyPatternsAndFoldGreedily(funcOp, std::move(vectorizationPatterns)); |
| } |
| |
| // Convert vector transfer_read to a load if possible. This is the case only if |
| // the element type of the memref matches the element type we want to load. |
| class VectorTransferReadToLoad |
| : public OpRewritePattern<vector::TransferReadOp> { |
| public: |
| using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::TransferReadOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getVectorType().getNumElements() != 1 || |
| op.getShapedType().getElementType() != |
| op.getVectorType().getElementType()) { |
| return failure(); |
| } |
| auto loc = op.getLoc(); |
| Value newOp = |
| rewriter.create<memref::LoadOp>(loc, op.source(), op.indices()); |
| newOp = |
| rewriter.create<vector::BroadcastOp>(loc, op.getVectorType(), newOp); |
| rewriter.replaceOp(op, newOp); |
| return success(); |
| } |
| }; |
| |
| // Convert vector transfer_write to a store if possible. This is the case only |
| // if the element type of the memref matches the element type we want to store. |
| class VectorTransferWriteToStore |
| : public OpRewritePattern<vector::TransferWriteOp> { |
| public: |
| using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::TransferWriteOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getVectorType().getNumElements() != 1 || |
| op.getShapedType().getElementType() != |
| op.getVectorType().getElementType()) { |
| return failure(); |
| } |
| auto loc = op.getLoc(); |
| SmallVector<int64_t, 2> zero(op.getVectorType().getRank(), 0); |
| Value scalarValue = |
| rewriter.create<vector::ExtractOp>(loc, op.vector(), zero); |
| rewriter.create<memref::StoreOp>(loc, scalarValue, op.source(), |
| op.indices()); |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| }; |
| |
| // Lower vector contract to a single scalar or vector mulf+addf. Insert casts to |
| // convert from N-D vector to 1D vector or scalar. |
| class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> { |
| public: |
| using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ContractionOp op, |
| PatternRewriter &rewriter) const override { |
| auto iteratorTypes = op.iterator_types().getValue(); |
| if (!isReductionIterator(iteratorTypes.back()) || |
| op.getContractingDimMap().size() > 1) |
| return failure(); |
| if (op.getLhsType().getNumElements() != 1) return failure(); |
| auto accType = op.getAccType().cast<VectorType>(); |
| auto rhsType = op.getRhsType(); |
| unsigned vecSize = accType.getNumElements(); |
| if (accType != rhsType || !(vecSize >= 1 && vecSize <= 4) || |
| accType.getShape().back() != vecSize) |
| return failure(); |
| auto loc = op.getLoc(); |
| VectorType vecType = VectorType::get( |
| vecSize, op.getResultType().cast<VectorType>().getElementType()); |
| llvm::SmallVector<int64_t, 4> zero(iteratorTypes.size() - 1, 0); |
| Value lhs = rewriter.create<vector::ExtractOp>(loc, op.lhs(), zero); |
| Value rhs, acc; |
| if (vecSize == 1) { |
| rhs = rewriter.create<vector::ExtractOp>(loc, op.rhs(), zero); |
| acc = rewriter.create<vector::ExtractOp>(loc, op.acc(), zero); |
| } else { |
| lhs = rewriter.create<vector::BroadcastOp>(loc, vecType, lhs); |
| rhs = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.rhs()); |
| acc = rewriter.create<vector::ShapeCastOp>(loc, vecType, op.acc()); |
| } |
| Value newOp = rewriter.create<MulFOp>(loc, lhs, rhs); |
| newOp = rewriter.create<AddFOp>(loc, newOp, acc); |
| if (vecSize == 1) |
| newOp = |
| rewriter.create<vector::BroadcastOp>(loc, op.getResultType(), newOp); |
| else |
| newOp = |
| rewriter.create<vector::ShapeCastOp>(loc, op.getResultType(), newOp); |
| rewriter.replaceOp(op, newOp); |
| return success(); |
| } |
| }; |
| |
| // Lower elementwise operation from N-D vector to 1-D vectors that can be |
| // natively supported. |
| class ElementwiseLowering : public RewritePattern { |
| public: |
| ElementwiseLowering(MLIRContext *context) |
| : RewritePattern(0, MatchAnyOpTypeTag()) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
| return failure(); |
| auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>(); |
| if (!vecType || vecType.getRank() == 1) return failure(); |
| |
| SmallVector<Value, 4> newOperands; |
| for (Value operand : op->getOperands()) { |
| if (auto opVecType = operand.getType().dyn_cast<VectorType>()) { |
| auto newType = VectorType::get({opVecType.getNumElements()}, |
| opVecType.getElementType()); |
| newOperands.push_back(rewriter.create<vector::ShapeCastOp>( |
| op->getLoc(), newType, operand)); |
| } else { |
| newOperands.push_back(operand); |
| } |
| } |
| OperationState state(op->getLoc(), op->getName()); |
| state.addAttributes(op->getAttrs()); |
| state.addOperands(newOperands); |
| state.addTypes({VectorType::get({vecType.getNumElements()}, |
| vecType.getElementType())}); |
| Operation *newOp = rewriter.createOperation(state); |
| rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType, |
| newOp->getResult(0)); |
| return success(); |
| } |
| }; |
| |
| // Lower ExtractStridedSliceOp to an ExtractOp instruction that can be natively |
| // converted to SPIR-V. Add a BroadcastOp to keep the type consistent, we expect |
| // the Broadcast to be removed by canonicalization. |
| class ExtractStridedLowering |
| : public OpRewritePattern<vector::ExtractStridedSliceOp> { |
| public: |
| using OpRewritePattern<vector::ExtractStridedSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, |
| PatternRewriter &rewriter) const override { |
| // Only handle cases extracting a degenerated vector so that we can generate |
| // an extractOp with scalar destination. |
| if (op.getResult().getType().cast<VectorType>().getNumElements() != 1) |
| return failure(); |
| auto loc = op.getLoc(); |
| SmallVector<int64_t, 4> offsets = llvm::to_vector<4>( |
| llvm::map_range(op.offsets().getAsRange<IntegerAttr>(), |
| [](IntegerAttr attr) { return attr.getInt(); })); |
| offsets.resize(op.getVectorType().getRank(), 0); |
| Value newOp = rewriter.create<vector::ExtractOp>(loc, op.vector(), offsets); |
| newOp = rewriter.create<vector::BroadcastOp>(loc, op.getResult().getType(), |
| newOp); |
| rewriter.replaceOp(op, newOp); |
| return success(); |
| } |
| }; |
| |
| // Lower vector ops to instructions that can be later converted to SPIR-V. |
| void ConvertVectorToGPUPass::lowerVectorOps(FuncOp funcOp, |
| MLIRContext *context) { |
| OwningRewritePatternList patterns(&getContext()); |
| patterns.insert<VectorContractLowering, VectorTransferReadToLoad, |
| VectorTransferWriteToStore, ExtractStridedLowering, |
| ElementwiseLowering>(context); |
| (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
| } |
| |
| void ConvertVectorToGPUPass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| FuncOp funcOp = getOperation(); |
| tileAndVectorizeLinalgCopy(funcOp, context); |
| |
| lowerVectorOps(funcOp, context); |
| |
| auto &cooperativeMatrixAnalysis = getAnalysis<CooperativeMatrixAnalysis>(); |
| OwningRewritePatternList patterns(&getContext()); |
| patterns.insert<UnaryAndBinaryOpPattern<AddFOp>, VectorTransferReadConversion, |
| VectorTransferWriteConversion>(context, |
| cooperativeMatrixAnalysis); |
| std::unique_ptr<VectorToGPUConversionTarget> target = |
| std::make_unique<VectorToGPUConversionTarget>(*context); |
| target->addDynamicallyLegalDialect<memref::MemRefDialect>(); |
| target->addDynamicallyLegalDialect<StandardOpsDialect>(); |
| target->addIllegalOp<scf::ParallelOp>(); |
| target->addLegalOp<scf::YieldOp>(); |
| target->addLegalOp<scf::ForOp>(); |
| target->addLegalDialect<gpu::GPUDialect>(); |
| if (failed(applyPartialConversion(funcOp, *target, std::move(patterns)))) |
| return signalPassFailure(); |
| } |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Pass entry point and registration |
| //===----------------------------------------------------------------------===// |
| std::unique_ptr<OperationPass<FuncOp>> createVectorToGPUPass() { |
| return std::make_unique<ConvertVectorToGPUPass>(); |
| } |
| |
| static PassRegistration<ConvertVectorToGPUPass> pass( |
| "iree-codegen-vector-to-gpu", |
| "Convert vector dialect to gpu subgroup level GPU instructions"); |
| } // namespace iree_compiler |
| } // namespace mlir |