blob: fe5c54a1ef08e1779aaf108c37bea37d81f444c9 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- GPUTile.cpp ------------------------------------------------------===//
//
// This pass tiles and Linalg ops with tensor semantics to invocations.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Codegen/Common/GPU/PassDetail.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-codegen-gpu-tile"
namespace mlir::iree_compiler {
//===----------------------------------------------------------------------===//
// Tiling and fusion utilities
//===----------------------------------------------------------------------===//
/// Collects computation ops which we will use as anchor to tile and fuse.
static FailureOr<IREE::Codegen::LoweringConfigAttr>
collectComputeOps(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<Operation *> &computeOps) {
// If there are `scf.if` ops which have linalg ops, we have both a fast and
// slow paths for padding handling. Then we need to scan both regions to
// discover such computation ops so that we can tile and fuse both regions.
SmallVector<scf::IfOp, 1> ifOps;
funcOp.walk<WalkOrder::PreOrder>([&ifOps](scf::IfOp ifOp) -> WalkResult {
if (ifOp->getParentOfType<linalg::LinalgOp>()) {
// Exclude scf.if in linalg op
return WalkResult::skip();
} else {
ifOps.push_back(ifOp);
return WalkResult::advance();
}
});
SmallVector<IREE::Codegen::LoweringConfigAttr> configs;
if (ifOps.empty()) {
computeOps = getComputeOps(funcOp);
for (Operation *op : computeOps) {
if (auto config = getLoweringConfig(op))
configs.push_back(config);
}
if (computeOps.size() > 1) {
// Only keep the last compute ops.
std::reverse(computeOps.begin(), computeOps.end());
computeOps.resize(1);
}
} else {
if (ifOps.size() > 1) {
return funcOp.emitError("expected to contain <= 1 scf.if ops");
}
ifOps.front()->walk([&configs](Operation *op) {
if (isa<linalg::LinalgOp, TilingInterface>(op)) {
if (auto config = getLoweringConfig(op))
configs.push_back(config);
}
});
for (Operation &op : llvm::reverse(*ifOps.front().thenBlock())) {
if (isa<linalg::LinalgOp, TilingInterface>(op)) {
computeOps.push_back(&op);
break;
}
}
if (Block *elseBlock = ifOps.front().elseBlock()) {
for (Operation &op : llvm::reverse(*elseBlock)) {
if (isa<linalg::LinalgOp, TilingInterface>(op)) {
computeOps.push_back(&op);
break;
}
}
}
}
if (configs.empty()) {
return funcOp.emitError("missing lowering configuration");
}
if (!llvm::all_equal(configs)) {
return funcOp.emitError("contains conflicting lowering configuration");
}
return configs.front();
}
static LogicalResult tileAndDistributeToThreads(TilingInterface consumerOp,
ArrayRef<int64_t> tileSizes) {
MLIRContext *context = consumerOp->getContext();
IRRewriter rewriter(context);
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(context, tileSizes);
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(
rewriter, consumerOp,
scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizes(tileSizesOfr)));
if (failed(tileAndFuseResult)) {
return consumerOp.emitOpError("failed tiling and fusing producers");
}
SmallVector<Value> replacements;
replacements.resize(consumerOp->getNumResults());
for (const auto &[index, result] :
llvm::enumerate(consumerOp->getResults())) {
replacements[index] = tileAndFuseResult->replacements.lookup(result);
}
consumerOp->replaceAllUsesWith(replacements);
// We don't distribute here; instead, it will be done in a later step
// after bufferization. So add attributes to the tiled loop nest to
// indicate that they should be distributed to invocations.
ArrayRef<LoopLikeOpInterface> loops = tileAndFuseResult.value().loops;
const char *attrName = getGPUDistributeAttrName();
// We can have more than 3 dimensions being tiled (e.g., for convolutions with
// non-1 batch). But only the innermost 3 dimensions are distributed.
for (auto [dim, loop] : zip(llvm::seq(0, 3), llvm::reverse(loops))) {
loop->setAttr(attrName, rewriter.getIndexAttr(dim));
}
return success();
}
/// Tiles reduction dimensions.
static LogicalResult
tileReduction(mlir::FunctionOpInterface funcOp,
const scf::SCFTileSizeComputationFunction &computeFn) {
auto filter = LinalgTransformationFilter().setMatchByDefault();
auto options =
scf::SCFTilingOptions().setTileSizeComputationFunction(computeFn);
auto result = tileLinalgOpsWithFilter(funcOp, options, filter);
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling reduction dimensions ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
return result;
}
/// Fuses `tensor.pad` ops into the the materalized loop nests containing
/// their consumer ops.
static void fusePadIntoConsumer(mlir::FunctionOpInterface funcOp) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
patterns.insert<linalg::ExtractSliceOfPadTensorSwapPattern>(
context, [](tensor::ExtractSliceOp) { return false; });
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
LLVM_DEBUG({
llvm::dbgs() << "--- After fusing padding into consumers ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
};
/// Concretizes `tensor.pad` ops' result shapes.
static void concretizePadShape(mlir::FunctionOpInterface funcOp) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
SmallVector<int64_t> numWorkgroups = getStaticNumWorkgroups(funcOp);
populateConcretizePadResultShapePatterns(patterns, numWorkgroups);
populateFoldAffineMinInDistributedLoopsPatterns(patterns, numWorkgroups);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
LLVM_DEBUG({
llvm::dbgs() << "--- After concretizing pad result shape ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
/// Tiles one of the convolution output window dimensions with size 1 to prepare
/// for downsizing 2-D convolution ops into 1-D ones.
static LogicalResult tileAndUnrollConvWindow(mlir::FunctionOpInterface funcOp,
ArrayRef<OpFoldResult> tileSizes) {
SmallVector<linalg::ConvolutionOpInterface, 1> convOps;
funcOp.walk([&convOps](linalg::ConvolutionOpInterface convOp) {
convOps.push_back(convOp);
});
for (linalg::ConvolutionOpInterface convOp : convOps) {
auto consumerOp = cast<linalg::LinalgOp>(*convOp);
IRRewriter rewriter(funcOp.getContext());
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(
rewriter, cast<TilingInterface>(consumerOp.getOperation()),
scf::SCFTileAndFuseOptions().setTilingOptions(
scf::SCFTilingOptions().setTileSizes(tileSizes)));
if (failed(tileAndFuseResult)) {
return consumerOp.emitOpError("failed tiling and fusing producers");
}
SmallVector<Value> replacements;
replacements.resize(consumerOp->getNumResults());
for (const auto &[index, result] :
llvm::enumerate(consumerOp->getResults())) {
replacements[index] = tileAndFuseResult->replacements.lookup(result);
}
consumerOp->replaceAllUsesWith(replacements);
// Fully unroll the generated loop. This allows us to remove the loop
// for parallel output window dimension, so it helps future vector
// transformations.
ArrayRef<LoopLikeOpInterface> loops = tileAndFuseResult.value().loops;
if (!loops.empty()) {
assert(loops.size() == 1);
scf::ForOp loopOp = cast<scf::ForOp>(loops.front());
IntegerAttr ub;
if (!matchPattern(loopOp.getUpperBound(), m_Constant(&ub))) {
return loopOp.emitOpError("upper bound should be a constant");
}
if (failed(mlir::loopUnrollByFactor(loopOp, ub.getInt()))) {
return loopOp.emitOpError("failed unrolling by factor 1");
}
}
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling convolution output window ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
return success();
}
//===----------------------------------------------------------------------===//
// Main pass
//===----------------------------------------------------------------------===//
namespace {
class GPUTilePass final : public GPUTileBase<GPUTilePass> {
public:
GPUTilePass() = default;
GPUTilePass(const GPUTilePass &pass) = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
// Try to find computation ops which we will use as anchor to tile and fuse.
SmallVector<Operation *> computeOps;
FailureOr<IREE::Codegen::LoweringConfigAttr> loweringConfig =
collectComputeOps(funcOp, computeOps);
if (failed(loweringConfig))
return signalPassFailure();
assert(computeOps.size() <= 2);
// Now tile the last computation op to invocations and fuse all operand
// computation ops into the materialized loop nest.
auto threadTileSizes = loweringConfig->getTileSizeVals(1);
for (Operation *computeOp : computeOps) {
auto consumerOp = dyn_cast<TilingInterface>(computeOp);
if (!consumerOp ||
failed(tileAndDistributeToThreads(consumerOp, threadTileSizes)))
return signalPassFailure();
}
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling to invocations ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
fusePadIntoConsumer(funcOp);
concretizePadShape(funcOp);
auto reductionTileComputeFn = getGPUScfTileSizeComputeFn(funcOp, 2);
if (failed(reductionTileComputeFn) ||
failed(tileReduction(funcOp, reductionTileComputeFn.value()))) {
return signalPassFailure();
}
fusePadIntoConsumer(funcOp);
SmallVector<OpFoldResult> windowTileSizes =
getAsIndexOpFoldResult(context, loweringConfig->getTileSizeVals(3));
if (failed(tileAndUnrollConvWindow(funcOp, windowTileSizes))) {
return signalPassFailure();
}
concretizePadShape(funcOp);
{ // Downsize n-D (n > 1) convolutions to 1-D.
RewritePatternSet patterns(context);
linalg::populateDecomposeConvolutionPatterns(patterns);
// Downsizing creates consecutive extract/insert slice ops. Merge them.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
// Pull in patterns to fold constant insert/extract slice op parameters.
tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, context);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
// Pull in scf.for op canonicalization patterns to help hoisting across
// multiple loops and remove loop carried values unused in the body.
scf::ForOp::getCanonicalizationPatterns(patterns, context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
LLVM_DEBUG({
llvm::dbgs() << "--- After Downsizing N-D convolution to 1-D ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
}
};
} // namespace
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> createGPUTilePass() {
return std::make_unique<GPUTilePass>();
}
} // namespace mlir::iree_compiler