| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Transforms/Transforms.h" |
| #include "iree/compiler/Codegen/Utils/MarkerUtils.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
| #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/Dialect/Vector/VectorTransforms.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| #define DEBUG_TYPE "iree-llvmgpu-vectorization" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| //====---------------------------------------------------------------------===// |
| // Patterns for vectorization |
| //====---------------------------------------------------------------------===// |
| |
| static void populateVectorizationPatterns(RewritePatternSet &patterns) { |
| linalg::insertVectorizationPatterns<linalg::FillOp, linalg::CopyOp, |
| linalg::GenericOp, |
| linalg::ContractionOpInterface>( |
| patterns, linalg::LinalgVectorizationOptions(), |
| linalg::LinalgTransformationFilter( |
| Identifier::get(getVectorizeMarker(), patterns.getContext()))); |
| vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); |
| vector::populateVectorReductionToContractPatterns(patterns); |
| } |
| |
| static Optional<SmallVector<int64_t, 4>> getGPUNativeVectorSize(Operation *op) { |
| if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) { |
| if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) { |
| // Map elementwise ops to vec4. |
| SmallVector<int64_t, 4> nativeSize(vecType.getRank(), 1); |
| nativeSize.back() = 4; |
| return nativeSize; |
| } |
| } else if (auto vt = dyn_cast<VectorTransferOpInterface>(op)) { |
| auto rank = vt.getVectorType().getRank(); |
| SmallVector<int64_t, 4> nativeSize(rank, 1); |
| // Load 4 elements on the most inner dimension. |
| for (auto dim : llvm::enumerate(vt.permutation_map().getResults())) { |
| if (auto dimExpr = dim.value().dyn_cast<AffineDimExpr>()) { |
| if (dimExpr.getPosition() == vt.permutation_map().getNumDims() - 1) |
| nativeSize[dim.index()] = 4; |
| } |
| } |
| return nativeSize; |
| } else if (auto contract = dyn_cast<vector::ContractionOp>(op)) { |
| unsigned lastParalleldim = 0; |
| for (auto it : llvm::enumerate(contract.iterator_types())) { |
| if (isParallelIterator(it.value())) lastParalleldim = it.index(); |
| } |
| SmallVector<int64_t, 4> nativeSize(contract.iterator_types().size(), 1); |
| nativeSize[lastParalleldim] = 4; |
| return nativeSize; |
| } |
| return llvm::None; |
| } |
| |
| static void populateVectorUnrollPatterns(RewritePatternSet &patterns) { |
| vector::populateVectorUnrollPatterns( |
| patterns, |
| vector::UnrollVectorOptions().setNativeShapeFn(getGPUNativeVectorSize)); |
| } |
| |
| namespace { |
| struct LLVMGPUVectorizationPass |
| : public LLVMGPUVectorizationBase<LLVMGPUVectorizationPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<vector::VectorDialect>(); |
| } |
| void runOnOperation() override { |
| auto funcOp = getOperation(); |
| MLIRContext *context = &getContext(); |
| |
| { |
| // Step 1. Vectorize |
| RewritePatternSet vectorizationPatterns(context); |
| populateVectorizationPatterns(vectorizationPatterns); |
| (void)applyPatternsAndFoldGreedily(funcOp, |
| std::move(vectorizationPatterns)); |
| |
| // Fold consumer add ops into the contraction op itself. |
| RewritePatternSet canonicalizationPatterns(context); |
| vector::ContractionOp::getCanonicalizationPatterns( |
| canonicalizationPatterns, context); |
| (void)applyPatternsAndFoldGreedily(funcOp, |
| std::move(canonicalizationPatterns)); |
| |
| RewritePatternSet vectorUnrollPatterns(context); |
| populateVectorUnrollPatterns(vectorUnrollPatterns); |
| (void)applyPatternsAndFoldGreedily(funcOp, |
| std::move(vectorUnrollPatterns)); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After Step 1: Vectorization ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| { |
| // Step 2. Lower transfer op to canonical form. |
| RewritePatternSet lowerTransferOpPatterns(funcOp.getContext()); |
| vector::populateVectorToVectorCanonicalizationPatterns( |
| lowerTransferOpPatterns); |
| vector::populateVectorTransferPermutationMapLoweringPatterns( |
| lowerTransferOpPatterns); |
| (void)applyPatternsAndFoldGreedily(funcOp, |
| std::move(lowerTransferOpPatterns)); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() |
| << "\n--- After Step 2: Lower transfer op to canonical form. ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| { |
| // Step 3. Canonicalize. |
| RewritePatternSet canonicalizationPatterns(funcOp.getContext()); |
| vector::ExtractStridedSliceOp::getCanonicalizationPatterns( |
| canonicalizationPatterns, canonicalizationPatterns.getContext()); |
| vector::populateVectorToVectorCanonicalizationPatterns( |
| canonicalizationPatterns); |
| (void)applyPatternsAndFoldGreedily(funcOp, |
| std::move(canonicalizationPatterns)); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After Step 3: Canonicalize. ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| { |
| // Step 4. Lower contract op to outer product. |
| RewritePatternSet contractLoweringPatterns(funcOp.getContext()); |
| vector::populateVectorBroadcastLoweringPatterns(contractLoweringPatterns); |
| vector::populateVectorContractLoweringPatterns( |
| contractLoweringPatterns, |
| vector::VectorTransformsOptions().setVectorTransformsOptions( |
| vector::VectorContractLowering::OuterProduct)); |
| vector::populateVectorMaskOpLoweringPatterns(contractLoweringPatterns); |
| vector::populateVectorShapeCastLoweringPatterns(contractLoweringPatterns); |
| vector::populateVectorMultiReductionLoweringPatterns( |
| contractLoweringPatterns, |
| vector::VectorMultiReductionLowering::InnerParallel); |
| (void)applyPatternsAndFoldGreedily(funcOp, |
| std::move(contractLoweringPatterns)); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() |
| << "\n--- After Step 4: Lower contract op to outer product. ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| } |
| }; |
| } // namespace |
| |
| std::unique_ptr<OperationPass<FuncOp>> createLLVMGPUVectorizationPass() { |
| return std::make_unique<LLVMGPUVectorizationPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |