| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/LLVMGPU/Passes.h" |
| #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "mlir/Transforms/Passes.h" |
| |
| namespace mlir::iree_compiler { |
| |
| #define GEN_PASS_DEF_LLVMGPUVECTORLOWERINGPASS |
| #include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc" |
| |
| //====---------------------------------------------------------------------===// |
| // Patterns for late vector op lowering. |
| //====---------------------------------------------------------------------===// |
| |
| namespace { |
| struct LLVMGPUVectorLoweringPass final |
| : impl::LLVMGPUVectorLoweringPassBase<LLVMGPUVectorLoweringPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<affine::AffineDialect>(); |
| registry.insert<memref::MemRefDialect>(); |
| registry.insert<vector::VectorDialect>(); |
| registry.insert<scf::SCFDialect>(); |
| } |
| void runOnOperation() override { |
| auto funcOp = getOperation(); |
| |
| { |
| // Lower high level vector operations like contract or multidim reduce ops |
| // to lower level vector ops. |
| RewritePatternSet contractLoweringPatterns(funcOp.getContext()); |
| vector::populateVectorTransferPermutationMapLoweringPatterns( |
| contractLoweringPatterns); |
| vector::TransposeOp::getCanonicalizationPatterns(contractLoweringPatterns, |
| funcOp.getContext()); |
| vector::populateVectorBroadcastLoweringPatterns(contractLoweringPatterns); |
| vector::populateVectorContractLoweringPatterns( |
| contractLoweringPatterns, |
| vector::VectorTransformsOptions().setVectorTransformsOptions( |
| vector::VectorContractLowering::OuterProduct)); |
| vector::populateVectorMaskOpLoweringPatterns(contractLoweringPatterns); |
| vector::populateVectorShapeCastLoweringPatterns(contractLoweringPatterns); |
| vector::populateVectorMultiReductionLoweringPatterns( |
| contractLoweringPatterns, |
| vector::VectorMultiReductionLowering::InnerParallel); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(contractLoweringPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| RewritePatternSet vectorToLoopsPatterns(&getContext()); |
| VectorTransferToSCFOptions vectorToSCFOptions; |
| vectorToSCFOptions.enableFullUnroll(); |
| populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, |
| vectorToSCFOptions); |
| memref::populateFoldMemRefAliasOpPatterns(vectorToLoopsPatterns); |
| vector::populateVectorTransferLoweringPatterns(vectorToLoopsPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(vectorToLoopsPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| }; |
| } // namespace |
| } // namespace mlir::iree_compiler |