| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| //===- CovertToSPIRVPass.cpp - Performs the final SPIR-V conversion -------===// |
| // |
| // This file implements a pass to perform the final conversion to SPIR-V. |
| // This pass converts remaining interface ops into SPIR-V global variables, |
| // GPU processor ID ops into SPIR-V global variables, loop/standard ops into |
| // corresponding SPIR-V ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <tuple> |
| |
| #include "iree/compiler/Codegen/Dialect/LoweringConfig.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Passes.h" |
| #include "iree/compiler/Codegen/SPIRV/Utils.h" |
| #include "iree/compiler/Codegen/Utils/MarkerUtils.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "llvm/ADT/DenseMapInfo.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" |
| #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" |
| #include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" |
| #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" |
| #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" |
| #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" |
| #include "mlir/Conversion/TosaToStandard/TosaToStandard.h" |
| #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace { |
| //===----------------------------------------------------------------------===// |
| // Resource utilities |
| //===----------------------------------------------------------------------===// |
| |
| /// Map from hal.interface.binding.subspan ops to their corresponding |
| /// spv.GlobalVariable ops. |
| using InterfaceResourceMap = |
| llvm::DenseMap<Operation *, spirv::GlobalVariableOp>; |
| |
| /// Creates a resource evariable of the given `type` at the beginning of |
| /// `moduleOp`'s block via `symbolTable` and bind it to `set` and `binding`. |
| spirv::GlobalVariableOp createResourceVariable(Location loc, Type type, |
| unsigned set, unsigned binding, |
| bool alias, ModuleOp moduleOp, |
| SymbolTable *symbolTable) { |
| std::string name = llvm::formatv("__resource_var_{0}_{1}_", set, binding); |
| OpBuilder builder(moduleOp.getContext()); |
| auto variable = |
| builder.create<spirv::GlobalVariableOp>(loc, type, name, set, binding); |
| if (alias) variable->setAttr("aliased", builder.getUnitAttr()); |
| symbolTable->insert(variable, moduleOp.getBody()->begin()); |
| return variable; |
| } |
| |
| /// Returns the (set, binding) pair for the given interface op. |
| std::pair<int32_t, int32_t> getInterfaceSetAndBinding(Operation *op) { |
| IREE::HAL::InterfaceBindingOp bindingOp = |
| cast<IREE::HAL::InterfaceBindingSubspanOp>(op).queryBindingOp(); |
| return {bindingOp.set().getSExtValue(), bindingOp.binding().getSExtValue()}; |
| } |
| |
| /// Scans all hal.interface.binding.subspan ops in `module`, creates their |
| /// corresponding spv.GlobalVariables when needed, and returns the map. |
| /// The created variables need to have their types fixed later. |
| InterfaceResourceMap createResourceVariables(mlir::ModuleOp module) { |
| SymbolTable symbolTable(module); |
| InterfaceResourceMap interfaceToResourceVars; |
| |
| auto fns = llvm::to_vector<1>(module.getOps<FuncOp>()); |
| for (FuncOp func : llvm::reverse(fns)) { |
| // Collect all interface ops and their (set, binding) pairs in this |
| // function. Use SmallVector here for a deterministic order. |
| SmallVector<IREE::HAL::InterfaceBindingSubspanOp, 8> interfaceOps; |
| SmallVector<std::pair<uint32_t, uint32_t>, 8> setBindings; |
| |
| // Use a map to see if we have different types for one (set, binding) pair, |
| // which will require creating multiple SPIR-V global variables. |
| llvm::DenseMap<std::pair<uint32_t, uint32_t>, llvm::DenseSet<Type>> |
| setBindingTypes; |
| |
| func.walk([&](Operation *op) { |
| auto interfaceOp = dyn_cast<IREE::HAL::InterfaceBindingSubspanOp>(op); |
| if (!interfaceOp || interfaceOp.use_empty()) return; |
| interfaceOps.emplace_back(interfaceOp); |
| setBindings.emplace_back(getInterfaceSetAndBinding(interfaceOp)); |
| setBindingTypes[setBindings.back()].insert(interfaceOp.getType()); |
| }); |
| |
| // Keep track of created SPIR-V global variables. This allows us to |
| // deduplicate when possible to reduce generated SPIR-V blob size. |
| llvm::DenseMap<std::tuple<uint32_t, uint32_t, Type>, |
| spirv::GlobalVariableOp> |
| resourceVars; |
| |
| for (int i = interfaceOps.size() - 1; i >= 0; --i) { |
| auto interfaceOp = interfaceOps[i]; |
| const auto &setBinding = setBindings[i]; |
| |
| auto key = std::make_tuple(setBinding.first, setBinding.second, |
| interfaceOp.getType()); |
| auto var = resourceVars.lookup(key); |
| if (!var) { |
| // If we have multiple SPIR-V global variables bound to the same (set, |
| // binding) pair and they are used in the same function, those variables |
| // need to have alias decoration. |
| bool alias = setBindingTypes[setBindings[i]].size() > 1; |
| |
| // We are using the interface op's type for creating the global |
| // variable. It's fine. The correctness boundary is the pass. |
| // We will fix it up during conversion so it won't leak. |
| var = createResourceVariable( |
| interfaceOp.getLoc(), interfaceOp.getType(), setBinding.first, |
| setBinding.second, alias, module, &symbolTable); |
| resourceVars[key] = var; |
| } |
| |
| interfaceToResourceVars[interfaceOp] = var; |
| } |
| } |
| |
| return interfaceToResourceVars; |
| } |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion patterns |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// A pattern to convert hal.interface.load.constant into a sequence of SPIR-V |
| /// ops to load from a global variable representing the push constant storage. |
| struct HALInterfaceLoadConstantConverter final |
| : public OpConversionPattern<IREE::HAL::InterfaceLoadConstantOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceLoadConstantOp loadOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // TODO(#1519): hal.interface.load.constant should point to the |
| // hal.interface op. |
| auto moduleOp = loadOp->getParentOfType<ModuleOp>(); |
| auto halInterfaceOps = |
| llvm::to_vector<1>(moduleOp.getOps<IREE::HAL::InterfaceOp>()); |
| assert(halInterfaceOps.size() == 1); |
| assert(halInterfaceOps.front().push_constants().hasValue()); |
| |
| uint64_t elementCount = |
| (*halInterfaceOps.front().push_constants()).getZExtValue(); |
| unsigned offset = loadOp.offset().getZExtValue(); |
| |
| // The following function generates SPIR-V ops with i32 types. So it does |
| // type "conversion" (index -> i32) implicitly. |
| auto i32Type = rewriter.getIntegerType(32); |
| auto value = spirv::getPushConstantValue(loadOp, elementCount, offset, |
| i32Type, rewriter); |
| |
| rewriter.replaceOp(loadOp, value); |
| return success(); |
| } |
| }; |
| |
| /// A pattern to convert hal.interface.workgroup.id/count into corresponding |
| /// SPIR-V Builtin ops. |
| template <typename InterfaceOpTy, spirv::BuiltIn builtin> |
| struct HALInterfaceWorkgroupIdAndCountConverter final |
| : public OpConversionPattern<InterfaceOpTy> { |
| using OpConversionPattern<InterfaceOpTy>::OpConversionPattern; |
| |
| LogicalResult matchAndRewrite( |
| InterfaceOpTy op, typename InterfaceOpTy::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| int32_t index = static_cast<int32_t>(op.dimension().getSExtValue()); |
| auto i32Type = rewriter.getIntegerType(32); |
| Value spirvBuiltin = |
| spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( |
| op, i32Type, spirvBuiltin, rewriter.getI32ArrayAttr({index})); |
| return success(); |
| } |
| }; |
| |
| /// A pattern to convert hal.interface.binding.subspan into a sequence of SPIR-V |
| /// ops to get the address to a global variable representing the resource |
| /// buffer. |
| struct HALInterfaceBindingSubspanConverter final |
| : public OpConversionPattern<IREE::HAL::InterfaceBindingSubspanOp> { |
| HALInterfaceBindingSubspanConverter( |
| TypeConverter &typeConverter, MLIRContext *context, |
| const InterfaceResourceMap &interfaceToResourceVars, |
| PatternBenefit benefit = 1) |
| : OpConversionPattern(typeConverter, context, benefit), |
| interfaceToResourceVars(interfaceToResourceVars) {} |
| |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceBindingSubspanOp interfaceOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (interfaceOp.use_empty()) { |
| rewriter.eraseOp(interfaceOp); |
| return success(); |
| } |
| |
| Type resultType = interfaceOp.getOperation()->getResult(0).getType(); |
| Type convertedType = this->getTypeConverter()->convertType(resultType); |
| if (!convertedType) { |
| return interfaceOp.emitError() |
| << "failed to convert SPIR-V type: " << resultType; |
| } |
| |
| auto varOp = interfaceToResourceVars.lookup(interfaceOp); |
| // Fix up the variable's type. |
| varOp.typeAttr(TypeAttr::get(convertedType)); |
| |
| rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(interfaceOp, varOp); |
| |
| return success(); |
| } |
| |
| private: |
| const InterfaceResourceMap &interfaceToResourceVars; |
| }; |
| |
| /// Pattern to lower operations that become a no-ops at this level. |
| template <typename OpTy> |
| struct FoldAsNoOp final : public OpConversionPattern<OpTy> { |
| using OpConversionPattern<OpTy>::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| OpTy op, typename OpTy::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp(op, adaptor.getOperands()); |
| return success(); |
| } |
| }; |
| |
| /// Removes unrealized_conversion_cast ops introduced during progressive |
| /// lowering when possible. |
| struct RemoveIdentityConversionCast final |
| : public OpConversionPattern<UnrealizedConversionCastOp> { |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult matchAndRewrite( |
| UnrealizedConversionCastOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (op->getNumOperands() == 1 && op->getNumResults() == 1 && |
| adaptor.getOperands().front().getType() == |
| op->getResultTypes().front()) { |
| rewriter.replaceOp(op, adaptor.getOperands()); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion pass |
| //===----------------------------------------------------------------------===// |
| |
| /// A pass to perform the SPIR-V conversion. |
| /// |
| /// This pass converts remaining interface ops into SPIR-V global variables, |
| /// GPU processor ID ops into SPIR-V global variables, loop/standard ops into |
| /// corresponding SPIR-V ops. |
| struct ConvertToSPIRVPass : public ConvertToSPIRVBase<ConvertToSPIRVPass> { |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<spirv::SPIRVDialect>(); |
| } |
| |
| ConvertToSPIRVPass() {} |
| ConvertToSPIRVPass(const ConvertToSPIRVPass &pass) {} |
| |
| void runOnOperation() override; |
| }; |
| } // namespace |
| |
| void ConvertToSPIRVPass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| ModuleOp moduleOp = getOperation(); |
| |
| llvm::StringMap<IREE::HAL::ExecutableEntryPointOp> entryPoints = |
| getAllEntryPoints(moduleOp); |
| for (auto funcOp : moduleOp.getOps<FuncOp>()) { |
| auto entryPointOp = entryPoints.lookup(funcOp.getName()); |
| if (!entryPointOp) continue; |
| // TODO(ravishankarm): This needs to be removed after ConvertToGPU is |
| // deprecated. All passes must set the `workgroup_size` on the |
| // `hal.executable.entry_point` directly and not on the function. |
| if (funcOp->hasAttr(spirv::getEntryPointABIAttrName())) continue; |
| SmallVector<int64_t> workgroupSize = getWorkgroupSize(entryPointOp); |
| if (workgroupSize.empty()) { |
| entryPointOp.emitOpError( |
| "expected workgroup_size attribute to be set for SPIR-V lowering"); |
| return signalPassFailure(); |
| } |
| auto workgroupSize32 = llvm::to_vector<4>(llvm::map_range( |
| workgroupSize, [](int64_t v) { return static_cast<int32_t>(v); })); |
| funcOp->setAttr(spirv::getEntryPointABIAttrName(), |
| spirv::getEntryPointABIAttr(workgroupSize32, context)); |
| } |
| |
| spirv::TargetEnvAttr targetAttr = getSPIRVTargetEnvAttr(moduleOp); |
| moduleOp->setAttr(spirv::getTargetEnvAttrName(), targetAttr); |
| SPIRVTypeConverter typeConverter(targetAttr); |
| OwningRewritePatternList patterns(&getContext()); |
| ScfToSPIRVContext scfToSPIRVContext; |
| |
| // Pull in GPU patterns to convert processor ID ops and loop ops. |
| populateGPUToSPIRVPatterns(typeConverter, patterns); |
| |
| // Pull in SCF patterns to convert control flow ops. |
| populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); |
| |
| // Use the default 64-bit lowering for TOSA's ApplyScale operator: |
| // This lowering widens integer types to 64-bit an performs the non-fused |
| // operations, specifically multiply, add, and shift. Bit-widening |
| // is used to guarantee higher-order bits are not truncated during the |
| // multiply or add. |
| // |
| // TODO(antiagainst): Use a lowering that uses specific SPIRV intrinsics. |
| tosa::populateTosaRescaleToStandardConversionPatterns(&patterns); |
| |
| // Pull in MemRef patterns to convert load/store ops. |
| populateMemRefToSPIRVPatterns(typeConverter, patterns); |
| |
| // Pull in standard/math patterns to convert arithmetic ops and others. |
| arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); |
| populateStandardToSPIRVPatterns(typeConverter, patterns); |
| populateMathToSPIRVPatterns(typeConverter, patterns); |
| |
| // Pull in standard patterns to convert tensor operations to SPIR-V. These are |
| // primarily used to handle tensor-type constants and contain a |
| // threshold. Only those constants that are below the threshold are converted |
| // to SPIR-V. In IREE we want to control this threshold at Flow level. So set |
| // this value arbitrarily high to make sure that everything within a dispatch |
| // region is converted. |
| mlir::populateTensorToSPIRVPatterns( |
| typeConverter, std::numeric_limits<int64_t>::max() / 8, patterns); |
| |
| // Pull in vector patterns to convert vector ops. |
| mlir::populateVectorToSPIRVPatterns(typeConverter, patterns); |
| |
| // Pull in builtin func to spv.func conversion. |
| populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); |
| |
| // Add IREE HAL interface op conversions. |
| patterns.insert< |
| HALInterfaceLoadConstantConverter, |
| HALInterfaceWorkgroupIdAndCountConverter< |
| IREE::HAL::InterfaceWorkgroupIDOp, spirv::BuiltIn::WorkgroupId>, |
| HALInterfaceWorkgroupIdAndCountConverter< |
| IREE::HAL::InterfaceWorkgroupCountOp, spirv::BuiltIn::NumWorkgroups>>( |
| typeConverter, context); |
| |
| // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops |
| // and create spv.GlobalVariables. |
| auto interfaceToResourceVars = createResourceVariables(moduleOp); |
| // For using use them in conversion. |
| patterns.insert<HALInterfaceBindingSubspanConverter>(typeConverter, context, |
| interfaceToResourceVars); |
| |
| /// Fold certain operations as no-ops: |
| /// - linalg.reshape becomes a no-op since all memrefs are linearized in |
| /// SPIR-V. |
| /// - tensor_to_memref can become a no-op since tensors are lowered to |
| /// !spv.array. |
| /// - unrealized_conversion_cast with the same source and target type. |
| patterns.insert< |
| FoldAsNoOp<memref::CollapseShapeOp>, FoldAsNoOp<memref::ExpandShapeOp>, |
| FoldAsNoOp<bufferization::ToMemrefOp>, RemoveIdentityConversionCast>( |
| typeConverter, context); |
| |
| std::unique_ptr<ConversionTarget> target = |
| SPIRVConversionTarget::get(targetAttr); |
| // Disallow all other ops. |
| target->markUnknownOpDynamicallyLegal([](Operation *) { return false; }); |
| |
| SmallVector<FuncOp, 1> functions; |
| for (FuncOp fn : moduleOp.getOps<FuncOp>()) { |
| if (!fn.isPublic()) continue; |
| functions.push_back(fn); |
| } |
| |
| FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| for (FuncOp fn : functions) { |
| if (failed(applyFullConversion(fn, *target, frozenPatterns))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| // Collect all SPIR-V ops into a spv.module. |
| auto builder = OpBuilder::atBlockBegin(moduleOp.getBody()); |
| auto spvModule = builder.create<spirv::ModuleOp>( |
| moduleOp.getLoc(), spirv::AddressingModel::Logical, |
| spirv::MemoryModel::GLSL450); |
| Block *body = spvModule.getBody(); |
| Dialect *spvDialect = spvModule->getDialect(); |
| for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) { |
| // Skip the newly created spv.module itself. |
| if (&op == spvModule) continue; |
| if (op.getDialect() == spvDialect) op.moveBefore(body, body->end()); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass entry point and registration |
| //===----------------------------------------------------------------------===// |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass() { |
| return std::make_unique<ConvertToSPIRVPass>(); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |