blob: 25cea529c0a50065ed7d03eed8114812991bf5a1 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/compiler/Dialect/HAL/Target/SPIRVCommon/SPIRVTarget.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace HAL {
// Records a full execution barrier that forces visibility of all buffers.
static void recordFullExecutionBarrier(Value commandBuffer, Location loc,
OpBuilder &builder) {
Value memoryBarrier = builder.create<IREE::HAL::MakeMemoryBarrierOp>(
loc, IREE::HAL::AccessScopeBitfield::DispatchWrite,
IREE::HAL::AccessScopeBitfield::DispatchRead);
builder.create<IREE::HAL::CommandBufferExecutionBarrierOp>(
loc, commandBuffer, IREE::HAL::ExecutionStageBitfield::Dispatch,
IREE::HAL::ExecutionStageBitfield::Dispatch,
ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
}
/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
/// function. This function has arguments the !shapex.ranked_shape for all the
/// input and output shaped types. Using this the function returns the number of
/// workgroups to use. To use this function on the host side, generate the
/// !shapex.ranked_shape values that describe the shape of the inputs and
/// outputs of the dispatch region and "inline" the function body.
static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
ArrayRef<Optional<TensorRewriteAdaptor>> operands,
ArrayRef<Optional<TensorRewriteAdaptor>> results,
ConversionPatternRewriter &rewriter) {
std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
// TODO: This is really just inlining a function. For now assume that the
// `numWorkgroupsFn` has a single block to make inlining easier.
if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
return returnValue;
SmallVector<SmallVector<Value, 4>, 4> shapeValues;
shapeValues.reserve(operands.size() + results.size());
auto getShapeValuesFn =
[&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
for (auto val : values) {
if (!val) continue;
Optional<SmallVector<Value, 4>> shape = val->getShapeDims(rewriter);
if (!shape) return emitError(loc, "shape computation for operand failed");
shapeValues.push_back(shape.getValue());
}
return success();
};
if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
return returnValue;
BlockAndValueMapping mapper;
for (Operation &op : numWorkgroupsFn.front()) {
if (isa<mlir::ReturnOp>(op)) {
for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
i != e; ++i) {
returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
}
break;
}
if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
auto &dimValues = shapeValues[arg.getArgNumber()];
mapper.map(shapeOp.result(), dimValues[shapeOp.getIndex()]);
continue;
}
return returnValue;
}
// If all its operands are mapped, clone it.
if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
return mapper.contains(operand);
})) {
rewriter.clone(op, mapper);
continue;
}
}
return returnValue;
}
SPIRVTargetBackend::SPIRVTargetBackend(SPIRVCodegenOptions options)
: spvCodeGenOptions_(std::move(options)) {}
void SPIRVTargetBackend::declareTargetOpsForEnv(
IREE::Flow::ExecutableOp sourceOp, IREE::HAL::ExecutableOp executableOp,
spirv::TargetEnvAttr spvTargetEnv) {
auto targetBuilder = OpBuilder::atBlockTerminator(&executableOp.getBlock());
auto targetOp = targetBuilder.create<IREE::HAL::ExecutableTargetOp>(
sourceOp.getLoc(), name(), filter_pattern());
auto containerBuilder = OpBuilder::atBlockTerminator(&targetOp.getBlock());
auto innerModuleOp = containerBuilder.create<ModuleOp>(sourceOp.getLoc());
// Attach SPIR-V target environment to the target's ModuleOp.
// If we had multiple target environments we would generate one target op
// per environment, with each setting its own environment attribute.
innerModuleOp.setAttr(spirv::getTargetEnvAttrName(), spvTargetEnv);
}
void SPIRVTargetBackend::buildTranslationPassPipeline(
IREE::HAL::ExecutableTargetOp targetOp, OpPassManager &passManager) {
buildSPIRVTransformPassPipeline(passManager, spvCodeGenOptions_);
}
LogicalResult SPIRVTargetBackend::recordDispatch(
Location loc, DispatchState dispatchState,
DeviceSwitchRewriter &switchRewriter) {
// Multiple entry points might be generated for a single dispatch function.
// Under such circumstances, we will have a special attribute indicating the
// schedule of the split entry points. Try to see if we can find such
// schedule attribute first.
ArrayAttr entryPointScheduleAttr;
spirv::ModuleOp spvModuleOp;
IREE::HAL::ExecutableOp executableOp = dispatchState.executableOp;
for (auto executableTargetOp :
executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
if (matchPattern(executableTargetOp.target_backend_filter(),
filter_pattern())) {
ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
assert(llvm::hasSingleElement(spvModuleOps));
spvModuleOp = *spvModuleOps.begin();
entryPointScheduleAttr = innerModuleOp.getAttrOfType<ArrayAttr>(
iree_compiler::getEntryPointScheduleAttrName());
break;
}
}
if (!spvModuleOp) return executableOp.emitError("unable to find spv.module");
SmallVector<spirv::FuncOp, 2> spvEntryPointFns;
if (!entryPointScheduleAttr) {
for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
SymbolTable::Visibility::Public)
spvEntryPointFns.push_back(spvFuncOp);
}
if (!llvm::hasSingleElement(spvEntryPointFns)) {
return spvModuleOp.emitError(
"expected a single entry point function, found ")
<< spvEntryPointFns.size();
}
} else {
llvm::StringMap<spirv::FuncOp> publicFns;
for (spirv::FuncOp spvFuncOp : spvModuleOp.getOps<spirv::FuncOp>()) {
if (SymbolTable::getSymbolVisibility(spvFuncOp) ==
SymbolTable::Visibility::Public)
publicFns[spvFuncOp.sym_name()] = spvFuncOp;
}
for (Attribute entryNameAttr : entryPointScheduleAttr) {
StringRef entryName = entryNameAttr.cast<StringAttr>().getValue();
spirv::FuncOp spvFuncOp = publicFns.lookup(entryName);
if (!spvFuncOp)
return spvModuleOp.emitError("unable to find entry point function ")
<< entryName;
spvEntryPointFns.push_back(spvFuncOp);
}
}
auto *region = switchRewriter.addConditionRegion(
IREE::HAL::DeviceMatchIDAttr::get(filter_pattern(), loc.getContext()),
{
dispatchState.workload,
dispatchState.commandBuffer,
});
auto &entryBlock = region->front();
ConversionPatternRewriter &rewriter = switchRewriter.getRewriter();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(&entryBlock);
auto commandBuffer = entryBlock.getArgument(1);
// We have multiple entry points to dispatch. Record in the order
// specified by entry point schedule and insert barrier between sequential
// ones.
for (auto it : llvm::enumerate(spvEntryPointFns)) {
spirv::FuncOp spvFuncOp = it.value();
FlatSymbolRefAttr numWorkgroupsFnAttr =
spvFuncOp.getAttrOfType<FlatSymbolRefAttr>(
getNumWorkgroupsFnAttrName());
if (!numWorkgroupsFnAttr) {
return spvFuncOp.emitError(
"expected vkspv.num_workgroups_fn attribute to refer to function "
"that computes number of workgroups to use");
}
std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
spvFuncOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
if (!numWorkgroupsFn) {
return spvFuncOp.emitError("unable to find function ")
<< numWorkgroupsFnAttr
<< " that computes the number of workgroups to use";
}
workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
loc, numWorkgroupsFn, dispatchState.executableOp.getFirstInterfaceOp(),
dispatchState.operands, dispatchState.results, rewriter);
if (llvm::any_of(workgroupCount,
[](Value v) -> bool { return v == nullptr; }))
return spvFuncOp.emitError("unable to find workgroup count");
// Ordinals are fixed based on the precomputed schedule, so use
// CommandBufferDispatchOp instead of CommandBufferDispatchSymbolOp.
auto executable = rewriter
.create<IREE::HAL::ExecutableLookupOp>(
loc, dispatchState.device,
dispatchState.dispatchOp.executable())
.getResult();
int32_t entryPointOrdinal = it.index();
rewriter.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBuffer, executable,
rewriter.getI32IntegerAttr(entryPointOrdinal), workgroupCount[0],
workgroupCount[1], workgroupCount[2]);
if (it.index() + 1 != spvEntryPointFns.size()) {
recordFullExecutionBarrier(commandBuffer, loc, rewriter);
}
}
rewriter.create<IREE::HAL::ReturnOp>(loc);
return success();
}
// Finds the spv.ExecutionMode operation to get the workgroup size from.
// TODO(ravishankarm): This might not be the only way this is specified. You
// could also have a spec constant, but that is not generated in the
// spv.module right now.
// TODO(ravishankarm): change workgroup size calculation to something we can
// query independently so that we don't need to lookup the value here.
std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
Location loc, IREE::HAL::ExecutableOp executableOp,
IREE::HAL::ExecutableEntryPointOp entryPointOp, Value workload,
OpBuilder &builder) {
// TODO(ravishankarm): possibly emit different recordDispatch logic if the
// workgroup sizes differ among targets.
spirv::ModuleOp spvModuleOp;
for (auto executableTargetOp :
executableOp.getBlock().getOps<IREE::HAL::ExecutableTargetOp>()) {
if (matchPattern(executableTargetOp.target_backend_filter(),
filter_pattern())) {
ModuleOp innerModuleOp = executableTargetOp.getInnerModule();
assert(!innerModuleOp.getAttr(
iree_compiler::getEntryPointScheduleAttrName()));
auto spvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
assert(llvm::hasSingleElement(spvModuleOps));
spvModuleOp = *spvModuleOps.begin();
break;
}
}
return calculateDispatchWorkgroupSize(
loc, spvModuleOp, entryPointOp.sym_name(), workload, builder);
}
std::array<Value, 3> SPIRVTargetBackend::calculateDispatchWorkgroupSize(
Location loc, spirv::ModuleOp spvModuleOp, StringRef entryPointName,
Value workload, OpBuilder &builder) {
std::array<Value, 3> workgroupSize;
for (auto executionModeOp :
spvModuleOp.getBlock().getOps<spirv::ExecutionModeOp>()) {
if (executionModeOp.fn() == entryPointName &&
executionModeOp.execution_mode() == spirv::ExecutionMode::LocalSize) {
for (int i = 0; i < executionModeOp.values().size(); ++i) {
workgroupSize[i] =
builder.create<ConstantIndexOp>(loc, executionModeOp.values()[i]
.cast<IntegerAttr>()
.getValue()
.getZExtValue());
}
break;
}
}
// Pad out the workgroup size with 1's (if the original rank was < 3).
for (int i = 0; i < workgroupSize.size(); ++i) {
if (!workgroupSize[i]) {
workgroupSize[i] = builder.create<ConstantIndexOp>(loc, 1);
}
}
return workgroupSize;
}
} // namespace HAL
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir