blob: be9760b81a4c0cded3683143c19c6ae6a73d880d [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Conversion/CodegenUtils/GetNumWorkgroups.h"
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Utils/TypeUtils.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#define DEBUG_TYPE "workgroup-calculation"
namespace mlir {
namespace iree_compiler {
FuncOp getNumWorkgroupsFn(FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr) {
SymbolRefAttr attr =
entryPointFn.getAttrOfType<SymbolRefAttr>(numWorkgroupsFnAttr);
if (!attr) {
entryPointFn.emitError("missing attribute '") << numWorkgroupsFnAttr << "'";
return nullptr;
}
FuncOp numWorkgroupsFn = dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(
entryPointFn.getParentOfType<ModuleOp>(), attr));
if (!numWorkgroupsFn) {
entryPointFn.emitError("unable to find num workgroups fn ") << attr;
return nullptr;
}
return numWorkgroupsFn;
}
// TODO: This method is templated on the builder type since the `OpBuilder`
// doesnt have an erase method. Just erasing the op leads to segfaults when the
// builder is `PatternRewriter` since the rewriter doesn't know the op was
// deleted. This can be simplified a lot when this issue is fixed.
template <typename BuilderTy>
static void eraseOp(BuilderTy &builder, Operation *op) {
builder.eraseOp(op);
}
template <>
void eraseOp(OpBuilder &builder, Operation *op) {
op->erase();
}
/// Computes the bounds of the loops of the `linalgOp`.
template <typename BuilderTy>
static Optional<SmallVector<Value, 4>> getLoopUpperBounds(
BuilderTy &builder, Location loc, FuncOp numWorkgroupsFn,
linalg::LinalgOp linalgOp) {
if (!numWorkgroupsFn.empty()) {
numWorkgroupsFn.emitError("num workgroups fn expected to be empty");
return {};
}
LLVM_DEBUG({
llvm::dbgs() << "Found num workgroups function : "
<< numWorkgroupsFn.getName();
});
builder.createBlock(&numWorkgroupsFn.getBody(), /*insertPt=*/{},
numWorkgroupsFn.getType().getInputs());
llvm::SetVector<Operation *> slice;
getBackwardSlice(linalgOp, &slice);
BlockAndValueMapping mapper;
for (Operation *op : slice) {
builder.clone(*op, mapper);
}
// Clone the linalg operation just to compute the loop bounds.
linalg::LinalgOp clonedLinalgOp =
builder.clone(*linalgOp.getOperation(), mapper);
auto loopRange = clonedLinalgOp.createLoopRanges(builder, loc);
if (llvm::any_of(loopRange, [](Range range) {
return !matchPattern(range.stride, m_One()) ||
!matchPattern(range.offset, m_Zero());
})) {
linalgOp.emitError("unhandled non-unit stride loop range");
return llvm::None;
}
SmallVector<Value, 4> bounds = llvm::to_vector<4>(
llvm::map_range(loopRange, [](Range range) { return range.size; }));
eraseOp<BuilderTy>(builder, clonedLinalgOp);
return bounds;
}
/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
static Value buildCeilDiv(OpBuilder &builder, Location loc, Value numerator,
Value denominator) {
Value one = builder.create<ConstantIndexOp>(loc, 1);
Value t = builder.create<AddIOp>(
loc, numerator, builder.create<SubIOp>(loc, denominator, one));
return builder.create<SignedDivIOp>(loc, t, denominator);
}
/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
/// when denominator is a constant.
static Value buildCeilDiv(OpBuilder &builder, Location loc, Value numerator,
int64_t denominator) {
return buildCeilDiv(
builder, loc, numerator,
builder.create<ConstantIndexOp>(loc, denominator).getResult());
}
template <class BuilderTy>
static LogicalResult createNumWorkgroupsFromResultShapeImpl(
BuilderTy &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> distributedLoops) {
FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
if (!numWorkgroupsFn) return failure();
Location loc = linalgOp.getLoc();
OpBuilder::InsertionGuard guard(builder);
auto loopRange = getLoopUpperBounds(builder, loc, numWorkgroupsFn, linalgOp);
if (!loopRange) return failure();
SmallVector<Value, 4> numWorkgroups;
DenseSet<unsigned> distributedLoopsSet(distributedLoops.begin(),
distributedLoops.end());
for (auto size : enumerate(tileSizes)) {
if (size.value() && distributedLoopsSet.count(size.index())) {
Value num =
buildCeilDiv(builder, loc, (*loopRange)[size.index()], size.value());
numWorkgroups.push_back(num);
}
}
SmallVector<Value, 4> resultValues =
llvm::to_vector<4>(llvm::reverse(numWorkgroups));
Value one = builder.template create<ConstantIndexOp>(loc, 1);
resultValues.resize(3, one);
builder.template create<mlir::ReturnOp>(loc, resultValues);
return success();
}
LogicalResult createNumWorkgroupsFromResultShape(
OpBuilder &builder, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes,
ArrayRef<unsigned> distributedLoops) {
return createNumWorkgroupsFromResultShapeImpl<OpBuilder>(
builder, linalgOp, entryPointFn, numWorkgroupsFnAttr, tileSizes,
distributedLoops);
}
LogicalResult createNumWorkgroupsFromResultShape(
PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
llvm::StringRef numWorkgroupsFnAttr, ArrayRef<int64_t> tileSizes) {
SmallVector<unsigned, 4> distributedLoops =
llvm::to_vector<4>(llvm::seq<unsigned>(
0, std::min<unsigned>(3, getNumOuterParallelLoops(linalgOp))));
return createNumWorkgroupsFromResultShapeImpl<PatternRewriter>(
rewriter, linalgOp, entryPointFn, numWorkgroupsFnAttr, tileSizes,
distributedLoops);
}
LogicalResult createNumWorkgroupsFromLinearizedResultShape(
ConversionPatternRewriter &rewriter, linalg::LinalgOp linalgOp,
FuncOp entryPointFn, llvm::StringRef numWorkgroupsFnAttr,
int64_t workgroupSizeX) {
FuncOp numWorkgroupsFn = getNumWorkgroupsFn(
linalgOp.getParentOfType<FuncOp>(), numWorkgroupsFnAttr);
if (!numWorkgroupsFn) return failure();
if (!numWorkgroupsFn.empty()) {
// TODO(ravishankarm): We can end up with multiple linalg operations
// (typically linalg.generic operations) that have the same workload in a
// dispatch region. In that case, the first linalg.generic creates the body
// of number of workgroups. For now, just returning if the body is not empty
// assuming that it is correct for all the ops in the dispatch region. This
// needs to be enforced somehow.
return success();
}
Location loc = linalgOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
Optional<SmallVector<Value, 4>> loopRange =
getLoopUpperBounds(rewriter, loc, numWorkgroupsFn, linalgOp);
if (!loopRange) return failure();
unsigned numParallelLoops = getNumOuterParallelLoops(linalgOp);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
SmallVector<Value, 3> returnValues(3, one);
for (auto range : ArrayRef<Value>(*loopRange).take_front(numParallelLoops)) {
returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
}
returnValues[0] =
buildCeilDiv(rewriter, loc, returnValues[0], workgroupSizeX);
rewriter.create<mlir::ReturnOp>(loc, returnValues);
return success();
}
/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
/// function. This function has arguments the !shapex.ranked_shape for all the
/// input and output shaped types. Using this the function returns the number of
/// workgroups to use. To use this function on the host side, generate the
/// !shapex.ranked_shape values that describe the shape of the inputs and
/// outputs of the dispatch region and "inline" the function body.
std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> operands,
ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> results,
ConversionPatternRewriter &rewriter) {
std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
// TODO: This is really just inlining a function. For now assume that the
// `numWorkgroupsFn` has a single block to make inlining easier.
if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
return returnValue;
SmallVector<SmallVector<Value, 4>, 4> shapeValues;
shapeValues.reserve(operands.size() + results.size());
auto getShapeValuesFn =
[&](ArrayRef<Optional<IREE::HAL::TensorRewriteAdaptor>> values)
-> LogicalResult {
for (auto val : values) {
if (!val) continue;
Optional<SmallVector<Value, 4>> shape = val->getShapeDims(rewriter);
if (!shape) return emitError(loc, "shape computation for operand failed");
shapeValues.push_back(shape.getValue());
}
return success();
};
if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
return returnValue;
BlockAndValueMapping mapper;
for (Operation &op : numWorkgroupsFn.front()) {
if (isa<mlir::ReturnOp>(op)) {
for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
i != e; ++i) {
returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
}
break;
}
if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
auto &dimValues = shapeValues[arg.getArgNumber()];
mapper.map(shapeOp.result(), dimValues[shapeOp.getIndex()]);
continue;
}
return returnValue;
}
// If all its operands are mapped, clone it.
if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
return mapper.contains(operand);
})) {
rewriter.clone(op, mapper);
continue;
}
}
return returnValue;
}
} // namespace iree_compiler
} // namespace mlir