blob: 5ca30867d466b866c0084cf94bcdfe35cf426144 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
#include "iree/compiler/Conversion/LinalgToLLVM/KernelDispatch.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-linalg-to-llvm-tile-and-vectorize"
namespace mlir {
namespace iree_compiler {
// TODO(ataei): Use pass options instead of global llvm flags.
static llvm::cl::opt<bool> clEnablePromoteWorkgroupToFullTiles(
"iree-codegen-llvm-promote-workgroup-to-full-tiles",
llvm::cl::desc("Enable promoting wokgroup memory to full tiles allocated "
"on the stack."),
llvm::cl::init(false));
namespace {
// Could just be linalg::TilingPattern with a ContractionOpInterface filter, but
// that is always templated on an op.
struct TileWorkgroups : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileWorkgroups(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker)
: LinalgBaseTilingPattern(options, marker) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op);
if (!contractionOp) return failure();
linalg::TiledLinalgOp tiledLinalgOp;
if (failed(Base::matchAndRewriteBase(op, rewriter, tiledLinalgOp)) ||
!tiledLinalgOp.tensorResults.empty()) {
return failure();
}
rewriter.eraseOp(op);
return success();
}
};
} // namespace
namespace {
struct TileAndVectorizeWorkgroups
: public PassWrapper<TileAndVectorizeWorkgroups, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, AffineDialect, scf::SCFDialect,
vector::VectorDialect>();
}
void runOnFunction() override;
};
} // namespace
namespace {
/// Pattern to promote all matmul operands to memory.
struct PromoteMatmulSubviewsPattern
: public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
PromoteMatmulSubviewsPattern(MLIRContext *context,
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: linalg::LinalgPromotionPattern<linalg::MatmulOp>(
context,
options.setOperandsToPromote({0, 1, 2})
.setUseFullTileBuffersByDefault(true),
marker, benefit) {}
};
} // namespace
namespace {
// TODO(ataei): Refactor this into a common utility with LinalgToSPIRV.
Optional<Value> allocateWorkgroupMemoryOnStack(
OpBuilder &b, memref::SubViewOp subview,
ArrayRef<Value> boundingSubViewSize, OperationFolder *folder) {
// Allocate the memory into the entry block of the parent FuncOp. This better
// aligns with the semantics of this memory which is available at the entry of
// the function.
OpBuilder::InsertionGuard guard(b);
FuncOp funcOp = subview->getParentOfType<FuncOp>();
if (!funcOp) {
subview.emitError("expected op to be within std.func");
return llvm::None;
}
b.setInsertionPointToStart(&(*funcOp.getBody().begin()));
// The bounding subview size is expected to be constant. This specified the
// shape of the allocation.
SmallVector<int64_t, 2> shape = llvm::to_vector<2>(
llvm::map_range(boundingSubViewSize, [](Value v) -> int64_t {
APInt value;
if (matchPattern(v, m_ConstantInt(&value))) return value.getSExtValue();
return -1;
}));
if (llvm::any_of(shape, [](int64_t v) { return v == -1; })) return {};
MemRefType allocType =
MemRefType::get(shape, subview.getType().getElementType(), {});
Value buffer = b.create<memref::AllocaOp>(subview.getLoc(), allocType);
return buffer;
}
LogicalResult deallocateWorkgroupMemory(OpBuilder &b, Value buffer) {
MemRefType bufferType = buffer.getType().dyn_cast<MemRefType>();
if (!bufferType) return failure();
return success();
}
} // namespace
void TileAndVectorizeWorkgroups::runOnFunction() {
auto funcOp = getOperation();
MLIRContext *context = &getContext();
// Promotes workgroups subviews to a full-tile allocated on the stack.
if (clEnablePromoteWorkgroupToFullTiles) {
OwningRewritePatternList promotionPatterns;
promotionPatterns.insert<PromoteMatmulSubviewsPattern>(
context,
linalg::LinalgPromotionOptions().setAllocationDeallocationFns(
allocateWorkgroupMemoryOnStack, deallocateWorkgroupMemory),
linalg::LinalgTransformationFilter(
Identifier::get(getWorkgroupMarker(), context),
Identifier::get(getWorkgroupMemoryMarker(), context)));
memref::ViewOp::getCanonicalizationPatterns(promotionPatterns, context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns));
}
// Workgroup first level of tiling.
{
// First level of tiling patterns. (workgroups memory)
OwningRewritePatternList l1patterns;
l1patterns.insert<TileWorkgroups>(
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
return TileSizeFn::get<TilingLevel::Level1Tiles>(builder,
operation);
}),
linalg::LinalgTransformationFilter(
Identifier::get(clEnablePromoteWorkgroupToFullTiles
? getWorkgroupMemoryMarker()
: getWorkgroupMarker(),
context),
Identifier::get(getWorkgroupL1TileMarker(), context)));
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l1patterns)))) {
return signalPassFailure();
}
}
// Second level of tiling. (workgroups memory -> vectors)
{
OwningRewritePatternList l2patterns;
l2patterns.insert<TileWorkgroups>(
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
return TileSizeFn::get<TilingLevel::Level2Tiles>(builder,
operation);
}),
linalg::LinalgTransformationFilter(
Identifier::get(getWorkgroupL1TileMarker(), context),
Identifier::get(getVectorizeMarker(), context)));
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(l2patterns)))) {
return signalPassFailure();
}
}
// Apply canonicalization.
{
OwningRewritePatternList canonicalizationPatterns;
canonicalizationPatterns.insert<AffineMinCanonicalizationPattern>(context);
AffineApplyOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
AffineMinOp::getCanonicalizationPatterns(canonicalizationPatterns, context);
memref::SubViewOp::getCanonicalizationPatterns(canonicalizationPatterns,
context);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(canonicalizationPatterns)))) {
return signalPassFailure();
}
}
// Apply vectorization patterns.
{
OwningRewritePatternList vectorizationPatterns;
linalg::insertVectorizationPatterns<linalg::ContractionOpInterface,
linalg::CopyOp, linalg::FillOp>(
vectorizationPatterns, context, linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter(
Identifier::get(getVectorizeMarker(), context)));
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorizationPatterns)))) {
return signalPassFailure();
}
}
// TODO: This should be a folding of Add into Contract in core but while they
// live in different dialects, it is not possible without unnatural
// dependencies.
funcOp.walk([&](Operation *op) {
if (auto contract = canonicalizeContractionAdd(op))
op->replaceAllUsesWith(contract);
});
// Apply vector specific operation lowering.
{
vector::VectorTransformsOptions vectorTransformsOptions =
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct);
OwningRewritePatternList vectorContractLoweringPatterns;
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vectorTransformsOptions, context);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorContractLoweringPatterns)))) {
return signalPassFailure();
}
}
// Programmatic controlled lowering of vector.transfer only.
{
VectorTransferToSCFOptions vectorToSCFOptions =
VectorTransferToSCFOptions().setUnroll(true);
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vectorToSCFOptions);
// Hosit hierarchical tiling indexing and other loop invariant transfer
// ops computation.
linalg::hoistRedundantVectorTransfers(funcOp);
// TODO(ataei): Move this to common vector dialect patterns.
populateStdLegalizationPatternsForSPIRVLowering(context,
vectorToLoopsPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(vectorToLoopsPatterns)))) {
return signalPassFailure();
}
}
}
std::unique_ptr<FunctionPass> createLinalgTileAndVectorizeWorkgroupsPass() {
return std::make_unique<TileAndVectorizeWorkgroups>();
}
static PassRegistration<TileAndVectorizeWorkgroups> pass(
"iree-codegen-linalg-to-llvm-workgroups-vectorization-pass",
"Tile and vectorize llvm workgroups",
[] { return std::make_unique<TileAndVectorizeWorkgroups>(); });
} // namespace iree_compiler
} // namespace mlir