Adding `iree_hal_dispatch_flags_t` to dispatch operations.
This is currently unused but may be useful for specifying scheduling
behavior or something else.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
index 7d2977e..bee6557 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
@@ -94,9 +94,9 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
return
}
util.initializer {
@@ -111,9 +111,9 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
util.return
}
@@ -304,10 +304,10 @@
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index
%dispatch_3_ordinal = hal.executable.export.ordinal target(@dispatch_3::@spirv::@dispatch_3) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
return
}
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
index af83f1b..2baedf6 100644
--- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
+++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
@@ -87,9 +87,9 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
return
}
util.initializer {
@@ -104,9 +104,9 @@
%dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index
%dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index
%dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None)
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index 71be493..cb4179f 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -396,6 +396,13 @@
auto importType = importOp.getFunctionType();
auto [workgroupsBufferSlot, workgroupsBuffer] =
splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
+ auto flags = adaptor.getFlagsAttr()
+ ? rewriter
+ .create<IREE::VM::ConstI64Op>(
+ op.getLoc(), adaptor.getFlagsAttr().getInt())
+ .getResult()
+ : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc())
+ .getResult();
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getExecutable(),
@@ -405,6 +412,7 @@
workgroupsBuffer,
castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(),
rewriter),
+ flags,
};
auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
op, SymbolRefAttr::get(importOp), importType.getResults(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
index 61ca923..2df6959 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -334,15 +334,17 @@
%cmd: !hal.command_buffer,
%executable: !hal.executable
) {
- // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300)
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300, %[[FLAGS]])
hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups([%c100, %c200, %c300])
+ flags(None)
util.return
}
@@ -361,10 +363,12 @@
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
// CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
- // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100)
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[FLAGS]])
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups(%buffer : !hal.buffer)[%c100]
+ flags(None)
util.return
}
@@ -383,9 +387,11 @@
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
// CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
- // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100)
+ // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[FLAGS]])
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable : !hal.executable)[%ordinal]
workgroups(%buffer_slot : index)[%c100]
+ flags(None)
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
index 274b348..74c0fc9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -731,9 +731,11 @@
entryPointAttr.getRootReference().getValue());
Value ordinal = caseBuilder.create<IREE::HAL::ExecutableExportOrdinalOp>(
loc, caseBuilder.getIndexType(), entryPointAttr);
+ auto flags = caseBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>(
+ IREE::HAL::DispatchFlags::None);
caseBuilder.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0],
- caseWorkgroupCount[1], caseWorkgroupCount[2]);
+ caseWorkgroupCount[1], caseWorkgroupCount[2], flags);
caseBuilder.create<scf::YieldOp>(loc);
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index e2d56d2..9d85020 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -186,6 +186,14 @@
let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
}
+def HAL_DispatchFlags_None : I64BitEnumAttrCase<"None", 0x0000>;
+def HAL_DispatchFlagsAttr :
+ I64BitEnumAttr<"DispatchFlags", "valid dispatch flags", [
+ HAL_DispatchFlags_None,
+ ]> {
+ let cppNamespace = "::mlir::iree_compiler::IREE::HAL";
+}
+
def HAL_ExecutionStage_None : I32BitEnumAttrCase<"None", 0x0000>;
def HAL_ExecutionStage_CommandIssue : I32BitEnumAttrCase<"CommandIssue", 0x0001>;
def HAL_ExecutionStage_CommandProcess : I32BitEnumAttrCase<"CommandProcess", 0x0002>;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index d0abb71..599c1ff 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1502,7 +1502,8 @@
HAL_Ordinal:$entry_point,
HAL_Dim:$workgroup_x,
HAL_Dim:$workgroup_y,
- HAL_Dim:$workgroup_z
+ HAL_Dim:$workgroup_z,
+ HAL_DispatchFlagsAttr:$flags
);
let assemblyFormat = [{
@@ -1514,6 +1515,7 @@
$workgroup_y `,`
$workgroup_z
`]` `)`
+ `flags` `(` $flags `)`
attr-dict-with-keyword
}];
}
@@ -1530,7 +1532,8 @@
HAL_Executable:$executable,
HAL_Ordinal:$entry_point,
AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer,
- HAL_DeviceSize:$workgroups_offset
+ HAL_DeviceSize:$workgroups_offset,
+ HAL_DispatchFlagsAttr:$flags
);
let assemblyFormat = [{
@@ -1539,6 +1542,7 @@
`` `[` $entry_point `]`
`workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)`
`` `[` $workgroups_offset `]`
+ `flags` `(` $flags `)`
attr-dict-with-keyword
}];
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
index 5598e39..77d56e3 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -266,9 +266,11 @@
// CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
// CHECK-SAME: workgroups([%[[X]], %[[Y]], %[[Z]]])
+ // CHECK-SAME: flags("None")
hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
target(%executable: !hal.executable)[%ordinal]
workgroups([%x, %y, %z])
+ flags("None")
util.return
}
@@ -296,9 +298,11 @@
// CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
// CHECK-SAME: workgroups(%[[BUFFER]] : !hal.buffer)[%[[OFFSET]]]
+ // CHECK-SAME: flags("None")
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable: !hal.executable)[%ordinal]
workgroups(%buffer : !hal.buffer)[%offset]
+ flags("None")
util.return
}
@@ -326,8 +330,10 @@
// CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer>
// CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
// CHECK-SAME: workgroups(%[[BUFFER_SLOT]] : index)[%[[OFFSET]]]
+ // CHECK-SAME: flags("None")
hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
target(%executable: !hal.executable)[%ordinal]
workgroups(%buffer_slot : index)[%offset]
+ flags("None")
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
index 4ec14bf..c2487b8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -353,10 +353,12 @@
loc, indexSet.get(0), batchSizeArg, indexSet.get(1), ValueRange{},
[&](OpBuilder &forBuilder, Location loc, Value iv, ValueRange iters) {
// Dispatch.
+ auto flags = forBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>(
+ IREE::HAL::DispatchFlags::None);
forBuilder.create<IREE::HAL::CommandBufferDispatchOp>(
loc, commandBuffer, executable, ordinal,
workgroupCountOp.getWorkgroupX(), workgroupCountOp.getWorkgroupY(),
- workgroupCountOp.getWorkgroupZ());
+ workgroupCountOp.getWorkgroupZ(), flags);
// Barrier following the dispatch to block the next dispatch.
auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire |
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
index e7b80e1..a139ece 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
@@ -13,12 +13,12 @@
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
- hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None)
hal.command_buffer.execution_barrier<%cmd1 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
- hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2])
+ hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2]) flags(None)
- hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1])
- hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2])
+ hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1]) flags(None)
+ hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2]) flags(None)
hal.command_buffer.execution_barrier<%cmd2 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
util.return
@@ -59,7 +59,7 @@
%c1 = arith.constant 1 : index
scf.index_switch %idx
case 0 {
- hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1])
+ hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None)
scf.yield
}
default {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index 0a923c9..d68b865 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -309,7 +309,8 @@
%entry_point : i32,
%workgroup_x : i32,
%workgroup_y : i32,
- %workgroup_z : i32
+ %workgroup_z : i32,
+ %flags : i64
)
// Dispatches an execution request with the dispatch parameters loaded from the
@@ -320,7 +321,8 @@
%entry_point : i32,
%workgroups_buffer_slot : i32,
%workgroups_buffer : !vm.ref<!hal.buffer>,
- %workgroups_offset : i64
+ %workgroups_offset : i64,
+ %flags : i64
)
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
index 8c3c502..fa837d4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -127,11 +127,11 @@
// one of the vm constant ops.
auto constValue = builder.create<mlir::arith::ConstantOp>(
loc, inputType,
- IntegerAttr::get(inputType,
- APInt(32, static_cast<int32_t>(intAttr.getInt()))));
+ IntegerAttr::get(inputType, APInt(inputType.getIntOrFloatBitWidth(),
+ intAttr.getValue().getSExtValue())));
return {{constValue}};
- }
- if (auto elementsAttr = llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) {
+ } else if (auto elementsAttr =
+ llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) {
SmallVector<Value> elementValues;
elementValues.reserve(elementsAttr.getNumElements());
for (auto intAttr : elementsAttr.getValues<Attribute>()) {
@@ -140,8 +140,7 @@
cast<TypedAttr>(intAttr)));
}
return elementValues;
- }
- if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) {
+ } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) {
SmallVector<Value> allValues;
for (auto elementAttr : arrayAttr) {
auto flattenedValues =
@@ -151,8 +150,7 @@
allValues.append(flattenedValues->begin(), flattenedValues->end());
}
return allValues;
- }
- if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) {
+ } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) {
return {{builder.create<IREE::VM::RodataInlineOp>(loc, strAttr)}};
}
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index fe165e5..42b476b 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -412,7 +412,8 @@
static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_rocm_direct_command_buffer_t* command_buffer =
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
// Lookup kernel parameters used for side-channeling additional launch
@@ -463,7 +464,7 @@
static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need rocm implementation");
}
diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c
index d57dee4..de89e4f 100644
--- a/experimental/webgpu/command_buffer.c
+++ b/experimental/webgpu/command_buffer.c
@@ -884,7 +884,8 @@
static iree_status_t iree_hal_webgpu_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_webgpu_command_buffer_t* command_buffer =
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
@@ -900,7 +901,7 @@
static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_webgpu_command_buffer_t* command_buffer =
iree_hal_webgpu_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c
index 38619a1..7f3785d 100644
--- a/runtime/src/iree/hal/command_buffer.c
+++ b/runtime/src/iree/hal/command_buffer.c
@@ -558,7 +558,8 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
if ((workgroup_x | workgroup_y | workgroup_z) == 0) {
@@ -574,7 +575,7 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
- entry_point, workgroup_x, workgroup_y, workgroup_z));
+ entry_point, workgroup_x, workgroup_y, workgroup_z, flags));
});
#if IREE_HAL_VERBOSE_TRACING_ENABLE
// TODO(benvanik): add a tracing.h helper that does the snprintf directly
@@ -594,7 +595,7 @@
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)(
command_buffer, executable, entry_point, workgroup_x, workgroup_y,
- workgroup_z);
+ workgroup_z, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -602,7 +603,7 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -610,10 +611,10 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_indirect_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
- entry_point, workgroups_ref));
+ entry_point, workgroups_ref, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)(
- command_buffer, executable, entry_point, workgroups_ref);
+ command_buffer, executable, entry_point, workgroups_ref, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index 38ac03d..c9c6037 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -384,6 +384,12 @@
IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count(
iree_hal_collective_element_type_t element_type);
+// Bitfield specifying flags controlling a dispatch operation.
+enum iree_hal_dispatch_flag_bits_t {
+ IREE_HAL_DISPATCH_FLAG_NONE = 0,
+};
+typedef uint64_t iree_hal_dispatch_flags_t;
+
// An RGBA color.
typedef struct iree_hal_label_color_t {
uint8_t r;
@@ -751,7 +757,8 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z);
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags);
// Dispatches an execution request with deferred workgroup counts.
// This is the same as iree_hal_command_buffer_dispatch but the workgroup counts
@@ -765,7 +772,7 @@
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref);
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
//===----------------------------------------------------------------------===//
// Validation support
@@ -922,12 +929,13 @@
iree_status_t(IREE_API_PTR* dispatch)(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z);
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags);
iree_status_t(IREE_API_PTR* dispatch_indirect)(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref);
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
} iree_hal_command_buffer_vtable_t;
IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t);
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c
index 87f3a4b..b27433c 100644
--- a/runtime/src/iree/hal/command_buffer_validation.c
+++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -603,7 +603,8 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings(
@@ -615,7 +616,7 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h
index 036d666..82ab1c5 100644
--- a/runtime/src/iree/hal/command_buffer_validation.h
+++ b/runtime/src/iree/hal/command_buffer_validation.h
@@ -142,13 +142,14 @@
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z);
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags);
iree_status_t iree_hal_command_buffer_dispatch_indirect_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref);
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags);
iree_status_t iree_hal_command_buffer_binding_table_validation(
iree_hal_command_buffer_t* command_buffer,
diff --git a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
index 4d71207..6d19793 100644
--- a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
+++ b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
@@ -154,7 +154,8 @@
IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
command_buffer, executable_, /*entry_point=*/0,
- /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1));
+ /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1,
+ IREE_HAL_DISPATCH_FLAG_NONE));
IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
command_buffer,
/*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
diff --git a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
index af99ee1..06fa747 100644
--- a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
+++ b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
@@ -120,7 +120,8 @@
IREE_ASSERT_OK(iree_hal_command_buffer_dispatch(
command_buffer, executable_, /*entry_point=*/0,
- /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1));
+ /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1,
+ IREE_HAL_DISPATCH_FLAG_NONE));
IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier(
command_buffer,
/*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
index c747c3c..c53428a 100644
--- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
@@ -747,7 +747,8 @@
static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -873,7 +874,7 @@
static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect dispatch not yet implemented");
}
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
index 8f9cd2d..3369f3b 100644
--- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -533,7 +533,8 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -646,7 +647,7 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need cuda implementation of dispatch indirect");
}
diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
index 65b2c4c..ae66cfd 100644
--- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
@@ -771,7 +771,8 @@
static iree_status_t iree_hal_hip_graph_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_hip_graph_command_buffer_t* command_buffer =
iree_hal_hip_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -882,7 +883,7 @@
static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"indirect dispatch not yet implemented");
}
diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
index a250299..0f08727 100644
--- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
@@ -525,7 +525,8 @@
static iree_status_t iree_hal_hip_stream_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_hip_stream_command_buffer_t* command_buffer =
iree_hal_hip_stream_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
@@ -626,7 +627,7 @@
static iree_status_t iree_hal_hip_stream_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"need hip implementation of dispatch indirect");
}
diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
index 95c503e..50d56c8 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
@@ -973,7 +973,8 @@
static iree_status_t iree_hal_task_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_task_command_buffer_t* command_buffer =
iree_hal_task_command_buffer_cast(base_command_buffer);
IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
@@ -987,7 +988,7 @@
static iree_status_t iree_hal_task_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_task_command_buffer_t* command_buffer =
iree_hal_task_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
index 50d01e7..eaed4f5 100644
--- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
+++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -1073,7 +1073,7 @@
static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y,
- uint32_t workgroup_count_z) {
+ uint32_t workgroup_count_z, iree_hal_dispatch_flags_t flags) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_dispatch_segment_t* segment = NULL;
@@ -1090,7 +1090,7 @@
static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable,
- int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref) {
+ int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_dispatch_segment_t* segment = NULL;
diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
index b66be80..8bb9413 100644
--- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
@@ -725,7 +725,8 @@
static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_vulkan_direct_command_buffer_t* command_buffer =
iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
@@ -764,7 +765,7 @@
static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_vulkan_direct_command_buffer_t* command_buffer =
iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c
index f69f9c2..3de7c60 100644
--- a/runtime/src/iree/hal/local/inline_command_buffer.c
+++ b/runtime/src/iree/hal/local/inline_command_buffer.c
@@ -442,7 +442,8 @@
static iree_status_t iree_hal_inline_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_inline_command_buffer_t* command_buffer =
iree_hal_inline_command_buffer_cast(base_command_buffer);
@@ -559,7 +560,7 @@
static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
// TODO(benvanik): track mapping so we can properly map/unmap/flush/etc.
iree_hal_buffer_mapping_t buffer_mapping = {{0}};
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
@@ -570,7 +571,7 @@
*(const iree_hal_vec3_t*)buffer_mapping.contents.data;
return iree_hal_inline_command_buffer_dispatch(
base_command_buffer, executable, entry_point, workgroup_count.x,
- workgroup_count.y, workgroup_count.z);
+ workgroup_count.y, workgroup_count.z, flags);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c
index a4b805a..49ec334 100644
--- a/runtime/src/iree/hal/utils/deferred_command_buffer.c
+++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c
@@ -771,12 +771,14 @@
uint32_t workgroup_x;
uint32_t workgroup_y;
uint32_t workgroup_z;
+ iree_hal_dispatch_flags_t flags;
} iree_hal_cmd_dispatch_t;
static iree_status_t iree_hal_deferred_command_buffer_dispatch(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
+ uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z,
+ iree_hal_dispatch_flags_t flags) {
iree_hal_deferred_command_buffer_t* command_buffer =
iree_hal_deferred_command_buffer_cast(base_command_buffer);
iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
@@ -790,6 +792,7 @@
cmd->workgroup_x = workgroup_x;
cmd->workgroup_y = workgroup_y;
cmd->workgroup_z = workgroup_z;
+ cmd->flags = flags;
return iree_ok_status();
}
@@ -799,7 +802,7 @@
const iree_hal_cmd_dispatch_t* cmd) {
return iree_hal_command_buffer_dispatch(
target_command_buffer, cmd->executable, cmd->entry_point,
- cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z);
+ cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z, cmd->flags);
}
//===----------------------------------------------------------------------===//
@@ -811,12 +814,13 @@
iree_hal_executable_t* executable;
int32_t entry_point;
iree_hal_buffer_ref_t workgroups_ref;
+ iree_hal_dispatch_flags_t flags;
} iree_hal_cmd_dispatch_indirect_t;
static iree_status_t iree_hal_deferred_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
- iree_hal_buffer_ref_t workgroups_ref) {
+ iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) {
iree_hal_deferred_command_buffer_t* command_buffer =
iree_hal_deferred_command_buffer_cast(base_command_buffer);
iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
@@ -834,6 +838,7 @@
cmd->executable = executable;
cmd->entry_point = entry_point;
cmd->workgroups_ref = workgroups_ref;
+ cmd->flags = flags;
return iree_ok_status();
}
@@ -845,7 +850,8 @@
IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref(
binding_table, cmd->workgroups_ref, &workgroups_ref));
return iree_hal_command_buffer_dispatch_indirect(
- target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref);
+ target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref,
+ cmd->flags);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index b87a8cb..f6f96f2 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -50,8 +50,8 @@
EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v)
EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v)
EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r)
-EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v)
-EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirI, v)
+EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiiiI, v)
+EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirII, v)
EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v)
EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index c0db04b..777bcb3 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -907,7 +907,7 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, //
iree_hal_module_state_t, //
- rriiii, v) {
+ rriiiiI, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
@@ -917,15 +917,16 @@
uint32_t workgroup_x = (uint32_t)args->i3;
uint32_t workgroup_y = (uint32_t)args->i4;
uint32_t workgroup_z = (uint32_t)args->i5;
+ iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6;
return iree_hal_command_buffer_dispatch(command_buffer, executable,
entry_point, workgroup_x, workgroup_y,
- workgroup_z);
+ workgroup_z, flags);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, //
iree_hal_module_state_t, //
- rriirI, v) {
+ rriirII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
@@ -938,9 +939,10 @@
workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer));
+ iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6;
- return iree_hal_command_buffer_dispatch_indirect(command_buffer, executable,
- entry_point, workgroups_ref);
+ return iree_hal_command_buffer_dispatch_indirect(
+ command_buffer, executable, entry_point, workgroups_ref, flags);
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index c89a8b7..5bd69a7 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -59,11 +59,11 @@
IREE_VM_ABI_DEFINE_SHIM(rriCiD, v);
IREE_VM_ABI_DEFINE_SHIM(rriiCID, v);
IREE_VM_ABI_DEFINE_SHIM(rriCiirIID, v);
-IREE_VM_ABI_DEFINE_SHIM(rriiii, v);
+IREE_VM_ABI_DEFINE_SHIM(rriiiiI, v);
IREE_VM_ABI_DEFINE_SHIM(rrIIiii, v);
IREE_VM_ABI_DEFINE_SHIM(rrirCID, v);
IREE_VM_ABI_DEFINE_SHIM(rrirI, v);
-IREE_VM_ABI_DEFINE_SHIM(rriirI, v);
+IREE_VM_ABI_DEFINE_SHIM(rriirII, v);
IREE_VM_ABI_DEFINE_SHIM(rrIrIIi, v);
IREE_VM_ABI_DEFINE_SHIM(riirIrII, v);
IREE_VM_ABI_DEFINE_SHIM(rrIii, v);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index 1d5a3aa..b47428c 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -371,13 +371,14 @@
int64_t i12;
});
-IREE_VM_ABI_FIXED_STRUCT(rriiii, {
+IREE_VM_ABI_FIXED_STRUCT(rriiiiI, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int32_t i2;
int32_t i3;
int32_t i4;
int32_t i5;
+ int64_t i6;
});
IREE_VM_ABI_FIXED_STRUCT(rrIIiii, {
@@ -398,13 +399,14 @@
int64_t i4;
});
-IREE_VM_ABI_FIXED_STRUCT(rriirI, {
+IREE_VM_ABI_FIXED_STRUCT(rriirII, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int32_t i2;
int32_t i3;
iree_vm_ref_t r4;
int64_t i5;
+ int64_t i6;
});
IREE_VM_ABI_FIXED_STRUCT(rrIrIIi, {
@@ -708,11 +710,11 @@
IREE_VM_ABI_DECLARE_SHIM(rriCiD, v);
IREE_VM_ABI_DECLARE_SHIM(rriiCID, v);
IREE_VM_ABI_DECLARE_SHIM(rriCiirIID, v);
-IREE_VM_ABI_DECLARE_SHIM(rriiii, v);
+IREE_VM_ABI_DECLARE_SHIM(rriiiiI, v);
IREE_VM_ABI_DECLARE_SHIM(rrIIiii, v);
IREE_VM_ABI_DECLARE_SHIM(rrirCID, v);
IREE_VM_ABI_DECLARE_SHIM(rrirI, v);
-IREE_VM_ABI_DECLARE_SHIM(rriirI, v);
+IREE_VM_ABI_DECLARE_SHIM(rriirII, v);
IREE_VM_ABI_DECLARE_SHIM(rrIrIIi, v);
IREE_VM_ABI_DECLARE_SHIM(riirIrII, v);
IREE_VM_ABI_DECLARE_SHIM(rrIii, v);
diff --git a/tools/iree-benchmark-executable-main.c b/tools/iree-benchmark-executable-main.c
index f3a0b14..c603cb8 100644
--- a/tools/iree-benchmark-executable-main.c
+++ b/tools/iree-benchmark-executable-main.c
@@ -277,7 +277,7 @@
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_dispatch(
command_buffer, args->executable, FLAG_entry_point,
args->workgroup_count[0], args->workgroup_count[1],
- args->workgroup_count[2]));
+ args->workgroup_count[2], IREE_HAL_DISPATCH_FLAG_NONE));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_execution_barrier(
command_buffer, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE,
IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE,