blob: fcfb63f55329c2d5baad9e651fef16c155c0e547 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-llvmcpu-tile-fuse-and-vectorize"
namespace mlir {
namespace iree_compiler {
namespace {
// Could just be linalg::TilingPattern with a ContractionOpInterface filter, but
// that is always templated on an op.
struct TileWorkgroups : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker)
: LinalgBaseTilingPattern(context, options, marker) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op);
if (!contractionOp) return failure();
linalg::TiledLinalgOp tiledLinalgOp;
if (failed(Base::matchAndRewriteBase(op, rewriter, tiledLinalgOp))) {
return failure();
}
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
} // namespace
namespace {
struct LLVMCPUTileFuseAndVectorizePass
: public LLVMCPUTileFuseAndVectorizeBase<LLVMCPUTileFuseAndVectorizePass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, memref::MemRefDialect,
vector::VectorDialect>();
}
void runOnOperation() override;
};
LogicalResult applyTileAndFuseCanonicalizationPatterns(FuncOp funcOp) {
auto context = funcOp.getContext();
OwningRewritePatternList patterns(context);
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
return applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
} // namespace
void LLVMCPUTileFuseAndVectorizePass::runOnOperation() {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- Before LLVMCPUTileFuseAndVectorize ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
// Assume there is a single op with a lowering config we use to drive the
// tiling decisions.
// TODO(hanchung): Speicify a callback to get tile sizes in tile+fuse after
// upstream method supports it. Then we don't need extracting the config.
IREE::Codegen::LoweringConfigAttr config;
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (auto opConfig = getLoweringConfig(linalgOp)) {
if (opConfig) {
// Duplicate configurations.
if (config) return signalPassFailure();
config = opConfig;
}
}
});
// Tile and fuse Linalg ops.
{
OpBuilder builder(funcOp.getContext());
SmallVector<Operation *> computeOps;
SmallVector<LoopTilingAndDistributionInfo> tiledLoops;
if (failed(getComputeOps(funcOp, computeOps, tiledLoops))) {
return signalPassFailure();
}
auto tileSizes =
config.getTileSizeVals(static_cast<unsigned>(TilingLevel::L1Tiles));
linalg::LinalgOp consumerOp;
for (auto iter : llvm::reverse(computeOps)) {
if (auto op = dyn_cast<linalg::LinalgOp>(iter)) {
consumerOp = op;
break;
}
}
assert(consumerOp && "can't find consumerOp");
SmallVector<int64_t> consumerTileSize(
tileSizes.begin(),
tileSizes.begin() + consumerOp.getNumParallelLoops());
auto identityIndicesOrder =
llvm::to_vector<4>(llvm::seq<int64_t>(0, consumerTileSize.size()));
FailureOr<linalg::TileLoopNest> tileLoopNest =
linalg::tileConsumerAndFuseProducers(
builder, consumerOp, consumerTileSize, identityIndicesOrder);
if (failed(tileLoopNest)) return signalPassFailure();
consumerOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
// Apply canoncalization
if (failed(applyTileAndFuseCanonicalizationPatterns(funcOp))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After tile and fuse paralell loops ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// Tile and fuse for vector sizes, then tile reduction loops. We don't rely on
// unroll vector pass because it could introduce register pressure.
bool hasMatmulAndIsVectorizable = true;
{
OwningRewritePatternList tileReductionPatterns(&getContext());
funcOp.walk([&](linalg::ContractionOpInterface op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
if (failed(linalg::vectorizeLinalgOpPrecondition(linalgOp))) {
hasMatmulAndIsVectorizable = false;
}
auto loopRanges = linalgOp.getStaticLoopRanges();
if (loopRanges) {
auto tiles =
getTileSizes(op, static_cast<unsigned>(TilingLevel::VectorTiles));
for (int i = linalgOp.getNumParallelLoops(); i < tiles.size(); ++i) {
if (loopRanges.getValue()[i] == ShapedType::kDynamicSize ||
(tiles[i] && loopRanges.getValue()[i] % tiles[i] != 0)) {
hasMatmulAndIsVectorizable = false;
}
}
}
});
// If the matmul op is not vectorizable, stop directly.
// If the follow generic op is not vectorizable, it's fine.
// If the follow generic op is vectorizable, we can't vectorize it. Because
// an extra allocation op will be created (to temporarily store the result
// of matmul.)
if (!hasMatmulAndIsVectorizable) return;
tileReductionPatterns.insert<TileWorkgroups>(
context,
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
auto tiles =
getTileSizes(builder, operation,
static_cast<unsigned>(TilingLevel::L1Tiles));
auto numParallelLoops =
dyn_cast<linalg::LinalgOp>(operation).getNumParallelLoops();
auto zeroTileVal = builder.create<arith::ConstantIndexOp>(
operation->getLoc(), 0);
SmallVector<Value> reductionTiles(tiles.size(), zeroTileVal);
for (int i = numParallelLoops; i < tiles.size(); ++i) {
reductionTiles[i] = tiles[i];
}
return std::move(reductionTiles);
}),
linalg::LinalgTransformationFilter(
ArrayRef<Identifier>{},
Identifier::get(getVectorizeMarker(), context)));
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(tileReductionPatterns)))) {
return signalPassFailure();
}
// Apply canoncalization
if (failed(applyTileAndFuseCanonicalizationPatterns(funcOp))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After tiling reduction loops ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
{
// Set vectorization marker globally
OpBuilder builder(funcOp.getContext());
funcOp.walk(
[&](linalg::LinalgOp op) { setMarker(op, getVectorizeMarker()); });
}
// Apply vectorization patterns.
{
OwningRewritePatternList vectorizationPatterns(&getContext());
linalg::insertVectorizationPatterns<linalg::ContractionOpInterface,
linalg::GenericOp, linalg::CopyOp,
linalg::FillOp>(
vectorizationPatterns, linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter(
Identifier::get(getVectorizeMarker(), context)));
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorizationPatterns);
vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorizationPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After vectorization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
{
// Fold consumer add ops into the contraction op itself.
RewritePatternSet canonicalizationPatterns(context);
vector::ContractionOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns));
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs()
<< "\n--- After folding consumer add ops into contraction op "
"iteself ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// Apply vector unroll
{
RewritePatternSet vectorUnrollPatterns(context);
// TODO(hanchung): Set different vector sizes for different operations. Also
// it seems that `{16, 16, 16}` is not a good config. We should tune it.
vector::populateVectorUnrollPatterns(
vectorUnrollPatterns, vector::UnrollVectorOptions().setNativeShape(
config.getNativeVectorSizeVals()));
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(vectorUnrollPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After vector unrolling ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
linalg::hoistRedundantVectorTransfersOnTensor(funcOp);
linalg::hoistRedundantVectorTransfers(funcOp);
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "--- After hoisting vector transfers ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
// Apply vector specific operation lowering.
{
vector::VectorTransformsOptions vectorTransformsOptions =
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct);
OwningRewritePatternList vectorContractLoweringPatterns(&getContext());
vectorContractLoweringPatterns.insert<
vector::ContractionOpToOuterProductOpLowering,
vector::ContractionOpToMatmulOpLowering, vector::ContractionOpLowering>(
vectorTransformsOptions, context);
vector::populateVectorTransferPermutationMapLoweringPatterns(
vectorContractLoweringPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorContractLoweringPatterns)))) {
return signalPassFailure();
}
DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After vector specific operatrion lowering ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
}
std::unique_ptr<OperationPass<FuncOp>> createLLVMCPUTileFuseAndVectorizePass() {
return std::make_unique<LLVMCPUTileFuseAndVectorizePass>();
}
} // namespace iree_compiler
} // namespace mlir