| // Copyright 2020 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| //===- CovertToSPIRVPass.cpp - Pass for the final SPIR-V conversion -------===// |
| // |
| // This file implements a pass to perform the final conversion to SPIR-V. |
| // This pass converts remaining interface ops into SPIR-V global variables, |
| // GPU processor ID ops into SPIR-V global variables, loop/standard ops into |
| // corresponding SPIR-V ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "iree/compiler/Conversion/CodegenUtils/MarkerUtils.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/IREE/IR/IREEOps.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" |
| #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/SPIRV/SPIRVLowering.h" |
| #include "mlir/Dialect/SPIRV/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/SPIRVTypes.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Vector/VectorOps.h" |
| #include "mlir/IR/Function.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| namespace { |
| //===----------------------------------------------------------------------===// |
| // Resource and push constant variable utilities |
| //===----------------------------------------------------------------------===// |
| // TODO(antiagainst): move these utilities to MLIR core. |
| |
| /// Returns the pointer type for the push constant storage containing |
| /// `elementCount` 32-bit integer values. |
| spirv::PointerType getPushConstantStorageType(unsigned elementCount, |
| Builder &builder) { |
| auto arrayType = spirv::ArrayType::get( |
| SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount, |
| /*stride=*/4); |
| auto structType = spirv::StructType::get({arrayType}, /*LayoutInfo=*/0); |
| return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); |
| } |
| |
| /// Returns the push constant varible containing `elementCount` 32-bit integer |
| /// values in `body`. Returns null op if such an op does not exit. |
| spirv::GlobalVariableOp getPushConstantVariable(Block &body, |
| unsigned elementCount) { |
| for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { |
| auto ptrType = varOp.type().cast<spirv::PointerType>(); |
| // Note that Vulkan requires "There must be no more than one push constant |
| // block statically used per shader entry point." So we should always reuse |
| // the existing one. |
| if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { |
| auto numElements = ptrType.getPointeeType() |
| .cast<spirv::StructType>() |
| .getElementType(0) |
| .cast<spirv::ArrayType>() |
| .getNumElements(); |
| if (numElements == elementCount) return varOp; |
| } |
| } |
| return nullptr; |
| } |
| |
| /// Gets or inserts a global variable for push constant storage containing |
| /// `elementCount` 32-bit integer values in `block`. |
| spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, |
| Block &block, |
| unsigned elementCount) { |
| if (auto varOp = getPushConstantVariable(block, elementCount)) return varOp; |
| |
| auto builder = OpBuilder::atBlockBegin(&block); |
| auto typeAttr = |
| TypeAttr::get(getPushConstantStorageType(elementCount, builder)); |
| StringRef name = "__push_constant_var__"; |
| return builder.create<spirv::GlobalVariableOp>(loc, typeAttr, name, |
| /*initializer=*/nullptr); |
| } |
| |
| /// Gets the value at the given `offset` of the push constant storage. A global |
| /// variable will be created for the push constant storage if not existing. Load |
| /// ops will be created via the given `builder` to load values from the push |
| /// constant. |
| Value getPushConstantValue(Operation *op, unsigned elementCount, |
| unsigned offset, OpBuilder &builder) { |
| Location loc = op->getLoc(); |
| Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); |
| if (!parent) { |
| op->emitError("expected operation to be within a module-like op"); |
| return nullptr; |
| } |
| |
| spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( |
| loc, parent->getRegion(0).front(), elementCount); |
| |
| auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext()); |
| Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder); |
| Value offsetOp = builder.create<spirv::ConstantOp>( |
| loc, i32Type, builder.getI32IntegerAttr(offset)); |
| auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp); |
| auto acOp = builder.create<spirv::AccessChainOp>( |
| loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); |
| return builder.create<spirv::LoadOp>(loc, acOp); |
| } |
| |
| /// Gets or inserts a resource evariable of the given `type` in `block` and bind |
| /// it to `set` and `binding`. |
| spirv::GlobalVariableOp getOrInsertResourceVariable(Location loc, Type type, |
| unsigned set, |
| unsigned binding, |
| Block &block) { |
| auto name = llvm::formatv("__resource_var_{0}_{1}__", set, binding).str(); |
| for (auto varOp : block.getOps<spirv::GlobalVariableOp>()) { |
| if (varOp.sym_name() == name) return varOp; |
| } |
| |
| auto builder = OpBuilder::atBlockBegin(&block); |
| return builder.create<spirv::GlobalVariableOp>(loc, type, name, set, binding); |
| } |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion patterns and pass declarations |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Template class for attaching type converter to conversion patterns. |
| template <typename SourceOp> |
| class InterfaceOpConversion : public OpConversionPattern<SourceOp> { |
| public: |
| InterfaceOpConversion(MLIRContext *context, TypeConverter &typeConverter, |
| PatternBenefit benefit = 1) |
| : OpConversionPattern<SourceOp>(context, benefit), |
| typeConverter(typeConverter) {} |
| |
| protected: |
| TypeConverter &typeConverter; |
| }; |
| |
| /// A pattern to convert hal.interface.load.constant into a sequence of SPIR-V |
| /// ops to load from a global variable representing the push constant storage. |
| struct HALInterfaceLoadConstantConverter final |
| : public InterfaceOpConversion<IREE::HAL::InterfaceLoadConstantOp> { |
| using InterfaceOpConversion< |
| IREE::HAL::InterfaceLoadConstantOp>::InterfaceOpConversion; |
| |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceLoadConstantOp loadOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// A pattern to convert iree.placeholdder into a sequence of SPIR-V ops to get |
| /// the address to a global variable representing the resource buffer. |
| struct IREEPlaceholderConverter final |
| : public InterfaceOpConversion<IREE::PlaceholderOp> { |
| using InterfaceOpConversion<IREE::PlaceholderOp>::InterfaceOpConversion; |
| |
| LogicalResult matchAndRewrite( |
| IREE::PlaceholderOp phOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Pattern to lower linalg.reshape to SPIR-V. Since all buffers are linearized |
| /// in SPIR-V lowering, linalg.reshape becomes a no-op. |
| // TODO(ravishankarm): Move this into MLIR Core. |
| struct LinalgReshapeConverter final |
| : public SPIRVOpLowering<linalg::ReshapeOp> { |
| using SPIRVOpLowering<linalg::ReshapeOp>::SPIRVOpLowering; |
| LogicalResult matchAndRewrite( |
| linalg::ReshapeOp reshapeOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOp(reshapeOp, operands); |
| return success(); |
| } |
| }; |
| |
| /// Convert subgroup level vector transfert to SPIR-V cooperative |
| /// matrix load/store if those are supported. |
| /// TODO(thomasraoux): Move to MLIR core once this is stable. |
| template <typename OpTy> |
| class TransferToCoopMatLoadStore final : public SPIRVOpLowering<OpTy> { |
| public: |
| using SPIRVOpLowering<OpTy>::SPIRVOpLowering; |
| |
| LogicalResult matchAndRewrite( |
| OpTy op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!hasCooperativeMatrixMarker(op)) return failure(); |
| auto loc = op.getLoc(); |
| auto vecType = op.getVectorType(); |
| if (vecType.getRank() != 2) return failure(); |
| // TODO(thomasraoux): use coloumn major operand when TransfertRead + |
| // TransposeOp. |
| if (!op.permutation_map().isIdentity()) return failure(); |
| if (op.masked()) return failure(); |
| auto matType = spirv::CooperativeMatrixNVType::get( |
| vecType.getElementType(), spirv::Scope::Subgroup, vecType.getDimSize(0), |
| vecType.getDimSize(1)); |
| SmallVector<Value, 4> remappedIndices; |
| for (auto i : op.indices()) |
| remappedIndices.push_back(rewriter.getRemappedValue(i)); |
| Value ptr = spirv::getElementPtr( |
| SPIRVOpLowering<OpTy>::typeConverter, op.getMemRefType(), |
| rewriter.getRemappedValue(op.memref()), remappedIndices, loc, rewriter); |
| int64_t offset = 0; |
| SmallVector<int64_t, 2> strides; |
| getStridesAndOffset(op.getMemRefType(), strides, offset); |
| auto stride = strides[0]; |
| if (BaseMemRefType::isDynamicStrideOrOffset(stride)) return failure(); |
| auto int32Type = rewriter.getI32Type(); |
| auto strideValue = rewriter.create<spirv::ConstantOp>( |
| loc, int32Type, IntegerAttr::get(int32Type, stride)); |
| auto coloumnMajor = rewriter.create<spirv::ConstantOp>( |
| loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); |
| replaceTransferOp(op, loc, matType, ptr, strideValue, coloumnMajor, |
| rewriter); |
| return success(); |
| } |
| |
| /// Helper to generate the right load/store instruction and replace the |
| /// transfer op. |
| void replaceTransferOp(OpTy op, Location loc, Type matType, Value ptr, |
| Value strideValue, Value coloumnMajor, |
| ConversionPatternRewriter &rewriter) const; |
| }; |
| |
| template <> |
| void TransferToCoopMatLoadStore<vector::TransferReadOp>::replaceTransferOp( |
| vector::TransferReadOp op, Location loc, Type matType, Value ptr, |
| Value strideValue, Value coloumnMajor, |
| ConversionPatternRewriter &rewriter) const { |
| Value load = rewriter.create<spirv::CooperativeMatrixLoadNVOp>( |
| loc, matType, ptr, strideValue, coloumnMajor, IntegerAttr()); |
| rewriter.replaceOp(op, load); |
| } |
| |
| template <> |
| void TransferToCoopMatLoadStore<vector::TransferWriteOp>::replaceTransferOp( |
| vector::TransferWriteOp op, Location loc, Type matType, Value ptr, |
| Value strideValue, Value coloumnMajor, |
| ConversionPatternRewriter &rewriter) const { |
| rewriter.create<spirv::CooperativeMatrixStoreNVOp>( |
| loc, ptr, rewriter.getRemappedValue(op.vector()), strideValue, |
| coloumnMajor, IntegerAttr()); |
| rewriter.eraseOp(op); |
| } |
| |
| /// Convert subgroup level vector contract to SPIR-V cooperative |
| /// matrix matmuladd. |
| class VectorContractToCoopMatmul final |
| : public SPIRVOpLowering<vector::ContractionOp> { |
| public: |
| using SPIRVOpLowering<vector::ContractionOp>::SPIRVOpLowering; |
| |
| LogicalResult matchAndRewrite( |
| vector::ContractionOp contractOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| // TODO(thomasraoux): Check that the size of the matmul is supported by the |
| // target. |
| if (!hasCooperativeMatrixMarker(contractOp)) return failure(); |
| auto loc = contractOp.getLoc(); |
| // Check that all the operands are cooperative matrix. |
| vector::ContractionOp::Adaptor adaptor(operands); |
| auto loadA = adaptor.lhs(); |
| auto loadB = adaptor.rhs(); |
| auto loadC = adaptor.acc(); |
| if (!loadA.getType().isa<spirv::CooperativeMatrixNVType>() || |
| !loadB.getType().isa<spirv::CooperativeMatrixNVType>() || |
| !loadC.getType().isa<spirv::CooperativeMatrixNVType>()) |
| return failure(); |
| if (llvm::size(contractOp.masks()) != 0) return failure(); |
| // Check that this is a matmul operation. |
| auto iteratorTypes = contractOp.iterator_types().getValue(); |
| if (!isParallelIterator(iteratorTypes[0]) || |
| !isParallelIterator(iteratorTypes[1]) || |
| !isReductionIterator(iteratorTypes[2])) |
| return failure(); |
| // Coloumn major matmul should have been lowered to Transpose+contract |
| // by this point. Transpose can be handled by load/stoore operations. |
| if (!isRowMajorMatmul(contractOp.indexing_maps())) return failure(); |
| |
| Value matmul = rewriter.create<spirv::CooperativeMatrixMulAddNVOp>( |
| loc, loadC.getType(), loadA, loadB, loadC); |
| rewriter.replaceOp(contractOp, matmul); |
| return success(); |
| } |
| }; |
| |
| /// A pass to perform the SPIR-V conversion. |
| /// |
| /// This pass converts remaining interface ops into SPIR-V global variables, |
| /// GPU processor ID ops into SPIR-V global variables, loop/standard ops into |
| /// corresponding SPIR-V ops. |
| struct ConvertToSPIRVPass |
| : public PassWrapper<ConvertToSPIRVPass, OperationPass<ModuleOp>> { |
| void runOnOperation() override; |
| ConvertToSPIRVPass() {} |
| ConvertToSPIRVPass(const ConvertToSPIRVPass &pass) {} |
| Option<bool> useCooperativeMatrix{ |
| *this, "use-cooperative-matrix", |
| llvm::cl::desc("Experimental: Lower vector contract to cooperative " |
| "matrix operations"), |
| llvm::cl::init(false)}; |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion patterns and pass implementations |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult HALInterfaceLoadConstantConverter::matchAndRewrite( |
| IREE::HAL::InterfaceLoadConstantOp loadOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const { |
| // TODO(#1519): hal.interface.load.constant should point to the |
| // hal.interface op. |
| auto moduleOp = loadOp.getParentOfType<ModuleOp>(); |
| auto halInterfaceOps = |
| llvm::to_vector<1>(moduleOp.getOps<IREE::HAL::InterfaceOp>()); |
| assert(halInterfaceOps.size() == 1); |
| |
| unsigned elementCount = |
| halInterfaceOps.front().push_constants()->getZExtValue(); |
| unsigned offset = loadOp.offset().getZExtValue(); |
| |
| // The following function generates SPIR-V ops with i32 types. So it does type |
| // "conversion" (index -> i32) implicitly. |
| auto value = getPushConstantValue(loadOp, elementCount, offset, rewriter); |
| |
| rewriter.replaceOp(loadOp, value); |
| return success(); |
| } |
| |
| LogicalResult IREEPlaceholderConverter::matchAndRewrite( |
| IREE::PlaceholderOp phOp, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const { |
| auto moduleOp = phOp.getParentOfType<ModuleOp>(); |
| |
| Type convertedType = typeConverter.convertType(phOp.getType()); |
| if (!convertedType) { |
| return phOp.emitError() |
| << "SPIRV type conversion failed: " << phOp.getType(); |
| } |
| auto bindingOp = dyn_cast_or_null<IREE::HAL::InterfaceBindingOp>( |
| SymbolTable::lookupNearestSymbolFrom( |
| phOp, phOp.getAttrOfType<SymbolRefAttr>("binding"))); |
| |
| spirv::GlobalVariableOp varOp = getOrInsertResourceVariable( |
| phOp.getLoc(), convertedType, bindingOp.set().getZExtValue(), |
| bindingOp.binding().getZExtValue(), *moduleOp.getBody()); |
| |
| rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(phOp, varOp); |
| return success(); |
| } |
| |
| static void populateVectorToSPIRVPatterns(MLIRContext *context, |
| SPIRVTypeConverter &typeConverter, |
| OwningRewritePatternList &patterns) { |
| patterns.insert<TransferToCoopMatLoadStore<vector::TransferReadOp>, |
| TransferToCoopMatLoadStore<vector::TransferWriteOp>, |
| VectorContractToCoopMatmul>(context, typeConverter); |
| } |
| |
| void ConvertToSPIRVPass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| ModuleOp moduleOp = getOperation(); |
| |
| auto targetAttr = spirv::lookupTargetEnv(moduleOp); |
| SPIRVTypeConverter typeConverter(targetAttr); |
| |
| OwningRewritePatternList patterns; |
| // Pull in GPU patterns to convert processor ID ops and loop ops. |
| populateGPUToSPIRVPatterns(context, typeConverter, patterns); |
| // Pull in standard patterns to convert arithmetic ops and others. |
| populateStandardToSPIRVPatterns(context, typeConverter, patterns); |
| // Pull in builtin func to spv.func conversion. |
| populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); |
| |
| if (useCooperativeMatrix) { |
| populateVectorToSPIRVPatterns(context, typeConverter, patterns); |
| } |
| patterns.insert<HALInterfaceLoadConstantConverter, IREEPlaceholderConverter, |
| LinalgReshapeConverter>(context, typeConverter); |
| |
| std::unique_ptr<ConversionTarget> target = |
| spirv::SPIRVConversionTarget::get(targetAttr); |
| // Disallow all other ops. |
| target->markUnknownOpDynamicallyLegal([](Operation *) { return false; }); |
| SmallVector<FuncOp, 1> functions; |
| for (FuncOp fn : moduleOp.getOps<FuncOp>()) { |
| if (SymbolTable::getSymbolVisibility(fn) != SymbolTable::Visibility::Public) |
| continue; |
| functions.push_back(fn); |
| } |
| |
| for (FuncOp fn : functions) |
| if (failed(applyFullConversion(fn, *target, patterns))) |
| return signalPassFailure(); |
| |
| // Collect all SPIR-V ops into a spv.module. |
| auto builder = OpBuilder::atBlockBegin(moduleOp.getBody()); |
| auto spvModule = builder.create<spirv::ModuleOp>( |
| moduleOp.getLoc(), spirv::AddressingModel::Logical, |
| spirv::MemoryModel::GLSL450); |
| Operation *terminator = spvModule.getBlock().getTerminator(); |
| Dialect *spvDialect = spvModule.getDialect(); |
| for (Operation &op : llvm::make_early_inc_range(*moduleOp.getBody())) { |
| // Skip the newly created spv.module itself. |
| if (&op == spvModule) continue; |
| if (op.getDialect() == spvDialect) op.moveBefore(terminator); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass entry point and registration |
| //===----------------------------------------------------------------------===// |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createConvertToSPIRVPass() { |
| return std::make_unique<ConvertToSPIRVPass>(); |
| } |
| |
| static PassRegistration<ConvertToSPIRVPass> pass( |
| "iree-codegen-convert-to-spirv", |
| "Perform final conversion from builtin/GPU/HAL/standard dialect to SPIR-V " |
| "dialect"); |
| } // namespace iree_compiler |
| } // namespace mlir |