Updating HAL/VM ABI to support indirect bindings (COMPATIBILITY BREAKING). (#17951)
This bumps the HAL module version to 0.3 and makes a few changes
required to support reusable/indirect command buffers.
This is the first breaking change in the HAL ABI in 2024 and hopefully
the only one required this year!
Summary of changes:
* `hal.command_buffer.create` now takes a set of allowed queue
affinities to match `iree_hal_command_buffer_create`
* `hal.command_buffer.update_buffer` was added and maps to
`iree_hal_command_buffer_update_buffer`
* `hal.command_buffer.fill_buffer`, `hal.command_buffer.copy_buffer`,
`hal.command_buffer.collective`, and
`hal.command_buffer.dispatch.indirect` were updated to take binding
table slots in addition to buffer pointers
Progress on #17875.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
index cda0bbb..bee6557 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
@@ -81,7 +81,8 @@
} {
%c0 = arith.constant 0 : index
%device = hal.devices.get %c0 : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes {
+ %affinity = arith.constant -1 : i64
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer attributes {
testing.op.a = @dispatch_0,
testing.op.b = @dispatch_0::@spirv,
testing.op.c = @dispatch_0::@spirv::@dispatch_0
@@ -93,15 +94,16 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
return
}
util.initializer {
%c0 = arith.constant 0 : index
%device = hal.devices.get %c0 : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
+ %affinity = arith.constant -1 : i64
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
%c1 = arith.constant 1 : index
%dispatch_0_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_0) : !hal.executable
%dispatch_1_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_1) : !hal.executable
@@ -109,9 +111,9 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
util.return
}
@@ -291,7 +293,8 @@
func.func @two_target_environments() -> () {
%c0 = arith.constant 0 : index
%device = hal.devices.get %c0 : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
+ %affinity = arith.constant -1 : i64
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
%c1 = arith.constant 1 : index
%dispatch_0_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_0) : !hal.executable
%dispatch_1_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_1) : !hal.executable
@@ -301,10 +304,10 @@
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index
%dispatch_3_ordinal = hal.executable.export.ordinal target(@dispatch_3::@spirv::@dispatch_3) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
return
}
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
index d6d6d15..2baedf6 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
@@ -74,7 +74,8 @@
} {
%c0 = arith.constant 0 : index
%device = hal.devices.get %c0 : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes {
+ %affinity = arith.constant -1 : i64
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer attributes {
testing.op.a = @dispatch_0,
testing.op.b = @dispatch_0::@vmvx,
testing.op.c = @dispatch_0::@vmvx::@dispatch_0
@@ -86,15 +87,16 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
return
}
util.initializer {
%c0 = arith.constant 0 : index
%device = hal.devices.get %c0 : !hal.device
- %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
+ %affinity = arith.constant -1 : i64
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
%c1 = arith.constant 1 : index
%dispatch_0_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_0) : !hal.executable
%dispatch_1_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_1) : !hal.executable
@@ -102,9 +104,9 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index 4de1867..cb4179f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -16,6 +16,31 @@
namespace {
+// Returns a slot value and a buffer ref value.
+// |bufferOrSlot| is intended to be a `AnyTypeOf<[Index, HAL_BufferType]>` in
+// the op definition.
+static std::tuple<Value, Value>
+splitBufferSlot(Location loc, Value bufferOrSlot, OpBuilder &builder) {
+ if (!bufferOrSlot) {
+ return std::make_tuple(
+ builder.create<IREE::VM::ConstI32ZeroOp>(loc),
+ builder.create<IREE::VM::ConstRefZeroOp>(
+ loc,
+ IREE::VM::RefType::get(builder.getType<IREE::HAL::BufferType>())));
+ } else if (isa<IREE::VM::RefType>(bufferOrSlot.getType())) {
+ // Direct buffer binding; pass 0 for table slot.
+ return std::make_tuple(builder.create<IREE::VM::ConstI32ZeroOp>(loc),
+ bufferOrSlot);
+ } else {
+ // Indirect binding table reference; pass null for the buffer.
+ return std::make_tuple(
+ castToImportType(bufferOrSlot, builder.getI32Type(), builder),
+ builder.create<IREE::VM::ConstRefZeroOp>(
+ loc,
+ IREE::VM::RefType::get(builder.getType<IREE::HAL::BufferType>())));
+ }
+}
+
// TODO(benvanik): import op handling of optional values.
// It'd be nice if the std::optional<Index>:$binding_capacity could be emitted
// as 0 when not present; today it'll be omitted entirely (as it's not in the
@@ -51,6 +76,7 @@
if (!categoriesValue.has_value())
return failure();
callOperands.append(categoriesValue.value());
+ callOperands.push_back(adaptor.getQueueAffinity());
if (adaptor.getBindingCapacity()) {
callOperands.push_back(castToImportType(adaptor.getBindingCapacity(),
rewriter.getI32Type(), rewriter));
@@ -88,12 +114,15 @@
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
+ auto [targetBufferSlot, targetBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter);
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
- adaptor.getTargetBuffer(),
+ targetBuffer,
castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
rewriter),
castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
+ targetBufferSlot,
};
// Record the original pattern length then extend it to a 32 bit integer.
@@ -127,6 +156,88 @@
mutable IREE::VM::ImportOp importOp;
};
+class CommandBufferUpdateBufferOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferUpdateBufferOp> {
+public:
+ CommandBufferUpdateBufferOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferUpdateBufferOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+ auto [targetBufferSlot, targetBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter);
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ adaptor.getSourceBuffer(),
+ castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(),
+ rewriter),
+ targetBuffer,
+ castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
+ rewriter),
+ castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
+ targetBufferSlot};
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+class CommandBufferCopyBufferOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferCopyBufferOp> {
+public:
+ CommandBufferCopyBufferOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferCopyBufferOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+ auto [sourceBufferSlot, sourceBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getSourceBuffer(), rewriter);
+ auto [targetBufferSlot, targetBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter);
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ sourceBufferSlot,
+ targetBufferSlot,
+ sourceBuffer,
+ castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(),
+ rewriter),
+ targetBuffer,
+ castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
+ rewriter),
+ castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
+ };
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
class CommandBufferCollectiveOpConversion
: public OpConversionPattern<IREE::HAL::CommandBufferCollectiveOp> {
public:
@@ -144,15 +255,6 @@
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
- Value nullBuffer;
- auto getNullBuffer = [&]() {
- if (!nullBuffer) {
- nullBuffer = rewriter.create<IREE::VM::ConstRefZeroOp>(
- op.getLoc(),
- IREE::VM::RefType::get(rewriter.getType<IREE::HAL::BufferType>()));
- }
- return nullBuffer;
- };
Value zeroI64;
auto getZeroI64 = [&]() {
if (!zeroI64) {
@@ -165,10 +267,12 @@
// %channel : !vm.ref<!hal.channel>,
// %op : i32,
// %param : i32,
+ // %send_buffer_slot : i32,
+ // %recv_buffer_slot : i32,
// %send_buffer : !vm.ref<!hal.buffer>,
+ // %recv_buffer : !vm.ref<!hal.buffer>,
// %send_offset : i64,
// %send_length : i64,
- // %recv_buffer : !vm.ref<!hal.buffer>,
// %recv_offset : i64,
// %recv_length : i64,
// %element_count : i64
@@ -184,25 +288,22 @@
rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc()));
}
- if (adaptor.getSendBuffer()) {
- callOperands.push_back(adaptor.getSendBuffer());
- callOperands.push_back(adaptor.getSendOffset());
- callOperands.push_back(adaptor.getSendLength());
- } else {
- callOperands.push_back(getNullBuffer());
- callOperands.push_back(getZeroI64());
- callOperands.push_back(getZeroI64());
- }
-
- if (adaptor.getRecvBuffer()) {
- callOperands.push_back(adaptor.getRecvBuffer());
- callOperands.push_back(adaptor.getRecvOffset());
- callOperands.push_back(adaptor.getRecvLength());
- } else {
- callOperands.push_back(getNullBuffer());
- callOperands.push_back(getZeroI64());
- callOperands.push_back(getZeroI64());
- }
+ auto [sendBufferSlot, sendBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getSendBuffer(), rewriter);
+ auto [recvBufferSlot, recvBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getRecvBuffer(), rewriter);
+ callOperands.push_back(sendBufferSlot);
+ callOperands.push_back(recvBufferSlot);
+ callOperands.push_back(sendBuffer);
+ callOperands.push_back(recvBuffer);
+ callOperands.push_back(adaptor.getSendOffset() ? adaptor.getSendOffset()
+ : getZeroI64());
+ callOperands.push_back(adaptor.getSendLength() ? adaptor.getSendLength()
+ : getZeroI64());
+ callOperands.push_back(adaptor.getRecvOffset() ? adaptor.getRecvOffset()
+ : getZeroI64());
+ callOperands.push_back(adaptor.getRecvLength() ? adaptor.getRecvLength()
+ : getZeroI64());
callOperands.push_back(castToImportType(adaptor.getElementCount(),
rewriter.getI64Type(), rewriter));
@@ -237,29 +338,6 @@
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
- // Memoize zeros/nulls ala IndexSet.
- // Since there are usually hundreds to thousands of these push ops and each
- // one can have 5-10 of these this saves us a tremendous amount of time
- // creating/verifying/pattern matching/folding/CSE'ing.
- // We could extend IndexSet into a ConstantSet that could use these custom
- // VM ops instead of just arith.constant in order to make this more
- // reusable.
- Value zero;
- auto getI32Zero = [&]() {
- if (!zero) {
- zero = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
- }
- return zero;
- };
- Value null;
- auto getNull = [&]() {
- if (!null) {
- null = rewriter.create<IREE::VM::ConstRefZeroOp>(
- op.getLoc(),
- IREE::VM::RefType::get(rewriter.getType<IREE::HAL::BufferType>()));
- }
- return null;
- };
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
@@ -278,16 +356,10 @@
for (size_t i = 0; i < adaptor.getBindingOrdinals().size(); ++i) {
callOperands.push_back(
castToImportType(adaptor.getBindingOrdinals()[i], i32Type, rewriter));
- auto bindingBuffer = adaptor.getBindingBuffers()[i];
- if (llvm::isa<IREE::VM::RefType>(bindingBuffer.getType())) {
- // Buffer binding; pass 0 for table slot.
- callOperands.push_back(getI32Zero());
- callOperands.push_back(bindingBuffer);
- } else {
- // Binding table reference; pass null for the buffer.
- callOperands.push_back(bindingBuffer);
- callOperands.push_back(getNull());
- }
+ auto [bindingBufferSlot, bindingBuffer] = splitBufferSlot(
+ op.getLoc(), adaptor.getBindingBuffers()[i], rewriter);
+ callOperands.push_back(bindingBufferSlot);
+ callOperands.push_back(bindingBuffer);
callOperands.push_back(
castToImportType(adaptor.getBindingOffsets()[i], i64Type, rewriter));
callOperands.push_back(
@@ -305,6 +377,54 @@
mutable IREE::VM::ImportOp importOp;
};
+class CommandBufferDispatchIndirectOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferDispatchIndirectOp> {
+public:
+ CommandBufferDispatchIndirectOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferDispatchIndirectOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+ auto [workgroupsBufferSlot, workgroupsBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
+ auto flags = adaptor.getFlagsAttr()
+ ? rewriter
+ .create<IREE::VM::ConstI64Op>(
+ op.getLoc(), adaptor.getFlagsAttr().getInt())
+ .getResult()
+ : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
+ .getResult();
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ adaptor.getExecutable(),
+ castToImportType(adaptor.getEntryPoint(), rewriter.getI32Type(),
+ rewriter),
+ workgroupsBufferSlot,
+ workgroupsBuffer,
+ castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(),
+ rewriter),
+ flags,
+ };
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
} // namespace
void populateHALCommandBufferToVMPatterns(MLIRContext *context,
@@ -329,7 +449,10 @@
"hal.command_buffer.execution_barrier");
patterns.insert<CommandBufferFillBufferOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer");
- patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>(
+ patterns.insert<CommandBufferUpdateBufferOpConversion>(
+ context, importSymbols, typeConverter,
+ "hal.command_buffer.update_buffer");
+ patterns.insert<CommandBufferCopyBufferOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
patterns.insert<CommandBufferCollectiveOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.collective");
@@ -342,10 +465,9 @@
"hal.command_buffer.push_descriptor_set");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch");
- patterns
- .insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchIndirectOp>>(
- context, importSymbols, typeConverter,
- "hal.command_buffer.dispatch.indirect");
+ patterns.insert<CommandBufferDispatchIndirectOpConversion>(
+ context, importSymbols, typeConverter,
+ "hal.command_buffer.dispatch.indirect");
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
index 711e860..2df6959 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -1,18 +1,20 @@
// RUN: iree-opt --split-input-file --iree-vm-conversion --canonicalize --iree-vm-target-index-bits=32 %s | FileCheck %s
// CHECK-LABEL: @command_buffer_create
-util.func public @command_buffer_create(%arg0: !hal.device) {
- // CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3, %zero) : (!vm.ref<!hal.device>, i32, i32, i32) -> !vm.ref<!hal.command_buffer>
- %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64)
+util.func public @command_buffer_create(%device: !hal.device, %affinity: i64) {
+ // CHECK: = vm.call @hal.command_buffer.create(%[[DEVICE]], %c1, %c3, %[[AFFINITY]], %zero) : (!vm.ref<!hal.device>, i32, i32, i64, i32) -> !vm.ref<!hal.command_buffer>
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
// -----
// CHECK-LABEL: @command_buffer_create_bindings
-util.func public @command_buffer_create_bindings(%arg0: !hal.device, %arg1: index) {
- // CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3, %arg1) : (!vm.ref<!hal.device>, i32, i32, i32) -> !vm.ref<!hal.command_buffer>
- %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") bindings(%arg1) : !hal.command_buffer
+// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref<!hal.device>, %[[AFFINITY:.+]]: i64, %[[CAPACITY:.+]]: i32)
+util.func public @command_buffer_create_bindings(%device: !hal.device, %affinity: i64, %capacity: index) {
+ // CHECK: = vm.call @hal.command_buffer.create(%[[DEVICE]], %c1, %c3, %[[AFFINITY]], %[[CAPACITY]]) : (!vm.ref<!hal.device>, i32, i32, i64, i32) -> !vm.ref<!hal.command_buffer>
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) bindings(%capacity) : !hal.command_buffer
util.return
}
@@ -50,9 +52,10 @@
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 1
// CHECK-DAG: %[[EXTEND:.+]] = vm.ext.i8.i32.u %arg2 : i32 -> i32
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32, i32) -> ()
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%arg2 : i8)
@@ -69,9 +72,10 @@
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 2
// CHECK-DAG: %[[EXTEND:.+]] = vm.ext.i16.i32.u %arg2 : i32 -> i32
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %[[EXTEND]], %[[PATTERN_LENGTH]])
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%arg2 : i16)
@@ -88,8 +92,9 @@
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %arg2, %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %arg2, %[[PATTERN_LENGTH]])
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%arg2 : i32)
@@ -98,18 +103,117 @@
// -----
-// CHECK-LABEL: @command_buffer_copy_buffer
-util.func public @command_buffer_copy_buffer(
+// CHECK-LABEL: @command_buffer_fill_buffer_i32_indirect
+util.func public @command_buffer_fill_buffer_i32_indirect(
%arg0: !hal.command_buffer,
- %arg1: !hal.buffer
+ %arg1: index,
+ %arg2: i32
+) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %[[NULL_BUFFER]], %c100, %c200, %arg1, %arg2, %[[PATTERN_LENGTH]])
+ hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
+ target(%arg1 : index)[%c100, %c200]
+ pattern(%arg2 : i32)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_update_buffer
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[DEVICE_BUFFER:[a-z0-9]+]]: !vm.ref<!hal.buffer>, %[[DST_OFFSET:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: i32)
+util.func public @command_buffer_update_buffer(
+ %cmd: !hal.command_buffer,
+ %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index,
+ %device_buffer: !hal.buffer, %dst_offset: index,
+ %length: index
+ ) {
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]]
+ // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]]
+ // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]]
+ // CHECK: vm.call @hal.command_buffer.update_buffer
+ // CHECK-SAME: (%[[CMD]],
+ // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]],
+ // CHECK-SAME: %[[DEVICE_BUFFER]], %[[DST_OFFSET_I64]],
+ // CHECK-SAME: %[[LENGTH_I64]], %[[UNUSED_SLOT]])
+ hal.command_buffer.update_buffer<%cmd : !hal.command_buffer>
+ source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset]
+ target(%device_buffer : !hal.buffer)[%dst_offset]
+ length(%length)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_update_buffer_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[DEVICE_BUFFER_SLOT:[a-z0-9]+]]: i32, %[[DST_OFFSET:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: i32)
+util.func public @command_buffer_update_buffer_indirect(
+ %cmd: !hal.command_buffer,
+ %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index,
+ %device_buffer: index, %dst_offset: index,
+ %length: index
+ ) {
+ // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]]
+ // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]]
+ // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]]
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.update_buffer
+ // CHECK-SAME: (%[[CMD]],
+ // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]],
+ // CHECK-SAME: %[[NULL_BUFFER]], %[[DST_OFFSET_I64]],
+ // CHECK-SAME: %[[LENGTH_I64]], %[[DEVICE_BUFFER_SLOT]])
+ hal.command_buffer.update_buffer<%cmd : !hal.command_buffer>
+ source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset]
+ target(%device_buffer : index)[%dst_offset]
+ length(%length)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_copy_buffer
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>, %[[BUFFER:.+]]: !vm.ref<!hal.buffer>)
+util.func public @command_buffer_copy_buffer(
+ %cmd: !hal.command_buffer,
+ %buffer: !hal.buffer
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK: vm.call @hal.command_buffer.copy_buffer(%arg0, %arg1, %c100, %arg1, %c200, %c300) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
- hal.command_buffer.copy_buffer<%arg0 : !hal.command_buffer>
- source(%arg1 : !hal.buffer)[%c100]
- target(%arg1 : !hal.buffer)[%c200]
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK: vm.call @hal.command_buffer.copy_buffer(%[[CMD]], %[[UNUSED_SLOT]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[BUFFER]], %c200, %c300)
+ hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer>
+ source(%buffer : !hal.buffer)[%c100]
+ target(%buffer : !hal.buffer)[%c200]
+ length(%c300)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_copy_buffer_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>, %[[BUFFER_SLOT:.+]]: i32)
+util.func public @command_buffer_copy_buffer_indirect(
+ %cmd: !hal.command_buffer,
+ %buffer_slot: index
+) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.copy_buffer(%[[CMD]], %[[BUFFER_SLOT]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[NULL_BUFFER]], %c200, %c300)
+ hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer>
+ source(%buffer_slot : index)[%c100]
+ target(%buffer_slot : index)[%c200]
length(%c300)
util.return
}
@@ -129,16 +233,17 @@
%send_buffer: !hal.buffer, %recv_buffer: !hal.buffer,
%count: index) {
// CHECK-DAG: %[[OP_BITS:.+]] = vm.const.i32 590081
- // CHECK-DAG: %[[PARAM:.+]] = vm.const.i32.zero
+ // CHECK-DAG: %[[ZERO_I32:.+]] = vm.const.i32.zero
%c10 = arith.constant 10 : index
%c20 = arith.constant 20 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
// CHECK-DAG: %[[COUNT_I64:.+]] = vm.ext.i32.i64.s %[[COUNT]]
// CHECK: vm.call @hal.command_buffer.collective
- // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[PARAM]]
- // CHECK-SAME: %[[SEND_BUFFER]], %c10, %c128,
- // CHECK-SAME: %[[RECV_BUFFER]], %c20, %c256,
+ // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[ZERO_I32]]
+ // CHECK-SAME: %[[ZERO_I32]], %[[ZERO_I32]],
+ // CHECK-SAME: %[[SEND_BUFFER]], %[[RECV_BUFFER]],
+ // CHECK-SAME: %c10, %c128, %c20, %c256,
// CHECK-SAME: %[[COUNT_I64]])
hal.command_buffer.collective<%cmd : !hal.command_buffer>
channel(%channel : !hal.channel)
@@ -163,15 +268,18 @@
%param: i32,
%send_buffer: !hal.buffer,
%count: index) {
- // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK-DAG: %[[OP_BITS:.+]] = vm.const.i32 262150
%c10 = arith.constant 10 : index
%c128 = arith.constant 128 : index
// CHECK-DAG: %[[COUNT_I64:.+]] = vm.ext.i32.i64.s %[[COUNT]]
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK-DAG: %[[ZERO_I64:.+]] = vm.const.i64.zero
// CHECK: vm.call @hal.command_buffer.collective
// CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[PARAM]],
- // CHECK-SAME: %[[SEND_BUFFER]], %c10, %c128,
- // CHECK-SAME: %[[NULL_BUFFER]], %zero, %zero,
+ // CHECK-SAME: %[[UNUSED_SLOT]], %[[UNUSED_SLOT]],
+ // CHECK-SAME: %[[SEND_BUFFER]], %[[NULL_BUFFER]],
+ // CHECK-SAME: %c10, %c128, %[[ZERO_I64]], %[[ZERO_I64]],
// CHECK-SAME: %[[COUNT_I64]])
hal.command_buffer.collective<%cmd : !hal.command_buffer>
channel(%channel : !hal.channel)
@@ -185,10 +293,10 @@
// -----
// CHECK-LABEL: @command_buffer_push_descriptor_set
-// CHECK-SAME: %[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
-// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref<!hal.pipeline_layout>,
-// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
-// CHECK-SAME: %[[SLOT:.+]]: i32
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref<!hal.pipeline_layout>,
+// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
+// CHECK-SAME: %[[SLOT:.+]]: i32)
util.func public @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
@@ -220,36 +328,70 @@
// -----
// CHECK-LABEL: @command_buffer_dispatch
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>)
util.func public @command_buffer_dispatch(
- %arg0: !hal.command_buffer,
- %arg1: !hal.executable
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable
) {
- // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK: vm.call @hal.command_buffer.dispatch(%arg0, %arg1, %[[ORDINAL]], %c100, %c200, %c300) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32) -> ()
- hal.command_buffer.dispatch<%arg0 : !hal.command_buffer>
- target(%arg1 : !hal.executable)[%ordinal]
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300, %[[FLAGS]])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
workgroups([%c100, %c200, %c300])
+ flags(None)
util.return
}
// -----
// CHECK-LABEL: @command_buffer_dispatch_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
+// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>)
util.func public @command_buffer_dispatch_indirect(
- %arg0: !hal.command_buffer,
- %arg1: !hal.executable,
- %arg2: !hal.buffer
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable,
+ %buffer: !hal.buffer
) {
- // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
- // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%arg0, %arg1, %[[ORDINAL]], %arg2, %c100) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, !vm.ref<!hal.buffer>, i64) -> ()
- hal.command_buffer.dispatch.indirect<%arg0 : !hal.command_buffer>
- target(%arg1 : !hal.executable)[%ordinal]
- workgroups(%arg2 : !hal.buffer)[%c100]
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[FLAGS]])
+ hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
+ workgroups(%buffer : !hal.buffer)[%c100]
+ flags(None)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
+// CHECK-SAME: %[[BUFFER_SLOT:.+]]: i32)
+util.func public @command_buffer_dispatch_indirect_indirect(
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable,
+ %buffer_slot: index
+) {
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
+ %ordinal = arith.constant 123 : index
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[FLAGS]])
+ hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
+ workgroups(%buffer_slot : index)[%c100]
+ flags(None)
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 771e7d4..74c0fc9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -731,9 +731,11 @@
entryPointAttr.getRootReference().getValue());
Value ordinal = caseBuilder.create<IREE::HAL::ExecutableExportOrdinalOp>(
loc, caseBuilder.getIndexType(), entryPointAttr);
+ auto flags = caseBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>(
+ IREE::HAL::DispatchFlags::None);
caseBuilder.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0],
- caseWorkgroupCount[1], caseWorkgroupCount[2]);
+ caseWorkgroupCount[1], caseWorkgroupCount[2], flags);
caseBuilder.create<scf::YieldOp>(loc);
}
@@ -939,7 +941,8 @@
rewriter
.create<IREE::HAL::CommandBufferCreateOp>(
loc, rewriter.getType<IREE::HAL::CommandBufferType>(), device,
- modes, commandCategories, /*binding_capacity=*/Value{})
+ modes, commandCategories, queueAffinity,
+ /*binding_capacity=*/Value{})
.getResult();
mapping->mapCommandBuffer(executeOp, commandBuffer);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index e2d56d2..9d85020 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -186,6 +186,14 @@
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
+def HAL_DispatchFlags_None : I64BitEnumAttrCase<"None", 0x0000>;
+def HAL_DispatchFlagsAttr :
+ I64BitEnumAttr<"DispatchFlags", "valid dispatch flags", [
+ HAL_DispatchFlags_None,
+ ]> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
+}
+
def HAL_ExecutionStage_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_ExecutionStage_CommandIssue : I32BitEnumAttrCase<"CommandIssue", 0x0001>;
def HAL_ExecutionStage_CommandProcess : I32BitEnumAttrCase<"CommandProcess", 0x0002>;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 784ade3..d082847 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -132,10 +132,10 @@
bool needsUpdate = false;
auto newSourceBuffer = op.getSourceBuffer();
auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
- if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
op.getSourceBuffer().getDefiningOp())) {
newSourceBuffer = subspanOp.getSourceBuffer();
- newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ newSourceOffset = rewriter.createOrFold<arith::AddIOp>(
subspanOp.getLoc(), subspanOp.getSourceOffset(),
op.getSourceOffset());
needsUpdate = true;
@@ -220,10 +220,10 @@
bool needsUpdate = false;
auto newTargetBuffer = op.getTargetBuffer();
auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
- if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
op.getTargetBuffer().getDefiningOp())) {
newTargetBuffer = subspanOp.getSourceBuffer();
- newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ newTargetOffset = rewriter.createOrFold<arith::AddIOp>(
subspanOp.getLoc(), subspanOp.getSourceOffset(),
op.getTargetOffset());
needsUpdate = true;
@@ -248,6 +248,46 @@
namespace {
+/// Folds hal.buffer.subspans into buffer update offsets.
+struct FoldCommandBufferUpdateBufferSubspans
+ : public OpRewritePattern<CommandBufferUpdateBufferOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CommandBufferUpdateBufferOp op,
+ PatternRewriter &rewriter) const override {
+ auto ip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ bool needsUpdate = false;
+ auto newTargetBuffer = op.getTargetBuffer();
+ auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
+ op.getTargetBuffer().getDefiningOp())) {
+ newTargetBuffer = subspanOp.getSourceBuffer();
+ newTargetOffset = rewriter.createOrFold<arith::AddIOp>(
+ subspanOp.getLoc(), subspanOp.getSourceOffset(),
+ op.getTargetOffset());
+ needsUpdate = true;
+ }
+ rewriter.restoreInsertionPoint(ip);
+ if (!needsUpdate)
+ return failure();
+ rewriter.modifyOpInPlace(op, [&]() {
+ op.getTargetBufferMutable().assign(newTargetBuffer);
+ op.getTargetOffsetMutable().assign(newTargetOffset);
+ });
+ return success();
+ }
+};
+
+} // namespace
+
+void CommandBufferUpdateBufferOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.insert<FoldCommandBufferUpdateBufferSubspans>(context);
+}
+
+namespace {
+
/// Folds hal.buffer.subspans into buffer copy offsets.
struct FoldCommandBufferCopyBufferSubspans
: public OpRewritePattern<CommandBufferCopyBufferOp> {
@@ -260,20 +300,20 @@
bool needsUpdate = false;
auto newSourceBuffer = op.getSourceBuffer();
auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset());
- if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
op.getSourceBuffer().getDefiningOp())) {
newSourceBuffer = subspanOp.getSourceBuffer();
- newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ newSourceOffset = rewriter.createOrFold<arith::AddIOp>(
subspanOp.getLoc(), subspanOp.getSourceOffset(),
op.getSourceOffset());
needsUpdate = true;
}
auto newTargetBuffer = op.getTargetBuffer();
auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset());
- if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
+ if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>(
op.getTargetBuffer().getDefiningOp())) {
newTargetBuffer = subspanOp.getSourceBuffer();
- newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>(
+ newTargetOffset = rewriter.createOrFold<arith::AddIOp>(
subspanOp.getLoc(), subspanOp.getSourceOffset(),
op.getTargetOffset());
needsUpdate = true;
@@ -317,10 +357,10 @@
auto *definingOp = bindingBuffers[i].getDefiningOp();
if (!definingOp)
continue;
- if (auto subspanOp = dyn_cast<BufferSubspanOp>(definingOp)) {
+ if (auto subspanOp = dyn_cast<IREE::HAL::BufferSubspanOp>(definingOp)) {
needsUpdate = true;
bindingBuffers[i] = subspanOp.getSourceBuffer();
- bindingOffsets[i] = rewriter.createOrFold<mlir::arith::AddIOp>(
+ bindingOffsets[i] = rewriter.createOrFold<arith::AddIOp>(
subspanOp.getLoc(), subspanOp.getSourceOffset(), bindingOffsets[i]);
}
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index cc4de30..b8c0bab 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -991,6 +991,32 @@
}
//===----------------------------------------------------------------------===//
+// hal.command_buffer.update_buffer
+//===----------------------------------------------------------------------===//
+
+IREE::Util::SubrangeOperand
+CommandBufferUpdateBufferOp::getSubrangeOperand(unsigned operandIndex) {
+ if (operandIndex == 1) {
+ return IREE::Util::SubrangeOperand{getSourceBuffer(), getSourceSize(),
+ getSourceOffset(), getLength()};
+ } else {
+ assert(false && "only source is a subrange");
+ return {};
+ }
+}
+
+void CommandBufferUpdateBufferOp::setSubrangeOperand(
+ unsigned operandIndex, IREE::Util::SubrangeOperand operand) {
+ if (operandIndex == 1) {
+ getSourceBufferMutable().assign(operand.resource);
+ getSourceSizeMutable().assign(operand.resourceSize);
+ getSourceOffsetMutable().assign(operand.offset);
+ } else {
+ assert(false && "only source is a subrange");
+ }
+}
+
+//===----------------------------------------------------------------------===//
// hal.command_buffer.push_descriptor_set
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 1889f59..599c1ff 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1163,6 +1163,7 @@
HAL_Device:$device,
HAL_CommandBufferModeBitfieldAttr:$modes,
HAL_CommandCategoryBitfieldAttr:$command_categories,
+ HAL_DeviceQueueAffinity:$queue_affinity,
Optional<Index>:$binding_capacity
);
let results = (outs
@@ -1173,6 +1174,7 @@
`device` `(` $device `:` type($device) `)`
`mode` `(` $modes `)`
`categories` `(` $command_categories `)`
+ `affinity` `(` $queue_affinity `)`
(`bindings` `(` $binding_capacity^ `)`)?
`:` type($result)
attr-dict-with-keyword
@@ -1289,7 +1291,7 @@
let arguments = (ins
HAL_CommandBuffer:$command_buffer,
- HAL_BufferType:$target_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer,
HAL_DeviceSize:$target_offset,
HAL_DeviceSize:$length,
HAL_FillPatternType:$pattern
@@ -1306,7 +1308,49 @@
let hasCanonicalizer = 1;
}
-// TODO(benvanik): update buffer op.
+def HAL_CommandBufferUpdateBufferOp : HAL_Op<"command_buffer.update_buffer", [
+ // TODO(benvanik): figure out the right way to model host effects - this is
+ // a host read but a device write; if we make it just MemRead then it gets
+ // DCEd because it has no result. For now we report both to keep analysis
+ // appeased even if incorrect.
+ MemoryEffects<[MemRead, MemWrite]>,
+ Util_SizeAwareOp,
+ DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>,
+]> {
+ let summary = [{command buffer buffer update recording operation}];
+ let description = [{
+ Copies a range of a host buffer into a device buffer. The host buffer
+ contents will be captured at the time of the call and embedded in the
+ command buffer.
+ }];
+
+ let arguments = (ins
+ HAL_CommandBuffer:$command_buffer,
+ Util_BufferType:$source_buffer,
+ Util_Size:$source_size,
+ Util_Size:$source_offset,
+ AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer,
+ HAL_DeviceSize:$target_offset,
+ HAL_DeviceSize:$length
+ );
+
+ let assemblyFormat = [{
+ `<` $command_buffer `:` type($command_buffer) `>`
+ `source` `(` $source_buffer `:` type($source_buffer) `{` $source_size `}` `)`
+ `` `[` $source_offset `]`
+ `target` `(` $target_buffer `:` type($target_buffer) `)`
+ `` `[` $target_offset `]`
+ `length` `(` $length `)`
+ attr-dict-with-keyword
+ }];
+
+ let extraClassDeclaration = [{
+ Value getOperandSize(unsigned idx) { return idx == 1 ? getSourceSize() : Value{}; }
+ Value getResultSize(unsigned idx) { return {}; }
+ }];
+
+ let hasCanonicalizer = 1;
+}
def HAL_CommandBufferCopyBufferOp : HAL_Op<"command_buffer.copy_buffer"> {
let summary = [{command buffer buffer copy recording operation}];
@@ -1316,9 +1360,9 @@
let arguments = (ins
HAL_CommandBuffer:$command_buffer,
- HAL_BufferType:$source_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$source_buffer,
HAL_DeviceSize:$source_offset,
- HAL_BufferType:$target_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer,
HAL_DeviceSize:$target_offset,
HAL_DeviceSize:$length
);
@@ -1352,10 +1396,10 @@
Optional<I32>:$param,
// TODO(benvanik): change this to take descriptor set + binding instead.
// This would let us use indirect bindings.
- Optional<HAL_BufferType>:$send_buffer,
+ Optional<AnyTypeOf<[Index, HAL_BufferType]>>:$send_buffer,
Optional<HAL_DeviceSize>:$send_offset,
Optional<HAL_DeviceSize>:$send_length,
- Optional<HAL_BufferType>:$recv_buffer,
+ Optional<AnyTypeOf<[Index, HAL_BufferType]>>:$recv_buffer,
Optional<HAL_DeviceSize>:$recv_offset,
Optional<HAL_DeviceSize>:$recv_length
);
@@ -1458,7 +1502,8 @@
HAL_Ordinal:$entry_point,
HAL_Dim:$workgroup_x,
HAL_Dim:$workgroup_y,
- HAL_Dim:$workgroup_z
+ HAL_Dim:$workgroup_z,
+ HAL_DispatchFlagsAttr:$flags
);
let assemblyFormat = [{
@@ -1470,6 +1515,7 @@
$workgroup_y `,`
$workgroup_z
`]` `)`
+ `flags` `(` $flags `)`
attr-dict-with-keyword
}];
}
@@ -1485,8 +1531,9 @@
HAL_CommandBuffer:$command_buffer,
HAL_Executable:$executable,
HAL_Ordinal:$entry_point,
- HAL_BufferType:$workgroups_buffer,
- HAL_DeviceSize:$workgroups_offset
+ AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer,
+ HAL_DeviceSize:$workgroups_offset,
+ HAL_DispatchFlagsAttr:$flags
);
let assemblyFormat = [{
@@ -1495,6 +1542,7 @@
`` `[` $entry_point `]`
`workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)`
`` `[` $workgroups_offset `]`
+ `flags` `(` $flags `)`
attr-dict-with-keyword
}];
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
index 5ced86a..3adbce8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -1,11 +1,12 @@
// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
// CHECK-LABEL: @skip_command_buffer_device
-// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
-util.func public @skip_command_buffer_device(%device: !hal.device) -> !hal.executable {
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64)
+util.func public @skip_command_buffer_device(%device: !hal.device, %affinity: i64) -> !hal.executable {
%cmd = hal.command_buffer.create device(%device : !hal.device)
mode(OneShot)
- categories("Transfer|Dispatch") : !hal.command_buffer
+ categories("Transfer|Dispatch")
+ affinity(%affinity) : !hal.command_buffer
// CHECK-NOT: hal.command_buffer.device
// CHECK: = hal.executable.lookup device(%[[DEVICE]] : !hal.device)
@@ -42,6 +43,35 @@
// -----
+// CHECK-LABEL: @fold_buffer_subspans_into_update_buffer
+// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !util.buffer, %[[SOURCE_BUFFER_SIZE:.+]]: index,
+// CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer
+util.func public @fold_buffer_subspans_into_update_buffer(
+ %cmd: !hal.command_buffer,
+ %source_buffer: !util.buffer, %source_buffer_size: index,
+ %target_buffer: !hal.buffer
+ ) {
+ %c0 = arith.constant 0 : index
+ %c4096 = arith.constant 4096 : index
+ %c8192 = arith.constant 8192 : index
+ %c100000 = arith.constant 100000 : index
+ %c262144 = arith.constant 262144 : index
+ %source_subspan = util.buffer.subspan %source_buffer[%c4096] : !util.buffer{%source_buffer_size} -> !util.buffer{%c262144}
+ %target_subspan = hal.buffer.subspan<%target_buffer : !hal.buffer>[%c8192, %c262144] : !hal.buffer
+ // CHECK: hal.command_buffer.update_buffer
+ hal.command_buffer.update_buffer<%cmd : !hal.command_buffer>
+ // CHECK-SAME: source(%[[SOURCE_BUFFER]] : !util.buffer{%[[SOURCE_BUFFER_SIZE]]})[%c4096]
+ source(%source_subspan : !util.buffer{%c262144})[%c0]
+ // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%c108192]
+ target(%target_subspan : !hal.buffer)[%c100000]
+ // CHECK-SAME: length(%c8192)
+ length(%c8192)
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @fold_buffer_subspan_into_copy_buffer
// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
// CHECK-SAME: %[[BASE_BUFFER:.+]]: !hal.buffer
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
index 766b39a..77d56e3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -1,15 +1,17 @@
// RUN: iree-opt --split-input-file %s | FileCheck %s
// CHECK-LABEL: @command_buffer_create
-// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device)
-util.func public @command_buffer_create(%device: !hal.device) {
+// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64)
+util.func public @command_buffer_create(%device: !hal.device, %affinity: i64) {
// CHECK: %cmd = hal.command_buffer.create
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: mode(OneShot)
- // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer
+ // CHECK-SAME: categories("Transfer|Dispatch")
+ // CHECK-SAME: affinity(%[[AFFINITY]]) : !hal.command_buffer
%cmd = hal.command_buffer.create device(%device : !hal.device)
mode(OneShot)
- categories("Transfer|Dispatch") : !hal.command_buffer
+ categories("Transfer|Dispatch")
+ affinity(%affinity) : !hal.command_buffer
util.return
}
@@ -73,11 +75,36 @@
// -----
+// CHECK-LABEL: @command_buffer_update_buffer
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !util.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: index, %[[SRC_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DEVICE_BUFFER:[a-z0-9]+]]: !hal.buffer, %[[DST_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index)
+util.func public @command_buffer_update_buffer(
+ %cmd: !hal.command_buffer,
+ %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index,
+ %device_buffer: !hal.buffer, %dst_offset: index,
+ %length: index
+ ) {
+ // CHECK: hal.command_buffer.update_buffer<%[[CMD]] : !hal.command_buffer>
+ // CHECK-SAME: source(%[[HOST_BUFFER]] : !util.buffer{%[[HOST_BUFFER_SIZE]]})[%[[SRC_OFFSET]]]
+ // CHECK-SAME: target(%[[DEVICE_BUFFER]] : !hal.buffer)[%[[DST_OFFSET]]]
+ // CHECK-SAME: length(%[[LENGTH]])
+ hal.command_buffer.update_buffer<%cmd : !hal.command_buffer>
+ source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset]
+ target(%device_buffer : !hal.buffer)[%dst_offset]
+ length(%length)
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @command_buffer_copy_buffer
// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
-// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
-// CHECK-SAME: %[[SRC_OFFSET:.+]]: index, %[[DST_OFFSET:.+]]: index,
-// CHECK-SAME: %[[LENGTH:.+]]: index)
+// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
+// CHECK-SAME: %[[SRC_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DST_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index)
util.func public @command_buffer_copy_buffer(
%cmd: !hal.command_buffer,
%buffer: !hal.buffer,
@@ -98,11 +125,38 @@
// -----
+// CHECK-LABEL: @command_buffer_copy_buffer_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[BUFFER_SLOT:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[SRC_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DST_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index)
+util.func public @command_buffer_copy_buffer_indirect(
+ %cmd: !hal.command_buffer,
+ %buffer_slot: index,
+ %src_offset: index,
+ %dst_offset: index,
+ %length: index
+ ) {
+ // CHECK: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer>
+ // CHECK-SAME: source(%[[BUFFER_SLOT]] : index)[%[[SRC_OFFSET]]]
+ // CHECK-SAME: target(%[[BUFFER_SLOT]] : index)[%[[DST_OFFSET]]]
+ // CHECK-SAME: length(%[[LENGTH]])
+ hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer>
+ source(%buffer_slot : index)[%src_offset]
+ target(%buffer_slot : index)[%dst_offset]
+ length(%length)
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @command_buffer_collective
// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
// CHECK-SAME: %[[CHANNEL:.+]]: !hal.channel,
// CHECK-SAME: %[[PARAM:.+]]: i32,
-// CHECK-SAME: %[[SEND_BUFFER:.+]]: !hal.buffer, %[[RECV_BUFFER:.+]]: !hal.buffer,
+// CHECK-SAME: %[[SEND_BUFFER:[a-z0-9]+]]: !hal.buffer,
+// CHECK-SAME: %[[RECV_BUFFER:[a-z0-9]+]]: !hal.buffer,
// CHECK-SAME: %[[COUNT:.+]]: index)
util.func public @command_buffer_collective(
%cmd: !hal.command_buffer,
@@ -160,10 +214,10 @@
// -----
// CHECK-LABEL: @command_buffer_push_descriptor_set
-// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
-// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout,
-// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
-// CHECK-SAME: %[[SLOT:.+]]: index
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout,
+// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
+// CHECK-SAME: %[[SLOT:.+]]: index)
util.func public @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
@@ -212,9 +266,11 @@
// CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
// CHECK-SAME: workgroups([%[[X]], %[[Y]], %[[Z]]])
+ // CHECK-SAME: flags("None")
hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
target(%executable: !hal.executable)[%ordinal]
workgroups([%x, %y, %z])
+ flags("None")
util.return
}
@@ -242,8 +298,42 @@
// CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
// CHECK-SAME: workgroups(%[[BUFFER]] : !hal.buffer)[%[[OFFSET]]]
+ // CHECK-SAME: flags("None")
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable: !hal.executable)[%ordinal]
workgroups(%buffer : !hal.buffer)[%offset]
+ flags("None")
+ util.return
+}
+
+// -----
+
+hal.executable @ex {
+ hal.executable.variant @backend target(<"backend", "format">) {
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+ ]>)
+ }
+}
+
+// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[EXECUTABLE:[a-z0-9]+]]: !hal.executable, %[[ORDINAL:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[BUFFER_SLOT:[a-z0-9]+]]: index, %[[OFFSET:[a-z0-9]+]]: index)
+util.func public @command_buffer_dispatch_indirect_indirect(
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable, %ordinal: index,
+ %buffer_slot: index, %offset: index) {
+ // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer>
+ // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
+ // CHECK-SAME: workgroups(%[[BUFFER_SLOT]] : index)[%[[OFFSET]]]
+ // CHECK-SAME: flags("None")
+ hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
+ target(%executable: !hal.executable)[%ordinal]
+ workgroups(%buffer_slot : index)[%offset]
+ flags("None")
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 588f560..c2487b8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -256,6 +256,7 @@
// TODO(multi-device): support multiple devices in benchmark generation.
// For now we should just use the affinityAttr to resolve the device.
Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder);
+ Value queueAffinity = funcBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
// Create and begin command buffer.
// TODO(benvanik): reuse the command buffer (initialize once and store).
@@ -267,6 +268,7 @@
.create<IREE::HAL::CommandBufferCreateOp>(
loc, funcBuilder.getType<IREE::HAL::CommandBufferType>(), device,
commandBufferModes, IREE::HAL::CommandCategoryBitfield::Dispatch,
+ queueAffinity,
/*binding_capacity=*/Value{})
.getResult();
@@ -351,10 +353,12 @@
loc, indexSet.get(0), batchSizeArg, indexSet.get(1), ValueRange{},
[&](OpBuilder &forBuilder, Location loc, Value iv, ValueRange iters) {
// Dispatch.
+ auto flags = forBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>(
+ IREE::HAL::DispatchFlags::None);
forBuilder.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBuffer, executable, ordinal,
workgroupCountOp.getWorkgroupX(), workgroupCountOp.getWorkgroupY(),
- workgroupCountOp.getWorkgroupZ());
+ workgroupCountOp.getWorkgroupZ(), flags);
// Barrier following the dispatch to block the next dispatch.
auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire |
@@ -379,7 +383,6 @@
IREE::HAL::FenceFlagBitfield::None);
// Queue execution.
- auto queueAffinity = funcBuilder.create<arith::ConstantIntOp>(loc, -1, 64);
funcBuilder.create<IREE::HAL::DeviceQueueExecuteOp>(
loc, device, queueAffinity, waitFence, signalFence,
ValueRange{commandBuffer});
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
index 3bf0a7b..d504c5e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir
@@ -109,7 +109,7 @@
// CHECK: %[[CMD:.+]] = hal.command_buffer.create
// CHECK-SAME: device(%[[DEVICE]] : !hal.device)
// CHECK-SAME: mode("OneShot|AllowInlineExecution")
- // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer
+ // CHECK-SAME: categories("Transfer|Dispatch")
%timepoint = stream.cmd.execute
with(%arg0_resource as %arg0_capture: !stream.resource<external>{%c16},
%arg1_resource as %arg1_capture: !stream.resource<external>{%c16},
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
index f47bcd2..29de091 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir
@@ -5,9 +5,9 @@
module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} {
// CHECK-LABEL: @command_buffer_reusable
-util.func public @command_buffer_reusable(%arg0: !hal.device) {
- // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode("None")
- %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("None") categories("Transfer|Dispatch") : !hal.command_buffer
+util.func public @command_buffer_reusable(%device: !hal.device, %affinity: i64) {
+ // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("None")
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
} // module
@@ -18,9 +18,9 @@
module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} {
// CHECK-LABEL: @command_buffer_oneshot
-util.func public @command_buffer_oneshot(%arg0: !hal.device) {
- // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot|AllowInlineExecution")
- %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) categories("Transfer|Dispatch") : !hal.command_buffer
+util.func public @command_buffer_oneshot(%device: !hal.device, %affinity: i64) {
+ // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("OneShot|AllowInlineExecution")
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
} // module
@@ -34,9 +34,9 @@
#hal.device.target<"vulkan", {}>
]} {
// CHECK-LABEL: @legacy_mode_not_required
-util.func public @legacy_mode_not_required(%arg0: !hal.device) {
- // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot)
- %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) categories("Transfer|Dispatch") : !hal.command_buffer
+util.func public @legacy_mode_not_required(%device: !hal.device, %affinity: i64) {
+ // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode(OneShot)
+ %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer
util.return
}
} // module
@@ -51,7 +51,7 @@
]} {
// CHECK-LABEL: @mixed_legacy_mode_required
util.func public @mixed_legacy_mode_required(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
- %affinity = arith.constant 0 : i64
+ %affinity = arith.constant 1 : i64
// CHECK: hal.fence.await
// CHECK: hal.device.queue.execute
// CHECK: hal.fence.await
@@ -71,7 +71,7 @@
// CHECK-LABEL: @blocking_execute
// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence)
util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) {
- %affinity = arith.constant 0 : i64
+ %affinity = arith.constant 1 : i64
// CHECK-DAG: %[[NULL:.+]] = util.null : !hal.fence
// CHECK-DAG: hal.fence.await until([%[[WAIT]]])
// CHECK-NEXT: hal.device.queue.execute<%[[DEVICE]] : !hal.device>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
index e7b80e1..a139ece 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
@@ -13,12 +13,12 @@
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
- hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None)
hal.command_buffer.execution_barrier<%cmd1 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
- hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2])
+ hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2]) flags(None)
- hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2])
+ hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2]) flags(None)
hal.command_buffer.execution_barrier<%cmd2 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
util.return
@@ -59,7 +59,7 @@
%c1 = arith.constant 1 : index
scf.index_switch %idx
case 0 {
- hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None)
scf.yield
}
default {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 1d21fb3..66f8dd7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -33,7 +33,6 @@
%buffer_usage : i32,
%allocation_size : i64
) -> !vm.ref<!hal.buffer>
-attributes {minimum_version = 1 : i32}
// Imports a host byte buffer into a device visible buffer.
// If try!=0 then returns null if the given memory type cannot be mapped.
@@ -48,7 +47,6 @@
%offset : i64,
%length : i64
) -> !vm.ref<!hal.buffer>
-attributes {minimum_version = 1 : i32}
//===----------------------------------------------------------------------===//
// iree_hal_buffer_t
@@ -199,8 +197,12 @@
%device : !vm.ref<!hal.device>,
%modes : i32,
%command_categories : i32,
+ %queue_affinity : i64,
%binding_capacity : i32
) -> !vm.ref<!hal.command_buffer>
+attributes {
+ minimum_version = 3 : i32 // command buffer API version
+}
// Finalizes recording into the command buffer and prepares it for submission.
// No more commands can be recorded afterward.
@@ -230,18 +232,35 @@
)
// Fills the target buffer with the given repeating value.
+// NOTE: order slightly differs from op in order to get better arg alignment.
vm.import private @command_buffer.fill_buffer(
%command_buffer : !vm.ref<!hal.command_buffer>,
%target_buffer : !vm.ref<!hal.buffer>,
%target_offset : i64,
%length : i64,
+ %target_buffer_slot : i32,
%pattern : i32,
%pattern_length: i32
)
+// Updates a device buffer with the captured contents of a host buffer.
+// NOTE: order slightly differs from op in order to get better arg alignment.
+vm.import private @command_buffer.update_buffer(
+ %command_buffer : !vm.ref<!hal.command_buffer>,
+ %source_buffer : !vm.buffer,
+ %source_offset : i64,
+ %target_buffer : !vm.ref<!hal.buffer>,
+ %target_offset : i64,
+ %length : i64,
+ %target_buffer_slot : i32
+)
+
// Copies a range of one buffer to another.
+// NOTE: order slightly differs from op in order to get better arg alignment.
vm.import private @command_buffer.copy_buffer(
%command_buffer : !vm.ref<!hal.command_buffer>,
+ %source_buffer_slot : i32,
+ %target_buffer_slot : i32,
%source_buffer : !vm.ref<!hal.buffer>,
%source_offset : i64,
%target_buffer : !vm.ref<!hal.buffer>,
@@ -256,10 +275,12 @@
%channel : !vm.ref<!hal.channel>,
%op : i32,
%param : i32,
+ %send_buffer_slot : i32,
+ %recv_buffer_slot : i32,
%send_buffer : !vm.ref<!hal.buffer>,
+ %recv_buffer : !vm.ref<!hal.buffer>,
%send_offset : i64,
%send_length : i64,
- %recv_buffer : !vm.ref<!hal.buffer>,
%recv_offset : i64,
%recv_length : i64,
%element_count : i64
@@ -289,7 +310,8 @@
%entry_point : i32,
%workgroup_x : i32,
%workgroup_y : i32,
- %workgroup_z : i32
+ %workgroup_z : i32,
+ %flags : i64
)
// Dispatches an execution request with the dispatch parameters loaded from the
@@ -298,8 +320,10 @@
%command_buffer : !vm.ref<!hal.command_buffer>,
%executable : !vm.ref<!hal.executable>,
%entry_point : i32,
+ %workgroups_buffer_slot : i32,
%workgroups_buffer : !vm.ref<!hal.buffer>,
- %workgroups_offset : i64
+ %workgroups_offset : i64,
+ %flags : i64
)
//===----------------------------------------------------------------------===//
@@ -425,16 +449,10 @@
//===----------------------------------------------------------------------===//
vm.import private @devices.count() -> i32
-attributes {
- minimum_version = 2 : i32,
- nosideeffects
-}
+attributes {nosideeffects}
vm.import private @devices.get(%index : i32) -> !vm.ref<!hal.device>
-attributes {
- minimum_version = 2 : i32,
- nosideeffects
-}
+attributes {nosideeffects}
//===----------------------------------------------------------------------===//
// iree_hal_executable_t
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
index 3ba86d9..fa837d4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -58,8 +58,7 @@
return success();
}
-Value castToImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter) {
+Value castToImportType(Value value, Type targetType, OpBuilder &builder) {
auto sourceType = value.getType();
if (sourceType == targetType)
return value;
@@ -70,36 +69,35 @@
if (llvm::isa<FloatType>(sourceType) && llvm::isa<IntegerType>(targetType) &&
sourceType.getIntOrFloatBitWidth() ==
targetType.getIntOrFloatBitWidth()) {
- return rewriter.create<mlir::arith::BitcastOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::BitcastOp>(value.getLoc(), targetType,
+ value);
} else if (sourceIsInteger &&
(targetType.isSignedInteger() || targetType.isSignlessInteger())) {
if (targetType.getIntOrFloatBitWidth() >
sourceType.getIntOrFloatBitWidth()) {
- return rewriter.create<mlir::arith::ExtSIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::ExtSIOp>(value.getLoc(), targetType,
+ value);
} else {
- return rewriter.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
+ value);
}
} else if (sourceIsInteger && targetType.isUnsignedInteger()) {
if (targetType.getIntOrFloatBitWidth() >
sourceType.getIntOrFloatBitWidth()) {
- return rewriter.create<mlir::arith::ExtUIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::ExtUIOp>(value.getLoc(), targetType,
+ value);
} else {
- return rewriter.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
+ value);
}
} else {
return value;
}
}
-Value castFromImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter) {
+Value castFromImportType(Value value, Type targetType, OpBuilder &builder) {
// Right now the to-import and from-import types are the same.
- return castToImportType(value, targetType, rewriter);
+ return castToImportType(value, targetType, builder);
}
void copyImportAttrs(IREE::VM::ImportOp importOp, Operation *callOp) {
@@ -118,43 +116,42 @@
}
}
-std::optional<SmallVector<Value>>
-rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType,
- ConversionPatternRewriter &rewriter) {
+std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc,
+ Attribute attrValue,
+ Type inputType,
+ OpBuilder &builder) {
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrValue)) {
// NOTE: we intentionally go to std.constant ops so that the standard
// conversions can do their job. If we want to remove the dependency
// from standard ops in the future we could instead go directly to
// one of the vm constant ops.
- auto constValue = rewriter.createOrFold<mlir::arith::ConstantOp>(
+ auto constValue = builder.create<mlir::arith::ConstantOp>(
loc, inputType,
- IntegerAttr::get(inputType,
- APInt(32, static_cast<int32_t>(intAttr.getInt()))));
+ IntegerAttr::get(inputType, APInt(inputType.getIntOrFloatBitWidth(),
+ intAttr.getValue().getSExtValue())));
return {{constValue}};
- }
- if (auto elementsAttr = llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) {
+ } else if (auto elementsAttr =
+ llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) {
SmallVector<Value> elementValues;
elementValues.reserve(elementsAttr.getNumElements());
for (auto intAttr : elementsAttr.getValues<Attribute>()) {
- elementValues.push_back(rewriter.createOrFold<mlir::arith::ConstantOp>(
+ elementValues.push_back(builder.create<mlir::arith::ConstantOp>(
loc, elementsAttr.getType().getElementType(),
cast<TypedAttr>(intAttr)));
}
return elementValues;
- }
- if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) {
+ } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) {
SmallVector<Value> allValues;
for (auto elementAttr : arrayAttr) {
auto flattenedValues =
- rewriteAttrToOperands(loc, elementAttr, inputType, rewriter);
+ rewriteAttrToOperands(loc, elementAttr, inputType, builder);
if (!flattenedValues)
return std::nullopt;
allValues.append(flattenedValues->begin(), flattenedValues->end());
}
return allValues;
- }
- if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) {
- return {{rewriter.create<IREE::VM::RodataInlineOp>(loc, strAttr)}};
+ } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) {
+ return {{builder.create<IREE::VM::RodataInlineOp>(loc, strAttr)}};
}
// This may be a custom dialect type. As we can't trivially access the storage
@@ -176,7 +173,7 @@
return;
auto elementType = tupleTypes[ordinal++];
auto flattenedValues =
- rewriteAttrToOperands(loc, elementAttr, elementType, rewriter);
+ rewriteAttrToOperands(loc, elementAttr, elementType, builder);
if (!flattenedValues) {
anyFailed = true;
return;
@@ -192,7 +189,7 @@
if (anyFailed)
return;
auto flattenedValues =
- rewriteAttrToOperands(loc, elementAttr, inputType, rewriter);
+ rewriteAttrToOperands(loc, elementAttr, inputType, builder);
if (!flattenedValues) {
anyFailed = true;
return;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index 3e80906..b2f0a8f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -30,20 +30,19 @@
namespace detail {
size_t getSegmentSpanSize(Type spanType);
-std::optional<SmallVector<Value>>
-rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType,
- ConversionPatternRewriter &rewriter);
+std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc,
+ Attribute attrValue,
+ Type inputType,
+ OpBuilder &builder);
} // namespace detail
// Casts |value| to |targetType| ala static_cast for when the declared type
// differs from the type provided by the input dialect.
-Value castToImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter);
+Value castToImportType(Value value, Type targetType, OpBuilder &builder);
// Casts |value| to |targetType| ala static_cast for when the declared return
// type of an import does not match the required output type.
-Value castFromImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter);
+Value castFromImportType(Value value, Type targetType, OpBuilder &builder);
// Copies known attributes from the |importOp| to the |callOp|.
// This allows for passes to quickly query the properties of the import such as
@@ -56,8 +55,7 @@
template <typename T, typename Adaptor = typename T::Adaptor>
std::optional<SmallVector<Value>>
rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp,
- const TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
+ const TypeConverter &typeConverter, OpBuilder &builder) {
auto *operation = op.getOperation();
bool isOpVariadic = importOp.isVariadic();
OperationState state{
@@ -76,7 +74,7 @@
auto inputName = importOp.getFuncArgumentName(input.index());
if (auto attrValue = op->getAttr(inputName)) {
auto flattenedAttrs = detail::rewriteAttrToOperands(
- op.getLoc(), attrValue, inputType, rewriter);
+ op.getLoc(), attrValue, inputType, builder);
if (!flattenedAttrs)
return std::nullopt;
state.addOperands(*flattenedAttrs);
@@ -101,11 +99,11 @@
}
for (auto [newOperand, inputType] :
llvm::zip_equal(newOperands, inputTupleType.getTypes())) {
- state.addOperands(castToImportType(newOperand, inputType, rewriter));
+ state.addOperands(castToImportType(newOperand, inputType, builder));
}
} else {
for (auto &operand : newOperands) {
- state.addOperands(castToImportType(operand, inputType, rewriter));
+ state.addOperands(castToImportType(operand, inputType, builder));
}
}
@@ -121,16 +119,16 @@
"segment_sizes",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int64_t>(segmentSizes.size())},
- rewriter.getIntegerType(16)),
+ builder.getIntegerType(16)),
segmentSizes));
state.addAttribute("segment_types",
- rewriter.getArrayAttr(llvm::map_to_vector(
+ builder.getArrayAttr(llvm::map_to_vector(
importType.getInputs(), [&](Type type) {
return cast<Attribute>(TypeAttr::get(type));
})));
}
- auto *callOp = rewriter.create(state);
+ auto *callOp = builder.create(state);
copyImportAttrs(importOp, callOp);
SmallVector<Value> results;
@@ -139,7 +137,7 @@
targetType = typeConverter.convertType(targetType);
if (!targetType)
return std::nullopt;
- results.push_back(castFromImportType(result, targetType, rewriter));
+ results.push_back(castFromImportType(result, targetType, builder));
}
return results;
}
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index fe165e5..42b476b 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -412,7 +412,8 @@
static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
// Lookup kernel parameters used for side-channeling additional launch
@@ -463,7 +464,7 @@
static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need rocm implementation");
}
diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c
index d57dee4..de89e4f 100644
--- a/experimental/webgpu/command_buffer.c
+++ b/experimental/webgpu/command_buffer.c
@@ -884,7 +884,8 @@
static iree_status_t iree_hal_webgpu_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_webgpu_command_buffer_t* command_buffer =
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
@@ -900,7 +901,7 @@
static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_webgpu_command_buffer_t* command_buffer =
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c
index 38619a1..7f3785d 100644
--- a/runtime/src/iree/hal/command_buffer.c
+++ b/runtime/src/iree/hal/command_buffer.c
@@ -558,7 +558,8 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
if ((workgroup_x | workgroup_y | workgroup_z) == 0) {
@@ -574,7 +575,7 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
- entry_point, workgroup_x, workgroup_y, workgroup_z));
+ entry_point, workgroup_x, workgroup_y, workgroup_z, flags));
});
#if IREE_HAL_VERBOSE_TRACING_ENABLE
// TODO(benvanik): add a tracing.h helper that does the snprintf directly
@@ -594,7 +595,7 @@
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)(
command_buffer, executable, entry_point, workgroup_x, workgroup_y,
- workgroup_z);
+ workgroup_z, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -602,7 +603,7 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -610,10 +611,10 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_indirect_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
- entry_point, workgroups_ref));
+ entry_point, workgroups_ref, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)(
- command_buffer, executable, entry_point, workgroups_ref);
+ command_buffer, executable, entry_point, workgroups_ref, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index 38ac03d..c9c6037 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -384,6 +384,12 @@
IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count(
iree_hal_collective_element_type_t element_type);
+// Bitfield specifying flags controlling a dispatch operation.
+enum iree_hal_dispatch_flag_bits_t {
+ IREE_HAL_DISPATCH_FLAG_NONE = 0,
+};
+typedef uint64_t iree_hal_dispatch_flags_t;
+
// An RGBA color.
typedef struct iree_hal_label_color_t {
uint8_t r;
@@ -751,7 +757,8 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z);
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags);
// Dispatches an execution request with deferred workgroup counts.
// This is the same as iree_hal_command_buffer_dispatch but the workgroup counts
@@ -765,7 +772,7 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref);
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
//===----------------------------------------------------------------------===//
// Validation support
@@ -922,12 +929,13 @@
iree_status_t(IREE_API_PTR* dispatch)(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z);
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags);
iree_status_t(IREE_API_PTR* dispatch_indirect)(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref);
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
} iree_hal_command_buffer_vtable_t;
IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t);
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c
index 87f3a4b..b27433c 100644
--- a/runtime/src/iree/hal/command_buffer_validation.c
+++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -603,7 +603,8 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings(
@@ -615,7 +616,7 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h
index 036d666..82ab1c5 100644
--- a/runtime/src/iree/hal/command_buffer_validation.h
+++ b/runtime/src/iree/hal/command_buffer_validation.h
@@ -142,13 +142,14 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z);
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags);
iree_status_t iree_hal_command_buffer_dispatch_indirect_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref);
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
iree_status_t iree_hal_command_buffer_binding_table_validation(
iree_hal_command_buffer_t* command_buffer,
diff --git a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
index 4d71207..6d19793 100644
--- a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
+++ b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
@@ -154,7 +154,8 @@
IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
command_buffer, executable_, /*entry_point=*/0,
- /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1));
+ /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1,
+ IREE_HAL_DISPATCH_FLAG_NONE));
IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
command_buffer,
/*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
diff --git a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
index af99ee1..06fa747 100644
--- a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
+++ b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
@@ -120,7 +120,8 @@
IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
command_buffer, executable_, /*entry_point=*/0,
- /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1));
+ /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1,
+ IREE_HAL_DISPATCH_FLAG_NONE));
IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
command_buffer,
/*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
index c747c3c..c53428a 100644
--- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
@@ -747,7 +747,8 @@
static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -873,7 +874,7 @@
static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect dispatch not yet implemented");
}
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
index 8f9cd2d..3369f3b 100644
--- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -533,7 +533,8 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -646,7 +647,7 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need cuda implementation of dispatch indirect");
}
diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
index 65b2c4c..ae66cfd 100644
--- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
@@ -771,7 +771,8 @@
static iree_status_t iree_hal_hip_graph_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_hip_graph_command_buffer_t* command_buffer =
iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -882,7 +883,7 @@
static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect dispatch not yet implemented");
}
diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
index a250299..0f08727 100644
--- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
@@ -525,7 +525,8 @@
static iree_status_t iree_hal_hip_stream_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -626,7 +627,7 @@
static iree_status_t iree_hal_hip_stream_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need hip implementation of dispatch indirect");
}
diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
index 95c503e..50d56c8 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
@@ -973,7 +973,8 @@
static iree_status_t iree_hal_task_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_task_command_buffer_t* command_buffer =
iree_hal_task_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
@@ -987,7 +988,7 @@
static iree_status_t iree_hal_task_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_task_command_buffer_t* command_buffer =
iree_hal_task_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
index 50d01e7..eaed4f5 100644
--- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -1073,7 +1073,7 @@
static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y,
- uint32_t workgroup_count_z) {
+ uint32_t workgroup_count_z, iree_hal_dispatch_flags_t flags) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_dispatch_segment_t* segment = NULL;
@@ -1090,7 +1090,7 @@
static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
- int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref) {
+ int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_dispatch_segment_t* segment = NULL;
diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
index b66be80..8bb9413 100644
--- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
@@ -725,7 +725,8 @@
static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_vulkan_direct_command_buffer_t* command_buffer =
iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
@@ -764,7 +765,7 @@
static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_vulkan_direct_command_buffer_t* command_buffer =
iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c
index f69f9c2..3de7c60 100644
--- a/runtime/src/iree/hal/local/inline_command_buffer.c
+++ b/runtime/src/iree/hal/local/inline_command_buffer.c
@@ -442,7 +442,8 @@
static iree_status_t iree_hal_inline_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_inline_command_buffer_t* command_buffer =
iree_hal_inline_command_buffer_cast(base_command_buffer);
@@ -559,7 +560,7 @@
static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
// TODO(benvanik): track mapping so we can properly map/unmap/flush/etc.
iree_hal_buffer_mapping_t buffer_mapping = {{0}};
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
@@ -570,7 +571,7 @@
*(const iree_hal_vec3_t*)buffer_mapping.contents.data;
return iree_hal_inline_command_buffer_dispatch(
base_command_buffer, executable, entry_point, workgroup_count.x,
- workgroup_count.y, workgroup_count.z);
+ workgroup_count.y, workgroup_count.z, flags);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c
index a4b805a..49ec334 100644
--- a/runtime/src/iree/hal/utils/deferred_command_buffer.c
+++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c
@@ -771,12 +771,14 @@
uint32_t workgroup_x;
uint32_t workgroup_y;
uint32_t workgroup_z;
+ iree_hal_dispatch_flags_t flags;
} iree_hal_cmd_dispatch_t;
static iree_status_t iree_hal_deferred_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_deferred_command_buffer_t* command_buffer =
iree_hal_deferred_command_buffer_cast(base_command_buffer);
iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
@@ -790,6 +792,7 @@
cmd->workgroup_x = workgroup_x;
cmd->workgroup_y = workgroup_y;
cmd->workgroup_z = workgroup_z;
+ cmd->flags = flags;
return iree_ok_status();
}
@@ -799,7 +802,7 @@
const iree_hal_cmd_dispatch_t* cmd) {
return iree_hal_command_buffer_dispatch(
target_command_buffer, cmd->executable, cmd->entry_point,
- cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z);
+ cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z, cmd->flags);
}
//===----------------------------------------------------------------------===//
@@ -811,12 +814,13 @@
iree_hal_executable_t* executable;
int32_t entry_point;
iree_hal_buffer_ref_t workgroups_ref;
+ iree_hal_dispatch_flags_t flags;
} iree_hal_cmd_dispatch_indirect_t;
static iree_status_t iree_hal_deferred_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_deferred_command_buffer_t* command_buffer =
iree_hal_deferred_command_buffer_cast(base_command_buffer);
iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
@@ -834,6 +838,7 @@
cmd->executable = executable;
cmd->entry_point = entry_point;
cmd->workgroups_ref = workgroups_ref;
+ cmd->flags = flags;
return iree_ok_status();
}
@@ -845,7 +850,8 @@
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->workgroups_ref, &workgroups_ref));
return iree_hal_command_buffer_dispatch_indirect(
- target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref);
+ target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref,
+ cmd->flags);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index 13f9d09..f6f96f2 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -47,17 +47,18 @@
EXPORT_FN("channel.split", iree_hal_module_channel_split, riii, r)
EXPORT_FN("command_buffer.begin_debug_group", iree_hal_module_command_buffer_begin_debug_group, rr, v)
-EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriirIIrIII, v)
-EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, rrIrII, v)
-EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riii, r)
-EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v)
-EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rrirI, v)
+EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v)
+EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v)
+EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r)
+EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiiiI, v)
+EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirII, v)
EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v)
EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
-EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIii, v)
+EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v)
EXPORT_FN("command_buffer.finalize", iree_hal_module_command_buffer_finalize, r, v)
EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v)
EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiirIID, v)
+EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrIIi, v)
EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 1b9f8df..fad75d0 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -32,8 +32,8 @@
// Module type definitions
//===----------------------------------------------------------------------===//
-#define IREE_HAL_MODULE_VERSION_0_2 0x00000002u
-#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_2
+#define IREE_HAL_MODULE_VERSION_0_3 0x00000003u
+#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_3
typedef struct iree_hal_module_t {
iree_allocator_t host_allocator;
@@ -670,19 +670,21 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_create, //
iree_hal_module_state_t, //
- riii, r) {
+ riiIi, r) {
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device));
iree_hal_command_buffer_mode_t modes =
(iree_hal_command_buffer_mode_t)args->i1;
iree_hal_command_category_t command_categories =
(iree_hal_command_category_t)args->i2;
- iree_host_size_t binding_capacity = (iree_host_size_t)args->i3;
+ iree_hal_queue_affinity_t queue_affinity =
+ (iree_hal_queue_affinity_t)args->i3;
+ iree_host_size_t binding_capacity = (iree_host_size_t)args->i4;
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create(
- device, modes, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY,
- binding_capacity, &command_buffer));
+ device, modes, command_categories, queue_affinity, binding_capacity,
+ &command_buffer));
iree_status_t status = iree_hal_command_buffer_begin(command_buffer);
if (iree_status_is_ok(status)) {
@@ -757,48 +759,76 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, //
iree_hal_module_state_t, //
- rrIIii, v) {
+ rrIIiii, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
- iree_hal_buffer_t* target_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i2);
iree_device_size_t length = iree_hal_cast_device_size(args->i3);
- uint32_t pattern = (uint32_t)args->i4;
- uint32_t pattern_length = (uint32_t)args->i5;
+ uint32_t target_buffer_slot = (uint32_t)args->i4;
+ iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
+ target_buffer_slot, target_offset, length);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r1, &target_ref.buffer));
+ uint32_t pattern = (uint32_t)args->i5;
+ uint32_t pattern_length = (uint32_t)args->i6;
- iree_hal_buffer_ref_t target_ref =
- iree_hal_make_buffer_ref(target_buffer, target_offset, length);
return iree_hal_command_buffer_fill_buffer(command_buffer, target_ref,
&pattern, pattern_length);
}
-IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, //
- iree_hal_module_state_t, //
- rrIrII, v) {
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_update_buffer, //
+ iree_hal_module_state_t, //
+ rrIrIIi, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
- iree_hal_buffer_t* source_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &source_buffer));
- iree_device_size_t source_offset = iree_hal_cast_device_size(args->i2);
- iree_hal_buffer_t* target_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer));
+ iree_vm_buffer_t* source_buffer = NULL;
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer));
+ iree_host_size_t source_offset = iree_hal_cast_host_size(args->i2);
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4);
iree_device_size_t length = iree_hal_cast_device_size(args->i5);
+ uint32_t target_buffer_slot = (uint32_t)args->i6;
+ iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
+ target_buffer_slot, target_offset, length);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r3, &target_ref.buffer));
- iree_hal_buffer_ref_t source_ref =
- iree_hal_make_buffer_ref(source_buffer, source_offset, length);
- iree_hal_buffer_ref_t target_ref =
- iree_hal_make_buffer_ref(target_buffer, target_offset, length);
+ iree_const_byte_span_t source_span = iree_const_byte_span_empty();
+ IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(
+ source_buffer, source_offset, (iree_host_size_t)length, 1, &source_span));
+
+ return iree_hal_command_buffer_update_buffer(command_buffer, source_span.data,
+ /*source_offset=*/0, target_ref);
+}
+
+IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, //
+ iree_hal_module_state_t, //
+ riirIrII, v) {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
+ uint32_t source_buffer_slot = (uint32_t)args->i1;
+ uint32_t target_buffer_slot = (uint32_t)args->i2;
+ iree_device_size_t source_offset = iree_hal_cast_device_size(args->i4);
+ iree_device_size_t target_offset = iree_hal_cast_device_size(args->i6);
+ iree_device_size_t length = iree_hal_cast_device_size(args->i7);
+ iree_hal_buffer_ref_t source_ref = iree_hal_make_indirect_buffer_ref(
+ source_buffer_slot, source_offset, length);
+ iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
+ target_buffer_slot, target_offset, length);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r3, &source_ref.buffer));
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r5, &target_ref.buffer));
+
return iree_hal_command_buffer_copy_buffer(command_buffer, source_ref,
target_ref);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_collective, //
iree_hal_module_state_t, //
- rriirIIrIII, v) {
+ rriiiirrIIIII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
@@ -806,17 +836,19 @@
IREE_RETURN_IF_ERROR(iree_hal_channel_check_deref(args->r1, &channel));
iree_hal_collective_op_t op = {.packed = args->i2};
uint32_t param = args->i3;
- iree_hal_buffer_ref_t send_ref =
- iree_hal_make_buffer_ref(NULL, iree_hal_cast_device_size(args->i5),
- iree_hal_cast_device_size(args->i6));
+ uint32_t send_buffer_slot = (uint32_t)args->i4;
+ uint32_t recv_buffer_slot = (uint32_t)args->i5;
+ iree_hal_buffer_ref_t send_ref = iree_hal_make_indirect_buffer_ref(
+ send_buffer_slot, iree_hal_cast_device_size(args->i8),
+ iree_hal_cast_device_size(args->i9));
IREE_RETURN_IF_ERROR(
- iree_hal_buffer_check_deref_or_null(args->r4, &send_ref.buffer));
- iree_hal_buffer_ref_t recv_ref =
- iree_hal_make_buffer_ref(NULL, iree_hal_cast_device_size(args->i8),
- iree_hal_cast_device_size(args->i9));
+ iree_hal_buffer_check_deref_or_null(args->r6, &send_ref.buffer));
+ iree_hal_buffer_ref_t recv_ref = iree_hal_make_indirect_buffer_ref(
+ recv_buffer_slot, iree_hal_cast_device_size(args->i10),
+ iree_hal_cast_device_size(args->i11));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r7, &recv_ref.buffer));
- iree_device_size_t element_count = iree_hal_cast_device_size(args->i10);
+ iree_device_size_t element_count = iree_hal_cast_device_size(args->i12);
return iree_hal_command_buffer_collective(command_buffer, channel, op, param,
send_ref, recv_ref, element_count);
@@ -875,7 +907,7 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, //
iree_hal_module_state_t, //
- rriiii, v) {
+ rriiiiI, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
@@ -885,30 +917,32 @@
uint32_t workgroup_x = (uint32_t)args->i3;
uint32_t workgroup_y = (uint32_t)args->i4;
uint32_t workgroup_z = (uint32_t)args->i5;
+ iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6;
return iree_hal_command_buffer_dispatch(command_buffer, executable,
entry_point, workgroup_x, workgroup_y,
- workgroup_z);
+ workgroup_z, flags);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, //
iree_hal_module_state_t, //
- rrirI, v) {
+ rriirII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_executable_t* executable = NULL;
IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
uint32_t entry_point = (uint32_t)args->i2;
- iree_hal_buffer_t* workgroups_buffer = NULL;
+ uint32_t workgroups_buffer_slot = (uint32_t)args->i3;
+ iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i5);
+ iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_indirect_buffer_ref(
+ workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t));
IREE_RETURN_IF_ERROR(
- iree_hal_buffer_check_deref(args->r3, &workgroups_buffer));
- iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i4);
+ iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer));
+ iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6;
- iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_buffer_ref(
- workgroups_buffer, workgroups_offset, 3 * sizeof(uint32_t));
- return iree_hal_command_buffer_dispatch_indirect(command_buffer, executable,
- entry_point, workgroups_ref);
+ return iree_hal_command_buffer_dispatch_indirect(
+ command_buffer, executable, entry_point, workgroups_ref, flags);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 9c8b68e..5bd69a7 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -42,9 +42,10 @@
IREE_VM_ABI_DEFINE_SHIM(riii, r);
IREE_VM_ABI_DEFINE_SHIM(riiI, r);
IREE_VM_ABI_DEFINE_SHIM(riii, v);
+IREE_VM_ABI_DEFINE_SHIM(riiIi, r);
IREE_VM_ABI_DEFINE_SHIM(rIiiI, r);
IREE_VM_ABI_DEFINE_SHIM(riIiirII, r);
-IREE_VM_ABI_DEFINE_SHIM(rriirIIrIII, v);
+IREE_VM_ABI_DEFINE_SHIM(rriiiirrIIIII, v);
IREE_VM_ABI_DEFINE_SHIM(rrrrCrD, r);
IREE_VM_ABI_DEFINE_SHIM(ririi, v);
IREE_VM_ABI_DEFINE_SHIM(rr, i);
@@ -58,11 +59,13 @@
IREE_VM_ABI_DEFINE_SHIM(rriCiD, v);
IREE_VM_ABI_DEFINE_SHIM(rriiCID, v);
IREE_VM_ABI_DEFINE_SHIM(rriCiirIID, v);
-IREE_VM_ABI_DEFINE_SHIM(rriiii, v);
-IREE_VM_ABI_DEFINE_SHIM(rrIIii, v);
+IREE_VM_ABI_DEFINE_SHIM(rriiiiI, v);
+IREE_VM_ABI_DEFINE_SHIM(rrIIiii, v);
IREE_VM_ABI_DEFINE_SHIM(rrirCID, v);
IREE_VM_ABI_DEFINE_SHIM(rrirI, v);
-IREE_VM_ABI_DEFINE_SHIM(rrIrII, v);
+IREE_VM_ABI_DEFINE_SHIM(rriirII, v);
+IREE_VM_ABI_DEFINE_SHIM(rrIrIIi, v);
+IREE_VM_ABI_DEFINE_SHIM(riirIrII, v);
IREE_VM_ABI_DEFINE_SHIM(rrIii, v);
IREE_VM_ABI_DEFINE_SHIM(rrrIii, v);
IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index ab2f012..b47428c 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -313,6 +313,14 @@
int32_t i4;
});
+IREE_VM_ABI_FIXED_STRUCT(riiIi, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+ int64_t i3;
+ int32_t i4;
+});
+
IREE_VM_ABI_FIXED_STRUCT(riiI, {
iree_vm_ref_t r0;
int32_t i1;
@@ -347,36 +355,40 @@
int64_t i7;
});
-IREE_VM_ABI_FIXED_STRUCT(rriirIIrIII, {
- iree_vm_ref_t r0;
- iree_vm_ref_t r1;
- int32_t i2;
- int32_t i3;
- iree_vm_ref_t r4;
- int64_t i5;
- int64_t i6;
- iree_vm_ref_t r7;
- int64_t i8;
- int64_t i9;
- int64_t i10;
-});
-
-IREE_VM_ABI_FIXED_STRUCT(rriiii, {
+IREE_VM_ABI_FIXED_STRUCT(rriiiirrIIIII, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int32_t i2;
int32_t i3;
int32_t i4;
int32_t i5;
+ iree_vm_ref_t r6;
+ iree_vm_ref_t r7;
+ int64_t i8;
+ int64_t i9;
+ int64_t i10;
+ int64_t i11;
+ int64_t i12;
});
-IREE_VM_ABI_FIXED_STRUCT(rrIIii, {
+IREE_VM_ABI_FIXED_STRUCT(rriiiiI, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ int32_t i3;
+ int32_t i4;
+ int32_t i5;
+ int64_t i6;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rrIIiii, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int64_t i2;
int64_t i3;
int32_t i4;
int32_t i5;
+ int32_t i6;
});
IREE_VM_ABI_FIXED_STRUCT(rrirI, {
@@ -387,13 +399,35 @@
int64_t i4;
});
-IREE_VM_ABI_FIXED_STRUCT(rrIrII, {
+IREE_VM_ABI_FIXED_STRUCT(rriirII, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ int32_t i3;
+ iree_vm_ref_t r4;
+ int64_t i5;
+ int64_t i6;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rrIrIIi, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int64_t i2;
iree_vm_ref_t r3;
int64_t i4;
int64_t i5;
+ int32_t i6;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(riirIrII, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+ iree_vm_ref_t r3;
+ int64_t i4;
+ iree_vm_ref_t r5;
+ int64_t i6;
+ int64_t i7;
});
IREE_VM_ABI_FIXED_STRUCT(rrIii, {
@@ -659,9 +693,10 @@
IREE_VM_ABI_DECLARE_SHIM(riii, r);
IREE_VM_ABI_DECLARE_SHIM(riiI, r);
IREE_VM_ABI_DECLARE_SHIM(riii, v);
+IREE_VM_ABI_DECLARE_SHIM(riiIi, r);
IREE_VM_ABI_DECLARE_SHIM(rIiiI, r);
IREE_VM_ABI_DECLARE_SHIM(riIiirII, r);
-IREE_VM_ABI_DECLARE_SHIM(rriirIIrIII, v);
+IREE_VM_ABI_DECLARE_SHIM(rriiiirrIIIII, v);
IREE_VM_ABI_DECLARE_SHIM(rrrrCrD, r);
IREE_VM_ABI_DECLARE_SHIM(ririi, v);
IREE_VM_ABI_DECLARE_SHIM(rr, i);
@@ -675,11 +710,13 @@
IREE_VM_ABI_DECLARE_SHIM(rriCiD, v);
IREE_VM_ABI_DECLARE_SHIM(rriiCID, v);
IREE_VM_ABI_DECLARE_SHIM(rriCiirIID, v);
-IREE_VM_ABI_DECLARE_SHIM(rriiii, v);
-IREE_VM_ABI_DECLARE_SHIM(rrIIii, v);
+IREE_VM_ABI_DECLARE_SHIM(rriiiiI, v);
+IREE_VM_ABI_DECLARE_SHIM(rrIIiii, v);
IREE_VM_ABI_DECLARE_SHIM(rrirCID, v);
IREE_VM_ABI_DECLARE_SHIM(rrirI, v);
-IREE_VM_ABI_DECLARE_SHIM(rrIrII, v);
+IREE_VM_ABI_DECLARE_SHIM(rriirII, v);
+IREE_VM_ABI_DECLARE_SHIM(rrIrIIi, v);
+IREE_VM_ABI_DECLARE_SHIM(riirIrII, v);
IREE_VM_ABI_DECLARE_SHIM(rrIii, v);
IREE_VM_ABI_DECLARE_SHIM(rrrIii, v);
IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r);
diff --git a/tools/iree-benchmark-executable-main.c b/tools/iree-benchmark-executable-main.c
index f3a0b14..c603cb8 100644
--- a/tools/iree-benchmark-executable-main.c
+++ b/tools/iree-benchmark-executable-main.c
@@ -277,7 +277,7 @@
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_dispatch(
command_buffer, args->executable, FLAG_entry_point,
args->workgroup_count[0], args->workgroup_count[1],
- args->workgroup_count[2]));
+ args->workgroup_count[2], IREE_HAL_DISPATCH_FLAG_NONE));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_execution_barrier(
command_buffer, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE,
IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE,