Adding `iree_hal_dispatch_flags_t` to dispatch operations. This is currently unused but may be useful for specifying scheduling behavior or something else.
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir index 7d2977e..bee6557 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir
@@ -94,9 +94,9 @@ %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) return } util.initializer { @@ -111,9 +111,9 @@ %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) util.return } @@ -304,10 +304,10 @@ %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index %dispatch_3_ordinal = hal.executable.export.ordinal target(@dispatch_3::@spirv::@dispatch_3) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1]) flags(None) return }
diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir index af83f1b..2baedf6 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir
@@ -87,9 +87,9 @@ %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) return } util.initializer { @@ -104,9 +104,9 @@ %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) util.return }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 71be493..cb4179f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -396,6 +396,13 @@ auto importType = importOp.getFunctionType(); auto [workgroupsBufferSlot, workgroupsBuffer] = splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter); + auto flags = adaptor.getFlagsAttr() + ? rewriter + .create<IREE::VM::ConstI64Op>( + op.getLoc(), adaptor.getFlagsAttr().getInt()) + .getResult() + : rewriter.create<IREE::VM::ConstI64ZeroOp>(op.getLoc()) + .getResult(); SmallVector<Value, 8> callOperands = { adaptor.getCommandBuffer(), adaptor.getExecutable(), @@ -405,6 +412,7 @@ workgroupsBuffer, castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(), rewriter), + flags, }; auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>( op, SymbolRefAttr::get(importOp), importType.getResults(),
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 61ca923..2df6959 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -334,15 +334,17 @@ %cmd: !hal.command_buffer, %executable: !hal.executable ) { - // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123 + // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123 %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300) + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300, %[[FLAGS]]) hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%executable : !hal.executable)[%ordinal] workgroups([%c100, %c200, %c300]) + flags(None) util.return } @@ -361,10 +363,12 @@ %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero - // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100) + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[FLAGS]]) hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable : !hal.executable)[%ordinal] workgroups(%buffer : !hal.buffer)[%c100] + flags(None) util.return } @@ -383,9 +387,11 @@ %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer> - // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100) + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[FLAGS]]) hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable : !hal.executable)[%ordinal] workgroups(%buffer_slot : index)[%c100] + flags(None) util.return }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 274b348..74c0fc9 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp
@@ -731,9 +731,11 @@ entryPointAttr.getRootReference().getValue()); Value ordinal = caseBuilder.create<IREE::HAL::ExecutableExportOrdinalOp>( loc, caseBuilder.getIndexType(), entryPointAttr); + auto flags = caseBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>( + IREE::HAL::DispatchFlags::None); caseBuilder.create<IREE::HAL::CommandBufferDispatchOp>( loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0], - caseWorkgroupCount[1], caseWorkgroupCount[2]); + caseWorkgroupCount[1], caseWorkgroupCount[2], flags); caseBuilder.create<scf::YieldOp>(loc); }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index e2d56d2..9d85020 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -186,6 +186,14 @@ let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; } +def HAL_DispatchFlags_None : I64BitEnumAttrCase<"None", 0x0000>; +def HAL_DispatchFlagsAttr : + I64BitEnumAttr<"DispatchFlags", "valid dispatch flags", [ + HAL_DispatchFlags_None, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + def HAL_ExecutionStage_None : I32BitEnumAttrCase<"None", 0x0000>; def HAL_ExecutionStage_CommandIssue : I32BitEnumAttrCase<"CommandIssue", 0x0001>; def HAL_ExecutionStage_CommandProcess : I32BitEnumAttrCase<"CommandProcess", 0x0002>;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index d0abb71..599c1ff 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1502,7 +1502,8 @@ HAL_Ordinal:$entry_point, HAL_Dim:$workgroup_x, HAL_Dim:$workgroup_y, - HAL_Dim:$workgroup_z + HAL_Dim:$workgroup_z, + HAL_DispatchFlagsAttr:$flags ); let assemblyFormat = [{ @@ -1514,6 +1515,7 @@ $workgroup_y `,` $workgroup_z `]` `)` + `flags` `(` $flags `)` attr-dict-with-keyword }]; } @@ -1530,7 +1532,8 @@ HAL_Executable:$executable, HAL_Ordinal:$entry_point, AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer, - HAL_DeviceSize:$workgroups_offset + HAL_DeviceSize:$workgroups_offset, + HAL_DispatchFlagsAttr:$flags ); let assemblyFormat = [{ @@ -1539,6 +1542,7 @@ `` `[` $entry_point `]` `workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)` `` `[` $workgroups_offset `]` + `flags` `(` $flags `)` attr-dict-with-keyword }]; }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index 5598e39..77d56e3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -266,9 +266,11 @@ // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] // CHECK-SAME: workgroups([%[[X]], %[[Y]], %[[Z]]]) + // CHECK-SAME: flags("None") hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%executable: !hal.executable)[%ordinal] workgroups([%x, %y, %z]) + flags("None") util.return } @@ -296,9 +298,11 @@ // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] // CHECK-SAME: workgroups(%[[BUFFER]] : !hal.buffer)[%[[OFFSET]]] + // CHECK-SAME: flags("None") hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable: !hal.executable)[%ordinal] workgroups(%buffer : !hal.buffer)[%offset] + flags("None") util.return } @@ -326,8 +330,10 @@ // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] // CHECK-SAME: workgroups(%[[BUFFER_SLOT]] : index)[%[[OFFSET]]] + // CHECK-SAME: flags("None") hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable: !hal.executable)[%ordinal] workgroups(%buffer_slot : index)[%offset] + flags("None") util.return }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 4ec14bf..c2487b8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp
@@ -353,10 +353,12 @@ loc, indexSet.get(0), batchSizeArg, indexSet.get(1), ValueRange{}, [&](OpBuilder &forBuilder, Location loc, Value iv, ValueRange iters) { // Dispatch. + auto flags = forBuilder.getAttr<IREE::HAL::DispatchFlagsAttr>( + IREE::HAL::DispatchFlags::None); forBuilder.create<IREE::HAL::CommandBufferDispatchOp>( loc, commandBuffer, executable, ordinal, workgroupCountOp.getWorkgroupX(), workgroupCountOp.getWorkgroupY(), - workgroupCountOp.getWorkgroupZ()); + workgroupCountOp.getWorkgroupZ(), flags); // Barrier following the dispatch to block the next dispatch. auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire |
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir index e7b80e1..a139ece 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir
@@ -13,12 +13,12 @@ %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index - hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None) hal.command_buffer.execution_barrier<%cmd1 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None") - hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2]) + hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2]) flags(None) - hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2]) + hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2]) flags(None) hal.command_buffer.execution_barrier<%cmd2 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None") util.return @@ -59,7 +59,7 @@ %c1 = arith.constant 1 : index scf.index_switch %idx case 0 { - hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None) scf.yield } default {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 0a923c9..d68b865 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -309,7 +309,8 @@ %entry_point : i32, %workgroup_x : i32, %workgroup_y : i32, - %workgroup_z : i32 + %workgroup_z : i32, + %flags : i64 ) // Dispatches an execution request with the dispatch parameters loaded from the @@ -320,7 +321,8 @@ %entry_point : i32, %workgroups_buffer_slot : i32, %workgroups_buffer : !vm.ref<!hal.buffer>, - %workgroups_offset : i64 + %workgroups_offset : i64, + %flags : i64 ) //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index 8c3c502..fa837d4 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -127,11 +127,11 @@ // one of the vm constant ops. auto constValue = builder.create<mlir::arith::ConstantOp>( loc, inputType, - IntegerAttr::get(inputType, - APInt(32, static_cast<int32_t>(intAttr.getInt())))); + IntegerAttr::get(inputType, APInt(inputType.getIntOrFloatBitWidth(), + intAttr.getValue().getSExtValue()))); return {{constValue}}; - } - if (auto elementsAttr = llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) { + } else if (auto elementsAttr = + llvm::dyn_cast<DenseIntElementsAttr>(attrValue)) { SmallVector<Value> elementValues; elementValues.reserve(elementsAttr.getNumElements()); for (auto intAttr : elementsAttr.getValues<Attribute>()) { @@ -140,8 +140,7 @@ cast<TypedAttr>(intAttr))); } return elementValues; - } - if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) { + } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attrValue)) { SmallVector<Value> allValues; for (auto elementAttr : arrayAttr) { auto flattenedValues = @@ -151,8 +150,7 @@ allValues.append(flattenedValues->begin(), flattenedValues->end()); } return allValues; - } - if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) { + } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) { return {{builder.create<IREE::VM::RodataInlineOp>(loc, strAttr)}}; }
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c index fe165e5..42b476b 100644 --- a/experimental/rocm/direct_command_buffer.c +++ b/experimental/rocm/direct_command_buffer.c
@@ -412,7 +412,8 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_rocm_direct_command_buffer_t* command_buffer = iree_hal_rocm_direct_command_buffer_cast(base_command_buffer); // Lookup kernel parameters used for side-channeling additional launch @@ -463,7 +464,7 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "need rocm implementation"); }
diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c index d57dee4..de89e4f 100644 --- a/experimental/webgpu/command_buffer.c +++ b/experimental/webgpu/command_buffer.c
@@ -884,7 +884,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_webgpu_command_buffer_t* command_buffer = iree_hal_webgpu_command_buffer_cast(base_command_buffer); @@ -900,7 +901,7 @@ static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_webgpu_command_buffer_t* command_buffer = iree_hal_webgpu_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c index 38619a1..7f3785d 100644 --- a/runtime/src/iree/hal/command_buffer.c +++ b/runtime/src/iree/hal/command_buffer.c
@@ -558,7 +558,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { IREE_ASSERT_ARGUMENT(command_buffer); IREE_ASSERT_ARGUMENT(executable); if ((workgroup_x | workgroup_y | workgroup_z) == 0) { @@ -574,7 +575,7 @@ IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_command_buffer_dispatch_validation( command_buffer, VALIDATION_STATE(command_buffer), executable, - entry_point, workgroup_x, workgroup_y, workgroup_z)); + entry_point, workgroup_x, workgroup_y, workgroup_z, flags)); }); #if IREE_HAL_VERBOSE_TRACING_ENABLE // TODO(benvanik): add a tracing.h helper that does the snprintf directly @@ -594,7 +595,7 @@ #endif // IREE_HAL_VERBOSE_TRACING_ENABLE iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)( command_buffer, executable, entry_point, workgroup_x, workgroup_y, - workgroup_z); + workgroup_z, flags); IREE_TRACE_ZONE_END(z0); return status; } @@ -602,7 +603,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { IREE_ASSERT_ARGUMENT(command_buffer); IREE_ASSERT_ARGUMENT(executable); IREE_TRACE_ZONE_BEGIN(z0); @@ -610,10 +611,10 @@ IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_command_buffer_dispatch_indirect_validation( command_buffer, VALIDATION_STATE(command_buffer), executable, - entry_point, workgroups_ref)); + entry_point, workgroups_ref, flags)); }); iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)( - command_buffer, executable, entry_point, workgroups_ref); + command_buffer, executable, entry_point, workgroups_ref, flags); IREE_TRACE_ZONE_END(z0); return status; }
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h index 38ac03d..c9c6037 100644 --- a/runtime/src/iree/hal/command_buffer.h +++ b/runtime/src/iree/hal/command_buffer.h
@@ -384,6 +384,12 @@ IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count( iree_hal_collective_element_type_t element_type); +// Bitfield specifying flags controlling a dispatch operation. +enum iree_hal_dispatch_flag_bits_t { + IREE_HAL_DISPATCH_FLAG_NONE = 0, +}; +typedef uint64_t iree_hal_dispatch_flags_t; + // An RGBA color. typedef struct iree_hal_label_color_t { uint8_t r; @@ -751,7 +757,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags); // Dispatches an execution request with deferred workgroup counts. // This is the same as iree_hal_command_buffer_dispatch but the workgroup counts @@ -765,7 +772,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref); + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); //===----------------------------------------------------------------------===// // Validation support @@ -922,12 +929,13 @@ iree_status_t(IREE_API_PTR* dispatch)( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags); iree_status_t(IREE_API_PTR* dispatch_indirect)( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref); + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); } iree_hal_command_buffer_vtable_t; IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t);
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c index 87f3a4b..b27433c 100644 --- a/runtime/src/iree/hal/command_buffer_validation.c +++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -603,7 +603,8 @@ iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings( @@ -615,7 +616,7 @@ iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h index 036d666..82ab1c5 100644 --- a/runtime/src/iree/hal/command_buffer_validation.h +++ b/runtime/src/iree/hal/command_buffer_validation.h
@@ -142,13 +142,14 @@ iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags); iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref); + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); iree_status_t iree_hal_command_buffer_binding_table_validation( iree_hal_command_buffer_t* command_buffer,
diff --git a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h index 4d71207..6d19793 100644 --- a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h +++ b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
@@ -154,7 +154,8 @@ IREE_ASSERT_OK(iree_hal_command_buffer_dispatch( command_buffer, executable_, /*entry_point=*/0, - /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1)); + /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1, + IREE_HAL_DISPATCH_FLAG_NONE)); IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier( command_buffer, /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
diff --git a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h index af99ee1..06fa747 100644 --- a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h +++ b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h
@@ -120,7 +120,8 @@ IREE_ASSERT_OK(iree_hal_command_buffer_dispatch( command_buffer, executable_, /*entry_point=*/0, - /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1)); + /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1, + IREE_HAL_DISPATCH_FLAG_NONE)); IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier( command_buffer, /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH |
diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c index c747c3c..c53428a 100644 --- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
@@ -747,7 +747,8 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_cuda_graph_command_buffer_t* command_buffer = iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -873,7 +874,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect dispatch not yet implemented"); }
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c index 8f9cd2d..3369f3b 100644 --- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -533,7 +533,8 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_cuda_stream_command_buffer_t* command_buffer = iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -646,7 +647,7 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "need cuda implementation of dispatch indirect"); }
diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c index 65b2c4c..ae66cfd 100644 --- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c
@@ -771,7 +771,8 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_hip_graph_command_buffer_t* command_buffer = iree_hal_hip_graph_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -882,7 +883,7 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect dispatch not yet implemented"); }
diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index a250299..0f08727 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c
@@ -525,7 +525,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -626,7 +627,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "need hip implementation of dispatch indirect"); }
diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c index 95c503e..50d56c8 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c +++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
@@ -973,7 +973,8 @@ static iree_status_t iree_hal_task_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_task_command_buffer_t* command_buffer = iree_hal_task_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( @@ -987,7 +988,7 @@ static iree_status_t iree_hal_task_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_task_command_buffer_t* command_buffer = iree_hal_task_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m index 50d01e7..eaed4f5 100644 --- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m +++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m
@@ -1073,7 +1073,7 @@ static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y, - uint32_t workgroup_count_z) { + uint32_t workgroup_count_z, iree_hal_dispatch_flags_t flags) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_metal_dispatch_segment_t* segment = NULL; @@ -1090,7 +1090,7 @@ static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, - int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref) { + int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_metal_dispatch_segment_t* segment = NULL;
diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc index b66be80..8bb9413 100644 --- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc +++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
@@ -725,7 +725,8 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_vulkan_direct_command_buffer_t* command_buffer = iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); @@ -764,7 +765,7 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_vulkan_direct_command_buffer_t* command_buffer = iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer);
diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c index f69f9c2..3de7c60 100644 --- a/runtime/src/iree/hal/local/inline_command_buffer.c +++ b/runtime/src/iree/hal/local/inline_command_buffer.c
@@ -442,7 +442,8 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_inline_command_buffer_t* command_buffer = iree_hal_inline_command_buffer_cast(base_command_buffer); @@ -559,7 +560,7 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. iree_hal_buffer_mapping_t buffer_mapping = {{0}}; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( @@ -570,7 +571,7 @@ *(const iree_hal_vec3_t*)buffer_mapping.contents.data; return iree_hal_inline_command_buffer_dispatch( base_command_buffer, executable, entry_point, workgroup_count.x, - workgroup_count.y, workgroup_count.z); + workgroup_count.y, workgroup_count.z, flags); } //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c index a4b805a..49ec334 100644 --- a/runtime/src/iree/hal/utils/deferred_command_buffer.c +++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c
@@ -771,12 +771,14 @@ uint32_t workgroup_x; uint32_t workgroup_y; uint32_t workgroup_z; + iree_hal_dispatch_flags_t flags; } iree_hal_cmd_dispatch_t; static iree_status_t iree_hal_deferred_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_deferred_command_buffer_t* command_buffer = iree_hal_deferred_command_buffer_cast(base_command_buffer); iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list; @@ -790,6 +792,7 @@ cmd->workgroup_x = workgroup_x; cmd->workgroup_y = workgroup_y; cmd->workgroup_z = workgroup_z; + cmd->flags = flags; return iree_ok_status(); } @@ -799,7 +802,7 @@ const iree_hal_cmd_dispatch_t* cmd) { return iree_hal_command_buffer_dispatch( target_command_buffer, cmd->executable, cmd->entry_point, - cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z); + cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z, cmd->flags); } //===----------------------------------------------------------------------===// @@ -811,12 +814,13 @@ iree_hal_executable_t* executable; int32_t entry_point; iree_hal_buffer_ref_t workgroups_ref; + iree_hal_dispatch_flags_t flags; } iree_hal_cmd_dispatch_indirect_t; static iree_status_t iree_hal_deferred_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_deferred_command_buffer_t* command_buffer = iree_hal_deferred_command_buffer_cast(base_command_buffer); iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list; @@ -834,6 +838,7 @@ cmd->executable = executable; cmd->entry_point = entry_point; cmd->workgroups_ref = workgroups_ref; + cmd->flags = flags; return iree_ok_status(); } @@ -845,7 +850,8 @@ IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref( binding_table, cmd->workgroups_ref, &workgroups_ref)); return iree_hal_command_buffer_dispatch_indirect( - target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref); + target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref, + cmd->flags); } //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index b87a8cb..f6f96f2 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl
@@ -50,8 +50,8 @@ EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v) EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v) EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r) -EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v) -EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirI, v) +EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiiiI, v) +EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirII, v) EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v) EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v) EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index c0db04b..777bcb3 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c
@@ -907,7 +907,7 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, // iree_hal_module_state_t, // - rriiii, v) { + rriiiiI, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); @@ -917,15 +917,16 @@ uint32_t workgroup_x = (uint32_t)args->i3; uint32_t workgroup_y = (uint32_t)args->i4; uint32_t workgroup_z = (uint32_t)args->i5; + iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6; return iree_hal_command_buffer_dispatch(command_buffer, executable, entry_point, workgroup_x, workgroup_y, - workgroup_z); + workgroup_z, flags); } IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, // iree_hal_module_state_t, // - rriirI, v) { + rriirII, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); @@ -938,9 +939,10 @@ workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t)); IREE_RETURN_IF_ERROR( iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer)); + iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6; - return iree_hal_command_buffer_dispatch_indirect(command_buffer, executable, - entry_point, workgroups_ref); + return iree_hal_command_buffer_dispatch_indirect( + command_buffer, executable, entry_point, workgroups_ref, flags); } //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index c89a8b7..5bd69a7 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c
@@ -59,11 +59,11 @@ IREE_VM_ABI_DEFINE_SHIM(rriCiD, v); IREE_VM_ABI_DEFINE_SHIM(rriiCID, v); IREE_VM_ABI_DEFINE_SHIM(rriCiirIID, v); -IREE_VM_ABI_DEFINE_SHIM(rriiii, v); +IREE_VM_ABI_DEFINE_SHIM(rriiiiI, v); IREE_VM_ABI_DEFINE_SHIM(rrIIiii, v); IREE_VM_ABI_DEFINE_SHIM(rrirCID, v); IREE_VM_ABI_DEFINE_SHIM(rrirI, v); -IREE_VM_ABI_DEFINE_SHIM(rriirI, v); +IREE_VM_ABI_DEFINE_SHIM(rriirII, v); IREE_VM_ABI_DEFINE_SHIM(rrIrIIi, v); IREE_VM_ABI_DEFINE_SHIM(riirIrII, v); IREE_VM_ABI_DEFINE_SHIM(rrIii, v);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index 1d5a3aa..b47428c 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h
@@ -371,13 +371,14 @@ int64_t i12; }); -IREE_VM_ABI_FIXED_STRUCT(rriiii, { +IREE_VM_ABI_FIXED_STRUCT(rriiiiI, { iree_vm_ref_t r0; iree_vm_ref_t r1; int32_t i2; int32_t i3; int32_t i4; int32_t i5; + int64_t i6; }); IREE_VM_ABI_FIXED_STRUCT(rrIIiii, { @@ -398,13 +399,14 @@ int64_t i4; }); -IREE_VM_ABI_FIXED_STRUCT(rriirI, { +IREE_VM_ABI_FIXED_STRUCT(rriirII, { iree_vm_ref_t r0; iree_vm_ref_t r1; int32_t i2; int32_t i3; iree_vm_ref_t r4; int64_t i5; + int64_t i6; }); IREE_VM_ABI_FIXED_STRUCT(rrIrIIi, { @@ -708,11 +710,11 @@ IREE_VM_ABI_DECLARE_SHIM(rriCiD, v); IREE_VM_ABI_DECLARE_SHIM(rriiCID, v); IREE_VM_ABI_DECLARE_SHIM(rriCiirIID, v); -IREE_VM_ABI_DECLARE_SHIM(rriiii, v); +IREE_VM_ABI_DECLARE_SHIM(rriiiiI, v); IREE_VM_ABI_DECLARE_SHIM(rrIIiii, v); IREE_VM_ABI_DECLARE_SHIM(rrirCID, v); IREE_VM_ABI_DECLARE_SHIM(rrirI, v); -IREE_VM_ABI_DECLARE_SHIM(rriirI, v); +IREE_VM_ABI_DECLARE_SHIM(rriirII, v); IREE_VM_ABI_DECLARE_SHIM(rrIrIIi, v); IREE_VM_ABI_DECLARE_SHIM(riirIrII, v); IREE_VM_ABI_DECLARE_SHIM(rrIii, v);
diff --git a/tools/iree-benchmark-executable-main.c b/tools/iree-benchmark-executable-main.c index f3a0b14..c603cb8 100644 --- a/tools/iree-benchmark-executable-main.c +++ b/tools/iree-benchmark-executable-main.c
@@ -277,7 +277,7 @@ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_dispatch( command_buffer, args->executable, FLAG_entry_point, args->workgroup_count[0], args->workgroup_count[1], - args->workgroup_count[2])); + args->workgroup_count[2], IREE_HAL_DISPATCH_FLAG_NONE)); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_execution_barrier( command_buffer, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE, IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE,