Changing descriptor ops to use index instead of int32 and SSA values. (#4765)
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp index 59eb6bc..af4c7c3 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/ConvertStreamOps.cpp
@@ -305,8 +305,10 @@ auto byteLength = value->getByteLength(); if (!byteLength) return failure(); - bindings.push_back(std::make_tuple(bindingOrdinal++, value->getBuffer(), - zeroOffset, byteLength)); + bindings.push_back( + std::make_tuple(rewriter.createOrFold<mlir::ConstantIndexOp>( + dispatchOp.getLoc(), bindingOrdinal++), + value->getBuffer(), zeroOffset, byteLength)); return success(); }; for (auto it : llvm::enumerate(dispatchOp.operands())) {
diff --git a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir index 2ff6847..abc7d9f 100644 --- a/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
@@ -1,4 +1,4 @@ -// RUN: iree-opt -print-ir-after-all -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s +// RUN: iree-opt -split-input-file -iree-convert-to-hal -canonicalize %s | IreeFileCheck %s hal.executable @ex0 { hal.interface @interface { @@ -26,7 +26,7 @@ // CHECK-NEXT: hal.command_buffer.begin %[[CMD]] %0 = flow.ex.stream.fragment(%arg1 = %cst : index, %arg2 = %arg0 : tensor<128xf32>) -> tensor<128xf32> { // CHECK-DAG: %[[EXE_LAYOUT:.+]] = hal.executable_layout.lookup - // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set=0, bindings=[0 = (%arg0, %c0, %c512), 1 = (%[[TMP_BUF]], %c0, %c512)] + // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %[[EXE_LAYOUT]], set = %c0, bindings = [%c0 = (%arg0, %c0, %c512), %c1 = (%[[TMP_BUF]], %c0, %c512)] // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex0::@vmla::@entry0, workgroup_xyz // CHECK: hal.command_buffer.execution_barrier %1 = flow.dispatch @ex0::@entry0[%arg1] (%arg2) : (tensor<128xf32>) -> tensor<128xf32> @@ -137,7 +137,7 @@ %arg6 = %c1024 : index, %arg7 = %c512 : index ) -> tensor<4x7x1024xf32> { - // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set=0, bindings=[0 = (%arg0, %c0, %c2688), 1 = (%buffer, %c0, %c114688)] + // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %c2688), %c1 = (%buffer, %c0, %c114688)] // CHECK: hal.command_buffer.dispatch.symbol {{.+}}, @ex::@tgt::@entry, workgroup_xyz %0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7] (%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> flow.return %0 : tensor<4x7x1024xf32> @@ -180,7 +180,7 @@ %4 = shapex.make_ranked_shape %arg5, %arg4 : (index, index) -> !shapex.ranked_shape<[?,?,1024]> %5 = shapex.tie_shape %arg3, %3 : tensor<7x?x24x?xf32>, !shapex.ranked_shape<[7,?,24,?]> // CHECK: hal.command_buffer.push_constants %[[CMD]], %executable_layout, offset = 0, values = [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] : i32 - // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set=0, bindings=[0 = (%arg0, %c0, %9), 1 = (%buffer, %c0, %12)] + // CHECK: hal.command_buffer.push_descriptor_set %[[CMD]], %executable_layout, set = %c0, bindings = [%c0 = (%arg0, %c0, %9), %c1 = (%buffer, %c0, %12)] // CHECK: #hal.device.match.id<"dylib*">( // CHECK-SAME: %[[CMD_INNER:.+]] = %cmd : !hal.command_buffer,
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index e2c44da..eff6109 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -43,35 +43,20 @@ SmallVector<Value, 8> callOperands = { newOperands.command_buffer(), newOperands.executable_layout(), - rewriter.create<mlir::ConstantOp>( - op.getLoc(), rewriter.getI32IntegerAttr( - static_cast<int32_t>(op.setAttr().getInt()))), + newOperands.set(), }; SmallVector<int16_t, 5> segmentSizes = { /*command_buffer=*/-1, /*executable_layout=*/-1, /*set=*/-1, - /*bindings_ordinals=*/ - static_cast<int16_t>(op.bindings().size()), - /*bindings_buffers=*/ - static_cast<int16_t>(op.bindings().size()), - /*bindings_offsets=*/ - static_cast<int16_t>(op.bindings().size()), - /*bindings_lengths=*/ - static_cast<int16_t>(op.bindings().size()), + /*bindings=*/ + static_cast<int16_t>(newOperands.binding_ordinals().size()), }; - for (auto bindingAttr : op.bindings()) { - callOperands.push_back( - rewriter.create<mlir::ConstantOp>(op.getLoc(), bindingAttr)); - } - for (auto bindingBuffer : newOperands.binding_buffers()) { - callOperands.push_back(bindingBuffer); - } - for (auto bindingOffset : newOperands.binding_offsets()) { - callOperands.push_back(bindingOffset); - } - for (auto bindingLength : newOperands.binding_lengths()) { - callOperands.push_back(bindingLength); + for (size_t i = 0; i < newOperands.binding_ordinals().size(); ++i) { + callOperands.push_back(newOperands.binding_ordinals()[i]); + callOperands.push_back(newOperands.binding_buffers()[i]); + callOperands.push_back(newOperands.binding_offsets()[i]); + callOperands.push_back(newOperands.binding_lengths()[i]); } rewriter.replaceOpWithNewOp<IREE::VM::CallVariadicOp>(
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index d0615eb..f46de67 100644 --- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -58,11 +58,12 @@ %arg0 : !hal.command_buffer, %arg1 : !hal.executable_layout, %arg2 : !hal.descriptor_set) { + %c0 = constant 0 : index %c100 = constant 100 : index // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero, %arg2, []) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable_layout>, i32, !vm.ref<!hal.descriptor_set>, i32 ...) - hal.command_buffer.bind_descriptor_set %arg0, %arg1, set=0, %arg2 - // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero_0, %arg2, [%c100]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable_layout>, i32, !vm.ref<!hal.descriptor_set>, i32 ...) - hal.command_buffer.bind_descriptor_set %arg0, %arg1, set=0, %arg2, offsets=[%c100] + hal.command_buffer.bind_descriptor_set %arg0, %arg1, set = %c0, %arg2 + // CHECK: vm.call.variadic @hal.command_buffer.bind_descriptor_set(%arg0, %arg1, %zero, %arg2, [%c100]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable_layout>, i32, !vm.ref<!hal.descriptor_set>, i32 ...) + hal.command_buffer.bind_descriptor_set %arg0, %arg1, set = %c0, %arg2, offsets = [%c100] return }
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp index 86f8521..0a1d25d 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -724,11 +724,18 @@ void CommandBufferPushDescriptorSetOp::build( OpBuilder &builder, OperationState &state, Value commandBuffer, - Value executableLayout, uint32_t set, + Value executableLayout, int64_t set, ArrayRef<DescriptorSetBindingValue> bindings) { - state.addOperands({commandBuffer, executableLayout}); - state.addAttribute("set", builder.getI32IntegerAttr(set)); - SmallVector<int32_t, 4> bindingOrdinals; + build(builder, state, commandBuffer, executableLayout, + builder.createOrFold<ConstantIndexOp>(state.location, set), bindings); +} + +void CommandBufferPushDescriptorSetOp::build( + OpBuilder &builder, OperationState &state, Value commandBuffer, + Value executableLayout, Value set, + ArrayRef<DescriptorSetBindingValue> bindings) { + state.addOperands({commandBuffer, executableLayout, set}); + SmallVector<Value, 4> bindingOrdinals; SmallVector<Value, 4> bindingBuffers; SmallVector<Value, 4> bindingOffsets; SmallVector<Value, 4> bindingLengths; @@ -738,7 +745,7 @@ bindingOffsets.push_back(std::get<2>(binding)); bindingLengths.push_back(std::get<3>(binding)); } - state.addAttribute("bindings", builder.getI32ArrayAttr(bindingOrdinals)); + state.addOperands(bindingOrdinals); state.addOperands(bindingBuffers); state.addOperands(bindingOffsets); state.addOperands(bindingLengths); @@ -746,20 +753,19 @@ static ParseResult parseDescriptorSetBindings(OpAsmParser &parser, OperationState *result) { - auto i32Type = parser.getBuilder().getIntegerType(32); auto indexType = parser.getBuilder().getIndexType(); - SmallVector<Attribute, 4> bindingAttrs; + SmallVector<Value, 4> bindingOrdinals; SmallVector<Value, 4> bindingBuffers; SmallVector<Value, 4> bindingOffsets; SmallVector<Value, 4> bindingLengths; do { - IntegerAttr bindingAttr; NamedAttrList attrList; + OpAsmParser::OperandType ordinal; OpAsmParser::OperandType buffer; OpAsmParser::OperandType bufferOffset; OpAsmParser::OperandType bufferLength; - if (failed( - parser.parseAttribute(bindingAttr, i32Type, "binding", attrList)) || + if (failed(parser.parseOperand(ordinal)) || + failed(parser.resolveOperand(ordinal, indexType, bindingOrdinals)) || failed(parser.parseEqual()) || failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) || failed(parser.resolveOperand( @@ -775,10 +781,8 @@ failed(parser.parseRParen())) { return failure(); } - bindingAttrs.push_back(bindingAttr); } while (succeeded(parser.parseOptionalComma())); - result->addAttribute("bindings", - parser.getBuilder().getArrayAttr(bindingAttrs)); + result->addOperands(bindingOrdinals); result->addOperands(bindingBuffers); result->addOperands(bindingOffsets); result->addOperands(bindingLengths); @@ -789,25 +793,24 @@ OpAsmParser &parser, OperationState *result) { OpAsmParser::OperandType commandBuffer; OpAsmParser::OperandType executableLayout; - IntegerAttr setAttr; + OpAsmParser::OperandType set; auto operandsLoc = parser.getCurrentLocation(); if (failed(parser.parseOperand(commandBuffer)) || failed(parser.parseComma()) || failed(parser.parseOperand(executableLayout)) || failed(parser.parseComma()) || failed(parser.parseKeyword("set")) || - failed(parser.parseEqual()) || - failed(parser.parseAttribute(setAttr, - parser.getBuilder().getIntegerType(32), - "set", result->attributes)) || + failed(parser.parseEqual()) || failed(parser.parseOperand(set)) || failed(parser.parseComma()) || failed(parser.resolveOperands( ArrayRef<OpAsmParser::OperandType>{ commandBuffer, executableLayout, + set, }, ArrayRef<Type>{ CommandBufferType::get(result->getContext()), ExecutableLayoutType::get(result->getContext()), + IndexType::get(result->getContext()), }, operandsLoc, result->operands)) || failed(parser.parseKeyword("bindings")) || failed(parser.parseEqual()) || @@ -822,8 +825,8 @@ template <typename T> static void printDescriptorSetBindings(OpAsmPrinter &p, T op) { - for (int i = 0; i < op.bindings().size(); ++i) { - p << op.bindings()[i].template cast<IntegerAttr>().getValue(); + for (int i = 0; i < op.binding_ordinals().size(); ++i) { + p.printOperand(op.binding_ordinals()[i]); p << " = ("; p.printOperand(op.binding_buffers()[i]); p << ", "; @@ -831,7 +834,7 @@ p << ", "; p.printOperand(op.binding_lengths()[i]); p << ")"; - if (i < op.bindings().size() - 1) p << ", "; + if (i < op.binding_ordinals().size() - 1) p << ", "; } } @@ -841,8 +844,9 @@ p.printOperand(op.command_buffer()); p << ", "; p.printOperand(op.executable_layout()); - p << ", set=" << op.set(); - p << ", bindings=["; + p << ", set = "; + p.printOperand(op.set()); + p << ", bindings = ["; printDescriptorSetBindings(p, op); p << "]"; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ @@ -859,11 +863,20 @@ OperationState &state, Value commandBuffer, Value executableLayout, - uint32_t set, Value descriptorSet, + int64_t set, Value descriptorSet, ValueRange dynamicOffsets) { - state.addOperands({commandBuffer, executableLayout, descriptorSet}); - state.addAttribute("set", - builder.getIntegerAttr(builder.getIntegerType(32), set)); + build(builder, state, commandBuffer, executableLayout, + builder.createOrFold<ConstantIndexOp>(state.location, set), + descriptorSet, dynamicOffsets); +} + +void CommandBufferBindDescriptorSetOp::build(OpBuilder &builder, + OperationState &state, + Value commandBuffer, + Value executableLayout, Value set, + Value descriptorSet, + ValueRange dynamicOffsets) { + state.addOperands({commandBuffer, executableLayout, set, descriptorSet}); state.addOperands(dynamicOffsets); } @@ -966,7 +979,7 @@ OpBuilder &builder, OperationState &state, Value device, Value setLayout, ArrayRef<DescriptorSetBindingValue> bindings) { state.addOperands({device, setLayout}); - SmallVector<int32_t, 4> bindingOrdinals; + SmallVector<Value, 4> bindingOrdinals; SmallVector<Value, 4> bindingBuffers; SmallVector<Value, 4> bindingOffsets; SmallVector<Value, 4> bindingLengths; @@ -976,7 +989,7 @@ bindingOffsets.push_back(std::get<2>(binding)); bindingLengths.push_back(std::get<3>(binding)); } - state.addAttribute("bindings", builder.getI32ArrayAttr(bindingOrdinals)); + state.addOperands(bindingOrdinals); state.addOperands(bindingBuffers); state.addOperands(bindingOffsets); state.addOperands(bindingLengths); @@ -1015,7 +1028,7 @@ p.printOperand(op.device()); p << ", "; p.printOperand(op.set_layout()); - p << ", bindings=["; + p << ", bindings = ["; printDescriptorSetBindings(p, op); p << "]"; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td index bea5b3e..68f052c 100644 --- a/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1139,10 +1139,10 @@ Pushes an inline-defined descriptor set to the command buffer. ```mlir - hal.command_buffer.push_descriptor_set %cmd, %executable_layout, set = 0, bindings = [ - 0 = (%buffer_0, %buffer_offset_0, %buffer_length_0), - 1 = (%buffer_1, %buffer_offset_1, %buffer_length_1), - 2 = (%buffer_2, %buffer_offset_2, %buffer_length_2) + hal.command_buffer.push_descriptor_set %cmd, %executable_layout, set = %c0, bindings = [ + %c0 = (%buffer_0, %buffer_offset_0, %buffer_length_0), + %c1 = (%buffer_1, %buffer_offset_1, %buffer_length_1), + %c2 = (%buffer_2, %buffer_offset_2, %buffer_length_2) ] ``` }]; @@ -1150,8 +1150,8 @@ let arguments = (ins HAL_CommandBuffer:$command_buffer, HAL_ExecutableLayout:$executable_layout, - I32Attr:$set, - I32ArrayAttr:$bindings, + Index:$set, + Variadic<Index>:$binding_ordinals, Variadic<HAL_Buffer>:$binding_buffers, Variadic<HAL_DeviceSize>:$binding_offsets, Variadic<HAL_DeviceSize>:$binding_lengths @@ -1160,7 +1160,9 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout, - "uint32_t":$set, "ArrayRef<DescriptorSetBindingValue>":$bindings)>, + "int64_t":$set, "ArrayRef<DescriptorSetBindingValue>":$bindings)>, + OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout, + "Value":$set, "ArrayRef<DescriptorSetBindingValue>":$bindings)>, ]; let hasCanonicalizer = 1; @@ -1177,7 +1179,7 @@ let arguments = (ins HAL_CommandBuffer:$command_buffer, HAL_ExecutableLayout:$executable_layout, - I32Attr:$set, + Index:$set, HAL_DescriptorSet:$descriptor_set, Variadic<HAL_DeviceSize>:$dynamic_offsets ); @@ -1191,7 +1193,10 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout, - "uint32_t":$set, "Value":$descriptorSet, + "int64_t":$set, "Value":$descriptorSet, + CArg<"ValueRange", "{}">:$dynamicOffsets)>, + OpBuilderDAG<(ins "Value":$commandBuffer, "Value":$executableLayout, + "Value":$set, "Value":$descriptorSet, CArg<"ValueRange", "{}">:$dynamicOffsets)>, ]; } @@ -1551,7 +1556,7 @@ let arguments = (ins HAL_Device:$device, HAL_DescriptorSetLayout:$set_layout, - I32ArrayAttr:$bindings, + Variadic<Index>:$binding_ordinals, Variadic<HAL_Buffer>:$binding_buffers, Variadic<HAL_DeviceSize>:$binding_offsets, Variadic<HAL_DeviceSize>:$binding_lengths
diff --git a/iree/compiler/Dialect/HAL/IR/HALTypes.h b/iree/compiler/Dialect/HAL/IR/HALTypes.h index 752ce04..1e0bed6 100644 --- a/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/iree/compiler/Dialect/HAL/IR/HALTypes.h
@@ -182,7 +182,7 @@ // A tuple containing runtime values for a descriptor set binding: // <binding ordinal, hal.buffer, buffer byte offset, buffer byte length> -using DescriptorSetBindingValue = std::tuple<uint32_t, Value, Value, Value>; +using DescriptorSetBindingValue = std::tuple<Value, Value, Value, Value>; } // namespace HAL } // namespace IREE
diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir index 6ede3e1..84f8ed0 100644 --- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -23,23 +23,25 @@ %buffer : !hal.buffer ) { %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index %c4 = constant 4 : index %c4096 = constant 4096 : index %c8000 = constant 8000 : index %c262140 = constant 262140 : index %c262144 = constant 262144 : index %subspan = hal.buffer.subspan %buffer, %c4096, %c262144 : !hal.buffer - // CHECK: hal.command_buffer.push_descriptor_set {{.+}}, bindings=[ - hal.command_buffer.push_descriptor_set %cmd, %layout, set=0, bindings=[ + // CHECK: hal.command_buffer.push_descriptor_set {{.+}}, set = %c0, bindings = [ + hal.command_buffer.push_descriptor_set %cmd, %layout, set = %c0, bindings = [ // 0 + 4096: - // CHECK-SAME: 0 = ([[BASE_BUFFER]], %c4096, %c8000) - 0 = (%subspan, %c0, %c8000), + // CHECK-SAME: %c0 = ([[BASE_BUFFER]], %c4096, %c8000) + %c0 = (%subspan, %c0, %c8000), // 4096 + 4: - // CHECK-SAME: 1 = ([[BASE_BUFFER]], %c4100, %c262140) - 1 = (%subspan, %c4, %c262140), + // CHECK-SAME: %c1 = ([[BASE_BUFFER]], %c4100, %c262140) + %c1 = (%subspan, %c4, %c262140), // No change: - // CHECK-SAME: 2 = ([[BASE_BUFFER]], %c4096, %c262144) - 2 = (%buffer, %c4096, %c262144) + // CHECK-SAME: %c2 = ([[BASE_BUFFER]], %c4096, %c262144) + %c2 = (%buffer, %c4096, %c262144) ] return }
diff --git a/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index 5291bf8..9730bfe 100644 --- a/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -91,10 +91,11 @@ %0 = "test_hal.executable_layout"() : () -> !hal.executable_layout %1 = "test_hal.descriptor_set"() : () -> !hal.descriptor_set %2 = "test_hal.offset"() : () -> index - // CHECK: hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1 - hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1 - // CHECK-NEXT: hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1, offsets = [%2] - hal.command_buffer.bind_descriptor_set %arg0, %0, set = 0, %1, offsets = [%2] + %c0 = constant 0 : index + // CHECK: hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1 + hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1 + // CHECK-NEXT: hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1, offsets = [%2] + hal.command_buffer.bind_descriptor_set %arg0, %0, set = %c0, %1, offsets = [%2] return }
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir index 39de1ec..c75e756 100644 --- a/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -194,10 +194,8 @@ %command_buffer : !vm.ref<!hal.command_buffer>, %executable_layout : !vm.ref<!hal.executable_layout>, %set : i32, - %bindings : i32 ..., - %binding_buffers : !vm.ref<!hal.buffer>..., - %binding_offsets : i32 ..., - %binding_lengths : i32 ... + // <binding, buffer, offset, length> + %bindings : tuple<i32, !vm.ref<!hal.buffer>, i32, i32>... ) // Binds a descriptor set to the given set number. @@ -237,10 +235,8 @@ vm.import @descriptor_set.create( %device : !vm.ref<!hal.device>, %set_layout : !vm.ref<!hal.descriptor_set_layout>, - %bindings : i32 ..., - %binding_buffers : !vm.ref<!hal.buffer>..., - %binding_offsets : i32 ..., - %binding_lengths : i32 ... + // <binding, buffer, offset, length> + %bindings : tuple<i32, !vm.ref<!hal.buffer>, i32, i32>... ) -> !vm.ref<!hal.descriptor_set> //===----------------------------------------------------------------------===//
diff --git a/iree/modules/hal/hal_module.cc b/iree/modules/hal/hal_module.cc index 9994903..9f3c164 100644 --- a/iree/modules/hal/hal_module.cc +++ b/iree/modules/hal/hal_module.cc
@@ -467,19 +467,19 @@ Status CommandBufferPushDescriptorSet( const vm::ref<iree_hal_command_buffer_t>& command_buffer, const vm::ref<iree_hal_executable_layout_t>& executable_layout, - uint32_t set, absl::Span<const uint32_t> binding_ordinals, - absl::Span<const vm::ref<iree_hal_buffer_t>> binding_buffers, - absl::Span<const int32_t> binding_offsets, - absl::Span<const int32_t> binding_lengths) { + uint32_t set, + absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t, + int32_t>> + bindings) { ExDeferRelease(executable_layout); absl::InlinedVector<iree_hal_descriptor_set_binding_t, 16> binding_structs( - binding_ordinals.size()); - for (int i = 0; i < binding_ordinals.size(); ++i) { + bindings.size()); + for (int i = 0; i < bindings.size(); ++i) { binding_structs[i] = { - binding_ordinals[i], binding_buffers[i].get(), - static_cast<iree_device_size_t>(binding_offsets[i]), - static_cast<iree_device_size_t>(binding_lengths[i])}; - ExDeferRelease(binding_buffers[i]); + std::get<0>(bindings[i]), std::get<1>(bindings[i]).get(), + static_cast<iree_device_size_t>(std::get<2>(bindings[i])), + static_cast<iree_device_size_t>(std::get<3>(bindings[i]))}; + ExDeferRelease(std::get<1>(bindings[i])); } return iree_hal_command_buffer_push_descriptor_set( command_buffer.get(), executable_layout.get(), set, @@ -534,18 +534,17 @@ StatusOr<vm::ref<iree_hal_descriptor_set_t>> DescriptorSetCreate( const vm::ref<iree_hal_device_t>& device, const vm::ref<iree_hal_descriptor_set_layout_t>& set_layout, - absl::Span<const uint32_t> binding_ordinals, - absl::Span<const vm::ref<iree_hal_buffer_t>> binding_buffers, - absl::Span<const uint32_t> binding_offsets, - absl::Span<const uint32_t> binding_lengths) { + absl::Span<const std::tuple<uint32_t, vm::ref<iree_hal_buffer_t>, int32_t, + int32_t>> + bindings) { absl::InlinedVector<iree_hal_descriptor_set_binding_t, 4> binding_structs( - binding_ordinals.size()); - for (int i = 0; i < binding_ordinals.size(); ++i) { + bindings.size()); + for (int i = 0; i < bindings.size(); ++i) { binding_structs[i] = { - binding_ordinals[i], // binding - binding_buffers[i].get(), // buffer - static_cast<iree_device_size_t>(binding_offsets[i]), // offset - static_cast<iree_device_size_t>(binding_lengths[i])}; // length + /*ordinal=*/std::get<0>(bindings[i]), + /*buffer=*/std::get<1>(bindings[i]).get(), + /*offset=*/static_cast<iree_device_size_t>(std::get<2>(bindings[i])), + /*length=*/static_cast<iree_device_size_t>(std::get<3>(bindings[i]))}; } vm::ref<iree_hal_descriptor_set_t> descriptor_set; IREE_RETURN_IF_ERROR(iree_hal_descriptor_set_create(