Removing nested command buffers and adding indirect execution. (#17724)
The compiler is unlikely to need nested execution and the concept of
nested command buffers makes HAL driver implementation trickier. Instead
now all command buffers have a single type which may optionally include
indirect bindings. HAL devices now optionally take a binding table per
command buffer to use when scheduling it. The compiler now has a
`hal.device.queue.execute.indirect` available that takes a single
command buffer and a binding table for it because having nested variadic
imports is not possible.
This is a non-breaking binary change as the compiler has never emitted
the removed execute call or nested command buffers.
Future changes will add command buffer recording and validation of
indirect bindings. Today they're ignored as there's no way to record
command buffers with indirect bindings.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
index ad62782..915e5f2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertDeviceOps.cpp
@@ -115,6 +115,61 @@
mutable IREE::VM::ImportOp importOp;
};
+class DeviceQueueExecuteIndirectOpConversion
+ : public OpConversionPattern<IREE::HAL::DeviceQueueExecuteIndirectOp> {
+public:
+ DeviceQueueExecuteIndirectOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::DeviceQueueExecuteIndirectOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+ auto i64Type = rewriter.getI64Type();
+
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getDevice(),
+ castToImportType(adaptor.getQueueAffinity(), i64Type, rewriter),
+ adaptor.getWaitFence(),
+ adaptor.getSignalFence(),
+ adaptor.getCommandBuffer(),
+ };
+ SmallVector<int16_t, 5> segmentSizes = {
+ /*device=*/-1,
+ /*queue_affinity=*/-1,
+ /*wait_fence=*/-1,
+ /*signal_fence=*/-1,
+ /*command_buffer=*/-1,
+ /*bindings=*/
+ static_cast<int16_t>(adaptor.getBindingBuffers().size()),
+ };
+ for (auto [bindingBuffer, bindingOffset, bindingLength] : llvm::zip_equal(
+ adaptor.getBindingBuffers(), adaptor.getBindingOffsets(),
+ adaptor.getBindingLengths())) {
+ callOperands.push_back(bindingBuffer);
+ callOperands.push_back(
+ castToImportType(bindingOffset, i64Type, rewriter));
+ callOperands.push_back(
+ castToImportType(bindingLength, i64Type, rewriter));
+ }
+
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes,
+ importType.getInputs(), callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
void populateHALDeviceToVMPatterns(MLIRContext *context,
SymbolTable &importSymbols,
TypeConverter &typeConverter,
@@ -136,6 +191,9 @@
context, importSymbols, typeConverter, "hal.device.queue.write");
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueExecuteOp>>(
context, importSymbols, typeConverter, "hal.device.queue.execute");
+ patterns.insert<DeviceQueueExecuteIndirectOpConversion>(
+ context, importSymbols, typeConverter,
+ "hal.device.queue.execute.indirect");
patterns.insert<VMImportOpConversion<IREE::HAL::DeviceQueueFlushOp>>(
context, importSymbols, typeConverter, "hal.device.queue.flush");
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir
index 0dbba92..a0052cb 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/device_ops.mlir
@@ -197,6 +197,38 @@
// -----
+// CHECK-LABEL: @device_queue_execute_indirect
+util.func public @device_queue_execute_indirect(
+ // CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64,
+ %device: !hal.device, %affinity: i64,
+ // CHECK-SAME: %[[WAIT_FENCE:.+]]: !vm.ref<!hal.fence>, %[[SIGNAL_FENCE:.+]]: !vm.ref<!hal.fence>,
+ %wait_fence: !hal.fence, %signal_fence: !hal.fence,
+ // CHECK-SAME: %[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+ %cmd: !hal.command_buffer,
+ // CHECK-SAME: %[[BUFFER0:.+]]: !vm.ref<!hal.buffer>, %[[BUFFER1:.+]]: !vm.ref<!hal.buffer>
+ %buffer0: !hal.buffer, %buffer1: !hal.buffer) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c1000 = arith.constant 1000 : index
+ %c2000 = arith.constant 2000 : index
+ // CHECK: vm.call.variadic @hal.device.queue.execute.indirect(
+ // CHECK-SAME: %[[DEVICE]], %[[AFFINITY]],
+ // CHECK-SAME: %[[WAIT_FENCE]], %[[SIGNAL_FENCE]],
+ // CHECK-SAME: %[[CMD]],
+ // CHECK-SAME: [(%[[BUFFER0]], %c100, %c1000), (%[[BUFFER1]], %c200, %c2000)])
+ hal.device.queue.execute.indirect<%device : !hal.device>
+ affinity(%affinity)
+ wait(%wait_fence) signal(%signal_fence)
+ commands(%cmd)
+ bindings([
+ (%buffer0 : !hal.buffer)[%c100, %c1000],
+ (%buffer1 : !hal.buffer)[%c200, %c2000]
+ ])
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @device_queue_flush
util.func public @device_queue_flush(
// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 37ab489..e2d56d2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -136,13 +136,11 @@
def HAL_CommandBufferMode_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_CommandBufferMode_OneShot : I32BitEnumAttrCase<"OneShot", 0x0001>;
-def HAL_CommandBufferMode_Nested : I32BitEnumAttrCase<"Nested", 0x0002>;
def HAL_CommandBufferMode_AllowInlineExecution : I32BitEnumAttrCase<"AllowInlineExecution", 0x0010>;
def HAL_CommandBufferModeBitfieldAttr :
I32BitEnumAttr<"CommandBufferModeBitfield", "valid CommandBufferMode", [
HAL_CommandBufferMode_None,
HAL_CommandBufferMode_OneShot,
- HAL_CommandBufferMode_Nested,
HAL_CommandBufferMode_AllowInlineExecution,
]> {
let cppNamespace = "mlir::iree_compiler::IREE::HAL";
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 4a0c36a..88ef71a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -112,6 +112,63 @@
}
//===----------------------------------------------------------------------===//
+// custom<BindingTable>($binding_buffers,
+// type($binding_buffers),
+// $binding_offsets,
+// $binding_lengths)
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBindingTable(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers,
+ SmallVectorImpl<Type> &bufferTypes,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) {
+ do {
+ OpAsmParser::UnresolvedOperand buffer;
+ Type bufferType;
+ OpAsmParser::UnresolvedOperand bufferOffset;
+ OpAsmParser::UnresolvedOperand bufferLength;
+ if (failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
+ failed(parser.parseColonType(bufferType)) ||
+ failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
+ failed(parser.parseOperand(bufferOffset)) ||
+ failed(parser.parseComma()) ||
+ failed(parser.parseOperand(bufferLength)) ||
+ failed(parser.parseRSquare())) {
+ return failure();
+ }
+ buffers.push_back(buffer);
+ bufferTypes.push_back(bufferType);
+ bufferOffsets.push_back(bufferOffset);
+ bufferLengths.push_back(bufferLength);
+ } while (succeeded(parser.parseOptionalComma()));
+ return success();
+}
+
+static void printBindingTable(OpAsmPrinter &p, Operation *op,
+ ValueRange buffers, TypeRange bufferTypes,
+ ValueRange bufferOffsets,
+ ValueRange bufferLengths) {
+ llvm::interleaveComma(
+ llvm::zip_equal(buffers, bufferTypes, bufferOffsets, bufferLengths), p,
+ [&](std::tuple<Value, Type, Value, Value> it) {
+ p.printNewline();
+ p << " ";
+ p << "(";
+ p.printOperand(std::get<0>(it));
+ p << " : ";
+ p.printType(std::get<1>(it));
+ p << ")[";
+ p.printOperand(std::get<2>(it));
+ p << ", ";
+ p.printOperand(std::get<3>(it));
+ p << "]";
+ });
+ p.printNewline();
+}
+
+//===----------------------------------------------------------------------===//
// custom<TargetConditionRegion>($body)
//===----------------------------------------------------------------------===//
@@ -1056,6 +1113,30 @@
return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}
+void DeviceQueueExecuteIndirectOp::build(OpBuilder &builder,
+ OperationState &state, Value device,
+ Value queueAffinity, Value waitFence,
+ Value signalFence, Value commandBuffer,
+ ArrayRef<BindingTableValue> bindings) {
+ state.addOperands(
+ {device, queueAffinity, waitFence, signalFence, commandBuffer});
+ SmallVector<Value> bindingBuffers;
+ SmallVector<Value> bindingOffsets;
+ SmallVector<Value> bindingLengths;
+ for (auto binding : bindings) {
+ bindingBuffers.push_back(binding.buffer);
+ bindingOffsets.push_back(binding.byteOffset);
+ bindingLengths.push_back(binding.byteLength);
+ }
+ state.addOperands(bindingBuffers);
+ state.addOperands(bindingOffsets);
+ state.addOperands(bindingLengths);
+}
+
+LogicalResult DeviceQueueExecuteIndirectOp::verify() {
+ return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
+}
+
//===----------------------------------------------------------------------===//
// hal.devices.*
//===----------------------------------------------------------------------===//
@@ -1650,6 +1731,7 @@
//===----------------------------------------------------------------------===//
// hal.interface.binding.subspan
//===----------------------------------------------------------------------===//
+
void InterfaceBindingSubspanOp::build(
OpBuilder &builder, OperationState &result, Type resultType, APInt set,
APInt binding, IREE::HAL::DescriptorType descriptor_type, Value byte_offset,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 32551f6..1889f59 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1881,6 +1881,58 @@
let hasVerifier = 1;
}
+def HAL_DeviceQueueExecuteIndirectOp : HAL_Op<"device.queue.execute.indirect", [
+ SameVariadicOperandSize,
+]> {
+ let summary = [{enqueues command buffer execution}];
+ let description = [{
+ Executes a command buffer on a device queue with the given binding table.
+ No commands will execute until the wait fence has been reached and the
+ signal fence will be signaled when all commands have completed.
+ }];
+
+ let arguments = (ins
+ HAL_Device:$device,
+ HAL_DeviceQueueAffinity:$queue_affinity,
+ HAL_Fence:$wait_fence,
+ HAL_Fence:$signal_fence,
+ HAL_CommandBuffer:$command_buffer,
+ Variadic<HAL_BufferType>:$binding_buffers,
+ Variadic<HAL_DeviceSize>:$binding_offsets,
+ Variadic<HAL_DeviceSize>:$binding_lengths
+ );
+ let results = (outs);
+
+ let assemblyFormat = [{
+ `<` $device `:` type($device) `>`
+ `affinity` `(` $queue_affinity `)`
+ `wait` `(` $wait_fence `)`
+ `signal` `(` $signal_fence `)`
+ `commands` `(` $command_buffer `)`
+ `bindings` `(` `[`
+ custom<BindingTable>($binding_buffers,
+ type($binding_buffers),
+ $binding_offsets,
+ $binding_lengths)
+ `]` `)`
+ attr-dict-with-keyword
+ }];
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins
+ "Value":$device,
+ "Value":$queueAffinity,
+ "Value":$waitFence,
+ "Value":$signalFence,
+ "Value":$commandBuffer,
+ "ArrayRef<BindingTableValue>":$bindings
+ )>,
+ ];
+
+ let hasVerifier = 1;
+}
+
def HAL_DeviceQueueFlushOp : HAL_Op<"device.queue.flush"> {
let summary = [{flushes locally-pending submissions to the queue}];
let description = [{
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
index e50a9d2..ef67024 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -177,6 +177,12 @@
Value byteLength;
};
+struct BindingTableValue {
+ Value buffer;
+ Value byteOffset;
+ Value byteLength;
+};
+
template <typename T>
struct StaticRange {
T min;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
index 206c3bb..46be846 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/device_ops.mlir
@@ -158,6 +158,40 @@
// -----
+// CHECK-LABEL: @device_queue_execute_indirect
+util.func public @device_queue_execute_indirect(
+ // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64,
+ %device: !hal.device, %affinity: i64,
+ // CHECK-SAME: %[[WAIT_FENCE:.+]]: !hal.fence, %[[SIGNAL_FENCE:.+]]: !hal.fence,
+ %wait_fence: !hal.fence, %signal_fence: !hal.fence,
+ // CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
+ %cmd: !hal.command_buffer,
+ // CHECK-SAME: %[[BUFFER0:.+]]: !hal.buffer, %[[BUFFER1:.+]]: !hal.buffer
+ %buffer0: !hal.buffer, %buffer1: !hal.buffer) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c1000 = arith.constant 1000 : index
+ %c2000 = arith.constant 2000 : index
+ // CHECK: hal.device.queue.execute.indirect<%[[DEVICE]] : !hal.device>
+ hal.device.queue.execute.indirect<%device : !hal.device>
+ // CHECK-SAME: affinity(%[[AFFINITY]])
+ affinity(%affinity)
+ // CHECK-SAME: wait(%[[WAIT_FENCE]]) signal(%[[SIGNAL_FENCE]])
+ wait(%wait_fence) signal(%signal_fence)
+ // CHECK-SAME: commands(%[[CMD]])
+ commands(%cmd)
+ // CHECK-SAME: bindings([
+ bindings([
+ // CHECK-NEXT: (%[[BUFFER0]] : !hal.buffer)[%c100, %c1000]
+ (%buffer0 : !hal.buffer)[%c100, %c1000],
+ // CHECK-NEXT: (%[[BUFFER1]] : !hal.buffer)[%c200, %c2000]
+ (%buffer1 : !hal.buffer)[%c200, %c2000]
+ ])
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @device_queue_flush
util.func public @device_queue_flush(
// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64)
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 445c508..1d21fb3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -302,14 +302,6 @@
%workgroups_offset : i64
)
-// Executes a secondary command buffer with the given binding table.
-vm.import private @command_buffer.execute.commands(
- %command_buffer : !vm.ref<!hal.command_buffer>,
- %commands : !vm.ref<!hal.command_buffer>,
- // <buffer, offset, length>
- %bindings : tuple<!vm.ref<!hal.buffer>, i64, i64>...
-)
-
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout_t
//===----------------------------------------------------------------------===//
@@ -407,6 +399,19 @@
%command_buffers : !vm.ref<!hal.command_buffer>...
)
+// Executes a command buffer on a device queue with the given binding table.
+// No commands will execute until the wait fence has been reached and the signal
+// fence will be signaled when all commands have completed.
+vm.import private @device.queue.execute.indirect(
+ %device : !vm.ref<!hal.device>,
+ %queue_affinity : i64,
+ %wait_fence : !vm.ref<!hal.fence>,
+ %signal_fence : !vm.ref<!hal.fence>,
+ %command_buffer : !vm.ref<!hal.command_buffer>,
+ // <buffer, offset, length>
+ %binding_table : tuple<!vm.ref<!hal.buffer>, i64, i64>...
+)
+
// Flushes any locally-pending submissions in the queue.
// When submitting many queue operations this can be used to eagerly flush
// earlier submissions while later ones are still being constructed.
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index dff9cbd..9bde337 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -466,14 +466,6 @@
"need rocm implementation");
}
-static iree_status_t iree_hal_rocm_direct_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
static const iree_hal_command_buffer_vtable_t
iree_hal_rocm_direct_command_buffer_vtable = {
.destroy = iree_hal_rocm_direct_command_buffer_destroy,
@@ -498,6 +490,4 @@
.dispatch = iree_hal_rocm_direct_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_rocm_direct_command_buffer_dispatch_indirect,
- .execute_commands =
- iree_hal_rocm_direct_command_buffer_execute_commands,
};
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index 7200668..fabc2f7 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -385,7 +385,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_rocm_device_t* device = iree_hal_rocm_device_cast(base_device);
// TODO(raikonenfnu): Once semaphore is implemented wait for semaphores
// TODO(thomasraoux): implement semaphores - for now this conservatively
diff --git a/experimental/web/sample_webgpu/main.c b/experimental/web/sample_webgpu/main.c
index d1fad75..e7016eb 100644
--- a/experimental/web/sample_webgpu/main.c
+++ b/experimental/web/sample_webgpu/main.c
@@ -793,7 +793,8 @@
};
status = iree_hal_device_queue_execute(
device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
- signal_semaphores, 1, &transfer_command_buffer);
+ signal_semaphores, 1, &transfer_command_buffer,
+ /*binding_tables=*/NULL);
}
// TODO(scotttodd): Make this async - pass a wait source to iree_loop_wait_one
// 1. create iree_hal_fence_t, iree_hal_fence_insert(fance, semaphore)
diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c
index 84af3a1..9a88b21 100644
--- a/experimental/webgpu/command_buffer.c
+++ b/experimental/webgpu/command_buffer.c
@@ -912,18 +912,6 @@
return iree_ok_status();
}
-static iree_status_t iree_hal_webgpu_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): support indirect command buffers via deferred command buffers
- // as WebGPU has no concept of reusable dispatch command encoders. One day
- // hopefully there's an equivalent of GPURenderBundle but given WebGPU's other
- // limitations it may not be useful.
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = {
.destroy = iree_hal_webgpu_command_buffer_destroy,
.begin = iree_hal_webgpu_command_buffer_begin,
@@ -942,5 +930,4 @@
.push_descriptor_set = iree_hal_webgpu_command_buffer_push_descriptor_set,
.dispatch = iree_hal_webgpu_command_buffer_dispatch,
.dispatch_indirect = iree_hal_webgpu_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_webgpu_command_buffer_execute_commands,
};
diff --git a/experimental/webgpu/webgpu_device.c b/experimental/webgpu/webgpu_device.c
index 8af38c0..e3246e0 100644
--- a/experimental/webgpu/webgpu_device.c
+++ b/experimental/webgpu/webgpu_device.c
@@ -394,7 +394,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_webgpu_device_t* device = iree_hal_webgpu_device_cast(base_device);
// TODO(benvanik): this currently assumes we are synchronizing on semaphores
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index a49faa1..f77800e 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -588,7 +588,7 @@
/*wait_semaphore_list=*/iree_hal_fence_semaphore_list(ready_fence_.get()),
/*signal_semaphore_list=*/
iree_hal_fence_semaphore_list(dst_buffer_ready_fence.get()),
- /*command_buffer_count=*/1, &transfer_cb));
+ /*command_buffer_count=*/1, &transfer_cb, NULL));
*out_done_event = copy_done_event;
return iree_ok_status();
@@ -844,7 +844,7 @@
{1, &transfer_timeline_, &signal_alloca_complete},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_copy_complete},
- /*command_buffer_count=*/1, &transfer_cb));
+ /*command_buffer_count=*/1, &transfer_cb, NULL));
// Wrap in a buffer view and return:
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
@@ -1188,7 +1188,7 @@
{1, &transfer_timeline_, &signal_alloca_complete},
/*signal_semaphore_list=*/
{1, &transfer_timeline_, &signal_copy_complete},
- /*command_buffer_count=*/1, &transfer_cb));
+ /*command_buffer_count=*/1, &transfer_cb, NULL));
// Wrap in a buffer view and return.
iree::vm::ref<iree_hal_buffer_view_t> result_buffer_view;
diff --git a/integrations/pjrt/src/iree_pjrt/common/iree_helpers.h b/integrations/pjrt/src/iree_pjrt/common/iree_helpers.h
index 1268758..fce48ba 100644
--- a/integrations/pjrt/src/iree_pjrt/common/iree_helpers.h
+++ b/integrations/pjrt/src/iree_pjrt/common/iree_helpers.h
@@ -149,7 +149,7 @@
return HandleStatus(__func__, iree_hal_device_queue_execute(
device, queue_affinity, wait_semaphore_list,
signal_semaphore_list, command_buffer_count,
- command_buffers));
+ command_buffers, /*binding_tables=*/NULL));
}
iree_status_t hal_fence_create(iree_host_size_t capacity,
diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc
index 8799e48..ccbf8b0 100644
--- a/runtime/bindings/python/hal.cc
+++ b/runtime/bindings/python/hal.cc
@@ -549,10 +549,10 @@
cb_list[i] = py::cast<HalCommandBuffer*>(command_buffers[i])->raw_ptr();
}
- CheckApiStatus(
- iree_hal_device_queue_execute(raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY,
- wait_list, signal_list, cb_count, cb_list),
- "executing command buffers");
+ CheckApiStatus(iree_hal_device_queue_execute(
+ raw_ptr(), IREE_HAL_QUEUE_AFFINITY_ANY, wait_list,
+ signal_list, cb_count, cb_list, /*binding_tables=*/NULL),
+ "executing command buffers");
}
void HalDevice::QueueCopy(HalBuffer& source_buffer, HalBuffer& target_buffer,
diff --git a/runtime/src/iree/hal/buffer.c b/runtime/src/iree/hal/buffer.c
index 3ec2049..c66835b 100644
--- a/runtime/src/iree/hal/buffer.c
+++ b/runtime/src/iree/hal/buffer.c
@@ -399,7 +399,7 @@
return iree_ok_status();
}
-static iree_status_t iree_hal_buffer_calculate_range(
+IREE_API_EXPORT iree_status_t iree_hal_buffer_calculate_range(
iree_device_size_t base_offset, iree_device_size_t max_length,
iree_device_size_t offset, iree_device_size_t length,
iree_device_size_t* out_adjusted_offset,
diff --git a/runtime/src/iree/hal/buffer.h b/runtime/src/iree/hal/buffer.h
index dfbdd45..370af34 100644
--- a/runtime/src/iree/hal/buffer.h
+++ b/runtime/src/iree/hal/buffer.h
@@ -573,6 +573,14 @@
iree_hal_buffer_t* buffer, iree_device_size_t byte_offset,
iree_device_size_t byte_length);
+// Adjusts the offset and length of a buffer subrange and returns the new
+// subrange. Fails if the range is invalid.
+IREE_API_EXPORT iree_status_t iree_hal_buffer_calculate_range(
+ iree_device_size_t base_offset, iree_device_size_t max_length,
+ iree_device_size_t offset, iree_device_size_t length,
+ iree_device_size_t* out_adjusted_offset,
+ iree_device_size_t* out_adjusted_length);
+
// Tests whether the given buffers overlap, including support for subspans.
// IREE_WHOLE_BUFFER may be used for |lhs_length| and/or |rhs_length| to use the
// lengths of those buffers, respectively.
diff --git a/runtime/src/iree/hal/buffer_transfer.c b/runtime/src/iree/hal/buffer_transfer.c
index b7980ad..2832ef8 100644
--- a/runtime/src/iree/hal/buffer_transfer.c
+++ b/runtime/src/iree/hal/buffer_transfer.c
@@ -76,9 +76,9 @@
.semaphores = &fence_semaphore,
.payload_values = &signal_value,
};
- status = iree_hal_device_queue_execute(device, IREE_HAL_QUEUE_AFFINITY_ANY,
- wait_semaphores, signal_semaphores,
- 1, &command_buffer);
+ status = iree_hal_device_queue_execute(
+ device, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores,
+ 1, &command_buffer, /*binding_tables=*/NULL);
}
if (iree_status_is_ok(status)) {
status = iree_hal_semaphore_wait(fence_semaphore, signal_value, timeout);
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c
index 49c60dc..e2bba37 100644
--- a/runtime/src/iree/hal/command_buffer.c
+++ b/runtime/src/iree/hal/command_buffer.c
@@ -201,17 +201,8 @@
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"inline command buffers must be one-shot");
- } else if (iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "inline command buffers cannot be nested");
}
}
- if (binding_capacity > 0 &&
- !iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "command buffer bindings are only supported for "
- "nested command buffers (today)");
- }
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
@@ -607,24 +598,31 @@
return status;
}
-IREE_API_EXPORT iree_status_t iree_hal_command_buffer_execute_commands(
+//===----------------------------------------------------------------------===//
+// Validation support
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_binding_table(
iree_hal_command_buffer_t* command_buffer,
- iree_hal_command_buffer_t* commands,
- iree_hal_buffer_binding_table_t binding_table) {
+ const iree_hal_buffer_binding_table_t* binding_table) {
IREE_ASSERT_ARGUMENT(command_buffer);
- IREE_ASSERT_ARGUMENT(commands);
- IREE_ASSERT_ARGUMENT(!binding_table.count || binding_table.bindings);
- IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_command_buffer_execute_commands_validation(
- command_buffer, VALIDATION_STATE(command_buffer), commands,
- binding_table));
+ // Only check binding tables when one is required and otherwise ignore any
+ // bindings provided.
+ if (command_buffer->binding_capacity == 0) {
+ return iree_ok_status();
+ } else if (!binding_table ||
+ binding_table->count < command_buffer->binding_capacity) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "indirect command buffer requires at least %u "
+ "bindings but only %" PRIhsz " were provided ",
+ command_buffer->binding_capacity,
+ binding_table ? binding_table->count : 0);
+ }
+ // TODO(benvanik): validate each binding against the requirements of the
+ // command buffer.
});
- iree_status_t status = _VTABLE_DISPATCH(command_buffer, execute_commands)(
- command_buffer, commands, binding_table);
- IREE_TRACE_ZONE_END(z0);
- return status;
+ return iree_ok_status();
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index 5aa699c..8729c94 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -37,12 +37,6 @@
// If this bit is not set the command buffer may be submitted multiple times.
IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT = 1u << 0,
- // Command buffer is executed nested within a primary command buffer via
- // iree_hal_command_buffer_execute_commands.
- // May not be directly submitted to queues for execution and only one level
- // of nested command buffers are allowed.
- IREE_HAL_COMMAND_BUFFER_MODE_NESTED = 1u << 1,
-
// Indicates that the command buffer execution is allowed to execute inline
// with recording. The exact execution behavior is unspecified by the API and
// intentionally unknowable and must always assume to happen entirely
@@ -59,8 +53,7 @@
// Remote backends can use this to flush the command buffer more aggressively
// to begin early execution and overlap with continued recording.
//
- // Requires IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT and is not compatible with
- // IREE_HAL_COMMAND_BUFFER_MODE_NESTED.
+ // Requires IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT.
IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION = 1u << 4,
// Disables additional command buffer validation (if present).
@@ -86,6 +79,52 @@
};
typedef uint32_t iree_hal_command_category_t;
+// Specifies a direct or indirect buffer binding.
+// The range specified by [offset, length) of either the specified buffer or
+// a buffer slot in the binding table will be used at the time the command is
+// executed.
+//
+// The IREE HAL buffer type may internally be offset; such offset is applied
+// here as if it were the base address of the buffer. Note that the offset will
+// be applied at the time the binding is recording into the command buffer.
+//
+// Roughly maps to VkDescriptorSetBinding.
+typedef struct iree_hal_buffer_ref_t {
+ // The binding number of this entry and corresponds to a resource of the
+ // same binding number in the executable interface. Only used by certain
+ // calls.
+ uint32_t ordinal : 8;
+ // Binding table slot the buffer will be sourced from if buffer is NULL.
+ // Only valid on command buffers that support indirect execution.
+ uint32_t buffer_slot : 24;
+ // Buffer bound to the binding number.
+ // If NULL then the buffer_slot will be used to resolve the buffer at command
+ // buffer execution time from the binding table.
+ iree_hal_buffer_t* buffer;
+ // Offset, in bytes, into the buffer that the binding starts at.
+ // When indirectly referencing a binding table buffer this will be added to
+ // the base offset of the bound buffer.
+ iree_device_size_t offset;
+ // Length, in bytes, of the buffer that is available to the executable.
+ // This can be IREE_WHOLE_BUFFER, however note that if the entire buffer
+ // contents are larger than supported by the device (~128MiB, usually) this
+ // will fail. If the descriptor type is dynamic this will be used for all
+ // ranges regardless of offset.
+ iree_device_size_t length;
+} iree_hal_buffer_ref_t;
+
+static inline iree_hal_buffer_ref_t iree_hal_make_buffer_ref(
+ iree_hal_buffer_t* buffer, iree_device_size_t offset,
+ iree_device_size_t length) {
+ return (iree_hal_buffer_ref_t){0, 0, buffer, offset, length};
+}
+
+static inline iree_hal_buffer_ref_t iree_hal_make_indirect_buffer_ref(
+ uint32_t buffer_slot, iree_device_size_t offset,
+ iree_device_size_t length) {
+ return (iree_hal_buffer_ref_t){0, buffer_slot, NULL, offset, length};
+}
+
// Bitfield specifying which execution stage a barrier should start/end at.
//
// Maps to VkPipelineStageFlagBits.
@@ -383,32 +422,6 @@
IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count(
iree_hal_collective_element_type_t element_type);
-// Describes a subrange of a buffer that can be bound to a binding slot.
-typedef struct iree_hal_buffer_binding_t {
- // Buffer being bound to the slot, if any.
- iree_hal_buffer_t* buffer;
- // Offset, in bytes, into the buffer that the binding starts at.
- // This will be added to the offset specified on each usage of the slot.
- iree_device_size_t offset;
- // Length, in bytes, of the buffer that is available to the executable.
- // This can be IREE_WHOLE_BUFFER, however note that if the entire buffer
- // contents are larger than supported by the device (~128MiB, usually) this
- // will fail. If the descriptor type is dynamic this will be used for all
- // ranges regardless of offset.
- iree_device_size_t length;
-} iree_hal_buffer_binding_t;
-
-typedef struct iree_hal_buffer_binding_table_t {
- iree_host_size_t count;
- const iree_hal_buffer_binding_t* bindings;
-} iree_hal_buffer_binding_table_t;
-
-static inline iree_hal_buffer_binding_table_t
-iree_hal_buffer_binding_table_empty(void) {
- iree_hal_buffer_binding_table_t table = {0, NULL};
- return table;
-}
-
// An RGBA color.
typedef struct iree_hal_label_color_t {
uint8_t r;
@@ -459,6 +472,36 @@
((iree_device_size_t)(64 * 1024))
//===----------------------------------------------------------------------===//
+// iree_hal_buffer_binding_table_t
+//===----------------------------------------------------------------------===//
+
+// Describes a subrange of a buffer that can be bound to a binding slot.
+typedef struct iree_hal_buffer_binding_t {
+ // Buffer being bound to the slot, if any.
+ iree_hal_buffer_t* buffer;
+ // Offset, in bytes, into the buffer that the binding starts at.
+ // This will be added to the offset specified on each usage of the slot.
+ iree_device_size_t offset;
+ // Length, in bytes, of the buffer that is available to the executable.
+ // This can be IREE_WHOLE_BUFFER, however note that if the entire buffer
+ // contents are larger than supported by the device (~128MiB, usually) this
+ // will fail. If the descriptor type is dynamic this will be used for all
+ // ranges regardless of offset.
+ iree_device_size_t length;
+} iree_hal_buffer_binding_t;
+
+typedef struct iree_hal_buffer_binding_table_t {
+ iree_host_size_t count;
+ const iree_hal_buffer_binding_t* bindings;
+} iree_hal_buffer_binding_table_t;
+
+static inline iree_hal_buffer_binding_table_t
+iree_hal_buffer_binding_table_empty(void) {
+ iree_hal_buffer_binding_table_t table = {0, NULL};
+ return table;
+}
+
+//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_t
//===----------------------------------------------------------------------===//
@@ -679,11 +722,6 @@
// The descriptor set will remain bound and valid so long as the executable
// layouts used by dispatches are compatible (same descriptor layouts and push
// constant sizes).
-//
-// When the command buffer is IREE_HAL_COMMAND_BUFFER_MODE_NESTED zero or more
-// bindings may omit their buffer reference and instead specify a slot within
-// the command buffer binding table that will contain the buffer when executed
-// via the iree_hal_command_buffer_execute_commands command.
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
@@ -719,22 +757,17 @@
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset);
-// Executes a secondary command buffer using the provided set of indirect
-// buffer bindings. The commands will be executed as if they were recorded
-// directly into the command buffer but push constant and descriptor state will
-// not be inherited.
-//
-// The |commands| buffer must have the IREE_HAL_COMMAND_BUFFER_MODE_NESTED flag.
-// Only valid to use on primary command buffers that are themselves not marked
-// nested.
-//
-// The |binding_table| provided will be available for indirect binding usage
-// within the nested command buffer but may not reference the parent command
-// buffer binding table. This restriction may be lifted in the future.
-IREE_API_EXPORT iree_status_t iree_hal_command_buffer_execute_commands(
+//===----------------------------------------------------------------------===//
+// Validation support
+//===----------------------------------------------------------------------===//
+
+// Validates that all bindings in the provided |binding_table| match the
+// requirements of |command_buffer| as recorded. If the command buffer does not
+// use any indirect bindings the table will be ignored. If more bindings than
+// are used by the command buffer are provided they will be ignored.
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_binding_table(
iree_hal_command_buffer_t* command_buffer,
- iree_hal_command_buffer_t* commands,
- iree_hal_buffer_binding_table_t binding_table);
+ const iree_hal_buffer_binding_table_t* binding_table);
//===----------------------------------------------------------------------===//
// Utilities for command buffer creation
@@ -890,11 +923,6 @@
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset);
-
- iree_status_t(IREE_API_PTR* execute_commands)(
- iree_hal_command_buffer_t* command_buffer,
- iree_hal_command_buffer_t* commands,
- iree_hal_buffer_binding_table_t binding_table);
} iree_hal_command_buffer_vtable_t;
IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t);
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c
index 71eccb9..80648f0 100644
--- a/runtime/src/iree/hal/command_buffer_validation.c
+++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -498,34 +498,6 @@
// TODO(benvanik): validate set index.
- // TODO(benvanik): allow indirect bindings on primary command buffers?
- const bool has_binding_table =
- iree_all_bits_set(iree_hal_command_buffer_mode(command_buffer),
- IREE_HAL_COMMAND_BUFFER_MODE_NESTED);
- for (iree_host_size_t i = 0; i < binding_count; ++i) {
- const iree_hal_descriptor_set_binding_t* binding = &bindings[i];
- // TODO(benvanik): validate binding index.
- // TODO(benvanik): validate binding buffer parameters/access.
- // TODO(benvanik): validate binding range (if possible).
-
- // Validate that indirect buffer references are supported and in bounds.
- if (!binding->buffer) {
- if (!has_binding_table) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "bindings[%" PRIhsz
- "] is indirect but the command buffer does not "
- "support binding tables",
- i);
- } else if (binding->buffer_slot >= command_buffer->binding_capacity) {
- return iree_make_status(
- IREE_STATUS_OUT_OF_RANGE,
- "bindings[%" PRIhsz
- "] references binding table slot %u but table capacity is %u",
- i, binding->buffer_slot, command_buffer->binding_capacity);
- }
- }
- }
-
return iree_ok_status();
}
@@ -571,25 +543,3 @@
return iree_ok_status();
}
-
-iree_status_t iree_hal_command_buffer_execute_commands_validation(
- iree_hal_command_buffer_t* command_buffer,
- iree_hal_command_buffer_validation_state_t* validation_state,
- iree_hal_command_buffer_t* commands,
- iree_hal_buffer_binding_table_t binding_table) {
- if (iree_all_bits_set(command_buffer->mode,
- IREE_HAL_COMMAND_BUFFER_MODE_NESTED)) {
- return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
- "command buffers can only be nested one level "
- "(nested cannot execute nested)");
- }
- if (!iree_all_bits_set(commands->mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED)) {
- return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
- "only nested command buffers can be executed as "
- "part of a primary command buffer");
- }
-
- // TODO(benvanik): validate bindings as with push descriptor sets.
-
- return iree_ok_status();
-}
diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h
index 687acd0..33e502d 100644
--- a/runtime/src/iree/hal/command_buffer_validation.h
+++ b/runtime/src/iree/hal/command_buffer_validation.h
@@ -122,10 +122,4 @@
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset);
-iree_status_t iree_hal_command_buffer_execute_commands_validation(
- iree_hal_command_buffer_t* command_buffer,
- iree_hal_command_buffer_validation_state_t* validation_state,
- iree_hal_command_buffer_t* commands,
- iree_hal_buffer_binding_table_t binding_table);
-
#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
diff --git a/runtime/src/iree/hal/cts/cts_test_base.h b/runtime/src/iree/hal/cts/cts_test_base.h
index 3d50b8f..1116eee 100644
--- a/runtime/src/iree/hal/cts/cts_test_base.h
+++ b/runtime/src/iree/hal/cts/cts_test_base.h
@@ -120,7 +120,8 @@
iree_status_t status = iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores,
- signal_semaphores, command_buffer_count, command_buffers);
+ signal_semaphores, command_buffer_count, command_buffers,
+ /*binding_tables=*/NULL);
if (iree_status_is_ok(status)) {
status = iree_hal_semaphore_wait(signal_semaphore, target_payload_value,
iree_infinite_timeout());
diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h
index 4e194a7..10ed5dc 100644
--- a/runtime/src/iree/hal/cts/semaphore_submission_test.h
+++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h
@@ -33,10 +33,10 @@
signal_payload_values,
};
- IREE_ASSERT_OK(iree_hal_device_queue_execute(device_,
+ IREE_ASSERT_OK(iree_hal_device_queue_barrier(device_,
/*queue_affinity=*/0,
iree_hal_semaphore_list_empty(),
- signal_semaphores, 0, NULL));
+ signal_semaphores));
IREE_ASSERT_OK(
iree_hal_semaphore_wait(signal_semaphore, 1, iree_infinite_timeout()));
@@ -58,7 +58,7 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_,
/*queue_affinity=*/0, iree_hal_semaphore_list_empty(), signal_semaphores,
- 1, &command_buffer));
+ 1, &command_buffer, /*binding_tables=*/NULL));
IREE_ASSERT_OK(
iree_hal_semaphore_wait(signal_semaphore, 1, iree_infinite_timeout()));
@@ -87,10 +87,10 @@
signal_payload_values,
};
- IREE_ASSERT_OK(
- iree_hal_device_queue_execute(device_,
- /*queue_affinity=*/0, wait_semaphores,
- signal_semaphores, 1, &command_buffer));
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_,
+ /*queue_affinity=*/0, wait_semaphores, signal_semaphores, 1,
+ &command_buffer, /*binding_tables=*/NULL));
// Work shouldn't start until the wait semaphore reaches its payload value.
CheckSemaphoreValue(signal_semaphore, 100);
@@ -130,10 +130,10 @@
signal_payload_values,
};
- IREE_ASSERT_OK(
- iree_hal_device_queue_execute(device_,
- /*queue_affinity=*/0, wait_semaphores,
- signal_semaphores, 1, &command_buffer));
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_,
+ /*queue_affinity=*/0, wait_semaphores, signal_semaphores, 1,
+ &command_buffer, /*binding_tables=*/NULL));
// Work shouldn't start until all wait semaphores reach their payload values.
CheckSemaphoreValue(signal_semaphore_1, 0);
@@ -178,7 +178,7 @@
// Dispatch the device command buffer to have it wait.
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY, device_wait_semaphores,
- device_signal_semaphores, 1, &command_buffer));
+ device_signal_semaphores, 1, &command_buffer, /*binding_tables=*/NULL));
// Start another thread and have it wait.
std::thread thread([&]() {
@@ -242,7 +242,7 @@
// Dispatch the device command buffer to have it wait.
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY, device_wait_semaphores,
- device_signal_semaphores, 1, &command_buffer));
+ device_signal_semaphores, 1, &command_buffer, /*binding_tables=*/NULL));
// Start another thread and have it wait.
std::thread thread([&]() {
@@ -311,7 +311,7 @@
// Dispatch the device command buffer to have it wait.
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY, device_wait_semaphores,
- device_signal_semaphores, 1, &command_buffer));
+ device_signal_semaphores, 1, &command_buffer, /*binding_tables=*/NULL));
// Start another thread and have it wait.
std::thread thread([&]() {
@@ -382,7 +382,8 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/semaphore1_list,
- /*signal_semaphore_list=*/semaphore2_list, 1, &command_buffer2));
+ /*signal_semaphore_list=*/semaphore2_list, 1, &command_buffer2,
+ /*binding_tables=*/NULL));
// Make sure that the intermediate and second semaphores have not advanced
// since only command_buffer2 is queued.
@@ -395,7 +396,8 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer1_wait_semaphore_list,
- /*signal_semaphore_list=*/semaphore1_list, 1, &command_buffer1));
+ /*signal_semaphore_list=*/semaphore1_list, 1, &command_buffer1,
+ /*binding_tables=*/NULL));
// Wait on the intermediate semaphore and check its value.
IREE_ASSERT_OK(
@@ -449,15 +451,18 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/semaphore11_list,
- /*signal_semaphore_list=*/semaphore22_list, 1, &command_buffer22));
+ /*signal_semaphore_list=*/semaphore22_list, 1, &command_buffer22,
+ /*binding_tables=*/NULL));
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/semaphore11_list,
- /*signal_semaphore_list=*/semaphore21_list, 1, &command_buffer21));
+ /*signal_semaphore_list=*/semaphore21_list, 1, &command_buffer21,
+ /*binding_tables=*/NULL));
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/empty_semaphore_list,
- /*signal_semaphore_list=*/empty_semaphore_list, 1, &command_buffer12));
+ /*signal_semaphore_list=*/empty_semaphore_list, 1, &command_buffer12,
+ /*binding_tables=*/NULL));
// Assert that semaphores have not advance since we have not yet submitted
// command_buffer11.
@@ -469,7 +474,8 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/empty_semaphore_list,
- /*signal_semaphore_list=*/semaphore11_list, 1, &command_buffer11));
+ /*signal_semaphore_list=*/semaphore11_list, 1, &command_buffer11,
+ /*binding_tables=*/NULL));
// Wait and check that semaphore values have advanced.
IREE_ASSERT_OK(
@@ -541,13 +547,13 @@
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer22_semaphore_wait_list,
/*signal_semaphore_list=*/command_buffer22_signal_list, 1,
- &command_buffer22));
+ &command_buffer22, /*binding_tables=*/NULL));
// We submit the command buffers in reverse order.
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer21_semaphore_wait_list,
/*signal_semaphore_list=*/command_buffer21_signal_list, 1,
- &command_buffer21));
+ &command_buffer21, /*binding_tables=*/NULL));
// Semaphores have not advance since we have not yet submitted
// command_buffer11.
@@ -559,7 +565,7 @@
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer11_semaphore_wait_list,
/*signal_semaphore_list=*/command_buffer11_semaphore_signal_list, 1,
- &command_buffer11));
+ &command_buffer11, /*binding_tables=*/NULL));
// Wait and check that semaphore values have advanced.
IREE_ASSERT_OK(
@@ -617,7 +623,7 @@
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer2_wait_list,
/*signal_semaphore_list=*/command_buffer2_signal_list, 1,
- &command_buffer2));
+ &command_buffer2, /*binding_tables=*/NULL));
// semaphore3 must not have advanced, because it depends on semaphore1 and
// semaphore2, which have not been signaled yet.
@@ -632,7 +638,7 @@
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer1_wait_list,
/*signal_semaphore_list=*/command_buffer1_signal_list, 1,
- &command_buffer1));
+ &command_buffer1, /*binding_tables=*/NULL));
// semaphore3 must not have advanced still, because it depends on semaphore2,
// which has not been signaled yet.
@@ -689,7 +695,7 @@
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer2_wait_list,
/*signal_semaphore_list=*/command_buffer2_signal_list, 1,
- &command_buffer2));
+ &command_buffer2, /*binding_tables=*/NULL));
// Semaphores have not advance since we have not yet submitted
// command_buffer1.
@@ -727,7 +733,7 @@
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer1_wait_list,
/*signal_semaphore_list=*/command_buffer1_signal_list, 1,
- &command_buffer1));
+ &command_buffer1, /*binding_tables=*/NULL));
thread11.join();
thread12.join();
@@ -776,8 +782,8 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer_wait_list,
- /*signal_semaphore_list=*/command_buffer_signal_list, 1,
- &command_buffer));
+ /*signal_semaphore_list=*/command_buffer_signal_list, 1, &command_buffer,
+ /*binding_tables=*/NULL));
IREE_ASSERT_OK(
iree_hal_semaphore_wait(semaphore2, semaphore2_signal_value,
@@ -818,8 +824,8 @@
IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_, IREE_HAL_QUEUE_AFFINITY_ANY,
/*wait_semaphore_list=*/command_buffer_wait_list,
- /*signal_semaphore_list=*/command_buffer_signal_list, 1,
- &command_buffer));
+ /*signal_semaphore_list=*/command_buffer_signal_list, 1, &command_buffer,
+ /*binding_tables=*/NULL));
std::thread signal_thread(
[&]() { IREE_ASSERT_OK(iree_hal_semaphore_signal(semaphore1, 2)); });
diff --git a/runtime/src/iree/hal/device.c b/runtime/src/iree/hal/device.c
index 4097a8d..07bd660 100644
--- a/runtime/src/iree/hal/device.c
+++ b/runtime/src/iree/hal/device.c
@@ -169,9 +169,9 @@
queue_affinity, 1, &command,
&command_buffer));
- iree_status_t status =
- iree_hal_device_queue_execute(device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, 1, &command_buffer);
+ iree_status_t status = iree_hal_device_queue_execute(
+ device, queue_affinity, wait_semaphore_list, signal_semaphore_list, 1,
+ &command_buffer, /*binding_tables=*/NULL);
iree_hal_command_buffer_release(command_buffer);
@@ -218,9 +218,9 @@
queue_affinity, 1, &command,
&command_buffer));
- iree_status_t status =
- iree_hal_device_queue_execute(device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, 1, &command_buffer);
+ iree_status_t status = iree_hal_device_queue_execute(
+ device, queue_affinity, wait_semaphore_list, signal_semaphore_list, 1,
+ &command_buffer, /*binding_tables=*/NULL);
iree_hal_command_buffer_release(command_buffer);
@@ -281,7 +281,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(
!wait_semaphore_list.count ||
@@ -313,9 +314,20 @@
}
}
+ // Validate command buffer bindings against the provided binding tables.
+ // This will error out if a binding table is required but not provided or if
+ // any binding in the table does not match the requirements of the command
+ // buffer as recorded.
+ for (iree_host_size_t i = 0; i < command_buffer_count; ++i) {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_hal_command_buffer_validate_binding_table(
+ command_buffers[i], binding_tables ? &binding_tables[i] : NULL));
+ }
+
iree_status_t status = _VTABLE_DISPATCH(device, queue_execute)(
device, queue_affinity, wait_semaphore_list, signal_semaphore_list,
- command_buffer_count, command_buffers);
+ command_buffer_count, command_buffers, binding_tables);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -329,7 +341,7 @@
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
iree_hal_device_queue_execute(device, queue_affinity, wait_semaphore_list,
- signal_semaphore_list, 0, NULL);
+ signal_semaphore_list, 0, NULL, NULL);
IREE_TRACE_ZONE_END(z0);
return status;
}
diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h
index 7a04654..13c8263 100644
--- a/runtime/src/iree/hal/device.h
+++ b/runtime/src/iree/hal/device.h
@@ -395,6 +395,11 @@
// placed on to the same queue. Note that the exact hashing function is
// implementation dependent.
//
+// A list of binding tables matching the list of command buffers must be
+// provided if any command buffer has indirect bindings and may otherwise be
+// NULL. The binding table contents will be captured during the call and need
+// not persist after the call returns.
+//
// The submission behavior matches Vulkan's vkQueueSubmit, with each submission
// executing its command buffers in the order they are defined but allowing the
// command buffers to complete out-of-order. See:
@@ -404,7 +409,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers);
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables);
// Enqueues a barrier waiting for |wait_semaphore_list| and signaling
// |signal_semaphore_list| when reached.
@@ -611,7 +617,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers);
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables);
iree_status_t(IREE_API_PTR* queue_flush)(
iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity);
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 50cbb26..879573a 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -757,7 +757,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
index 44d6c2b..c5748b6 100644
--- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
@@ -843,19 +843,6 @@
"indirect dispatch not yet implemented");
}
-static iree_status_t iree_hal_cuda_graph_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): support indirect command buffers by adding subgraph nodes and
- // tracking the binding table for future cuGraphExecKernelNodeSetParams usage.
- // Need to look into how to update the params of the subgraph nodes - is the
- // graph exec the outer one and if so will it allow node handles from the
- // subgraphs?
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
static const iree_hal_command_buffer_vtable_t
iree_hal_cuda_graph_command_buffer_vtable = {
.destroy = iree_hal_cuda_graph_command_buffer_destroy,
@@ -880,5 +867,4 @@
.dispatch = iree_hal_cuda_graph_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_cuda_graph_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_cuda_graph_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
index 52e69bf..784d906 100644
--- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -618,16 +618,6 @@
"need cuda implementation of dispatch indirect");
}
-static iree_status_t iree_hal_cuda_stream_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): support indirect command buffers with deferred command
- // buffers or graphs. We likely just want to switch to graphs.
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
static const iree_hal_command_buffer_vtable_t
iree_hal_cuda_stream_command_buffer_vtable = {
.destroy = iree_hal_cuda_stream_command_buffer_destroy,
@@ -652,6 +642,4 @@
.dispatch = iree_hal_cuda_stream_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_cuda_stream_command_buffer_dispatch_indirect,
- .execute_commands =
- iree_hal_cuda_stream_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
index 2b00a83..64ab55c 100644
--- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
@@ -850,19 +850,6 @@
"indirect dispatch not yet implemented");
}
-static iree_status_t iree_hal_hip_graph_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): support indirect command buffers by adding subgraph nodes and
- // tracking the binding table for future hipGraphExecKernelNodeSetParams
- // usage. Need to look into how to update the params of the subgraph nodes -
- // is the graph exec the outer one and if so will it allow node handles from
- // the subgraphs?
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
static const iree_hal_command_buffer_vtable_t
iree_hal_hip_graph_command_buffer_vtable = {
.destroy = iree_hal_hip_graph_command_buffer_destroy,
@@ -887,5 +874,4 @@
.dispatch = iree_hal_hip_graph_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_hip_graph_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_hip_graph_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c
index c41851c..13e41a7 100644
--- a/runtime/src/iree/hal/drivers/hip/hip_device.c
+++ b/runtime/src/iree/hal/drivers/hip/hip_device.c
@@ -764,7 +764,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
index 5ec1555..9dc7786 100644
--- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
@@ -596,16 +596,6 @@
"need hip implementation of dispatch indirect");
}
-static iree_status_t iree_hal_hip_stream_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): support indirect command buffers with deferred command
- // buffers or graphs. We likely just want to switch to graphs.
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
static const iree_hal_command_buffer_vtable_t
iree_hal_hip_stream_command_buffer_vtable = {
.destroy = iree_hal_hip_stream_command_buffer_destroy,
@@ -630,5 +620,4 @@
.dispatch = iree_hal_hip_stream_command_buffer_dispatch,
.dispatch_indirect =
iree_hal_hip_stream_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_hip_stream_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
index 711704c..f2cc6df 100644
--- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c
+++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
@@ -429,7 +429,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_sync_device_t* device = iree_hal_sync_device_cast(base_device);
// TODO(#4680): there is some better error handling here needed; we should
diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
index a88d887..d065cac 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
@@ -1023,26 +1023,6 @@
}
//===----------------------------------------------------------------------===//
-// iree_hal_command_buffer_execute_commands
-//===----------------------------------------------------------------------===//
-
-static iree_status_t iree_hal_task_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): support indirect command buffers by using deferred command
- // buffers or caching the task topology (probably not worth the tracking).
- // If we could separate the topology that referenced the binding table we'd
- // be able to reissue but not concurrently (as each task can only be in flight
- // as a singleton) - which may be enough in many cases but adds complexity to
- // tracking as we'd need to either enforce serialization of subsequent
- // submissions or copy-on-write-style clone the topology for each additional
- // concurrent submission.
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
-//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_vtable_t
//===----------------------------------------------------------------------===//
@@ -1066,5 +1046,4 @@
.push_descriptor_set = iree_hal_task_command_buffer_push_descriptor_set,
.dispatch = iree_hal_task_command_buffer_dispatch,
.dispatch_indirect = iree_hal_task_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_task_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c
index 5a99037..d90b3a2 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_device.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_device.c
@@ -459,7 +459,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_task_device_t* device = iree_hal_task_device_cast(base_device);
// NOTE: today we are not discriminating queues based on command type.
iree_host_size_t queue_index = iree_hal_task_device_select_queue(
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
index 3951c48..4a2a07a 100644
--- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -342,7 +342,6 @@
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_command_buffer);
IREE_ASSERT_TRUE(iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT));
- IREE_ASSERT_TRUE(!iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED));
*out_command_buffer = NULL;
if (binding_capacity > 0) {
@@ -1101,12 +1100,6 @@
return iree_ok_status();
}
-static iree_status_t iree_hal_metal_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer, iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "secondary command buffer not yet supported");
-}
-
static iree_status_t iree_hal_metal_command_segment_record(
iree_hal_metal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
@@ -1182,5 +1175,4 @@
.push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set,
.dispatch = iree_hal_metal_command_buffer_prepare_dispatch,
.dispatch_indirect = iree_hal_metal_command_buffer_prepare_dispatch_indirect,
- .execute_commands = iree_hal_metal_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m
index 05878b0..e16a883 100644
--- a/runtime/src/iree/hal/drivers/metal/metal_device.m
+++ b/runtime/src/iree/hal/drivers/metal/metal_device.m
@@ -247,8 +247,6 @@
iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) {
iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
- if (iree_any_bit_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_NESTED))
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "nested command buffer not yet supported");
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT))
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"multi-shot command buffer not yet supported");
@@ -392,7 +390,8 @@
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list, iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
index c177e9e..000584a 100644
--- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
@@ -792,38 +792,6 @@
return iree_ok_status();
}
-static iree_status_t iree_hal_vulkan_direct_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- iree_hal_vulkan_direct_command_buffer_t* command_buffer =
- iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
-
- if (binding_table.count > 0) {
- // TODO(#10144): support indirect command buffers with binding tables.
- // Since Vulkan doesn't natively support this we'd need to emulate things
- // with an iree_hal_vulkan_indirect_command_buffer_t type that captured the
- // command buffer using deferred command buffer and allowed replay with a
- // binding table. If we wanted to actually reuse the command buffers we'd
- // need to use update-after-bind (where supported), device pointers (where
- // supported), or descriptor indexing and a big ringbuffer (make a 1024
- // element descriptor array and cycle through it with each submission).
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
- }
-
- IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
- command_buffer->resource_set, 1, &base_commands));
-
- iree_hal_vulkan_direct_command_buffer_t* commands =
- iree_hal_vulkan_direct_command_buffer_cast(base_commands);
-
- command_buffer->syms->vkCmdExecuteCommands(command_buffer->handle, 1,
- &commands->handle);
-
- return iree_ok_status();
-}
-
namespace {
const iree_hal_command_buffer_vtable_t
iree_hal_vulkan_direct_command_buffer_vtable = {
@@ -855,7 +823,5 @@
/*.dispatch=*/iree_hal_vulkan_direct_command_buffer_dispatch,
/*.dispatch_indirect=*/
iree_hal_vulkan_direct_command_buffer_dispatch_indirect,
- /*.execute_commands=*/
- iree_hal_vulkan_direct_command_buffer_execute_commands,
};
} // namespace
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index f370fdd..23e6d2d 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -1709,7 +1709,8 @@
const iree_hal_semaphore_list_t wait_semaphore_list,
const iree_hal_semaphore_list_t signal_semaphore_list,
iree_host_size_t command_buffer_count,
- iree_hal_command_buffer_t* const* command_buffers) {
+ iree_hal_command_buffer_t* const* command_buffers,
+ iree_hal_buffer_binding_table_t const* binding_tables) {
iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device);
// NOTE: today we are not discriminating queues based on command type.
CommandQueue* queue = iree_hal_vulkan_device_select_queue(
diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c
index 0f5071a..22760bc 100644
--- a/runtime/src/iree/hal/local/inline_command_buffer.c
+++ b/runtime/src/iree/hal/local/inline_command_buffer.c
@@ -572,22 +572,6 @@
}
//===----------------------------------------------------------------------===//
-// iree_hal_command_buffer_execute_commands
-//===----------------------------------------------------------------------===//
-
-static iree_status_t iree_hal_inline_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* base_commands,
- iree_hal_buffer_binding_table_t binding_table) {
- // TODO(#10144): decide how to execute the inline command buffer; it is
- // definitely a deferred command buffer but we don't want to force that
- // dependency here. We could allow injection of a function to call to execute
- // command buffers so that the device can decide how it wants to handle them.
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "indirect command buffers not yet implemented");
-}
-
-//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_vtable_t
//===----------------------------------------------------------------------===//
@@ -612,5 +596,4 @@
iree_hal_inline_command_buffer_push_descriptor_set,
.dispatch = iree_hal_inline_command_buffer_dispatch,
.dispatch_indirect = iree_hal_inline_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_inline_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/hal/utils/debug_allocator.c b/runtime/src/iree/hal/utils/debug_allocator.c
index 06f969c..8ca388d 100644
--- a/runtime/src/iree/hal/utils/debug_allocator.c
+++ b/runtime/src/iree/hal/utils/debug_allocator.c
@@ -167,9 +167,9 @@
.semaphores = &semaphore,
.payload_values = &signal_value,
};
- status = iree_hal_device_queue_execute(device, IREE_HAL_QUEUE_AFFINITY_ANY,
- iree_hal_semaphore_list_empty(),
- signal_list, 1, &command_buffer);
+ status = iree_hal_device_queue_execute(
+ device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+ signal_list, 1, &command_buffer, /*binding_tables=*/NULL);
}
if (iree_status_is_ok(status)) {
diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c
index 05562a8..3b5cc22 100644
--- a/runtime/src/iree/hal/utils/deferred_command_buffer.c
+++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c
@@ -27,7 +27,6 @@
IREE_HAL_CMD_PUSH_DESCRIPTOR_SET,
IREE_HAL_CMD_DISPATCH,
IREE_HAL_CMD_DISPATCH_INDIRECT,
- IREE_HAL_CMD_EXECUTE_COMMANDS,
} iree_hal_cmd_type_t;
// Header prefixed to all commands, forming a linked-list.
@@ -821,56 +820,6 @@
}
//===----------------------------------------------------------------------===//
-// IREE_HAL_CMD_EXECUTE_COMMANDS
-//===----------------------------------------------------------------------===//
-
-typedef struct iree_hal_cmd_execute_commands_t {
- iree_hal_cmd_header_t header;
- iree_hal_command_buffer_t* commands;
- iree_host_size_t binding_count;
- iree_hal_buffer_binding_t bindings[];
-} iree_hal_cmd_execute_commands_t;
-
-static iree_status_t iree_hal_deferred_command_buffer_execute_commands(
- iree_hal_command_buffer_t* base_command_buffer,
- iree_hal_command_buffer_t* commands,
- iree_hal_buffer_binding_table_t binding_table) {
- iree_hal_deferred_command_buffer_t* command_buffer =
- iree_hal_deferred_command_buffer_cast(base_command_buffer);
- iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
- IREE_RETURN_IF_ERROR(
- iree_hal_resource_set_insert(command_buffer->resource_set, 1, &commands));
- iree_hal_cmd_execute_commands_t* cmd = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
- cmd_list, IREE_HAL_CMD_EXECUTE_COMMANDS,
- sizeof(*cmd) + sizeof(cmd->bindings[0]) * binding_table.count,
- (void**)&cmd));
- cmd->commands = commands;
- cmd->binding_count = binding_table.count;
- for (iree_host_size_t i = 0; i < binding_table.count; ++i) {
- const iree_hal_buffer_binding_t binding = binding_table.bindings[i];
- cmd->bindings[i] = binding;
- if (binding.buffer) {
- IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
- command_buffer->resource_set, 1, &binding.buffer));
- }
- }
- return iree_ok_status();
-}
-
-static iree_status_t iree_hal_deferred_command_buffer_apply_execute_commands(
- iree_hal_command_buffer_t* target_command_buffer,
- iree_hal_buffer_binding_table_t binding_table,
- const iree_hal_cmd_execute_commands_t* cmd) {
- const iree_hal_buffer_binding_table_t child_binding_table = {
- .count = cmd->binding_count,
- .bindings = cmd->bindings,
- };
- return iree_hal_command_buffer_execute_commands(
- target_command_buffer, cmd->commands, child_binding_table);
-}
-
-//===----------------------------------------------------------------------===//
// Dynamic replay dispatch
//===----------------------------------------------------------------------===//
@@ -901,8 +850,6 @@
iree_hal_deferred_command_buffer_apply_dispatch,
[IREE_HAL_CMD_DISPATCH_INDIRECT] = (iree_hal_cmd_apply_fn_t)
iree_hal_deferred_command_buffer_apply_dispatch_indirect,
- [IREE_HAL_CMD_EXECUTE_COMMANDS] = (iree_hal_cmd_apply_fn_t)
- iree_hal_deferred_command_buffer_apply_execute_commands,
};
IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_apply(
@@ -961,5 +908,4 @@
iree_hal_deferred_command_buffer_push_descriptor_set,
.dispatch = iree_hal_deferred_command_buffer_dispatch,
.dispatch_indirect = iree_hal_deferred_command_buffer_dispatch_indirect,
- .execute_commands = iree_hal_deferred_command_buffer_execute_commands,
};
diff --git a/runtime/src/iree/io/parameter_index_provider.c b/runtime/src/iree/io/parameter_index_provider.c
index 38b83c4..c144bfd 100644
--- a/runtime/src/iree/io/parameter_index_provider.c
+++ b/runtime/src/iree/io/parameter_index_provider.c
@@ -589,7 +589,8 @@
if (iree_status_is_ok(status)) {
status = iree_hal_device_queue_execute(
batch->device, batch->queue_affinity, step.wait_semaphore_list,
- step.signal_semaphore_list, 1, &batch->transfer_command_buffer);
+ step.signal_semaphore_list, 1, &batch->transfer_command_buffer,
+ /*binding_tables=*/NULL);
}
IREE_TRACE_ZONE_END(z_transfer);
}
diff --git a/runtime/src/iree/modules/check/module.cc b/runtime/src/iree/modules/check/module.cc
index 41ab445..039d2d5 100644
--- a/runtime/src/iree/modules/check/module.cc
+++ b/runtime/src/iree/modules/check/module.cc
@@ -222,7 +222,8 @@
semaphore.get(), 1ull, iree_hal_device_host_allocator(device), &fence));
IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
- iree_hal_fence_semaphore_list(fence.get()), 1, &command_buffer));
+ iree_hal_fence_semaphore_list(fence.get()), 1, &command_buffer,
+ /*binding_tables=*/NULL));
IREE_RETURN_IF_ERROR(
iree_hal_fence_wait(fence.get(), iree_infinite_timeout()));
return std::move(target_views);
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index 8100bd4..13f9d09 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -53,7 +53,6 @@
EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v)
EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rrirI, v)
EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v)
-EXPORT_FN("command_buffer.execute.commands", iree_hal_module_command_buffer_execute_commands, rrCrIID, v)
EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIii, v)
EXPORT_FN("command_buffer.finalize", iree_hal_module_command_buffer_finalize, r, v)
@@ -67,6 +66,7 @@
EXPORT_FN("device.queue.alloca", iree_hal_module_device_queue_alloca, rIrriiiI, r)
EXPORT_FN("device.queue.dealloca", iree_hal_module_device_queue_dealloca, rIrrr, v)
EXPORT_FN("device.queue.execute", iree_hal_module_device_queue_execute, rIrrCrD, v)
+EXPORT_FN("device.queue.execute.indirect", iree_hal_module_device_queue_execute_indirect, rIrrrCrIID, v)
EXPORT_FN("device.queue.flush", iree_hal_module_device_queue_flush, rI, v)
EXPORT_FN("device.queue.read", iree_hal_module_device_queue_read, rIrrrIrIIi, v)
EXPORT_FN("device.queue.write", iree_hal_module_device_queue_write, rIrrrIrIIi, v)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 824766b..790c616 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -20,13 +20,13 @@
// in the future but right now guards the stack from blowing up during calls.
#define IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT ((iree_host_size_t)32)
-// Limit the number of execution bindings in a binding table. This today limits
-// our number of unique indirect buffers used within a command buffer but the
-// compiler is very good at coalescing those and we often end up with 1-3. If in
-// the future we want to use more from compiled programs we could change from
-// using a stack allocation to a heap allocation when many bindings are
-// provided.
-#define IREE_HAL_MODULE_MAX_COMMAND_BUFFER_BINDING_COUNT ((iree_host_size_t)256)
+// Limit the number of bindings in a binding table that we allocate on the stack
+// while marshaling from the VM. Counts over this amount will result in heap
+// allocations to avoid blowing the native stack. In most programs we expect
+// at most a dozen buffers but programs with individually stored parameters may
+// need hundreds or even thousands. Yuck.
+#define IREE_HAL_MODULE_MAX_STACK_COMMAND_BUFFER_BINDING_COUNT \
+ ((iree_host_size_t)64)
//===----------------------------------------------------------------------===//
// Module type definitions
@@ -679,14 +679,6 @@
(iree_hal_command_category_t)args->i2;
iree_host_size_t binding_capacity = (iree_host_size_t)args->i3;
- if (IREE_UNLIKELY(binding_capacity >
- IREE_HAL_MODULE_MAX_COMMAND_BUFFER_BINDING_COUNT)) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "binding table capacity %" PRIhsz " > %" PRIhsz,
- binding_capacity,
- IREE_HAL_MODULE_MAX_COMMAND_BUFFER_BINDING_COUNT);
- }
-
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
device, modes, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY,
@@ -920,40 +912,6 @@
workgroups_offset);
}
-IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_execute_commands, //
- iree_hal_module_state_t, //
- rrCrIID, v) {
- iree_hal_command_buffer_t* command_buffer = NULL;
- IREE_RETURN_IF_ERROR(
- iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
- iree_hal_command_buffer_t* commands = NULL;
- IREE_RETURN_IF_ERROR(
- iree_hal_command_buffer_check_deref(args->r1, &commands));
-
- iree_host_size_t binding_count = args->a2_count;
- if (IREE_UNLIKELY(binding_count >
- IREE_HAL_MODULE_MAX_COMMAND_BUFFER_BINDING_COUNT)) {
- return iree_make_status(
- IREE_STATUS_OUT_OF_RANGE, "binding table count %" PRIhsz " > %" PRIhsz,
- binding_count, IREE_HAL_MODULE_MAX_COMMAND_BUFFER_BINDING_COUNT);
- }
- iree_hal_buffer_binding_t* bindings = (iree_hal_buffer_binding_t*)iree_alloca(
- binding_count * sizeof(iree_hal_buffer_binding_t));
- for (iree_host_size_t i = 0; i < binding_count; ++i) {
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null(
- args->a2[i].r0, &bindings[i].buffer));
- bindings[i].offset = iree_hal_cast_device_size(args->a2[i].i1);
- bindings[i].length = iree_hal_cast_device_size(args->a2[i].i2);
- }
-
- const iree_hal_buffer_binding_table_t binding_table = {
- .count = binding_count,
- .bindings = bindings,
- };
- return iree_hal_command_buffer_execute_commands(command_buffer, commands,
- binding_table);
-}
-
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout
//===----------------------------------------------------------------------===//
@@ -1128,7 +1086,69 @@
return iree_hal_device_queue_execute(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
iree_hal_fence_semaphore_list(signal_fence), command_buffer_count,
- command_buffers);
+ command_buffers, /*binding_tables=*/NULL);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_execute_indirect, //
+ iree_hal_module_state_t, //
+ rIrrrCrIID, v) {
+ iree_hal_device_t* device = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i1;
+ iree_hal_fence_t* wait_fence = iree_hal_fence_deref(args->r2);
+ iree_hal_fence_t* signal_fence = iree_hal_fence_deref(args->r3);
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r4, &command_buffer));
+
+ // Allocate temporary storage for the binding table in order to marshal VM
+ // refs and 64-bit offsets/lengths into the types required by the HAL C API.
+ iree_host_size_t binding_count = args->a5_count;
+ iree_hal_buffer_binding_t* bindings = NULL;
+ if (binding_count > IREE_HAL_MODULE_MAX_STACK_COMMAND_BUFFER_BINDING_COUNT) {
+ // Heap allocate when using a large number of bindings to avoid blowing the
+ // native stack. Note that we have to free it before returning from the
+ // function.
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc_uninitialized(
+ state->host_allocator, binding_count * sizeof(*bindings),
+ (void**)&bindings));
+ } else {
+ // Stack allocate when using a small number of bindings (common).
+ bindings = (iree_hal_buffer_binding_t*)iree_alloca(binding_count *
+ sizeof(*bindings));
+ }
+
+ // Ensure all buffers are valid (may be NULL) and build the binding table.
+ iree_status_t status = iree_ok_status();
+ for (iree_host_size_t i = 0; i < binding_count; ++i) {
+ status = iree_hal_buffer_check_deref_or_null(args->a5[i].r0,
+ &bindings[i].buffer);
+ if (!iree_status_is_ok(status)) break;
+ bindings[i].offset = iree_hal_cast_device_size(args->a5[i].i1);
+ bindings[i].length = iree_hal_cast_device_size(args->a5[i].i2);
+ }
+
+ // Schedule execution with the binding table - it will be copied by the device
+ // and need not live longer than the call.
+ if (iree_status_is_ok(status)) {
+ iree_hal_buffer_binding_table_t binding_table = {
+ .count = binding_count,
+ .bindings = bindings,
+ };
+ status = iree_hal_device_queue_execute(
+ device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
+ iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer,
+ &binding_table);
+ }
+
+ // If we had to heap-allocate the binding table storage it must be freed
+ // before returning to the VM.
+ if (binding_count > IREE_HAL_MODULE_MAX_STACK_COMMAND_BUFFER_BINDING_COUNT) {
+ iree_allocator_free(state->host_allocator, bindings);
+ }
+
+ return status;
}
IREE_VM_ABI_EXPORT(iree_hal_module_device_queue_flush, //
diff --git a/runtime/src/iree/tooling/function_util.c b/runtime/src/iree/tooling/function_util.c
index 4e4d71d..aa06445 100644
--- a/runtime/src/iree/tooling/function_util.c
+++ b/runtime/src/iree/tooling/function_util.c
@@ -117,7 +117,8 @@
if (iree_status_is_ok(status)) {
status = iree_hal_device_queue_execute(
device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
- iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer);
+ iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer,
+ /*binding_tables=*/NULL);
}
if (iree_status_is_ok(status) && needs_wait) {
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 80f11f5..9c8b68e 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -71,6 +71,7 @@
IREE_VM_ABI_DEFINE_SHIM(rIrrrIiirrr, r);
IREE_VM_ABI_DEFINE_SHIM(rIrrr, v);
IREE_VM_ABI_DEFINE_SHIM(rIrrCrD, v);
+IREE_VM_ABI_DEFINE_SHIM(rIrrrCrIID, v);
IREE_VM_ABI_DEFINE_SHIM(CrID, r);
IREE_VM_ABI_DEFINE_SHIM(CrD, r);
IREE_VM_ABI_DEFINE_SHIM(iCrD, i);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index 860d96f..ab2f012 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -480,6 +480,16 @@
iree_vm_abi_r_t a4[0];
});
+IREE_VM_ABI_VLA_STRUCT(rIrrrCrIID, a5_count, a5, {
+ iree_vm_ref_t r0;
+ int64_t i1;
+ iree_vm_ref_t r2;
+ iree_vm_ref_t r3;
+ iree_vm_ref_t r4;
+ iree_vm_size_t a5_count;
+ iree_vm_abi_rII_t a5[0];
+});
+
IREE_VM_ABI_VLA_STRUCT(rCiD, a1_count, a1, {
iree_vm_ref_t r0;
iree_vm_size_t a1_count;
@@ -678,6 +688,7 @@
IREE_VM_ABI_DECLARE_SHIM(rIrrrIiirrr, r);
IREE_VM_ABI_DECLARE_SHIM(rIrrr, v);
IREE_VM_ABI_DECLARE_SHIM(rIrrCrD, v);
+IREE_VM_ABI_DECLARE_SHIM(rIrrrCrIID, v);
IREE_VM_ABI_DECLARE_SHIM(CrID, r);
IREE_VM_ABI_DECLARE_SHIM(CrD, r);
IREE_VM_ABI_DECLARE_SHIM(iCrD, i);
diff --git a/tools/iree-benchmark-executable-main.c b/tools/iree-benchmark-executable-main.c
index 60f5395..2fd1066 100644
--- a/tools/iree-benchmark-executable-main.c
+++ b/tools/iree-benchmark-executable-main.c
@@ -290,7 +290,7 @@
++fence_value;
IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute(
args->device, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphore_list,
- signal_semaphore_list, 1, &command_buffer));
+ signal_semaphore_list, 1, &command_buffer, /*binding_tables=*/NULL));
// Block and wait for the submission to complete.
// Note that this will include round-trip overhead and if the dispatch or