Making MaterializeResourceCaches support multiple devices. This will fail on cases where a query can't be tracked to a single device but it's possible in the future to hoist/propagate across CFG edges before running this pass such that it doesn't happen. Today we inline most things and don't deduplicate functions so it'll be rare that we end up being unable to memoize. Hopefully.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index 9761580..de22093 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -7,11 +7,13 @@ #include <memory> #include <utility> +#include "iree/compiler/Dialect/HAL/Analysis/DeviceAnalysis.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" @@ -20,6 +22,9 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" +#define DEBUG_TYPE "iree-hal-materialize-resource-caches" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + namespace mlir::iree_compiler::IREE::HAL { #define GEN_PASS_DEF_MATERIALIZERESOURCECACHESPASS @@ -27,315 +32,683 @@ namespace { -// TODO(multi-device): rewrite this to shard resources per device. +//===----------------------------------------------------------------------===// +// --iree-hal-materialize-resource-caches +//===----------------------------------------------------------------------===// + +struct DescriptorSetLayout { + // All locations that use the layout. + SetVector<Location> locs; + // Value within the initializer once materialized. + Value initializerValue; +}; +using DescriptorSetLayoutKey = + std::pair<ArrayAttr, IREE::HAL::DescriptorSetLayoutFlags>; + +struct PipelineLayout { + // All locations that use the layout. + SetVector<Location> locs; + // Lookup ops for this layout. + SmallVector<IREE::HAL::PipelineLayoutLookupOp> lookupOps; + // Global once materialized. + IREE::Util::GlobalOpInterface globalOp; + // Value within the initializer once materialized. + Value initializerValue; +}; + +struct Executable { + // All locations that use the executable. + SetVector<Location> locs; + // Executable representing the program to load. + IREE::HAL::ExecutableOp executableOp; + // Lookup ops for this executable. + SmallVector<IREE::HAL::ExecutableLookupOp> lookupOps; + // Global once materialized. + IREE::Util::GlobalOpInterface globalOp; +}; + +struct DeviceResources { + DeviceResources() = default; + explicit DeviceResources(IREE::Util::GlobalOpInterface deviceOp) + : deviceOp(deviceOp) {} + + // Global !hal.device. + IREE::Util::GlobalOpInterface deviceOp; + + // Fallback devices that should be checked for resources. + // These are derived from the transitive set of #hal.device.fallback attrs. + SetVector<DeviceResources *> fallbackDeviceResources; + + // Descriptor set layouts used on the device, keyed by [bindingAttrs, flags]. + llvm::MapVector<DescriptorSetLayoutKey, DescriptorSetLayout> + descriptorSetLayouts; + // Pipeline layouts used on the device, keyed by layout attr. + llvm::MapVector<IREE::HAL::PipelineLayoutAttr, PipelineLayout> + pipelineLayouts; + // Executables used on the device, keyed by name. + llvm::MapVector<StringAttr, Executable> executables; +}; + +static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) { + StringRef deviceName = deviceOp.getGlobalName().getValue(); + if (deviceName.starts_with("__")) { + // Already prefixed. + return deviceName.str(); + } + auto prefixedName = "__" + deviceName; + return prefixedName.str(); +} + +static void declareDevicePipelineLayout(IREE::Util::GlobalOpInterface deviceOp, + PipelineLayout &pipelineLayout, + size_t pipelineLayoutIndex, + OpBuilder &moduleBuilder) { + // Create global in the module. + auto symbolName = getDeviceNamePrefix(deviceOp) + "_pipeline_layout_" + + std::to_string(pipelineLayoutIndex); + LLVM_DEBUG(DBGS() << "+ creating device `" + << deviceOp.getGlobalName().getValue() + << "` pipeline global `" << symbolName << "`\n"); + auto layoutType = moduleBuilder.getType<PipelineLayoutType>(); + auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>( + moduleBuilder.getFusedLoc(llvm::to_vector(pipelineLayout.locs)), + symbolName, + /*isMutable=*/false, layoutType); + globalOp.setPrivate(); + pipelineLayout.globalOp = globalOp; + + // Replace lookups with the global. + for (auto lookupOp : pipelineLayout.lookupOps) { + LLVM_DEBUG({ + DBGS() << " - replacing lookup: "; + lookupOp.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + OpBuilder lookupBuilder(lookupOp); + auto loadedValue = + pipelineLayout.globalOp.createLoadOp(lookupOp.getLoc(), lookupBuilder) + .getLoadedGlobalValue(); + lookupOp.replaceAllUsesWith(loadedValue); + lookupOp.erase(); + } + pipelineLayout.lookupOps.clear(); +} + +static void declareDeviceExecutable(IREE::Util::GlobalOpInterface deviceOp, + Executable &executable, + size_t executableIndex, + OpBuilder &moduleBuilder) { + // Create global in the module. + auto symbolName = (getDeviceNamePrefix(deviceOp) + "_executable_" + + std::to_string(executableIndex) + "_" + + executable.executableOp.getName()) + .str(); + LLVM_DEBUG(DBGS() << "+ creating device `" + << deviceOp.getGlobalName().getValue() + << "` executable global `" << symbolName << "`\n"); + auto executableType = moduleBuilder.getType<IREE::HAL::ExecutableType>(); + auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>( + moduleBuilder.getFusedLoc(llvm::to_vector(executable.locs)), symbolName, + /*isMutable=*/false, executableType); + globalOp.setPrivate(); + executable.globalOp = globalOp; + + // Replace lookups with the global. + for (auto lookupOp : executable.lookupOps) { + LLVM_DEBUG({ + DBGS() << " - replacing lookup: "; + lookupOp.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + OpBuilder lookupBuilder(lookupOp); + auto loadedValue = + executable.globalOp.createLoadOp(lookupOp.getLoc(), lookupBuilder) + .getLoadedGlobalValue(); + lookupOp.replaceAllUsesWith(loadedValue); + lookupOp.erase(); + } + executable.lookupOps.clear(); +} + +static DescriptorSetLayoutKey +getDescriptorSetLayoutKey(IREE::HAL::DescriptorSetLayoutAttr setLayoutAttr) { + auto bindingAttrs = + llvm::to_vector_of<Attribute>(setLayoutAttr.getBindings()); + return DescriptorSetLayoutKey{ + ArrayAttr::get(setLayoutAttr.getContext(), bindingAttrs), + setLayoutAttr.getFlags().value_or( + IREE::HAL::DescriptorSetLayoutFlags::None), + }; +} + +// Inlines a constant block as a function in |moduleBuilder| and then inserts +// a call to it in |callerBuilder|. +static SmallVector<Value> inlineConstantBlockOp( + StringRef funcName, IREE::HAL::ExecutableConstantBlockOp blockOp, + OpBuilder &moduleBuilder, OpBuilder &callerBuilder, Value callerDevice) { + LLVM_DEBUG(DBGS() << "- inlining constant block `" << funcName << "`\n"); + + // Create the function with the region contents of the constant block. + auto funcOp = moduleBuilder.create<IREE::Util::FuncOp>( + blockOp.getLoc(), funcName, blockOp.getFunctionType()); + funcOp.setPrivate(); + IRMapping mapping; + blockOp.getRegion().cloneInto(&funcOp.getRegion(), mapping); + + // Replace the hal.return with a func.return. + for (auto returnOp : + llvm::make_early_inc_range(funcOp.getOps<IREE::HAL::ReturnOp>())) { + OpBuilder(returnOp).create<IREE::Util::ReturnOp>(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + } + + // Create the call passing in the device if needed. + SmallVector<Value> callOperands; + if (funcOp.getNumArguments() > 0) { + callOperands.push_back(callerDevice); + } + auto callOp = callerBuilder.create<IREE::Util::CallOp>(blockOp.getLoc(), + funcOp, callOperands); + return llvm::to_vector_of<Value>(callOp.getResults()); +} + +static Value initializeExecutable(DeviceResources &deviceResources, + Executable &executable, + OpBuilder &moduleBuilder, + Value initializerDevice, + OpBuilder &initializerBuilder) { + auto loc = executable.globalOp.getLoc(); + auto executableType = moduleBuilder.getType<IREE::HAL::ExecutableType>(); + + // Create a switch statement with a case for each variant. + // Each case should then cache only executables which contain a matching + // ExecutableVariantOp. + // Afterwards, canonicalization will take care of de-duping/etc. + SmallVector<int64_t> caseIndices; + SmallVector<IREE::HAL::ExecutableVariantOp> caseVariantOps; + for (auto variantOp : + executable.executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) { + caseIndices.push_back(caseIndices.size()); + caseVariantOps.push_back(variantOp); + } + + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseVariantOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + return caseVariantOps[i].buildCondition(initializerDevice, builder); + }, + initializerBuilder); + + // Allow each variant to define how it is loaded and what pipeline it has. + auto switchOp = initializerBuilder.create<scf::IndexSwitchOp>( + loc, executableType, selectedIndex, caseIndices, caseIndices.size()); + for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) { + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + + // Gather each of the pipeline layouts needed for each entry point in + // the executable. + SmallVector<Value> pipelineLayoutValues; + for (auto exportOp : variantOp.getExportOps()) { + auto &pipelineLayout = + deviceResources.pipelineLayouts[exportOp.getLayoutAttr()]; + pipelineLayoutValues.push_back(pipelineLayout.initializerValue); + } + + // Inline constant initializer from the variant. + // We want these to all happen inside of this device switch case; they'll + // get deduplicated/hoisted if possible in future canonicalization passes. + SmallVector<Value> constantValues; + for (auto [blockIndex, blockOp] : + llvm::enumerate(variantOp.getConstantBlockOps())) { + auto blockName = (executable.globalOp.getGlobalName().getValue() + + "_constant_block_" + std::to_string(blockIndex)) + .str(); + constantValues.append(inlineConstantBlockOp( + blockName, blockOp, moduleBuilder, caseBuilder, initializerDevice)); + } + + Value executableValue = + caseBuilder.createOrFold<IREE::HAL::ExecutableCreateOp>( + loc, executableType, initializerDevice, + SymbolRefAttr::get( + executable.executableOp.getSymNameAttr(), + {SymbolRefAttr::get(variantOp.getSymNameAttr())}), + pipelineLayoutValues, constantValues); + + caseBuilder.create<scf::YieldOp>(loc, executableValue); + } + + // Fallback for no available variant. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + Value status = defaultBuilder.create<arith::ConstantIntOp>( + loc, static_cast<int>(IREE::Util::StatusCode::Unavailable), 32); + { + std::string errorStr; + llvm::raw_string_ostream errorStream(errorStr); + errorStream << "HAL device `" + << deviceResources.deviceOp.getGlobalName().getValue() + << "` does not support any variant of executable `" + << executable.executableOp.getName() + << "`; available formats: ["; + llvm::interleaveComma(caseVariantOps, errorStream, [&](auto variantOp) { + errorStream << variantOp.getTargetAttr().getFormat().getValue(); + }); + errorStream << "]"; + defaultBuilder.create<IREE::Util::StatusCheckOkOp>(loc, status, errorStr); + } + auto nullValue = + defaultBuilder.createOrFold<IREE::Util::NullOp>(loc, executableType); + defaultBuilder.create<scf::YieldOp>(loc, nullValue); + + return switchOp.getResult(0); +} + +static void initializeDeviceResources(DeviceResources &deviceResources, + OpBuilder &moduleBuilder, + Value initializerDevice, + OpBuilder &initializerBuilder) { + // Initialize all descriptor set layouts for use by the pipeline layouts. + auto setLayoutType = initializerBuilder.getType<DescriptorSetLayoutType>(); + for (auto [i, it] : llvm::enumerate(deviceResources.descriptorSetLayouts)) { + auto [bindingAttrs, flags] = it.first; + auto &descriptorSetLayout = it.second; + descriptorSetLayout.initializerValue = + initializerBuilder.createOrFold<IREE::HAL::DescriptorSetLayoutCreateOp>( + initializerBuilder.getFusedLoc( + llvm::to_vector(descriptorSetLayout.locs)), + setLayoutType, initializerDevice, flags, bindingAttrs); + } + + // Initialize all pipeline layouts required for executable creation. + auto pipelineLayoutType = initializerBuilder.getType<PipelineLayoutType>(); + for (auto [i, it] : llvm::enumerate(deviceResources.pipelineLayouts)) { + auto &[layoutAttr, pipelineLayout] = it; + SmallVector<Value> setLayoutValues; + for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { + auto key = getDescriptorSetLayoutKey(setLayoutAttr); + setLayoutValues.push_back( + deviceResources.descriptorSetLayouts[key].initializerValue); + } + pipelineLayout.initializerValue = + initializerBuilder.createOrFold<IREE::HAL::PipelineLayoutCreateOp>( + pipelineLayout.globalOp.getLoc(), pipelineLayoutType, + initializerDevice, + initializerBuilder.getIndexAttr(layoutAttr.getPushConstants()), + setLayoutValues); + pipelineLayout.globalOp.createStoreOp(pipelineLayout.globalOp.getLoc(), + pipelineLayout.initializerValue, + initializerBuilder); + } + + // Initialize all executables. + for (auto [i, it] : llvm::enumerate(deviceResources.executables)) { + auto &[executableName, executable] = it; + executable.globalOp.createStoreOp( + executable.globalOp.getLoc(), + initializeExecutable(deviceResources, executable, moduleBuilder, + initializerDevice, initializerBuilder), + initializerBuilder); + } +} + +static void reuseFallbackDeviceResources(DeviceResources &deviceResources, + DeviceResources &fallbackResources, + Value initializerDevice, + OpBuilder &initializerBuilder) { + // Load fallback pipeline layouts for all required by this device. + for (auto &[layoutAttr, pipelineLayout] : deviceResources.pipelineLayouts) { + auto fallbackGlobalOp = + fallbackResources.pipelineLayouts[layoutAttr].globalOp; + assert(fallbackGlobalOp && "should have created global"); + Value fallbackPipelineLayout = + fallbackGlobalOp + .createLoadOp(pipelineLayout.globalOp.getLoc(), initializerBuilder) + .getLoadedGlobalValue(); + pipelineLayout.globalOp.createStoreOp(pipelineLayout.globalOp.getLoc(), + fallbackPipelineLayout, + initializerBuilder); + } + + // Load fallback executables for all required by this device. + for (auto &[executableName, executable] : deviceResources.executables) { + auto fallbackGlobalOp = + fallbackResources.executables[executable.executableOp.getNameAttr()] + .globalOp; + assert(fallbackGlobalOp && "should have created global"); + Value fallbackExecutable = + fallbackGlobalOp + .createLoadOp(executable.globalOp.getLoc(), initializerBuilder) + .getLoadedGlobalValue(); + executable.globalOp.createStoreOp(executable.globalOp.getLoc(), + fallbackExecutable, initializerBuilder); + } +} + +static void buildDeviceResourceInitializer(DeviceResources &deviceResources, + OpBuilder &moduleBuilder) { + auto loc = deviceResources.deviceOp.getLoc(); + auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc); + OpBuilder initializerBuilder = + OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); + Value initializerDevice = + deviceResources.deviceOp.createLoadOp(loc, initializerBuilder) + .getLoadedGlobalValue(); + + // If there are any fallbacks then we need to handle referencing their + // resources and otherwise will initialize our own. + if (deviceResources.fallbackDeviceResources.empty()) { + initializeDeviceResources(deviceResources, moduleBuilder, initializerDevice, + initializerBuilder); + } else { + SmallVector<int64_t> caseIndices; + Value selectedIndex = buildIfElseTree( + loc, deviceResources.fallbackDeviceResources.size(), + [&](Location loc, size_t i, OpBuilder &caseBuilder) { + caseIndices.push_back(caseIndices.size()); + auto *fallbackResources = deviceResources.fallbackDeviceResources[i]; + Value fallbackDevice = + fallbackResources->deviceOp.createLoadOp(loc, caseBuilder) + .getLoadedGlobalValue(); + return caseBuilder.create<IREE::Util::CmpEQOp>(loc, initializerDevice, + fallbackDevice); + }, + initializerBuilder); + auto switchOp = initializerBuilder.create<scf::IndexSwitchOp>( + loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size()); + for (auto [fallbackResources, caseRegion] : + llvm::zip_equal(deviceResources.fallbackDeviceResources, + switchOp.getCaseRegions())) { + auto &caseBlock = caseRegion.emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + reuseFallbackDeviceResources(deviceResources, *fallbackResources, + initializerDevice, caseBuilder); + caseBuilder.create<scf::YieldOp>(loc); + } + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + initializeDeviceResources(deviceResources, moduleBuilder, initializerDevice, + defaultBuilder); + defaultBuilder.create<scf::YieldOp>(loc); + } + + initializerBuilder.create<IREE::Util::ReturnOp>(loc); +} + +// Returns zero or more devices globals that may act as fallbacks for the +// given device, if analyzed. The result is in selection order. +static std::optional<SetVector<IREE::Util::GlobalOpInterface>> +getDeviceFallbackGlobals(IREE::Util::GlobalOpInterface deviceGlobal, + SymbolTable &symbolTable) { + SetVector<IREE::Util::GlobalOpInterface> resultSet; + auto processAttr = [&](Attribute attr) { + if (!attr) + return true; // ignore uninitialized devices + return TypeSwitch<Attribute, bool>(attr) + .Case<IREE::HAL::DeviceOrdinalAttr>([](auto attr) { return true; }) + .Case<IREE::HAL::DeviceTargetAttr>([](auto attr) { return true; }) + .Case<IREE::HAL::DeviceFallbackAttr>([&](auto fallbackAttr) { + resultSet.insert(symbolTable.lookup<IREE::Util::GlobalOpInterface>( + fallbackAttr.getName().getValue())); + return true; + }) + .Default([](auto attr) { return false; }); + }; + auto initialValue = deviceGlobal.getGlobalInitialValue(); + if (auto selectAttr = + dyn_cast_if_present<IREE::HAL::DeviceSelectAttr>(initialValue)) { + for (auto deviceAttr : selectAttr.getDevices()) { + if (!processAttr(deviceAttr)) { + // Fails if unsupported/unhandled device attribute type. + return std::nullopt; + } + } + } else { + if (!processAttr(initialValue)) { + // Fails if unsupported/unhandled device attribute type. + return std::nullopt; + } + } + return resultSet; +} + +static LogicalResult gatherDeviceResources( + ModuleOp &moduleOp, SymbolTable &symbolTable, + DeviceAnalysis &deviceAnalysis, + llvm::MapVector<Attribute, DeviceResources> &allDeviceResources) { + // Allocate storage for the resource sets. + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + LLVM_DEBUG(DBGS() << "Gathering device `" + << deviceOp.getGlobalName().getValue() + << "` resources...\n"); + allDeviceResources.try_emplace(deviceOp.getGlobalName(), + DeviceResources(deviceOp)); + } + + // Link fallbacks between the resources. + for (auto deviceOp : deviceAnalysis.getDeviceGlobals()) { + auto fallbackOps = getDeviceFallbackGlobals(deviceOp, symbolTable); + if (!fallbackOps) { + return deviceOp->emitOpError() + << "analysis failed on device; currently analysis must succeed"; + } + auto &deviceResources = allDeviceResources[deviceOp.getGlobalName()]; + for (auto fallbackOp : *fallbackOps) { + LLVM_DEBUG(DBGS() << "* linking to fallback `" + << fallbackOp.getGlobalName().getValue() << "`\n"); + deviceResources.fallbackDeviceResources.insert( + &allDeviceResources[fallbackOp.getGlobalName()]); + } + } + + // Find all relevant ops. If we don't find any we skip the pass as it's + // likely it's already been run. We could fix the pass to better support + // partial materialization but there's no use cases for that today. + auto tryGetDeviceResources = [&](Operation *op, + Value device) -> DeviceResources * { + auto deviceGlobals = deviceAnalysis.lookupDeviceGlobals(device); + if (!deviceGlobals || deviceGlobals->size() != 1) { + op->emitOpError() << "analysis failed on device; currently analysis " + "must succeed with a single device"; + return nullptr; + } + auto deviceOp = deviceGlobals->front(); + return &allDeviceResources.find(deviceOp.getGlobalName())->second; + }; + for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) { + for (auto &block : funcOp.getFunctionBody()) { + if (block + .walk([&](Operation *op) -> WalkResult { + if (auto lookupOp = dyn_cast<PipelineLayoutLookupOp>(op)) { + auto *deviceResources = + tryGetDeviceResources(lookupOp, lookupOp.getDevice()); + if (!deviceResources) { + return WalkResult::interrupt(); + } + auto layoutAttr = lookupOp.getLayoutAttr(); + LLVM_DEBUG(DBGS() + << "+ requiring pipeline layout from lookup: `" + << layoutAttr << "`\n"); + auto &pipelineLayout = + deviceResources->pipelineLayouts[layoutAttr]; + pipelineLayout.locs.insert(lookupOp.getLoc()); + pipelineLayout.lookupOps.push_back(lookupOp); + for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { + LLVM_DEBUG( + DBGS() + << "+ requiring descriptor set layout from lookup: `" + << setLayoutAttr << "`\n"); + auto key = getDescriptorSetLayoutKey(setLayoutAttr); + auto &setLayout = + deviceResources->descriptorSetLayouts[key]; + setLayout.locs.insert(lookupOp.getLoc()); + } + } else if (auto lookupOp = dyn_cast<ExecutableLookupOp>(op)) { + auto *deviceResources = + tryGetDeviceResources(lookupOp, lookupOp.getDevice()); + if (!deviceResources) { + return WalkResult::interrupt(); + } + auto executableAttr = lookupOp.getExecutableAttr().getAttr(); + LLVM_DEBUG(DBGS() << "+ requiring executable from lookup: `" + << executableAttr.getValue() << "`\n"); + auto &executable = + deviceResources->executables[executableAttr]; + executable.locs.insert(lookupOp.getLoc()); + executable.lookupOps.push_back(lookupOp); + } + return WalkResult::advance(); + }) + .wasInterrupted()) { + return failure(); + } + } + } + + // Gather the executables referenced by all lookup ops. + for (auto &[deviceName, deviceResources] : allDeviceResources) { + for (auto &[executableName, executable] : deviceResources.executables) { + executable.executableOp = + symbolTable.lookup<IREE::HAL::ExecutableOp>(executableName); + for (auto variantOp : + executable.executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) { + for (auto exportOp : variantOp.getExportOps()) { + auto layoutAttr = exportOp.getLayoutAttr(); + LLVM_DEBUG(DBGS() << "+ requiring pipeline layout from export: `" + << layoutAttr << "`\n"); + auto &pipelineLayout = deviceResources.pipelineLayouts[layoutAttr]; + pipelineLayout.locs.insert(exportOp.getLoc()); + for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { + LLVM_DEBUG(DBGS() + << "+ requiring descriptor set layout from export: `" + << setLayoutAttr << "`\n"); + auto key = getDescriptorSetLayoutKey(setLayoutAttr); + auto &setLayout = deviceResources.descriptorSetLayouts[key]; + setLayout.locs.insert(exportOp.getLoc()); + } + } + } + } + } + + // Merge all resources that may be used by way of fallbacks into each fallback + // device. We could make this optional to improve startup performance by + // adding these as optional and create them on demand but that's more complex. + // For now we just always ensure the resources are available even if they end + // up unused. + for (auto &[deviceName, deviceResources] : + llvm::reverse(allDeviceResources)) { + for (auto *fallbackResources : deviceResources.fallbackDeviceResources) { + LLVM_DEBUG( + DBGS() << "-> requiring fallback resources from device `" + << fallbackResources->deviceOp.getGlobalName().getValue() + << "`\n"); + for (auto [setKey, setLayout] : deviceResources.descriptorSetLayouts) { + auto &fallbackSetLayout = + fallbackResources->descriptorSetLayouts[setKey]; + fallbackSetLayout.locs.insert(setLayout.locs.begin(), + setLayout.locs.end()); + } + for (auto [layoutAttr, pipelineLayout] : + deviceResources.pipelineLayouts) { + auto &fallbackPipelineLayout = + fallbackResources->pipelineLayouts[layoutAttr]; + fallbackPipelineLayout.locs.insert(pipelineLayout.locs.begin(), + pipelineLayout.locs.end()); + } + for (auto [executableName, executable] : deviceResources.executables) { + auto &fallbackExecutable = + fallbackResources->executables[executableName]; + fallbackExecutable.locs.insert(executable.locs.begin(), + executable.locs.end()); + fallbackExecutable.executableOp = executable.executableOp; + } + } + } + + return success(); +} + struct MaterializeResourceCachesPass : public IREE::HAL::impl::MaterializeResourceCachesPassBase< MaterializeResourceCachesPass> { void runOnOperation() override { auto moduleOp = getOperation(); - if (moduleOp.getBody()->empty()) - return; - moduleBuilder = OpBuilder(&moduleOp.getBody()->front()); + SymbolTable symbolTable(moduleOp); - // Find all relevant ops. If we don't find any we skip the pass as it's - // likely it's already been run. We could fix the pass to better support - // partial materialization but there's no use cases for that today. - auto executableOps = llvm::to_vector<8>(moduleOp.getOps<ExecutableOp>()); - SmallVector<IREE::HAL::PipelineLayoutLookupOp> pipelineLayoutLookupOps; - SmallVector<IREE::HAL::ExecutableLookupOp> executableLookupOps; - for (auto funcOp : moduleOp.getOps<mlir::FunctionOpInterface>()) { - for (auto &block : funcOp.getFunctionBody()) { - block.walk([&](Operation *op) { - if (auto lookupOp = dyn_cast<PipelineLayoutLookupOp>(op)) { - pipelineLayoutLookupOps.push_back(lookupOp); - } else if (auto lookupOp = dyn_cast<ExecutableLookupOp>(op)) { - executableLookupOps.push_back(lookupOp); - } - }); + // Analyze the module to determine which devices are used where. + LLVM_DEBUG(DBGS() << "Running device analysis...\n"); + DeviceAnalysis deviceAnalysis(moduleOp); + if (failed(deviceAnalysis.run())) { + return signalPassFailure(); + } + + // Build a table of all resources used by all devices in the program. + LLVM_DEBUG(DBGS() << "Gathering device resources...\n"); + llvm::MapVector<Attribute, DeviceResources> allDeviceResources; + if (failed(gatherDeviceResources(moduleOp, symbolTable, deviceAnalysis, + allDeviceResources))) { + return signalPassFailure(); + } + + // Materialize resources for each device (if any) and replace lookups. + for (auto &[nameAttr, deviceResources] : allDeviceResources) { + LLVM_DEBUG(DBGS() << "Materializing device `" + << deviceResources.deviceOp.getGlobalName().getValue() + << "` resources...\n"); + // Skip devices with no resources. + if (deviceResources.pipelineLayouts.empty() && + deviceResources.executables.empty()) { + LLVM_DEBUG(DBGS() << "~ skipping device with no resources\n"); + continue; } - } - if (pipelineLayoutLookupOps.empty() && executableLookupOps.empty()) { - return; + + // TODO(benvanik): proper insertion order if devices are initialized via + // an initializer. Today this assumes the device hasn't been materialized + // yet if there are any lookups to them. + if (!deviceResources.deviceOp.getGlobalInitialValue()) { + deviceResources.deviceOp.emitOpError() + << "is expected to be initialized with an attribute and not yet " + "via a util.initializer"; + return signalPassFailure(); + } + + // Declare globals for each pipeline layout and executable and replace all + // lookup ops to reference them. + OpBuilder moduleBuilder(moduleOp); + moduleBuilder.setInsertionPointAfter(deviceResources.deviceOp); + for (auto [i, it] : llvm::enumerate(deviceResources.pipelineLayouts)) { + auto &[layoutAttr, pipelineLayout] = it; + declareDevicePipelineLayout(deviceResources.deviceOp, pipelineLayout, i, + moduleBuilder); + } + for (auto [i, it] : llvm::enumerate(deviceResources.executables)) { + auto &[executableName, executable] = it; + declareDeviceExecutable(deviceResources.deviceOp, executable, i, + moduleBuilder); + } + + // Create an initializer after the declared globals. + buildDeviceResourceInitializer(deviceResources, moduleBuilder); } - // Declare all layouts used by the executables. This will ensure that the - // initialization order is correct as any pipeline layout needed (and its - // dependencies) will be created prior to the executable cache below. The - // other nice thing is that we get ordering similar to the executable - // variables above. - for (auto executableOp : executableOps) { + // Remove ops that are no longer required after materialization. + for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) { for (auto variantOp : executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) { - for (auto exportOp : variantOp.getExportOps()) { - definePipelineLayoutOp(exportOp.getLoc(), exportOp.getLayout()); + if (auto conditionOp = variantOp.getConditionOp()) { + conditionOp.erase(); + } + for (auto blockOp : + llvm::make_early_inc_range(variantOp.getConstantBlockOps())) { + blockOp.erase(); } } } - - // Declare executable variables so that we can reference them during lookup - // replacement. - for (auto executableOp : executableOps) { - defineExecutableOp(executableOp); - } - - // Generate cached resource singletons and replace lookup ops with direct - // loads from variables. - for (auto lookupOp : pipelineLayoutLookupOps) { - replacePipelineLayoutLookupOp(lookupOp); - } - for (auto lookupOp : executableLookupOps) { - replaceExecutableLookupOp(lookupOp); - } } - -private: - IREE::Util::GlobalOp - defineDescriptorSetLayoutOp(Location loc, ArrayAttr bindingAttrs, - IREE::HAL::DescriptorSetLayoutFlags flags) { - std::pair<Attribute, IREE::HAL::DescriptorSetLayoutFlags> key = { - bindingAttrs, flags}; - auto existingIt = descriptorSetLayoutCache_.find(key); - if (existingIt != descriptorSetLayoutCache_.end()) { - return existingIt->second; - } - - auto symbolName = (StringRef("_descriptor_set_layout_") + - std::to_string(nextUniqueDescriptorSetLayoutId++)) - .str(); - - auto layoutType = DescriptorSetLayoutType::get(loc.getContext()); - auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>( - loc, symbolName, - /*isMutable=*/false, layoutType); - globalOp.setPrivate(); - descriptorSetLayoutCache_.try_emplace(key, globalOp); - - auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc); - OpBuilder blockBuilder = - OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder); - Value layout = blockBuilder.createOrFold<DescriptorSetLayoutCreateOp>( - loc, layoutType, device, flags, bindingAttrs); - globalOp.createStoreOp(loc, layout, blockBuilder); - blockBuilder.create<IREE::Util::ReturnOp>(loc); - - return globalOp; - } - - IREE::Util::GlobalOp - definePipelineLayoutOp(Location loc, - IREE::HAL::PipelineLayoutAttr layoutAttr) { - auto existingIt = pipelineLayoutCache_.find(layoutAttr); - if (existingIt != pipelineLayoutCache_.end()) { - return existingIt->second; - } - - // First lookup (or create) all the required descriptor sets. This ensures - // they end up in the proper initialization order. - SmallVector<IREE::Util::GlobalOp> setLayoutGlobalOps; - for (auto setLayoutAttr : layoutAttr.getSetLayouts()) { - SmallVector<Attribute> bindingAttrs; - for (auto bindingAttr : setLayoutAttr.getBindings()) { - bindingAttrs.push_back(bindingAttr); - } - setLayoutGlobalOps.push_back(defineDescriptorSetLayoutOp( - loc, ArrayAttr::get(loc.getContext(), bindingAttrs), - setLayoutAttr.getFlags().value_or( - IREE::HAL::DescriptorSetLayoutFlags::None))); - } - - auto symbolName = (StringRef("_pipeline_layout_") + - std::to_string(nextUniquePipelineLayoutId++)) - .str(); - - auto layoutType = PipelineLayoutType::get(loc.getContext()); - auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>( - loc, symbolName, /*isMutable=*/false, layoutType); - globalOp.setPrivate(); - pipelineLayoutCache_.try_emplace(layoutAttr, globalOp); - - auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc); - OpBuilder blockBuilder = - OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); - SmallVector<Value> setLayoutValues; - for (auto setLayoutGlobalOp : setLayoutGlobalOps) { - setLayoutValues.push_back( - setLayoutGlobalOp.createLoadOp(loc, blockBuilder) - .getLoadedGlobalValue()); - } - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder); - Value layout = blockBuilder.createOrFold<PipelineLayoutCreateOp>( - loc, layoutType, device, - blockBuilder.getIndexAttr(layoutAttr.getPushConstants()), - setLayoutValues); - globalOp.createStoreOp(loc, layout, blockBuilder); - blockBuilder.create<IREE::Util::ReturnOp>(loc); - - return globalOp; - } - - void defineExecutableOp(ExecutableOp executableOp) { - auto loc = executableOp.getLoc(); - auto symbolName = - (StringRef("_executable_") + executableOp.getSymName()).str(); - - auto executableType = ExecutableType::get(executableOp.getContext()); - auto globalOp = moduleBuilder.create<IREE::Util::GlobalOp>( - loc, symbolName, /*isMutable=*/false, executableType); - globalOp.setPrivate(); - executableCache_.try_emplace(executableOp.getSymName(), globalOp); - - auto initializerOp = moduleBuilder.create<IREE::Util::InitializerOp>(loc); - OpBuilder blockBuilder = - OpBuilder::atBlockEnd(initializerOp.addEntryBlock()); - // TODO(multi-device): pass in resolve info to the call and reuse. - Value device = IREE::HAL::DeviceType::resolveAny(loc, blockBuilder); - - // Create a switch statement with a case for each variant. - // Each case should then cache only executables which contain a matching - // ExecutableVariantOp. - // Afterwards, canonicalization will take care of de-duping/etc. - SmallVector<int64_t> caseIndices; - SmallVector<IREE::HAL::ExecutableVariantOp> caseVariantOps; - for (auto variantOp : - executableOp.getOps<IREE::HAL::ExecutableVariantOp>()) { - caseIndices.push_back(caseIndices.size()); - caseVariantOps.push_back(variantOp); - } - - // Select the variant index. - Value selectedIndex = buildIfElseTree( - loc, caseVariantOps.size(), - [&](Location loc, size_t i, OpBuilder &builder) { - return caseVariantOps[i].buildCondition(device, builder); - }, - blockBuilder); - - // Allow each variant to define how it is loaded and what pipeline it has. - auto switchOp = blockBuilder.create<scf::IndexSwitchOp>( - loc, executableType, selectedIndex, caseIndices, caseIndices.size()); - for (auto [i, variantOp] : llvm::enumerate(caseVariantOps)) { - auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); - auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); - - // Gather each of the pipeline layouts needed for each entry point in - // the executable. - SmallVector<Value, 8> pipelineLayoutValues; - for (auto exportOp : variantOp.getExportOps()) { - auto pipelineLayoutGlobalOp = - definePipelineLayoutOp(executableOp.getLoc(), exportOp.getLayout()); - pipelineLayoutValues.push_back( - pipelineLayoutGlobalOp.createLoadOp(loc, caseBuilder) - .getLoadedGlobalValue()); - } - - // Inline constant initializer from the variant. - // We want these to all happen inside of this device switch case; they'll - // get deduplicated/hoisted if possible in future canonicalization passes. - SmallVector<Value> constantValues; - for (auto blockOp : - llvm::make_early_inc_range(variantOp.getConstantBlockOps())) { - constantValues.append( - inlineConstantBlockOp(blockOp, moduleBuilder, caseBuilder, device)); - blockOp.erase(); - } - - Value executable = caseBuilder.createOrFold<ExecutableCreateOp>( - loc, executableType, device, - SymbolRefAttr::get(executableOp.getSymNameAttr(), - {SymbolRefAttr::get(variantOp.getSymNameAttr())}), - pipelineLayoutValues, constantValues); - - caseBuilder.create<scf::YieldOp>(loc, executable); - } - - // Fallback for no available variant. - auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); - auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); - Value status = defaultBuilder.create<arith::ConstantIntOp>( - loc, static_cast<int>(IREE::Util::StatusCode::Unavailable), 32); - defaultBuilder.create<IREE::Util::StatusCheckOkOp>( - loc, status, - "none of the executable binaries in the module are supported by the " - "runtime"); - auto nullValue = - defaultBuilder.createOrFold<IREE::Util::NullOp>(loc, executableType); - defaultBuilder.create<scf::YieldOp>(loc, nullValue); - - auto executableValue = switchOp.getResult(0); - globalOp.createStoreOp(loc, executableValue, blockBuilder); - blockBuilder.create<IREE::Util::ReturnOp>(loc); - } - - // Inlines a constant block as a function in |moduleBuilder| and then inserts - // a call to it in |callerBuilder|. - SmallVector<Value> inlineConstantBlockOp(ExecutableConstantBlockOp blockOp, - OpBuilder &moduleBuilder, - OpBuilder &callerBuilder, - Value device) { - // Create the function with the region contents of the constant block. - auto funcName = (StringRef("__constant_block_") + - std::to_string(nextUniqueConstantBlockId++)) - .str(); - auto funcOp = moduleBuilder.create<IREE::Util::FuncOp>( - blockOp.getLoc(), funcName, blockOp.getFunctionType()); - funcOp.setPrivate(); - funcOp.getRegion().takeBody(blockOp.getRegion()); - - // Replace the hal.return with a func.return. - for (auto returnOp : - llvm::make_early_inc_range(funcOp.getOps<IREE::HAL::ReturnOp>())) { - OpBuilder(returnOp).create<IREE::Util::ReturnOp>(returnOp.getLoc(), - returnOp.getOperands()); - returnOp.erase(); - } - - // Create the call passing in the device if needed. - SmallVector<Value> callOperands; - if (funcOp.getNumArguments() > 0) { - callOperands.push_back(device); - } - auto callOp = callerBuilder.create<IREE::Util::CallOp>( - blockOp.getLoc(), funcOp, callOperands); - - return llvm::map_to_vector(callOp.getResults(), - [](OpResult result) -> Value { return result; }); - } - - void replacePipelineLayoutLookupOp(PipelineLayoutLookupOp &lookupOp) { - OpBuilder builder(lookupOp); - auto globalOp = - definePipelineLayoutOp(lookupOp.getLoc(), lookupOp.getLayout()); - auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder) - .getLoadedGlobalValue(); - lookupOp.replaceAllUsesWith(loadedValue); - lookupOp.erase(); - } - - void replaceExecutableLookupOp(ExecutableLookupOp &lookupOp) { - OpBuilder builder(lookupOp); - auto executableIt = executableCache_.find(lookupOp.getExecutable()); - assert(executableIt != executableCache_.end() && - "executable must have been cached"); - auto globalOp = executableIt->second; - auto loadedValue = globalOp.createLoadOp(lookupOp.getLoc(), builder) - .getLoadedGlobalValue(); - lookupOp.replaceAllUsesWith(loadedValue); - lookupOp.erase(); - } - - OpBuilder moduleBuilder{static_cast<MLIRContext *>(nullptr)}; - DenseMap<std::pair<Attribute, IREE::HAL::DescriptorSetLayoutFlags>, - IREE::Util::GlobalOp> - descriptorSetLayoutCache_; - DenseMap<Attribute, IREE::Util::GlobalOp> pipelineLayoutCache_; - DenseMap<StringRef, IREE::Util::GlobalOp> executableCache_; - - int nextUniqueConstantBlockId = 0; - int nextUniquePipelineLayoutId = 0; - int nextUniqueDescriptorSetLayoutId = 0; }; } // namespace
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp index 2068a87..096b7bf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MemoizeDeviceQueries.cpp
@@ -52,9 +52,11 @@ static std::string getDeviceNamePrefix(IREE::Util::GlobalOpInterface deviceOp) { StringRef deviceName = deviceOp.getGlobalName().getValue(); - if (deviceName.starts_with("__")) + if (deviceName.starts_with("__")) { return deviceName.str(); - return ("__" + deviceName).str(); + } + auto prefixedName = "__" + deviceName; + return prefixedName.str(); } // NOTE: this implementation is just for a single active device. As we start to
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir index 3cb5f76..4e562f6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_resource_caches.mlir
@@ -1,27 +1,34 @@ // RUN: iree-opt --split-input-file --iree-hal-materialize-resource-caches %s | FileCheck %s -// CHECK: util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout - -// CHECK: util.global private @_pipeline_layout_0 : !hal.pipeline_layout +// CHECK: util.global private @device = #hal.device.ordinal<0> +util.global private @device = #hal.device.ordinal<0> : !hal.device +// CHECK: util.global private @__device_pipeline_layout_0 : !hal.pipeline_layout // CHECK-NEXT: util.initializer { -// CHECK-DAG: %[[SET0:.+]] = util.global.load @_descriptor_set_layout_0 : !hal.descriptor_set_layout -// CHECK-DAG: %[[DEVICE:.+]] = hal.devices.get %{{.+}} -// CHECK-NEXT: %[[LAYOUT:.+]] = hal.pipeline_layout.create -// CHECK-SAME: device(%[[DEVICE]] : !hal.device) -// CHECK-SAME: push_constants(1) -// CHECK-SAME: layouts([%[[SET0]]]) : !hal.pipeline_layout -// CHECK-NEXT: util.global.store %[[LAYOUT]], @_pipeline_layout_0 : !hal.pipeline_layout +// CHECK-DAG: %[[DEVICE:.+]] = util.global.load @device +// CHECK-DAG: %[[SET_LAYOUT_0:.+]] = hal.descriptor_set_layout.create +// CHECK-SAME: device(%[[DEVICE]] : !hal.device) +// CHECK-SAME: flags("None") +// CHECK-SAME: bindings([ +// CHECK-SAME: #hal.descriptor_set.binding<0, storage_buffer>, +// CHECK-SAME: #hal.descriptor_set.binding<1, storage_buffer> +// CHECK-SAME: ]) : !hal.descriptor_set_layout +// CHECK-NEXT: %[[PIPELINE_LAYOUT:.+]] = hal.pipeline_layout.create +// CHECK-SAME: device(%[[DEVICE]] : !hal.device) +// CHECK-SAME: push_constants(1) +// CHECK-SAME: layouts([%[[SET_LAYOUT_0]]]) : !hal.pipeline_layout +// CHECK-NEXT: util.global.store %[[PIPELINE_LAYOUT]], @__device_pipeline_layout_0 : !hal.pipeline_layout // CHECK-LABEL: @exeLayoutLookup -util.func public @exeLayoutLookup(%device : !hal.device) -> !hal.pipeline_layout { - // CHECK: %[[LAYOUT:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout +util.func public @exeLayoutLookup() -> !hal.pipeline_layout { + %device = util.global.load @device : !hal.device + // CHECK: %[[LOADED_LAYOUT:.+]] = util.global.load @__device_pipeline_layout_0 : !hal.pipeline_layout %0 = hal.pipeline_layout.lookup device(%device : !hal.device) layout(#hal.pipeline.layout<push_constants = 1, sets = [ #hal.descriptor_set.layout<0, bindings = [ #hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer> ]> ]>) : !hal.pipeline_layout - // CHECK-NEXT: util.return %[[LAYOUT]] + // CHECK-NEXT: util.return %[[LOADED_LAYOUT]] util.return %0 : !hal.pipeline_layout } @@ -41,28 +48,25 @@ ]> ]> -// TODO(scotttodd): Test without depending on a specific HAL target? Or move to HAL/Target/*/test/? -// - If there is no matching hal.executable.variant then the executable will not be cached -hal.executable @exe { +// CHECK: hal.executable private @exe +hal.executable private @exe { + // CHECK: hal.executable.variant public @vmvx hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + // CHECK-NOT: hal.executable.condition hal.executable.condition(%device: !hal.device) -> i1 { %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 hal.return %selected : i1 } - hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) attributes { - workgroup_size = [32 : index, 1 : index, 1 : index] - } - hal.executable.export @entry0_alias ordinal(0) layout(#pipeline_layout_0) attributes { - workgroup_size = [32 : index, 1 : index, 1 : index] - } - hal.executable.export @entry1 ordinal(1) layout(#pipeline_layout_1) attributes { - workgroup_size = [32 : index, 1 : index, 1 : index] - } + hal.executable.export @entry0 ordinal(0) layout(#pipeline_layout_0) + hal.executable.export @entry0_alias ordinal(0) layout(#pipeline_layout_0) + hal.executable.export @entry1 ordinal(1) layout(#pipeline_layout_1) + // CHECK-NOT: hal.executable.constant.block hal.executable.constant.block() -> (i32, i32) as ("foo", "bar") { %c123 = arith.constant 123 : i32 %c456 = arith.constant 456 : i32 hal.return %c123, %c456 : i32, i32 } + // CHECK-NOT: hal.executable.constant.block hal.executable.constant.block(%device: !hal.device) -> i32 as "baz" { %ok, %query = hal.device.query<%device : !hal.device> key("sys" :: "baz") : i1, i32 cf.cond_br %ok, ^bb_ok, ^bb_fail @@ -75,16 +79,27 @@ } } -// CHECK-DAG: util.global private @_descriptor_set_layout_0 -// CHECK-DAG: util.global private @_pipeline_layout_0 -// CHECK-DAG: util.global private @_descriptor_set_layout_1 -// CHECK-DAG: util.global private @_pipeline_layout_1 +// CHECK: util.global private @device = #hal.device.ordinal<0> +util.global private @device = #hal.device.ordinal<0> : !hal.device -// CHECK: util.global private @_executable_exe : !hal.executable -// CHECK-NEXT: util.initializer { +// Cached resources for the device. +// CHECK: util.global private @__device_pipeline_layout_0 : !hal.pipeline_layout +// CHECK: util.global private @__device_pipeline_layout_1 : !hal.pipeline_layout +// CHECK: util.global private @__device_executable_0_exe : !hal.executable + +// Device initializer for all resources used with the device: +// CHECK: util.initializer +// CHECK: %[[DEVICE:.+]] = util.global.load @device + +// Create pipeline layouts (and required descriptor set layouts): +// CHECK: %[[SET_LAYOUT_0:.+]] = hal.descriptor_set_layout.create device(%[[DEVICE]] : !hal.device) +// CHECK: %[[SET_LAYOUT_1:.+]] = hal.descriptor_set_layout.create device(%[[DEVICE]] : !hal.device) +// CHECK: %[[PIPELINE_LAYOUT_0:.+]] = hal.pipeline_layout.create device(%[[DEVICE]] : !hal.device) push_constants(0) layouts([%[[SET_LAYOUT_0]]]) : !hal.pipeline_layout +// CHECK: util.global.store %[[PIPELINE_LAYOUT_0]], @__device_pipeline_layout_0 +// CHECK: %[[PIPELINE_LAYOUT_1:.+]] = hal.pipeline_layout.create device(%device : !hal.device) push_constants(0) layouts([%[[SET_LAYOUT_1]]]) : !hal.pipeline_layout +// CHECK: util.global.store %[[PIPELINE_LAYOUT_1]], @__device_pipeline_layout_1 // Switch on the supported formats: -// CHECK: %[[DEVICE:.+]] = hal.devices.get %{{.+}} // CHECK: %{{.+}}, %[[FORMAT_VMVX:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("hal.executable.format" :: "vmvx-bytecode-fb") // CHECK: %[[VMVX_CONDITION:.+]] = scf.execute_region -> i1 { // CHECK: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[DEVICE]] : !hal.device> key("some" :: "feature") @@ -97,20 +112,15 @@ // CHECK: %[[RET:.+]] = scf.index_switch %[[VARIANT_INDEX]] -> !hal.executable // CHECK: case 0 { -// Dependent layouts: -// CHECK: %[[LAYOUT0:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout -// CHECK: %[[LAYOUT0_2:.+]] = util.global.load @_pipeline_layout_0 : !hal.pipeline_layout -// CHECK: %[[LAYOUT1:.+]] = util.global.load @_pipeline_layout_1 : !hal.pipeline_layout - // Constant block initializers: -// CHECK: %[[CONST_01:.+]]:2 = util.call @__constant_block_0() -// CHECK: %[[CONST_2:.+]] = util.call @__constant_block_1(%[[DEVICE]]) +// CHECK: %[[CONST_01:.+]]:2 = util.call @__device_executable_0_exe_constant_block_0() +// CHECK: %[[CONST_2:.+]] = util.call @__device_executable_0_exe_constant_block_1(%[[DEVICE]]) // Executable creation: // CHECK: %[[EXE:.+]] = hal.executable.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: target(@exe::@vmvx) -// CHECK-SAME: layouts([%[[LAYOUT0]], %[[LAYOUT0_2]], %[[LAYOUT1]]]) +// CHECK-SAME: layouts([%[[PIPELINE_LAYOUT_0]], %[[PIPELINE_LAYOUT_0]], %[[PIPELINE_LAYOUT_1]]]) // CHECK-SAME: constants([%[[CONST_01]]#0, %[[CONST_01]]#1, %[[CONST_2]]]) // CHECK-SAME: : !hal.executable @@ -118,18 +128,18 @@ // CHECK: } // CHECK: default { // CHECK: %[[C14:.+]] = arith.constant 14 : i32 -// CHECK: util.status.check_ok %[[C14]], "none of the executable binaries in the module are supported by the runtime" +// CHECK: util.status.check_ok %[[C14]], "HAL device `device` does not support any variant of executable `exe`; available formats: [vmvx-bytecode-fb]" // CHECK: %[[NULL:.+]] = util.null : !hal.executable // CHECK: scf.yield %[[NULL]] : !hal.executable // CHECK: } -// CHECK: util.global.store %[[RET]], @_executable_exe : !hal.executable +// CHECK: util.global.store %[[RET]], @__device_executable_0_exe : !hal.executable -// Inlined constant block functions (here we ensure all blocks are cloned): -// CHECK: util.func private @__constant_block_0() -> (i32, i32) +// Constant block functions (here we ensure all blocks are cloned): +// CHECK: util.func private @__device_executable_0_exe_constant_block_0() -> (i32, i32) // CHECK-DAG: %[[C0:.+]] = arith.constant 123 // CHECK-DAG: %[[C1:.+]] = arith.constant 456 // CHECK: util.return %[[C0]], %[[C1]] -// CHECK: util.func private @__constant_block_1(%[[BLOCK_DEVICE:.+]]: !hal.device) -> i32 +// CHECK: util.func private @__device_executable_0_exe_constant_block_1(%[[BLOCK_DEVICE:.+]]: !hal.device) -> i32 // CHECK: %[[OK:.+]], %[[VALUE:.+]] = hal.device.query<%[[BLOCK_DEVICE]] : !hal.device> key("sys" :: "baz") // CHECK: cf.cond_br %[[OK]], ^bb1, ^bb2 // CHECK: ^bb1: @@ -139,16 +149,172 @@ // CHECK: util.return %[[DUMMY]] // CHECK-LABEL: @exeLookup -util.func public @exeLookup(%device : !hal.device) -> !hal.executable { - // CHECK: %[[EXE:.+]] = util.global.load @_executable_exe : !hal.executable +util.func public @exeLookup() -> !hal.executable { + %device = util.global.load @device : !hal.device + // CHECK: %[[EXE:.+]] = util.global.load @__device_executable_0_exe : !hal.executable %0 = hal.executable.lookup device(%device : !hal.device) - executable(@exe) : !hal.executable + executable(@exe) : !hal.executable // CHECK-NEXT: util.return %[[EXE]] util.return %0 : !hal.executable } // ----- +// Tests that fallback resources are reused instead of being created again +// when a device selects a fallback. + +// CHECK: hal.executable private @exe +hal.executable private @exe { + // CHECK: hal.executable.variant public @vmvx + hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + // CHECK-NOT: hal.executable.condition + hal.executable.condition(%device: !hal.device) -> i1 { + %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 + hal.return %selected : i1 + } + hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [ + #hal.descriptor_set.layout<0, bindings = [ + #hal.descriptor_set.binding<0, storage_buffer>, + #hal.descriptor_set.binding<1, storage_buffer> + ]> + ]>) + // CHECK-NOT: hal.executable.constant.block + hal.executable.constant.block() -> (i32, i32) as ("foo", "bar") { + %c123 = arith.constant 123 : i32 + %c456 = arith.constant 456 : i32 + hal.return %c123, %c456 : i32, i32 + } + } +} + +// CHECK: util.global private @primary_device +util.global private @primary_device = #hal.device.ordinal<0> : !hal.device +// CHECK-NEXT: util.global private @__primary_device_pipeline_layout_0 +// CHECK-NEXT: util.global private @__primary_device_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK: util.global.load @primary_device +// CHECK: hal.descriptor_set_layout.create +// CHECK: hal.pipeline_layout.create +// CHECK: util.global.store {{.+}}, @__primary_device_pipeline_layout_0 +// CHECK: hal.executable.create +// CHECK: util.global.store {{.+}}, @__primary_device_executable_0_exe +// CHECK: util.func private @__primary_device_executable_0_exe_constant_block_0 + +// CHECK: util.global private @optional_device +util.global private @optional_device = #hal.device.select<[ + #hal.device.ordinal<1> : !hal.device, + #hal.device.fallback<@primary_device> : !hal.device +]> : !hal.device +// CHECK-NEXT: util.global private @__optional_device_pipeline_layout_0 +// CHECK-NEXT: util.global private @__optional_device_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[OPTIONAL_DEVICE:.+]] = util.global.load @optional_device +// CHECK-DAG: %[[PRIMARY_DEVICE:.+]] = util.global.load @primary_device +// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE]], %[[PRIMARY_DEVICE]] +// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]] +// CHECK-DAG: scf.index_switch %[[INDEX]] +// CHECK: case 0 +// CHECK: %[[PRIMARY_LAYOUT:.+]] = util.global.load @__primary_device_pipeline_layout_0 +// CHECK: util.global.store %[[PRIMARY_LAYOUT]], @__optional_device_pipeline_layout_0 +// CHECK: %[[PRIMARY_EXE:.+]] = util.global.load @__primary_device_executable_0_exe +// CHECK: util.global.store %[[PRIMARY_EXE]], @__optional_device_executable_0_exe +// CHECK: default +// CHECK: hal.descriptor_set_layout.create +// CHECK: hal.pipeline_layout.create +// CHECK: util.global.store {{.+}}, @__optional_device_pipeline_layout_0 +// CHECK: hal.executable.create +// CHECK: util.global.store {{.+}}, @__optional_device_executable_0_exe +// CHECK: util.func private @__optional_device_executable_0_exe_constant_block_0 + +// CHECK-LABEL: @fallbackLookup +util.func public @fallbackLookup() -> (!hal.executable, !hal.executable) { + %primary_device = util.global.load @primary_device : !hal.device + // CHECK: %[[PRIMARY_EXE_LOOKUP:.+]] = util.global.load @__primary_device_executable_0_exe + %0 = hal.executable.lookup device(%primary_device : !hal.device) + executable(@exe) : !hal.executable + %optional_device = util.global.load @optional_device : !hal.device + // CHECK: %[[OPTIONAL_EXE_LOOKUP:.+]] = util.global.load @__optional_device_executable_0_exe + %1 = hal.executable.lookup device(%optional_device : !hal.device) + executable(@exe) : !hal.executable + util.return %0, %1 : !hal.executable, !hal.executable +} + +// ----- + +// Tests that resources only used by optional devices force the resources to +// be created on fallbacks. This isn't optimal as we should really only be +// creating them if the fallback is selected but that's more complex than it's +// worth today given the limited usage of fallbacks. + +hal.executable private @exe { + hal.executable.variant @vmvx target(<"vmvx", "vmvx-bytecode-fb">) { + hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [ + #hal.descriptor_set.layout<0, bindings = [ + #hal.descriptor_set.binding<0, storage_buffer> + ]> + ]>) + } +} + +// CHECK-LABEL: util.global private @primary_device +util.global private @primary_device = #hal.device.ordinal<0> : !hal.device +// CHECK-NEXT: util.global private @__primary_device_pipeline_layout_0 +// CHECK-NEXT: util.global private @__primary_device_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK: util.global.load @primary_device +// CHECK: hal.descriptor_set_layout.create +// CHECK: hal.pipeline_layout.create +// CHECK: util.global.store {{.+}}, @__primary_device_pipeline_layout_0 +// CHECK: hal.executable.create +// CHECK: util.global.store {{.+}}, @__primary_device_executable_0_exe + +// CHECK-LABEL: util.global private @optional_device_0 +util.global private @optional_device_0 = #hal.device.select<[ + #hal.device.ordinal<1> : !hal.device, + #hal.device.fallback<@primary_device> : !hal.device +]> : !hal.device +// CHECK-NEXT: util.global private @__optional_device_0_pipeline_layout_0 +// CHECK-NEXT: util.global private @__optional_device_0_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[OPTIONAL_DEVICE_0:.+]] = util.global.load @optional_device_0 +// CHECK-DAG: %[[PRIMARY_DEVICE:.+]] = util.global.load @primary_device +// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE_0]], %[[PRIMARY_DEVICE]] +// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]] +// CHECK-DAG: scf.index_switch %[[INDEX]] +// CHECK: util.global.load @__primary_device_pipeline_layout_0 +// CHECK: util.global.store {{.+}}, @__optional_device_0_pipeline_layout_0 +// CHECK: util.global.load @__primary_device_executable_0_exe +// CHECK: util.global.store {{.+}}, @__optional_device_0_executable_0_exe + +// CHECK-LABEL: util.global private @optional_device_1 +util.global private @optional_device_1 = #hal.device.select<[ + #hal.device.ordinal<2> : !hal.device, + #hal.device.fallback<@optional_device_0> : !hal.device +]> : !hal.device +// CHECK-NEXT: util.global private @__optional_device_1_pipeline_layout_0 +// CHECK-NEXT: util.global private @__optional_device_1_executable_0_exe +// CHECK-NEXT: util.initializer +// CHECK-DAG: %[[OPTIONAL_DEVICE_1:.+]] = util.global.load @optional_device_1 +// CHECK-DAG: %[[OPTIONAL_DEVICE_0:.+]] = util.global.load @optional_device_0 +// CHECK-DAG: %[[DEVICE_EQ:.+]] = util.cmp.eq %[[OPTIONAL_DEVICE_1]], %[[OPTIONAL_DEVICE_0]] +// CHECK-DAG: %[[INDEX:.+]] = arith.select %[[DEVICE_EQ]] +// CHECK-DAG: scf.index_switch %[[INDEX]] +// CHECK: util.global.load @__optional_device_0_pipeline_layout_0 +// CHECK: util.global.store {{.+}}, @__optional_device_1_pipeline_layout_0 +// CHECK: util.global.load @__optional_device_0_executable_0_exe +// CHECK: util.global.store {{.+}}, @__optional_device_1_executable_0_exe + +// CHECK-LABEL: @fallbackOnlyLookup +util.func public @fallbackOnlyLookup() -> !hal.executable { + %optional_device_1 = util.global.load @optional_device_1 : !hal.device + // CHECK: util.global.load @__optional_device_1_executable_0_exe + %0 = hal.executable.lookup device(%optional_device_1 : !hal.device) + executable(@exe) : !hal.executable + util.return %0 : !hal.executable +} + +// ----- + // Tests that materialization no-ops when resource caches have already been // materialized. Today this is rather simplistic and just bails if the names // match with the expectation being that users are mostly just running through @@ -163,6 +329,8 @@ ]> ]> +util.global private @device : !hal.device + util.global private @_descriptor_set_layout_0 : !hal.descriptor_set_layout util.initializer { %c0 = arith.constant 0 : index
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp index 65ad8c9..871087c 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -97,8 +97,18 @@ // No uses - erase the global entirely. deadOps.push_back(globalInfo->op); } else { - // If there are stores mark the global as mutable. - globalInfo->op.setGlobalMutable(!globalInfo->getStores().empty()); + // TODO(benvanik): verify we want this behavior - we likely want to change + // this to be mutable only if stores exist outside of initializers. + // + // If there are stores mark the global as mutable. We need to update all + // of the loads if this changes anything. + bool hasStores = !globalInfo->getStores().empty(); + bool didChange = globalInfo->op.isGlobalMutable() != hasStores; + globalInfo->op.setGlobalMutable(hasStores); + if (didChange) { + for (auto loadOp : globalInfo->getLoads()) + loadOp.setGlobalImmutable(!hasStores); + } } for (auto loadOp : globalInfo->getLoads()) loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable());