blob: 7bfc774223c4ae3cc7bfb482616df732d01397fb [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#define DEBUG_TYPE "iree-codegen-remove-trivial-loops"
namespace mlir {
namespace iree_compiler {
/// Converts a symbolic GPU processor dimension to its numeric one.
static unsigned dimToIndex(StringRef dim) {
return StringSwitch<unsigned>(dim).Case("x", 0).Case("y", 1).Case("z", 2);
}
/// If the value is a threadID return the range [0, workgroupSize-1].
/// If the number of workgroup is known also return the range of workgroupId ad
/// workgroupCount.
static Optional<std::pair<AffineExpr, AffineExpr>> getWorkgroupRange(
Value processorValue, SmallVectorImpl<Value> & /*dims*/,
SmallVectorImpl<Value> & /*symbols*/, ArrayRef<int64_t> workgroupCount,
ArrayRef<int64_t> workgroupSize) {
if (auto idOp = processorValue.getDefiningOp<gpu::ThreadIdOp>()) {
unsigned index = dimToIndex(idOp.dimension());
OpBuilder b(processorValue.getContext());
AffineExpr zero = b.getAffineConstantExpr(0);
AffineExpr ubExpr = b.getAffineConstantExpr(workgroupSize[index]);
return std::make_pair(zero, ubExpr - 1);
}
if (auto dimOp = processorValue.getDefiningOp<gpu::BlockDimOp>()) {
OpBuilder builder(processorValue.getContext());
unsigned index = dimToIndex(dimOp.dimension());
AffineExpr bound = builder.getAffineConstantExpr(workgroupSize[index]);
return std::make_pair(bound, bound);
}
if (workgroupCount.empty()) return llvm::None;
if (auto idOp =
processorValue.getDefiningOp<IREE::HAL::InterfaceWorkgroupIDOp>()) {
OpBuilder builder(processorValue.getContext());
unsigned index = idOp.dimension().getZExtValue();
AffineExpr zero = builder.getAffineConstantExpr(0);
AffineExpr ubExpr = builder.getAffineConstantExpr(workgroupCount[index]);
return std::make_pair(zero, ubExpr - 1);
}
if (auto dimOp = processorValue
.getDefiningOp<IREE::HAL::InterfaceWorkgroupCountOp>()) {
OpBuilder builder(processorValue.getContext());
unsigned index = dimOp.dimension().getZExtValue();
AffineExpr bound = builder.getAffineConstantExpr(workgroupCount[index]);
return std::make_pair(bound, bound);
}
return llvm::None;
}
/// Return true if the given tiled loop is distributed to workgroups.
static bool isWorkgroupLoop(const LoopTilingAndDistributionInfo &info) {
auto forOp = cast<scf::ForOp>(info.loop);
Operation *lbOp = forOp.getLowerBound().getDefiningOp();
if (isa<IREE::HAL::InterfaceWorkgroupIDOp>(lbOp)) return true;
auto applyOp = dyn_cast<AffineApplyOp>(lbOp);
return applyOp && llvm::any_of(applyOp.getMapOperands(), [](Value operand) {
return operand.getDefiningOp<IREE::HAL::InterfaceWorkgroupIDOp>();
});
}
/// Infer the number of workgroups by looking at the tiled loop and the number
/// of element per workgroups.
static SmallVector<int64_t> getNumWorkgroup(
FuncOp funcOp, IREE::HAL::ExecutableEntryPointOp entryPointOp) {
auto allLoops = getTiledAndDistributedLoopInfo(funcOp);
auto wgLoops =
llvm::to_vector<3>(llvm::make_filter_range(allLoops, isWorkgroupLoop));
SmallVector<int64_t> workloadSize(wgLoops.size());
for (LoopTilingAndDistributionInfo &tileInfo : wgLoops) {
if (tileInfo.processorDistributionDim >= workloadSize.size()) return {};
if (!tileInfo.untiledLowerBound.is<Attribute>() ||
!tileInfo.untiledUpperBound.is<Attribute>() ||
!tileInfo.untiledStep.is<Attribute>()) {
continue;
}
int64_t lb = tileInfo.untiledLowerBound.get<Attribute>()
.cast<IntegerAttr>()
.getInt();
int64_t ub = tileInfo.untiledUpperBound.get<Attribute>()
.cast<IntegerAttr>()
.getInt();
int64_t step =
tileInfo.untiledStep.get<Attribute>().cast<IntegerAttr>().getInt();
if (step == 0) return SmallVector<int64_t>();
workloadSize[tileInfo.processorDistributionDim] = (ub - lb) / step;
}
auto translationInfo = getTranslationInfo(entryPointOp);
if (!translationInfo) return SmallVector<int64_t>();
SmallVector<int64_t> workloadPerWorkgroup =
translationInfo.getWorkloadPerWorkgroupVals();
if (workloadSize.size() != workloadPerWorkgroup.size()) {
return SmallVector<int64_t>();
}
SmallVector<int64_t> numWorkgroups;
for (auto pair : llvm::zip(workloadSize, workloadPerWorkgroup)) {
auto workload = std::get<0>(pair);
auto size = std::get<1>(pair);
numWorkgroups.push_back(llvm::divideCeil(workload, size));
}
numWorkgroups.resize(kNumMaxParallelDims, 1);
return numWorkgroups;
}
static LogicalResult removeOneTripTiledLoops(FuncOp funcOp,
ArrayRef<int64_t> workgroupSize,
ArrayRef<int64_t> numWorkgroups) {
auto getWorkgroupRangeFn = [numWorkgroups, workgroupSize](
Value processorValue,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols) {
return getWorkgroupRange(processorValue, dims, symbols, numWorkgroups,
workgroupSize);
};
OwningRewritePatternList patterns(funcOp.getContext());
populateRemoveSingleIterationLoopPattern(patterns, getWorkgroupRangeFn);
return applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
namespace {
class RemoveSingleIterationLoopPass final
: public RemoveSingleIterationLoopBase<RemoveSingleIterationLoopPass> {
void runOnOperation() override {
FuncOp funcOp = getOperation();
auto entryPointOp = getEntryPoint(funcOp);
if (!entryPointOp) return;
SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
SmallVector<int64_t> numWorkgroups = getNumWorkgroup(funcOp, entryPointOp);
if (failed(removeOneTripTiledLoops(funcOp, workgroupSize, numWorkgroups))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createRemoveSingleIterationLoopPass() {
return std::make_unique<RemoveSingleIterationLoopPass>();
}
} // namespace iree_compiler
} // namespace mlir