| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/LLVMCPU/DispatchABI.h" |
| |
| #include "iree/compiler/Codegen/Utils/Utils.h" |
| #include "iree/schemas/cpu_data.h" |
| #include "llvm/BinaryFormat/Dwarf.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/Path.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| |
| static llvm::cl::opt<bool> clVerboseDebugInfo( |
| "iree-codegen-llvm-verbose-debug-info", |
| llvm::cl::desc("Emit verbose debug information in LLVM IR."), |
| llvm::cl::init(false)); |
| |
| namespace mlir::iree_compiler { |
| |
| //------------------------------------------------------------------------------ |
| // ExecutableLibraryDI |
| //------------------------------------------------------------------------------ |
| |
| // NOTE: the debug information used here is only present as a compiler developer |
| // aid. It may get out of sync with published versions of the executable ABI and |
| // may not be very clean. For example, we size push constant and binding arrays |
| // not based on the actual layout of the pipeline layout but with some |
| // reasonable limit that allows for use in a debugger. If we wanted to improve |
| // this we could customize per-scope the structures and look up the |
| // IREE::HAL::PipelineLayoutAttr for each entry point to discover the real |
| // limits. |
| // |
| // NOTE: MLIR and subsequent LLVM optimizations will often remove a lot of this |
| // debug information (or at least make it less useful). This can happen even in |
| // LLVM modes of -O0 as MLIR has no such configurability at this time. |
| // |
| // It'd be nice to have an automatic sync of the debug information and structs |
| // in this file such that we'd get matching source file/line information with |
| // the runtime headers and be able to delete most of this hand-authored code. |
| // The current debug information and types were constructed by compiling |
| // executable_library_demo.c to LLVM IR and then importing it into MLIR to see |
| // what the attributes look like. We could automate this and embed the |
| // attributes in the compiler binary but due to the differences in 32/64-bit |
| // pointer widths some manual massaging may still be required (or we just embed |
| // both). |
| // |
| // $ clang -emit-llvm -Iruntime/src/ \ |
| // runtime/src/iree/hal/local/executable_library_demo.c -g -S \ |
| // --target=x86_64-pc-windows-elf |
| // $ mlir-translate --import-llvm executable_library_demo.ll |
| |
| // Returns the size, in bits, of |typeAttr|. |
| static unsigned getDITypeSizeInBits(LLVM::DITypeAttr typeAttr) { |
| if (auto basicTypeAttr = llvm::dyn_cast<LLVM::DIBasicTypeAttr>(typeAttr)) { |
| return basicTypeAttr.getSizeInBits(); |
| } else if (auto derivedTypeAttr = |
| llvm::dyn_cast<LLVM::DIDerivedTypeAttr>(typeAttr)) { |
| if (unsigned derivedSize = derivedTypeAttr.getSizeInBits()) { |
| return derivedSize; |
| } else { |
| return getDITypeSizeInBits(derivedTypeAttr.getBaseType()); |
| } |
| } else { |
| return 0; |
| } |
| } |
| |
| ExecutableLibraryDI::ExecutableLibraryDI(const LLVMTypeConverter *typeConverter) |
| : builder(&typeConverter->getContext()) { |
| auto *context = builder.getContext(); |
| fileAttr = LLVM::DIFileAttr::get( |
| context, "runtime/src/iree/hal/local/executable_library.h", "."); |
| ptrBitwidth = typeConverter->getPointerBitwidth(); |
| |
| voidPtr = getPtrOf(LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "void", |
| /*sizeInBits=*/0, llvm::dwarf::DW_ATE_address)); |
| int8T = getTypedefOf("int8_t", |
| LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "char", |
| /*sizeInBits=*/8, llvm::dwarf::DW_ATE_signed_char)); |
| uint8T = getTypedefOf( |
| "uint8_t", LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "unsigned char", |
| /*sizeInBits=*/8, llvm::dwarf::DW_ATE_unsigned_char)); |
| int16T = getTypedefOf("int16_t", |
| LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "short", |
| /*sizeInBits=*/16, llvm::dwarf::DW_ATE_signed)); |
| uint16T = getTypedefOf( |
| "uint16_t", LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "unsigned short", |
| /*sizeInBits=*/16, llvm::dwarf::DW_ATE_unsigned)); |
| int32T = getTypedefOf("int32_t", |
| LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "int", |
| /*sizeInBits=*/32, llvm::dwarf::DW_ATE_signed)); |
| uint32T = getTypedefOf( |
| "uint32_t", LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "unsigned int", |
| /*sizeInBits=*/32, llvm::dwarf::DW_ATE_unsigned)); |
| int64T = getTypedefOf( |
| "int64_t", LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, "long long int", |
| /*sizeInBits=*/64, llvm::dwarf::DW_ATE_signed)); |
| uint64T = getTypedefOf("uint64_t", |
| LLVM::DIBasicTypeAttr::get( |
| context, llvm::dwarf::DW_TAG_base_type, |
| "long long unsigned int", |
| /*sizeInBits=*/64, llvm::dwarf::DW_ATE_unsigned)); |
| intptrT = |
| getTypedefOf("intptr_t", ptrBitwidth == 32 ? getInt32T() : getInt64T()); |
| sizeT = |
| getTypedefOf("size_t", ptrBitwidth == 32 ? getUint32T() : getUint64T()); |
| } |
| |
| LLVM::DIDerivedTypeAttr |
| ExecutableLibraryDI::getConstOf(LLVM::DITypeAttr typeAttr) { |
| return LLVM::DIDerivedTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_const_type, |
| /*name=*/nullptr, typeAttr, /*sizeInBits=*/0, /*alignInBits=*/0, |
| /*offsetInBits=*/0, /*dwarfAddressSpace=*/std::nullopt, |
| /*extraData=*/nullptr); |
| } |
| |
| LLVM::DIDerivedTypeAttr |
| ExecutableLibraryDI::getPtrOf(LLVM::DITypeAttr typeAttr) { |
| return LLVM::DIDerivedTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_pointer_type, |
| /*name=*/nullptr, typeAttr, /*sizeInBits=*/ptrBitwidth, |
| /*alignInBits=*/0, |
| /*offsetInBits=*/0, |
| /*dwarfAddressSpace=*/std::nullopt, |
| /*extraData=*/nullptr); |
| } |
| |
| LLVM::DICompositeTypeAttr |
| ExecutableLibraryDI::getArrayOf(LLVM::DITypeAttr typeAttr, int64_t count) { |
| return LLVM::DICompositeTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_array_type, /*recId=*/{}, |
| /*name=*/builder.getStringAttr(""), fileAttr, |
| /*line=*/227, fileAttr, |
| /*baseType=*/typeAttr, LLVM::DIFlags::Zero, |
| /*sizeInBits=*/getDITypeSizeInBits(typeAttr) * count, |
| /*alignInBits=*/0, |
| { |
| LLVM::DISubrangeAttr::get( |
| builder.getContext(), builder.getI64IntegerAttr(count), |
| /*lowerBound=*/nullptr, /*upperBound=*/nullptr, |
| /*stride=*/nullptr), |
| }); |
| } |
| |
| LLVM::DIDerivedTypeAttr |
| ExecutableLibraryDI::getTypedefOf(StringRef name, LLVM::DITypeAttr typeAttr) { |
| return LLVM::DIDerivedTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_typedef, |
| builder.getStringAttr(name), typeAttr, /*sizeInBits=*/0, |
| /*alignInBits=*/0, /*offsetInBits=*/0, /*dwarfAddressSpace=*/std::nullopt, |
| /*extraData=*/nullptr); |
| } |
| |
| LLVM::DIDerivedTypeAttr |
| ExecutableLibraryDI::getMemberOf(StringRef name, LLVM::DITypeAttr typeAttr, |
| unsigned *offsetInBits) { |
| unsigned memberOffsetInBits = *offsetInBits; |
| unsigned memberSizeInBits = getDITypeSizeInBits(typeAttr); |
| *offsetInBits += memberSizeInBits; |
| return LLVM::DIDerivedTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_member, |
| builder.getStringAttr(name), typeAttr, |
| /*sizeInBits=*/memberSizeInBits, /*alignInBits=*/0, |
| /*offsetInBits=*/memberOffsetInBits, /*dwarfAddressSpace=*/std::nullopt, |
| /*extraData=*/nullptr); |
| } |
| |
| LLVM::DITypeAttr ExecutableLibraryDI::getBasicType(Type type) { |
| return TypeSwitch<Type, LLVM::DITypeAttr>(type) |
| .Case([&](IndexType) { return getIntptrT(); }) |
| .Case([&](IntegerType integerType) -> LLVM::DITypeAttr { |
| unsigned bitWidth = integerType.getIntOrFloatBitWidth(); |
| switch (bitWidth) { |
| case 8: |
| return integerType.isUnsigned() ? getUint8T() : getInt8T(); |
| case 16: |
| return integerType.isUnsigned() ? getUint16T() : getInt16T(); |
| case 32: |
| return integerType.isUnsigned() ? getUint32T() : getInt32T(); |
| case 64: |
| return integerType.isUnsigned() ? getUint64T() : getInt64T(); |
| default: |
| return LLVM::DIBasicTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_base_type, |
| StringRef("int") + std::to_string(bitWidth), |
| /*sizeInBits=*/bitWidth, |
| integerType.isUnsigned() ? llvm::dwarf::DW_ATE_unsigned |
| : llvm::dwarf::DW_ATE_signed); |
| } |
| }) |
| .Case([&](FloatType floatType) -> LLVM::DITypeAttr { |
| unsigned bitWidth = floatType.getIntOrFloatBitWidth(); |
| return LLVM::DIBasicTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_base_type, |
| StringRef("float") + std::to_string(bitWidth), |
| /*sizeInBits=*/bitWidth, llvm::dwarf::DW_ATE_float); |
| }) |
| .Default([](Type) { |
| assert(false && "unhandled basic type"); |
| return nullptr; |
| }); |
| } |
| |
| LLVM::DICompositeTypeAttr ExecutableLibraryDI::getProcessorV0T() { |
| unsigned offsetInBits = 0; |
| return LLVM::DICompositeTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_structure_type, /*recId=*/{}, |
| builder.getStringAttr("iree_hal_processor_v0_t"), fileAttr, |
| /*line=*/227, fileAttr, |
| /*baseType=*/nullptr, LLVM::DIFlags::Zero, /*sizeInBits=*/512, |
| /*alignInBits=*/0, |
| { |
| getMemberOf("data", getArrayOf(getUint64T(), 8), &offsetInBits), |
| }); |
| } |
| |
| LLVM::DIDerivedTypeAttr ExecutableLibraryDI::getEnvironmentV0T() { |
| unsigned offsetInBits = 0; |
| return getTypedefOf( |
| "iree_hal_executable_environment_v0_t", |
| LLVM::DICompositeTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_structure_type, |
| /*recId=*/{}, |
| builder.getStringAttr("iree_hal_executable_environment_v0_t"), |
| fileAttr, |
| /*line=*/246, fileAttr, |
| /*baseType=*/nullptr, LLVM::DIFlags::Zero, /*sizeInBits=*/768, |
| /*alignInBits=*/0, |
| { |
| getMemberOf("constants", |
| getPtrOf(getConstOf(getArrayOf(getUint32T(), 64))), |
| &offsetInBits), |
| getMemberOf("import_thunk", getVoidPtr(), &offsetInBits), |
| getMemberOf("import_funcs", getPtrOf(getConstOf(getVoidPtr())), |
| &offsetInBits), |
| getMemberOf("import_contexts", |
| getPtrOf(getPtrOf(getConstOf(getVoidPtr()))), |
| &offsetInBits), |
| getMemberOf("processor", getProcessorV0T(), &offsetInBits), |
| })); |
| } |
| |
| LLVM::DIDerivedTypeAttr ExecutableLibraryDI::getDispatchStateV0T() { |
| unsigned offsetInBits = 0; |
| return getTypedefOf( |
| "iree_hal_executable_dispatch_state_v0_t", |
| LLVM::DICompositeTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_structure_type, |
| /*recId=*/{}, |
| builder.getStringAttr("iree_hal_executable_dispatch_state_v0_t"), |
| fileAttr, /*line=*/275, fileAttr, |
| /*baseType=*/nullptr, LLVM::DIFlags::Zero, /*sizeInBits=*/384, |
| /*alignInBits=*/0, |
| { |
| getMemberOf("workgroup_size_x", getUint32T(), &offsetInBits), |
| getMemberOf("workgroup_size_y", getUint32T(), &offsetInBits), |
| getMemberOf("workgroup_size_z", getUint16T(), &offsetInBits), |
| getMemberOf("push_constant_count", getUint16T(), &offsetInBits), |
| getMemberOf("workgroup_count_x", getUint32T(), &offsetInBits), |
| getMemberOf("workgroup_count_y", getUint32T(), &offsetInBits), |
| getMemberOf("workgroup_count_z", getUint16T(), &offsetInBits), |
| getMemberOf("max_concurrency", getUint8T(), &offsetInBits), |
| getMemberOf("binding_count", getUint8T(), &offsetInBits), |
| getMemberOf("push_constants", |
| getPtrOf(getConstOf(getArrayOf(getUint32T(), 64))), |
| &offsetInBits), |
| getMemberOf( |
| "binding_ptrs", |
| getPtrOf(getConstOf(getArrayOf(getPtrOf(getUint8T()), 64))), |
| &offsetInBits), |
| getMemberOf("binding_lengths", |
| getPtrOf(getConstOf(getArrayOf(getSizeT(), 64))), |
| &offsetInBits), |
| })); |
| } |
| |
| LLVM::DIDerivedTypeAttr ExecutableLibraryDI::getWorkgroupStateV0T() { |
| unsigned offsetInBits = 0; |
| return getTypedefOf( |
| "iree_hal_executable_workgroup_state_v0_t", |
| LLVM::DICompositeTypeAttr::get( |
| builder.getContext(), llvm::dwarf::DW_TAG_structure_type, |
| /*recId=*/{}, |
| builder.getStringAttr("iree_hal_executable_workgroup_state_v0_t"), |
| fileAttr, /*line=*/321, fileAttr, |
| /*baseType=*/nullptr, LLVM::DIFlags::Zero, /*sizeInBits=*/256, |
| /*alignInBits=*/0, |
| { |
| getMemberOf("workgroup_id_x", getUint32T(), &offsetInBits), |
| getMemberOf("workgroup_id_y", getUint32T(), &offsetInBits), |
| getMemberOf("workgroup_id_z", getUint16T(), &offsetInBits), |
| getMemberOf("reserved", getUint16T(), &offsetInBits), |
| getMemberOf("processor_id", getUint32T(), &offsetInBits), |
| getMemberOf("local_memory", getVoidPtr(), &offsetInBits), |
| getMemberOf("local_memory_size", getUint32T(), &offsetInBits), |
| })); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // HALDispatchABI |
| //------------------------------------------------------------------------------ |
| |
| // static |
| llvm::sys::Mutex HALDispatchABI::sMutex; |
| |
| // static |
| LLVM::LLVMStructType |
| HALDispatchABI::getProcessorType(MLIRContext *context, |
| const LLVMTypeConverter *typeConverter) { |
| llvm::sys::ScopedLock lock(sMutex); |
| auto structType = |
| LLVM::LLVMStructType::getIdentified(context, "iree_hal_processor_v0_t"); |
| if (structType.isInitialized()) |
| return structType; |
| |
| auto uint64Type = IntegerType::get(context, 64); |
| SmallVector<Type> fieldTypes; |
| |
| // uint64_t data[IREE_HAL_PROCESSOR_DATA_CAPACITY_V0]; |
| fieldTypes.push_back( |
| LLVM::LLVMArrayType::get(uint64Type, ProcessorDataCapacity)); |
| |
| LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false); |
| assert(succeeded(bodySet) && |
| "could not set the body of an identified struct"); |
| (void)bodySet; |
| |
| return structType; |
| } |
| |
| // static |
| LLVM::LLVMStructType |
| HALDispatchABI::getEnvironmentType(MLIRContext *context, |
| const LLVMTypeConverter *typeConverter, |
| LLVM::LLVMStructType processorType) { |
| llvm::sys::ScopedLock lock(sMutex); |
| auto structType = LLVM::LLVMStructType::getIdentified( |
| context, "iree_hal_executable_environment_v0_t"); |
| if (structType.isInitialized()) |
| return structType; |
| |
| auto opaquePtrType = LLVM::LLVMPointerType::get(context); |
| SmallVector<Type> fieldTypes; |
| |
| // const uint32_t* constants; |
| fieldTypes.push_back(opaquePtrType); |
| |
| // iree_hal_executable_import_thunk_v0_t import_thunk; |
| // const iree_hal_executable_import_v0_t* import_funcs; |
| // const void** import_contexts; |
| fieldTypes.push_back(LLVM::LLVMPointerType::get(context)); |
| fieldTypes.push_back(LLVM::LLVMPointerType::get(context)); |
| fieldTypes.push_back(LLVM::LLVMPointerType::get(context)); |
| |
| // iree_hal_processor_v0_t processor; |
| fieldTypes.push_back(processorType); |
| |
| LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false); |
| assert(succeeded(bodySet) && |
| "could not set the body of an identified struct"); |
| (void)bodySet; |
| |
| return structType; |
| } |
| |
| // static |
| LLVM::LLVMStructType |
| HALDispatchABI::getDispatchStateType(MLIRContext *context, |
| const LLVMTypeConverter *typeConverter) { |
| llvm::sys::ScopedLock lock(sMutex); |
| auto structType = LLVM::LLVMStructType::getIdentified( |
| context, "iree_hal_executable_dispatch_state_v0_t"); |
| if (structType.isInitialized()) |
| return structType; |
| |
| auto uint8Type = IntegerType::get(context, 8); |
| auto uint16Type = IntegerType::get(context, 16); |
| auto uint32Type = IntegerType::get(context, 32); |
| auto opaquePtrType = LLVM::LLVMPointerType::get(context); |
| SmallVector<Type> fieldTypes; |
| |
| // uint32_t workgroup_size_x; |
| // uint32_t workgroup_size_y; |
| // uint16_t workgroup_size_z; |
| fieldTypes.push_back(uint32Type); |
| fieldTypes.push_back(uint32Type); |
| fieldTypes.push_back(uint16Type); |
| |
| // uint16_t push_constant_count; |
| fieldTypes.push_back(uint16Type); |
| |
| // uint32_t workgroup_count_x; |
| // uint32_t workgroup_count_y; |
| // uint16_t workgroup_count_z; |
| fieldTypes.push_back(uint32Type); |
| fieldTypes.push_back(uint32Type); |
| fieldTypes.push_back(uint16Type); |
| |
| // uint8_t max_concurrency; |
| fieldTypes.push_back(uint8Type); |
| |
| // uint8_t binding_count; |
| fieldTypes.push_back(uint8Type); |
| |
| // const uint32_t * push_constants; |
| // void *const * binding_ptrs; |
| // const size_t * binding_lengths; |
| fieldTypes.push_back(opaquePtrType); |
| fieldTypes.push_back(opaquePtrType); |
| fieldTypes.push_back(opaquePtrType); |
| |
| LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false); |
| assert(succeeded(bodySet) && |
| "could not set the body of an identified struct"); |
| (void)bodySet; |
| |
| return structType; |
| } |
| |
| // static |
| LLVM::LLVMStructType |
| HALDispatchABI::getWorkgroupStateType(MLIRContext *context, |
| const LLVMTypeConverter *typeConverter) { |
| llvm::sys::ScopedLock lock(sMutex); |
| auto structType = LLVM::LLVMStructType::getIdentified( |
| context, "iree_hal_executable_workgroup_state_v0_t"); |
| if (structType.isInitialized()) |
| return structType; |
| |
| auto uint16Type = IntegerType::get(context, 16); |
| auto uint32Type = IntegerType::get(context, 32); |
| auto opaquePtrType = LLVM::LLVMPointerType::get(context); |
| SmallVector<Type> fieldTypes; |
| |
| // uint32_t workgroup_id_x; |
| // uint32_t workgroup_id_y; |
| // uint16_t workgroup_id_z; |
| fieldTypes.push_back(uint32Type); |
| fieldTypes.push_back(uint32Type); |
| fieldTypes.push_back(uint16Type); |
| |
| // uint16_t reserved; |
| fieldTypes.push_back(uint16Type); |
| |
| // uint32_t processor_id; |
| fieldTypes.push_back(uint32Type); |
| |
| // void* local_memory; |
| // uint32_t local_memory_size; |
| fieldTypes.push_back(opaquePtrType); |
| fieldTypes.push_back(uint32Type); |
| |
| LogicalResult bodySet = structType.setBody(fieldTypes, /*isPacked=*/false); |
| assert(succeeded(bodySet) && |
| "could not set the body of an identified struct"); |
| (void)bodySet; |
| |
| return structType; |
| } |
| |
| // static |
| SmallVector<Type, 5> |
| HALDispatchABI::getInputTypes(MLIRContext *context, |
| const LLVMTypeConverter *typeConverter) { |
| return SmallVector<Type, 5>{ |
| // const iree_hal_executable_environment_v0_t* IREE_RESTRICT |
| // environment |
| LLVM::LLVMPointerType::get(context), |
| // const iree_hal_executable_dispatch_state_v0_t* IREE_RESTRICT |
| // dispatch_state |
| LLVM::LLVMPointerType::get(context), |
| // const iree_hal_executable_workgroup_state_v0_t* IREE_RESTRICT |
| // workgroup_state |
| LLVM::LLVMPointerType::get(context), |
| }; |
| } |
| |
| // static |
| LLVM::DISubprogramAttr |
| HALDispatchABI::buildScopeAttr(mlir::ModuleOp moduleOp, |
| LLVM::LLVMFuncOp llvmFuncOp, |
| const LLVMTypeConverter *typeConverter) { |
| auto *context = &typeConverter->getContext(); |
| Builder builder(context); |
| |
| std::string inputFilePath("-"); |
| if (auto fileLoc = llvm::dyn_cast<mlir::FileLineColLoc>(moduleOp.getLoc())) { |
| inputFilePath = fileLoc.getFilename().getValue(); |
| } |
| |
| auto fileAttr = |
| LLVM::DIFileAttr::get(context, llvm::sys::path::filename(inputFilePath), |
| llvm::sys::path::parent_path(inputFilePath)); |
| auto compileUnitAttr = LLVM::DICompileUnitAttr::get( |
| DistinctAttr::create(UnitAttr::get(context)), llvm::dwarf::DW_LANG_C17, |
| fileAttr, builder.getStringAttr("IREE"), |
| /*isOptimized=*/true, LLVM::DIEmissionKind::Full); |
| |
| auto int32TypeAttr = |
| LLVM::DIBasicTypeAttr::get(context, llvm::dwarf::DW_TAG_base_type, "int", |
| /*sizeInBits=*/32, llvm::dwarf::DW_ATE_signed); |
| ExecutableLibraryDI di(typeConverter); |
| auto subroutineTypeAttr = LLVM::DISubroutineTypeAttr::get( |
| context, llvm::dwarf::DW_CC_normal, |
| { |
| int32TypeAttr, |
| di.getPtrOf(di.getConstOf(di.getEnvironmentV0T())), |
| di.getPtrOf(di.getConstOf(di.getDispatchStateV0T())), |
| di.getPtrOf(di.getConstOf(di.getWorkgroupStateV0T())), |
| }); |
| |
| auto funcNameAttr = builder.getStringAttr(llvmFuncOp.getName()); |
| DistinctAttr id; |
| if (!llvmFuncOp.isExternal()) { |
| id = DistinctAttr::create(UnitAttr::get(context)); |
| } |
| return LLVM::DISubprogramAttr::get(context, id, compileUnitAttr, fileAttr, |
| funcNameAttr, funcNameAttr, fileAttr, |
| /*line=*/1, |
| /*scopeline=*/1, |
| LLVM::DISubprogramFlags::Definition | |
| LLVM::DISubprogramFlags::Optimized, |
| subroutineTypeAttr); |
| } |
| |
| // Returns the most local DISubprogramAttr starting from |forOp|. |
| static LLVM::DISubprogramAttr getLocalScopeAttr(Operation *forOp) { |
| auto funcOp = forOp->getParentOfType<LLVM::LLVMFuncOp>(); |
| assert(funcOp && "usage requires an enclosing LLVMFuncOp"); |
| auto scopeLocAttr = |
| funcOp.getLoc() |
| ->findInstanceOf<mlir::FusedLocWith<LLVM::DISubprogramAttr>>(); |
| assert(scopeLocAttr && |
| "must have attached a DISubprogramAttr to the parent function"); |
| return scopeLocAttr.getMetadata(); |
| } |
| |
| // Returns the argument at |argIndex| in the parent function of |forOp|. |
| static Value getLocalArgument(Operation *forOp, unsigned argIndex) { |
| auto funcOp = forOp->getParentOfType<LLVM::LLVMFuncOp>(); |
| assert(funcOp && "usage requires an enclosing LLVMFuncOp"); |
| return funcOp.getArgument(argIndex); |
| } |
| |
| // Returns "x" "y" or "z" based on |dim|. |
| static StringRef getDimName(int32_t dim) { |
| assert(dim >= 0 && dim <= 2 && "must be x, y, z"); |
| static const char *dims[3] = {"x", "y", "z"}; |
| return StringRef(dims[dim]); |
| } |
| |
| // Debug intrinsics require valid location information to pass LLVM's verifier. |
| // Since nothing checks these cases in MLIR before converting we avoid creating |
| // the ops if MLIR or LLVM is likely to reject them. |
| static bool isLocationValidForDI(Location loc) { |
| // Unknown locations are passed as null and DI doesn't like that. |
| if (llvm::isa<UnknownLoc>(loc)) |
| return false; |
| // MLIR currently can't handle name-only locations. We do this check to ensure |
| // there's at least one real location MLIR can pass along. |
| if (auto callLoc = llvm::dyn_cast<CallSiteLoc>(loc)) { |
| return isLocationValidForDI(callLoc.getCaller()) && |
| isLocationValidForDI(callLoc.getCallee()); |
| } else if (auto fileLoc = llvm::dyn_cast<FileLineColLoc>(loc)) { |
| return true; |
| } else if (auto fusedLoc = llvm::dyn_cast<FusedLoc>(loc)) { |
| return llvm::all_of(fusedLoc.getLocations(), isLocationValidForDI); |
| } else if (auto namedLoc = llvm::dyn_cast<NameLoc>(loc)) { |
| return isLocationValidForDI(namedLoc.getChildLoc()); |
| } else if (auto opaqueLoc = llvm::dyn_cast<OpaqueLoc>(loc)) { |
| return isLocationValidForDI(opaqueLoc.getFallbackLocation()); |
| } |
| return false; |
| } |
| |
| static Value buildArgDI(Operation *forOp, int argNum, Value value, Twine name, |
| LLVM::DITypeAttr type, OpBuilder &builder) { |
| if (!clVerboseDebugInfo) |
| return value; |
| auto loc = forOp->getLoc(); |
| if (!isLocationValidForDI(loc)) |
| return value; |
| auto scopeAttr = getLocalScopeAttr(forOp); |
| builder.create<LLVM::DbgValueOp>( |
| loc, value, |
| LLVM::DILocalVariableAttr::get(scopeAttr, builder.getStringAttr(name), |
| scopeAttr.getFile(), |
| /*line=*/1, /*arg=*/argNum + 1, |
| /*alignInBits=*/0, type)); |
| return value; |
| } |
| |
| static Value buildValueDI(Operation *forOp, Value value, Twine name, |
| LLVM::DITypeAttr type, OpBuilder &builder) { |
| if (!clVerboseDebugInfo) |
| return value; |
| auto loc = forOp->getLoc(); |
| if (!isLocationValidForDI(loc)) |
| return value; |
| auto scopeAttr = getLocalScopeAttr(forOp); |
| builder.create<LLVM::DbgValueOp>( |
| loc, value, |
| LLVM::DILocalVariableAttr::get(scopeAttr, builder.getStringAttr(name), |
| scopeAttr.getFile(), |
| /*line=*/1, /*arg=*/0, |
| /*alignInBits=*/0, type)); |
| return value; |
| } |
| |
| Value HALDispatchABI::loadWorkgroupID(Operation *forOp, int32_t dim, |
| Type resultType, OpBuilder &builder) { |
| auto dimValue = |
| loadFieldValue(forOp, WorkgroupStateField::workgroup_id_x + dim, builder); |
| auto resultValue = |
| castValueToType(forOp->getLoc(), dimValue, resultType, builder); |
| return buildValueDI(forOp, resultValue, |
| StringRef("workgroup_id_") + getDimName(dim), |
| di.getBasicType(resultType), builder); |
| } |
| |
| Value HALDispatchABI::loadWorkgroupCount(Operation *forOp, int32_t dim, |
| Type resultType, OpBuilder &builder) { |
| auto dimValue = loadFieldValue( |
| forOp, DispatchStateField::workgroup_count_x + dim, builder); |
| auto resultValue = |
| castValueToType(forOp->getLoc(), dimValue, resultType, builder); |
| return buildValueDI(forOp, resultValue, |
| StringRef("workgroup_count_") + getDimName(dim), |
| di.getBasicType(resultType), builder); |
| } |
| |
| Value HALDispatchABI::loadWorkgroupSize(Operation *forOp, int32_t dim, |
| Type resultType, OpBuilder &builder) { |
| auto dimValue = loadFieldValue( |
| forOp, DispatchStateField::workgroup_size_x + dim, builder); |
| auto resultValue = |
| castValueToType(forOp->getLoc(), dimValue, resultType, builder); |
| return buildValueDI(forOp, resultValue, |
| StringRef("workgroup_size_") + getDimName(dim), |
| di.getBasicType(resultType), builder); |
| } |
| |
| Value HALDispatchABI::loadMaxConcurrency(Operation *forOp, OpBuilder &builder) { |
| auto maxValue = |
| loadFieldValue(forOp, DispatchStateField::max_concurrency, builder); |
| auto resultValue = castValueToType( |
| forOp->getLoc(), maxValue, |
| typeConverter->convertType(builder.getIndexType()), builder); |
| return buildValueDI(forOp, resultValue, "max_concurrency", di.getIntptrT(), |
| builder); |
| } |
| |
| Value HALDispatchABI::loadWorkgroupLocalMemorySize(Operation *forOp, |
| OpBuilder &builder) { |
| auto sizeValue = |
| loadFieldValue(forOp, WorkgroupStateField::local_memory_size, builder); |
| auto resultValue = castValueToType( |
| forOp->getLoc(), sizeValue, |
| typeConverter->convertType(builder.getIndexType()), builder); |
| return buildValueDI(forOp, resultValue, "local_memory_size", di.getSizeT(), |
| builder); |
| } |
| |
| Value HALDispatchABI::loadWorkgroupLocalMemoryPtr(Operation *forOp, |
| OpBuilder &builder) { |
| auto resultValue = |
| loadFieldValue(forOp, WorkgroupStateField::local_memory, builder); |
| return buildValueDI(forOp, resultValue, "local_memory", di.getVoidPtr(), |
| builder); |
| } |
| |
| Value HALDispatchABI::loadPushConstantCount(Operation *forOp, |
| OpBuilder &builder) { |
| auto countValue = |
| loadFieldValue(forOp, DispatchStateField::push_constant_count, builder); |
| auto resultValue = castValueToType( |
| forOp->getLoc(), countValue, |
| typeConverter->convertType(builder.getIndexType()), builder); |
| return buildValueDI(forOp, resultValue, "push_constant_count", di.getSizeT(), |
| builder); |
| } |
| |
| Value HALDispatchABI::loadPushConstant(Operation *forOp, int64_t offset, |
| Type resultType, OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto constantsPtrValue = |
| loadFieldValue(forOp, DispatchStateField::push_constants, builder); |
| auto pushConstantType = IntegerType::get(context, 32); |
| Value constantPtrValue = builder.create<LLVM::GEPOp>( |
| loc, constantsPtrValue.getType(), pushConstantType, constantsPtrValue, |
| LLVM::GEPArg(int32_t(offset))); |
| Value constantValue = |
| builder.create<LLVM::LoadOp>(loc, pushConstantType, constantPtrValue); |
| auto resultValue = castValueToType(loc, constantValue, resultType, builder); |
| return buildValueDI(forOp, resultValue, |
| StringRef("push_constant[") + std::to_string(offset) + |
| "]", |
| di.getBasicType(resultType), builder); |
| } |
| |
| Value HALDispatchABI::loadBindingCount(Operation *forOp, OpBuilder &builder) { |
| auto countValue = |
| loadFieldValue(forOp, DispatchStateField::binding_count, builder); |
| auto resultValue = castValueToType( |
| forOp->getLoc(), countValue, |
| typeConverter->convertType(builder.getIndexType()), builder); |
| return buildValueDI(forOp, resultValue, "binding_count", di.getSizeT(), |
| builder); |
| } |
| |
| Value HALDispatchABI::loadBindingPtr(Operation *forOp, int64_t ordinal, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto ptrsPtrValue = |
| loadFieldValue(forOp, DispatchStateField::binding_ptrs, builder); |
| auto elementPtrValue = builder.create<LLVM::GEPOp>( |
| loc, ptrsPtrValue.getType(), |
| mlir::LLVM::LLVMPointerType::get(builder.getContext()), ptrsPtrValue, |
| LLVM::GEPArg(int32_t(ordinal))); |
| auto elementValue = builder.create<LLVM::LoadOp>( |
| loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), |
| elementPtrValue); |
| return buildValueDI(forOp, elementValue, |
| StringRef("binding_ptrs[") + std::to_string(ordinal) + |
| "]", |
| di.getPtrOf(di.getUint8T()), builder); |
| } |
| |
| Value HALDispatchABI::loadBindingLength(Operation *forOp, int64_t ordinal, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto lengthsPtrValue = |
| loadFieldValue(forOp, DispatchStateField::binding_lengths, builder); |
| auto indexType = typeConverter->convertType(IndexType::get(context)); |
| auto elementPtrValue = builder.create<LLVM::GEPOp>( |
| loc, lengthsPtrValue.getType(), indexType, lengthsPtrValue, |
| LLVM::GEPArg(int32_t(ordinal))); |
| auto elementValue = |
| builder.create<LLVM::LoadOp>(loc, indexType, elementPtrValue); |
| return buildValueDI(forOp, elementValue, |
| StringRef("binding_lengths[") + std::to_string(ordinal) + |
| "]", |
| di.getSizeT(), builder); |
| } |
| |
| MemRefDescriptor HALDispatchABI::loadBinding(Operation *forOp, int64_t ordinal, |
| Value baseOffsetValue, |
| MemRefType memRefType, |
| ValueRange dynamicDims, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| |
| // Load the base buffer pointer in the appropriate type (f32*, etc). |
| Value basePtrValue = loadBindingPtr(forOp, ordinal, builder); |
| |
| // NOTE: if we wanted to check the range was in bounds here would be the |
| // place to do it. |
| |
| // Construct the MemRefDescriptor type based on the information we have. |
| // NOTE: we could use the binding length to clamp this/check that the |
| // requested range is valid. |
| auto [strides, offset] = getStridesAndOffset(memRefType); |
| if (memRefType.hasStaticShape() && |
| !llvm::any_of(strides, ShapedType::isDynamic) && |
| !ShapedType::isDynamic(offset)) { |
| return MemRefDescriptor::fromStaticShape(builder, loc, *typeConverter, |
| memRefType, basePtrValue); |
| } else { |
| assert(memRefType.getNumDynamicDims() == dynamicDims.size()); |
| int64_t rank = memRefType.getRank(); |
| |
| // Build MemRef descriptor for this interface binding. |
| auto desc = MemRefDescriptor::undef(builder, loc, |
| typeConverter->convertType(memRefType)); |
| desc.setAllocatedPtr(builder, loc, basePtrValue); |
| desc.setAlignedPtr(builder, loc, basePtrValue); |
| auto llvmIndexType = typeConverter->convertType(builder.getIndexType()); |
| if (ShapedType::isDynamic(offset)) { |
| // The offset in the subspan is byteoffset. It is converted to element |
| // offset here. It is assumed that the byte offset is a multiple of |
| // the element type byte width. |
| int32_t elementBitWidth = |
| IREE::Util::getTypeBitWidth(memRefType.getElementType()); |
| Value elementWidthVal = |
| builder.create<LLVM::ConstantOp>(loc, llvmIndexType, elementBitWidth); |
| Value eight = builder.create<LLVM::ConstantOp>(loc, llvmIndexType, 8); |
| Value bitOffset = |
| builder.create<LLVM::MulOp>(loc, baseOffsetValue, eight); |
| Value elementOffsetVal = |
| builder.create<LLVM::UDivOp>(loc, bitOffset, elementWidthVal); |
| desc.setOffset(builder, loc, elementOffsetVal); |
| } else { |
| desc.setConstantOffset(builder, loc, offset); |
| } |
| |
| // Update memref descriptor shape. Dynamic dimensions can be mixed with |
| // static dimensions, like [128, ?, 128]. |
| int dynamicDimIndex = 0; |
| for (int i = 0; i < rank; ++i) { |
| if (memRefType.isDynamicDim(i)) { |
| desc.setSize(builder, loc, i, dynamicDims[dynamicDimIndex++]); |
| } else { |
| desc.setConstantSize(builder, loc, i, memRefType.getDimSize(i)); |
| } |
| } |
| |
| // Compute and update strides. Assume that MemRefs are row-major, that is, |
| // following index linearization: |
| // x[i, j, k] = i * x.dim[1] * x.dim[2] + j * x.dim[2] + k |
| if (!strides.empty()) { |
| assert(strides.back() == 1 && |
| "unexpected non-unit stride for innermost dimension"); |
| desc.setConstantStride(builder, loc, rank - 1, 1); |
| OpFoldResult currentStride = builder.getIndexAttr(1); |
| for (int i = rank - 1; i > 0; --i) { |
| if (ShapedType::isDynamic(strides[i - 1])) { |
| auto dim = desc.size(builder, loc, i); |
| Value currentStrideVal; |
| if (std::optional<int64_t> currentStrideInt = |
| getConstantIntValue(currentStride)) { |
| currentStrideVal = builder.create<LLVM::ConstantOp>( |
| loc, llvmIndexType, currentStrideInt.value()); |
| } else { |
| currentStrideVal = currentStride.get<Value>(); |
| } |
| currentStride = |
| builder.create<LLVM::MulOp>(loc, currentStrideVal, dim) |
| .getResult(); |
| desc.setStride(builder, loc, i - 1, currentStride.get<Value>()); |
| } else { |
| currentStride = builder.getIndexAttr(strides[i - 1]); |
| desc.setConstantStride(builder, loc, i - 1, strides[i - 1]); |
| } |
| } |
| } |
| |
| return desc; |
| } |
| } |
| |
| Type HALDispatchABI::getProcessorIDType() { |
| return getFieldType(WorkgroupStateField::processor_id); |
| } |
| |
| Value HALDispatchABI::loadProcessorID(Operation *forOp, OpBuilder &builder) { |
| auto resultValue = |
| loadFieldValue(forOp, WorkgroupStateField::processor_id, builder); |
| return buildValueDI(forOp, resultValue, "processor_id", |
| di.getBasicType(resultValue.getType()), builder); |
| } |
| |
| Value HALDispatchABI::updateProcessorDataFromTargetAttr( |
| Operation *forOp, Value processorDataPtrValue, OpBuilder &builder) { |
| // Get the target attr. |
| IREE::HAL::ExecutableTargetAttr targetAttr = |
| IREE::HAL::ExecutableTargetAttr::lookup(forOp); |
| if (!targetAttr) { |
| return processorDataPtrValue; |
| } |
| |
| // Lookup CPU features. |
| std::optional<NamedAttribute> cpuFeatures = |
| targetAttr.getConfiguration().getNamed("cpu_features"); |
| if (!cpuFeatures) { |
| return processorDataPtrValue; |
| } |
| |
| // Currently requiring all CPU feature bits to be in field 0. Generalize as |
| // needed when other CPU feature fields start to be used. |
| uint64_t specifiedCpuDataField0 = 0; |
| { |
| // Map llvm feature-name to bit used to represent it in IREE_CPUDATA_FIELD0. |
| // |
| // TODO(ravishankarm): This link to the runtime schemas needs to be broken. |
| // Instead we should use a reflection callback to resolve arch guarded |
| // features directly in the compiler. |
| llvm::StringMap<uint64_t> featureToBitPattern; |
| auto targetTriple = getTargetTriple(targetAttr); |
| if (!targetTriple) { |
| return processorDataPtrValue; |
| } |
| std::string targetArchUppercase = |
| StringRef(getIreeArchNameForTargetTriple(targetTriple.value())).upper(); |
| #define IREE_CPU_FEATURE_BIT(arch, field_index, bit_pos, bit_name, llvm_name) \ |
| if (targetArchUppercase == #arch) { \ |
| assert(field_index == 0); \ |
| featureToBitPattern[llvm_name] = 1ull << bit_pos; \ |
| } |
| #include "iree/schemas/cpu_feature_bits.inl" |
| #undef IREE_CPU_FEATURE_BIT |
| |
| // Find CPU features in featureToBitPattern |
| SmallVector<StringRef> cpuFeatureStrings; |
| llvm::cast<StringAttr>(cpuFeatures->getValue()) |
| .getValue() |
| .split(cpuFeatureStrings, ',', /*MakeSplit=*/-1, /*KeepEmpty=*/false); |
| for (auto featureString : cpuFeatureStrings) { |
| // CPU features are typically prefixed with a +, e.g. +avx,+avx2,+fma. |
| featureString.consume_front("+"); |
| // Silently skip unknown CPU features, more flexible for now. Note that |
| // some featurs occurring here are not standard CPU features but internal |
| // things such as the "+reserve-x18" that we add on arm64. |
| if (featureToBitPattern.count(featureString)) { |
| specifiedCpuDataField0 |= featureToBitPattern.lookup(featureString); |
| } |
| } |
| } |
| if (specifiedCpuDataField0 == 0) { |
| return processorDataPtrValue; |
| } |
| |
| // Create a new stack allocation for the bit pattern. |
| Location loc = forOp->getLoc(); |
| MLIRContext *context = forOp->getContext(); |
| auto ptrType = LLVM::LLVMPointerType::get(context); |
| auto i64Ty = builder.getI64Type(); |
| Value arraySize = builder.create<LLVM::ConstantOp>( |
| loc, i64Ty, builder.getI64IntegerAttr(ProcessorDataCapacity)); |
| Value alloca = builder.create<LLVM::AllocaOp>(loc, ptrType, i64Ty, arraySize, |
| /*alignment=*/sizeof(uint64_t)); |
| // Load the 0-th value. |
| Value srcData0 = |
| builder.create<LLVM::LoadOp>(loc, i64Ty, processorDataPtrValue); |
| // Set the specified CPU arch data. |
| Value bitPatternVal = builder.create<LLVM::ConstantOp>( |
| loc, i64Ty, builder.getI64IntegerAttr(specifiedCpuDataField0)); |
| srcData0 = builder.create<LLVM::OrOp>(loc, srcData0, bitPatternVal); |
| builder.create<LLVM::StoreOp>(loc, srcData0, alloca); |
| // Copy over the rest. |
| for (int64_t i = 1, e = ProcessorDataCapacity; i < e; ++i) { |
| Value loadPtr = builder.create<LLVM::GEPOp>( |
| loc, processorDataPtrValue.getType(), i64Ty, processorDataPtrValue, |
| LLVM::GEPArg(int32_t(i)), /*inbounds =*/true); |
| Value loadVal = builder.create<LLVM::LoadOp>(loc, i64Ty, loadPtr); |
| Value storePtr = builder.create<LLVM::GEPOp>( |
| loc, alloca.getType(), i64Ty, alloca, LLVM::GEPArg(int32_t(i)), |
| /*inbounds =*/true); |
| builder.create<LLVM::StoreOp>(loc, loadVal, storePtr); |
| } |
| return alloca; |
| } |
| |
| Type HALDispatchABI::getProcessorDataType() { |
| return LLVM::LLVMPointerType::get(processorType.getContext()); |
| } |
| |
| Value HALDispatchABI::loadProcessorData(Operation *forOp, OpBuilder &builder) { |
| // To get a pointer to the processor data we need to track pointers all the |
| // way from the environment argument. This is redundant with loadFieldValue |
| // but that returns values instead. |
| auto loc = forOp->getLoc(); |
| auto environmentPtrValue = |
| buildArgDI(forOp, /*argNum=*/0, getLocalArgument(forOp, 0), "environment", |
| di.getPtrOf(di.getConstOf(di.getEnvironmentV0T())), builder); |
| Value processorPtrValue = builder.create<LLVM::GEPOp>( |
| loc, LLVM::LLVMPointerType::get(context), |
| LLVM::LLVMPointerType::get(context), environmentPtrValue, |
| LLVM::GEPArg(int32_t(EnvironmentField::processor)), |
| /*inbounds=*/true); |
| Value processorDataPtrValue = builder.create<LLVM::GEPOp>( |
| loc, LLVM::LLVMPointerType::get(context), |
| LLVM::LLVMPointerType::get(context), processorPtrValue, |
| LLVM::GEPArg(int32_t(ProcessorField::data)), |
| /*inbounds=*/true); |
| Value updatedProcessorData = |
| updateProcessorDataFromTargetAttr(forOp, processorDataPtrValue, builder); |
| return buildValueDI(forOp, updatedProcessorData, "processor_data", |
| di.getPtrOf(di.getConstOf(di.getArrayOf( |
| di.getUint64T(), ProcessorDataCapacity))), |
| builder); |
| } |
| |
| Value HALDispatchABI::loadProcessorData(Operation *forOp, int64_t index, |
| OpBuilder &builder) { |
| // Load the value; it should always be in bounds. |
| Value dataArrayValue = loadFieldValue(forOp, ProcessorField::data, builder); |
| SmallVector<int64_t, 1> position = {index}; |
| Value dataValue = builder.create<LLVM::ExtractValueOp>( |
| forOp->getLoc(), dataArrayValue, position); |
| return buildValueDI(forOp, dataValue, |
| StringRef("processor_data[") + std::to_string(index) + |
| "]", |
| di.getBasicType(dataValue.getType()), builder); |
| } |
| |
| Value HALDispatchABI::loadExecutableConstant(Operation *forOp, StringRef key, |
| Type resultType, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| |
| // Create top-level global placeholder. |
| // The magic attribute is used by future assignment passes. |
| std::string globalName = ("__constant_ordinal_" + key).str(); |
| auto moduleOp = |
| builder.getInsertionPoint()->getParentOfType<mlir::ModuleOp>(); |
| LLVM::GlobalOp globalOp; |
| if (!(globalOp = moduleOp.lookupSymbol<LLVM::GlobalOp>(globalName))) { |
| auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); |
| globalOp = moduleBuilder.create<LLVM::GlobalOp>( |
| loc, builder.getI32Type(), |
| /*isConstant=*/false, LLVM::Linkage::Internal, globalName, Attribute{}); |
| globalOp->setAttr(IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName(), |
| builder.getStringAttr(key)); |
| } |
| |
| // Load the placeholder global ordinal. |
| Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, globalOp); |
| Value ordinalValue = |
| builder.create<LLVM::LoadOp>(loc, globalOp.getType(), globalPtr); |
| |
| // Load constant from the executable constants struct. |
| auto constantsPtrValue = |
| loadFieldValue(forOp, EnvironmentField::constants, builder); |
| Value constantPtrValue = |
| builder.create<LLVM::GEPOp>(loc, constantsPtrValue.getType(), resultType, |
| constantsPtrValue, ordinalValue); |
| Value constantValue = |
| builder.create<LLVM::LoadOp>(loc, resultType, constantPtrValue); |
| auto resultValue = castValueToType(loc, constantValue, resultType, builder); |
| return buildValueDI(forOp, resultValue, |
| StringRef("executable_constant['") + key + "']", |
| di.getBasicType(resultValue.getType()), builder); |
| } |
| |
| Value HALDispatchABI::loadImportOrdinal(Operation *forOp, StringRef importName, |
| bool weak, OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| |
| // Create top-level global placeholder. |
| // The magic attribute is used by future assignment passes. |
| std::string globalName = ("__import_ordinal_" + importName).str(); |
| auto moduleOp = |
| builder.getInsertionPoint()->getParentOfType<mlir::ModuleOp>(); |
| LLVM::GlobalOp globalOp; |
| if (!(globalOp = moduleOp.lookupSymbol<LLVM::GlobalOp>(globalName))) { |
| auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); |
| globalOp = moduleBuilder.create<LLVM::GlobalOp>( |
| loc, builder.getI32Type(), |
| /*isConstant=*/false, LLVM::Linkage::Internal, globalName, Attribute{}); |
| globalOp->setAttr("hal.executable.import.key", |
| builder.getStringAttr(importName)); |
| if (weak) { |
| globalOp->setAttr("hal.executable.import.weak", builder.getUnitAttr()); |
| } |
| } |
| |
| // Load the placeholder global ordinal. |
| Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, globalOp); |
| return builder.create<LLVM::LoadOp>(loc, globalOp.getType(), globalPtr); |
| } |
| |
| std::pair<Value, Value> HALDispatchABI::loadImportFunc(Operation *forOp, |
| Value importOrdinal, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto funcPtrsValue = |
| loadFieldValue(forOp, EnvironmentField::import_funcs, builder); |
| auto opaquePtrType = LLVM::LLVMPointerType::get(builder.getContext()); |
| auto funcPtrValue = |
| builder.create<LLVM::GEPOp>(loc, funcPtrsValue.getType(), opaquePtrType, |
| funcPtrsValue, importOrdinal); |
| auto contextPtrsValue = |
| loadFieldValue(forOp, EnvironmentField::import_contexts, builder); |
| auto contextPtrValue = builder.create<LLVM::GEPOp>( |
| loc, contextPtrsValue.getType(), opaquePtrType, contextPtrsValue, |
| importOrdinal); |
| return std::make_pair( |
| builder.create<LLVM::LoadOp>(loc, opaquePtrType, funcPtrValue), |
| builder.create<LLVM::LoadOp>(loc, opaquePtrType, contextPtrValue)); |
| } |
| |
| Value HALDispatchABI::isImportFuncAvailable(Operation *forOp, |
| StringRef importName, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto importOrdinal = |
| loadImportOrdinal(forOp, importName, /*weak=*/true, builder); |
| auto importFunc = loadImportFunc(forOp, importOrdinal, builder); |
| Value nullPtrValue = |
| builder.create<LLVM::ZeroOp>(loc, importFunc.first.getType()); |
| return builder.create<LLVM::ICmpOp>(loc, builder.getI1Type(), |
| LLVM::ICmpPredicate::ne, importFunc.first, |
| nullPtrValue); |
| } |
| |
| Value HALDispatchABI::callImport(Operation *forOp, StringRef importName, |
| bool weak, Value params, OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto importOrdinal = loadImportOrdinal(forOp, importName, weak, builder); |
| auto thunkPtrValue = |
| loadFieldValue(forOp, EnvironmentField::import_thunk, builder); |
| auto importFunc = loadImportFunc(forOp, importOrdinal, builder); |
| |
| // TODO(benvanik): if weak is set then we should bail if the import is not |
| // found. Since we've loaded the import func here we can just compare for |
| // null as in isImportFuncAvailable but we'll need to make the control flow. |
| assert(!weak && "calls to weak imports not yet implemented"); |
| |
| Value nullPtrValue = builder.create<LLVM::ZeroOp>( |
| loc, LLVM::LLVMPointerType::get(builder.getContext())); |
| auto callOp = |
| builder.create<LLVM::CallOp>(loc, TypeRange{builder.getI32Type()}, |
| ValueRange{ |
| /*thunk_func_ptr=*/thunkPtrValue, |
| /*import_func_ptr=*/importFunc.first, |
| /*params=*/params, |
| /*context=*/importFunc.second, |
| /*reserved=*/nullPtrValue, |
| }); |
| return callOp.getResult(); |
| } |
| |
| // static |
| std::optional<Type> |
| HALDispatchABI::getParameterStructType(TypeRange resultTypes, ValueRange args, |
| TypeRange extraFieldsTypes) { |
| // Struct types are ordered [results..., args...]. |
| SmallVector<Type> types(resultTypes); |
| types.reserve(resultTypes.size() + args.size()); |
| for (Value arg : args) { |
| types.push_back(typeConverter->convertType(arg.getType())); |
| } |
| types.append(extraFieldsTypes.begin(), extraFieldsTypes.end()); |
| |
| if (types.empty()) { |
| return std::nullopt; |
| } |
| return LLVM::LLVMStructType::getLiteral(context, types); |
| } |
| |
| // static |
| std::tuple<Type, Value> |
| HALDispatchABI::packIntoParameterStruct(Operation *forOp, TypeRange resultTypes, |
| ValueRange args, ValueRange extraFields, |
| OpBuilder &builder) { |
| Location loc = forOp->getLoc(); |
| MLIRContext *context = builder.getContext(); |
| |
| // Query any extra fields that were requested and append them to the struct. |
| auto extraFieldsTypes = |
| llvm::map_to_vector(extraFields, [](Value v) { return v.getType(); }); |
| |
| std::optional<Type> structType = |
| getParameterStructType(resultTypes, args, extraFieldsTypes); |
| |
| if (!structType) { |
| Type voidPtrType = LLVM::LLVMPointerType::get(context); |
| return {voidPtrType, |
| builder.create<LLVM::UndefOp>(loc, voidPtrType).getResult()}; |
| } |
| |
| auto ptrStructType = LLVM::LLVMPointerType::get(context); |
| Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), |
| builder.getIndexAttr(1)); |
| Value paramsPtr = |
| builder.create<LLVM::AllocaOp>(loc, ptrStructType, *structType, one, |
| /*alignment=*/0); |
| Value structVal = builder.create<LLVM::UndefOp>(loc, *structType); |
| for (int64_t i = 0, e = args.size(); i < e; ++i) { |
| structVal = builder.create<LLVM::InsertValueOp>(loc, structVal, args[i], |
| i + resultTypes.size()); |
| } |
| for (int64_t i = 0, e = extraFields.size(); i < e; ++i) { |
| structVal = builder.create<LLVM::InsertValueOp>( |
| loc, structVal, extraFields[i], i + resultTypes.size() + args.size()); |
| } |
| // Store into the alloca'ed descriptor. |
| builder.create<LLVM::StoreOp>(loc, structVal, paramsPtr); |
| return {*structType, paramsPtr}; |
| } |
| |
| // static |
| FailureOr<LLVM::LLVMFunctionType> HALDispatchABI::getABIFunctionType( |
| Operation *forOp, IREE::HAL::CallingConvention cConv, TypeRange resultTypes, |
| TypeRange argTypes, ArrayRef<StringRef> extraFields) { |
| MLIRContext *context = forOp->getContext(); |
| SmallVector<Type> extraFieldsTypes = llvm::map_to_vector( |
| extraFields, [&](StringRef name) { return getExtraFieldType(name); }); |
| |
| // Check for extra fields already added. |
| if (argTypes.size() >= extraFieldsTypes.size()) { |
| if (llvm::all_of(llvm::zip(argTypes.take_back(extraFieldsTypes.size()), |
| extraFieldsTypes), |
| [](auto it) { |
| auto lhsType = std::get<0>(it); |
| auto rhsType = std::get<1>(it); |
| return (llvm::isa<LLVM::LLVMPointerType>(lhsType) && |
| llvm::isa<LLVM::LLVMPointerType>(rhsType)) || |
| std::get<0>(it) == std::get<1>(it); |
| })) { |
| // Extra fields already added. Drop them. |
| extraFieldsTypes.clear(); |
| } |
| } |
| |
| switch (cConv) { |
| case IREE::HAL::CallingConvention::Default: { |
| if (resultTypes.size() > 1) { |
| return forOp->emitOpError( |
| "Cannot have multiple return values for function"); |
| } |
| Type resultType = resultTypes.size() == 1 |
| ? resultTypes[0] |
| : LLVM::LLVMVoidType::get(context); |
| SmallVector<Type> allArgTypes = argTypes; |
| allArgTypes.append(extraFieldsTypes.begin(), extraFieldsTypes.end()); |
| return LLVM::LLVMFunctionType::get(resultType, allArgTypes); |
| } |
| default: |
| llvm_unreachable("unhandled calling convention"); |
| return failure(); |
| } |
| } |
| |
| // static |
| bool HALDispatchABI::hasCompatibleFunctionSignature( |
| MLIRContext *context, LLVM::LLVMFunctionType funcType, |
| TypeRange resultTypes, TypeRange paramTypes) { |
| TypeRange funcParamTypes = funcType.getParams(); |
| if (funcParamTypes.size() != paramTypes.size()) { |
| return false; |
| } |
| if (!llvm::all_of(llvm::zip(funcParamTypes, paramTypes), [](auto it) { |
| auto lhsType = std::get<0>(it); |
| auto rhsType = std::get<1>(it); |
| return (llvm::isa<LLVM::LLVMPointerType>(lhsType) && |
| llvm::isa<LLVM::LLVMPointerType>(rhsType)) || |
| std::get<0>(it) == std::get<1>(it); |
| })) { |
| return false; |
| } |
| if (resultTypes.size() > 1) { |
| return false; |
| } |
| Type funcResultType = funcType.getReturnType(); |
| if (resultTypes.empty() && |
| funcResultType != LLVM::LLVMVoidType::get(context)) { |
| return false; |
| } |
| if (resultTypes.size() == 1 && resultTypes[0] != funcResultType) { |
| return false; |
| } |
| return true; |
| } |
| |
| FailureOr<SmallVector<Value>> HALDispatchABI::materializeABI( |
| Operation *forOp, StringRef symbolName, IREE::HAL::CallingConvention cConv, |
| TypeRange resultTypes, ValueRange args, ArrayRef<StringRef> extraFields, |
| RewriterBase &rewriter) { |
| auto argTypes = |
| llvm::map_to_vector(args, [](Value v) { return v.getType(); }); |
| FailureOr<LLVM::LLVMFunctionType> abiFunctionType = |
| getABIFunctionType(forOp, cConv, resultTypes, argTypes, extraFields); |
| if (failed(abiFunctionType)) { |
| return forOp->emitOpError( |
| "failed to get function type for calling convention"); |
| } |
| if (hasCompatibleFunctionSignature(rewriter.getContext(), |
| abiFunctionType.value(), resultTypes, |
| argTypes)) { |
| return rewriter.notifyMatchFailure( |
| forOp, "no change in function signature. skipping"); |
| } |
| |
| // Combined args list. |
| SmallVector<Value> allArgsList = llvm::to_vector(args); |
| SmallVector<Value> extraFieldVals = |
| llvm::map_to_vector(extraFields, [&](StringRef fieldName) { |
| return getExtraField(forOp, fieldName, rewriter); |
| }); |
| allArgsList.append(extraFieldVals); |
| |
| Location loc = forOp->getLoc(); |
| if (cConv == IREE::HAL::CallingConvention::Default) { |
| auto callOp = rewriter.create<LLVM::CallOp>( |
| loc, abiFunctionType->getReturnTypes(), allArgsList, forOp->getAttrs()); |
| return llvm::map_to_vector(callOp.getResults(), |
| [](OpResult v) -> Value { return v; }); |
| } |
| |
| return forOp->emitOpError("unhandled calling convention"); |
| } |
| |
| SmallVector<Value> HALDispatchABI::wrapAndCallImport( |
| Operation *forOp, StringRef importName, bool weak, TypeRange resultTypes, |
| ValueRange args, ArrayRef<StringRef> extraFields, OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| |
| SmallVector<Value> extraFieldVals = |
| llvm::map_to_vector(extraFields, [&](StringRef fieldName) { |
| return getExtraField(forOp, fieldName, builder); |
| }); |
| |
| auto [structType, paramsPtr] = packIntoParameterStruct( |
| forOp, resultTypes, args, extraFieldVals, builder); |
| |
| // Calls return 0 (success) or non-zero (failure). |
| auto callResult = callImport(forOp, importName, weak, paramsPtr, builder); |
| Block *trueDest = |
| builder.getInsertionBlock()->splitBlock(++builder.getInsertionPoint()); |
| Block *falseDest = builder.createBlock(trueDest); |
| |
| // Check the call results and branch to exit if it failed. |
| // Note that we weight the true branch (call successful) higher. |
| builder.setInsertionPointAfterValue(callResult); |
| Value zeroI32 = builder.create<LLVM::ConstantOp>( |
| loc, builder.getI32Type(), builder.getI32IntegerAttr(0)); |
| Value cmpZero = builder.create<LLVM::ICmpOp>( |
| loc, builder.getI1Type(), LLVM::ICmpPredicate::eq, callResult, zeroI32); |
| builder.create<LLVM::CondBrOp>(loc, cmpZero, trueDest, ValueRange{}, |
| falseDest, ValueRange{callResult}, |
| std::make_pair(1u, 0u)); |
| |
| // Failure return block. |
| // Return the call result to the runtime. |
| builder.setInsertionPointToStart(falseDest); |
| builder.create<LLVM::ReturnOp>( |
| loc, falseDest->addArgument(builder.getI32Type(), loc)); |
| |
| // Successful continuation block. |
| // Marshal results out of the params struct. |
| builder.setInsertionPointToStart(trueDest); |
| SmallVector<Value> results; |
| if (!resultTypes.empty()) { |
| results.reserve(resultTypes.size()); |
| Value structVal = builder.create<LLVM::LoadOp>(loc, structType, paramsPtr); |
| for (int64_t i = 0, e = resultTypes.size(); i < e; ++i) { |
| results.push_back( |
| builder.create<LLVM::ExtractValueOp>(loc, structVal, i)); |
| } |
| } |
| return results; |
| } |
| |
| Value HALDispatchABI::getIndexValue(Location loc, int64_t value, |
| OpBuilder &builder) { |
| return builder.create<LLVM::ConstantOp>( |
| loc, typeConverter->convertType(builder.getIndexType()), |
| builder.getI64IntegerAttr(value)); |
| } |
| |
| Value HALDispatchABI::castValueToType(Location loc, Value value, |
| Type resultType, OpBuilder &builder) { |
| // NOTE: we should handle more cases here (and proper sign extension). |
| if (value.getType() == resultType) |
| return value; |
| return builder.createOrFold<LLVM::ZExtOp>(loc, resultType, value); |
| } |
| |
| Value HALDispatchABI::loadFieldValue(Operation *forOp, EnvironmentField field, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto environmentPtrValue = |
| buildArgDI(forOp, /*argNum=*/0, getLocalArgument(forOp, 0), "environment", |
| di.getPtrOf(di.getConstOf(di.getEnvironmentV0T())), builder); |
| Value environmentValue = |
| builder.create<LLVM::LoadOp>(loc, environmentType, environmentPtrValue); |
| SmallVector<int64_t, 1> position = {int64_t(field)}; |
| return builder.create<LLVM::ExtractValueOp>(loc, environmentValue, position); |
| } |
| |
| Value HALDispatchABI::loadFieldValue(Operation *forOp, ProcessorField field, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| Value processorValue = |
| loadFieldValue(forOp, EnvironmentField::processor, builder); |
| SmallVector<int64_t, 1> position = {int64_t(field)}; |
| return builder.create<LLVM::ExtractValueOp>(loc, processorValue, position); |
| } |
| |
| Value HALDispatchABI::loadFieldValue(Operation *forOp, DispatchStateField field, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto statePtrValue = buildArgDI( |
| forOp, /*argNum=*/1, getLocalArgument(forOp, 1), "dispatch_state", |
| di.getPtrOf(di.getConstOf(di.getDispatchStateV0T())), builder); |
| Value stateValue = |
| builder.create<LLVM::LoadOp>(loc, dispatchStateType, statePtrValue); |
| SmallVector<int64_t, 1> position = {int64_t(field)}; |
| return builder.create<LLVM::ExtractValueOp>(loc, stateValue, position); |
| } |
| |
| Type HALDispatchABI::getFieldType(WorkgroupStateField field) { |
| return workgroupStateType.getBody()[int64_t(field)]; |
| } |
| |
| Value HALDispatchABI::loadFieldValue(Operation *forOp, |
| WorkgroupStateField field, |
| OpBuilder &builder) { |
| auto loc = forOp->getLoc(); |
| auto statePtrValue = buildArgDI( |
| forOp, /*argNum=*/2, getLocalArgument(forOp, 2), "workgroup_state", |
| di.getPtrOf(di.getConstOf(di.getWorkgroupStateV0T())), builder); |
| Value stateValue = |
| builder.create<LLVM::LoadOp>(loc, workgroupStateType, statePtrValue); |
| SmallVector<int64_t, 1> position = {int64_t(field)}; |
| return builder.create<LLVM::ExtractValueOp>(loc, stateValue, position); |
| } |
| |
| Type HALDispatchABI::getExtraFieldType(StringRef extraField) { |
| if (extraField == "processor_id") { |
| return getProcessorIDType(); |
| } |
| if (extraField == "processor_data") { |
| return getProcessorDataType(); |
| } |
| assert(false && "unhandled extra filed"); |
| return {}; |
| } |
| |
| Value HALDispatchABI::getExtraField(Operation *forOp, StringRef extraField, |
| OpBuilder &builder) { |
| if (extraField == "processor_id") { |
| return loadProcessorID(forOp, builder); |
| } else if (extraField == "processor_data") { |
| return loadProcessorData(forOp, builder); |
| } else { |
| assert(false && "unhandled extra field"); |
| return {}; |
| } |
| } |
| |
| } // namespace mlir::iree_compiler |