| // Copyright 2020 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/LLVMCPU/DispatchABI.h" |
| #include "iree/compiler/Codegen/LLVMCPU/LLVMCPUPasses.h" |
| #include "iree/compiler/Codegen/LLVMCPU/Utils.h" |
| #include "iree/compiler/Codegen/PassDetail.h" |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" |
| #include "iree/compiler/Dialect/HAL/IR/HALOps.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" |
| #include "iree/compiler/Dialect/Util/IR/UtilOps.h" |
| #include "iree/schemas/instruments/dispatch.h" |
| #include "llvm/Support/Mutex.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "llvm/TargetParser/Triple.h" |
| #include "mlir/Analysis/DataLayoutAnalysis.h" |
| #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" |
| #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
| #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" |
| #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" |
| #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
| #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
| #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" |
| #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" |
| #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
| #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
| #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
| #include "mlir/Conversion/TosaToArith/TosaToArith.h" |
| #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
| #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Transforms/Passes.h" |
| #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Func/Transforms/Passes.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Math/Transforms/Passes.h" |
| #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace iree_compiler { |
| |
| namespace { |
| |
| template <typename OpT> |
| struct ConvertOpToLLVMWithABIPattern : public ConvertOpToLLVMPattern<OpT> { |
| ConvertOpToLLVMWithABIPattern(HALDispatchABI &abi, |
| LLVMTypeConverter &typeConverter, |
| PatternBenefit benefit = 1) |
| : ConvertOpToLLVMPattern<OpT>(typeConverter, benefit), abi(abi) {} |
| HALDispatchABI &abi; |
| }; |
| |
| /// Converts Standard MLIR FuncOps to LLVMFuncOps matching the IREE HAL ABI. |
| /// This is an IREE-specific conversion that assumes the input function is |
| /// `() -> ()` and that hal.interface.* ops are used to access all state. |
| /// |
| /// Source function: |
| /// |
| /// ``` |
| /// func.func @foo() { |
| /// %0 = hal.interface.binding.subspan ... |
| /// } |
| /// ``` |
| /// |
| /// into: |
| /// |
| /// ``` |
| /// llvm.func foo(%state: !llvm.ptr<!...>, |
| /// %workgroup_id : !llvm.ptr<!llvm.array<i32, 3>>) { |
| /// %0 = <GEP/loads to access binding in %state> |
| /// } |
| /// ``` |
| /// |
| /// See `iree/hal/local/executable_library.h` for more information. |
| /// |
| /// NOTE: we bump the benefit of the pattern to 100 to pick this pattern instead |
| /// of a competing pattern inserted by `populateFuncToLLVMConversionPatterns`. |
| struct ConvertHALEntryPointFuncOp |
| : public ConvertOpToLLVMWithABIPattern<func::FuncOp> { |
| ConvertHALEntryPointFuncOp(HALDispatchABI &abi, |
| LLVMTypeConverter &typeConverter) |
| : ConvertOpToLLVMWithABIPattern(abi, typeConverter, |
| /*benefit=*/100) {} |
| LogicalResult matchAndRewrite( |
| func::FuncOp stdFuncOp, func::FuncOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!stdFuncOp.isPublic()) return failure(); |
| FunctionType fnType = stdFuncOp.getFunctionType(); |
| if (fnType.getNumInputs() != 0 || fnType.getNumResults() != 0) { |
| stdFuncOp->emitWarning() |
| << "public functions on executables must be () -> ()"; |
| return failure(); |
| } |
| |
| // Convert the function signature to take the HAL ABI LLVM pointers. |
| TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0); |
| MLIRContext *context = rewriter.getContext(); |
| auto abiInputTypes = |
| HALDispatchABI::getInputTypes(context, getTypeConverter()); |
| signatureConverter.addInputs(abiInputTypes); |
| |
| // Copy all attributes onto the LLVM function except the ones handled by |
| // MLIR implicitly. |
| SmallVector<NamedAttribute, 4> funcAttrs; |
| for (auto attr : stdFuncOp->getAttrs()) { |
| if (attr.getName() == SymbolTable::getSymbolAttrName() || |
| attr.getName() == stdFuncOp.getFunctionTypeAttrName()) { |
| continue; |
| } |
| funcAttrs.push_back(attr); |
| } |
| |
| // Clone the function as an LLVMFuncOp and convert all interior types. |
| auto int32Type = IntegerType::get(rewriter.getContext(), 32); |
| auto llvmFuncType = LLVM::LLVMFunctionType::get(int32Type, abiInputTypes); |
| auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( |
| stdFuncOp.getLoc(), stdFuncOp.getName(), llvmFuncType, |
| LLVM::Linkage::External, /*dso_local=*/false, /*cconv*/ LLVM::CConv::C, |
| funcAttrs); |
| rewriter.inlineRegionBefore(stdFuncOp.getFunctionBody(), |
| llvmFuncOp.getFunctionBody(), llvmFuncOp.end()); |
| if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getFunctionBody(), |
| *typeConverter, |
| &signatureConverter))) { |
| return failure(); |
| } |
| |
| // Tag all arguments so LLVM can reason about our exports it otherwise |
| // cannot analyze. We do this early on so that MLIR-based LLVM transforms |
| // can use the attributes. |
| // (%arg0: environment, %arg1: dispatch_state, %arg2: workgroup_state) |
| for (unsigned i = 0; i <= 2; ++i) { |
| llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getNoAliasAttrName(), |
| rewriter.getUnitAttr()); |
| llvmFuncOp.setArgAttr(i, LLVM::LLVMDialect::getAlignAttrName(), |
| rewriter.getI64IntegerAttr(16)); |
| } |
| |
| // Add default zero return value. |
| // TODO(ataei): do something meaningful with the return value; non-zero will |
| // have the runtime bail out with an error. |
| for (auto returnOp : llvm::make_early_inc_range( |
| llvmFuncOp.getOps<mlir::func::ReturnOp>())) { |
| rewriter.setInsertionPoint(returnOp); |
| auto returnValue = rewriter.createOrFold<mlir::arith::ConstantIntOp>( |
| returnOp.getLoc(), 0, 32); |
| rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(returnOp, returnValue); |
| } |
| |
| // Populate debug info for the subprogram signature. This is required in |
| // order to get any debug information (including just line tables) from MLIR |
| // into LLVM IR. |
| auto scopeAttr = HALDispatchABI::buildScopeAttr( |
| llvmFuncOp->getParentOfType<mlir::ModuleOp>(), llvmFuncOp.getName(), |
| getTypeConverter()); |
| llvmFuncOp->setLoc(FusedLoc::get(llvmFuncOp.getContext(), |
| {llvmFuncOp->getLoc()}, scopeAttr)); |
| |
| rewriter.eraseOp(stdFuncOp); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.constant.load to ops loading from the ABI structs. |
| /// Because ordinals are not yet available we emit a placeholder global that |
| /// later gets updated with the value after linking. |
| /// |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| struct ConvertHALExecutableConstantLoadOp |
| : public ConvertOpToLLVMWithABIPattern< |
| IREE::HAL::ExecutableConstantLoadOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::ExecutableConstantLoadOp loadOp, |
| IREE::HAL::ExecutableConstantLoadOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto resultType = |
| typeConverter->convertType(loadOp->getResult(0).getType()); |
| rewriter.replaceOp( |
| loadOp, abi.loadExecutableConstant(loadOp, loadOp.getKey(), resultType, |
| rewriter)); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.workgroup.id to ops loading from the ABI structs. |
| /// |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| struct ConvertHALInterfaceWorkgroupIDOp |
| : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InterfaceWorkgroupIDOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceWorkgroupIDOp idOp, |
| IREE::HAL::InterfaceWorkgroupIDOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| int32_t dim = (int32_t)idOp.getDimension().getZExtValue(); |
| auto resultType = typeConverter->convertType(idOp->getResult(0).getType()); |
| rewriter.replaceOp(idOp, |
| abi.loadWorkgroupID(idOp, dim, resultType, rewriter)); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.workgroup.size to ops loading from the ABI structs. |
| /// |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| struct ConvertHALInterfaceWorkgroupSizeOp |
| : public ConvertOpToLLVMWithABIPattern< |
| IREE::HAL::InterfaceWorkgroupSizeOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceWorkgroupSizeOp sizeOp, |
| IREE::HAL::InterfaceWorkgroupSizeOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| int32_t dim = (int32_t)sizeOp.getDimension().getZExtValue(); |
| auto resultType = |
| typeConverter->convertType(sizeOp->getResult(0).getType()); |
| rewriter.replaceOp( |
| sizeOp, abi.loadWorkgroupSize(sizeOp, dim, resultType, rewriter)); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.workgroup.count to ops loading from the ABI structs. |
| /// |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| struct ConvertHALInterfaceWorkgroupCountOp |
| : public ConvertOpToLLVMWithABIPattern< |
| IREE::HAL::InterfaceWorkgroupCountOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceWorkgroupCountOp countOp, |
| IREE::HAL::InterfaceWorkgroupCountOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| int32_t dim = (int32_t)countOp.getDimension().getZExtValue(); |
| auto resultType = |
| typeConverter->convertType(countOp->getResult(0).getType()); |
| rewriter.replaceOp( |
| countOp, abi.loadWorkgroupCount(countOp, dim, resultType, rewriter)); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.constant.load to ops loading from the ABI structs. |
| /// |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| struct ConvertHALInterfaceConstantLoadOp |
| : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InterfaceConstantLoadOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceConstantLoadOp loadOp, |
| IREE::HAL::InterfaceConstantLoadOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| int64_t index = loadOp.getIndex().getZExtValue(); |
| auto resultType = |
| typeConverter->convertType(loadOp->getResult(0).getType()); |
| rewriter.replaceOp( |
| loadOp, abi.loadPushConstant(loadOp, index, resultType, rewriter)); |
| return success(); |
| } |
| }; |
| |
| /// Rewrites hal.interface.binding.subspan to ops loading from the ABI structs. |
| /// |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| struct ConvertHALInterfaceBindingSubspanOp |
| : public ConvertOpToLLVMWithABIPattern< |
| IREE::HAL::InterfaceBindingSubspanOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InterfaceBindingSubspanOp subspanOp, |
| IREE::HAL::InterfaceBindingSubspanOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| MemRefType memRefType = |
| llvm::dyn_cast<MemRefType>(subspanOp->getResult(0).getType()); |
| if (!memRefType) { |
| return rewriter.notifyMatchFailure( |
| subspanOp, |
| "failed to convert interface.binding.subspan result to memref type"); |
| } |
| auto memRefDesc = abi.loadBinding( |
| subspanOp, operands.getBindingAttr().getInt(), operands.getByteOffset(), |
| memRefType, operands.getDynamicDims(), rewriter); |
| rewriter.replaceOp(subspanOp, {memRefDesc}); |
| return success(); |
| } |
| }; |
| |
| struct InstrumentationEntry { |
| // !llvm.ptr<i8> pointing at the base of the ringbuffer. |
| Value basePtr; |
| // !llvm.ptr<i8> pointing at the start of the entry (basePtr + offset). |
| Value entryPtr; |
| // i64 offset within the ringbuffer of the entry. |
| Value offset; |
| }; |
| |
| // entrySize must be 16-byte aligned |
| static InstrumentationEntry acquireInstrumentationEntry(Location loc, |
| Value buffer, |
| Value bufferPtr, |
| Value entrySize, |
| OpBuilder &builder) { |
| auto i64Type = builder.getI64Type(); |
| auto bufferType = llvm::cast<MemRefType>(buffer.getType()); |
| int64_t totalBufferSize = |
| (bufferType.getNumElements() * bufferType.getElementTypeBitWidth()) / 8; |
| int64_t headOffset = totalBufferSize - 8; |
| int64_t ringSize = totalBufferSize - IREE_INSTRUMENT_DISPATCH_PADDING; |
| assert(llvm::isPowerOf2_64(ringSize) && |
| "ringbuffer storage size must be a power-of-two"); |
| |
| Value basePtr = MemRefDescriptor(bufferPtr).alignedPtr(builder, loc); |
| |
| Value offsetIndex = |
| builder.create<LLVM::ConstantOp>(loc, i64Type, headOffset); |
| Value offsetPtr = builder.create<LLVM::GEPOp>( |
| loc, basePtr.getType(), LLVM::LLVMPointerType::get(builder.getContext()), |
| basePtr, offsetIndex, |
| /*inbounds=*/true); |
| Value rawOffset = builder.create<LLVM::AtomicRMWOp>( |
| loc, LLVM::AtomicBinOp::add, offsetPtr, entrySize, |
| LLVM::AtomicOrdering::monotonic); |
| Value offsetMask = |
| builder.create<LLVM::ConstantOp>(loc, i64Type, ringSize - 1); |
| Value wrappedOffset = builder.create<LLVM::AndOp>(loc, rawOffset, offsetMask); |
| |
| Value entryPtr = builder.create<LLVM::GEPOp>( |
| loc, basePtr.getType(), LLVM::LLVMPointerType::get(builder.getContext()), |
| basePtr, wrappedOffset); |
| |
| return {basePtr, entryPtr, wrappedOffset}; |
| } |
| |
| static InstrumentationEntry appendInstrumentationEntry( |
| Location loc, Value buffer, Value bufferPtr, LLVM::LLVMStructType entryType, |
| ArrayRef<Value> entryValues, DataLayout &dataLayout, OpBuilder &builder) { |
| auto i64Type = builder.getI64Type(); |
| |
| Value entrySize = builder.create<LLVM::ConstantOp>( |
| loc, i64Type, dataLayout.getTypeSize(entryType)); |
| auto entry = |
| acquireInstrumentationEntry(loc, buffer, bufferPtr, entrySize, builder); |
| |
| Value entryStruct = builder.create<LLVM::UndefOp>(loc, entryType); |
| for (auto entryValue : llvm::enumerate(entryValues)) { |
| entryStruct = builder.create<LLVM::InsertValueOp>( |
| loc, entryStruct, entryValue.value(), entryValue.index()); |
| } |
| |
| builder.create<LLVM::StoreOp>( |
| loc, entryStruct, |
| builder.create<LLVM::BitcastOp>( |
| loc, LLVM::LLVMPointerType::get(builder.getContext()), |
| entry.entryPtr), |
| /*alignment=*/16); |
| |
| return entry; |
| } |
| |
| static int64_t getMemoryAccessByteSize(Type type) { |
| if (auto vectorType = llvm::dyn_cast<VectorType>(type)) { |
| return (vectorType.getNumElements() * vectorType.getElementTypeBitWidth()) / |
| 8; |
| } else { |
| return type.getIntOrFloatBitWidth() / 8; |
| } |
| } |
| |
| struct ConvertHALInstrumentWorkgroupOp |
| : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentWorkgroupOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InstrumentWorkgroupOp instrumentOp, |
| IREE::HAL::InstrumentWorkgroupOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = instrumentOp.getLoc(); |
| auto dataLayout = |
| getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp); |
| auto i32Type = rewriter.getI32Type(); |
| auto i64Type = rewriter.getI64Type(); |
| |
| auto entryType = LLVM::LLVMStructType::getLiteral( |
| getContext(), { |
| i32Type, // header |
| i32Type, // workgroup_id_x |
| i32Type, // workgroup_id_y |
| i32Type, // workgroup_id_z |
| i32Type, // workgroup_count_x |
| i32Type, // workgroup_count_y |
| i32Type, // workgroup_count_z |
| i32Type, // processor_id |
| }); |
| |
| // 8 bit tag = 00 | 24 bit dispatch id |
| // NOTE: we could pre-shift this to avoid needing to do it in each group. |
| // We just need to do the shift - the bottom two bits will be the 00 tag. |
| Value rawDispatchId = instrumentOp.getDispatchId(); |
| Value header = rewriter.create<LLVM::ShlOp>( |
| loc, i32Type, rawDispatchId, |
| rewriter.create<LLVM::ConstantOp>(loc, i32Type, 8)); // | 8bit tag |
| |
| auto entry = appendInstrumentationEntry( |
| loc, instrumentOp.getBuffer(), operands.getBuffer(), entryType, |
| { |
| header, |
| abi.loadWorkgroupID(instrumentOp, 0, i32Type, rewriter), |
| abi.loadWorkgroupID(instrumentOp, 1, i32Type, rewriter), |
| abi.loadWorkgroupID(instrumentOp, 2, i32Type, rewriter), |
| abi.loadWorkgroupCount(instrumentOp, 0, i32Type, rewriter), |
| abi.loadWorkgroupCount(instrumentOp, 1, i32Type, rewriter), |
| abi.loadWorkgroupCount(instrumentOp, 2, i32Type, rewriter), |
| abi.loadProcessorID(instrumentOp, rewriter), |
| }, |
| dataLayout, rewriter); |
| |
| // Prepare the 40-bit key used by all accesses - we do this once so that we |
| // can ensure it's hoisted. |
| // Consumers expect 40 bits of offset << 24 bits. |
| Value workgroupKey = rewriter.create<LLVM::ShlOp>( |
| loc, |
| rewriter.create<LLVM::AndOp>( |
| loc, entry.offset, |
| rewriter.create<LLVM::ConstantOp>(loc, i64Type, 0xFFFFFFFFFFll)), |
| rewriter.create<LLVM::ConstantOp>(loc, i64Type, 24)); |
| |
| rewriter.replaceOp(instrumentOp, workgroupKey); |
| return success(); |
| } |
| }; |
| |
| static std::optional<uint64_t> mapValueType(Type type) { |
| return TypeSwitch<Type, std::optional<uint64_t>>(type) |
| .Case<IntegerType>([&](Type type) -> std::optional<uint64_t> { |
| if (type.isUnsignedInteger()) { |
| switch (type.getIntOrFloatBitWidth()) { |
| case 8: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_8; |
| case 16: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_16; |
| case 32: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_32; |
| case 64: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_UINT_64; |
| default: |
| return std::nullopt; |
| } |
| } |
| switch (type.getIntOrFloatBitWidth()) { |
| case 8: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_8; |
| case 16: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_16; |
| case 32: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_32; |
| case 64: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_64; |
| default: |
| return std::nullopt; |
| } |
| }) |
| .Case<FloatType>([&](Type type) -> std::optional<uint64_t> { |
| if (type.isBF16()) { |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_BFLOAT_16; |
| } |
| switch (type.getIntOrFloatBitWidth()) { |
| case 16: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_16; |
| case 32: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_32; |
| case 64: |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_FLOAT_64; |
| default: |
| return std::nullopt; |
| } |
| }) |
| .Case<IndexType>([&](Type type) -> std::optional<uint64_t> { |
| return IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_SINT_64; |
| }) |
| .Default([&](Type) -> std::optional<uint64_t> { return std::nullopt; }); |
| } |
| |
| struct ConvertHALInstrumentValueOp |
| : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentValueOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InstrumentValueOp instrumentOp, |
| IREE::HAL::InstrumentValueOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = instrumentOp.getLoc(); |
| |
| // Only convert ops we can handle, otherwise warn and discard. |
| std::optional<uint64_t> valueType; |
| if (llvm::isa<LLVM::LLVMPointerType>(operands.getOperand().getType())) { |
| valueType = IREE_INSTRUMENT_DISPATCH_VALUE_TYPE_POINTER; |
| } else { |
| valueType = mapValueType(instrumentOp.getType()); |
| } |
| if (!valueType) { |
| mlir::emitWarning(loc, |
| "skipping hal.instrument.value on unsupported type: ") |
| << instrumentOp.getType(); |
| rewriter.replaceOp(instrumentOp, {operands.getOperand()}); |
| return success(); |
| } |
| |
| auto dataLayout = |
| getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp); |
| auto i64Type = rewriter.getI64Type(); |
| |
| auto entryType = |
| LLVM::LLVMStructType::getLiteral(getContext(), { |
| i64Type, // header |
| i64Type, // value |
| }); |
| |
| // 8 bit tag |
| // 8 bit type |
| // 8 bit ordinal |
| // 40 bit workgroup offset |
| Value header = rewriter.create<LLVM::OrOp>( |
| loc, operands.getWorkgroupKey(), |
| rewriter.create<LLVM::ConstantOp>( |
| loc, i64Type, |
| (instrumentOp.getOrdinal().getZExtValue() << 16) | |
| (valueType.value() << 8) | |
| IREE_INSTRUMENT_DISPATCH_TYPE_VALUE)); |
| |
| // Bitcast to an integer and widen to 64 bits. |
| Value bits = rewriter.create<LLVM::ZExtOp>( |
| loc, i64Type, |
| rewriter.create<LLVM::BitcastOp>( |
| loc, |
| rewriter.getIntegerType( |
| instrumentOp.getType().getIntOrFloatBitWidth()), |
| operands.getOperand())); |
| |
| appendInstrumentationEntry(loc, instrumentOp.getBuffer(), |
| operands.getBuffer(), entryType, |
| { |
| header, |
| bits, |
| }, |
| dataLayout, rewriter); |
| |
| rewriter.replaceOp(instrumentOp, operands.getOperand()); |
| return success(); |
| } |
| }; |
| |
| struct ConvertHALInstrumentMemoryLoadOp |
| : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentMemoryLoadOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InstrumentMemoryLoadOp instrumentOp, |
| IREE::HAL::InstrumentMemoryLoadOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = instrumentOp.getLoc(); |
| auto dataLayout = |
| getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp); |
| auto i64Type = rewriter.getI64Type(); |
| |
| auto entryType = |
| LLVM::LLVMStructType::getLiteral(getContext(), { |
| i64Type, // header |
| i64Type, // address |
| }); |
| |
| // 8 bit tag = 100 (read), 101 (write) |
| // 16 bit length |
| // 40 bit workgroup offset |
| int64_t loadSize = getMemoryAccessByteSize(instrumentOp.getType()); |
| assert(loadSize <= UINT16_MAX && "16-bit length maximum"); |
| Value header = rewriter.create<LLVM::OrOp>( |
| loc, operands.getWorkgroupKey(), |
| rewriter.create<LLVM::ConstantOp>( |
| loc, i64Type, |
| (loadSize << 8) | IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_LOAD)); |
| |
| Value loadPtr = getStridedElementPtr( |
| loc, llvm::cast<MemRefType>(instrumentOp.getBase().getType()), |
| operands.getBase(), operands.getIndices(), rewriter); |
| Value addressI64 = rewriter.create<LLVM::PtrToIntOp>(loc, i64Type, loadPtr); |
| |
| appendInstrumentationEntry(loc, instrumentOp.getBuffer(), |
| operands.getBuffer(), entryType, |
| { |
| header, |
| addressI64, |
| }, |
| dataLayout, rewriter); |
| |
| rewriter.replaceOp(instrumentOp, operands.getLoadValue()); |
| return success(); |
| } |
| }; |
| |
| struct ConvertHALInstrumentMemoryStoreOp |
| : public ConvertOpToLLVMWithABIPattern<IREE::HAL::InstrumentMemoryStoreOp> { |
| using ConvertOpToLLVMWithABIPattern::ConvertOpToLLVMWithABIPattern; |
| LogicalResult matchAndRewrite( |
| IREE::HAL::InstrumentMemoryStoreOp instrumentOp, |
| IREE::HAL::InstrumentMemoryStoreOpAdaptor operands, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = instrumentOp.getLoc(); |
| auto dataLayout = |
| getTypeConverter()->getDataLayoutAnalysis()->getAbove(instrumentOp); |
| auto i64Type = rewriter.getI64Type(); |
| |
| auto entryType = |
| LLVM::LLVMStructType::getLiteral(getContext(), { |
| i64Type, // header |
| i64Type, // address |
| }); |
| |
| // 8 bit tag = 10 (read), 11 (write) |
| // 16 bit length |
| // 40 bit workgroup offset |
| int64_t storeSize = getMemoryAccessByteSize(instrumentOp.getType()); |
| assert(storeSize <= UINT16_MAX && "16-bit length maximum"); |
| Value header = rewriter.create<LLVM::OrOp>( |
| loc, operands.getWorkgroupKey(), |
| rewriter.create<LLVM::ConstantOp>( |
| loc, i64Type, |
| (storeSize << 8) | IREE_INSTRUMENT_DISPATCH_TYPE_MEMORY_STORE)); |
| |
| Value storePtr = getStridedElementPtr( |
| loc, llvm::cast<MemRefType>(instrumentOp.getBase().getType()), |
| operands.getBase(), operands.getIndices(), rewriter); |
| Value addressI64 = |
| rewriter.create<LLVM::PtrToIntOp>(loc, i64Type, storePtr); |
| |
| appendInstrumentationEntry(loc, instrumentOp.getBuffer(), |
| operands.getBuffer(), entryType, |
| { |
| header, |
| addressI64, |
| }, |
| dataLayout, rewriter); |
| |
| rewriter.replaceOp(instrumentOp, operands.getStoreValue()); |
| return success(); |
| } |
| }; |
| |
| /// Helper method to get information about extra operands that need to be |
| /// appended to a function defn/call operation. |
| static SmallVector<StringRef> getExtraFields(Operation *forOp) { |
| SmallVector<StringRef> extraFields; |
| if (auto extraFieldsAttr = |
| forOp->getAttrOfType<ArrayAttr>("hal.import.fields")) { |
| extraFields = llvm::map_to_vector( |
| extraFieldsAttr.getValue(), |
| [](Attribute attr) { return llvm::cast<StringAttr>(attr).getValue(); }); |
| } |
| return extraFields; |
| } |
| |
| /// Return calling convention to use for the operation. |
| static IREE::HAL::CallingConvention getCallingConvention(Operation *forOp) { |
| auto cConv = IREE::HAL::CallingConvention::Default; |
| if (auto cConvAttr = forOp->getAttrOfType<IREE::HAL::CallingConventionAttr>( |
| "hal.import.cconv")) { |
| cConv = cConvAttr.getValue(); |
| } |
| return cConv; |
| } |
| |
| /// Lower func ops with specified ABI. Currently this pattern is triggered |
| /// only for operations with the `hal.import.bitcode` attribute set. |
| /// |
| /// Note: this is an LLVM::CallOp -> LLVM::CallOp rewrite that is introduced |
| /// after all conversions are done. Importantly, this is not a conversion |
| /// pattern. |
| struct RewriteFuncOpABI : public OpRewritePattern<LLVM::LLVMFuncOp> { |
| RewriteFuncOpABI(HALDispatchABI &abi, LLVMTypeConverter &typeConverter) |
| : OpRewritePattern(&typeConverter.getContext()), |
| abi(abi), |
| typeConverter(typeConverter) {} |
| |
| LogicalResult matchAndRewrite(LLVM::LLVMFuncOp funcOp, |
| PatternRewriter &rewriter) const override { |
| if (!funcOp.isExternal()) { |
| return rewriter.notifyMatchFailure(funcOp, "skipping non-external calls"); |
| } |
| if (!funcOp->hasAttr("hal.import.bitcode")) { |
| return rewriter.notifyMatchFailure( |
| funcOp, "callee is not imported using bitcode linkage; skipping"); |
| } |
| IREE::HAL::CallingConvention cConv = getCallingConvention(funcOp); |
| |
| SmallVector<StringRef> extraFields = getExtraFields(funcOp); |
| auto funcType = funcOp.getFunctionType(); |
| FailureOr<LLVM::LLVMFunctionType> expectedType = |
| abi.getABIFunctionType(funcOp, cConv, funcType.getReturnTypes(), |
| funcType.getParams(), extraFields); |
| if (failed(expectedType)) { |
| return rewriter.notifyMatchFailure( |
| funcOp, |
| "unable to get function type to match the calling convention"); |
| } |
| if (abi.hasCompatibleFunctionSignature( |
| rewriter.getContext(), expectedType.value(), |
| funcType.getReturnTypes(), funcType.getParams())) { |
| return failure(); |
| } |
| auto attrs = getPrunedAttributeList( |
| funcOp, llvm::to_vector(LLVM::LLVMFuncOp::getAttributeNames())); |
| SmallVector<DictionaryAttr> argAttrs; |
| if (auto currArgAttrs = funcOp.getArgAttrsAttr()) { |
| argAttrs = llvm::map_to_vector(currArgAttrs, [](Attribute attr) { |
| return llvm::cast<DictionaryAttr>(attr); |
| }); |
| } |
| rewriter.create<LLVM::LLVMFuncOp>( |
| funcOp.getLoc(), funcOp.getName(), expectedType.value(), |
| funcOp.getLinkage(), funcOp.getDsoLocal(), funcOp.getCConv(), attrs, |
| argAttrs, funcOp.getFunctionEntryCount()); |
| rewriter.eraseOp(funcOp); |
| return success(); |
| } |
| |
| private: |
| HALDispatchABI &abi; |
| LLVMTypeConverter &typeConverter; |
| }; |
| |
| /// Lower call ops with specified ABI. The ABI to use is looked up from the |
| /// callee. Currently this pattern is triggered only for operations where the |
| /// callee has the `hal.import.bitcode` attribute set. |
| /// |
| /// Note: this is an LLVM::CallOp -> LLVM::CallOp rewrite that is introduced |
| /// after all conversions are done. Importantly, this is not a conversion |
| /// pattern. |
| struct RewriteCallOpABI : public OpRewritePattern<LLVM::CallOp> { |
| RewriteCallOpABI(HALDispatchABI &abi, LLVMTypeConverter &typeConverter) |
| : OpRewritePattern(&typeConverter.getContext()), |
| abi(abi), |
| typeConverter(typeConverter) {} |
| |
| LogicalResult matchAndRewrite(LLVM::CallOp callOp, |
| PatternRewriter &rewriter) const override { |
| auto symbol = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); |
| auto flatSymbol = llvm::dyn_cast_if_present<FlatSymbolRefAttr>(symbol); |
| if (!flatSymbol) return failure(); |
| |
| // Ensure the target function is extern. |
| // To support conversion inserting calls in local patterns that can't add |
| // global function symbols we assume any missing callee is extern. |
| auto calleeOp = |
| SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(callOp, symbol); |
| if (!calleeOp || !calleeOp->hasAttr("hal.import.bitcode") || |
| !calleeOp.isExternal()) { |
| return rewriter.notifyMatchFailure( |
| callOp, "callee is not imported using bitcode linakge; skipping"); |
| } |
| |
| IREE::HAL::CallingConvention cConv = getCallingConvention(calleeOp); |
| SmallVector<StringRef> extraFields = getExtraFields(calleeOp); |
| |
| FailureOr<SmallVector<Value>> results = abi.materializeABI( |
| callOp, calleeOp.getSymName(), cConv, callOp->getResultTypes(), |
| callOp->getOperands(), extraFields, rewriter); |
| if (failed(results)) { |
| return failure(); |
| } |
| rewriter.replaceOp(callOp, *results); |
| return success(); |
| } |
| |
| private: |
| HALDispatchABI &abi; |
| LLVMTypeConverter &typeConverter; |
| }; |
| |
| /// Rewrites calls to extern functions to dynamic library import calls. |
| /// The parent LLVMFuncOp must be compatible with HALDispatchABI. |
| /// |
| /// Note: this is an LLVM::CallOp -> LLVM::CallOp rewrite that is introduced |
| /// after all conversions are done. Importantly, this is not a conversion |
| /// pattern. |
| struct RewriteExternCallOpToDynamicImportCallOp |
| : public OpRewritePattern<LLVM::CallOp> { |
| RewriteExternCallOpToDynamicImportCallOp(HALDispatchABI &abi, |
| LLVMTypeConverter &typeConverter) |
| : OpRewritePattern(&typeConverter.getContext()), |
| abi(abi), |
| typeConverter(typeConverter) {} |
| LogicalResult matchAndRewrite(LLVM::CallOp callOp, |
| PatternRewriter &rewriter) const override { |
| // Ignore indirect calls (they're probably already converted imports). |
| auto symbol = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>(); |
| auto flatSymbol = llvm::dyn_cast_if_present<FlatSymbolRefAttr>(symbol); |
| if (!flatSymbol) return failure(); |
| |
| // Ensure the target function is extern. |
| // To support conversion inserting calls in local patterns that can't add |
| // global function symbols we assume any missing callee is extern. |
| auto calleeOp = |
| SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(callOp, symbol); |
| if (calleeOp && !calleeOp.isExternal()) { |
| return rewriter.notifyMatchFailure( |
| callOp, |
| "callee is not external; treating as a normal call and skipping " |
| "import logic"); |
| } |
| |
| // If the function is marked as statically linked we don't touch it. That'll |
| // let it fall through to the linker stage where it can be picked up either |
| // from the runtime build (in the case of us producing static libraries) or |
| // the user-specified object files (when producing dynamic libraries). |
| if (calleeOp->hasAttr("hal.import.static") || |
| calleeOp->hasAttr("hal.import.bitcode")) { |
| return rewriter.notifyMatchFailure(callOp, |
| "external function is marked static " |
| "and does not need an import wrapper"); |
| } |
| |
| // The call may need some additional internal fields appended. |
| SmallVector<StringRef> extraFields; |
| if (auto extraFieldsAttr = |
| calleeOp->getAttrOfType<ArrayAttr>("hal.import.fields")) { |
| for (auto extraFieldAttr : extraFieldsAttr) { |
| extraFields.push_back( |
| llvm::cast<StringAttr>(extraFieldAttr).getValue()); |
| } |
| } |
| |
| // Allow multiple imports to alias by having their name explicitly |
| // specified. |
| StringRef importName = flatSymbol.getValue(); |
| if (auto importNameAttr = |
| calleeOp->getAttrOfType<StringAttr>("hal.import.name")) { |
| importName = importNameAttr.getValue(); |
| } |
| |
| // TODO(benvanik): way to determine if weak (maybe via linkage?). |
| bool weak = false; |
| |
| // Rewrite the call to a dynamic import call. |
| SmallVector<Value> results = abi.wrapAndCallImport( |
| callOp, importName, weak, callOp->getResultTypes(), |
| callOp->getOperands(), extraFields, rewriter); |
| |
| rewriter.replaceOp(callOp, results); |
| return success(); |
| } |
| HALDispatchABI &abi; |
| LLVMTypeConverter &typeConverter; |
| }; |
| |
| /// The 32-bit RISC-V backend is very sensitive to how extended multiplication |
| /// is lowered. This pattern lowers `arith.mulsi_extended` before going to the |
| /// LLVM dialect, in a way compatible with that backend, so that we break down |
| /// any 64-bit constants that would otherwise prevent the code from being |
| /// vectorized. |
| class ExpandMulSIExtended : public OpRewritePattern<arith::MulSIExtendedOp> { |
| public: |
| using OpRewritePattern<arith::MulSIExtendedOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(arith::MulSIExtendedOp op, |
| PatternRewriter &rewriter) const override { |
| Type resultType = op.getLhs().getType(); |
| if (getElementTypeOrSelf(resultType).getIntOrFloatBitWidth() != 32) { |
| return failure(); |
| } |
| |
| Location loc = op.getLoc(); |
| |
| Type wideType = rewriter.getIntegerType(64); |
| // Shift amount necessary to extract the high bits from widened result. |
| TypedAttr shiftValAttr = rewriter.getI64IntegerAttr(32); |
| if (auto vecTy = llvm::dyn_cast<VectorType>(resultType)) { |
| wideType = VectorType::get(vecTy.getShape(), wideType); |
| shiftValAttr = |
| SplatElementsAttr::get(cast<ShapedType>(wideType), shiftValAttr); |
| } |
| Value shiftVal = rewriter.create<arith::ConstantOp>(loc, shiftValAttr); |
| |
| Value lhsExt = rewriter.create<arith::ExtSIOp>(loc, wideType, op.getLhs()); |
| Value rhsExt = rewriter.create<arith::ExtSIOp>(loc, wideType, op.getRhs()); |
| Value mulExt = |
| rewriter.create<arith::MulIOp>(loc, wideType, lhsExt, rhsExt); |
| Value low = rewriter.create<arith::MulIOp>(loc, resultType, op.getLhs(), |
| op.getRhs()); |
| |
| // Produce two 32-bit results. |
| Value highExt = rewriter.create<arith::ShRUIOp>(loc, mulExt, shiftVal); |
| Value high = rewriter.create<arith::TruncIOp>(loc, resultType, highExt); |
| |
| rewriter.replaceOp(op, {low, high}); |
| return success(); |
| } |
| }; |
| |
| class ConvertToLLVMPass : public ConvertToLLVMBase<ConvertToLLVMPass> { |
| public: |
| ConvertToLLVMPass(bool reassociateFpReductions) { |
| targetReassociateFpReductions.setValue(reassociateFpReductions); |
| } |
| ConvertToLLVMPass(const ConvertToLLVMPass &pass) {} |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| registry.insert<LLVM::LLVMDialect, arm_neon::ArmNeonDialect>(); |
| } |
| |
| void runOnOperation() override; |
| |
| private: |
| Option<std::string> targetTriple{ |
| *this, "target-triple", llvm::cl::desc("Code generation target triple."), |
| llvm::cl::init("")}; |
| Option<std::string> targetDataLayout{ |
| *this, "target-data-layout", |
| llvm::cl::desc("Code generation target data layout."), |
| llvm::cl::init("")}; |
| Option<bool> targetReassociateFpReductions{ |
| *this, "target-reassociate-fp-reductions", |
| llvm::cl::desc("Code generation target reassociate FP reductions."), |
| llvm::cl::init("false")}; |
| }; |
| |
| } // namespace |
| |
| static std::string getStringAttrFromTargetAttr(ModuleOp module, |
| StringRef attrName) { |
| auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(module); |
| auto stringAttr = getConfigStringAttr(targetAttr, attrName); |
| return stringAttr ? stringAttr.value().str() : std::string(""); |
| } |
| |
| void ConvertToLLVMPass::runOnOperation() { |
| auto module = getOperation(); |
| std::string dataLayoutStr = targetDataLayout.getValue(); |
| if (targetDataLayout.empty()) { |
| dataLayoutStr = getStringAttrFromTargetAttr(module, "data_layout"); |
| } |
| std::string targetTripleStr = targetTriple.getValue(); |
| if (targetTripleStr.empty()) { |
| targetTripleStr = getStringAttrFromTargetAttr(module, "target_triple"); |
| } |
| |
| // Add required attributes to the module so that the lowering knows how to |
| // handle structs and data layouts. |
| module->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(), |
| StringAttr::get(module->getContext(), targetTripleStr)); |
| module->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), |
| StringAttr::get(module->getContext(), dataLayoutStr)); |
| |
| // Run Vector -> Vector transformations ahead of conversion to LLVM. |
| { |
| RewritePatternSet patterns(&getContext()); |
| vector::populateVectorToVectorCanonicalizationPatterns(patterns); |
| vector::populateVectorBroadcastLoweringPatterns(patterns); |
| // TODO: doubtful that the "default" does what one want here, it is likely |
| // better to use outerproduct. |
| vector::populateVectorContractLoweringPatterns( |
| patterns, vector::VectorTransformsOptions()); |
| vector::populateVectorMaskMaterializationPatterns( |
| patterns, /*force32BitVectorIndices=*/false); |
| vector::populateVectorMaskOpLoweringPatterns(patterns); |
| vector::populateVectorShapeCastLoweringPatterns(patterns); |
| // TODO: doubtful that the "default" does what one want here, it is likely |
| // better to use shuffle. |
| vector::populateVectorTransposeLoweringPatterns( |
| patterns, vector::VectorTransformsOptions()); |
| populateConvertArmNeon2dToIntrPatterns(patterns); |
| if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| std::move(patterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| { |
| RewritePatternSet vectorToLoopsPatterns(&getContext()); |
| populateVectorToSCFConversionPatterns( |
| vectorToLoopsPatterns, VectorTransferToSCFOptions().enableFullUnroll()); |
| if (failed(applyPatternsAndFoldGreedily( |
| getOperation(), std::move(vectorToLoopsPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); |
| LowerToLLVMOptions options(&getContext(), |
| dataLayoutAnalysis.getAtOrAbove(module)); |
| options.dataLayout = llvm::DataLayout(dataLayoutStr); |
| options.overrideIndexBitwidth(options.dataLayout.getPointerSizeInBits()); |
| LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis); |
| |
| RewritePatternSet patterns(&getContext()); |
| |
| // Use the default 64-bit lowering for TOSA's ApplyScale operator: |
| // This lowering widens integer types to 64-bit an performs the non-fused |
| // operations, specifically multiply, add, and shift. Bit-widening |
| // is used to guarantee higher-order bits are not truncated during the |
| // multiply or add. |
| // |
| // TODO(bjacob): Use a lowering that uses specific ARM/X86 intrinsics. |
| bool use32BitImpl = false; |
| auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(module); |
| if (isRISCV(targetAttr)) { |
| // Use the 32-bit lowering for RISC-V if 'zve32*' is specified and there is |
| // no 64-bit integer vector support. |
| // TODO(#9440) Simplify logic when 'cpu_features' is simplified. |
| use32BitImpl = |
| (hasZve32xFeature(targetAttr) || hasZve32fFeature(targetAttr)) && |
| !hasVFeature(targetAttr) && !hasZve64xFeature(targetAttr); |
| } |
| tosa::populateTosaRescaleToArithConversionPatterns(&patterns, use32BitImpl); |
| |
| // Make sure we expand any `arith.mulsi_extended` before going to the LLVM |
| // dialect. |
| if (use32BitImpl) { |
| patterns.add<ExpandMulSIExtended>(patterns.getContext(), /*benefit=*/1024); |
| } |
| |
| populateAffineToStdConversionPatterns(patterns); |
| populateSCFToControlFlowConversionPatterns(patterns); |
| cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); |
| populateExpandTanhPattern(patterns); |
| |
| populateComplexToLLVMConversionPatterns(typeConverter, patterns); |
| populateMathToLLVMConversionPatterns(typeConverter, patterns); |
| memref::populateExpandStridedMetadataPatterns(patterns); |
| populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); |
| populateFuncToLLVMConversionPatterns(typeConverter, patterns); |
| arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); |
| arith::populateExpandBFloat16Patterns(patterns); |
| populateVectorToSCFConversionPatterns(patterns); |
| populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns); |
| populateVectorToLLVMConversionPatterns( |
| typeConverter, patterns, targetReassociateFpReductions.getValue()); |
| populateLinalgToLLVMConversionPatterns(typeConverter, patterns); |
| populateReconcileUnrealizedCastsPatterns(patterns); |
| |
| HALDispatchABI abi(&typeConverter); |
| // clang-format off |
| patterns.insert< |
| ConvertHALEntryPointFuncOp, |
| ConvertHALExecutableConstantLoadOp, |
| ConvertHALInterfaceWorkgroupIDOp, |
| ConvertHALInterfaceWorkgroupSizeOp, |
| ConvertHALInterfaceWorkgroupCountOp, |
| ConvertHALInterfaceConstantLoadOp, |
| ConvertHALInterfaceBindingSubspanOp, |
| ConvertHALInstrumentWorkgroupOp, |
| ConvertHALInstrumentValueOp, |
| ConvertHALInstrumentMemoryLoadOp, |
| ConvertHALInstrumentMemoryStoreOp |
| >(abi, typeConverter); |
| // clang-format on |
| |
| LLVMConversionTarget target(getContext()); |
| target.addLegalOp<ModuleOp>(); |
| target.addIllegalDialect<func::FuncDialect, mlir::arith::ArithDialect, |
| IREE::Util::UtilDialect, IREE::HAL::HALDialect, |
| math::MathDialect, tosa::TosaDialect>(); |
| target.addIllegalOp<UnrealizedConversionCastOp>(); |
| |
| if (failed(applyPartialConversion(module, target, std::move(patterns)))) { |
| signalPassFailure(); |
| return; |
| } |
| |
| // Rewrite any extern calls emitted to dynamic library imports. |
| { |
| RewritePatternSet patterns(&getContext()); |
| patterns.insert<RewriteExternCallOpToDynamicImportCallOp, RewriteCallOpABI, |
| RewriteFuncOpABI>(abi, typeConverter); |
| if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) |
| return signalPassFailure(); |
| } |
| |
| // Post conversion patterns. |
| { |
| RewritePatternSet postPatterns(&getContext()); |
| // TODO(ravishankarm): Move this to a separate pass. |
| llvm::Triple triple(targetTripleStr); |
| if (triple.isWasm()) { |
| populateUnfusedFMAOpsPassPatterns(&getContext(), postPatterns); |
| if (failed( |
| applyPatternsAndFoldGreedily(module, std::move(postPatterns)))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| } |
| |
| std::unique_ptr<OperationPass<ModuleOp>> createConvertToLLVMPass( |
| bool reassociateFpReductions) { |
| return std::make_unique<ConvertToLLVMPass>(reassociateFpReductions); |
| } |
| |
| } // namespace iree_compiler |
| } // namespace mlir |