blob: 5f0af79025a0ec3e03cf5becfa091973d1b75220 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- TileAndVectorizeInOneWorkgroup.cpp ---------------------------------===//
//
// This pass tiles and vectorizes Linalg ops on buffers within in a single
// workgroup.
//
//===----------------------------------------------------------------------===//
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h"
#include "iree/compiler/Conversion/CodegenUtils/TransformUtils.h"
#include "iree/compiler/Conversion/Common/Transforms.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/CodeGenOptionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
#include "iree/compiler/Conversion/LinalgToVector/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#define DEBUG_TYPE "iree-spirv-tile-and-vectorize-in-one-workgroup"
namespace mlir {
namespace iree_compiler {
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Returns a Linalg marker that matches any of the `matchMarkers` and replaces
/// it with `replaceMarker`.
static linalg::LinalgTransformationFilter getLinalgMatchAndReplaceMarker(
ArrayRef<StringRef> matchMarkers, StringRef replaceMarker,
MLIRContext *context) {
SmallVector<Identifier, 2> markers;
markers.reserve(matchMarkers.size());
for (StringRef marker : matchMarkers) {
markers.emplace_back(Identifier::get(marker, context));
}
return linalg::LinalgTransformationFilter(
markers, Identifier::get(replaceMarker, context));
}
/// Converts a symbolic GPU processor dimension to its numeric one.
static unsigned dimToIndex(StringRef dim) {
return StringSwitch<unsigned>(dim).Case("x", 0).Case("y", 1).Case("z", 2);
}
//===----------------------------------------------------------------------===//
// Main pass
//===----------------------------------------------------------------------===//
namespace {
/// Function pass that implements tiling and fusion in Linalg on buffers.
class TileAndVectorizeInOneWorkgroupPass
: public PassWrapper<TileAndVectorizeInOneWorkgroupPass,
OperationPass<IREE::HAL::ExecutableTargetOp>> {
public:
TileAndVectorizeInOneWorkgroupPass(const SPIRVCodegenOptions &passOptions)
: options(passOptions) {}
TileAndVectorizeInOneWorkgroupPass(
const TileAndVectorizeInOneWorkgroupPass &pass)
: options(pass.options) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, IREE::HAL::HALDialect, gpu::GPUDialect,
linalg::LinalgDialect, scf::SCFDialect, ShapeDialect,
vector::VectorDialect>();
}
void runOnOperation() override;
private:
SPIRVCodegenOptions options;
};
} // namespace
//===----------------------------------------------------------------------===//
// Patterns to promote subviews to workgroup memory
//===----------------------------------------------------------------------===//
namespace {
/// Pattern to promote matmul operands to workgroup memory.
struct PromoteMatmulSubviewsPattern
: public linalg::LinalgPromotionPattern<linalg::MatmulOp> {
PromoteMatmulSubviewsPattern(MLIRContext *context,
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: linalg::LinalgPromotionPattern<linalg::MatmulOp>(
context,
options.setOperandsToPromote({0, 1}).setUseFullTileBuffers(
{false, false}),
marker, benefit) {}
};
/// Patterns to promote convolution operands to workgroup memory.
// TODO(ravishankarm): This pattern is only promoting the image subview to
// workgroup memory. In reality we should also be able to promote the filter
// subview to workgroup memory as well. Since none of the loops used to access
// the filter are tiled, this would mean the entire filter is moved to workgroup
// memory. Two reasons this is not done right now:
// 1) Linalg when tiling doesnt create a subview for the filter (since none of
// its dimensions are tiled. This needs to be relaxed (maybe by using an
// option).
// 2) Maybe there are better alternatives for handling filter like using
// different storage classes, since for inference workloads these are model
// constants. This is TBD.
template <typename ConvOpTy>
struct PromoteConvSubviewsPattern
: public linalg::LinalgPromotionPattern<ConvOpTy> {
PromoteConvSubviewsPattern(MLIRContext *context,
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: linalg::LinalgPromotionPattern<ConvOpTy>(
context,
options.setOperandsToPromote({0}).setUseFullTileBuffers(
{false, false}),
marker, benefit) {}
};
} // namespace
static void populatePromotionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns
.insert<PromoteMatmulSubviewsPattern,
PromoteConvSubviewsPattern<linalg::ConvInputNHWCFilterHWCFOp>>(
context,
linalg::LinalgPromotionOptions()
.setAllocationDeallocationFns(allocateWorkgroupMemory,
deallocateWorkgroupMemory)
.setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory),
getLinalgMatchAndReplaceMarker(getWorkgroupMarker(),
getWorkgroupMemoryMarker(), context));
}
//===----------------------------------------------------------------------===//
// Patterns to tile computation to map to subgroups
//===----------------------------------------------------------------------===//
/// Computes the Value for subgroupID along each dimension given number of
/// subgroups `numSubGroups` along each dimension (x-first, y-second, z-third).
static SmallVector<linalg::ProcInfo, 2> getSubgroupIdsAndCounts(
OpBuilder &builder, Location loc, ArrayRef<int64_t> numSubgroups) {
Type indexType = builder.getIndexType();
Value subgroupId = builder.create<gpu::SubgroupIdOp>(loc, indexType);
SmallVector<linalg::ProcInfo, 2> procInfo(numSubgroups.size());
// subgroupID
// = id.z * nsubgroups.y * nsubgroups.x + id.y * nsubgroups.x + id.x
for (size_t i = 0, e = numSubgroups.size(); i != e; ++i) {
Value nprocs = builder.create<ConstantIndexOp>(loc, numSubgroups[i]);
AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
Value procId =
makeComposedAffineApply(builder, loc, d0 % s0, {subgroupId, nprocs});
procInfo[e - i - 1] = linalg::ProcInfo{procId, nprocs};
subgroupId = builder.create<SignedDivIOp>(loc, subgroupId, nprocs);
}
return procInfo;
}
namespace {
/// Pattern to tile linalg.matmul for subgroups.
struct TileMatmulSubgroupPattern
: public linalg::LinalgTilingPattern<linalg::MatmulOp> {
using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
TileMatmulSubgroupPattern(MLIRContext *context,
linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: Base(context, options, marker, benefit) {}
};
} // namespace
/// Patterns for second level tiling to target subgroups.
static void populateTilingToSubgroupPatterns(
MLIRContext *context, const LaunchConfig &launchConfig,
OwningRewritePatternList &patterns) {
auto getInnerTileSizeFn = [&launchConfig](
OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 1);
if (tileSizes.empty()) return {};
SmallVector<Value, 4> tileSizesVal;
tileSizesVal.reserve(tileSizes.size());
for (auto val : tileSizes) {
tileSizesVal.push_back(
builder.create<ConstantIndexOp>(operation->getLoc(), val));
}
return tileSizesVal;
};
auto getSubgroupProcInfoFn = [&launchConfig](
OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
ArrayRef<int64_t> numSubgroups =
launchConfig.getNumSubgroups().take_front(parallelLoopRanges.size());
return getSubgroupIdsAndCounts(builder, loc, numSubgroups);
};
linalg::LinalgLoopDistributionOptions subgroupDistributionOptions;
subgroupDistributionOptions.procInfo = getSubgroupProcInfoFn;
subgroupDistributionOptions.distributionMethod = {
{linalg::DistributionMethod::CyclicNumProcsEqNumIters,
linalg::DistributionMethod::CyclicNumProcsEqNumIters}};
patterns.insert<TileMatmulSubgroupPattern>(
context,
linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::ParallelLoops)
.setTileSizeComputationFunction(getInnerTileSizeFn)
.setDistributionOptions(subgroupDistributionOptions),
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getVectorizeMarker(), context));
}
//===----------------------------------------------------------------------===//
// Patterns and methods for thread tiling.
//===----------------------------------------------------------------------===//
/// Patterns for third level tiling to target invocations.
static void populateTilingToInvocationPatterns(
MLIRContext *context, const LaunchConfig &launchConfig,
OwningRewritePatternList &patterns) {
linalg::TileSizeComputationFunction getInnerTileSizeFn =
[&launchConfig](OpBuilder &builder, Operation *operation) {
ArrayRef<int64_t> tileSizes = launchConfig.getTileSizes(operation, 2);
if (tileSizes.empty()) return SmallVector<Value, 4>();
SmallVector<Value, 4> tileSizesVal;
tileSizesVal.reserve(tileSizes.size());
for (auto val : tileSizes) {
tileSizesVal.push_back(
builder.create<ConstantIndexOp>(operation->getLoc(), val));
}
return tileSizesVal;
};
auto getThreadProcInfoFn = [](OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
return getGPUProcessorIdsAndCounts<gpu::ThreadIdOp, gpu::BlockDimOp>(
builder, loc, parallelLoopRanges.size());
};
linalg::LinalgLoopDistributionOptions invocationDistributionOptions;
invocationDistributionOptions.procInfo = getThreadProcInfoFn;
invocationDistributionOptions.distributionMethod = {
{linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
linalg::DistributionMethod::Cyclic}};
auto tilingOptions =
linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::Loops)
.setTileSizeComputationFunction(getInnerTileSizeFn)
.setDistributionOptions(invocationDistributionOptions);
patterns.insert<
linalg::LinalgTilingPattern<linalg::MatmulOp>,
linalg::LinalgTilingPattern<linalg::FillOp>,
linalg::LinalgTilingPattern<linalg::BatchMatmulOp>,
linalg::LinalgTilingPattern<linalg::ConvInputNWCFilterWCFOp>,
linalg::LinalgTilingPattern<linalg::ConvInputNDHWCFilterDHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
linalg::LinalgTilingPattern<linalg::GenericOp>,
linalg::LinalgTilingPattern<linalg::IndexedGenericOp>,
linalg::LinalgTilingPattern<linalg::PoolingNHWCMaxFOp>,
linalg::LinalgTilingPattern<linalg::PoolingNHWCMinFOp>,
linalg::LinalgTilingPattern<linalg::PoolingNHWCSumFOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getVectorizeMarker(), context));
patterns.insert<
linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
context, tilingOptions,
getLinalgMatchAndReplaceMarker(
{getWorkgroupMemoryMarker(), getWorkgroupMarker()},
getConvFilterTileMarker(), context));
}
/// Returns the corresponding range for the given `processorValue` is a GPU
/// thread id or block dim.
static Optional<std::pair<AffineExpr, AffineExpr>> getThreadRange(
Value processorValue, SmallVectorImpl<Value> & /*dims*/,
SmallVectorImpl<Value> & /*symbols*/, ArrayRef<int64_t> workgroupSize) {
if (auto idOp = processorValue.getDefiningOp<gpu::ThreadIdOp>()) {
OpBuilder builder(processorValue.getContext());
unsigned index = dimToIndex(idOp.dimension());
AffineExpr zero = builder.getAffineConstantExpr(0);
AffineExpr ubExpr = builder.getAffineConstantExpr(workgroupSize[index]);
return std::make_pair(zero, ubExpr - 1);
}
if (auto dimOp = processorValue.getDefiningOp<gpu::BlockDimOp>()) {
OpBuilder builder(processorValue.getContext());
unsigned index = dimToIndex(dimOp.dimension());
AffineExpr bound = builder.getAffineConstantExpr(workgroupSize[index]);
return std::make_pair(bound, bound);
}
return llvm::None;
}
//====---------------------------------------------------------------------===//
// Patterns for vectorization
//====---------------------------------------------------------------------===//
static void populateVectorizationPatterns(MLIRContext *context,
const LaunchConfig &launchConfig,
OwningRewritePatternList &patterns) {
linalg::insertVectorizationPatterns<linalg::FillOp, linalg::GenericOp,
linalg::ContractionOpInterface>(
patterns, linalg::LinalgVectorizationOptions(),
linalg::LinalgTransformationFilter(
Identifier::get(getVectorizeMarker(), context)));
}
//====---------------------------------------------------------------------===//
// Patterns for unrolling vectors
//====---------------------------------------------------------------------===//
static void populateVectorUnrollPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
patterns.insert<vector::UnrollVectorPattern>(
context,
vector::UnrollVectorOptions().setNativeShapeFn(getNativeVectorSize));
}
namespace {
/// Workaround SPIR-V backend limitations. SPIR-V vetorization pass relies on
/// unrolling to reduce instructions to a vector size we can convert to SPIR-V.
/// When vectorization creates transpose those block unrolling and result in
/// large vector we currently cannot lower. For now we always merge the
/// transpose into the contract op so that it can be unrolled.
// TODO(thomasraoux): Make transpose work with the current unrolling mechanism
// or replace unrolling.
class CombineContractTranspose final
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
// Perform lhs + rhs transpositions to conform to matmul row-major
// semantics. Bail out if the contraction cannot be put in this form.
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
bool foundTranspose = false;
std::array<Value, 3> sources = {op.lhs(), op.rhs(), op.acc()};
SmallVector<AffineMap> newMaps;
SmallVector<Value> newSources;
for (auto source : llvm::enumerate(sources)) {
auto map =
op.indexing_maps()[source.index()].cast<AffineMapAttr>().getValue();
auto tranposeOp = source.value().getDefiningOp<vector::TransposeOp>();
if (!tranposeOp) {
newSources.push_back(source.value());
newMaps.push_back(map);
continue;
}
SmallVector<int64_t, 3> perm;
tranposeOp.getTransp(perm);
SmallVector<AffineExpr> exprs(perm.size());
for (auto remap : llvm::enumerate(perm)) {
exprs[remap.value()] = map.getResult(remap.index());
}
newMaps.push_back(
AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs, ctx));
newSources.push_back(tranposeOp.vector());
foundTranspose = true;
}
if (!foundTranspose) return failure();
Value res = rewriter.create<vector::ContractionOp>(
loc, newSources[0], newSources[1], newSources[2],
rewriter.getAffineMapArrayAttr(newMaps), op.iterator_types());
rewriter.replaceOp(op, res);
return success();
}
};
} // namespace
//====---------------------------------------------------------------------===//
// Vector patterns
//====---------------------------------------------------------------------===//
static void applyVectorTransformation(FuncOp funcOp) {
auto targetEnv = spirv::TargetEnv(spirv::lookupTargetEnv(funcOp));
bool useCooperativeMatrix =
targetEnv.allows(spirv::Capability::CooperativeMatrixNV) &&
targetEnv.allows(spirv::Extension::SPV_NV_cooperative_matrix);
{
{
OwningRewritePatternList vectorUnrollPatterns(funcOp.getContext());
populateVectorUnrollPatterns(funcOp.getContext(), vectorUnrollPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorUnrollPatterns));
}
{
OwningRewritePatternList canonicalizationPatterns1(funcOp.getContext());
vector::populateVectorToVectorTransformationPatterns(
canonicalizationPatterns1);
vector::populateVectorToVectorCanonicalizationPatterns(
canonicalizationPatterns1);
vector::populateSplitVectorTransferPatterns(canonicalizationPatterns1);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns1));
OwningRewritePatternList canonicalizationPatterns2(funcOp.getContext());
vector::populateVectorSlicesLoweringPatterns(canonicalizationPatterns2);
vector::populateVectorTransferLoweringPatterns(canonicalizationPatterns2);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canonicalizationPatterns2));
if (useCooperativeMatrix) {
// When using cooperative matrix we don't want to lower the contract,
// instead we want to merge contract and transpose so that they can be
// converted to cooperative matrix matmul op.
// TODO(thomasraoux): remove that once we support cooperative matrix
// lowering in MLIR core.
OwningRewritePatternList combineTransposePatterns(funcOp.getContext());
combineTransposePatterns.add<CombineContractTranspose>(
funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(combineTransposePatterns));
} else {
OwningRewritePatternList contractLoweringPatterns(funcOp.getContext());
vector::populateVectorContractLoweringPatterns(
contractLoweringPatterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(contractLoweringPatterns));
}
}
LLVM_DEBUG({
llvm::dbgs() << "--- After unrolling vector ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
{
linalg::hoistRedundantVectorTransfers(funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After hoisting vector transfers ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
}
//====---------------------------------------------------------------------===//
// Patterns to tile convolution window dimensions
//====---------------------------------------------------------------------===//
static void populateTilingConvFilterPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
const LaunchConfig &launchConfig,
linalg::LinalgTransformationFilter marker) {
auto getTileSizeFn = [&launchConfig](OpBuilder &builder, Operation *op) {
SmallVector<Value, 4> tileSizes;
ArrayRef<int64_t> fourthLevel = launchConfig.getTileSizes(op, 3);
tileSizes.reserve(fourthLevel.size());
Location loc = op->getLoc();
for (int64_t size : fourthLevel) {
tileSizes.push_back(builder.create<ConstantIndexOp>(loc, size));
}
return tileSizes;
};
auto tilingOptions = linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::Loops)
.setTileSizeComputationFunction(getTileSizeFn);
patterns.insert<
linalg::LinalgTilingPattern<linalg::ConvInputNHWCFilterHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
linalg::LinalgTilingPattern<linalg::DepthwiseConvInputNHWCFilterHWCOp>>(
context, tilingOptions, marker);
}
//====---------------------------------------------------------------------===//
// Patterns to lower linalg ops to loops
//====---------------------------------------------------------------------===//
template <typename OpTy>
struct LowerToLoops final : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Only handle the cases where tiling to invocations was done, where tiling
// convolution filters or vectorization is expected.
if (!hasMarker(op, {getConvFilterTileMarker(), getVectorizeMarker()}))
return failure();
if (linalg::linalgOpToLoops(rewriter, op)) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
};
//====---------------------------------------------------------------------===//
// Main pass implementation
//====---------------------------------------------------------------------===//
void TileAndVectorizeInOneWorkgroupPass::runOnOperation() {
MLIRContext *context = &getContext();
IREE::HAL::ExecutableTargetOp targetOp = getOperation();
ModuleOp module = targetOp.getInnerModule();
for (FuncOp funcOp : module.getOps<FuncOp>()) {
if (!isEntryPoint(funcOp)) continue;
SmallVector<linalg::LinalgOp, 4> linalgOps;
SmallVector<Operation *, 4> tiledLoops;
if (failed(getLinalgOps(funcOp, linalgOps, tiledLoops))) {
// Nothing to do here.
continue;
}
linalg::Aliases aliases;
linalg::LinalgDependenceGraph dependenceGraph(aliases, linalgOps);
Optional<LaunchConfig> launchConfigOpt =
initGPULaunchConfig(context, dependenceGraph, options, linalgOps);
if (!launchConfigOpt) {
// No configuration to tile and vectorize. Nothing to do here.
continue;
}
LaunchConfig &launchConfig = *launchConfigOpt;
LLVM_DEBUG({
llvm::dbgs() << "\n--- Linalg tile configuration ---\n";
llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
interleaveComma(launchConfig.getWorkgroupSize(), llvm::dbgs());
llvm::dbgs() << "]\n";
for (auto op : linalgOps) {
llvm::dbgs() << "\t" << op.getOperation()->getName() << " : ";
TileSizesListTypeRef tileSizes = launchConfig.getTileSizes(op);
llvm::dbgs() << "{";
std::string sep = "";
for (auto &level : enumerate(tileSizes)) {
llvm::dbgs() << sep << level.index() << " : [";
sep = ", ";
interleaveComma(level.value(), llvm::dbgs());
llvm::dbgs() << "]";
}
llvm::dbgs() << "}\n";
}
});
if (options.useWorkgroupMemory) {
// The promotion patterns are put separate from the tiling patterns to
// make sure that the allocated scratchspace memory is constant sizes
// which requires some folding to trigger.
OwningRewritePatternList promotionPatterns(&getContext());
populatePromotionPatterns(context, promotionPatterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(promotionPatterns));
applyCanonicalizationPatternsForTiling(context, funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After workgroup memory promotion ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// TODO(thomasraoux, antiagainst): Tiling to subgroups shouldn't be
// controlled by vectorization. This is needed due to historical reasons.
// Change the second level tiling to cyclic to loops and remove this.
if (launchConfig.useVectorize()) {
OwningRewritePatternList secondLevelTilingPatterns(&getContext());
populateTilingToSubgroupPatterns(context, launchConfig,
secondLevelTilingPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(secondLevelTilingPatterns));
applyCanonicalizationPatternsForTiling(context, funcOp);
promoteSingleIterationLoops(funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling to subgroups ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
{
OwningRewritePatternList thirdLevelTilingPatterns(&getContext());
populateTilingToInvocationPatterns(context, launchConfig,
thirdLevelTilingPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(thirdLevelTilingPatterns));
// Remove trip-one loops created during cyclic loop distribution if we can
// prove the tiling was perfect.
RewritePatternSet canoncalizationPatterns(context);
populateAffineMinSCFCanonicalizationPattern(canoncalizationPatterns);
ArrayRef<int64_t> workgroupSize = launchConfig.getWorkgroupSize();
auto getThreadRangeFn = [workgroupSize](Value processorValue,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols) {
return getThreadRange(processorValue, dims, symbols, workgroupSize);
};
populateRemoveSingleIterationLoopPattern(canoncalizationPatterns,
getThreadRangeFn);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(canoncalizationPatterns));
// Perform generic canonicalization.
applyCanonicalizationPatternsForTiling(context, funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling to invocations ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
{
OwningRewritePatternList tilingPatterns(&getContext());
auto marker = getLinalgMatchAndReplaceMarker(
getConvFilterTileMarker(), getVectorizeMarker(), context);
populateTilingConvFilterPatterns(context, tilingPatterns, launchConfig,
marker);
populateFoldGPUProcessorIDUsesPatterns(context, tilingPatterns);
tilingPatterns.insert<linalg::AffineMinSCFCanonicalizationPattern>(
context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPatterns));
applyCanonicalizationPatternsForTiling(context, funcOp);
LLVM_DEBUG({
llvm::dbgs() << "--- After tiling convolution filter ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
if (launchConfig.useVectorize()) {
{
OwningRewritePatternList vectorizationPatterns(&getContext());
populateVectorizationPatterns(context, launchConfig,
vectorizationPatterns);
populateVectorizeLinalgConvPatterns(context, vectorizationPatterns);
(void)applyPatternsAndFoldGreedily(funcOp,
std::move(vectorizationPatterns));
LLVM_DEBUG({
llvm::dbgs() << "--- After vectorization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
}
// TODO: This should be a folding of Add into Contract in core but while
// they live in different dialects, it is not possible without unnatural
// dependencies.
funcOp.walk([&](Operation *op) {
if (auto contract = canonicalizeContractionAdd(op))
op->replaceAllUsesWith(contract);
});
applyVectorTransformation(funcOp);
}
// Lower ops that were tiled to invocations but not vectorized to loops.
// TODO(antiagainst): This is here now to simplify the interaction with
// ConvertToGPUPass, where we finally lower away all Linalg ops. Once that
// pass is cleaned up, we can invoke createConvertLinalgToLoopsPass
// directly.
{
RewritePatternSet patterns(context);
patterns
.add<LowerToLoops<linalg::BatchMatmulOp>,
LowerToLoops<linalg::ConvInputNWCFilterWCFOp>,
LowerToLoops<linalg::ConvInputNHWCFilterHWCFOp>,
LowerToLoops<linalg::ConvInputNDHWCFilterDHWCFOp>,
LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCFOp>,
LowerToLoops<linalg::DepthwiseConvInputNHWCFilterHWCOp>,
LowerToLoops<linalg::FillOp>, LowerToLoops<linalg::GenericOp>,
LowerToLoops<linalg::IndexedGenericOp>,
LowerToLoops<linalg::MatmulOp>,
LowerToLoops<linalg::PoolingNHWCMaxFOp>,
LowerToLoops<linalg::PoolingNHWCMinFOp>,
LowerToLoops<linalg::PoolingNHWCSumFOp>>(context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
launchConfig.finalize(funcOp);
}
}
//===----------------------------------------------------------------------===//
// Pass entry point and registration
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<IREE::HAL::ExecutableTargetOp>>
createTileAndVectorizeInOneWorkgroupPass(const SPIRVCodegenOptions &options) {
return std::make_unique<TileAndVectorizeInOneWorkgroupPass>(options);
}
static PassRegistration<TileAndVectorizeInOneWorkgroupPass> pass(
"iree-spirv-tile-and-vectorize-in-one-workgroup",
"Tile and vectorize Linalg operations on buffers in one workgroup", [] {
SPIRVCodegenOptions options = getSPIRVCodegenOptionsFromClOptions();
return std::make_unique<TileAndVectorizeInOneWorkgroupPass>(options);
});
} // namespace iree_compiler
} // namespace mlir