| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" |
| #include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/Transforms/Transforms.h" |
| #include "iree/compiler/Codegen/Utils/MarkerUtils.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| #include "mlir/Dialect/SCF/Transforms.h" |
| #include "mlir/Dialect/Vector/VectorTransforms.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| #define DEBUG_TYPE "iree-llvmcpu-tile-fuse-and-vectorize" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| // Could just be linalg::TilingPattern with a ContractionOpInterface filter, but |
| // that is always templated on an op. |
| struct TileWorkgroups : public linalg::LinalgTilingPattern { |
| using Base = linalg::LinalgTilingPattern; |
| TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options, |
| linalg::LinalgTransformationFilter marker) |
| : LinalgTilingPattern(context, options, marker) {} |
| LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, |
| PatternRewriter &rewriter) const override { |
| if (!isa<linalg::ContractionOpInterface>(linalgOp.getOperation())) |
| return failure(); |
| return Base::returningMatchAndRewrite(linalgOp, rewriter); |
| } |
| }; |
| |
| } // namespace |
| |
| namespace { |
| struct LLVMCPUTileFuseAndVectorizePass |
| : public LLVMCPUTileFuseAndVectorizeBase<LLVMCPUTileFuseAndVectorizePass> { |
| LLVMCPUTileFuseAndVectorizePass(bool vectorize = true) |
| : lowerToVectors(vectorize) {} |
| LLVMCPUTileFuseAndVectorizePass(const LLVMCPUTileFuseAndVectorizePass &pass) { |
| lowerToVectors = pass.lowerToVectors; |
| } |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<linalg::LinalgDialect, memref::MemRefDialect, |
| vector::VectorDialect>(); |
| } |
| void runOnOperation() override; |
| |
| private: |
| bool lowerToVectors; |
| }; |
| |
| LogicalResult applyTileAndFuseCanonicalizationPatterns(FuncOp funcOp) { |
| auto context = funcOp.getContext(); |
| OwningRewritePatternList patterns(context); |
| linalg::populateLinalgTilingCanonicalizationPatterns(patterns); |
| tensor::DimOp::getCanonicalizationPatterns(patterns, context); |
| memref::DimOp::getCanonicalizationPatterns(patterns, context); |
| memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); |
| memref::populateResolveShapedTypeResultDimsPatterns(patterns); |
| scf::populateSCFForLoopCanonicalizationPatterns(patterns); |
| return applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); |
| } |
| } // namespace |
| |
| void LLVMCPUTileFuseAndVectorizePass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| auto funcOp = getOperation(); |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- Before LLVMCPUTileFuseAndVectorize ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| |
| // Assume there is a single op with a lowering config we use to drive the |
| // tiling decisions. |
| // TODO(hanchung): Speicify a callback to get tile sizes in tile+fuse after |
| // upstream method supports it. Then we don't need extracting the config. |
| IREE::Codegen::LoweringConfigAttr config; |
| funcOp.walk([&](linalg::LinalgOp linalgOp) { |
| if (auto opConfig = getLoweringConfig(linalgOp)) { |
| if (opConfig) { |
| // Duplicate configurations. |
| if (config) return signalPassFailure(); |
| config = opConfig; |
| } |
| } |
| }); |
| |
| // Tile and fuse Linalg ops. |
| { |
| OpBuilder builder(funcOp.getContext()); |
| SmallVector<Operation *> computeOps; |
| SmallVector<LoopTilingAndDistributionInfo> tiledLoops; |
| if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) { |
| return signalPassFailure(); |
| } |
| auto tileSizes = |
| config.getTileSizeVals(static_cast<unsigned>(TilingLevel::L1Tiles)); |
| linalg::LinalgOp consumerOp; |
| for (auto iter : llvm::reverse(computeOps)) { |
| if (auto op = dyn_cast<linalg::LinalgOp>(iter)) { |
| consumerOp = op; |
| break; |
| } |
| } |
| assert(consumerOp && "can't find consumerOp"); |
| SmallVector<int64_t> consumerTileSize( |
| tileSizes.begin(), |
| tileSizes.begin() + consumerOp.getNumParallelLoops()); |
| auto identityIndicesOrder = |
| llvm::to_vector<4>(llvm::seq<int64_t>(0, consumerTileSize.size())); |
| FailureOr<linalg::TileLoopNest> tileLoopNest = |
| linalg::tileConsumerAndFuseProducers( |
| builder, consumerOp, consumerTileSize, identityIndicesOrder); |
| if (failed(tileLoopNest)) return signalPassFailure(); |
| consumerOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); |
| |
| // Apply canoncalization |
| if (failed(applyTileAndFuseCanonicalizationPatterns(funcOp))) { |
| return signalPassFailure(); |
| } |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After tile and fuse paralell loops ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| { |
| OwningRewritePatternList tileReductionPatterns(&getContext()); |
| |
| // TODO(hanchung): Add a pattern to fold the tensor.extract_slice op. |
| // One-trip loop can be removed. But weird patterns could be generated and |
| // can't be folded atm. E.g., |
| // %a = linalg.init_tensor [%x, 4] : tensor<?x4xf32> |
| // %b = linalg.fill(%cst0, %a) |
| // %c = tensor.extract_slice %b[0, 0] [%x, 4] [1, 1] |
| // |
| // In this case, %c should be folded. Otherwise, it introduces memref.alloca |
| // in bufferization. |
| bool shouldTileReductionLoop = true; |
| funcOp.walk([&](linalg::ContractionOpInterface op) { |
| auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation()); |
| auto loopRanges = linalgOp.getStaticLoopRanges(); |
| if (loopRanges) { |
| auto l1Tiles = |
| getTileSizes(op, static_cast<unsigned>(TilingLevel::L1Tiles)); |
| for (int i = linalgOp.getNumParallelLoops(); i < l1Tiles.size(); ++i) { |
| if (loopRanges.getValue()[i] != ShapedType::kDynamicSize && |
| l1Tiles[i] && loopRanges.getValue()[i] <= l1Tiles[i]) { |
| shouldTileReductionLoop = false; |
| } |
| } |
| } |
| }); |
| if (shouldTileReductionLoop) { |
| tileReductionPatterns.insert<TileWorkgroups>( |
| context, |
| linalg::LinalgTilingOptions().setTileSizeComputationFunction( |
| [](OpBuilder &builder, |
| Operation *operation) -> SmallVector<Value, 4> { |
| auto tiles = |
| getTileSizes(builder, operation, |
| static_cast<unsigned>(TilingLevel::L1Tiles)); |
| auto numParallelLoops = |
| dyn_cast<linalg::LinalgOp>(operation).getNumParallelLoops(); |
| auto zeroTileVal = builder.create<arith::ConstantIndexOp>( |
| operation->getLoc(), 0); |
| SmallVector<Value> reductionTiles(tiles.size(), zeroTileVal); |
| for (int i = numParallelLoops; i < tiles.size(); ++i) { |
| reductionTiles[i] = tiles[i]; |
| } |
| return std::move(reductionTiles); |
| }), |
| linalg::LinalgTransformationFilter( |
| ArrayRef<StringAttr>{}, |
| StringAttr::get(context, getVectorizeMarker()))); |
| |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(tileReductionPatterns)))) { |
| return signalPassFailure(); |
| } |
| // Apply canoncalization |
| if (failed(applyTileAndFuseCanonicalizationPatterns(funcOp))) { |
| return signalPassFailure(); |
| } |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After tiling reduction loops ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| } |
| |
| funcOp.walk([&](linalg::ContractionOpInterface op) { |
| if (cast<linalg::LinalgOp>(op.getOperation()).hasDynamicShape()) { |
| lowerToVectors = false; |
| } |
| }); |
| |
| if (!lowerToVectors) { |
| // Apply second level of tiling patterns if they are not vectorizable. This |
| // will trigger LLVM auto-vectorization, which gains better performance. |
| { |
| funcOp.walk([&](linalg::ContractionOpInterface op) { |
| setMarker(op, getWorkgroupL1TileMarker()); |
| }); |
| OwningRewritePatternList l2patterns(&getContext()); |
| l2patterns.insert<TileWorkgroups>( |
| context, |
| linalg::LinalgTilingOptions().setTileSizeComputationFunction( |
| [](OpBuilder &builder, Operation *op) -> SmallVector<Value, 4> { |
| return getTileSizes( |
| builder, op, |
| static_cast<unsigned>(TilingLevel::VectorTiles)); |
| }), |
| linalg::LinalgTransformationFilter( |
| StringAttr::get(context, getWorkgroupL1TileMarker()), |
| StringAttr::get(context, getVectorizeMarker()))); |
| |
| if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l2patterns)))) { |
| return signalPassFailure(); |
| } |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After second level of tiling patterns ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| return; |
| } |
| |
| { |
| // Set vectorization marker globally |
| OpBuilder builder(funcOp.getContext()); |
| funcOp.walk( |
| [&](linalg::LinalgOp op) { setMarker(op, getVectorizeMarker()); }); |
| } |
| |
| // Op specific conversion. |
| { |
| RewritePatternSet patterns(context); |
| populateLinalgToVectorVectorizeMMT4dPatterns(context, patterns); |
| if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| // Apply vectorization patterns. |
| { |
| OwningRewritePatternList vectorizationPatterns(&getContext()); |
| linalg::LinalgVectorizationOptions opt; |
| linalg::LinalgTransformationFilter f( |
| StringAttr::get(context, getVectorizeMarker())); |
| linalg::VectorizationPatterns<linalg::GenericOp, linalg::CopyOp, |
| linalg::FillOp>::insert(vectorizationPatterns, |
| opt, f); |
| vectorizationPatterns.add<linalg::LinalgVectorizationPattern>( |
| &getContext(), f.addOpFilter<linalg::ContractionOpInterface>(), opt); |
| vector::populateVectorTransferPermutationMapLoweringPatterns( |
| vectorizationPatterns); |
| vector::populateVectorReductionToContractPatterns(vectorizationPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(vectorizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After vectorization ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| { |
| // Fold consumer add ops into the contraction op itself. |
| RewritePatternSet canonicalizationPatterns(context); |
| vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns, |
| context); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(canonicalizationPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() |
| << "\n--- After folding consumer add ops into contraction op " |
| "iteself ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| // Apply vector unroll |
| { |
| RewritePatternSet vectorUnrollPatterns(context); |
| // TODO(hanchung): Set different vector sizes for different operations. Also |
| // it seems that `{16, 16, 16}` is not a good config. We should tune it. |
| vector::populateVectorUnrollPatterns( |
| vectorUnrollPatterns, |
| vector::UnrollVectorOptions().setNativeShape(config.getTileSizeVals( |
| static_cast<unsigned>(TilingLevel::VectorTiles)))); |
| |
| if (failed(applyPatternsAndFoldGreedily(funcOp, |
| std::move(vectorUnrollPatterns)))) { |
| return signalPassFailure(); |
| } |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After vector unrolling ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| |
| linalg::hoistRedundantVectorTransfersOnTensor(funcOp); |
| linalg::hoistRedundantVectorTransfers(funcOp); |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "--- After hoisting vector transfers ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| |
| { |
| // Special-case vector.contract codegen paths. This needs to happen |
| // just before the generic vector ops lowerings. |
| CustomKernelsTargetInfo info; |
| if (succeeded(InferCustomKernelsTargetInfoFromParent(funcOp, info))) { |
| RewritePatternSet patterns(context); |
| populateVectorContractCustomKernelsPatterns(info, patterns); |
| if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| |
| // Apply vector specific operation lowering. |
| { |
| vector::VectorTransformsOptions vectorTransformsOptions = |
| vector::VectorTransformsOptions().setVectorTransformsOptions( |
| vector::VectorContractLowering::OuterProduct); |
| OwningRewritePatternList vectorContractLoweringPatterns(&getContext()); |
| vectorContractLoweringPatterns.insert< |
| vector::ContractionOpToOuterProductOpLowering, |
| vector::ContractionOpToMatmulOpLowering, vector::ContractionOpLowering>( |
| vectorTransformsOptions, context); |
| vector::populateVectorTransferPermutationMapLoweringPatterns( |
| vectorContractLoweringPatterns); |
| if (failed(applyPatternsAndFoldGreedily( |
| funcOp, std::move(vectorContractLoweringPatterns)))) { |
| return signalPassFailure(); |
| } |
| |
| DEBUG_WITH_TYPE(DEBUG_TYPE, { |
| llvm::dbgs() << "\n--- After vector specific operatrion lowering ---\n"; |
| funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); |
| llvm::dbgs() << "\n\n"; |
| }); |
| } |
| } |
| |
| std::unique_ptr<OperationPass<FuncOp>> createLLVMCPUTileFuseAndVectorizePass( |
| bool lowerToVectors) { |
| return std::make_unique<LLVMCPUTileFuseAndVectorizePass>(lowerToVectors); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |