blob: 7eb2c56ce8afb382fe33af66b3338d99d393d9f9 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-llvmcpu-tile-and-vectorize"
namespace mlir {
namespace iree_compiler {
namespace {
// Could just be linalg::TilingPattern with a ContractionOpInterface filter, but
// that is always templated on an op.
struct TileWorkgroups : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker)
: LinalgBaseTilingPattern(context, options, marker) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op);
if (!contractionOp) return failure();
linalg::TiledLinalgOp tiledLinalgOp;
if (failed(Base::matchAndRewriteBase(op, rewriter, tiledLinalgOp))) {
return failure();
}
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
} // namespace
namespace {
struct LLVMCPUTileAndVectorizePass
: public LLVMCPUTileAndVectorizeBase<LLVMCPUTileAndVectorizePass> {
LLVMCPUTileAndVectorizePass(bool vectorize = true)
: lowerToVectors(vectorize) {}
LLVMCPUTileAndVectorizePass(const LLVMCPUTileAndVectorizePass &pass) {
lowerToVectors = pass.lowerToVectors;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, memref::MemRefDialect,
vector::VectorDialect>();
}
void runOnOperation() override;
private:
bool lowerToVectors;
};
} // namespace
void LLVMCPUTileAndVectorizePass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- Before LLVMCPUTileAndVectorizePass ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
// First level of tiling patterns
{
OwningRewritePatternList l1patterns(&getContext());
l1patterns.insert<TileWorkgroups>(
context,
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder, Operation *op) -> SmallVector<Value, 4> {
return getTileSizes(builder, op,
static_cast<unsigned>(TilingLevel::L1Tiles));
}),
linalg::LinalgTransformationFilter(
ArrayRef<Identifier>{},
Identifier::get(getWorkgroupL1TileMarker(), context)));
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l1patterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After first level of tiling patterns ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// Apply canoncalization
{
OwningRewritePatternList canonicalizationPatterns(&getContext());
linalg::populateLinalgTilingCanonicalizationPatterns(
canonicalizationPatterns);
memref::DimOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
scf::populateSCFForLoopCanonicalizationPatterns(canonicalizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(canonicalizationPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After canonicalization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// Second level of tiling patterns{
{
OwningRewritePatternList l2patterns(&getContext());
l2patterns.insert<TileWorkgroups>(
context,
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder, Operation *op) -> SmallVector<Value, 4> {
return getTileSizes(
builder, op, static_cast<unsigned>(TilingLevel::VectorTiles));
}),
linalg::LinalgTransformationFilter(
Identifier::get(getWorkgroupL1TileMarker(), context),
Identifier::get(getVectorizeMarker(), context)));
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l2patterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After second level of tiling patterns ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// Apply canoncalization
{
OwningRewritePatternList canonicalizationPatterns(&getContext());
linalg::populateLinalgTilingCanonicalizationPatterns(
canonicalizationPatterns);
memref::DimOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
scf::populateSCFForLoopCanonicalizationPatterns(canonicalizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(canonicalizationPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After canonicalization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
if (!lowerToVectors) {
return;
}
// Apply vectorization patterns.
{
OwningRewritePatternList vectorizationPatterns(&getContext());
linalg::insertVectorizationPatterns<linalg::ContractionOpInterface,
linalg::CopyOp, linalg::FillOp>(
vectorizationPatterns, linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter(
Identifier::get(getVectorizeMarker(), context)));
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorizationPatterns);
vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorizationPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After vectorization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
{
// Fold consumer add ops into the contraction op itself.
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns));
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs()
<< "\n--- After folding consumer add ops into contraction op "
"iteself ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// Apply vector specific operation lowering.
{
vector::VectorTransformsOptions vectorTransformsOptions =
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct);
OwningRewritePatternList vectorContractLoweringPatterns(&getContext());
vectorContractLoweringPatterns.insert<
vector::ContractionOpToOuterProductOpLowering,
vector::ContractionOpToMatmulOpLowering, vector::ContractionOpLowering>(
vectorTransformsOptions, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorContractLoweringPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorContractLoweringPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After vector specific operatrion lowering ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
}
std::unique_ptr<OperationPass<FuncOp>> createLLVMCPUTileAndVectorizePass(
bool lowerToVectors) {
return std::make_unique<LLVMCPUTileAndVectorizePass>(lowerToVectors);
}
} // namespace iree_compiler
} // namespace mlir