| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" |
| #include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Transforms/Transforms.h" |
| #include "iree/compiler/Codegen/Utils/MarkerUtils.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| #include "mlir/Dialect/SCF/Transforms.h" |
| #include "mlir/Dialect/Vector/VectorTransforms.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "iree-llvmcpu-tile-and-vectorize" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| // Could just be linalg::TilingPattern with a ContractionOpInterface filter, but |
| // that is always templated on an op. |
| struct TileWorkgroups : public linalg::LinalgTilingPattern { |
| using Base = linalg::LinalgTilingPattern; |
| TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options, |
| linalg::LinalgTransformationFilter marker) |
| : LinalgTilingPattern(context, options, marker) {} |
| LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, |
| PatternRewriter &rewriter) const override { |
| if (!isa<linalg::ContractionOpInterface>(linalgOp.getOperation())) |
| return failure(); |
| return Base::returningMatchAndRewrite(linalgOp, rewriter); |
| } |
| }; |
| |
| } // namespace |
| |
| namespace { |
| struct LLVMCPUTileAndVectorizePass |
| : public LLVMCPUTileAndVectorizeBase<LLVMCPUTileAndVectorizePass> { |
| LLVMCPUTileAndVectorizePass(bool vectorize = true) |
| : lowerToVectors(vectorize) {} |
| LLVMCPUTileAndVectorizePass(const LLVMCPUTileAndVectorizePass &pass) { |
| lowerToVectors = pass.lowerToVectors; |
| } |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<linalg::LinalgDialect, memref::MemRefDialect, |
| vector::VectorDialect>(); |
| } |
| void runOnOperation() override; |
| |
| private: |
| bool lowerToVectors; |
| }; |
| } // namespace |
| |
| void LLVMCPUTileAndVectorizePass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| auto funcOp = getOperation(); |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- Before LLVMCPUTileAndVectorizePass ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| |
| // First level of tiling patterns |
| { |
| OwningRewritePatternList l1patterns(&getContext()); |
| l1patterns.insert<TileWorkgroups>( |
| context, |
| linalg::LinalgTilingOptions().setTileSizeComputationFunction( |
| [](OpBuilder &builder, Operation *op) -> SmallVector<Value, 4> { |
| return getTileSizes(builder, op, |
| static_cast<unsigned>(TilingLevel::L1Tiles)); |
| }), |
| linalg::LinalgTransformationFilter( |
| ArrayRef<Identifier>{}, |
| Identifier::get(getWorkgroupL1TileMarker(), context))); |
| |
| if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l1patterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After first level of tiling patterns ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| // Apply canoncalization |
| { |
| OwningRewritePatternList canonicalizationPatterns(&getContext()); |
| linalg::populateLinalgTilingCanonicalizationPatterns( |
| canonicalizationPatterns); |
| memref::DimOp::getCanonicalizationPatterns(canonicalizationPatterns, |
| context); |
| scf::populateSCFForLoopCanonicalizationPatterns(canonicalizationPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(canonicalizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After canonicalization ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| // Second level of tiling patterns{ |
| { |
| OwningRewritePatternList l2patterns(&getContext()); |
| l2patterns.insert<TileWorkgroups>( |
| context, |
| linalg::LinalgTilingOptions().setTileSizeComputationFunction( |
| [](OpBuilder &builder, Operation *op) -> SmallVector<Value, 4> { |
| return getTileSizes( |
| builder, op, static_cast<unsigned>(TilingLevel::VectorTiles)); |
| }), |
| linalg::LinalgTransformationFilter( |
| Identifier::get(getWorkgroupL1TileMarker(), context), |
| Identifier::get(getVectorizeMarker(), context))); |
| |
| if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l2patterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After second level of tiling patterns ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| // Apply canoncalization |
| { |
| OwningRewritePatternList canonicalizationPatterns(&getContext()); |
| linalg::populateLinalgTilingCanonicalizationPatterns( |
| canonicalizationPatterns); |
| memref::DimOp::getCanonicalizationPatterns(canonicalizationPatterns, |
| context); |
| scf::populateSCFForLoopCanonicalizationPatterns(canonicalizationPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(canonicalizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After canonicalization ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| if (!lowerToVectors) { |
| return; |
| } |
| |
| // Op specific conversion. |
| { |
| RewritePatternSet vectorizeOpsPattenrs(context); |
| populateLinalgToVectorVectorizeMMT4dPatterns(context, vectorizeOpsPattenrs); |
| if (failed(applyPatternsAndFoldGreedily(funcOp, |
| std::move(vectorizeOpsPattenrs)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| // Apply vectorization patterns. |
| { |
| OwningRewritePatternList vectorizationPatterns(&getContext()); |
| linalg::LinalgVectorizationOptions opt; |
| linalg::LinalgTransformationFilter f( |
| Identifier::get(getVectorizeMarker(), context)); |
| linalg::VectorizationPatterns<linalg::CopyOp, linalg::FillOp>::insert( |
| vectorizationPatterns, opt, f); |
| vectorizationPatterns.add<linalg::LinalgVectorizationPattern>( |
| context, f.addOpFilter<linalg::ContractionOpInterface>(), opt); |
| vector::populateVectorTransferPermutationMapLoweringPatterns( |
| vectorizationPatterns); |
| vector::populateVectorReductionToContractPatterns(vectorizationPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(vectorizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After vectorization ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| { |
| // Fold consumer add ops into the contraction op itself. |
| RewritePatternSet canonicalizationPatterns(context); |
| vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns, |
| context); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(canonicalizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() |
| << "\n--- After folding consumer add ops into contraction op " |
| "iteself ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| // Apply vector specific operation lowering. |
| { |
| vector::VectorTransformsOptions vectorTransformsOptions = |
| vector::VectorTransformsOptions().setVectorTransformsOptions( |
| vector::VectorContractLowering::OuterProduct); |
| OwningRewritePatternList vectorContractLoweringPatterns(&getContext()); |
| vectorContractLoweringPatterns.insert< |
| vector::ContractionOpToOuterProductOpLowering, |
| vector::ContractionOpToMatmulOpLowering, vector::ContractionOpLowering>( |
| vectorTransformsOptions, context); |
| vector::populateVectorTransferPermutationMapLoweringPatterns( |
| vectorContractLoweringPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(vectorContractLoweringPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After vector specific operatrion lowering ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| } |
| |
| std::unique_ptr<OperationPass<FuncOp>> createLLVMCPUTileAndVectorizePass( |
| bool lowerToVectors) { |
| return std::make_unique<LLVMCPUTileAndVectorizePass>(lowerToVectors); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |