blob: e3f2a48934a22f073e08ab46d032718764cc44ed [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- CovertToSPIRVPass.cpp - Performs the final SPIR-V conversion -------===//
//
// This file implements a pass to perform the final conversion to SPIR-V.
// This pass converts remaining interface ops into SPIR-V global variables,
// GPU processor ID ops into SPIR-V global variables, loop/standard ops into
// corresponding SPIR-V ops.
//
//===----------------------------------------------------------------------===//
#include <tuple>
#include "iree/compiler/Codegen/Dialect/LoweringConfig.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h"
#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace iree_compiler {
namespace {
//===----------------------------------------------------------------------===//
// Resource utilities
//===----------------------------------------------------------------------===//
/// Map from hal.interface.binding.subspan ops to their corresponding
/// spv.GlobalVariable ops.
using InterfaceResourceMap =
llvm::DenseMap<Operation *, spirv::GlobalVariableOp>;
/// Creates a resource evariable of the given `type` at the beginning of
/// `moduleOp`'s block via `symbolTable` and bind it to `set` and `binding`.
spirv::GlobalVariableOp createResourceVariable(Location loc, Type type,
unsigned set, unsigned binding,
bool alias, ModuleOp moduleOp,
SymbolTable *symbolTable) {
std::string name = llvm::formatv("__resource_var_{0}_{1}_", set, binding);
OpBuilder builder(moduleOp.getContext());
auto variable =
builder.create<spirv::GlobalVariableOp>(loc, type, name, set, binding);
if (alias) variable->setAttr("aliased", builder.getUnitAttr());
symbolTable->insert(variable, moduleOp.getBody()->begin());
return variable;
}
/// Returns the (set, binding) pair for the given interface op.
std::pair<int32_t, int32_t> getInterfaceSetAndBinding(Operation *op) {
IREE::HAL::InterfaceBindingOp bindingOp =
cast<IREE::HAL::InterfaceBindingSubspanOp>(op).queryBindingOp();
return {bindingOp.set().getSExtValue(), bindingOp.binding().getSExtValue()};
}
/// Scans all hal.interface.binding.subspan ops in `module`, creates their
/// corresponding spv.GlobalVariables when needed, and returns the map.
/// The created variables need to have their types fixed later.
InterfaceResourceMap createResourceVariables(mlir::ModuleOp module) {
SymbolTable symbolTable(module);
InterfaceResourceMap interfaceToResourceVars;
auto fns = llvm::to_vector<1>(module.getOps<FuncOp>());
for (FuncOp func : llvm::reverse(fns)) {
// Collect all interface ops and their (set, binding) pairs in this
// function. Use SmallVector here for a deterministic order.
SmallVector<IREE::HAL::InterfaceBindingSubspanOp, 8> interfaceOps;
SmallVector<std::pair<uint32_t, uint32_t>, 8> setBindings;
// Use a map to see if we have different types for one (set, binding) pair,
// which will require creating multiple SPIR-V global variables.
llvm::DenseMap<std::pair<uint32_t, uint32_t>, llvm::DenseSet<Type>>
setBindingTypes;
func.walk([&](Operation *op) {
auto interfaceOp = dyn_cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
if (!interfaceOp || interfaceOp.use_empty()) return;
interfaceOps.emplace_back(interfaceOp);
setBindings.emplace_back(getInterfaceSetAndBinding(interfaceOp));
setBindingTypes[setBindings.back()].insert(interfaceOp.getType());
});
// Keep track of created SPIR-V global variables. This allows us to
// deduplicate when possible to reduce generated SPIR-V blob size.
llvm::DenseMap<std::tuple<uint32_t, uint32_t, Type>,
spirv::GlobalVariableOp>
resourceVars;
for (int i = interfaceOps.size() - 1; i >= 0; --i) {
auto interfaceOp = interfaceOps[i];
const auto &setBinding = setBindings[i];
auto key = std::make_tuple(setBinding.first, setBinding.second,
interfaceOp.getType());
auto var = resourceVars.lookup(key);
if (!var) {
// If we have multiple SPIR-V global variables bound to the same (set,
// binding) pair and they are used in the same function, those variables
// need to have alias decoration.
bool alias = setBindingTypes[setBindings[i]].size() > 1;
// We are using the interface op's type for creating the global
// variable. It's fine. The correctness boundary is the pass.
// We will fix it up during conversion so it won't leak.
var = createResourceVariable(
interfaceOp.getLoc(), interfaceOp.getType(), setBinding.first,
setBinding.second, alias, module, &symbolTable);
resourceVars[key] = var;
}
interfaceToResourceVars[interfaceOp] = var;
}
}
return interfaceToResourceVars;
}
} // namespace
//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
namespace {
/// A pattern to convert hal.interface.load.constant into a sequence of SPIR-V
/// ops to load from a global variable representing the push constant storage.
struct HALInterfaceLoadConstantConverter final
: public OpConversionPattern<IREE::HAL::InterfaceLoadConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
IREE::HAL::InterfaceLoadConstantOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(#1519): hal.interface.load.constant should point to the
// hal.interface op.
auto moduleOp = loadOp->getParentOfType<ModuleOp>();
auto halInterfaceOps =
llvm::to_vector<1>(moduleOp.getOps<IREE::HAL::InterfaceOp>());
assert(halInterfaceOps.size() == 1);
assert(halInterfaceOps.front().push_constants().hasValue());
uint64_t elementCount =
(*halInterfaceOps.front().push_constants()).getZExtValue();
unsigned offset = loadOp.offset().getZExtValue();
// The following function generates SPIR-V ops with i32 types. So it does
// type "conversion" (index -> i32) implicitly.
auto i32Type = rewriter.getIntegerType(32);
auto value = spirv::getPushConstantValue(loadOp, elementCount, offset,
i32Type, rewriter);
rewriter.replaceOp(loadOp, value);
return success();
}
};
/// A pattern to convert hal.interface.workgroup.id/count into corresponding
/// SPIR-V Builtin ops.
template <typename InterfaceOpTy, spirv::BuiltIn builtin>
struct HALInterfaceWorkgroupIdAndCountConverter final
: public OpConversionPattern<InterfaceOpTy> {
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
InterfaceOpTy op, typename InterfaceOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int32_t index = static_cast<int32_t>(op.dimension().getSExtValue());
auto i32Type = rewriter.getIntegerType(32);
Value spirvBuiltin =
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
op, i32Type, spirvBuiltin, rewriter.getI32ArrayAttr({index}));
return success();
}
};
/// A pattern to convert hal.interface.binding.subspan into a sequence of SPIR-V
/// ops to get the address to a global variable representing the resource
/// buffer.
struct HALInterfaceBindingSubspanConverter final
: public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> {
HALInterfaceBindingSubspanConverter(
TypeConverter &typeConverter, MLIRContext *context,
const InterfaceResourceMap &interfaceToResourceVars,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit),
interfaceToResourceVars(interfaceToResourceVars) {}
LogicalResult matchAndRewrite(
IREE::HAL::InterfaceBindingSubspanOp interfaceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (interfaceOp.use_empty()) {
rewriter.eraseOp(interfaceOp);
return success();
}
Type resultType = interfaceOp.getOperation()->getResult(0).getType();
Type convertedType = this->getTypeConverter()->convertType(resultType);
if (!convertedType) {
return interfaceOp.emitError()
<< "failed to convert SPIR-V type: " << resultType;
}
auto varOp = interfaceToResourceVars.lookup(interfaceOp);
// Fix up the variable's type.
varOp.typeAttr(TypeAttr::get(convertedType));
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(interfaceOp, varOp);
return success();
}
private:
const InterfaceResourceMap &interfaceToResourceVars;
};
/// Pattern to lower operations that become a no-ops at this level.
template <typename OpTy>
struct FoldAsNoOp final : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
/// Removes unrealized_conversion_cast ops introduced during progressive
/// lowering when possible.
struct RemoveIdentityConversionCast final
: public OpConversionPattern<UnrealizedConversionCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
UnrealizedConversionCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getNumOperands() == 1 && op->getNumResults() == 1 &&
adaptor.getOperands().front().getType() ==
op->getResultTypes().front()) {
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
return failure();
}
};
//===----------------------------------------------------------------------===//
// Conversion pass
//===----------------------------------------------------------------------===//
/// A pass to perform the SPIR-V conversion.
///
/// This pass converts remaining interface ops into SPIR-V global variables,
/// GPU processor ID ops into SPIR-V global variables, loop/standard ops into
/// corresponding SPIR-V ops.
struct ConvertToSPIRVPass : public ConvertToSPIRVBase<ConvertToSPIRVPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<spirv::SPIRVDialect>();
}
ConvertToSPIRVPass() {}
ConvertToSPIRVPass(const ConvertToSPIRVPass &pass) {}
void runOnOperation() override;
};
} // namespace
void ConvertToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp moduleOp = getOperation();
llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints =
getAllEntryPoints(moduleOp);
for (auto funcOp : moduleOp.getOps<FuncOp>()) {
auto entryPointOp = entryPoints.lookup(funcOp.getName());
if (!entryPointOp) continue;
// TODO(ravishankarm): This needs to be removed after ConvertToGPU is
// deprecated. All passes must set the `workgroup_size` on the
// `hal.executable.entry_point` directly and not on the function.
if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) continue;
SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp);
if (workgroupSize.empty()) {
entryPointOp.emitOpError(
"expected workgroup_size attribute to be set for SPIR-V lowering");
return signalPassFailure();
}
auto workgroupSize32 = llvm::to_vector<4>(llvm::map_range(
workgroupSize, [](int64_t v) { return static_cast<int32_t>(v); }));
funcOp->setAttr(spirv::getEntryPointABIAttrName(),
spirv::getEntryPointABIAttr(workgroupSize32, context));
}
spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(moduleOp);
moduleOp->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
OwningRewritePatternList patterns(&getContext());
ScfToSPIRVContext scfToSPIRVContext;
// Pull in GPU patterns to convert processor ID ops and loop ops.
populateGPUToSPIRVPatterns(typeConverter, patterns);
// Pull in SCF patterns to convert control flow ops.
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
// Use the default 64-bit lowering for TOSA's ApplyScale operator:
// This lowering widens integer types to 64-bit an performs the non-fused
// operations, specifically multiply, add, and shift. Bit-widening
// is used to guarantee higher-order bits are not truncated during the
// multiply or add.
//
// TODO(antiagainst): Use a lowering that uses specific SPIRV intrinsics.
tosa::populateTosaRescaleToStandardConversionPatterns(&patterns);
// Pull in MemRef patterns to convert load/store ops.
populateMemRefToSPIRVPatterns(typeConverter, patterns);
// Pull in standard/math patterns to convert arithmetic ops and others.
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
populateStandardToSPIRVPatterns(typeConverter, patterns);
populateMathToSPIRVPatterns(typeConverter, patterns);
// Pull in standard patterns to convert tensor operations to SPIR-V. These are
// primarily used to handle tensor-type constants and contain a
// threshold. Only those constants that are below the threshold are converted
// to SPIR-V. In IREE we want to control this threshold at Flow level. So set
// this value arbitrarily high to make sure that everything within a dispatch
// region is converted.
mlir::populateTensorToSPIRVPatterns(
typeConverter, std::numeric_limits<int64_t>::max() / 8, patterns);
// Pull in vector patterns to convert vector ops.
mlir::populateVectorToSPIRVPatterns(typeConverter, patterns);
// Pull in builtin func to spv.func conversion.
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
// Add IREE HAL interface op conversions.
patterns.insert<
HALInterfaceLoadConstantConverter,
HALInterfaceWorkgroupIdAndCountConverter<
IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>,
HALInterfaceWorkgroupIdAndCountConverter<
IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>(
typeConverter, context);
// Performs a prelimiary step to analyze all hal.interface.binding.subspan ops
// and create spv.GlobalVariables.
auto interfaceToResourceVars = createResourceVariables(moduleOp);
// For using use them in conversion.
patterns.insert<HALInterfaceBindingSubspanConverter>(typeConverter, context,
interfaceToResourceVars);
/// Fold certain operations as no-ops:
/// - linalg.reshape becomes a no-op since all memrefs are linearized in
/// SPIR-V.
/// - tensor_to_memref can become a no-op since tensors are lowered to
/// !spv.array.
/// - unrealized_conversion_cast with the same source and target type.
patterns.insert<
FoldAsNoOp<memref::CollapseShapeOp>, FoldAsNoOp<memref::ExpandShapeOp>,
FoldAsNoOp<bufferization::ToMemrefOp>, RemoveIdentityConversionCast>(
typeConverter, context);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
// Disallow all other ops.
target->markUnknownOpDynamicallyLegal([](Operation *) { return false; });
SmallVector<FuncOp, 1> functions;
for (FuncOp fn : moduleOp.getOps<FuncOp>()) {
if (!fn.isPublic()) continue;
functions.push_back(fn);
}
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (FuncOp fn : functions) {
if (failed(applyFullConversion(fn, *target, frozenPatterns))) {
return signalPassFailure();
}
}
// Collect all SPIR-V ops into a spv.module.
auto builder = OpBuilder::atBlockBegin(moduleOp.getBody());
auto spvModule = builder.create<spirv::ModuleOp>(
moduleOp.getLoc(), spirv::AddressingModel::Logical,
spirv::MemoryModel::GLSL450);
Block *body = spvModule.getBody();
Dialect *spvDialect = spvModule->getDialect();
for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) {
// Skip the newly created spv.module itself.
if (&op == spvModule) continue;
if (op.getDialect() == spvDialect) op.moveBefore(body, body->end());
}
}
//===----------------------------------------------------------------------===//
// Pass entry point and registration
//===----------------------------------------------------------------------===//
std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass() {
return std::make_unique<ConvertToSPIRVPass>();
}
} // namespace iree_compiler
} // namespace mlir