Adding simplified HAL dispatch methods. (#18189)
These combine push constants and push descriptor sets into the dispatch
calls as in practice we have a near 1:1 relationship anyway. Pipeline
layouts are still used in HAL interfaces to allow the compiler to map
the information but are otherwise not used by the new ops.
The `--iree-hal-experimental-dispatch2` flag enables emitting the new
ops. Since executables no longer require pipeline layouts in this
simplified model the `--iree-hal-experimental-executable-create2` flag
can be used to stop passing them; targets that support dispatch2 will
ignore them if provided. Future changes will start to add support on
targets for the simplified bindings and then remove the existing
pipeline layout-based binding model as a breaking ABI change.
Current target status:
* [x] Local/CPU: executable-create2 and executable-dispatch2 supported
(backward compat)
* [x] CUDA: executable-dispatch2 supported (backward compat)
* [x] HIP: executable-dispatch2 supported (backward compat)
* [x] Metal: executable-dispatch2 supported (backward compat)
* [x] Vulkan: executable-dispatch2 supported (backward compat)
* [x] WebGPU: executable-dispatch2 supported (backward compat)
Reworking the CUDA/HIP/Metal/Vulkan/WebGPU flatbuffers to support
executable-create2 will be done in a follow-up.
Progress on #18154.
diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
index 912b7a8..3a6337e 100644
--- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
+++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
@@ -387,12 +387,21 @@
llvmFunc->addParamAttr(i, align16);
}
- // Optionally entry points may specify that they require workgroup local
+ LibraryBuilder::DispatchAttrs dispatchAttrs = {0};
+
+ // Entry points may optionally specify that they require workgroup local
// memory. We fetch that value here and plumb it through so the runtime
// knows how much memory to reserve and pass in.
- int64_t localMemorySize = exportOp.getWorkgroupLocalMemory()
- .value_or(APInt(64, 0))
- .getSExtValue();
+ dispatchAttrs.localMemorySize = exportOp.getWorkgroupLocalMemory()
+ .value_or(APInt(64, 0))
+ .getSExtValue();
+
+ // Specify the constant and binding information used to validate
+ // dispatches.
+ // TODO(#18189): pack per-binding information bitfields.
+ dispatchAttrs.constantCount = exportOp.getLayout().getPushConstants();
+ dispatchAttrs.bindingCount =
+ exportOp.getLayout().getSetLayout(0).getBindings().size();
LibraryBuilder::SourceLocation sourceLocation;
if (options.debugLevel >= 1) {
@@ -417,8 +426,7 @@
}
libraryBuilder.addExport(exportOp.getName(), std::move(sourceLocation),
std::move(stageLocations), /*tag=*/"",
- LibraryBuilder::DispatchAttrs{localMemorySize},
- llvmFunc);
+ dispatchAttrs, llvmFunc);
}
// Embed source files (if present).
diff --git a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp
index 3c39849..21621b9 100644
--- a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp
+++ b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp
@@ -111,19 +111,22 @@
// %struct.iree_hal_executable_dispatch_attrs_v0_t = type {
// i16,
-// i16
+// i8,
+// i8
// }
static llvm::StructType *makeDispatchAttrsType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_dispatch_attrs_v0_t")) {
return existingType;
}
+ auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i16Type = llvm::IntegerType::getInt16Ty(context);
auto *type =
llvm::StructType::create(context,
{
i16Type,
- i16Type,
+ i8Type,
+ i8Type,
},
"iree_hal_executable_dispatch_attrs_v0_t",
/*isPacked=*/false);
@@ -502,7 +505,7 @@
bool hasNonDefaultAttrs = llvm::any_of(exports, [](const auto &dispatch) {
return !dispatch.attrs.isDefault();
});
- if (!hasNonDefaultAttrs) {
+ if (hasNonDefaultAttrs) {
SmallVector<llvm::Constant *> exportAttrValues;
for (auto dispatch : exports) {
exportAttrValues.push_back(llvm::ConstantStruct::get(
@@ -513,8 +516,10 @@
i16Type, roundUpToAlignment(dispatch.attrs.localMemorySize,
kWorkgroupLocalMemoryPageSize) /
kWorkgroupLocalMemoryPageSize),
- // reserved=
- llvm::ConstantInt::get(i16Type, 0),
+ // constant_count=
+ llvm::ConstantInt::get(i8Type, dispatch.attrs.constantCount),
+ // binding_count=
+ llvm::ConstantInt::get(i8Type, dispatch.attrs.bindingCount),
}));
}
exportAttrs = createArrayConstant(libraryName + "_attrs", dispatchAttrsType,
diff --git a/compiler/plugins/target/LLVMCPU/LibraryBuilder.h b/compiler/plugins/target/LLVMCPU/LibraryBuilder.h
index fd3416b..6b1ee87 100644
--- a/compiler/plugins/target/LLVMCPU/LibraryBuilder.h
+++ b/compiler/plugins/target/LLVMCPU/LibraryBuilder.h
@@ -74,16 +74,22 @@
UNDEFINED = 4u,
};
- // IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
+ // IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
static const int64_t kWorkgroupLocalMemoryPageSize = 4096;
// iree_hal_executable_dispatch_attrs_v0_t
struct DispatchAttrs {
// Required workgroup local memory size, in bytes.
int64_t localMemorySize = 0;
+ // Total number of 32-bit constants used by the dispatch.
+ uint8_t constantCount = 0;
+ // Total number of bindings used by the dispatch.
+ uint8_t bindingCount = 0;
// True if all values are default and the attributes may be omitted.
- constexpr bool isDefault() const { return localMemorySize == 0; }
+ constexpr bool isDefault() const {
+ return localMemorySize == 0 && constantCount == 0 && bindingCount == 0;
+ }
};
// iree_hal_executable_source_location_v0_t
diff --git a/compiler/plugins/target/VMVX/VMVXTarget.cpp b/compiler/plugins/target/VMVX/VMVXTarget.cpp
index b87844d..831eb8c 100644
--- a/compiler/plugins/target/VMVX/VMVXTarget.cpp
+++ b/compiler/plugins/target/VMVX/VMVXTarget.cpp
@@ -116,7 +116,9 @@
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
// Add reflection information used at runtime specific to the HAL interface.
- SymbolTable symbolTable(variantOp.getInnerModule());
+ auto vmModule =
+ *variantOp.getInnerModule().getOps<IREE::VM::ModuleOp>().begin();
+ SymbolTable symbolTable(vmModule);
for (auto exportOp : variantOp.getBlock().getOps<ExecutableExportOp>()) {
auto funcOp = symbolTable.lookup<IREE::VM::FuncOp>(exportOp.getName());
@@ -127,6 +129,24 @@
if (localMemorySizeAttr) {
funcOp.setReflectionAttr("local_memory", localMemorySizeAttr);
}
+
+ // Specify the constant and binding information used to validate
+ // dispatches.
+ // TODO(#18189): pack per-binding information bitfields.
+ if (auto layoutAttr = exportOp.getLayout()) {
+ int64_t constantCount = layoutAttr.getPushConstants();
+ if (constantCount > 0) {
+ funcOp.setReflectionAttr("constant_count",
+ executableBuilder.getI8IntegerAttr(
+ static_cast<uint8_t>(constantCount)));
+ }
+ size_t bindingCount = layoutAttr.getSetLayout(0).getBindings().size();
+ if (bindingCount > 0) {
+ funcOp.setReflectionAttr("binding_count",
+ executableBuilder.getI8IntegerAttr(
+ static_cast<uint8_t>(bindingCount)));
+ }
+ }
}
// Serialize the VM module to bytes and embed it directly.
diff --git a/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp b/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp
index 495a1a5..9b418c3 100644
--- a/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp
+++ b/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp
@@ -98,7 +98,7 @@
SmallVector<IREE::HAL::DescriptorSetBindingAttr> bindingAttrs;
bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get(
originalAttr.getContext(), 0, IREE::HAL::DescriptorType::UniformBuffer,
- std::nullopt));
+ IREE::HAL::DescriptorFlags::None));
setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get(
originalAttr.getContext(), 3, bindingAttrs, std::nullopt));
return IREE::HAL::PipelineLayoutAttr::get(originalAttr.getContext(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
index caf47bd..39b5ec8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
@@ -62,8 +62,7 @@
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = bindingAttr.getOrdinal();
setBinding.type = bindingAttr.getType();
- setBinding.flags =
- bindingAttr.getFlags().value_or(IREE::HAL::DescriptorFlags::None);
+ setBinding.flags = bindingAttr.getFlags();
setLayout.bindings[setBinding.ordinal] = setBinding;
pipelineLayout.resourceMap.emplace_back(setLayout.ordinal,
setBinding.ordinal);
@@ -123,7 +122,6 @@
// Check the usage of each binding at each dispatch site.
struct DescriptorInfo {
- bool isIndirect = false;
DescriptorFlags flags = DescriptorFlags::None;
};
SmallVector<DescriptorInfo> descriptorInfos(bindingCount);
@@ -142,12 +140,18 @@
// Opt into indirect descriptors when dynamic values are used from
// execution regions that may be executed more than once.
if (!isRegionExecutedOnce) {
- auto resource = dispatchOp.getResources()[i];
+ Value resource = dispatchOp.getResources()[i];
+ if (auto blockArg = dyn_cast<BlockArgument>(resource)) {
+ if (blockArg.getOwner()->getParentOp() == parentOp) {
+ resource = parentOp.getResourceOperands()[blockArg.getArgNumber()];
+ }
+ }
switch (categorizeValue(resource)) {
default:
case ValueOrigin::Unknown:
case ValueOrigin::MutableGlobal:
- descriptorInfo.isIndirect |= true;
+ descriptorInfo.flags =
+ descriptorInfo.flags | IREE::HAL::DescriptorFlags::Indirect;
break;
case ValueOrigin::LocalConstant:
case ValueOrigin::ImmutableGlobal:
@@ -173,74 +177,27 @@
pipelineLayout.pushConstantCount = operandCount;
pipelineLayout.resourceMap.resize(bindingCount);
- // Today we use one or two sets based on the composition of bindings we have:
- // we try to put everything in a directly referenced set 0 and spill over any
- // indirectly referenced values into the second set.
- //
- // HACK: the Vulkan HAL implementation currently cannot handle multiple
- // descriptor sets. Ouch. To preserve existing behavior we only use a single
- // set and mark the whole thing as indirect if any bindings are indirect.
- const bool forceSingleSet = true;
- if (forceSingleSet) {
- DescriptorSetLayout setLayout;
- setLayout.ordinal = 0;
- setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
- setLayout.bindings.reserve(bindingCount);
- for (unsigned i = 0; i < bindingCount; ++i) {
- const auto &descriptorInfo = descriptorInfos[i];
- if (descriptorInfo.isIndirect) {
- setLayout.flags =
- setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
- }
- DescriptorSetLayoutBinding setBinding;
- setBinding.ordinal = setLayout.bindings.size();
- setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
- setBinding.flags = descriptorInfo.flags;
- setLayout.bindings.push_back(setBinding);
- pipelineLayout.resourceMap[i] =
- std::make_pair(setLayout.ordinal, setBinding.ordinal);
+ // TODO(#18154): simplify binding setup.
+ DescriptorSetLayout setLayout;
+ setLayout.ordinal = 0;
+ setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
+ setLayout.bindings.reserve(bindingCount);
+ for (unsigned i = 0; i < bindingCount; ++i) {
+ const auto &descriptorInfo = descriptorInfos[i];
+ if (allEnumBitsSet(descriptorInfo.flags,
+ IREE::HAL::DescriptorFlags::Indirect)) {
+ setLayout.flags =
+ setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
- pipelineLayout.setLayouts.push_back(setLayout);
- } else {
- DescriptorSetLayout directSetLayout;
- directSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
- directSetLayout.bindings.reserve(bindingCount);
- DescriptorSetLayout indirectSetLayout;
- indirectSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::Indirect;
- indirectSetLayout.bindings.reserve(bindingCount);
-
- // Ordinals relative to the owning set.
- SmallVector<unsigned> bindingSetOrdinals(bindingCount);
- for (unsigned i = 0; i < bindingCount; ++i) {
- const auto &descriptorInfo = descriptorInfos[i];
- auto &setLayout =
- descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
- DescriptorSetLayoutBinding setBinding;
- setBinding.ordinal = setLayout.bindings.size();
- setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
- setBinding.flags = descriptorInfo.flags;
- setLayout.bindings.push_back(setBinding);
- bindingSetOrdinals[i] = setBinding.ordinal;
- }
- unsigned nextSetOrdinal = 0;
- if (!directSetLayout.bindings.empty()) {
- directSetLayout.ordinal = nextSetOrdinal++;
- pipelineLayout.setLayouts.push_back(directSetLayout);
- }
- if (!indirectSetLayout.bindings.empty()) {
- indirectSetLayout.ordinal = nextSetOrdinal++;
- pipelineLayout.setLayouts.push_back(indirectSetLayout);
- }
-
- // Map each resource to its set/binding ordinals.
- for (unsigned i = 0; i < bindingCount; ++i) {
- const auto &descriptorInfo = descriptorInfos[i];
- auto &setLayout =
- descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
- pipelineLayout.resourceMap[i] =
- std::make_pair(setLayout.ordinal, bindingSetOrdinals[i]);
- }
+ DescriptorSetLayoutBinding setBinding;
+ setBinding.ordinal = setLayout.bindings.size();
+ setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
+ setBinding.flags = descriptorInfo.flags;
+ setLayout.bindings.push_back(setBinding);
+ pipelineLayout.resourceMap[i] =
+ std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
+ pipelineLayout.setLayouts.push_back(setLayout);
LLVM_DEBUG({
auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp
index 07bb527..16f1aee 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp
@@ -12,16 +12,6 @@
namespace mlir::iree_compiler::IREE::HAL {
ValueOrigin categorizeValue(Value value) {
- // If this is a captured argument of an execution region then look up to the
- // SSA value that was captured.
- if (auto blockArg = dyn_cast<BlockArgument>(value)) {
- if (auto closureOp = dyn_cast<IREE::Util::ClosureOpInterface>(
- blockArg.getOwner()->getParentOp())) {
- return categorizeValue(
- closureOp.getClosureOperands()[blockArg.getArgNumber()]);
- }
- }
-
// If we wanted to pull in entire IR slices this would have to use a
// worklist (selects of globals based on globals, etc). For now this analysis
// only looks at the value provided.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index cb4179f..72716cb 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -425,6 +425,162 @@
mutable IREE::VM::ImportOp importOp;
};
+class CommandBufferDispatch2OpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferDispatch2Op> {
+public:
+ CommandBufferDispatch2OpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferDispatch2Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+
+ auto i32Type = rewriter.getI32Type();
+ auto i64Type = rewriter.getI64Type();
+ Value zeroI32 = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
+
+ auto flags = adaptor.getFlagsAttr()
+ ? rewriter
+ .create<IREE::VM::ConstI64Op>(
+ op.getLoc(), adaptor.getFlagsAttr().getInt())
+ .getResult()
+ : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
+ .getResult();
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ adaptor.getExecutable(),
+ castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
+ castToImportType(adaptor.getWorkgroupX(), i32Type, rewriter),
+ castToImportType(adaptor.getWorkgroupY(), i32Type, rewriter),
+ castToImportType(adaptor.getWorkgroupZ(), i32Type, rewriter),
+ flags,
+ };
+ SmallVector<int16_t, 5> segmentSizes = {
+ /*command_buffer=*/-1,
+ /*executable=*/-1,
+ /*entry_point=*/-1,
+ /*workgroup_x=*/-1,
+ /*workgroup_y=*/-1,
+ /*workgroup_z=*/-1,
+ /*flags=*/-1,
+ /*constants=*/static_cast<int16_t>(adaptor.getConstants().size()),
+ /*bindings=*/
+ static_cast<int16_t>(adaptor.getBindingBuffers().size()),
+ };
+ llvm::append_range(callOperands, adaptor.getConstants());
+ for (auto [bindingBufferOrSlot, bindingOffset, bindingLength] :
+ llvm::zip_equal(adaptor.getBindingBuffers(),
+ adaptor.getBindingOffsets(),
+ adaptor.getBindingLengths())) {
+ callOperands.push_back(zeroI32);
+ auto [bindingBufferSlot, bindingBuffer] =
+ splitBufferSlot(op.getLoc(), bindingBufferOrSlot, rewriter);
+ callOperands.push_back(bindingBufferSlot);
+ callOperands.push_back(bindingBuffer);
+ callOperands.push_back(
+ castToImportType(bindingOffset, i64Type, rewriter));
+ callOperands.push_back(
+ castToImportType(bindingLength, i64Type, rewriter));
+ }
+
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
+ importType.getInputs(), callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+class CommandBufferDispatch2IndirectOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferDispatch2IndirectOp> {
+public:
+ CommandBufferDispatch2IndirectOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferDispatch2IndirectOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto importType = importOp.getFunctionType();
+
+ auto i32Type = rewriter.getI32Type();
+ auto i64Type = rewriter.getI64Type();
+ Value zeroI32 = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
+
+ auto [workgroupsBufferSlot, workgroupsBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
+ auto flags = adaptor.getFlagsAttr()
+ ? rewriter
+ .create<IREE::VM::ConstI64Op>(
+ op.getLoc(), adaptor.getFlagsAttr().getInt())
+ .getResult()
+ : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
+ .getResult();
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ adaptor.getExecutable(),
+ castToImportType(adaptor.getEntryPoint(), i32Type, rewriter),
+ workgroupsBufferSlot,
+ workgroupsBuffer,
+ castToImportType(adaptor.getWorkgroupsOffset(), i64Type, rewriter),
+ flags,
+ };
+ SmallVector<int16_t, 5> segmentSizes = {
+ /*command_buffer=*/-1,
+ /*executable=*/-1,
+ /*entry_point=*/-1,
+ /*workgroups_buffer_slot=*/-1,
+ /*workgroups_buffer=*/-1,
+ /*workgroups_offset=*/-1,
+ /*flags=*/-1,
+ /*constants=*/static_cast<int16_t>(adaptor.getConstants().size()),
+ /*bindings=*/
+ static_cast<int16_t>(adaptor.getBindingBuffers().size()),
+ };
+ llvm::append_range(callOperands, adaptor.getConstants());
+ for (auto [bindingBufferOrSlot, bindingOffset, bindingLength] :
+ llvm::zip_equal(adaptor.getBindingBuffers(),
+ adaptor.getBindingOffsets(),
+ adaptor.getBindingLengths())) {
+ callOperands.push_back(zeroI32);
+ auto [bindingBufferSlot, bindingBuffer] =
+ splitBufferSlot(op.getLoc(), bindingBufferOrSlot, rewriter);
+ callOperands.push_back(bindingBufferSlot);
+ callOperands.push_back(bindingBuffer);
+ callOperands.push_back(
+ castToImportType(bindingOffset, i64Type, rewriter));
+ callOperands.push_back(
+ castToImportType(bindingLength, i64Type, rewriter));
+ }
+
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
+ importType.getInputs(), callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
} // namespace
void populateHALCommandBufferToVMPatterns(MLIRContext *context,
@@ -468,6 +624,11 @@
patterns.insert<CommandBufferDispatchIndirectOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.dispatch.indirect");
+ patterns.insert<CommandBufferDispatch2OpConversion>(
+ context, importSymbols, typeConverter, "hal.command_buffer.dispatch2");
+ patterns.insert<CommandBufferDispatch2IndirectOpConversion>(
+ context, importSymbols, typeConverter,
+ "hal.command_buffer.dispatch2.indirect");
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
index a911de8..7b0372d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp
@@ -150,6 +150,62 @@
mutable IREE::VM::ImportOp importOp;
};
+class ExecutableCreate2OpConversion
+ : public OpConversionPattern<IREE::HAL::ExecutableCreate2Op> {
+public:
+ ExecutableCreate2OpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::ExecutableCreate2Op createOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Materialize vm.rodata for the binary.
+ auto executableBinaryOp =
+ SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableBinaryOp>(
+ createOp, createOp.getExecutableTarget());
+ auto executableOp = executableBinaryOp.getOperation()
+ ->getParentOfType<IREE::HAL::ExecutableOp>();
+ std::string rodataName = sanitizeSymbolName(
+ (executableOp.getName() + "_" + executableBinaryOp.getName()).str());
+ auto rodataOp = rewriter.create<IREE::VM::RodataInlineOp>(
+ executableBinaryOp.getLoc(),
+ IREE::VM::RefType::get(rewriter.getType<IREE::VM::BufferType>()),
+ rewriter.getStringAttr(rodataName), executableBinaryOp.getData(),
+ rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr());
+
+ // Get format string as a rodata blob.
+ auto executableFormatStr = rewriter.create<IREE::VM::RodataInlineOp>(
+ createOp.getLoc(), executableBinaryOp.getFormatAttr());
+
+ // Pack constants, if any.
+ auto constantBuffer = createPackedConstantBuffer(
+ createOp.getLoc(), adaptor.getConstants(), rewriter);
+
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getDevice(),
+ executableFormatStr,
+ rodataOp,
+ constantBuffer,
+ };
+ auto importType = importOp.getFunctionType();
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ createOp, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+ copyImportAttrs(importOp, callOp);
+
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
} // namespace
void populateHALExecutableToVMPatterns(MLIRContext *context,
@@ -162,6 +218,8 @@
patterns.insert<ExecutableCreateOpConversion>(
context, importSymbols, typeConverter, "hal.executable.create");
+ patterns.insert<ExecutableCreate2OpConversion>(
+ context, importSymbols, typeConverter, "hal.executable.create2");
patterns.insert<VMImportOpConversion<IREE::HAL::DescriptorSetLayoutCreateOp>>(
context, importSymbols, typeConverter,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
index 2df6959..f005985 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -395,3 +395,103 @@
flags(None)
util.return
}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_dispatch2
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
+// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
+// CHECK-SAME: %[[SLOT:.+]]: i32)
+util.func public @command_buffer_dispatch2(
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable,
+ %buffer: !hal.buffer,
+ %slot: index
+) {
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
+ // CHECK-DAG: %[[C0:.+]] = vm.const.i32.zero
+ %ordinal = arith.constant 123 : index
+ // CHECK-DAG: %[[X:.+]] = vm.const.i32 100
+ %x = arith.constant 100 : index
+ // CHECK-DAG: %[[Y:.+]] = vm.const.i32 200
+ %y = arith.constant 200 : index
+ // CHECK-DAG: %[[Z:.+]] = vm.const.i32 300
+ %z = arith.constant 300 : index
+ // CHECK-DAG: %[[CONSTANT0:.+]] = vm.const.i32 31
+ %constant0 = arith.constant 31 : i32
+ // CHECK-DAG: %[[CONSTANT1:.+]] = vm.const.i32 32
+ %constant1 = arith.constant 32 : i32
+ %c4 = arith.constant 4 : index
+ %c4096 = arith.constant 4096 : index
+ %c8000 = arith.constant 8000 : index
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call.variadic @hal.command_buffer.dispatch2
+ // CHECK-SAME: %[[CMD]],
+ // CHECK-SAME: %[[EXECUTABLE]], %[[ORDINAL]],
+ // CHECK-SAME: %[[X]], %[[Y]], %[[Z]],
+ // CHECK-SAME: %[[FLAGS]],
+ // CHECK-SAME: [%[[CONSTANT0]], %[[CONSTANT1]]],
+ // CHECK-SAME: [(%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000),
+ // CHECK-SAME: (%[[C0]], %[[SLOT]], %[[NULL_BUFFER]], %c4, %c4096)]
+ hal.command_buffer.dispatch2<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
+ workgroups([%x, %y, %z])
+ constants([%constant0, %constant1])
+ bindings([
+ (%buffer : !hal.buffer)[%c4096, %c8000],
+ (%slot : index)[%c4, %c4096]
+ ])
+ flags(None)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: vm.func private @command_buffer_dispatch2
+// CHECK-SAME: (%[[CMD:[a-z0-9]+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:[a-z0-9]+]]: !vm.ref<!hal.executable>,
+// CHECK-SAME: %[[WORKGROUPS_SLOT:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[BUFFER:[a-z0-9]+]]: !vm.ref<!hal.buffer>,
+// CHECK-SAME: %[[SLOT:[a-z0-9]+]]: i32)
+util.func public @command_buffer_dispatch2(
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable,
+ %workgroups_slot: index,
+ %buffer: !hal.buffer,
+ %slot: index
+) {
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
+ // CHECK-DAG: %[[C0:.+]] = vm.const.i32.zero
+ %ordinal = arith.constant 123 : index
+ // CHECK-DAG: %[[WORKGROUPS_OFFSET:.+]] = vm.const.i64 100
+ %workgroups_offset = arith.constant 100 : index
+ // CHECK-DAG: %[[CONSTANT0:.+]] = vm.const.i32 31
+ %constant0 = arith.constant 31 : i32
+ // CHECK-DAG: %[[CONSTANT1:.+]] = vm.const.i32 32
+ %constant1 = arith.constant 32 : i32
+ %c4 = arith.constant 4 : index
+ %c4096 = arith.constant 4096 : index
+ %c8000 = arith.constant 8000 : index
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call.variadic @hal.command_buffer.dispatch2.indirect
+ // CHECK-SAME: %[[CMD]],
+ // CHECK-SAME: %[[EXECUTABLE]], %[[ORDINAL]],
+ // CHECK-SAME: %[[WORKGROUPS_SLOT]], %[[NULL_BUFFER]], %[[WORKGROUPS_OFFSET]],
+ // CHECK-SAME: %[[FLAGS]],
+ // CHECK-SAME: [%[[CONSTANT0]], %[[CONSTANT1]]],
+ // CHECK-SAME: [(%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000),
+ // CHECK-SAME: (%[[C0]], %[[SLOT]], %[[NULL_BUFFER]], %c4, %c4096)]
+ hal.command_buffer.dispatch2.indirect<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
+ workgroups(%workgroups_slot : index)[%workgroups_offset]
+ constants([%constant0, %constant1])
+ bindings([
+ (%buffer : !hal.buffer)[%c4096, %c8000],
+ (%slot : index)[%c4, %c4096]
+ ])
+ flags(None)
+ util.return
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
index 5dd5341..292cb47 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir
@@ -43,6 +43,45 @@
// -----
+hal.executable @exe {
+ hal.executable.binary @binary1 attributes {
+ data = dense<[0, 1, 2, 3]> : vector<4xi8>,
+ format = "format1"
+ }
+ hal.executable.binary @binary2 attributes {
+ data = dense<[4, 5, 6, 7]> : vector<4xi8>,
+ format = "format2"
+ }
+}
+
+// CHECK-LABEL: @executableCreate2
+util.func public @executableCreate2(
+ // CHECK-SAME: %[[DEV:.+]]: !vm.ref<!hal.device>
+ %device: !hal.device
+) -> (!hal.executable, !hal.executable) {
+
+ // CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_
+ // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8>
+ // CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer
+ // CHECK: %[[EXE1:.+]] = vm.call @hal.executable.create2(
+ // CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]]
+ // CHECK-SAME: ) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref<!hal.executable>
+ %0 = hal.executable.create2 device(%device : !hal.device) target(@exe::@binary1) : !hal.executable
+
+ // CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_
+ // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8>
+ // CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer
+ // CHECK: %[[EXE2:.+]] = vm.call @hal.executable.create2(
+ // CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]]
+ // CHECK-SAME: ) {nosideeffects} : (!vm.ref<!hal.device>, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref<!hal.executable>
+ %1 = hal.executable.create2 device(%device : !hal.device) target(@exe::@binary2) : !hal.executable
+
+ // CHECK: vm.return %[[EXE1]], %[[EXE2]]
+ util.return %0, %1 : !hal.executable, !hal.executable
+}
+
+// -----
+
hal.executable @exe1 {
hal.executable.binary @binary1 attributes {
data = dense<[0, 1, 2, 3]> : vector<4xi8>,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index e21a626..109701a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -30,6 +30,13 @@
llvm::cl::init(false),
};
+// TODO(#18154): switch default to true and then remove.
+static llvm::cl::opt<bool> clExperimentalDispatch2{
+ "iree-hal-experimental-dispatch2",
+ llvm::cl::desc("Whether to emit iree_hal_command_buffer_dispatch2 ops."),
+ llvm::cl::init(false),
+};
+
struct ContextResolveOpPattern
: public StreamConversionPattern<IREE::Stream::ContextResolveOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -623,8 +630,8 @@
ConversionPatternRewriter &rewriter) const override {
auto commandBufferMapping = mapping->lookupCommandBufferFor(op);
- IREE::HAL::BindingTableValue sendBinding;
- IREE::HAL::BindingTableValue recvBinding;
+ IREE::HAL::BindingValue sendBinding;
+ IREE::HAL::BindingValue recvBinding;
switch (adaptor.getOp().getKind()) {
default:
assert(adaptor.getResources().size() == 2 && "should have verified");
@@ -663,6 +670,7 @@
}
};
+// TODO(#18154): switch to dispatch2.
struct CmdDispatchOpPattern
: public StreamConversionPattern<IREE::Stream::CmdDispatchOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -845,6 +853,145 @@
}
};
+struct CmdDispatch2OpPattern
+ : public StreamConversionPattern<IREE::Stream::CmdDispatchOp> {
+ using StreamConversionPattern::StreamConversionPattern;
+ LogicalResult
+ matchAndRewrite(IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = dispatchOp.getLoc();
+ auto commandBufferMapping = mapping->lookupCommandBufferFor(dispatchOp);
+
+ // TODO(multi-device): reusable command buffers done at the stream level may
+ // make this difficult. For now we assume each stream region being lowered
+ // has a singular affinity that may itself reference multiple devices in the
+ // future but currently uniquely identifies a device.
+ auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(dispatchOp);
+
+ // Get the device handle we're executing against in this execution region.
+ // Note that this is a dynamic value: we have to treat the device as unknown
+ // here.
+ Value device = rewriter.create<IREE::HAL::CommandBufferDeviceOp>(
+ loc, rewriter.getType<IREE::HAL::DeviceType>(),
+ commandBufferMapping.getHandle());
+
+ // Prepare for variant switch table by gathering the conditions selecting
+ // each variant.
+ SmallVector<int64_t> caseIndices;
+ SmallVector<std::pair<SymbolRefAttr, IREE::HAL::ExecutableExportOp>>
+ caseExportOps;
+ dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) {
+ // NOTE: slow lookup!
+ auto exportOp =
+ SymbolTable::lookupNearestSymbolFrom<IREE::HAL::ExecutableExportOp>(
+ dispatchOp, entryPointAttr);
+ assert(exportOp && "dispatch target export not found");
+ caseIndices.push_back(caseIndices.size());
+ caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp));
+ });
+
+ // If there is only one variant we can emit that directly without a
+ // conditional check. The same result should occur later on but it saves
+ // a lot of IR during generation if we know we can avoid it.
+ if (caseExportOps.size() == 1) {
+ auto [entryPointAttr, exportOp] = caseExportOps.front();
+ rewriter.replaceOp(dispatchOp,
+ emitDispatchOp(loc, affinityAttr, device,
+ commandBufferMapping, exportOp,
+ entryPointAttr, dispatchOp, adaptor,
+ rewriter));
+ } else {
+ // Select the variant index.
+ Value selectedIndex = buildIfElseTree(
+ loc, caseExportOps.size(),
+ [&](Location loc, size_t i, OpBuilder &builder) {
+ auto exportOp = caseExportOps[i].second;
+ auto variantOp =
+ exportOp->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ return variantOp.buildCondition(device, rewriter);
+ },
+ rewriter);
+
+ // Allow each variant to define how it is dispatched.
+ auto switchOp = rewriter.create<scf::IndexSwitchOp>(
+ loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size());
+ for (size_t i = 0; i < caseExportOps.size(); ++i) {
+ auto [entryPointAttr, exportOp] = caseExportOps[i];
+ auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
+ auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
+ emitDispatchOp(loc, affinityAttr, device, commandBufferMapping,
+ exportOp, entryPointAttr, dispatchOp, adaptor,
+ caseBuilder);
+ caseBuilder.create<scf::YieldOp>(loc);
+ }
+
+ // Fallback for no available variant. Today we just no-op as executable
+ // loading should have already failed.
+ auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock();
+ auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock);
+ defaultBuilder.create<scf::YieldOp>(loc);
+
+ rewriter.replaceOp(dispatchOp, switchOp);
+ }
+
+ return success();
+ }
+
+ Operation *emitDispatchOp(
+ Location loc, IREE::Stream::AffinityAttr affinityAttr, Value device,
+ CommandBufferConversionMapping &commandBufferMapping,
+ IREE::HAL::ExecutableExportOp exportOp, SymbolRefAttr entryPointAttr,
+ IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor,
+ OpBuilder &builder) const {
+ auto workgroupCount = exportOp.calculateWorkgroupCount(
+ loc, device, adaptor.getWorkload(), builder);
+
+ Value executable = builder.create<IREE::HAL::ExecutableLookupOp>(
+ loc, builder.getType<IREE::HAL::ExecutableType>(), device,
+ entryPointAttr.getRootReference().getValue());
+ Value ordinal = builder.create<IREE::HAL::ExecutableExportOrdinalOp>(
+ loc, builder.getIndexType(), entryPointAttr);
+
+ // TODO(#18154): simplify bindings by removing descriptor sets.
+ auto layoutAttr = exportOp.getLayout();
+ auto bindingAttrs = IREE::HAL::getInterfaceBindingAttrs(
+ exportOp, dispatchOp.getResources().size());
+ SmallVector<IREE::HAL::BindingValue> bindings;
+ for (auto [i, bindingAttr] : llvm::enumerate(bindingAttrs)) {
+ auto descriptorFlags = layoutAttr.getSetLayout(bindingAttr.getSet())
+ .getBinding(i)
+ .getFlags();
+ IREE::HAL::BindingValue binding;
+ if (bitEnumContainsAll(descriptorFlags,
+ IREE::HAL::DescriptorFlags::Indirect)) {
+ // Indirect binding resolved through the cached command buffer binding
+ // table. The buffer recorded in the descriptor is a slot ordinal into
+ // the binding table. Note that the range may be adjusted based on the
+ // range bound to the slot in the table.
+ auto resolvedBinding = commandBufferMapping.resolveBinding(
+ loc, dispatchOp.getResources()[i], adaptor.getResources()[i],
+ adaptor.getResourceOffsets()[i], adaptor.getResourceLengths()[i],
+ builder);
+ binding.buffer = resolvedBinding.buffer;
+ binding.byteOffset = resolvedBinding.byteOffset;
+ binding.byteLength = resolvedBinding.byteLength;
+ } else {
+ // Direct binding referencing the buffer and range provided on the op.
+ binding.buffer = adaptor.getResources()[i];
+ binding.byteOffset = adaptor.getResourceOffsets()[i];
+ binding.byteLength = adaptor.getResourceLengths()[i];
+ }
+ bindings.push_back(binding);
+ }
+
+ auto flags = IREE::HAL::DispatchFlags::None;
+
+ return builder.create<IREE::HAL::CommandBufferDispatch2Op>(
+ loc, commandBufferMapping.getHandle(), executable, ordinal,
+ workgroupCount, adaptor.getUniformOperands(), bindings, flags);
+ }
+};
+
struct CmdFuncOpPattern
: public StreamConversionPattern<IREE::Stream::CmdFuncOp> {
using StreamConversionPattern::StreamConversionPattern;
@@ -1408,9 +1555,15 @@
patterns
.insert<CmdFlushOpPattern, CmdInvalidateOpPattern, CmdDiscardOpPattern,
CmdFillOpPattern, CmdCopyOpPattern, CmdCollectiveOpPattern,
- CmdDispatchOpPattern, CmdFuncOpPattern, CmdCallOpPattern,
- CmdExecuteOpPattern, CmdSerialOpPattern, CmdConcurrentOpPattern>(
+ CmdFuncOpPattern, CmdCallOpPattern, CmdExecuteOpPattern,
+ CmdSerialOpPattern, CmdConcurrentOpPattern>(
mapping, typeConverter, context);
+ // TODO(#18154): drop existing pattern.
+ if (clExperimentalDispatch2) {
+ patterns.insert<CmdDispatch2OpPattern>(mapping, typeConverter, context);
+ } else {
+ patterns.insert<CmdDispatchOpPattern>(mapping, typeConverter, context);
+ }
patterns.insert<TimepointImmediateOpPattern, TimepointImportOpPattern,
TimepointExportOpPattern, TimepointChainExternalOpPattern,
TimepointJoinOpPattern, TimepointBarrierOpPattern,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
index 7cc68a5..c5df454 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.cpp
@@ -275,11 +275,11 @@
// consumer region. This would require emitting ops that track that
// information (probably via util.range.min/max). For now we bind the
// entire buffer range and let the individual commands subrange them.
- IREE::HAL::BindingTableValue bindingTableValue;
- bindingTableValue.buffer = bufferValue;
- bindingTableValue.byteOffset = indexSet.get(0);
- bindingTableValue.byteLength = bufferSize;
- indirectBuffers.push_back(bindingTableValue);
+ IREE::HAL::BindingValue bindingValue;
+ bindingValue.buffer = bufferValue;
+ bindingValue.byteOffset = indexSet.get(0);
+ bindingValue.byteLength = bufferSize;
+ indirectBuffers.push_back(bindingValue);
}
break;
}
@@ -300,26 +300,26 @@
return std::nullopt;
}
-IREE::HAL::BindingTableValue CommandBufferConversionMapping::resolveBinding(
+IREE::HAL::BindingValue CommandBufferConversionMapping::resolveBinding(
Location loc, Value resourceValue, Value bufferValue, Value useOffset,
Value useLength, OpBuilder &builder) {
- IREE::HAL::BindingTableValue bindingTableValue;
+ IREE::HAL::BindingValue bindingValue;
// Try to resolve the resource to a slot. If not found then it's a direct
// reference and we use the buffer provided.
auto slot = bindingTable.lookupResourceSlot(resourceValue);
if (slot.has_value()) {
- bindingTableValue.buffer = slot.value();
+ bindingValue.buffer = slot.value();
} else {
- bindingTableValue.buffer = bufferValue;
+ bindingValue.buffer = bufferValue;
}
// TODO(benvanik): adjust range by the binding table base index. Today all
// binding table entries are the full buffers starting at zero.
- bindingTableValue.byteOffset = useOffset;
- bindingTableValue.byteLength = useLength;
+ bindingValue.byteOffset = useOffset;
+ bindingValue.byteLength = useLength;
- return bindingTableValue;
+ return bindingValue;
}
void StreamConversionMapping::mapCommandBuffer(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h
index 925220d..50f0b0d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h
@@ -84,7 +84,7 @@
size_t size() const { return indirectBuffers.size(); }
// Builds a binding table (buffer, offset, length) based on the analysis.
- ArrayRef<IREE::HAL::BindingTableValue> getValues() { return indirectBuffers; }
+ ArrayRef<IREE::HAL::BindingValue> getValues() { return indirectBuffers; }
// Returns the binding table slot for the given resource, if it's used
// indirectly.
@@ -94,7 +94,7 @@
// True if any ops are nested that may prevent binding table usage.
bool hasUnsupportedOps = false;
// Buffer binding table with <buffer, offset, length>.
- SmallVector<IREE::HAL::BindingTableValue> indirectBuffers;
+ SmallVector<IREE::HAL::BindingValue> indirectBuffers;
// A mapping of resources to binding table slot ordinals.
DenseMap<Value, Value> indirectSlots;
};
@@ -111,10 +111,9 @@
// The returned range may differ from the provided used range in cases where
// an indirect binding table reference may have already factored in the
// offset.
- IREE::HAL::BindingTableValue resolveBinding(Location loc, Value resourceValue,
- Value bufferValue,
- Value useOffset, Value useLength,
- OpBuilder &builder);
+ IREE::HAL::BindingValue resolveBinding(Location loc, Value resourceValue,
+ Value bufferValue, Value useOffset,
+ Value useLength, OpBuilder &builder);
private:
Value handle;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel
index 2d6f777..12dc520 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel
@@ -17,6 +17,7 @@
srcs = enforce_glob(
[
"channel_ops.mlir",
+ "cmd_dispatch2_ops.mlir",
"cmd_ops.mlir",
"context_ops.mlir",
"debug_ops.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt
index b273190..0aeea90 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"channel_ops.mlir"
+ "cmd_dispatch2_ops.mlir"
"cmd_ops.mlir"
"context_ops.mlir"
"debug_ops.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_dispatch2_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_dispatch2_ops.mlir
new file mode 100644
index 0000000..ce9a4ad
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_dispatch2_ops.mlir
@@ -0,0 +1,114 @@
+// RUN: iree-opt --split-input-file --iree-hal-conversion --cse --iree-hal-indirect-command-buffers=true --iree-hal-experimental-dispatch2=true %s | FileCheck %s
+
+#executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64">
+#executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64">
+#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer, Indirect>
+ ]>
+]>
+hal.executable private @ex {
+ hal.executable.variant public @aarch64 target(#executable_target_aarch64) {
+ hal.executable.condition(%device: !hal.device) -> i1 {
+ %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1
+ hal.return %selected : i1
+ }
+ hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ // Opaque at this point (in some target-specific dialects).
+ }
+ }
+ hal.executable.variant public @x86_64 target(#executable_target_x86_64) {
+ hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes {
+ translation_info = #iree_codegen.translation_info<CPUDefault>
+ } {
+ ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0]
+ hal.return %0, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ // Opaque at this point (in some target-specific dialects).
+ }
+ }
+}
+
+util.global private @device : !hal.device
+util.global private @constant_resource : !stream.resource<constant>
+util.global private @constant_size : index
+
+// CHECK-LABEL: @cmdDispatch
+// CHECK-SAME: (%[[ARG_RESOURCE:.+]]: !hal.buffer, %[[ARG_SIZE:.+]]: index)
+util.func public @cmdDispatch(%arg_resource: !stream.resource<external>, %arg_size: index) -> !stream.timepoint {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4_i32 = arith.constant 4 : i32
+ %c5_i32 = arith.constant 5 : i32
+ %c128 = arith.constant 128 : index
+ // CHECK-DAG: %[[CONSTANT_RESOURCE:.+]] = util.global.load immutable @constant_resource
+ %constant_resource = util.global.load immutable @constant_resource : !stream.resource<constant>
+ %constant_size = util.global.load immutable @constant_size : index
+ // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device
+ // CHECK: %[[MEMOIZED_CMD:.+]] = hal.device.memoize
+ // CHECK: %[[CMD:.+]] = hal.command_buffer.create
+ %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%constant_resource as %constant_capture: !stream.resource<constant>{%constant_size}, %arg_resource as %arg_capture: !stream.resource<external>{%arg_size}) {
+ // Switch for each executable variant by checking conditions and ranking:
+ // CHECK: %[[CMD_DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer>
+ // CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[CMD_DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64")
+ // CHECK-DAG: %[[AARCH64_FEATURE:.+]] = scf.execute_region -> i1 {
+ // CHECK-NEXT: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[CMD_DEVICE]] : !hal.device> key("some" :: "feature")
+ // CHECK-NEXT: scf.yield %[[FEATURE]]
+ // CHECK-NEXT: }
+ // CHECK-DAG: %[[AARCH64_SELECTED:.+]] = arith.andi %[[AARCH64_FORMAT]], %[[AARCH64_FEATURE]]
+ // CHECK-DAG: %{{.+}}, %[[X86_64_SELECTED:.+]] = hal.device.query<%[[CMD_DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64")
+ // CHECK: %[[VARIANT1:.+]] = arith.select %[[X86_64_SELECTED]], %c1
+ // CHECK: %[[VARIANT0:.+]] = arith.select %[[AARCH64_SELECTED]], %c0, %[[VARIANT1]]
+ // CHECK: scf.index_switch %[[VARIANT0]]
+ // CHECK-NEXT: case 0 {
+
+ // Inlined workgroup count calculation:
+ // CHECK: %[[X:.+]] = affine.apply #map()[%c1]
+
+ // Target executable/export:
+ // CHECK-DAG: %[[EXECUTABLE_0:.+]] = hal.executable.lookup
+ // CHECK-SAME: device(%[[CMD_DEVICE]] : !hal.device)
+ // CHECK-SAME: executable(@ex) : !hal.executable
+ // CHECK-DAG: %[[ORDINAL_0:.+]] = hal.executable.export.ordinal
+ // CHECK-SAME: target(@ex::@aarch64::@dispatch) : index
+
+ // Dispatch:
+ // CHECK: hal.command_buffer.dispatch2<%[[CMD]]
+ // CHECK-SAME: target(%[[EXECUTABLE_0]] : !hal.executable)[%[[ORDINAL_0]]]
+ // CHECK-SAME: workgroups([%[[X]], %c1, %c1])
+ // CHECK-SAME: constants([%c4_i32, %c5_i32])
+ // CHECK-SAME: bindings([
+ // CHECK-NEXT: (%[[CONSTANT_RESOURCE]] : !hal.buffer)[%c0, %c128],
+ // CHECK-NEXT: (%c0 : index)[%c0, %c128]
+
+ // Other variant, when selected:
+ // CHECK: case 1 {
+ // CHECK-DAG: %[[ORDINAL_1:.+]] = hal.executable.export.ordinal target(@ex::@x86_64::@dispatch)
+ // CHECK: hal.command_buffer.dispatch2<%[[CMD]]
+ // CHECK-SAME: target({{.+}})[%[[ORDINAL_1]]]
+ stream.cmd.dispatch {@ex::@aarch64::@dispatch, @ex::@x86_64::@dispatch}[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) {
+ ro %constant_capture[%c0 for %c128] : !stream.resource<constant>{%constant_size},
+ wo %arg_capture[%c0 for %c128] : !stream.resource<external>{%arg_size}
+ }
+ // CHECK: hal.command_buffer.execution_barrier<%[[CMD]]
+ } => !stream.timepoint
+ // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]]
+ // CHECK: hal.device.queue.execute.indirect<%[[DEVICE]] : !hal.device> {{.+}} commands(%[[MEMOIZED_CMD]]) bindings([
+ // CHECK-NEXT: (%[[ARG_RESOURCE]] : !hal.buffer)[%c0, %[[ARG_SIZE]]]
+ // CHECK-NEXT: ])
+ util.return %0 : !stream.timepoint
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 2933f27..78fdad3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -168,10 +168,12 @@
def HAL_DescriptorFlags_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_DescriptorFlags_ReadOnly : I32BitEnumAttrCase<"ReadOnly", 0x0001>;
+def HAL_DescriptorFlags_Indirect : I32BitEnumAttrCase<"Indirect", 0x0002>;
def HAL_DescriptorFlagsAttr :
I32BitEnumAttr<"DescriptorFlags", "valid Descriptor flags", [
HAL_DescriptorFlags_None,
HAL_DescriptorFlags_ReadOnly,
+ HAL_DescriptorFlags_Indirect,
]> {
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
@@ -387,7 +389,7 @@
let parameters = (ins
AttrParameter<"int64_t", "">:$ordinal,
AttrParameter<"DescriptorType", "">:$type,
- OptionalParameter<"std::optional<DescriptorFlags>">:$flags
+ OptionalParameter<"DescriptorFlags", "DescriptorFlags::None">:$flags
);
let assemblyFormat = [{
`<` $ordinal `,` $type (`,` $flags^)? `>`
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
index fcf5ae4..811789c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp
@@ -104,9 +104,7 @@
fn(IntegerAttr::get(IndexType::get(context),
APInt(64, bindingAttr.getOrdinal())));
fn(IREE::HAL::DescriptorTypeAttr::get(context, bindingAttr.getType()));
- fn(IREE::HAL::DescriptorFlagsAttr::get(
- context,
- bindingAttr.getFlags().value_or(IREE::HAL::DescriptorFlags::None)));
+ fn(IREE::HAL::DescriptorFlagsAttr::get(context, bindingAttr.getFlags()));
return success();
}
if (auto dtAttr = llvm::dyn_cast<IREE::HAL::DescriptorTypeAttr>(attr)) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index e90d0e4..6b787a7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -112,6 +112,61 @@
}
//===----------------------------------------------------------------------===//
+// custom<Bindings>($binding_buffers,
+// type($binding_buffers),
+// $binding_offsets,
+// $binding_lengths)
+//===----------------------------------------------------------------------===//
+
+static ParseResult
+parseBindings(OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers,
+ SmallVectorImpl<Type> &bufferTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) {
+ do {
+ OpAsmParser::UnresolvedOperand buffer;
+ Type bufferType;
+ OpAsmParser::UnresolvedOperand bufferOffset;
+ OpAsmParser::UnresolvedOperand bufferLength;
+ if (failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
+ failed(parser.parseColonType(bufferType)) ||
+ failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(bufferOffset)) ||
+ failed(parser.parseComma()) ||
+ failed(parser.parseOperand(bufferLength)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+ buffers.push_back(buffer);
+ bufferTypes.push_back(bufferType);
+ bufferOffsets.push_back(bufferOffset);
+ bufferLengths.push_back(bufferLength);
+ } while (succeeded(parser.parseOptionalComma()));
+ return success();
+}
+
+static void printBindings(OpAsmPrinter &p, Operation *op, ValueRange buffers,
+ TypeRange bufferTypes, ValueRange bufferOffsets,
+ ValueRange bufferLengths) {
+ llvm::interleaveComma(
+ llvm::zip_equal(buffers, bufferTypes, bufferOffsets, bufferLengths), p,
+ [&](std::tuple<Value, Type, Value, Value> it) {
+ p.printNewline();
+ p << " (";
+ p.printOperand(std::get<0>(it));
+ p << " : ";
+ p.printType(std::get<1>(it));
+ p << ")[";
+ p.printOperand(std::get<2>(it));
+ p << ", ";
+ p.printOperand(std::get<3>(it));
+ p << "]";
+ });
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
// custom<BindingTable>($binding_buffers,
// type($binding_buffers),
// $binding_offsets,
@@ -1055,6 +1110,108 @@
}
//===----------------------------------------------------------------------===//
+// hal.command_buffer.dispatch2 + .indirect
+//===----------------------------------------------------------------------===//
+
+void CommandBufferDispatch2Op::build(OpBuilder &builder, OperationState &state,
+ Value commandBuffer, Value executable,
+ Value entryPoint, ValueRange workgroups,
+ ValueRange constants,
+ ArrayRef<BindingValue> bindings,
+ IREE::HAL::DispatchFlags flags) {
+ state.addOperands({commandBuffer, executable, entryPoint});
+ state.addOperands(workgroups);
+ state.addOperands(constants);
+ SmallVector<Value> bindingBuffers;
+ SmallVector<Value> bindingOffsets;
+ SmallVector<Value> bindingLengths;
+ for (auto binding : bindings) {
+ bindingBuffers.push_back(binding.buffer);
+ bindingOffsets.push_back(binding.byteOffset);
+ bindingLengths.push_back(binding.byteLength);
+ }
+ state.addOperands(bindingBuffers);
+ state.addOperands(bindingOffsets);
+ state.addOperands(bindingLengths);
+ state.addAttribute("flags",
+ builder.getAttr<IREE::HAL::DispatchFlagsAttr>(flags));
+ state.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ static_cast<int32_t>(constants.size()),
+ static_cast<int32_t>(bindingBuffers.size()),
+ static_cast<int32_t>(bindingOffsets.size()),
+ static_cast<int32_t>(bindingLengths.size()),
+ }));
+}
+
+void CommandBufferDispatch2IndirectOp::build(
+ OpBuilder &builder, OperationState &state, Value commandBuffer,
+ Value executable, Value entryPoint, Value workgroupsBuffer,
+ Value workgroupsOffset, ValueRange constants,
+ ArrayRef<BindingValue> bindings, IREE::HAL::DispatchFlags flags) {
+ state.addOperands({commandBuffer, executable, entryPoint, workgroupsBuffer,
+ workgroupsOffset});
+ state.addOperands(constants);
+ SmallVector<Value> bindingBuffers;
+ SmallVector<Value> bindingOffsets;
+ SmallVector<Value> bindingLengths;
+ for (auto binding : bindings) {
+ bindingBuffers.push_back(binding.buffer);
+ bindingOffsets.push_back(binding.byteOffset);
+ bindingLengths.push_back(binding.byteLength);
+ }
+ state.addOperands(bindingBuffers);
+ state.addOperands(bindingOffsets);
+ state.addOperands(bindingLengths);
+ state.addAttribute("flags",
+ builder.getAttr<IREE::HAL::DispatchFlagsAttr>(flags));
+ state.addAttribute(getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ static_cast<int32_t>(constants.size()),
+ static_cast<int32_t>(bindingBuffers.size()),
+ static_cast<int32_t>(bindingOffsets.size()),
+ static_cast<int32_t>(bindingLengths.size()),
+ }));
+}
+
+static LogicalResult verifyDispatch2Bindings(Operation *op,
+ ValueRange bindingBuffers,
+ ValueRange bindingOffsets,
+ ValueRange bindingLengths) {
+ if (bindingBuffers.size() != bindingOffsets.size() ||
+ bindingBuffers.size() != bindingLengths.size()) {
+ return op->emitOpError() << "requires that binding fields all have the "
+ "same number of elements";
+ }
+ return success();
+}
+
+LogicalResult CommandBufferDispatch2Op::verify() {
+ CommandBufferDispatch2Op op = *this;
+ return verifyDispatch2Bindings(op, op.getBindingBuffers(),
+ op.getBindingOffsets(),
+ op.getBindingLengths());
+}
+
+LogicalResult CommandBufferDispatch2IndirectOp::verify() {
+ CommandBufferDispatch2IndirectOp op = *this;
+ return verifyDispatch2Bindings(op, op.getBindingBuffers(),
+ op.getBindingOffsets(),
+ op.getBindingLengths());
+}
+
+//===----------------------------------------------------------------------===//
// hal.descriptor_set_layout.create
//===----------------------------------------------------------------------===//
@@ -1165,7 +1322,7 @@
OperationState &state, Value device,
Value queueAffinity, Value waitFence,
Value signalFence, Value commandBuffer,
- ArrayRef<BindingTableValue> bindings) {
+ ArrayRef<BindingValue> bindings) {
state.addOperands(
{device, queueAffinity, waitFence, signalFence, commandBuffer});
SmallVector<Value> bindingBuffers;
@@ -1749,6 +1906,16 @@
}
//===----------------------------------------------------------------------===//
+// hal.executable.create2
+//===----------------------------------------------------------------------===//
+
+void ExecutableCreate2Op::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ // TODO(benvanik): name after sanitized symbol.
+ setNameFn(getResult(), StringRef("executable"));
+}
+
+//===----------------------------------------------------------------------===//
// hal.executable.lookup
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index a268dbc..a6fe1f2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1469,6 +1469,7 @@
}];
}
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
def HAL_CommandBufferPushConstantsOp : HAL_Op<"command_buffer.push_constants"> {
let summary = [{command buffer push constants operation}];
let description = [{
@@ -1496,6 +1497,7 @@
}];
}
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
def HAL_CommandBufferPushDescriptorSetOp : HAL_Op<"command_buffer.push_descriptor_set", [
SameVariadicOperandSize,
]> {
@@ -1541,6 +1543,7 @@
let hasCanonicalizer = 1;
}
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> {
let summary = [{command buffer dispatch recording operation}];
let description = [{
@@ -1571,6 +1574,7 @@
}];
}
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indirect"> {
let summary = [{command buffer indirect dispatch recording operation}];
let description = [{
@@ -1598,6 +1602,139 @@
}];
}
+def HAL_CommandBufferDispatch2Op : HAL_Op<"command_buffer.dispatch2", [
+ AttrSizedOperandSegments,
+]> {
+ let summary = [{command buffer dispatch recording operation}];
+ let description = [{
+ Dispatches an execution request.
+ The request may execute overlapped with any other transfer operation or
+ dispatch made within the same barrier-defined sequence.
+
+ The provided constant data and binding list will be recorded into the
+ command buffer and need not remain live beyond the call. Push constants are
+ always 4-byte values and treated as opaque, meaning that they may be
+ bit-casted floats, bit-packed booleans, etc. The provided buffers may either
+ be HAL buffers or indirect references into the command buffer binding table.
+ }];
+
+ let arguments = (ins
+ HAL_CommandBuffer:$command_buffer,
+ HAL_Executable:$executable,
+ HAL_Ordinal:$entry_point,
+ HAL_Dim:$workgroup_x,
+ HAL_Dim:$workgroup_y,
+ HAL_Dim:$workgroup_z,
+ Variadic<I32>:$constants,
+ Variadic<AnyTypeOf<[Index, HAL_BufferType]>>:$binding_buffers,
+ Variadic<HAL_DeviceSize>:$binding_offsets,
+ Variadic<HAL_DeviceSize>:$binding_lengths,
+ HAL_DispatchFlagsAttr:$flags
+ );
+
+ let assemblyFormat = [{
+ `<` $command_buffer `:` type($command_buffer) `>`
+ `target` `(` $executable `:` type($executable) `)`
+ `` `[` $entry_point `]`
+ `workgroups` `(` `[`
+ $workgroup_x `,`
+ $workgroup_y `,`
+ $workgroup_z
+ `]` `)`
+ (`constants` `(` `[` $constants^ `]` `)`)?
+ `bindings` `(` `[`
+ custom<Bindings>($binding_buffers,
+ type($binding_buffers),
+ $binding_offsets,
+ $binding_lengths)
+ `]` `)`
+ `flags` `(` $flags `)`
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "Value":$commandBuffer,
+ "Value":$executable,
+ "Value":$entryPoint,
+ "ValueRange":$workgroups,
+ "ValueRange":$constants,
+ "ArrayRef<BindingValue>":$bindings,
+ "IREE::HAL::DispatchFlags":$flags
+ )>,
+ ];
+
+ let hasVerifier = 1;
+}
+
+def HAL_CommandBufferDispatch2IndirectOp : HAL_Op<"command_buffer.dispatch2.indirect", [
+ AttrSizedOperandSegments,
+]> {
+ let summary = [{command buffer indirect dispatch recording operation}];
+ let description = [{
+ Dispatches an execution request with a deferred workgroup count.
+ This is the same as iree_hal_command_buffer_dispatch but the workgroup count
+ is read from the given |workgroups_ref| buffer at the specified offset as
+ 3 uint32_t XYZ values immediately before performing the dispatch. This
+ allows prior dispatches within the command sequence to populate the
+ workgroup count or the workgroup count to change across submissions of the
+ same reusable command buffer.
+
+ The provided constant data and binding list will be recorded into the
+ command buffer and need not remain live beyond the call. Push constants are
+ always 4-byte values and treated as opaque, meaning that they may be
+ bit-casted floats, bit-packed booleans, etc. The provided buffers may either
+ be HAL buffers or indirect references into the command buffer binding table.
+ }];
+
+ let arguments = (ins
+ HAL_CommandBuffer:$command_buffer,
+ HAL_Executable:$executable,
+ HAL_Ordinal:$entry_point,
+ AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer,
+ HAL_DeviceSize:$workgroups_offset,
+ Variadic<I32>:$constants,
+ Variadic<AnyTypeOf<[Index, HAL_BufferType]>>:$binding_buffers,
+ Variadic<HAL_DeviceSize>:$binding_offsets,
+ Variadic<HAL_DeviceSize>:$binding_lengths,
+ HAL_DispatchFlagsAttr:$flags
+ );
+
+ let assemblyFormat = [{
+ `<` $command_buffer `:` type($command_buffer) `>`
+ `target` `(` $executable `:` type($executable) `)`
+ `` `[` $entry_point `]`
+ `workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)`
+ `` `[` $workgroups_offset `]`
+ (`constants` `(` `[` $constants^ `]` `)`)?
+ `bindings` `(` `[`
+ custom<Bindings>($binding_buffers,
+ type($binding_buffers),
+ $binding_offsets,
+ $binding_lengths)
+ `]` `)`
+ `flags` `(` $flags `)`
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "Value":$commandBuffer,
+ "Value":$executable,
+ "Value":$entryPoint,
+ "Value":$workgroupsBuffer,
+ "Value":$workgroupsOffset,
+ "ValueRange":$constants,
+ "ArrayRef<BindingValue>":$bindings,
+ "IREE::HAL::DispatchFlags":$flags
+ )>,
+ ];
+
+ let hasVerifier = 1;
+}
+
} // OpGroupCommandBufferOps
//===----------------------------------------------------------------------===//
@@ -2060,7 +2197,7 @@
"Value":$waitFence,
"Value":$signalFence,
"Value":$commandBuffer,
- "ArrayRef<BindingTableValue>":$bindings
+ "ArrayRef<IREE::HAL::BindingValue>":$bindings
)>,
];
@@ -2663,6 +2800,7 @@
];
}
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
def HAL_ExecutableCreateOp : HAL_PureOp<"executable.create", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AttrSizedOperandSegments,
@@ -2702,6 +2840,42 @@
}];
}
+def HAL_ExecutableCreate2Op : HAL_PureOp<"executable.create2", [
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+]> {
+ let summary = [{creates an executable}];
+ let description = [{
+ Creates a target-dependent executable cached on the provided device. Entry
+ points contained within the executable can be dispatched using the resulting
+ executable handle.
+
+ Depending on the driver creation may take a non-trivial amount of time
+ (such as when JITing/etc). As the cache is internally synchronized callers
+ can issue preparation requests from multiple threads - even for the same
+ executables - and calls will block until preparation completes.
+
+ Optional constants provide for specialization of the executable based on
+ runtime-derived parameters.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ SymbolRefAttr:$executable_target,
+ Variadic<I32>:$constants
+ );
+ let results = (outs
+ HAL_Executable:$result
+ );
+
+ let assemblyFormat = [{
+ `device` `(` $device `:` type($device) `)`
+ `target` `(` $executable_target `)`
+ (`constants` `(` `[` $constants^ `]` `)`)?
+ `:` type($result)
+ attr-dict-with-keyword
+ }];
+}
+
def HAL_ExecutableLookupOp : HAL_PureOp<"executable.lookup", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
]> {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index ef6417d..cdb29e0 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -185,7 +185,10 @@
Value byteLength;
};
-struct BindingTableValue {
+// A tuple containing runtime values for a binding.
+// The buffer specified may be either a !hal.buffer or an index of a binding
+// table slot to source the buffer from.
+struct BindingValue {
Value buffer;
Value byteOffset;
Value byteLength;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
index ec0cbdd..d04c7d6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
@@ -5,8 +5,8 @@
"descriptor_set_layout_binding.basic"() {
// CHECK: dslb0 = #hal.descriptor_set.binding<0, uniform_buffer>
dslb0 = #hal.descriptor_set.binding<0, uniform_buffer>,
- // CHECK: dslb1 = #hal.descriptor_set.binding<1, storage_buffer>
- dslb1 = #hal.descriptor_set.binding<1, storage_buffer>
+ // CHECK: dslb1 = #hal.descriptor_set.binding<1, storage_buffer, "ReadOnly|Indirect">
+ dslb1 = #hal.descriptor_set.binding<1, storage_buffer, "ReadOnly|Indirect">
} : () -> ()
// -----
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
index 7408ad9..c3ee554 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir
@@ -160,6 +160,19 @@
// -----
+// CHECK-LABEL: @executable_create2
+// CHECK-SAME: %[[DEVICE:.+]]: !hal.device
+util.func public @executable_create2(%device: !hal.device) {
+ // CHECK: = hal.executable.create
+ // CHECK-SAME: device(%[[DEVICE]] : !hal.device)
+ // CHECK-SAME: target(@exe::@binary1) : !hal.executable
+ %0 = hal.executable.create2 device(%device : !hal.device)
+ target(@exe::@binary1) : !hal.executable
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @pipeline_layout_create
// CHECK-SAME: %[[DEVICE:.+]]: !hal.device,
// CHECK-SAME: %[[LAYOUT0:.+]]: !hal.descriptor_set_layout,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 7cc0471..37d9603 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -268,10 +268,7 @@
SmallVector<IREE::HAL::DescriptorSetBindingAttr> bindingAttrs;
for (const auto &binding : setLayout.bindings) {
bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get(
- builder.getContext(), binding.ordinal, binding.type,
- binding.flags != IREE::HAL::DescriptorFlags::None
- ? binding.flags
- : std::optional<IREE::HAL::DescriptorFlags>{}));
+ builder.getContext(), binding.ordinal, binding.type, binding.flags));
}
setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get(
builder.getContext(), setLayout.ordinal, bindingAttrs,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
index de22093..16c57e2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp
@@ -32,6 +32,13 @@
namespace {
+// TODO(#18154): switch default to true and then remove.
+static llvm::cl::opt<bool> clExperimentalExecutableCreate2{
+ "iree-hal-experimental-executable-create2",
+ llvm::cl::desc("Whether to emit iree_hal_executable_create2 ops."),
+ llvm::cl::init(false),
+};
+
//===----------------------------------------------------------------------===//
// --iree-hal-materialize-resource-caches
//===----------------------------------------------------------------------===//
@@ -248,15 +255,6 @@
auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock();
auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock);
- // Gather each of the pipeline layouts needed for each entry point in
- // the executable.
- SmallVector<Value> pipelineLayoutValues;
- for (auto exportOp : variantOp.getExportOps()) {
- auto &pipelineLayout =
- deviceResources.pipelineLayouts[exportOp.getLayoutAttr()];
- pipelineLayoutValues.push_back(pipelineLayout.initializerValue);
- }
-
// Inline constant initializer from the variant.
// We want these to all happen inside of this device switch case; they'll
// get deduplicated/hoisted if possible in future canonicalization passes.
@@ -270,13 +268,31 @@
blockName, blockOp, moduleBuilder, caseBuilder, initializerDevice));
}
- Value executableValue =
- caseBuilder.createOrFold<IREE::HAL::ExecutableCreateOp>(
- loc, executableType, initializerDevice,
- SymbolRefAttr::get(
- executable.executableOp.getSymNameAttr(),
- {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
- pipelineLayoutValues, constantValues);
+ Value executableValue;
+ if (clExperimentalExecutableCreate2) {
+ executableValue =
+ caseBuilder.createOrFold<IREE::HAL::ExecutableCreate2Op>(
+ loc, executableType, initializerDevice,
+ SymbolRefAttr::get(
+ executable.executableOp.getSymNameAttr(),
+ {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
+ constantValues);
+ } else {
+ // Gather each of the pipeline layouts needed for each entry point in
+ // the executable.
+ SmallVector<Value> pipelineLayoutValues;
+ for (auto exportOp : variantOp.getExportOps()) {
+ auto &pipelineLayout =
+ deviceResources.pipelineLayouts[exportOp.getLayoutAttr()];
+ pipelineLayoutValues.push_back(pipelineLayout.initializerValue);
+ }
+
+ executableValue = caseBuilder.createOrFold<IREE::HAL::ExecutableCreateOp>(
+ loc, executableType, initializerDevice,
+ SymbolRefAttr::get(executable.executableOp.getSymNameAttr(),
+ {SymbolRefAttr::get(variantOp.getSymNameAttr())}),
+ pipelineLayoutValues, constantValues);
+ }
caseBuilder.create<scf::YieldOp>(loc, executableValue);
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
index d350e0e..5623e7f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
@@ -13,9 +13,9 @@
// CHECK-SAME: push_constants = 1
// CHECK-SAME: sets = [
// CHECK-SAME: <0, bindings = [
-// CHECK-SAME: <0, storage_buffer, ReadOnly>
-// CHECK-SAME: <1, storage_buffer, ReadOnly>
-// CHECK-SAME: <2, storage_buffer>
+// CHECK-SAME: <0, storage_buffer, "ReadOnly|Indirect">
+// CHECK-SAME: <1, storage_buffer, "ReadOnly|Indirect">
+// CHECK-SAME: <2, storage_buffer, Indirect>
// CHECK: hal.executable private @ex
// CHECK: hal.executable.variant public @arm_64 target(#executable_target_arm_64
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 66f8dd7..9cd824d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -286,6 +286,8 @@
%element_count : i64
)
+// TODO(#18154): remove this in favor of inlined constants.
+//
// Pushes constants for consumption by dispatches.
vm.import private @command_buffer.push_constants(
%command_buffer : !vm.ref<!hal.command_buffer>,
@@ -294,6 +296,8 @@
%values : i32 ...
)
+// TODO(#18154): remove this in favor of inlined bindings.
+//
// Pushes a descriptor set to the given set number.
vm.import private @command_buffer.push_descriptor_set(
%command_buffer : !vm.ref<!hal.command_buffer>,
@@ -326,6 +330,45 @@
%flags : i64
)
+// TODO(#18154): replace @command_buffer.dispatch.
+//
+// Dispatches an execution request.
+vm.import private @command_buffer.dispatch2(
+ %command_buffer : !vm.ref<!hal.command_buffer>,
+ %executable : !vm.ref<!hal.executable>,
+ %entry_point : i32,
+ %workgroup_x : i32,
+ %workgroup_y : i32,
+ %workgroup_z : i32,
+ %flags : i64,
+ %constants : i32 ...,
+ // <reserved, slot, buffer, offset, length>
+ %bindings : tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64>...
+)
+attributes {
+ minimum_version = 4 : i32
+}
+
+// TODO(#18154): replace @command_buffer.dispatch.indirect.
+//
+// Dispatches an execution request with the dispatch parameters loaded from the
+// given buffer.
+vm.import private @command_buffer.dispatch2.indirect(
+ %command_buffer : !vm.ref<!hal.command_buffer>,
+ %executable : !vm.ref<!hal.executable>,
+ %entry_point : i32,
+ %workgroups_buffer_slot : i32,
+ %workgroups_buffer : !vm.ref<!hal.buffer>,
+ %workgroups_offset : i64,
+ %flags : i64,
+ %constants : i32 ...,
+ // <reserved, slot, buffer, offset, length>
+ %bindings : tuple<i32, i32, !vm.ref<!hal.buffer>, i64, i64>...
+)
+attributes {
+ minimum_version = 4 : i32
+}
+
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout_t
//===----------------------------------------------------------------------===//
@@ -468,6 +511,19 @@
) -> !vm.ref<!hal.executable>
attributes {nosideeffects}
+// TODO(#18154): replace @executable.create.
+// Creates an executable for use with the specified device.
+vm.import private @executable.create2(
+ %device : !vm.ref<!hal.device>,
+ %executable_format : !vm.buffer,
+ %executable_data : !vm.buffer,
+ %constants : !vm.buffer
+) -> !vm.ref<!hal.executable>
+attributes {
+ minimum_version = 4 : i32,
+ nosideeffects
+}
+
//===----------------------------------------------------------------------===//
// iree_hal_fence_t
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 3f9d680..266b1e1 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -147,9 +147,11 @@
// TODO(benvanik): remove reflection attrs as a concept and use something more
// MLIRish like an attribute interface/dialect interface.
// DictionaryAttr is not very friendly for modification :/
- auto existingAttr =
- getOperation()->getAttrOfType<DictionaryAttr>("iree.reflection");
- SmallVector<NamedAttribute> attrs(existingAttr.begin(), existingAttr.end());
+ SmallVector<NamedAttribute> attrs;
+ if (auto existingAttr =
+ getOperation()->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+ llvm::append_range(attrs, existingAttr);
+ }
bool didFind = false;
for (size_t i = 0; i < attrs.size(); ++i) {
if (attrs[i].getName() == name) {
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index ea5cb95..fbf8876 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -202,12 +202,23 @@
SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
for (auto attr : attrs) {
auto key = attr.getName().strref();
- auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
- if (!value || key.empty())
+ if (key.empty()) {
continue;
+ }
+ std::string value;
+ if (auto stringAttr = dyn_cast<StringAttr>(attr.getValue())) {
+ value = stringAttr.getValue().str();
+ } else if (auto integerAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
+ SmallVector<char> str;
+ integerAttr.getValue().toStringSigned(str);
+ value.append(str.data(), str.size());
+ } else {
+ assert(false && "expected string or integer reflection attr");
+ continue;
+ }
// NOTE: if we actually want to keep these we should dedupe them (as the
// keys and likely several of the values are shared across all functions).
- auto valueRef = fbb.createString(value.getValue());
+ auto valueRef = fbb.createString(value);
auto keyRef = fbb.createString(key);
attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
}
diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index 6e8ecc0..f3b0438 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -78,17 +78,16 @@
}
}
-static std::optional<IREE::HAL::DescriptorFlags>
+static IREE::HAL::DescriptorFlags
convertDescriptorFlags(std::optional<IREE::Input::DescriptorFlags> src) {
if (!src.has_value())
- return std::nullopt;
+ return IREE::HAL::DescriptorFlags::None;
switch (*src) {
+ default:
case IREE::Input::DescriptorFlags::None:
return IREE::HAL::DescriptorFlags::None;
case IREE::Input::DescriptorFlags::ReadOnly:
return IREE::HAL::DescriptorFlags::ReadOnly;
- default:
- return std::nullopt;
}
}
diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c
index de89e4f..2a7047b 100644
--- a/experimental/webgpu/command_buffer.c
+++ b/experimental/webgpu/command_buffer.c
@@ -270,6 +270,16 @@
iree_hal_webgpu_command_segment_list_reset(&command_buffer->segments);
iree_arena_reset(&command_buffer->arena);
+ // Pad up to IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX with empty bind groups.
+ WGPUBindGroup empty_handle = command_buffer->staging_buffer->empty_bind_group;
+ for (iree_host_size_t i = 0; i < IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX;
+ ++i) {
+ wgpuComputePassEncoderSetBindGroup(compute_pass, (uint32_t)i, empty_handle,
+ 0, NULL);
+ command_buffer->state.bind_groups[i].handle = empty_handle;
+ command_buffer->state.bind_groups_empty |= 1ull << i;
+ }
+
IREE_TRACE_ZONE_END(z0);
}
@@ -802,7 +812,8 @@
static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch(
iree_hal_webgpu_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, uint32_t ordinal,
- WGPUComputePassEncoder* out_compute_pass) {
+ iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings,
+ iree_hal_dispatch_flags_t flags, WGPUComputePassEncoder* out_compute_pass) {
const iree_hal_webgpu_entry_point_t* entry_point =
iree_hal_webgpu_executable_lookup_entry_point(executable, ordinal);
@@ -915,6 +926,111 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch2(
+ iree_hal_webgpu_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, uint32_t ordinal,
+ iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings,
+ iree_hal_dispatch_flags_t flags, WGPUComputePassEncoder* out_compute_pass) {
+ const iree_hal_webgpu_entry_point_t* entry_point =
+ iree_hal_webgpu_executable_lookup_entry_point(executable, ordinal);
+
+ // Upload push constant data - this may incur a segment flush if the staging
+ // buffer is exhausted.
+ uint32_t params_offset = 0;
+ if (!iree_const_byte_span_is_empty(constants)) {
+ IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_append_parameters(
+ command_buffer, constants, ¶ms_offset));
+ }
+
+ // Acquire the compute pass we'll encode the dispatch into - this may be
+ // fresh or reused from prior commands.
+ WGPUComputePassEncoder compute_pass = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_acquire_compute_pass(
+ command_buffer, &compute_pass));
+ wgpuComputePassEncoderSetPipeline(compute_pass, entry_point->pipeline);
+
+ if (!iree_const_byte_span_is_empty(constants)) {
+ // Bind the push constant emulation bind group at the staging buffer
+ // relative offset for this dispatch.
+ wgpuComputePassEncoderSetBindGroup(
+ compute_pass, IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX,
+ command_buffer->staging_buffer->bind_group, 1, ¶ms_offset);
+ }
+
+ // Set all bindings.
+ const iree_hal_webgpu_set_binding_info_t* binding_info =
+ iree_hal_webgpu_pipeline_layout_set_binding_info(entry_point->layout);
+
+ // TODO: change the bind group cache to take the bindings list directly and
+ // avoid this copy.
+ iree_hal_webgpu_bind_group_binding_t* group_bindings =
+ (iree_hal_webgpu_bind_group_binding_t*)iree_alloca(
+ bindings.count * sizeof(iree_hal_webgpu_bind_group_binding_t));
+ iree_hal_webgpu_binding_mask_t binding_mask = 0;
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ binding_mask |= 1u << i;
+ group_bindings[i].type = WGPUBufferBindingType_Storage;
+ group_bindings[i].buffer =
+ bindings[i].buffer ? iree_hal_webgpu_buffer_handle(bindings[i].buffer)
+ : NULL;
+ group_bindings[i] offset = bindings[i].offset;
+ group_bindings[i] length = bindings[i].length;
+ }
+
+ // Acquire the bind group to use for the current descriptor set.
+ WGPUBindGroup handle = iree_hal_webgpu_bind_group_cache_acquire(
+ command_buffer->bind_group_cache, binding_info->set_layout,
+ group_bindings, binding_mask);
+
+ // NOTE: today we don't support dynamic offsets for push descriptor sets.
+ // This will be a larger change we'll need to handle in the compiler. If we
+ // wanted to improve caching we could make all the bindings dynamic and then
+ // always cache the base offsets, however
+ // maxDynamicStorageBuffersPerPipelineLayout is minimally 4 and that's not
+ // a lot of bindings.
+ wgpuComputePassEncoderSetBindGroup(compute_pass, 0, handle, 0, NULL);
+
+ *out_compute_pass = compute_pass;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_webgpu_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_webgpu_command_buffer_t* command_buffer =
+ iree_hal_webgpu_command_buffer_cast(base_command_buffer);
+
+ WGPUComputePassEncoder compute_pass = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_prepare_dispatch2(
+ command_buffer, executable, entry_point, constants, bindings, flags,
+ &compute_pass));
+ wgpuComputePassEncoderDispatchWorkgroups(
+ compute_pass, workgroup_count[0], workgroup_count[1], workgroup_count[2]);
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_webgpu_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_webgpu_command_buffer_t* command_buffer =
+ iree_hal_webgpu_command_buffer_cast(base_command_buffer);
+
+ WGPUComputePassEncoder compute_pass = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_prepare_dispatch2(
+ command_buffer, executable, entry_point, constants, bindings, flags,
+ &compute_pass));
+ wgpuComputePassEncoderDispatchWorkgroupsIndirect(
+ compute_pass, iree_hal_webgpu_buffer_handle(workgroups_ref.buffer),
+ workgroups_ref.offset);
+
+ return iree_ok_status();
+}
+
const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = {
.destroy = iree_hal_webgpu_command_buffer_destroy,
.begin = iree_hal_webgpu_command_buffer_begin,
@@ -933,4 +1049,6 @@
.push_descriptor_set = iree_hal_webgpu_command_buffer_push_descriptor_set,
.dispatch = iree_hal_webgpu_command_buffer_dispatch,
.dispatch_indirect = iree_hal_webgpu_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_webgpu_command_buffer_dispatch2,
+ .dispatch2_indirect = iree_hal_webgpu_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/base/internal/threading_darwin.c b/runtime/src/iree/base/internal/threading_darwin.c
index 8f611b8..537f705 100644
--- a/runtime/src/iree/base/internal/threading_darwin.c
+++ b/runtime/src/iree/base/internal/threading_darwin.c
@@ -26,7 +26,7 @@
iree_atomic_ref_count_t ref_count;
iree_allocator_t allocator;
- char name[16];
+ char name[32];
pthread_t handle;
mach_port_t mach_port;
diff --git a/runtime/src/iree/base/internal/threading_pthreads.c b/runtime/src/iree/base/internal/threading_pthreads.c
index ec0f107..0d5c016 100644
--- a/runtime/src/iree/base/internal/threading_pthreads.c
+++ b/runtime/src/iree/base/internal/threading_pthreads.c
@@ -33,7 +33,7 @@
iree_atomic_ref_count_t ref_count;
iree_allocator_t allocator;
- char name[16];
+ char name[32];
pthread_t handle;
iree_thread_entry_t entry;
diff --git a/runtime/src/iree/base/internal/threading_win32.c b/runtime/src/iree/base/internal/threading_win32.c
index 944c24a..0091af1 100644
--- a/runtime/src/iree/base/internal/threading_win32.c
+++ b/runtime/src/iree/base/internal/threading_win32.c
@@ -25,7 +25,7 @@
iree_atomic_ref_count_t ref_count;
iree_allocator_t allocator;
- char name[16];
+ char name[32];
HANDLE handle;
DWORD id;
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c
index 7f3785d..802330f 100644
--- a/runtime/src/iree/hal/command_buffer.c
+++ b/runtime/src/iree/hal/command_buffer.c
@@ -619,6 +619,77 @@
return status;
}
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ IREE_ASSERT_ARGUMENT(command_buffer);
+ IREE_ASSERT_ARGUMENT(executable);
+
+ if ((workgroup_count[0] | workgroup_count[1] | workgroup_count[2]) == 0) {
+ // No-op dispatch. All implementations are expected to do this but we ensure
+ // it happens here to avoid the overhead of going all the way down into the
+ // device layer for something we know should have no (intentional)
+ // side-effects. Note that this does mean that validation is skipped and
+ // the executable/etc could be bogus but that's fine.
+ return iree_ok_status();
+ }
+
+ IREE_TRACE_ZONE_BEGIN(z0);
+#if IREE_HAL_VERBOSE_TRACING_ENABLE
+ // TODO(benvanik): add a tracing.h helper that does the snprintf directly
+ // into a tracy_malloc buffer so that we can avoid the memcpy. Today this can
+ // take 4-5us which adds too much overhead when trying to get accurate timings
+ // with tracing enabled. Because benchmarks shouldn't be run with asserts
+ // enabled we only enable these when assertions are enabled. Ideally we'd
+ // slice off a much larger allocation and then suballocate from that ourselves
+ // so that we could avoid the tracy_malloc overheads per-dispatch.
+ IREE_TRACE({
+ char xyz_string[32];
+ int xyz_string_length =
+ snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u",
+ workgroup_count[0], workgroup_count[1], workgroup_count[2]);
+ IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, xyz_string, xyz_string_length);
+ });
+#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
+
+ IF_VALIDATING(command_buffer, {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_dispatch2_validation(
+ command_buffer, VALIDATION_STATE(command_buffer), executable,
+ entry_point, workgroup_count, constants, bindings, flags));
+ });
+
+ iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch2)(
+ command_buffer, executable, entry_point, workgroup_count, constants,
+ bindings, flags);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ IREE_ASSERT_ARGUMENT(command_buffer);
+ IREE_ASSERT_ARGUMENT(executable);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IF_VALIDATING(command_buffer, {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_dispatch2_indirect_validation(
+ command_buffer, VALIDATION_STATE(command_buffer), executable,
+ entry_point, workgroups_ref, constants, bindings, flags));
+ });
+ iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch2_indirect)(
+ command_buffer, executable, entry_point, workgroups_ref, constants,
+ bindings, flags);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
//===----------------------------------------------------------------------===//
// Validation support
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index 5cd30c6..43a876f 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -91,6 +91,7 @@
//
// Roughly maps to VkDescriptorSetBinding.
typedef struct iree_hal_buffer_ref_t {
+ // TODO(#18154): change ordinal to `reserved` after binding simplification.
// The binding number of this entry and corresponds to a resource of the
// same binding number in the executable interface. Only used by certain
// calls.
@@ -125,6 +126,12 @@
return (iree_hal_buffer_ref_t){0, buffer_slot, NULL, offset, length};
}
+// A list of buffer references.
+typedef struct iree_hal_buffer_ref_list_t {
+ iree_host_size_t count;
+ const iree_hal_buffer_ref_t* values;
+} iree_hal_buffer_ref_list_t;
+
// Bitfield specifying which execution stage a barrier should start/end at.
//
// Maps to VkPipelineStageFlagBits.
@@ -714,6 +721,8 @@
iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref,
iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
+//
// Pushes an inline set of constants that can be accessed by subsequent
// dispatches using a compatible pipeline layout.
//
@@ -725,6 +734,8 @@
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
+//
// Pushes descriptor set bindings and associates them with |set|.
// This uses an internal ringbuffer inside of the command buffer to avoid the
// need for creating and binding descriptor sets and managing their lifetime.
@@ -745,6 +756,8 @@
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
+//
// Dispatches an execution request.
// The request may execute overlapped with any other transfer operation or
// dispatch made within the same barrier-defined sequence.
@@ -761,6 +774,8 @@
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
iree_hal_dispatch_flags_t flags);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
+//
// Dispatches an execution request with deferred workgroup counts.
// This is the same as iree_hal_command_buffer_dispatch but the workgroup counts
// are read from the given |workgroups_buffer| at offset |workgroups_offset| as
@@ -775,6 +790,40 @@
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
+// Dispatches an execution request.
+// The request may execute overlapped with any other transfer operation or
+// dispatch made within the same barrier-defined sequence. The executable
+// specified must be registered for use with the device driver owning this
+// queue.
+//
+// The provided constant data and binding list will be recorded into the command
+// buffer and need not remain live beyond the call.
+//
+// Fails if the queue does not support dispatch operations (as indicated by
+// can_dispatch).
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags);
+
+// Dispatches an execution request with a deferred workgroup count.
+// This is the same as iree_hal_command_buffer_dispatch but the workgroup count
+// is read from the given |workgroups_ref| buffer at the specified offset as
+// 3 uint32_t XYZ values immediately before performing the dispatch. This allows
+// prior dispatches within the command sequence to populate the workgroup
+// count or the workgroup count to change across submissions of the same
+// reusable command buffer.
+//
+// The buffer must have been allocated with
+// IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS and be of
+// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE.
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags);
+
//===----------------------------------------------------------------------===//
// Validation support
//===----------------------------------------------------------------------===//
@@ -937,6 +986,18 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
+
+ iree_status_t(IREE_API_PTR* dispatch2)(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags);
+
+ iree_status_t(IREE_API_PTR* dispatch2_indirect)(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags);
} iree_hal_command_buffer_vtable_t;
IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t);
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c
index b27433c..0c5b0dc 100644
--- a/runtime/src/iree/hal/command_buffer_validation.c
+++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -651,6 +651,88 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_command_buffer_dispatch2_validation_base(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings,
+ iree_hal_dispatch_flags_t flags) {
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
+ command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
+
+ if (IREE_UNLIKELY((constants.data_length % 4) != 0)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "invalid alignment %" PRIhsz
+ ", must be 4-byte aligned",
+ constants.data_length);
+ }
+
+ // For now we conservatively say _any_ access may be performed (read/write).
+ iree_hal_buffer_binding_requirements_t requirements = {
+ .required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
+ .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE,
+ .access = IREE_HAL_MEMORY_ACCESS_ANY,
+ .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ };
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ requirements.max_byte_offset =
+ bindings.values[i].offset + bindings.values[i].length;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_validate_buffer_requirements(
+ command_buffer, validation_state, bindings.values[i], requirements),
+ "binding[%u] (arg[%" PRIhsz "])", bindings.values[i].ordinal, i);
+ }
+
+ return iree_ok_status();
+}
+
+iree_status_t iree_hal_command_buffer_dispatch2_validation(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ return iree_hal_command_buffer_dispatch2_validation_base(
+ command_buffer, validation_state, executable, entry_point, constants,
+ bindings, flags);
+}
+
+iree_status_t iree_hal_command_buffer_dispatch2_indirect_validation(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ if ((workgroups_ref.offset % sizeof(uint32_t)) != 0) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "workgroup count offset does not match the required natural alignment "
+ "of uint32_t (offset=%" PRIdsz ", min_byte_alignment=%" PRIhsz ")",
+ workgroups_ref.offset, sizeof(uint32_t));
+ } else if (workgroups_ref.length < 3 * sizeof(uint32_t)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "workgroup count buffer does not have the capacity "
+ "to store the required 3 uint32_t values "
+ "(length=%" PRIdsz ", min_length=%" PRIhsz ")",
+ workgroups_ref.length, 3 * sizeof(uint32_t));
+ }
+
+ const iree_hal_buffer_binding_requirements_t workgroups_reqs = {
+ .required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
+ .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS,
+ .access = IREE_HAL_MEMORY_ACCESS_READ,
+ .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ .max_byte_offset = workgroups_ref.offset + workgroups_ref.length,
+ .min_byte_alignment = sizeof(uint32_t),
+ };
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
+ command_buffer, validation_state, workgroups_ref, workgroups_reqs));
+
+ return iree_hal_command_buffer_dispatch2_validation_base(
+ command_buffer, validation_state, executable, entry_point, constants,
+ bindings, flags);
+}
+
iree_status_t iree_hal_command_buffer_binding_table_validation(
iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h
index 82ab1c5..505982f 100644
--- a/runtime/src/iree/hal/command_buffer_validation.h
+++ b/runtime/src/iree/hal/command_buffer_validation.h
@@ -126,18 +126,21 @@
iree_hal_buffer_ref_t send_ref, iree_hal_buffer_ref_t recv_ref,
iree_device_size_t element_count);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
iree_status_t iree_hal_command_buffer_push_constants_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
iree_status_t iree_hal_command_buffer_push_descriptor_set_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
iree_status_t iree_hal_command_buffer_dispatch_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
@@ -145,12 +148,27 @@
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
iree_hal_dispatch_flags_t flags);
+// TODO(#18154): deprecated and will be replaced with simplified bindings.
iree_status_t iree_hal_command_buffer_dispatch_indirect_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
+iree_status_t iree_hal_command_buffer_dispatch2_validation(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags);
+
+iree_status_t iree_hal_command_buffer_dispatch2_indirect_validation(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags);
+
iree_status_t iree_hal_command_buffer_binding_table_validation(
iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
index c53428a..68d4d34 100644
--- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
@@ -59,9 +59,8 @@
// Iteratively constructed batch of collective operations.
iree_hal_collective_batch_t collective_batch;
+ // TODO(#18189): drop state used by legacy bindings mechanism.
int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];
-
- // The current bound descriptor sets.
struct {
CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT];
@@ -879,6 +878,132 @@
"indirect dispatch not yet implemented");
}
+static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_cuda_graph_command_buffer_t* command_buffer =
+ iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
+ // Lookup kernel parameters used for side-channeling additional launch
+ // information from the compiler.
+ iree_hal_cuda_kernel_info_t kernel_info;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda_native_executable_entry_point_kernel_info(
+ executable, entry_point, &kernel_info));
+
+ IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL(
+ command_buffer, kernel_info.source_filename.data,
+ kernel_info.source_filename.size, kernel_info.source_line,
+ kernel_info.function_name.data, kernel_info.function_name.size,
+ /*name=*/NULL, 0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &executable));
+ // We append push constants to the end of descriptors to form a linear chain
+ // of kernel arguments.
+ iree_host_size_t kernel_params_count =
+ kernel_info.binding_count + kernel_info.constant_count;
+ iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);
+
+ // TODO: use packed parameters instead of the indirection mechanism - this
+ // would avoid additional driver overhead to reflect and repack them all.
+ //
+ // Per CUDA API requirements, we need two levels of indirection for passing
+ // kernel arguments in.
+ // "If the kernel has N parameters, then kernelParams needs to be an array
+ // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1],
+ // points to the region of memory from which the actual parameter will be
+ // copied."
+ //
+ // (From the cuGraphAddKernelNode API doc in
+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b)
+ //
+ // It means each kernel_params[i] is itself a pointer to the corresponding
+ // element at the *second* inline allocation at the end of the current
+ // segment.
+ iree_host_size_t total_size = kernel_params_length * 2;
+ uint8_t* storage_base = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_arena_allocate(&command_buffer->arena, total_size,
+ (void**)&storage_base));
+ void** params_ptr = (void**)storage_base;
+ CUdeviceptr* payload_ptr =
+ (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length);
+ for (size_t i = 0; i < kernel_params_count; i++) {
+ params_ptr[i] = &payload_ptr[i];
+ }
+ for (iree_host_size_t i = 0; i < bindings.count; i++) {
+ const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+ CUdeviceptr device_ptr = 0;
+ if (binding->buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &binding->buffer));
+ CUdeviceptr device_buffer = iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(binding->buffer));
+ iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
+ device_ptr = device_buffer + offset + binding->offset;
+ }
+ payload_ptr[i] = device_ptr;
+ }
+
+ // As commented in the above, what each kernel parameter points to is a
+ // CUdeviceptr, which as the size of a pointer on the target machine. we are
+ // just storing a 32-bit value for the push constant here instead. So we must
+ // process one element each type, for 64-bit machines.
+ for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) {
+ *((uint32_t*)params_ptr[kernel_info.binding_count + i]) =
+ ((const uint32_t*)constants.data)[i];
+ }
+
+ CUDA_KERNEL_NODE_PARAMS params = {
+ .func = kernel_info.function,
+ .blockDimX = kernel_info.block_size[0],
+ .blockDimY = kernel_info.block_size[1],
+ .blockDimZ = kernel_info.block_size[2],
+ .gridDimX = workgroup_count[0],
+ .gridDimY = workgroup_count[1],
+ .gridDimZ = workgroup_count[2],
+ .kernelParams = params_ptr,
+ .sharedMemBytes = kernel_info.shared_memory_size,
+ };
+
+ if (command_buffer->graph_node_count >=
+ IREE_HAL_CUDA_MAX_CONCURRENT_GRAPH_NODE_COUNT) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "exceeded max concurrent node limit");
+ }
+
+ size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0;
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ cuGraphAddKernelNode(
+ &command_buffer->cu_graph_nodes[command_buffer->graph_node_count++],
+ command_buffer->cu_graph, &command_buffer->cu_barrier_node,
+ dependency_count, ¶ms),
+ "cuGraphAddKernelNode");
+
+ IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect dispatch not yet implemented");
+}
+
static const iree_hal_command_buffer_vtable_t
iree_hal_cuda_graph_command_buffer_vtable = {
.destroy = iree_hal_cuda_graph_command_buffer_destroy,
@@ -903,4 +1028,7 @@
.dispatch = iree_hal_cuda_graph_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_cuda_graph_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_cuda_graph_command_buffer_dispatch2,
+ .dispatch2_indirect =
+ iree_hal_cuda_graph_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.c b/runtime/src/iree/hal/drivers/cuda/native_executable.c
index 00b7216..06a7ffc 100644
--- a/runtime/src/iree/hal/drivers/cuda/native_executable.c
+++ b/runtime/src/iree/hal/drivers/cuda/native_executable.c
@@ -224,16 +224,33 @@
}
if (!iree_status_is_ok(status)) break;
+ // TODO(#18189): embed all of this on a single flatbuffer table
+ // per-export.
+ //
// Package required parameters for kernel launches for each entry point.
iree_hal_cuda_kernel_info_t* info = &executable->entry_points[i];
info->layout = executable_params->pipeline_layouts[i];
iree_hal_pipeline_layout_retain(info->layout);
info->function = function;
+ info->constant_count =
+ iree_hal_cuda_pipeline_layout_push_constant_count(info->layout);
+ info->binding_count =
+ iree_hal_cuda_pipeline_layout_total_binding_count(info->layout);
info->block_size[0] = block_sizes_vec[i].x;
info->block_size[1] = block_sizes_vec[i].y;
info->block_size[2] = block_sizes_vec[i].z;
info->shared_memory_size = shared_memory_sizes[i];
+ if (info->binding_count >
+ IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
+ status = iree_make_status(
+ IREE_STATUS_RESOURCE_EXHAUSTED,
+ "exceeded available binding slots; requested %u of maximum %d",
+ info->binding_count,
+ IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT);
+ }
+ if (!iree_status_is_ok(status)) break;
+
// Stash the entry point name in the string table for use when tracing.
IREE_TRACE({
iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name);
diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.h b/runtime/src/iree/hal/drivers/cuda/native_executable.h
index 1dee84d..226ceda 100644
--- a/runtime/src/iree/hal/drivers/cuda/native_executable.h
+++ b/runtime/src/iree/hal/drivers/cuda/native_executable.h
@@ -20,8 +20,12 @@
#endif // __cplusplus
typedef struct iree_hal_cuda_kernel_info_t {
+ // TODO(#18189): remove when using simplified bindings.
iree_hal_pipeline_layout_t* layout;
CUfunction function;
+ uint32_t constant_count;
+ uint32_t binding_count;
+ // TODO(#18189): add bitfield indicating indirect bindings.
uint32_t block_size[3];
uint32_t shared_memory_size;
diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
index 5b64dfa..17af968 100644
--- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
@@ -538,13 +538,13 @@
// Create the ready-list processing worker itself.
iree_thread_create_params_t params;
memset(¶ms, 0, sizeof(params));
- params.name = IREE_SV("deferque_worker");
+ params.name = IREE_SV("iree-cuda-queue-worker");
params.create_suspended = false;
iree_status_t status = iree_thread_create(
(iree_thread_entry_t)iree_hal_cuda_worker_execute, actions, params,
actions->host_allocator, &actions->worker_thread);
- params.name = IREE_SV("done_worker");
+ params.name = IREE_SV("iree-cuda-queue-completion");
params.create_suspended = false;
if (iree_status_is_ok(status)) {
status = iree_thread_create(
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
index 3369f3b..a9b50fc 100644
--- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -39,10 +39,8 @@
// Iteratively constructed batch of collective operations.
iree_hal_collective_batch_t collective_batch;
- // The current set push constants.
+ // TODO(#18189): drop state used by legacy bindings mechanism.
int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];
-
- // The current bound descriptor sets.
struct {
CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT];
@@ -652,6 +650,120 @@
"need cuda implementation of dispatch indirect");
}
+static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
+
+ // Lookup kernel parameters used for side-channeling additional launch
+ // information from the compiler.
+ iree_hal_cuda_kernel_info_t kernel_info;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_cuda_native_executable_entry_point_kernel_info(
+ executable, entry_point, &kernel_info));
+
+ IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
+ command_buffer->tracing_context, &command_buffer->tracing_event_list,
+ command_buffer->cu_stream, kernel_info.source_filename.data,
+ kernel_info.source_filename.size, kernel_info.source_line,
+ kernel_info.function_name.data, kernel_info.function_name.size,
+ /*name=*/NULL, 0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &executable));
+
+ // We append push constants to the end of descriptors to form a linear chain
+ // of kernel arguments.
+ iree_host_size_t kernel_params_count =
+ kernel_info.binding_count + kernel_info.constant_count;
+ iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);
+
+ // TODO: use packed parameters instead of the indirection mechanism - this
+ // would avoid additional driver overhead to reflect and repack them all.
+ //
+ // Per CUDA API requirements, we need two levels of indirection for passing
+ // kernel arguments in.
+ // "If the kernel has N parameters, then kernelParams needs to be an array
+ // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1],
+ // points to the region of memory from which the actual parameter will be
+ // copied."
+ //
+ // (From the cuGraphAddKernelNode API doc in
+ // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b)
+ //
+ // It means each kernel_params[i] is itself a pointer to the corresponding
+ // element at the *second* inline allocation at the end of the current
+ // segment.
+ iree_host_size_t total_size = kernel_params_length * 2;
+ uint8_t* storage_base = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_arena_allocate(&command_buffer->arena, total_size,
+ (void**)&storage_base));
+ void** params_ptr = (void**)storage_base;
+ CUdeviceptr* payload_ptr =
+ (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length);
+ for (size_t i = 0; i < kernel_params_count; i++) {
+ params_ptr[i] = &payload_ptr[i];
+ }
+ for (iree_host_size_t i = 0; i < bindings.count; i++) {
+ const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+ CUdeviceptr device_ptr = 0;
+ if (binding->buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &binding->buffer));
+ CUdeviceptr device_buffer = iree_hal_cuda_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(binding->buffer));
+ iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
+ device_ptr = device_buffer + offset + binding->offset;
+ }
+ payload_ptr[i] = device_ptr;
+ }
+
+ // As commented in the above, what each kernel parameter points to is a
+ // CUdeviceptr, which as the size of a pointer on the target machine. we are
+ // just storing a 32-bit value for the push constant here instead. So we must
+ // process one element each type, for 64-bit machines.
+ for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) {
+ *((uint32_t*)params_ptr[kernel_info.binding_count + i]) =
+ ((const uint32_t*)constants.data)[i];
+ }
+
+ IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->cuda_symbols,
+ cuLaunchKernel(kernel_info.function, workgroup_count[0],
+ workgroup_count[1], workgroup_count[2],
+ kernel_info.block_size[0], kernel_info.block_size[1],
+ kernel_info.block_size[2], kernel_info.shared_memory_size,
+ command_buffer->cu_stream, params_ptr, NULL),
+ "cuLaunchKernel");
+
+ IREE_CUDA_STREAM_TRACE_ZONE_END(command_buffer->tracing_context,
+ &command_buffer->tracing_event_list,
+ command_buffer->cu_stream);
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect dispatch not yet implemented");
+}
+
static const iree_hal_command_buffer_vtable_t
iree_hal_cuda_stream_command_buffer_vtable = {
.destroy = iree_hal_cuda_stream_command_buffer_destroy,
@@ -676,4 +788,7 @@
.dispatch = iree_hal_cuda_stream_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_cuda_stream_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_cuda_stream_command_buffer_dispatch2,
+ .dispatch2_indirect =
+ iree_hal_cuda_stream_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
index ae66cfd..99b3538 100644
--- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
@@ -60,9 +60,8 @@
// Iteratively constructed batch of collective operations.
iree_hal_collective_batch_t collective_batch;
+ // TODO(#18189): drop state used by legacy bindings mechanism.
int32_t push_constants[IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT];
-
- // The current bound descriptor sets.
struct {
hipDeviceptr_t bindings[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_COUNT];
@@ -888,6 +887,123 @@
"indirect dispatch not yet implemented");
}
+static iree_status_t iree_hal_hip_graph_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_hip_graph_command_buffer_t* command_buffer =
+ iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_hip_graph_command_buffer_flush_collectives(command_buffer));
+
+ // Lookup kernel parameters used for side-channeling additional launch
+ // information from the compiler.
+ iree_hal_hip_kernel_info_t kernel_info;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_hip_native_executable_entry_point_kernel_info(
+ executable, entry_point, &kernel_info));
+
+ IREE_HIP_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL(
+ command_buffer, kernel_info.source_filename.data,
+ kernel_info.source_filename.size, kernel_info.source_line,
+ kernel_info.function_name.data, kernel_info.function_name.size,
+ /*name=*/NULL, 0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &executable));
+
+ // We append push constants to the end of descriptors to form a linear chain
+ // of kernel arguments.
+ iree_host_size_t kernel_params_count =
+ kernel_info.binding_count + kernel_info.constant_count;
+ iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);
+
+ // TODO: use packed parameters instead of the indirection mechanism - this
+ // would avoid additional driver overhead to reflect and repack them all.
+ //
+ // Each kernel_params[i] is itself a pointer to the corresponding
+ // element at the *second* inline allocation at the end of the current
+ // segment.
+ iree_host_size_t total_size = kernel_params_length * 2;
+ uint8_t* storage_base = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_arena_allocate(&command_buffer->arena, total_size,
+ (void**)&storage_base));
+ void** params_ptr = (void**)storage_base;
+ hipDeviceptr_t* payload_ptr =
+ (hipDeviceptr_t*)((uint8_t*)params_ptr + kernel_params_length);
+ for (size_t i = 0; i < kernel_params_count; i++) {
+ params_ptr[i] = &payload_ptr[i];
+ }
+ for (iree_host_size_t i = 0; i < bindings.count; i++) {
+ const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+ hipDeviceptr_t device_ptr = NULL;
+ if (binding->buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &binding->buffer));
+ hipDeviceptr_t device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(binding->buffer));
+ iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
+ device_ptr = (uint8_t*)device_buffer + offset + binding->offset;
+ }
+ payload_ptr[i] = device_ptr;
+ }
+
+ // Each kernel parameter points to is a hipDeviceptr_t, which as the size of a
+ // pointer on the target machine. we are just storing a 32-bit value for the
+ // push constant here instead. So we must process one element each type, for
+ // 64-bit machines.
+ for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) {
+ *((uint32_t*)params_ptr[kernel_info.binding_count + i]) =
+ ((const uint32_t*)constants.data)[i];
+ }
+
+ hipKernelNodeParams params = {
+ .blockDim.x = kernel_info.block_size[0],
+ .blockDim.y = kernel_info.block_size[1],
+ .blockDim.z = kernel_info.block_size[2],
+ .gridDim.x = workgroup_count[0],
+ .gridDim.y = workgroup_count[1],
+ .gridDim.z = workgroup_count[2],
+ .func = kernel_info.function,
+ .kernelParams = params_ptr,
+ .sharedMemBytes = kernel_info.shared_memory_size,
+ };
+
+ if (command_buffer->graph_node_count >=
+ IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "exceeded max concurrent node limit");
+ }
+
+ size_t dependency_count = command_buffer->hip_barrier_node ? 1 : 0;
+ IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, command_buffer->symbols,
+ hipGraphAddKernelNode(
+ &command_buffer->hip_graph_nodes[command_buffer->graph_node_count++],
+ command_buffer->hip_graph, &command_buffer->hip_barrier_node,
+ dependency_count, ¶ms),
+ "hipGraphAddKernelNode");
+
+ IREE_HIP_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_hip_graph_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect dispatch not yet implemented");
+}
+
static const iree_hal_command_buffer_vtable_t
iree_hal_hip_graph_command_buffer_vtable = {
.destroy = iree_hal_hip_graph_command_buffer_destroy,
@@ -912,4 +1028,7 @@
.dispatch = iree_hal_hip_graph_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_hip_graph_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_hip_graph_command_buffer_dispatch2,
+ .dispatch2_indirect =
+ iree_hal_hip_graph_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c
index 10b5aac..19caae9 100644
--- a/runtime/src/iree/hal/drivers/hip/native_executable.c
+++ b/runtime/src/iree/hal/drivers/hip/native_executable.c
@@ -10,6 +10,7 @@
#include "iree/base/api.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
+#include "iree/hal/drivers/hip/pipeline_layout.h"
#include "iree/hal/drivers/hip/status_util.h"
// flatcc schemas:
@@ -242,16 +243,33 @@
}
if (!iree_status_is_ok(status)) break;
+ // TODO(#18189): embed all of this on a single flatbuffer table
+ // per-export.
+ //
// Package required parameters for kernel launches for each entry point.
iree_hal_hip_kernel_info_t* kernel_info = &executable->entry_points[i];
kernel_info->layout = executable_params->pipeline_layouts[i];
iree_hal_pipeline_layout_retain(kernel_info->layout);
kernel_info->function = function;
+ iree_hal_hip_dispatch_layout_t dispatch_params =
+ iree_hal_hip_pipeline_layout_dispatch_layout(kernel_info->layout);
+ kernel_info->constant_count = dispatch_params.push_constant_count;
+ kernel_info->binding_count = dispatch_params.total_binding_count;
kernel_info->block_size[0] = block_sizes_vec[i].x;
kernel_info->block_size[1] = block_sizes_vec[i].y;
kernel_info->block_size[2] = block_sizes_vec[i].z;
kernel_info->shared_memory_size = shared_memory_sizes_vec[i];
+ if (kernel_info->binding_count >
+ IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT) {
+ status = iree_make_status(
+ IREE_STATUS_RESOURCE_EXHAUSTED,
+ "exceeded available binding slots; requested %u of maximum %d",
+ kernel_info->binding_count,
+ IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT);
+ }
+ if (!iree_status_is_ok(status)) break;
+
// Stash the entry point name in the string table for use when tracing.
IREE_TRACE({
iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name);
diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.h b/runtime/src/iree/hal/drivers/hip/native_executable.h
index 922f343..d2b1a31 100644
--- a/runtime/src/iree/hal/drivers/hip/native_executable.h
+++ b/runtime/src/iree/hal/drivers/hip/native_executable.h
@@ -20,8 +20,12 @@
#endif // __cplusplus
typedef struct iree_hal_hip_kernel_info_t {
+ // TODO(#18189): remove when using simplified bindings.
iree_hal_pipeline_layout_t* layout;
hipFunction_t function;
+ uint32_t constant_count;
+ uint32_t binding_count;
+ // TODO(#18189): add bitfield indicating indirect bindings.
uint32_t block_size[3];
uint32_t shared_memory_size;
diff --git a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
index 6d72330..88a2e83 100644
--- a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
@@ -537,13 +537,13 @@
// Create the ready-list processing worker itself.
iree_thread_create_params_t params;
memset(¶ms, 0, sizeof(params));
- params.name = IREE_SV("deferque_worker");
+ params.name = IREE_SV("iree-hip-queue-worker");
params.create_suspended = false;
iree_status_t status = iree_thread_create(
(iree_thread_entry_t)iree_hal_hip_worker_execute, actions, params,
actions->host_allocator, &actions->worker_thread);
- params.name = IREE_SV("done_worker");
+ params.name = IREE_SV("iree-hip-queue-completion");
params.create_suspended = false;
if (iree_status_is_ok(status)) {
status = iree_thread_create(
diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
index 0f08727..e4ffac2 100644
--- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
@@ -41,9 +41,8 @@
// Iteratively constructed batch of collective operations.
iree_hal_collective_batch_t collective_batch;
+ // TODO(#18189): drop state used by legacy bindings mechanism.
int32_t push_constants[IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT];
-
- // The current bound descriptor sets.
struct {
hipDeviceptr_t bindings[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT];
} descriptor_sets[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_COUNT];
@@ -632,6 +631,110 @@
"need hip implementation of dispatch indirect");
}
+static iree_status_t iree_hal_hip_stream_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_hip_stream_command_buffer_t* command_buffer =
+ iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer));
+
+ // Lookup kernel parameters used for side-channeling additional launch
+ // information from the compiler.
+ iree_hal_hip_kernel_info_t kernel_info;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_hip_native_executable_entry_point_kernel_info(
+ executable, entry_point, &kernel_info));
+
+ IREE_HIP_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
+ command_buffer->tracing_context, &command_buffer->tracing_event_list,
+ command_buffer->hip_stream, kernel_info.source_filename.data,
+ kernel_info.source_filename.size, kernel_info.source_line,
+ kernel_info.function_name.data, kernel_info.function_name.size,
+ /*name=*/NULL, 0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &executable));
+
+ // We append push constants to the end of descriptors to form a linear chain
+ // of kernel arguments.
+ iree_host_size_t kernel_params_count =
+ kernel_info.binding_count + kernel_info.constant_count;
+ iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*);
+
+ // TODO: use packed parameters instead of the indirection mechanism - this
+ // would avoid additional driver overhead to reflect and repack them all.
+ //
+ // Each kernel_params[i] is itself a pointer to the corresponding
+ // element at the *second* inline allocation at the end of the current
+ // segment.
+ iree_host_size_t total_size = kernel_params_length * 2;
+ uint8_t* storage_base = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_arena_allocate(&command_buffer->arena, total_size,
+ (void**)&storage_base));
+ void** params_ptr = (void**)storage_base;
+ hipDeviceptr_t* payload_ptr =
+ (hipDeviceptr_t*)((uint8_t*)params_ptr + kernel_params_length);
+ for (size_t i = 0; i < kernel_params_count; i++) {
+ params_ptr[i] = &payload_ptr[i];
+ }
+ for (iree_host_size_t i = 0; i < bindings.count; i++) {
+ const iree_hal_buffer_ref_t* binding = &bindings.values[i];
+ hipDeviceptr_t device_ptr = NULL;
+ if (binding->buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1,
+ &binding->buffer));
+ hipDeviceptr_t device_buffer = iree_hal_hip_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(binding->buffer));
+ iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer);
+ device_ptr = (uint8_t*)device_buffer + offset + binding->offset;
+ }
+ payload_ptr[i] = device_ptr;
+ }
+
+ // As commented in the above, what each kernel parameter points to is a
+ // hipDeviceptr_t, which as the size of a pointer on the target machine. we
+ // are just storing a 32-bit value for the push constant here instead. So we
+ // must process one element each type, for 64-bit machines.
+ for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) {
+ *((uint32_t*)params_ptr[kernel_info.binding_count + i]) =
+ ((const uint32_t*)constants.data)[i];
+ }
+
+ iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
+ command_buffer->hip_symbols,
+ hipModuleLaunchKernel(
+ kernel_info.function, workgroup_count[0], workgroup_count[1],
+ workgroup_count[2], kernel_info.block_size[0],
+ kernel_info.block_size[1], kernel_info.block_size[2],
+ kernel_info.shared_memory_size, command_buffer->hip_stream,
+ params_ptr, NULL),
+ "hipModuleLaunchKernel");
+
+ IREE_HIP_STREAM_TRACE_ZONE_END(command_buffer->tracing_context,
+ &command_buffer->tracing_event_list,
+ command_buffer->hip_stream);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_hal_hip_stream_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "indirect dispatch not yet implemented");
+}
+
static const iree_hal_command_buffer_vtable_t
iree_hal_hip_stream_command_buffer_vtable = {
.destroy = iree_hal_hip_stream_command_buffer_destroy,
@@ -656,4 +759,7 @@
.dispatch = iree_hal_hip_stream_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_hip_stream_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_hip_stream_command_buffer_dispatch2,
+ .dispatch2_indirect =
+ iree_hal_hip_stream_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
index de8642e..3b0a9ba 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
@@ -78,6 +78,7 @@
// All execution tasks emitted that must execute after |open_barrier|.
iree_task_list_t open_tasks;
+ // TODO(#18189): remove legacy binding state.
// A flattened list of all available descriptor set bindings.
// As descriptor sets are pushed/bound the bindings will be updated to
// represent the fully-translated binding data pointer.
@@ -89,6 +90,7 @@
binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
+ // TODO(#18189): remove legacy push constant state.
// All available push constants updated each time push_constants is called.
// Reset only with the command buffer and otherwise will maintain its values
// during recording to allow for partial push_constants updates.
@@ -930,7 +932,7 @@
cmd->task.local_memory_size =
local_executable->dispatch_attrs
? local_executable->dispatch_attrs[entry_point].local_memory_pages *
- IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
+ IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
: 0;
// Copy only the push constant range used by the executable.
@@ -1013,6 +1015,234 @@
}
//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_dispatch2
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_task_cmd_dispatch2_t {
+ iree_task_dispatch_t task;
+ iree_hal_local_executable_t* executable;
+ int32_t ordinal;
+
+ // Total number of available 4 byte push constant values in |push_constants|.
+ uint16_t push_constant_count;
+
+ // Total number of binding base pointers in |binding_ptrs| and
+ // |binding_lengths|. The set is packed densely based on which bindings are
+ // used (known at compile-time).
+ uint16_t binding_count;
+
+ // Following this structure in memory there are 3 tables:
+ // - const uint32_t push_constants[push_constant_count];
+ // - void* binding_ptrs[binding_count];
+ // - const size_t binding_lengths[binding_count];
+} iree_hal_task_cmd_dispatch2_t;
+
+static iree_status_t iree_hal_task_cmd_dispatch2_tile(
+ void* user_context, const iree_task_tile_context_t* tile_context,
+ iree_task_submission_t* pending_submission) {
+ const iree_hal_task_cmd_dispatch2_t* cmd =
+ (const iree_hal_task_cmd_dispatch2_t*)user_context;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // We could share this across all workgroups in a dispatch and reduce cache
+ // pressure as all cores would be hitting the same hot read-only cache line.
+ // It'd grow the size of iree_hal_task_cmd_dispatch_t by a few dozen bytes,
+ // though, and so we'd need some profiling to see if it's worth it (fixed
+ // command buffer cost vs potential for saving a cache miss or two).
+ iree_alignas(64) iree_hal_executable_dispatch_state_v0_t dispatch_state = {
+ .workgroup_size_x = tile_context->workgroup_size[0],
+ .workgroup_size_y = tile_context->workgroup_size[1],
+ .workgroup_size_z = tile_context->workgroup_size[2],
+ .push_constant_count = cmd->push_constant_count,
+ .workgroup_count_x = tile_context->workgroup_count[0],
+ .workgroup_count_y = tile_context->workgroup_count[1],
+ .workgroup_count_z = tile_context->workgroup_count[2],
+ .max_concurrency =
+ iree_task_affinity_set_count_ones(cmd->task.header.affinity_set),
+ .binding_count = cmd->binding_count,
+ };
+ uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd);
+ dispatch_state.push_constants = (uint32_t*)cmd_ptr;
+ cmd_ptr += cmd->push_constant_count * sizeof(*dispatch_state.push_constants);
+ dispatch_state.binding_ptrs = (void**)cmd_ptr;
+ cmd_ptr += cmd->binding_count * sizeof(*dispatch_state.binding_ptrs);
+ dispatch_state.binding_lengths = (size_t*)cmd_ptr;
+ cmd_ptr += cmd->binding_count * sizeof(*dispatch_state.binding_lengths);
+
+ const iree_alignas(64)
+ iree_hal_executable_workgroup_state_v0_t workgroup_state = {
+ .workgroup_id_x = tile_context->workgroup_xyz[0],
+ .workgroup_id_y = tile_context->workgroup_xyz[1],
+ .workgroup_id_z = tile_context->workgroup_xyz[2],
+ .reserved = 0,
+ .processor_id = tile_context->processor_id,
+ .local_memory = tile_context->local_memory.data,
+ .local_memory_size = (size_t)tile_context->local_memory.data_length,
+ };
+ iree_status_t status = iree_hal_local_executable_issue_call(
+ cmd->executable, cmd->ordinal, &dispatch_state, &workgroup_state,
+ tile_context->worker_id);
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static iree_status_t iree_hal_task_command_buffer_build_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings,
+ iree_hal_task_cmd_dispatch2_t** out_cmd) {
+ iree_hal_task_command_buffer_t* command_buffer =
+ iree_hal_task_command_buffer_cast(base_command_buffer);
+
+ iree_hal_local_executable_t* local_executable =
+ iree_hal_local_executable_cast(executable);
+ iree_hal_executable_dispatch_attrs_v0_t dispatch_attrs = {0};
+ if (local_executable->dispatch_attrs) {
+ dispatch_attrs = local_executable->dispatch_attrs[entry_point];
+ }
+
+ iree_hal_task_cmd_dispatch2_t* cmd = NULL;
+ iree_host_size_t total_cmd_size =
+ sizeof(*cmd) + dispatch_attrs.constant_count * sizeof(uint32_t) +
+ dispatch_attrs.binding_count * sizeof(void*) +
+ dispatch_attrs.binding_count * sizeof(size_t);
+ IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->arena,
+ total_cmd_size, (void**)&cmd));
+
+ cmd->executable = local_executable;
+ cmd->ordinal = entry_point;
+ cmd->push_constant_count = dispatch_attrs.constant_count;
+ cmd->binding_count = dispatch_attrs.binding_count;
+
+ // TODO(benvanik): expose on API or keep fixed on executable.
+ const uint32_t workgroup_size[3] = {1, 1, 1};
+ iree_task_dispatch_initialize(
+ command_buffer->scope,
+ iree_task_make_dispatch_closure(iree_hal_task_cmd_dispatch_tile,
+ (void*)cmd),
+ workgroup_size, workgroup_count, &cmd->task);
+
+ // Tell the task system how much workgroup local memory is required for the
+ // dispatch; each invocation of the entry point will have at least as much
+ // scratch memory available during execution.
+ cmd->task.local_memory_size =
+ dispatch_attrs.local_memory_pages *
+ IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE;
+
+ // Push constants are pulled directly from the args and copied into the
+ // command buffer. Note that we require 4 byte alignment and if the input
+ // buffer is not aligned we have to fail.
+ if (IREE_UNLIKELY((constants.data_length % sizeof(uint32_t)) != 0)) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "constants must be 4-byte aligned");
+ } else if (IREE_UNLIKELY(constants.data_length !=
+ dispatch_attrs.constant_count * sizeof(uint32_t))) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "constant count mismatch, expected %u but was provided %" PRIhsz,
+ (uint32_t)dispatch_attrs.constant_count,
+ constants.data_length / sizeof(uint32_t));
+ }
+ uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd);
+ uint32_t* push_constants = (uint32_t*)cmd_ptr;
+ memcpy(push_constants, constants.data,
+ dispatch_attrs.constant_count * sizeof(*push_constants));
+ cmd_ptr += dispatch_attrs.constant_count * sizeof(*push_constants);
+
+ // Produce the dense binding list based on the declared bindings used.
+ //
+ // Note that we are just directly setting the binding data pointers here with
+ // no ownership/retaining/etc - it's part of the HAL contract that buffers are
+ // kept valid for the duration they may be in use.
+ if (IREE_UNLIKELY(bindings.count != dispatch_attrs.binding_count)) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "binding count mismatch, expected %u but was provided %" PRIhsz,
+ (uint32_t)dispatch_attrs.binding_count, bindings.count);
+ }
+ void** binding_ptrs = (void**)cmd_ptr;
+ cmd_ptr += bindings.count * sizeof(*binding_ptrs);
+ size_t* binding_lengths = (size_t*)cmd_ptr;
+ cmd_ptr += bindings.count * sizeof(*binding_lengths);
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc.
+ iree_hal_buffer_mapping_t buffer_mapping = {{0}};
+ if (IREE_LIKELY(bindings.values[i].buffer)) {
+ // TODO(benvanik): batch insert by getting the resources in their own
+ // list.
+ const iree_hal_buffer_ref_t binding = bindings.values[i];
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ binding.buffer, IREE_HAL_MAPPING_MODE_PERSISTENT,
+ IREE_HAL_MEMORY_ACCESS_ANY, binding.offset, binding.length,
+ &buffer_mapping));
+ } else {
+ return iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "required binding %" PRIhsz
+ " is NULL; all bindings must have a valid pointer",
+ i);
+ }
+ binding_ptrs[i] = buffer_mapping.contents.data;
+ binding_lengths[i] = buffer_mapping.contents.data_length;
+ }
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided(
+ command_buffer->resource_set, bindings.count, bindings.values,
+ offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t)));
+
+ *out_cmd = cmd;
+ return iree_hal_task_command_buffer_emit_execution_task(command_buffer,
+ &cmd->task.header);
+}
+
+static iree_status_t iree_hal_task_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_task_command_buffer_t* command_buffer =
+ iree_hal_task_command_buffer_cast(base_command_buffer);
+
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+ command_buffer->resource_set, 1, &executable));
+
+ iree_hal_task_cmd_dispatch2_t* cmd = NULL;
+ return iree_hal_task_command_buffer_build_dispatch2(
+ base_command_buffer, executable, entry_point, workgroup_count, constants,
+ bindings, &cmd);
+}
+
+static iree_status_t iree_hal_task_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_task_command_buffer_t* command_buffer =
+ iree_hal_task_command_buffer_cast(base_command_buffer);
+
+ const void* resources[2] = {executable, workgroups_ref.buffer};
+ IREE_RETURN_IF_ERROR(
+ iree_hal_resource_set_insert(command_buffer->resource_set, 2, resources));
+
+ // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc.
+ iree_hal_buffer_mapping_t buffer_mapping = {{0}};
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ workgroups_ref.buffer, IREE_HAL_MAPPING_MODE_PERSISTENT,
+ IREE_HAL_MEMORY_ACCESS_READ, workgroups_ref.offset, 3 * sizeof(uint32_t),
+ &buffer_mapping));
+
+ uint32_t workgroup_count[3] = {0}; // unused with the indirect flag
+ iree_hal_task_cmd_dispatch2_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_task_command_buffer_build_dispatch2(
+ base_command_buffer, executable, entry_point, workgroup_count, constants,
+ bindings, &cmd));
+ cmd->task.workgroup_count.ptr = (const uint32_t*)buffer_mapping.contents.data;
+ cmd->task.header.flags |= IREE_TASK_FLAG_DISPATCH_INDIRECT;
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_vtable_t
//===----------------------------------------------------------------------===//
@@ -1036,4 +1266,6 @@
.push_descriptor_set = iree_hal_task_command_buffer_push_descriptor_set,
.dispatch = iree_hal_task_command_buffer_dispatch,
.dispatch_indirect = iree_hal_task_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_task_command_buffer_dispatch2,
+ .dispatch2_indirect = iree_hal_task_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
index eaed4f5..fbf6374 100644
--- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -49,6 +49,7 @@
typedef enum iree_hal_metal_command_segment_action_e {
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER, // Execution/memory barrier command
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH, // Dispatch command
+ IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH2, // Dispatch command
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER, // Fill buffer command
IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER, // Copy buffer command
} iree_hal_metal_command_segment_action_t;
@@ -94,6 +95,30 @@
// + Additional inline allocation for holding all bound descriptors.
// + Additional inline allocation for holding all push constants.
+// API data for dispatch command segments.
+typedef struct iree_hal_metal_dispatch2_segment_t {
+ // Compute kernel information--kernel object, pipeline layout, threadgroup size, etc.
+ iree_hal_metal_kernel_params_t kernel_params;
+
+ // Workgroup count information--if |workgroups_buffer| is not nil, then indirect dispatch;
+ // otherwise uses |workgroup_count| for direct dispatch.
+ id<MTLBuffer> workgroups_buffer;
+ iree_device_size_t workgroups_offset;
+ uint32_t workgroup_count[3];
+
+ // The number of descriptors bound for this dispatch.
+ iree_host_size_t descriptor_count;
+ // The list of bound descriptors, pointing to the end of the segment allocation.
+ iree_hal_metal_descriptor_t* descriptors;
+
+ // The number of push constant values.
+ iree_host_size_t push_constant_count;
+ // The list of push constants, pointing to the end of the segment allocation.
+ int32_t* push_constants;
+} iree_hal_metal_dispatch2_segment_t;
+// + Additional inline allocation for holding all bound descriptors.
+// + Additional inline allocation for holding all push constants.
+
// API data for fill buffer command segments.
typedef struct iree_hal_metal_fill_buffer_segment_t {
id<MTLBuffer> target_buffer;
@@ -121,6 +146,7 @@
union {
iree_hal_metal_barrier_segment_t barrier;
iree_hal_metal_dispatch_segment_t dispatch;
+ iree_hal_metal_dispatch2_segment_t dispatch2;
iree_hal_metal_fill_buffer_segment_t fill_buffer;
iree_hal_metal_copy_buffer_segment_t copy_buffer;
};
@@ -1105,6 +1131,183 @@
return iree_ok_status();
}
+// Prepares kernels and argument buffers needed for kernel dispatches.
+static iree_status_t iree_hal_metal_command_segment_create_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
+ int32_t entry_point, iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings,
+ iree_hal_dispatch_flags_t flags, iree_hal_metal_dispatch2_segment_t** out_segment) {
+ iree_hal_metal_command_buffer_t* command_buffer =
+ iree_hal_metal_command_buffer_cast(base_command_buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable));
+
+ iree_hal_metal_kernel_params_t kernel_params;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_kernel_library_entry_point_kernel_params(
+ executable, entry_point, &kernel_params));
+
+ // Allocate the command segment and keep track of all necessary API data.
+ uint8_t* storage_base = NULL;
+ iree_hal_metal_command_segment_t* segment = NULL;
+ iree_host_size_t descriptor_length = bindings.count * sizeof(iree_hal_metal_descriptor_t);
+ iree_host_size_t total_size = sizeof(*segment) + descriptor_length + constants.data_length;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base));
+
+ // Compose and push the dispatch segment.
+ segment = (iree_hal_metal_command_segment_t*)storage_base;
+ memset(segment, 0, sizeof(*segment));
+ segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH2;
+ iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment);
+
+ segment->dispatch.kernel_params = kernel_params;
+
+ // Copy descriptors from all sets to the end of the current segment for later access.
+ const iree_hal_descriptor_set_layout_t* set_layout =
+ iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, 0);
+ segment->dispatch.descriptor_count = bindings.count;
+ segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)(storage_base + sizeof(*segment));
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ iree_hal_metal_descriptor_t* descriptor = &segment->dispatch.descriptors[i];
+
+ descriptor->set = 0;
+ descriptor->binding = i;
+ descriptor->buffer = bindings.values[i].buffer;
+ descriptor->offset = bindings.values[i].offset;
+
+ const iree_hal_descriptor_set_layout_binding_t* binding_params =
+ iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding);
+ descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params);
+
+ if (descriptor->buffer) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &descriptor->buffer));
+ }
+ }
+
+ // Copy push constants to the end of the current segment for later access.
+ segment->dispatch.push_constant_count = constants.data_length / sizeof(uint32_t);
+ uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length;
+ segment->dispatch.push_constants = (int32_t*)push_constant_ptr;
+ memcpy(push_constant_ptr, constants.data, constants.data_length);
+
+ *out_segment = &segment->dispatch2;
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_metal_command_segment_record_dispatch2(
+ iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_dispatch2_segment_t* segment) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Set the compute kernel to dispatch.
+ id<MTLComputeCommandEncoder> compute_encoder =
+ iree_hal_metal_get_or_begin_compute_encoder(command_buffer);
+ [compute_encoder setComputePipelineState:segment->kernel_params.pso];
+
+ // Record push constants.
+ if (segment->push_constant_count != 0) {
+ [compute_encoder setBytes:(void*)segment->push_constants
+ length:segment->push_constant_count * sizeof(int32_t)
+ atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX];
+ }
+
+ // Record argument buffers for all descriptors and record buffer usages.
+ iree_hal_metal_descriptor_t* descriptors = segment->descriptors;
+
+ // Build argument encoder and argument buffer for the current descriptor set.
+ // TODO(antiagainst): Use a cache layer to cache and reuse argument buffers with the same
+ // content, to avoid duplicating overhead.
+ id<MTLBuffer> argument_buffer = command_buffer->staging_buffer->metal_buffer;
+ id<MTLArgumentEncoder> argument_encoder =
+ [segment->kernel_params.function newArgumentEncoderWithBufferIndex:0]; // +1
+ IREE_ASSERT(argument_encoder != nil);
+
+ // Reserve space for the argument buffer from shared staging buffer.
+ iree_byte_span_t reservation = iree_byte_span_empty();
+ uint32_t argument_buffer_offset = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_metal_staging_buffer_reserve(
+ command_buffer->staging_buffer, argument_encoder.encodedLength,
+ argument_encoder.alignment, &reservation, &argument_buffer_offset));
+ [argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset];
+
+ // Now record all bound buffers belonging to the current set into the argument buffer.
+ for (iree_host_size_t i = 0; i < segment->descriptor_count; ++i) {
+ uint32_t current_binding = descriptors[i].binding;
+ id<MTLBuffer> current_buffer =
+ iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer));
+ iree_host_size_t offset =
+ iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset;
+ [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding];
+
+ // Also record buffer usages.
+ [compute_encoder useResource:current_buffer usage:descriptors[i].usage];
+ }
+ // Record the argument buffer.
+ [compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:0];
+
+ [argument_encoder release]; // -1
+
+ // Record the dispatch, either direct or indirect.
+ uint32_t* workgroup_size = segment->kernel_params.threadgroup_size;
+ if (segment->workgroups_buffer == nil) {
+ // Direct dispatch of a fixed workgroup count.
+ uint32_t* workgroup_count = segment->workgroup_count;
+ [compute_encoder
+ dispatchThreadgroups:MTLSizeMake(workgroup_count[0], workgroup_count[1],
+ workgroup_count[2])
+ threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], workgroup_size[2])];
+ } else {
+ // Indirect dispatch using a workgroup count from buffers.
+ [compute_encoder
+ dispatchThreadgroupsWithIndirectBuffer:segment->workgroups_buffer
+ indirectBufferOffset:segment->workgroups_offset
+ threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1],
+ workgroup_size[2])];
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
+ int32_t entry_point, const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_metal_dispatch2_segment_t* segment = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_metal_command_segment_create_dispatch2(
+ base_command_buffer, executable, entry_point, constants, bindings, flags, &segment));
+ segment->workgroup_count[0] = workgroup_count[0];
+ segment->workgroup_count[1] = workgroup_count[1];
+ segment->workgroup_count[2] = workgroup_count[2];
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
+ int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_metal_dispatch2_segment_t* segment = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_metal_command_segment_create_dispatch2(
+ base_command_buffer, executable, entry_point, constants, bindings, flags, &segment));
+ segment->workgroups_buffer =
+ iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(workgroups_ref.buffer));
+ segment->workgroups_offset = workgroups_ref.offset;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
static iree_status_t iree_hal_metal_command_segment_record(
iree_hal_metal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
@@ -1121,6 +1324,10 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_metal_command_segment_record_dispatch(command_buffer, &segment->dispatch));
} break;
+ case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH2: {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_dispatch2(
+ command_buffer, &segment->dispatch2));
+ } break;
case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER: {
IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_fill_buffer(
command_buffer, &segment->fill_buffer));
@@ -1180,4 +1387,6 @@
.push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set,
.dispatch = iree_hal_metal_command_buffer_prepare_dispatch,
.dispatch_indirect = iree_hal_metal_command_buffer_prepare_dispatch_indirect,
+ .dispatch2 = iree_hal_metal_command_buffer_prepare_dispatch2,
+ .dispatch2_indirect = iree_hal_metal_command_buffer_prepare_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
index 8bb9413..03dac80 100644
--- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
@@ -749,7 +749,7 @@
VkPipeline pipeline_handle = VK_NULL_HANDLE;
IREE_RETURN_IF_ERROR(
iree_hal_vulkan_native_executable_pipeline_for_entry_point(
- executable, entry_point, &pipeline_handle));
+ executable, entry_point, &pipeline_handle, NULL));
command_buffer->syms->vkCmdBindPipeline(
command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle);
@@ -787,7 +787,7 @@
VkPipeline pipeline_handle = VK_NULL_HANDLE;
IREE_RETURN_IF_ERROR(
iree_hal_vulkan_native_executable_pipeline_for_entry_point(
- executable, entry_point, &pipeline_handle));
+ executable, entry_point, &pipeline_handle, NULL));
command_buffer->syms->vkCmdBindPipeline(
command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle);
@@ -805,6 +805,120 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch2_bind(
+ iree_hal_vulkan_direct_command_buffer_t* command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings,
+ iree_hal_dispatch_flags_t flags) {
+ // Get the compiled and linked pipeline for the specified entry point.
+ VkPipeline pipeline_handle = VK_NULL_HANDLE;
+ iree_hal_pipeline_layout_t* pipeline_layout = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_vulkan_native_executable_pipeline_for_entry_point(
+ executable, entry_point, &pipeline_handle, &pipeline_layout));
+
+ // Update push constants.
+ if (!iree_const_byte_span_is_empty(constants)) {
+ VkPipelineLayout pipeline_layout_handle =
+ iree_hal_vulkan_native_pipeline_layout_handle(pipeline_layout);
+ command_buffer->syms->vkCmdPushConstants(
+ command_buffer->handle, pipeline_layout_handle,
+ VK_SHADER_STAGE_COMPUTE_BIT, (uint32_t)0,
+ (uint32_t)constants.data_length, constants.data);
+ }
+
+ // Retain bound buffers until the command buffer is reset.
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided(
+ command_buffer->resource_set, bindings.count, bindings.values,
+ offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t)));
+
+ // Either allocate, update, and bind a descriptor set or use push descriptor
+ // sets to use the command buffer pool when supported.
+ IREE_RETURN_IF_ERROR(command_buffer->descriptor_set_arena.BindDescriptorSet(
+ command_buffer->handle, pipeline_layout, 0, bindings.count,
+ bindings.values));
+
+ command_buffer->syms->vkCmdBindPipeline(
+ command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle);
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_vulkan_direct_command_buffer_t* command_buffer =
+ iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
+
+ IREE_TRACE({
+ iree_hal_vulkan_source_location_t source_location;
+ iree_hal_vulkan_native_executable_entry_point_source_location(
+ executable, entry_point, &source_location);
+ IREE_VULKAN_TRACE_ZONE_BEGIN_EXTERNAL(
+ command_buffer->tracing_context, command_buffer->handle,
+ source_location.file_name.data, source_location.file_name.size,
+ source_location.line, source_location.func_name.data,
+ source_location.func_name.size, /*name=*/NULL, 0);
+ });
+
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+ command_buffer->resource_set, 1, &executable));
+
+ IREE_RETURN_IF_ERROR(iree_hal_vulkan_direct_command_buffer_dispatch2_bind(
+ command_buffer, executable, entry_point, constants, bindings, flags));
+
+ command_buffer->syms->vkCmdDispatch(command_buffer->handle,
+ workgroup_count[0], workgroup_count[1],
+ workgroup_count[2]);
+
+ IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context,
+ command_buffer->handle);
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_vulkan_direct_command_buffer_t* command_buffer =
+ iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
+
+ IREE_TRACE({
+ iree_hal_vulkan_source_location_t source_location;
+ iree_hal_vulkan_native_executable_entry_point_source_location(
+ executable, entry_point, &source_location);
+ IREE_VULKAN_TRACE_ZONE_BEGIN_EXTERNAL(
+ command_buffer->tracing_context, command_buffer->handle,
+ source_location.file_name.data, source_location.file_name.size,
+ source_location.line, source_location.func_name.data,
+ source_location.func_name.size, /*name=*/NULL, 0);
+ });
+
+ const void* resources[2] = {executable, workgroups_ref.buffer};
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+ command_buffer->resource_set, IREE_ARRAYSIZE(resources), resources));
+
+ IREE_RETURN_IF_ERROR(iree_hal_vulkan_direct_command_buffer_dispatch2_bind(
+ command_buffer, executable, entry_point, constants, bindings, flags));
+
+ VkBuffer workgroups_device_buffer =
+ iree_hal_vulkan_buffer_handle(workgroups_ref.buffer);
+ iree_device_size_t workgroups_offset =
+ iree_hal_buffer_byte_offset(workgroups_ref.buffer) +
+ workgroups_ref.offset;
+ command_buffer->syms->vkCmdDispatchIndirect(
+ command_buffer->handle, workgroups_device_buffer, workgroups_offset);
+
+ IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context,
+ command_buffer->handle);
+
+ return iree_ok_status();
+}
+
namespace {
const iree_hal_command_buffer_vtable_t
iree_hal_vulkan_direct_command_buffer_vtable = {
@@ -836,5 +950,8 @@
/*.dispatch=*/iree_hal_vulkan_direct_command_buffer_dispatch,
/*.dispatch_indirect=*/
iree_hal_vulkan_direct_command_buffer_dispatch_indirect,
+ /*.dispatch2=*/iree_hal_vulkan_direct_command_buffer_dispatch2,
+ /*.dispatch2_indirect=*/
+ iree_hal_vulkan_direct_command_buffer_dispatch2_indirect,
};
} // namespace
diff --git a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
index b6d8dc6..ebfd006 100644
--- a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
@@ -26,6 +26,7 @@
typedef struct iree_hal_vulkan_entry_point_t {
VkPipeline pipeline;
+ iree_hal_pipeline_layout_t* layout;
iree_string_view_t name;
// Optional debug information.
@@ -107,6 +108,11 @@
iree_hal_spirv_ExecutableDef_subgroup_sizes_get(executable_def);
for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count;
++entry_ordinal) {
+ iree_hal_pipeline_layout_t* pipeline_layout =
+ executable_params->pipeline_layouts[entry_ordinal];
+ iree_hal_pipeline_layout_retain(pipeline_layout);
+ out_entry_points[entry_ordinal].layout = pipeline_layout;
+
VkComputePipelineCreateInfo* create_info = &create_infos[entry_ordinal];
create_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
create_info->pNext = NULL;
@@ -121,8 +127,8 @@
} else {
create_info->flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT;
}
- create_info->layout = iree_hal_vulkan_native_pipeline_layout_handle(
- executable_params->pipeline_layouts[entry_ordinal]);
+ create_info->layout =
+ iree_hal_vulkan_native_pipeline_layout_handle(pipeline_layout);
create_info->basePipelineHandle = VK_NULL_HANDLE;
create_info->basePipelineIndex = 0;
@@ -472,6 +478,7 @@
for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) {
iree_hal_vulkan_destroy_pipeline(executable->logical_device,
executable->entry_points[i].pipeline);
+ iree_hal_pipeline_layout_release(executable->entry_points[i].layout);
}
iree_allocator_free(host_allocator, executable);
@@ -528,7 +535,8 @@
iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point(
iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal,
- VkPipeline* out_pipeline_handle) {
+ VkPipeline* out_pipeline_handle,
+ iree_hal_pipeline_layout_t** out_pipeline_layout) {
iree_hal_vulkan_native_executable_t* executable =
iree_hal_vulkan_native_executable_cast(base_executable);
if (entry_ordinal >= executable->entry_point_count) {
@@ -537,6 +545,9 @@
entry_ordinal);
}
*out_pipeline_handle = executable->entry_points[entry_ordinal].pipeline;
+ if (out_pipeline_layout) {
+ *out_pipeline_layout = executable->entry_points[entry_ordinal].layout;
+ }
return iree_ok_status();
}
diff --git a/runtime/src/iree/hal/drivers/vulkan/native_executable.h b/runtime/src/iree/hal/drivers/vulkan/native_executable.h
index da6a845..248db1d 100644
--- a/runtime/src/iree/hal/drivers/vulkan/native_executable.h
+++ b/runtime/src/iree/hal/drivers/vulkan/native_executable.h
@@ -43,7 +43,8 @@
// Returns the cached VkPipeline for the given executable |entry_ordinal|.
iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point(
iree_hal_executable_t* executable, iree_host_size_t entry_ordinal,
- VkPipeline* out_pipeline_handle);
+ VkPipeline* out_pipeline_handle,
+ iree_hal_pipeline_layout_t** out_pipeline_layout);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/hal/executable_cache.h b/runtime/src/iree/hal/executable_cache.h
index bee9bf9..435f01d 100644
--- a/runtime/src/iree/hal/executable_cache.h
+++ b/runtime/src/iree/hal/executable_cache.h
@@ -92,6 +92,9 @@
// to any executable created using it still held by the caller.
iree_const_byte_span_t executable_data;
+ // TODO(#18154): drop pipeline layouts with simplified bindings. Allowed to be
+ // empty for now on targets that support simplified bindings.
+ //
// A set of pipeline layouts for each entry point in the executable.
// The order matches that produced by the compiler. As multiple entry points
// may share the same layout some entries in this list may reference the same
diff --git a/runtime/src/iree/hal/local/executable_library.h b/runtime/src/iree/hal/local/executable_library.h
index d917f0f..d45b477 100644
--- a/runtime/src/iree/hal/local/executable_library.h
+++ b/runtime/src/iree/hal/local/executable_library.h
@@ -372,19 +372,27 @@
// Bytes per page of workgroup local memory.
// This is chosen to match the common page size of devices.
-#define IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE 4096
+#define IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE 4096
+
+// Maximum number of constants that can be used by a single dispatch.
+#define IREE_HAL_EXECUTABLE_MAX_CONSTANT_COUNT 64
+// Maximum number of bindings that can be used by a single dispatch.
+#define IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT 64
// Attributes for exported dispatch functions defining how they are to be
// executed. 0 defaults are well-specified and the entire attributes table may
// be omitted if no dispatch functions require these fields.
typedef struct iree_hal_executable_dispatch_attrs_v0_t {
- // Number of IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE byte pages (or 0)
- // indicating how much workgroup local memory is required for the dispatch.
- // This is the size of the buffer referenced by the `local_memory` argument.
+ // Number of IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE byte pages
+ // (or 0) indicating how much workgroup local memory is required for the
+ // dispatch. This is the size of the buffer referenced by the `local_memory`
+ // argument.
uint16_t local_memory_pages;
- // Must be 0. May be used in the future for flags controlling the dispatch
- // behavior/synchronization requirements.
- uint16_t reserved;
+ // Total number of 32-bit constants used by the dispatch.
+ uint8_t constant_count;
+ // Total number of bindings used by the dispatch.
+ uint8_t binding_count;
+ // TODO(#18189): add ~8 uint64_t fields for binding bits (readonly/indirect).
} iree_hal_executable_dispatch_attrs_v0_t;
static_assert(sizeof(iree_hal_executable_dispatch_attrs_v0_t) == 4, "uint32_t");
diff --git a/runtime/src/iree/hal/local/executable_library_benchmark.c b/runtime/src/iree/hal/local/executable_library_benchmark.c
index 95403de..d87149d 100644
--- a/runtime/src/iree/hal/local/executable_library_benchmark.c
+++ b/runtime/src/iree/hal/local/executable_library_benchmark.c
@@ -186,7 +186,7 @@
local_executable->dispatch_attrs
? local_executable->dispatch_attrs[FLAG_entry_point]
.local_memory_pages *
- IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
+ IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
: 0;
if (local_memory_size > 0) {
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
diff --git a/runtime/src/iree/hal/local/executable_library_demo.c b/runtime/src/iree/hal/local/executable_library_demo.c
index af18875..300d645 100644
--- a/runtime/src/iree/hal/local/executable_library_demo.c
+++ b/runtime/src/iree/hal/local/executable_library_demo.c
@@ -66,9 +66,13 @@
static const iree_hal_executable_dispatch_attrs_v0_t entry_attrs[2] = {
{
.local_memory_pages = 0,
+ .constant_count = 1,
+ .binding_count = 2,
},
{
.local_memory_pages = 0,
+ .constant_count = 0,
+ .binding_count = 0,
},
};
// Names for each entry point.
diff --git a/runtime/src/iree/hal/local/executable_library_util.c b/runtime/src/iree/hal/local/executable_library_util.c
index 5dbe51c..b2d1165 100644
--- a/runtime/src/iree/hal/local/executable_library_util.c
+++ b/runtime/src/iree/hal/local/executable_library_util.c
@@ -39,6 +39,30 @@
executable_params->constant_count);
}
+ // If dispatch attributes are present validate they are in range.
+ if (library->exports.attrs) {
+ for (uint32_t i = 0; i < library->exports.count; ++i) {
+ const iree_hal_executable_dispatch_attrs_v0_t dispatch_attrs =
+ library->exports.attrs[i];
+ if (dispatch_attrs.constant_count >
+ IREE_HAL_EXECUTABLE_MAX_CONSTANT_COUNT) {
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "dispatch requiring %u constants exceeds limit of %d",
+ dispatch_attrs.constant_count,
+ IREE_HAL_EXECUTABLE_MAX_CONSTANT_COUNT);
+ }
+ if (dispatch_attrs.binding_count >
+ IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT) {
+ return iree_make_status(
+ IREE_STATUS_OUT_OF_RANGE,
+ "dispatch requiring %u bindings exceeds limit of %d",
+ dispatch_attrs.binding_count,
+ IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT);
+ }
+ }
+ }
+
return iree_ok_status();
}
diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c
index 3de7c60..2e0465c 100644
--- a/runtime/src/iree/hal/local/inline_command_buffer.c
+++ b/runtime/src/iree/hal/local/inline_command_buffer.c
@@ -28,24 +28,18 @@
iree_allocator_t host_allocator;
struct {
+ // TODO(#18189): remove legacy bindings state.
+ //
// A flattened list of all available descriptor set bindings.
// As descriptor sets are pushed/bound the bindings will be updated to
// represent the fully-translated binding data pointer.
- //
- // TODO(benvanik): support proper mapping semantics and track the
- // iree_hal_buffer_mapping_t and map/unmap where appropriate.
void* full_bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
size_t full_binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
- // Packed bindings scratch space used during dispatch. Executable bindings
- // are packed into a dense list with unused bindings removed.
- void* packed_bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
- IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
- size_t packed_binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT *
- IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT];
-
+ // TODO(#18189): remove legacy push constant state.
+ //
// All available push constants updated each time push_constants is called.
// Reset only with the command buffer and otherwise will maintain its values
// during recording to allow for partial push_constants updates.
@@ -55,6 +49,10 @@
// Individual dispatches must populate the dynamically changing fields like
// push_constant_count and binding_count.
iree_alignas(64) iree_hal_executable_dispatch_state_v0_t dispatch_state;
+ // Persistent storage for binding pointers used by dispatch_state.
+ void* binding_ptr_storage[IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT];
+ // Persistent storage for binding lengths used by dispatch_state.
+ size_t binding_length_storage[IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT];
// An opaque tag used to reduce the cost of processor ID queries.
iree_cpu_processor_tag_t processor_tag;
@@ -80,9 +78,9 @@
iree_hal_executable_dispatch_state_v0_t* dispatch_state =
&command_buffer->state.dispatch_state;
dispatch_state->push_constants = command_buffer->state.push_constants;
- dispatch_state->binding_ptrs = command_buffer->state.packed_bindings;
+ dispatch_state->binding_ptrs = command_buffer->state.binding_ptr_storage;
dispatch_state->binding_lengths =
- command_buffer->state.packed_binding_lengths;
+ command_buffer->state.binding_length_storage;
}
iree_host_size_t iree_hal_inline_command_buffer_size(
@@ -461,7 +459,7 @@
iree_host_size_t local_memory_size =
local_executable->dispatch_attrs
? local_executable->dispatch_attrs[entry_point].local_memory_pages *
- IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
+ IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
: 0;
// Update the ID of the processor we are running on.
@@ -489,6 +487,7 @@
// only allow the dispatch to read what we know is initialized based on the
// layout.
dispatch_state->push_constant_count = local_layout->push_constants;
+ dispatch_state->push_constants = command_buffer->state.push_constants;
// Produce the dense binding list based on the declared bindings used.
// This allows us to change the descriptor sets and bindings counts supported
@@ -548,6 +547,123 @@
return status;
}
+static iree_status_t iree_hal_inline_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_inline_command_buffer_t* command_buffer =
+ iree_hal_inline_command_buffer_cast(base_command_buffer);
+
+ iree_hal_local_executable_t* local_executable =
+ iree_hal_local_executable_cast(executable);
+
+ iree_hal_executable_dispatch_attrs_v0_t dispatch_attrs = {0};
+ if (local_executable->dispatch_attrs) {
+ dispatch_attrs = local_executable->dispatch_attrs[entry_point];
+ }
+ const iree_host_size_t local_memory_size =
+ dispatch_attrs.local_memory_pages *
+ IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE;
+
+ // Update the ID of the processor we are running on.
+ // We don't know how much time has passed since we last updated as we are
+ // running inline with the user program; if we knew we were going to be
+ // handling a batch of dispatches we could reduce the amount of times we call
+ // this - but that's what the task system is for.
+ iree_hal_inline_command_buffer_update_processor_id(command_buffer);
+
+ iree_hal_executable_dispatch_state_v0_t* dispatch_state =
+ &command_buffer->state.dispatch_state;
+
+ // TODO(benvanik): expose on API or keep fixed on executable.
+ dispatch_state->workgroup_size_x = 1;
+ dispatch_state->workgroup_size_y = 1;
+ dispatch_state->workgroup_size_z = 1;
+ dispatch_state->workgroup_count_x = workgroup_count[0];
+ dispatch_state->workgroup_count_y = workgroup_count[1];
+ dispatch_state->workgroup_count_z = workgroup_count[2];
+
+ // Single-threaded.
+ dispatch_state->max_concurrency = 1;
+
+ // Push constants are pulled directly from the args. Note that we require 4
+ // byte alignment and if the input buffer is not aligned we have to fail.
+ if (IREE_UNLIKELY((constants.data_length % sizeof(uint32_t)) != 0)) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "constants must be 4-byte aligned");
+ } else if (IREE_UNLIKELY(constants.data_length !=
+ dispatch_attrs.constant_count * sizeof(uint32_t))) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "constant count mismatch, expected %u but was provided %" PRIhsz,
+ (uint32_t)dispatch_attrs.constant_count,
+ constants.data_length / sizeof(uint32_t));
+ }
+ dispatch_state->push_constant_count = dispatch_attrs.constant_count;
+ dispatch_state->push_constants = (const uint32_t*)constants.data;
+
+ // Produce the dense binding list based on the declared bindings used.
+ //
+ // Note that we are just directly setting the binding data pointers here with
+ // no ownership/retaining/etc - it's part of the HAL contract that buffers are
+ // kept valid for the duration they may be in use.
+ if (IREE_UNLIKELY(bindings.count != dispatch_attrs.binding_count)) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "binding count mismatch, expected %u but was provided %" PRIhsz,
+ (uint32_t)dispatch_attrs.binding_count, bindings.count);
+ }
+ dispatch_state->binding_count = bindings.count;
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc.
+ iree_hal_buffer_mapping_t buffer_mapping = {{0}};
+ if (IREE_LIKELY(bindings.values[i].buffer)) {
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ bindings.values[i].buffer, IREE_HAL_MAPPING_MODE_PERSISTENT,
+ IREE_HAL_MEMORY_ACCESS_ANY, bindings.values[i].offset,
+ bindings.values[i].length, &buffer_mapping));
+ } else {
+ return iree_make_status(
+ IREE_STATUS_FAILED_PRECONDITION,
+ "required binding %" PRIhsz
+ " is NULL; all bindings must have a valid pointer",
+ i);
+ }
+ command_buffer->state.binding_ptr_storage[i] = buffer_mapping.contents.data;
+ command_buffer->state.binding_length_storage[i] =
+ buffer_mapping.contents.data_length;
+ }
+
+ // TODO(benvanik): plumb through an arena or fixed-size reservation to use.
+ // For now when deploying to devices where you want something like the
+ // inline command buffer you probably don't want 256KB of transient memory
+ // getting allocated and retained implicitly - this should be a compiler
+ // option. For now we just malloc here to make things work and strongly
+ // encourage the kind of user who wants synchronous inline execution to not
+ // also want tons of scratch memory.
+ iree_byte_span_t local_memory = iree_make_byte_span(NULL, local_memory_size);
+ if (local_memory_size > 0) {
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(command_buffer->host_allocator,
+ local_memory_size,
+ (void**)&local_memory.data));
+ }
+
+ // Since we are running on a borrowed thread, we know nothing about the
+ // floating point state. Reset it.
+ iree_fpu_state_t fpu_state =
+ iree_fpu_state_push(IREE_FPU_STATE_FLAG_FLUSH_DENORMALS_TO_ZERO);
+ iree_status_t status = iree_hal_local_executable_issue_dispatch_inline(
+ local_executable, entry_point, dispatch_state,
+ command_buffer->state.processor_id, local_memory);
+ iree_fpu_state_pop(fpu_state);
+
+ if (local_memory.data) {
+ iree_allocator_free(command_buffer->host_allocator, local_memory.data);
+ }
+ return status;
+}
+
typedef union iree_hal_vec3_t {
struct {
uint32_t x;
@@ -574,6 +690,24 @@
workgroup_count.y, workgroup_count.z, flags);
}
+static iree_status_t iree_hal_inline_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc.
+ iree_hal_buffer_mapping_t buffer_mapping = {{0}};
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ workgroups_ref.buffer, IREE_HAL_MAPPING_MODE_PERSISTENT,
+ IREE_HAL_MEMORY_ACCESS_READ, workgroups_ref.offset, 3 * sizeof(uint32_t),
+ &buffer_mapping));
+ iree_hal_vec3_t workgroup_count =
+ *(const iree_hal_vec3_t*)buffer_mapping.contents.data;
+ return iree_hal_inline_command_buffer_dispatch2(
+ base_command_buffer, executable, entry_point, workgroup_count.value,
+ constants, bindings, flags);
+}
+
//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_vtable_t
//===----------------------------------------------------------------------===//
@@ -599,4 +733,6 @@
iree_hal_inline_command_buffer_push_descriptor_set,
.dispatch = iree_hal_inline_command_buffer_dispatch,
.dispatch_indirect = iree_hal_inline_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_inline_command_buffer_dispatch2,
+ .dispatch2_indirect = iree_hal_inline_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c b/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c
index 265c0b6..2675f8e 100644
--- a/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c
+++ b/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c
@@ -302,6 +302,7 @@
.linkage = IREE_VM_FUNCTION_LINKAGE_EXPORT,
.ordinal = executable->entry_fn_ordinals[i],
};
+
iree_string_view_t local_memory_str =
iree_vm_function_lookup_attr_by_name(
&entry_fn, iree_make_cstring_view("local_memory"));
@@ -309,8 +310,26 @@
if (!iree_string_view_is_empty(local_memory_str)) {
iree_string_view_atoi_uint32(local_memory_str, &local_memory_size);
}
- local_memory_size /= IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE;
+ local_memory_size /= IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE;
dispatch_attrs[i].local_memory_pages = (uint16_t)local_memory_size;
+
+ iree_string_view_t constant_count_str =
+ iree_vm_function_lookup_attr_by_name(
+ &entry_fn, iree_make_cstring_view("constant_count"));
+ uint32_t constant_count = 0;
+ if (!iree_string_view_is_empty(constant_count_str)) {
+ iree_string_view_atoi_uint32(constant_count_str, &constant_count);
+ }
+ dispatch_attrs[i].constant_count = (uint8_t)constant_count;
+
+ iree_string_view_t binding_count_str =
+ iree_vm_function_lookup_attr_by_name(
+ &entry_fn, iree_make_cstring_view("binding_count"));
+ uint32_t binding_count = 0;
+ if (!iree_string_view_is_empty(binding_count_str)) {
+ iree_string_view_atoi_uint32(binding_count_str, &binding_count);
+ }
+ dispatch_attrs[i].binding_count = (uint8_t)binding_count;
}
}
diff --git a/runtime/src/iree/hal/local/local_executable.h b/runtime/src/iree/hal/local/local_executable.h
index 6eeb038..b6e2445 100644
--- a/runtime/src/iree/hal/local/local_executable.h
+++ b/runtime/src/iree/hal/local/local_executable.h
@@ -31,8 +31,8 @@
// Defines per-entry point how much workgroup local memory is required.
// Contains entries with 0 to indicate no local memory is required or >0 in
- // units of IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE for the minimum amount
- // of memory required by the function.
+ // units of IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE for the
+ // minimum amount of memory required by the function.
const iree_hal_executable_dispatch_attrs_v0_t* dispatch_attrs;
// Execution environment.
diff --git a/runtime/src/iree/hal/pipeline_layout.h b/runtime/src/iree/hal/pipeline_layout.h
index bd15bb1..d090868 100644
--- a/runtime/src/iree/hal/pipeline_layout.h
+++ b/runtime/src/iree/hal/pipeline_layout.h
@@ -43,9 +43,15 @@
// A bitmask of flags controlling the behavior of a descriptor.
enum iree_hal_descriptor_flag_bits_t {
IREE_HAL_DESCRIPTOR_FLAG_NONE = 0u,
+
// Indicates that the binding is treated as immutable within all dispatches
// using it.
IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY = 1u << 0,
+
+ // Indicates the descriptor is 'bindless' and passed via implementation-
+ // specific parameter buffers stored in memory instead of API-level calls.
+ // Ignored by implementations that don't have a concept of indirect bindings.
+ IREE_HAL_DESCRIPTOR_FLAG_INDIRECT = 1u << 1,
};
typedef uint32_t iree_hal_descriptor_flags_t;
diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c
index 49ec334..a1c92bf 100644
--- a/runtime/src/iree/hal/utils/deferred_command_buffer.c
+++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c
@@ -27,6 +27,8 @@
IREE_HAL_CMD_PUSH_DESCRIPTOR_SET,
IREE_HAL_CMD_DISPATCH,
IREE_HAL_CMD_DISPATCH_INDIRECT,
+ IREE_HAL_CMD_DISPATCH2,
+ IREE_HAL_CMD_DISPATCH2_INDIRECT,
} iree_hal_cmd_type_t;
// Header prefixed to all commands, forming a linked-list.
@@ -855,6 +857,162 @@
}
//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_DISPATCH2
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_dispatch2_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_t* executable;
+ int32_t entry_point;
+ uint32_t workgroup_count[3];
+ iree_const_byte_span_t constants;
+ iree_hal_buffer_ref_list_t bindings;
+ iree_hal_dispatch_flags_t flags;
+} iree_hal_cmd_dispatch2_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_dispatch2(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+ command_buffer->resource_set, 1, &executable));
+
+ iree_hal_cmd_dispatch2_t* cmd = NULL;
+ iree_host_size_t total_size =
+ sizeof(*cmd) + iree_host_align(constants.data_length, iree_max_align_t) +
+ bindings.count * sizeof(bindings.values[0]);
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ &command_buffer->cmd_list, IREE_HAL_CMD_DISPATCH2, total_size,
+ (void**)&cmd));
+ cmd->executable = executable;
+ cmd->entry_point = entry_point;
+ memcpy(cmd->workgroup_count, workgroup_count, sizeof(cmd->workgroup_count));
+ cmd->flags = flags;
+
+ uint8_t* cmd_ptr = (uint8_t*)cmd;
+ cmd_ptr += sizeof(*cmd);
+
+ memcpy(cmd_ptr, constants.data, constants.data_length);
+ cmd->constants = iree_make_const_byte_span(cmd_ptr, constants.data_length);
+ cmd_ptr += iree_host_align(constants.data_length, iree_max_align_t);
+
+ cmd->bindings.count = bindings.count;
+ memcpy(cmd_ptr, bindings.values, bindings.count * sizeof(bindings.values[0]));
+ cmd->bindings.values = (iree_hal_buffer_ref_t*)cmd_ptr;
+ cmd_ptr += bindings.count * sizeof(bindings.values[0]);
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided(
+ command_buffer->resource_set, bindings.count, bindings.values,
+ offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t)));
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch2(
+ iree_hal_command_buffer_t* target_command_buffer,
+ iree_hal_buffer_binding_table_t binding_table,
+ const iree_hal_cmd_dispatch2_t* cmd) {
+ iree_hal_buffer_ref_t* binding_refs = (iree_hal_buffer_ref_t*)iree_alloca(
+ cmd->bindings.count * sizeof(iree_hal_buffer_ref_t));
+ for (iree_host_size_t i = 0; i < cmd->bindings.count; ++i) {
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+ binding_table, cmd->bindings.values[i], &binding_refs[i]));
+ }
+ const iree_hal_buffer_ref_list_t binding_ref_list = {
+ .count = cmd->bindings.count,
+ .values = binding_refs,
+ };
+ return iree_hal_command_buffer_dispatch2(
+ target_command_buffer, cmd->executable, cmd->entry_point,
+ cmd->workgroup_count, cmd->constants, binding_ref_list, cmd->flags);
+}
+
+//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_DISPATCH2_INDIRECT
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_dispatch2_indirect_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_executable_t* executable;
+ int32_t entry_point;
+ iree_hal_buffer_ref_t workgroups_ref;
+ iree_const_byte_span_t constants;
+ iree_hal_buffer_ref_list_t bindings;
+ iree_hal_dispatch_flags_t flags;
+} iree_hal_cmd_dispatch2_indirect_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_dispatch2_indirect(
+ iree_hal_command_buffer_t* base_command_buffer,
+ iree_hal_executable_t* executable, int32_t entry_point,
+ iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
+ iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(base_command_buffer);
+
+ iree_host_size_t resource_count = 0;
+ const void* resources[2] = {NULL, NULL};
+ resources[resource_count++] = executable;
+ if (workgroups_ref.buffer) {
+ resources[resource_count++] = workgroups_ref.buffer;
+ }
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+ command_buffer->resource_set, resource_count, resources));
+
+ iree_hal_cmd_dispatch2_indirect_t* cmd = NULL;
+ iree_host_size_t total_size =
+ sizeof(*cmd) + iree_host_align(constants.data_length, iree_max_align_t) +
+ bindings.count * sizeof(bindings.values[0]);
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ &command_buffer->cmd_list, IREE_HAL_CMD_DISPATCH2_INDIRECT, total_size,
+ (void**)&cmd));
+ cmd->executable = executable;
+ cmd->entry_point = entry_point;
+ cmd->workgroups_ref = workgroups_ref;
+ cmd->flags = flags;
+
+ uint8_t* cmd_ptr = (uint8_t*)cmd;
+ cmd_ptr += sizeof(*cmd);
+
+ memcpy(cmd_ptr, constants.data, constants.data_length);
+ cmd->constants = iree_make_const_byte_span(cmd_ptr, constants.data_length);
+ cmd_ptr += iree_host_align(constants.data_length, iree_max_align_t);
+
+ cmd->bindings.count = bindings.count;
+ memcpy(cmd_ptr, bindings.values, bindings.count * sizeof(bindings.values[0]));
+ cmd->bindings.values = (iree_hal_buffer_ref_t*)cmd_ptr;
+ cmd_ptr += bindings.count * sizeof(bindings.values[0]);
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided(
+ command_buffer->resource_set, bindings.count, bindings.values,
+ offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t)));
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch2_indirect(
+ iree_hal_command_buffer_t* target_command_buffer,
+ iree_hal_buffer_binding_table_t binding_table,
+ const iree_hal_cmd_dispatch2_indirect_t* cmd) {
+ iree_hal_buffer_ref_t workgroups_ref;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+ binding_table, cmd->workgroups_ref, &workgroups_ref));
+ iree_hal_buffer_ref_t* binding_refs = (iree_hal_buffer_ref_t*)iree_alloca(
+ cmd->bindings.count * sizeof(iree_hal_buffer_ref_t));
+ for (iree_host_size_t i = 0; i < cmd->bindings.count; ++i) {
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
+ binding_table, cmd->bindings.values[i], &binding_refs[i]));
+ }
+ const iree_hal_buffer_ref_list_t binding_ref_list = {
+ .count = cmd->bindings.count,
+ .values = binding_refs,
+ };
+ return iree_hal_command_buffer_dispatch2_indirect(
+ target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref,
+ cmd->constants, binding_ref_list, cmd->flags);
+}
+
+//===----------------------------------------------------------------------===//
// Dynamic replay dispatch
//===----------------------------------------------------------------------===//
@@ -885,6 +1043,10 @@
iree_hal_deferred_command_buffer_apply_dispatch,
[IREE_HAL_CMD_DISPATCH_INDIRECT] = (iree_hal_cmd_apply_fn_t)
iree_hal_deferred_command_buffer_apply_dispatch_indirect,
+ [IREE_HAL_CMD_DISPATCH2] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_dispatch2,
+ [IREE_HAL_CMD_DISPATCH2_INDIRECT] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_dispatch2_indirect,
};
IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_apply(
@@ -943,4 +1105,7 @@
iree_hal_deferred_command_buffer_push_descriptor_set,
.dispatch = iree_hal_deferred_command_buffer_dispatch,
.dispatch_indirect = iree_hal_deferred_command_buffer_dispatch_indirect,
+ .dispatch2 = iree_hal_deferred_command_buffer_dispatch2,
+ .dispatch2_indirect =
+ iree_hal_deferred_command_buffer_dispatch2_indirect,
};
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index f6f96f2..8d44445 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -50,8 +50,11 @@
EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v)
EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v)
EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r)
+// TODO(#18154): replace base dispatch with new `2` versions.
EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiiiI, v)
EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirII, v)
+EXPORT_FN_CUSTOM("command_buffer.dispatch2", iree_hal_module_command_buffer_dispatch2, rriiiiICiDCiirIID, v)
+EXPORT_FN_CUSTOM("command_buffer.dispatch2.indirect", iree_hal_module_command_buffer_dispatch2_indirect, rriirIICiDCiirIID, v)
EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v)
EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v)
@@ -77,7 +80,9 @@
EXPORT_FN("ex.file.from_memory", iree_hal_module_ex_file_from_memory, rIirIIi, r)
+// TODO(#18154): replace base executable create with new `2` versions.
EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r)
+EXPORT_FN("executable.create2", iree_hal_module_executable_create2, rrrr, r)
EXPORT_FN("fence.await", iree_hal_module_fence_await, iCrD, i)
EXPORT_FN("fence.create", iree_hal_module_fence_create, ri, r)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index e599d77..f3cac5e 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -32,8 +32,8 @@
// Module type definitions
//===----------------------------------------------------------------------===//
-#define IREE_HAL_MODULE_VERSION_0_3 0x00000003u
-#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_3
+#define IREE_HAL_MODULE_VERSION_0_4 0x00000004u
+#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_4
typedef struct iree_hal_module_t {
iree_allocator_t host_allocator;
@@ -945,6 +945,212 @@
command_buffer, executable, entry_point, workgroups_ref, flags);
}
+// Argument signature: rriiiiICiDCiirIID
+typedef struct {
+ union {
+ struct {
+ iree_vm_ref_t command_buffer;
+ iree_vm_ref_t executable;
+ int32_t entry_point;
+ uint32_t workgroup_count[3];
+ iree_hal_dispatch_flags_t flags;
+ };
+ iree_vm_abi_rriiiiI_t params;
+ };
+ iree_vm_size_t constant_count;
+ const uint32_t* constants;
+ iree_vm_size_t binding_count;
+ const iree_vm_abi_iirII_t* bindings;
+} iree_hal_module_command_buffer_dispatch2_args_t;
+static iree_status_t iree_hal_module_command_buffer_dispatch2(
+ iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module,
+ iree_hal_module_state_t* IREE_RESTRICT state,
+ const iree_hal_module_command_buffer_dispatch2_args_t* IREE_RESTRICT args) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_check_deref(args->command_buffer,
+ &command_buffer));
+ iree_hal_executable_t* executable = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_executable_check_deref(args->executable, &executable));
+
+ if (IREE_UNLIKELY(args->binding_count >
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "binding count %" PRIhsz " > %" PRIhsz,
+ (iree_host_size_t)args->binding_count,
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+ }
+ iree_hal_buffer_ref_list_t bindings = {
+ .count = (iree_host_size_t)args->binding_count,
+ .values = (iree_hal_buffer_ref_t*)iree_alloca(
+ args->binding_count * sizeof(iree_hal_buffer_ref_t)),
+ };
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ iree_hal_buffer_ref_t* binding =
+ (iree_hal_buffer_ref_t*)&bindings.values[i];
+ binding->ordinal = 0;
+ binding->buffer_slot = (uint32_t)args->bindings[i].i1;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null(
+ args->bindings[i].r2, &binding->buffer));
+ binding->offset = iree_hal_cast_device_size(args->bindings[i].i3);
+ binding->length = iree_hal_cast_device_size(args->bindings[i].i4);
+ }
+
+ return iree_hal_command_buffer_dispatch2(
+ command_buffer, executable, args->entry_point, args->workgroup_count,
+ iree_make_const_byte_span(args->constants,
+ args->constant_count * sizeof(uint32_t)),
+ bindings, (iree_hal_dispatch_flags_t)args->flags);
+}
+static iree_status_t iree_hal_module_command_buffer_dispatch2_shim(
+ iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags,
+ iree_byte_span_t args_storage, iree_byte_span_t rets_storage,
+ iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module,
+ void* IREE_RESTRICT module_state) {
+ // TODO(benvanik): support multiple variadic segments in one call.
+ // For now we inline what it would do in a very painful way.
+ bool args_ok = true;
+ if (args_storage.data_length <
+ (sizeof(iree_vm_abi_rriiiiI_t) + sizeof(iree_vm_size_t) +
+ sizeof(iree_vm_size_t))) {
+ // Can't fit even with zero lengths.
+ args_ok = false;
+ }
+ iree_hal_module_command_buffer_dispatch2_args_t args = {
+ .params = *(const iree_vm_abi_rriiiiI_t*)args_storage.data,
+ };
+ if (args_ok) {
+ const uint8_t* constants_ptr = args_storage.data + sizeof(args.params);
+ args.constant_count = *(const iree_vm_size_t*)constants_ptr;
+ args.constants = (const uint32_t*)(constants_ptr + sizeof(iree_vm_size_t));
+ const uint8_t* bindings_ptr =
+ constants_ptr + sizeof(iree_vm_size_t) +
+ args.constant_count * sizeof(args.constants[0]);
+ args.binding_count = *(const iree_vm_size_t*)bindings_ptr;
+ args.bindings =
+ (const iree_vm_abi_iirII_t*)(bindings_ptr + sizeof(iree_vm_size_t));
+ const uint8_t* max_ptr = (const uint8_t*)args.bindings +
+ args.binding_count * sizeof(args.bindings[0]);
+ const uint8_t* end_ptr = args_storage.data + args_storage.data_length;
+ if (max_ptr > end_ptr) args_ok = false;
+ }
+ if (IREE_UNLIKELY(!args_ok || rets_storage.data_length > 0)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "argument/result signature mismatch");
+ }
+ IREE_ASSERT(target_fn == (iree_vm_native_function_target2_t)
+ iree_hal_module_command_buffer_dispatch2);
+ return iree_hal_module_command_buffer_dispatch2(stack, module, module_state,
+ &args);
+}
+
+// Argument signature: rriirIICiDCiirIID
+typedef struct {
+ union {
+ struct {
+ iree_vm_ref_t command_buffer;
+ iree_vm_ref_t executable;
+ int32_t entry_point;
+ int32_t workgroups_buffer_slot;
+ iree_vm_ref_t workgroups_buffer;
+ int64_t workgroups_offset;
+ iree_hal_dispatch_flags_t flags;
+ };
+ iree_vm_abi_rriirII_t params;
+ };
+ iree_vm_size_t constant_count;
+ const uint32_t* constants;
+ iree_vm_size_t binding_count;
+ const iree_vm_abi_iirII_t* bindings;
+} iree_hal_module_command_buffer_dispatch2_indirect_args_t;
+static iree_status_t iree_hal_module_command_buffer_dispatch2_indirect(
+ iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module,
+ iree_hal_module_state_t* IREE_RESTRICT state,
+ const iree_hal_module_command_buffer_dispatch2_indirect_args_t*
+ IREE_RESTRICT args) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_check_deref(args->command_buffer,
+ &command_buffer));
+ iree_hal_executable_t* executable = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_executable_check_deref(args->executable, &executable));
+ iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_indirect_buffer_ref(
+ args->workgroups_buffer_slot,
+ iree_hal_cast_device_size(args->workgroups_offset), 3 * sizeof(uint32_t));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null(
+ args->workgroups_buffer, &workgroups_ref.buffer));
+
+ if (IREE_UNLIKELY(args->binding_count >
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) {
+ return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
+ "binding count %" PRIhsz " > %" PRIhsz,
+ (iree_host_size_t)args->binding_count,
+ IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT);
+ }
+ iree_hal_buffer_ref_list_t bindings = {
+ .count = (iree_host_size_t)args->binding_count,
+ .values = (iree_hal_buffer_ref_t*)iree_alloca(
+ args->binding_count * sizeof(iree_hal_buffer_ref_t)),
+ };
+ for (iree_host_size_t i = 0; i < bindings.count; ++i) {
+ iree_hal_buffer_ref_t* binding =
+ (iree_hal_buffer_ref_t*)&bindings.values[i];
+ binding->buffer_slot = (uint32_t)args->bindings[i].i1;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null(
+ args->bindings[i].r2, &binding->buffer));
+ binding->offset = iree_hal_cast_device_size(args->bindings[i].i3);
+ binding->length = iree_hal_cast_device_size(args->bindings[i].i4);
+ }
+
+ return iree_hal_command_buffer_dispatch2_indirect(
+ command_buffer, executable, args->entry_point, workgroups_ref,
+ iree_make_const_byte_span(args->constants,
+ args->constant_count * sizeof(uint32_t)),
+ bindings, (iree_hal_dispatch_flags_t)args->flags);
+}
+static iree_status_t iree_hal_module_command_buffer_dispatch2_indirect_shim(
+ iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags,
+ iree_byte_span_t args_storage, iree_byte_span_t rets_storage,
+ iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module,
+ void* IREE_RESTRICT module_state) {
+ // TODO(benvanik): support multiple variadic segments in one call.
+ // For now we inline what it would do in a very painful way.
+ bool args_ok = true;
+ if (args_storage.data_length <
+ (sizeof(iree_vm_abi_rriirII_t) + sizeof(iree_vm_size_t) +
+ sizeof(iree_vm_size_t))) {
+ // Can't fit even with zero lengths.
+ args_ok = false;
+ }
+ iree_hal_module_command_buffer_dispatch2_indirect_args_t args = {
+ .params = *(const iree_vm_abi_rriirII_t*)args_storage.data,
+ };
+ if (args_ok) {
+ const uint8_t* constants_ptr = args_storage.data + sizeof(args.params);
+ args.constant_count = *(const iree_vm_size_t*)constants_ptr;
+ args.constants = (const uint32_t*)(constants_ptr + sizeof(iree_vm_size_t));
+ const uint8_t* bindings_ptr =
+ constants_ptr + sizeof(iree_vm_size_t) +
+ args.constant_count * sizeof(args.constants[0]);
+ args.binding_count = *(const iree_vm_size_t*)bindings_ptr;
+ args.bindings =
+ (const iree_vm_abi_iirII_t*)(bindings_ptr + sizeof(iree_vm_size_t));
+ const uint8_t* max_ptr = (const uint8_t*)args.bindings +
+ args.binding_count * sizeof(args.bindings[0]);
+ const uint8_t* end_ptr = args_storage.data + args_storage.data_length;
+ if (max_ptr > end_ptr) args_ok = false;
+ }
+ if (IREE_UNLIKELY(!args_ok || rets_storage.data_length > 0)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "argument/result signature mismatch");
+ }
+ IREE_ASSERT(target_fn ==
+ (iree_vm_native_function_target2_t)
+ iree_hal_module_command_buffer_dispatch2_indirect);
+ return iree_hal_module_command_buffer_dispatch2_indirect(stack, module,
+ module_state, &args);
+}
+
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout
//===----------------------------------------------------------------------===//
@@ -1289,6 +1495,57 @@
return status;
}
+IREE_VM_ABI_EXPORT(iree_hal_module_executable_create2, //
+ iree_hal_module_state_t, //
+ rrrr, r) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_vm_buffer_t* executable_format = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref(args->r1, &executable_format));
+ iree_string_view_t executable_format_str =
+ iree_vm_buffer_as_string(executable_format);
+ iree_vm_buffer_t* executable_data = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r2, &executable_data));
+ iree_host_size_t constant_count = 0;
+ const uint32_t* constants = NULL;
+ if (iree_vm_buffer_isa(args->r3)) {
+ iree_vm_buffer_t* constant_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_vm_buffer_check_deref(args->r3, &constant_buffer));
+ if (constant_buffer->data.data_length % 4 != 0) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "constant buffer data must contain 4-byte "
+ "elements but data length is %" PRIhsz,
+ constant_buffer->data.data_length);
+ }
+ constant_count = constant_buffer->data.data_length / sizeof(uint32_t);
+ constants = (const uint32_t*)constant_buffer->data.data;
+ }
+
+ iree_hal_executable_cache_t* executable_cache = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_module_state_lookup_executable_cache(
+ state, device, &executable_cache));
+
+ iree_hal_executable_t* executable = NULL;
+ iree_hal_executable_params_t executable_params;
+ iree_hal_executable_params_initialize(&executable_params);
+ executable_params.caching_mode |=
+ executable_data->access == IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE
+ ? IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA
+ : 0;
+ executable_params.executable_format = executable_format_str;
+ executable_params.executable_data = iree_make_const_byte_span(
+ executable_data->data.data, executable_data->data.data_length);
+ executable_params.constant_count = constant_count;
+ executable_params.constants = constants;
+ IREE_RETURN_IF_ERROR(iree_hal_executable_cache_prepare_executable(
+ executable_cache, &executable_params, &executable));
+
+ rets->r0 = iree_hal_executable_move_ref(executable);
+ return iree_ok_status();
+}
+
//===----------------------------------------------------------------------===//
// iree_hal_fence_t
//===----------------------------------------------------------------------===//
@@ -1652,8 +1909,14 @@
iree_vm_shim_##arg_types##_##ret_types, \
.target = (iree_vm_native_function_target_t)(target_fn), \
},
+#define EXPORT_FN_CUSTOM(name, target_fn, arg_types, ret_types) \
+ { \
+ .shim = (iree_vm_native_function_shim_t)(target_fn##_shim), \
+ .target = (iree_vm_native_function_target_t)(target_fn), \
+ },
#include "iree/modules/hal/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
+#undef EXPORT_FN_CUSTOM
};
// NOTE: 0 length, but can't express that in C.
@@ -1668,8 +1931,10 @@
.attr_count = 0, \
.attrs = NULL, \
},
+#define EXPORT_FN_CUSTOM EXPORT_FN
#include "iree/modules/hal/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
+#undef EXPORT_FN_CUSTOM
};
static_assert(IREE_ARRAYSIZE(iree_hal_module_funcs_) ==
IREE_ARRAYSIZE(iree_hal_module_exports_),
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 5bd69a7..2509ffa 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -46,6 +46,7 @@
IREE_VM_ABI_DEFINE_SHIM(rIiiI, r);
IREE_VM_ABI_DEFINE_SHIM(riIiirII, r);
IREE_VM_ABI_DEFINE_SHIM(rriiiirrIIIII, v);
+IREE_VM_ABI_DEFINE_SHIM(rrrr, r);
IREE_VM_ABI_DEFINE_SHIM(rrrrCrD, r);
IREE_VM_ABI_DEFINE_SHIM(ririi, v);
IREE_VM_ABI_DEFINE_SHIM(rr, i);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index b47428c..cd14a46 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -585,6 +585,13 @@
iree_vm_abi_r_t a3[0];
});
+IREE_VM_ABI_FIXED_STRUCT(rrrr, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ iree_vm_ref_t r2;
+ iree_vm_ref_t r3;
+});
+
IREE_VM_ABI_VLA_STRUCT(rrrrCrD, a4_count, a4, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
@@ -697,6 +704,7 @@
IREE_VM_ABI_DECLARE_SHIM(rIiiI, r);
IREE_VM_ABI_DECLARE_SHIM(riIiirII, r);
IREE_VM_ABI_DECLARE_SHIM(rriiiirrIIIII, v);
+IREE_VM_ABI_DECLARE_SHIM(rrrr, r);
IREE_VM_ABI_DECLARE_SHIM(rrrrCrD, r);
IREE_VM_ABI_DECLARE_SHIM(ririi, v);
IREE_VM_ABI_DECLARE_SHIM(rr, i);