Updating HAL VM ABI to pass binding table slots.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
index 8299939..71be493 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -16,6 +16,31 @@
namespace {
+// Returns a slot value and a buffer ref value.
+// |bufferOrSlot| is intended to be a `AnyTypeOf<[Index, HAL_BufferType]>` in
+// the op definition.
+static std::tuple<Value, Value>
+splitBufferSlot(Location loc, Value bufferOrSlot, OpBuilder &builder) {
+ if (!bufferOrSlot) {
+ return std::make_tuple(
+ builder.create<IREE::VM::ConstI32ZeroOp>(loc),
+ builder.create<IREE::VM::ConstRefZeroOp>(
+ loc,
+ IREE::VM::RefType::get(builder.getType<IREE::HAL::BufferType>())));
+ } else if (isa<IREE::VM::RefType>(bufferOrSlot.getType())) {
+ // Direct buffer binding; pass 0 for table slot.
+ return std::make_tuple(builder.create<IREE::VM::ConstI32ZeroOp>(loc),
+ bufferOrSlot);
+ } else {
+ // Indirect binding table reference; pass null for the buffer.
+ return std::make_tuple(
+ castToImportType(bufferOrSlot, builder.getI32Type(), builder),
+ builder.create<IREE::VM::ConstRefZeroOp>(
+ loc,
+ IREE::VM::RefType::get(builder.getType<IREE::HAL::BufferType>())));
+ }
+}
+
// TODO(benvanik): import op handling of optional values.
// It'd be nice if the std::optional<Index>:$binding_capacity could be emitted
// as 0 when not present; today it'll be omitted entirely (as it's not in the
@@ -89,12 +114,15 @@
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
+ auto [targetBufferSlot, targetBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter);
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
- adaptor.getTargetBuffer(),
+ targetBuffer,
castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
rewriter),
castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
+ targetBufferSlot,
};
// Record the original pattern length then extend it to a 32 bit integer.
@@ -144,12 +172,57 @@
matchAndRewrite(IREE::HAL::CommandBufferUpdateBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
+ auto [targetBufferSlot, targetBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter);
SmallVector<Value, 8> callOperands = {
adaptor.getCommandBuffer(),
adaptor.getSourceBuffer(),
castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(),
rewriter),
- adaptor.getTargetBuffer(),
+ targetBuffer,
+ castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
+ rewriter),
+ castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
+ targetBufferSlot};
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
+class CommandBufferCopyBufferOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferCopyBufferOp> {
+public:
+ CommandBufferCopyBufferOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferCopyBufferOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+ auto [sourceBufferSlot, sourceBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getSourceBuffer(), rewriter);
+ auto [targetBufferSlot, targetBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter);
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ sourceBufferSlot,
+ targetBufferSlot,
+ sourceBuffer,
+ castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(),
+ rewriter),
+ targetBuffer,
castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(),
rewriter),
castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter),
@@ -182,15 +255,6 @@
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
- Value nullBuffer;
- auto getNullBuffer = [&]() {
- if (!nullBuffer) {
- nullBuffer = rewriter.create<IREE::VM::ConstRefZeroOp>(
- op.getLoc(),
- IREE::VM::RefType::get(rewriter.getType<IREE::HAL::BufferType>()));
- }
- return nullBuffer;
- };
Value zeroI64;
auto getZeroI64 = [&]() {
if (!zeroI64) {
@@ -203,10 +267,12 @@
// %channel : !vm.ref<!hal.channel>,
// %op : i32,
// %param : i32,
+ // %send_buffer_slot : i32,
+ // %recv_buffer_slot : i32,
// %send_buffer : !vm.ref<!hal.buffer>,
+ // %recv_buffer : !vm.ref<!hal.buffer>,
// %send_offset : i64,
// %send_length : i64,
- // %recv_buffer : !vm.ref<!hal.buffer>,
// %recv_offset : i64,
// %recv_length : i64,
// %element_count : i64
@@ -222,25 +288,22 @@
rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc()));
}
- if (adaptor.getSendBuffer()) {
- callOperands.push_back(adaptor.getSendBuffer());
- callOperands.push_back(adaptor.getSendOffset());
- callOperands.push_back(adaptor.getSendLength());
- } else {
- callOperands.push_back(getNullBuffer());
- callOperands.push_back(getZeroI64());
- callOperands.push_back(getZeroI64());
- }
-
- if (adaptor.getRecvBuffer()) {
- callOperands.push_back(adaptor.getRecvBuffer());
- callOperands.push_back(adaptor.getRecvOffset());
- callOperands.push_back(adaptor.getRecvLength());
- } else {
- callOperands.push_back(getNullBuffer());
- callOperands.push_back(getZeroI64());
- callOperands.push_back(getZeroI64());
- }
+ auto [sendBufferSlot, sendBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getSendBuffer(), rewriter);
+ auto [recvBufferSlot, recvBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getRecvBuffer(), rewriter);
+ callOperands.push_back(sendBufferSlot);
+ callOperands.push_back(recvBufferSlot);
+ callOperands.push_back(sendBuffer);
+ callOperands.push_back(recvBuffer);
+ callOperands.push_back(adaptor.getSendOffset() ? adaptor.getSendOffset()
+ : getZeroI64());
+ callOperands.push_back(adaptor.getSendLength() ? adaptor.getSendLength()
+ : getZeroI64());
+ callOperands.push_back(adaptor.getRecvOffset() ? adaptor.getRecvOffset()
+ : getZeroI64());
+ callOperands.push_back(adaptor.getRecvLength() ? adaptor.getRecvLength()
+ : getZeroI64());
callOperands.push_back(castToImportType(adaptor.getElementCount(),
rewriter.getI64Type(), rewriter));
@@ -275,29 +338,6 @@
ConversionPatternRewriter &rewriter) const override {
auto importType = importOp.getFunctionType();
- // Memoize zeros/nulls ala IndexSet.
- // Since there are usually hundreds to thousands of these push ops and each
- // one can have 5-10 of these this saves us a tremendous amount of time
- // creating/verifying/pattern matching/folding/CSE'ing.
- // We could extend IndexSet into a ConstantSet that could use these custom
- // VM ops instead of just arith.constant in order to make this more
- // reusable.
- Value zero;
- auto getI32Zero = [&]() {
- if (!zero) {
- zero = rewriter.create<IREE::VM::ConstI32ZeroOp>(op.getLoc());
- }
- return zero;
- };
- Value null;
- auto getNull = [&]() {
- if (!null) {
- null = rewriter.create<IREE::VM::ConstRefZeroOp>(
- op.getLoc(),
- IREE::VM::RefType::get(rewriter.getType<IREE::HAL::BufferType>()));
- }
- return null;
- };
auto i32Type = rewriter.getI32Type();
auto i64Type = rewriter.getI64Type();
@@ -316,16 +356,10 @@
for (size_t i = 0; i < adaptor.getBindingOrdinals().size(); ++i) {
callOperands.push_back(
castToImportType(adaptor.getBindingOrdinals()[i], i32Type, rewriter));
- auto bindingBuffer = adaptor.getBindingBuffers()[i];
- if (llvm::isa<IREE::VM::RefType>(bindingBuffer.getType())) {
- // Buffer binding; pass 0 for table slot.
- callOperands.push_back(getI32Zero());
- callOperands.push_back(bindingBuffer);
- } else {
- // Binding table reference; pass null for the buffer.
- callOperands.push_back(bindingBuffer);
- callOperands.push_back(getNull());
- }
+ auto [bindingBufferSlot, bindingBuffer] = splitBufferSlot(
+ op.getLoc(), adaptor.getBindingBuffers()[i], rewriter);
+ callOperands.push_back(bindingBufferSlot);
+ callOperands.push_back(bindingBuffer);
callOperands.push_back(
castToImportType(adaptor.getBindingOffsets()[i], i64Type, rewriter));
callOperands.push_back(
@@ -343,6 +377,46 @@
mutable IREE::VM::ImportOp importOp;
};
+class CommandBufferDispatchIndirectOpConversion
+ : public OpConversionPattern<IREE::HAL::CommandBufferDispatchIndirectOp> {
+public:
+ CommandBufferDispatchIndirectOpConversion(MLIRContext *context,
+ SymbolTable &importSymbols,
+ TypeConverter &typeConverter,
+ StringRef importName)
+ : OpConversionPattern(typeConverter, context) {
+ importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName);
+ assert(importOp);
+ }
+
+ LogicalResult
+ matchAndRewrite(IREE::HAL::CommandBufferDispatchIndirectOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto importType = importOp.getFunctionType();
+ auto [workgroupsBufferSlot, workgroupsBuffer] =
+ splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter);
+ SmallVector<Value, 8> callOperands = {
+ adaptor.getCommandBuffer(),
+ adaptor.getExecutable(),
+ castToImportType(adaptor.getEntryPoint(), rewriter.getI32Type(),
+ rewriter),
+ workgroupsBufferSlot,
+ workgroupsBuffer,
+ castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(),
+ rewriter),
+ };
+ auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>(
+ op, SymbolRefAttr::get(importOp), importType.getResults(),
+ callOperands);
+ copyImportAttrs(importOp, callOp);
+ return success();
+ }
+
+private:
+ mutable IREE::VM::ImportOp importOp;
+};
+
} // namespace
void populateHALCommandBufferToVMPatterns(MLIRContext *context,
@@ -370,7 +444,7 @@
patterns.insert<CommandBufferUpdateBufferOpConversion>(
context, importSymbols, typeConverter,
"hal.command_buffer.update_buffer");
- patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>(
+ patterns.insert<CommandBufferCopyBufferOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer");
patterns.insert<CommandBufferCollectiveOpConversion>(
context, importSymbols, typeConverter, "hal.command_buffer.collective");
@@ -383,10 +457,9 @@
"hal.command_buffer.push_descriptor_set");
patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchOp>>(
context, importSymbols, typeConverter, "hal.command_buffer.dispatch");
- patterns
- .insert<VMImportOpConversion<IREE::HAL::CommandBufferDispatchIndirectOp>>(
- context, importSymbols, typeConverter,
- "hal.command_buffer.dispatch.indirect");
+ patterns.insert<CommandBufferDispatchIndirectOpConversion>(
+ context, importSymbols, typeConverter,
+ "hal.command_buffer.dispatch.indirect");
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
index 7402f62..61ca923 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -52,9 +52,10 @@
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 1
// CHECK-DAG: %[[EXTEND:.+]] = vm.ext.i8.i32.u %arg2 : i32 -> i32
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32, i32) -> ()
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%arg2 : i8)
@@ -71,9 +72,10 @@
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 2
// CHECK-DAG: %[[EXTEND:.+]] = vm.ext.i16.i32.u %arg2 : i32 -> i32
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %[[EXTEND]], %[[PATTERN_LENGTH]])
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%arg2 : i16)
@@ -90,8 +92,9 @@
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4
- // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %arg2, %[[PATTERN_LENGTH]]) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, i64, i32, i32) -> ()
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %arg2, %[[PATTERN_LENGTH]])
hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
target(%arg1 : !hal.buffer)[%c100, %c200]
pattern(%arg2 : i32)
@@ -100,6 +103,25 @@
// -----
+// CHECK-LABEL: @command_buffer_fill_buffer_i32_indirect
+util.func public @command_buffer_fill_buffer_i32_indirect(
+ %arg0: !hal.command_buffer,
+ %arg1: index,
+ %arg2: i32
+) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %[[NULL_BUFFER]], %c100, %c200, %arg1, %arg2, %[[PATTERN_LENGTH]])
+ hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer>
+ target(%arg1 : index)[%c100, %c200]
+ pattern(%arg2 : i32)
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @command_buffer_update_buffer
// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32,
@@ -111,6 +133,7 @@
%device_buffer: !hal.buffer, %dst_offset: index,
%length: index
) {
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
// CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]]
// CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]]
// CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]]
@@ -118,7 +141,7 @@
// CHECK-SAME: (%[[CMD]],
// CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]],
// CHECK-SAME: %[[DEVICE_BUFFER]], %[[DST_OFFSET_I64]],
- // CHECK-SAME: %[[LENGTH_I64]])
+ // CHECK-SAME: %[[LENGTH_I64]], %[[UNUSED_SLOT]])
hal.command_buffer.update_buffer<%cmd : !hal.command_buffer>
source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset]
target(%device_buffer : !hal.buffer)[%dst_offset]
@@ -128,18 +151,69 @@
// -----
+// CHECK-LABEL: @command_buffer_update_buffer_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[DEVICE_BUFFER_SLOT:[a-z0-9]+]]: i32, %[[DST_OFFSET:[a-z0-9]+]]: i32,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: i32)
+util.func public @command_buffer_update_buffer_indirect(
+ %cmd: !hal.command_buffer,
+ %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index,
+ %device_buffer: index, %dst_offset: index,
+ %length: index
+ ) {
+ // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]]
+ // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]]
+ // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]]
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.update_buffer
+ // CHECK-SAME: (%[[CMD]],
+ // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]],
+ // CHECK-SAME: %[[NULL_BUFFER]], %[[DST_OFFSET_I64]],
+ // CHECK-SAME: %[[LENGTH_I64]], %[[DEVICE_BUFFER_SLOT]])
+ hal.command_buffer.update_buffer<%cmd : !hal.command_buffer>
+ source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset]
+ target(%device_buffer : index)[%dst_offset]
+ length(%length)
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @command_buffer_copy_buffer
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>, %[[BUFFER:.+]]: !vm.ref<!hal.buffer>)
util.func public @command_buffer_copy_buffer(
- %arg0: !hal.command_buffer,
- %arg1: !hal.buffer
+ %cmd: !hal.command_buffer,
+ %buffer: !hal.buffer
) {
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK: vm.call @hal.command_buffer.copy_buffer(%arg0, %arg1, %c100, %arg1, %c200, %c300) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.buffer>, i64, !vm.ref<!hal.buffer>, i64, i64) -> ()
- hal.command_buffer.copy_buffer<%arg0 : !hal.command_buffer>
- source(%arg1 : !hal.buffer)[%c100]
- target(%arg1 : !hal.buffer)[%c200]
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK: vm.call @hal.command_buffer.copy_buffer(%[[CMD]], %[[UNUSED_SLOT]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[BUFFER]], %c200, %c300)
+ hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer>
+ source(%buffer : !hal.buffer)[%c100]
+ target(%buffer : !hal.buffer)[%c200]
+ length(%c300)
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_copy_buffer_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>, %[[BUFFER_SLOT:.+]]: i32)
+util.func public @command_buffer_copy_buffer_indirect(
+ %cmd: !hal.command_buffer,
+ %buffer_slot: index
+) {
+ %c100 = arith.constant 100 : index
+ %c200 = arith.constant 200 : index
+ %c300 = arith.constant 300 : index
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.copy_buffer(%[[CMD]], %[[BUFFER_SLOT]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[NULL_BUFFER]], %c200, %c300)
+ hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer>
+ source(%buffer_slot : index)[%c100]
+ target(%buffer_slot : index)[%c200]
length(%c300)
util.return
}
@@ -159,16 +233,17 @@
%send_buffer: !hal.buffer, %recv_buffer: !hal.buffer,
%count: index) {
// CHECK-DAG: %[[OP_BITS:.+]] = vm.const.i32 590081
- // CHECK-DAG: %[[PARAM:.+]] = vm.const.i32.zero
+ // CHECK-DAG: %[[ZERO_I32:.+]] = vm.const.i32.zero
%c10 = arith.constant 10 : index
%c20 = arith.constant 20 : index
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
// CHECK-DAG: %[[COUNT_I64:.+]] = vm.ext.i32.i64.s %[[COUNT]]
// CHECK: vm.call @hal.command_buffer.collective
- // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[PARAM]]
- // CHECK-SAME: %[[SEND_BUFFER]], %c10, %c128,
- // CHECK-SAME: %[[RECV_BUFFER]], %c20, %c256,
+ // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[ZERO_I32]]
+ // CHECK-SAME: %[[ZERO_I32]], %[[ZERO_I32]],
+ // CHECK-SAME: %[[SEND_BUFFER]], %[[RECV_BUFFER]],
+ // CHECK-SAME: %c10, %c128, %c20, %c256,
// CHECK-SAME: %[[COUNT_I64]])
hal.command_buffer.collective<%cmd : !hal.command_buffer>
channel(%channel : !hal.channel)
@@ -193,15 +268,18 @@
%param: i32,
%send_buffer: !hal.buffer,
%count: index) {
- // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
// CHECK-DAG: %[[OP_BITS:.+]] = vm.const.i32 262150
%c10 = arith.constant 10 : index
%c128 = arith.constant 128 : index
// CHECK-DAG: %[[COUNT_I64:.+]] = vm.ext.i32.i64.s %[[COUNT]]
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK-DAG: %[[ZERO_I64:.+]] = vm.const.i64.zero
// CHECK: vm.call @hal.command_buffer.collective
// CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[PARAM]],
- // CHECK-SAME: %[[SEND_BUFFER]], %c10, %c128,
- // CHECK-SAME: %[[NULL_BUFFER]], %zero, %zero,
+ // CHECK-SAME: %[[UNUSED_SLOT]], %[[UNUSED_SLOT]],
+ // CHECK-SAME: %[[SEND_BUFFER]], %[[NULL_BUFFER]],
+ // CHECK-SAME: %c10, %c128, %[[ZERO_I64]], %[[ZERO_I64]],
// CHECK-SAME: %[[COUNT_I64]])
hal.command_buffer.collective<%cmd : !hal.command_buffer>
channel(%channel : !hal.channel)
@@ -215,10 +293,10 @@
// -----
// CHECK-LABEL: @command_buffer_push_descriptor_set
-// CHECK-SAME: %[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
-// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref<!hal.pipeline_layout>,
-// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
-// CHECK-SAME: %[[SLOT:.+]]: i32
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref<!hal.pipeline_layout>,
+// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>,
+// CHECK-SAME: %[[SLOT:.+]]: i32)
util.func public @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
@@ -250,18 +328,20 @@
// -----
// CHECK-LABEL: @command_buffer_dispatch
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>)
util.func public @command_buffer_dispatch(
- %arg0: !hal.command_buffer,
- %arg1: !hal.executable
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable
) {
// CHECK: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
%c200 = arith.constant 200 : index
%c300 = arith.constant 300 : index
- // CHECK: vm.call @hal.command_buffer.dispatch(%arg0, %arg1, %[[ORDINAL]], %c100, %c200, %c300) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, i32, i32, i32) -> ()
- hal.command_buffer.dispatch<%arg0 : !hal.command_buffer>
- target(%arg1 : !hal.executable)[%ordinal]
+ // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300)
+ hal.command_buffer.dispatch<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
workgroups([%c100, %c200, %c300])
util.return
}
@@ -269,17 +349,43 @@
// -----
// CHECK-LABEL: @command_buffer_dispatch_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
+// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref<!hal.buffer>)
util.func public @command_buffer_dispatch_indirect(
- %arg0: !hal.command_buffer,
- %arg1: !hal.executable,
- %arg2: !hal.buffer
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable,
+ %buffer: !hal.buffer
) {
- // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
%ordinal = arith.constant 123 : index
%c100 = arith.constant 100 : index
- // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%arg0, %arg1, %[[ORDINAL]], %arg2, %c100) : (!vm.ref<!hal.command_buffer>, !vm.ref<!hal.executable>, i32, !vm.ref<!hal.buffer>, i64) -> ()
- hal.command_buffer.dispatch.indirect<%arg0 : !hal.command_buffer>
- target(%arg1 : !hal.executable)[%ordinal]
- workgroups(%arg2 : !hal.buffer)[%c100]
+ // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero
+ // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100)
+ hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
+ workgroups(%buffer : !hal.buffer)[%c100]
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>,
+// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref<!hal.executable>,
+// CHECK-SAME: %[[BUFFER_SLOT:.+]]: i32)
+util.func public @command_buffer_dispatch_indirect_indirect(
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable,
+ %buffer_slot: index
+) {
+ // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123
+ %ordinal = arith.constant 123 : index
+ %c100 = arith.constant 100 : index
+ // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref<!hal.buffer>
+ // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100)
+ hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
+ target(%executable : !hal.executable)[%ordinal]
+ workgroups(%buffer_slot : index)[%c100]
util.return
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 68a4c5f..d0abb71 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1291,7 +1291,7 @@
let arguments = (ins
HAL_CommandBuffer:$command_buffer,
- HAL_BufferType:$target_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer,
HAL_DeviceSize:$target_offset,
HAL_DeviceSize:$length,
HAL_FillPatternType:$pattern
@@ -1360,9 +1360,9 @@
let arguments = (ins
HAL_CommandBuffer:$command_buffer,
- HAL_BufferType:$source_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$source_buffer,
HAL_DeviceSize:$source_offset,
- HAL_BufferType:$target_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer,
HAL_DeviceSize:$target_offset,
HAL_DeviceSize:$length
);
@@ -1396,10 +1396,10 @@
Optional<I32>:$param,
// TODO(benvanik): change this to take descriptor set + binding instead.
// This would let us use indirect bindings.
- Optional<HAL_BufferType>:$send_buffer,
+ Optional<AnyTypeOf<[Index, HAL_BufferType]>>:$send_buffer,
Optional<HAL_DeviceSize>:$send_offset,
Optional<HAL_DeviceSize>:$send_length,
- Optional<HAL_BufferType>:$recv_buffer,
+ Optional<AnyTypeOf<[Index, HAL_BufferType]>>:$recv_buffer,
Optional<HAL_DeviceSize>:$recv_offset,
Optional<HAL_DeviceSize>:$recv_length
);
@@ -1529,7 +1529,7 @@
HAL_CommandBuffer:$command_buffer,
HAL_Executable:$executable,
HAL_Ordinal:$entry_point,
- HAL_BufferType:$workgroups_buffer,
+ AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer,
HAL_DeviceSize:$workgroups_offset
);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
index c3348d9..5598e39 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -101,9 +101,10 @@
// CHECK-LABEL: @command_buffer_copy_buffer
// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
-// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
-// CHECK-SAME: %[[SRC_OFFSET:.+]]: index, %[[DST_OFFSET:.+]]: index,
-// CHECK-SAME: %[[LENGTH:.+]]: index)
+// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
+// CHECK-SAME: %[[SRC_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DST_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index)
util.func public @command_buffer_copy_buffer(
%cmd: !hal.command_buffer,
%buffer: !hal.buffer,
@@ -124,11 +125,38 @@
// -----
+// CHECK-LABEL: @command_buffer_copy_buffer_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[BUFFER_SLOT:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[SRC_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DST_OFFSET:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index)
+util.func public @command_buffer_copy_buffer_indirect(
+ %cmd: !hal.command_buffer,
+ %buffer_slot: index,
+ %src_offset: index,
+ %dst_offset: index,
+ %length: index
+ ) {
+ // CHECK: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer>
+ // CHECK-SAME: source(%[[BUFFER_SLOT]] : index)[%[[SRC_OFFSET]]]
+ // CHECK-SAME: target(%[[BUFFER_SLOT]] : index)[%[[DST_OFFSET]]]
+ // CHECK-SAME: length(%[[LENGTH]])
+ hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer>
+ source(%buffer_slot : index)[%src_offset]
+ target(%buffer_slot : index)[%dst_offset]
+ length(%length)
+ util.return
+}
+
+// -----
+
// CHECK-LABEL: @command_buffer_collective
// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
// CHECK-SAME: %[[CHANNEL:.+]]: !hal.channel,
// CHECK-SAME: %[[PARAM:.+]]: i32,
-// CHECK-SAME: %[[SEND_BUFFER:.+]]: !hal.buffer, %[[RECV_BUFFER:.+]]: !hal.buffer,
+// CHECK-SAME: %[[SEND_BUFFER:[a-z0-9]+]]: !hal.buffer,
+// CHECK-SAME: %[[RECV_BUFFER:[a-z0-9]+]]: !hal.buffer,
// CHECK-SAME: %[[COUNT:.+]]: index)
util.func public @command_buffer_collective(
%cmd: !hal.command_buffer,
@@ -186,10 +214,10 @@
// -----
// CHECK-LABEL: @command_buffer_push_descriptor_set
-// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer,
-// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout,
-// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
-// CHECK-SAME: %[[SLOT:.+]]: index
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout,
+// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
+// CHECK-SAME: %[[SLOT:.+]]: index)
util.func public @command_buffer_push_descriptor_set(
%cmd: !hal.command_buffer,
%layout: !hal.pipeline_layout,
@@ -273,3 +301,33 @@
workgroups(%buffer : !hal.buffer)[%offset]
util.return
}
+
+// -----
+
+hal.executable @ex {
+ hal.executable.variant @backend target(<"backend", "format">) {
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [
+ #hal.descriptor_set.layout<0, bindings = [
+ #hal.descriptor_set.binding<0, storage_buffer>,
+ #hal.descriptor_set.binding<1, storage_buffer>
+ ]>
+ ]>)
+ }
+}
+
+// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect
+// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer,
+// CHECK-SAME: %[[EXECUTABLE:[a-z0-9]+]]: !hal.executable, %[[ORDINAL:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[BUFFER_SLOT:[a-z0-9]+]]: index, %[[OFFSET:[a-z0-9]+]]: index)
+util.func public @command_buffer_dispatch_indirect_indirect(
+ %cmd: !hal.command_buffer,
+ %executable: !hal.executable, %ordinal: index,
+ %buffer_slot: index, %offset: index) {
+ // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer>
+ // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]]
+ // CHECK-SAME: workgroups(%[[BUFFER_SLOT]] : index)[%[[OFFSET]]]
+ hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer>
+ target(%executable: !hal.executable)[%ordinal]
+ workgroups(%buffer_slot : index)[%offset]
+ util.return
+}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
index de1ed18..0a923c9 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -231,28 +231,35 @@
)
// Fills the target buffer with the given repeating value.
+// NOTE: order slightly differs from op in order to get better arg alignment.
vm.import private @command_buffer.fill_buffer(
%command_buffer : !vm.ref<!hal.command_buffer>,
%target_buffer : !vm.ref<!hal.buffer>,
%target_offset : i64,
%length : i64,
+ %target_buffer_slot : i32,
%pattern : i32,
%pattern_length: i32
)
// Updates a device buffer with the captured contents of a host buffer.
+// NOTE: order slightly differs from op in order to get better arg alignment.
vm.import private @command_buffer.update_buffer(
%command_buffer : !vm.ref<!hal.command_buffer>,
%source_buffer : !vm.buffer,
%source_offset : i64,
%target_buffer : !vm.ref<!hal.buffer>,
%target_offset : i64,
- %length : i64
+ %length : i64,
+ %target_buffer_slot : i32
)
// Copies a range of one buffer to another.
+// NOTE: order slightly differs from op in order to get better arg alignment.
vm.import private @command_buffer.copy_buffer(
%command_buffer : !vm.ref<!hal.command_buffer>,
+ %source_buffer_slot : i32,
+ %target_buffer_slot : i32,
%source_buffer : !vm.ref<!hal.buffer>,
%source_offset : i64,
%target_buffer : !vm.ref<!hal.buffer>,
@@ -267,10 +274,12 @@
%channel : !vm.ref<!hal.channel>,
%op : i32,
%param : i32,
+ %send_buffer_slot : i32,
+ %recv_buffer_slot : i32,
%send_buffer : !vm.ref<!hal.buffer>,
+ %recv_buffer : !vm.ref<!hal.buffer>,
%send_offset : i64,
%send_length : i64,
- %recv_buffer : !vm.ref<!hal.buffer>,
%recv_offset : i64,
%recv_length : i64,
%element_count : i64
@@ -309,6 +318,7 @@
%command_buffer : !vm.ref<!hal.command_buffer>,
%executable : !vm.ref<!hal.executable>,
%entry_point : i32,
+ %workgroups_buffer_slot : i32,
%workgroups_buffer : !vm.ref<!hal.buffer>,
%workgroups_offset : i64
)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
index 3ba86d9..8c3c502 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -58,8 +58,7 @@
return success();
}
-Value castToImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter) {
+Value castToImportType(Value value, Type targetType, OpBuilder &builder) {
auto sourceType = value.getType();
if (sourceType == targetType)
return value;
@@ -70,36 +69,35 @@
if (llvm::isa<FloatType>(sourceType) && llvm::isa<IntegerType>(targetType) &&
sourceType.getIntOrFloatBitWidth() ==
targetType.getIntOrFloatBitWidth()) {
- return rewriter.create<mlir::arith::BitcastOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::BitcastOp>(value.getLoc(), targetType,
+ value);
} else if (sourceIsInteger &&
(targetType.isSignedInteger() || targetType.isSignlessInteger())) {
if (targetType.getIntOrFloatBitWidth() >
sourceType.getIntOrFloatBitWidth()) {
- return rewriter.create<mlir::arith::ExtSIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::ExtSIOp>(value.getLoc(), targetType,
+ value);
} else {
- return rewriter.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
+ value);
}
} else if (sourceIsInteger && targetType.isUnsignedInteger()) {
if (targetType.getIntOrFloatBitWidth() >
sourceType.getIntOrFloatBitWidth()) {
- return rewriter.create<mlir::arith::ExtUIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::ExtUIOp>(value.getLoc(), targetType,
+ value);
} else {
- return rewriter.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
- value);
+ return builder.create<mlir::arith::TruncIOp>(value.getLoc(), targetType,
+ value);
}
} else {
return value;
}
}
-Value castFromImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter) {
+Value castFromImportType(Value value, Type targetType, OpBuilder &builder) {
// Right now the to-import and from-import types are the same.
- return castToImportType(value, targetType, rewriter);
+ return castToImportType(value, targetType, builder);
}
void copyImportAttrs(IREE::VM::ImportOp importOp, Operation *callOp) {
@@ -118,15 +116,16 @@
}
}
-std::optional<SmallVector<Value>>
-rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType,
- ConversionPatternRewriter &rewriter) {
+std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc,
+ Attribute attrValue,
+ Type inputType,
+ OpBuilder &builder) {
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrValue)) {
// NOTE: we intentionally go to std.constant ops so that the standard
// conversions can do their job. If we want to remove the dependency
// from standard ops in the future we could instead go directly to
// one of the vm constant ops.
- auto constValue = rewriter.createOrFold<mlir::arith::ConstantOp>(
+ auto constValue = builder.create<mlir::arith::ConstantOp>(
loc, inputType,
IntegerAttr::get(inputType,
APInt(32, static_cast<int32_t>(intAttr.getInt()))));
@@ -136,7 +135,7 @@
SmallVector<Value> elementValues;
elementValues.reserve(elementsAttr.getNumElements());
for (auto intAttr : elementsAttr.getValues<Attribute>()) {
- elementValues.push_back(rewriter.createOrFold<mlir::arith::ConstantOp>(
+ elementValues.push_back(builder.create<mlir::arith::ConstantOp>(
loc, elementsAttr.getType().getElementType(),
cast<TypedAttr>(intAttr)));
}
@@ -146,7 +145,7 @@
SmallVector<Value> allValues;
for (auto elementAttr : arrayAttr) {
auto flattenedValues =
- rewriteAttrToOperands(loc, elementAttr, inputType, rewriter);
+ rewriteAttrToOperands(loc, elementAttr, inputType, builder);
if (!flattenedValues)
return std::nullopt;
allValues.append(flattenedValues->begin(), flattenedValues->end());
@@ -154,7 +153,7 @@
return allValues;
}
if (auto strAttr = llvm::dyn_cast<StringAttr>(attrValue)) {
- return {{rewriter.create<IREE::VM::RodataInlineOp>(loc, strAttr)}};
+ return {{builder.create<IREE::VM::RodataInlineOp>(loc, strAttr)}};
}
// This may be a custom dialect type. As we can't trivially access the storage
@@ -176,7 +175,7 @@
return;
auto elementType = tupleTypes[ordinal++];
auto flattenedValues =
- rewriteAttrToOperands(loc, elementAttr, elementType, rewriter);
+ rewriteAttrToOperands(loc, elementAttr, elementType, builder);
if (!flattenedValues) {
anyFailed = true;
return;
@@ -192,7 +191,7 @@
if (anyFailed)
return;
auto flattenedValues =
- rewriteAttrToOperands(loc, elementAttr, inputType, rewriter);
+ rewriteAttrToOperands(loc, elementAttr, inputType, builder);
if (!flattenedValues) {
anyFailed = true;
return;
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index 3e80906..b2f0a8f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -30,20 +30,19 @@
namespace detail {
size_t getSegmentSpanSize(Type spanType);
-std::optional<SmallVector<Value>>
-rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType,
- ConversionPatternRewriter &rewriter);
+std::optional<SmallVector<Value>> rewriteAttrToOperands(Location loc,
+ Attribute attrValue,
+ Type inputType,
+ OpBuilder &builder);
} // namespace detail
// Casts |value| to |targetType| ala static_cast for when the declared type
// differs from the type provided by the input dialect.
-Value castToImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter);
+Value castToImportType(Value value, Type targetType, OpBuilder &builder);
// Casts |value| to |targetType| ala static_cast for when the declared return
// type of an import does not match the required output type.
-Value castFromImportType(Value value, Type targetType,
- ConversionPatternRewriter &rewriter);
+Value castFromImportType(Value value, Type targetType, OpBuilder &builder);
// Copies known attributes from the |importOp| to the |callOp|.
// This allows for passes to quickly query the properties of the import such as
@@ -56,8 +55,7 @@
template <typename T, typename Adaptor = typename T::Adaptor>
std::optional<SmallVector<Value>>
rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp,
- const TypeConverter &typeConverter,
- ConversionPatternRewriter &rewriter) {
+ const TypeConverter &typeConverter, OpBuilder &builder) {
auto *operation = op.getOperation();
bool isOpVariadic = importOp.isVariadic();
OperationState state{
@@ -76,7 +74,7 @@
auto inputName = importOp.getFuncArgumentName(input.index());
if (auto attrValue = op->getAttr(inputName)) {
auto flattenedAttrs = detail::rewriteAttrToOperands(
- op.getLoc(), attrValue, inputType, rewriter);
+ op.getLoc(), attrValue, inputType, builder);
if (!flattenedAttrs)
return std::nullopt;
state.addOperands(*flattenedAttrs);
@@ -101,11 +99,11 @@
}
for (auto [newOperand, inputType] :
llvm::zip_equal(newOperands, inputTupleType.getTypes())) {
- state.addOperands(castToImportType(newOperand, inputType, rewriter));
+ state.addOperands(castToImportType(newOperand, inputType, builder));
}
} else {
for (auto &operand : newOperands) {
- state.addOperands(castToImportType(operand, inputType, rewriter));
+ state.addOperands(castToImportType(operand, inputType, builder));
}
}
@@ -121,16 +119,16 @@
"segment_sizes",
DenseIntElementsAttr::get(
VectorType::get({static_cast<int64_t>(segmentSizes.size())},
- rewriter.getIntegerType(16)),
+ builder.getIntegerType(16)),
segmentSizes));
state.addAttribute("segment_types",
- rewriter.getArrayAttr(llvm::map_to_vector(
+ builder.getArrayAttr(llvm::map_to_vector(
importType.getInputs(), [&](Type type) {
return cast<Attribute>(TypeAttr::get(type));
})));
}
- auto *callOp = rewriter.create(state);
+ auto *callOp = builder.create(state);
copyImportAttrs(importOp, callOp);
SmallVector<Value> results;
@@ -139,7 +137,7 @@
targetType = typeConverter.convertType(targetType);
if (!targetType)
return std::nullopt;
- results.push_back(castFromImportType(result, targetType, rewriter));
+ results.push_back(castFromImportType(result, targetType, builder));
}
return results;
}
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl
index bd5adee..b87a8cb 100644
--- a/runtime/src/iree/modules/hal/exports.inl
+++ b/runtime/src/iree/modules/hal/exports.inl
@@ -47,18 +47,18 @@
EXPORT_FN("channel.split", iree_hal_module_channel_split, riii, r)
EXPORT_FN("command_buffer.begin_debug_group", iree_hal_module_command_buffer_begin_debug_group, rr, v)
-EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriirIIrIII, v)
-EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, rrIrII, v)
+EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v)
+EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v)
EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r)
EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v)
-EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rrirI, v)
+EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirI, v)
EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v)
EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v)
-EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIii, v)
+EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v)
EXPORT_FN("command_buffer.finalize", iree_hal_module_command_buffer_finalize, r, v)
EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v)
EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiirIID, v)
-EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrII, v)
+EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrIIi, v)
EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c
index 52b3fe3..c0db04b 100644
--- a/runtime/src/iree/modules/hal/module.c
+++ b/runtime/src/iree/modules/hal/module.c
@@ -759,72 +759,76 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, //
iree_hal_module_state_t, //
- rrIIii, v) {
+ rrIIiii, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
- iree_hal_buffer_t* target_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer));
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i2);
iree_device_size_t length = iree_hal_cast_device_size(args->i3);
- uint32_t pattern = (uint32_t)args->i4;
- uint32_t pattern_length = (uint32_t)args->i5;
+ uint32_t target_buffer_slot = (uint32_t)args->i4;
+ iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
+ target_buffer_slot, target_offset, length);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r1, &target_ref.buffer));
+ uint32_t pattern = (uint32_t)args->i5;
+ uint32_t pattern_length = (uint32_t)args->i6;
- iree_hal_buffer_ref_t target_ref =
- iree_hal_make_buffer_ref(target_buffer, target_offset, length);
return iree_hal_command_buffer_fill_buffer(command_buffer, target_ref,
&pattern, pattern_length);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_update_buffer, //
iree_hal_module_state_t, //
- rrIrII, v) {
+ rrIrIIi, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_vm_buffer_t* source_buffer = NULL;
IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer));
iree_host_size_t source_offset = iree_hal_cast_host_size(args->i2);
- iree_hal_buffer_t* target_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer));
iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4);
iree_device_size_t length = iree_hal_cast_device_size(args->i5);
+ uint32_t target_buffer_slot = (uint32_t)args->i6;
+ iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
+ target_buffer_slot, target_offset, length);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r3, &target_ref.buffer));
iree_const_byte_span_t source_span = iree_const_byte_span_empty();
IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro(
source_buffer, source_offset, (iree_host_size_t)length, 1, &source_span));
- iree_hal_buffer_ref_t target_ref =
- iree_hal_make_buffer_ref(target_buffer, target_offset, length);
return iree_hal_command_buffer_update_buffer(command_buffer, source_span.data,
/*source_offset=*/0, target_ref);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, //
iree_hal_module_state_t, //
- rrIrII, v) {
+ riirIrII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
- iree_hal_buffer_t* source_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &source_buffer));
- iree_device_size_t source_offset = iree_hal_cast_device_size(args->i2);
- iree_hal_buffer_t* target_buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer));
- iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4);
- iree_device_size_t length = iree_hal_cast_device_size(args->i5);
+ uint32_t source_buffer_slot = (uint32_t)args->i1;
+ uint32_t target_buffer_slot = (uint32_t)args->i2;
+ iree_device_size_t source_offset = iree_hal_cast_device_size(args->i4);
+ iree_device_size_t target_offset = iree_hal_cast_device_size(args->i6);
+ iree_device_size_t length = iree_hal_cast_device_size(args->i7);
+ iree_hal_buffer_ref_t source_ref = iree_hal_make_indirect_buffer_ref(
+ source_buffer_slot, source_offset, length);
+ iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref(
+ target_buffer_slot, target_offset, length);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r3, &source_ref.buffer));
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_check_deref_or_null(args->r5, &target_ref.buffer));
- iree_hal_buffer_ref_t source_ref =
- iree_hal_make_buffer_ref(source_buffer, source_offset, length);
- iree_hal_buffer_ref_t target_ref =
- iree_hal_make_buffer_ref(target_buffer, target_offset, length);
return iree_hal_command_buffer_copy_buffer(command_buffer, source_ref,
target_ref);
}
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_collective, //
iree_hal_module_state_t, //
- rriirIIrIII, v) {
+ rriiiirrIIIII, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
@@ -832,17 +836,19 @@
IREE_RETURN_IF_ERROR(iree_hal_channel_check_deref(args->r1, &channel));
iree_hal_collective_op_t op = {.packed = args->i2};
uint32_t param = args->i3;
- iree_hal_buffer_ref_t send_ref =
- iree_hal_make_buffer_ref(NULL, iree_hal_cast_device_size(args->i5),
- iree_hal_cast_device_size(args->i6));
+ uint32_t send_buffer_slot = (uint32_t)args->i4;
+ uint32_t recv_buffer_slot = (uint32_t)args->i5;
+ iree_hal_buffer_ref_t send_ref = iree_hal_make_indirect_buffer_ref(
+ send_buffer_slot, iree_hal_cast_device_size(args->i8),
+ iree_hal_cast_device_size(args->i9));
IREE_RETURN_IF_ERROR(
- iree_hal_buffer_check_deref_or_null(args->r4, &send_ref.buffer));
- iree_hal_buffer_ref_t recv_ref =
- iree_hal_make_buffer_ref(NULL, iree_hal_cast_device_size(args->i8),
- iree_hal_cast_device_size(args->i9));
+ iree_hal_buffer_check_deref_or_null(args->r6, &send_ref.buffer));
+ iree_hal_buffer_ref_t recv_ref = iree_hal_make_indirect_buffer_ref(
+ recv_buffer_slot, iree_hal_cast_device_size(args->i10),
+ iree_hal_cast_device_size(args->i11));
IREE_RETURN_IF_ERROR(
iree_hal_buffer_check_deref_or_null(args->r7, &recv_ref.buffer));
- iree_device_size_t element_count = iree_hal_cast_device_size(args->i10);
+ iree_device_size_t element_count = iree_hal_cast_device_size(args->i12);
return iree_hal_command_buffer_collective(command_buffer, channel, op, param,
send_ref, recv_ref, element_count);
@@ -919,20 +925,20 @@
IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, //
iree_hal_module_state_t, //
- rrirI, v) {
+ rriirI, v) {
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_check_deref(args->r0, &command_buffer));
iree_hal_executable_t* executable = NULL;
IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable));
uint32_t entry_point = (uint32_t)args->i2;
- iree_hal_buffer_t* workgroups_buffer = NULL;
+ uint32_t workgroups_buffer_slot = (uint32_t)args->i3;
+ iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i5);
+ iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_indirect_buffer_ref(
+ workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t));
IREE_RETURN_IF_ERROR(
- iree_hal_buffer_check_deref(args->r3, &workgroups_buffer));
- iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i4);
+ iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer));
- iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_buffer_ref(
- workgroups_buffer, workgroups_offset, 3 * sizeof(uint32_t));
return iree_hal_command_buffer_dispatch_indirect(command_buffer, executable,
entry_point, workgroups_ref);
}
diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c
index 1d5744d..c89a8b7 100644
--- a/runtime/src/iree/vm/shims.c
+++ b/runtime/src/iree/vm/shims.c
@@ -45,7 +45,7 @@
IREE_VM_ABI_DEFINE_SHIM(riiIi, r);
IREE_VM_ABI_DEFINE_SHIM(rIiiI, r);
IREE_VM_ABI_DEFINE_SHIM(riIiirII, r);
-IREE_VM_ABI_DEFINE_SHIM(rriirIIrIII, v);
+IREE_VM_ABI_DEFINE_SHIM(rriiiirrIIIII, v);
IREE_VM_ABI_DEFINE_SHIM(rrrrCrD, r);
IREE_VM_ABI_DEFINE_SHIM(ririi, v);
IREE_VM_ABI_DEFINE_SHIM(rr, i);
@@ -60,10 +60,12 @@
IREE_VM_ABI_DEFINE_SHIM(rriiCID, v);
IREE_VM_ABI_DEFINE_SHIM(rriCiirIID, v);
IREE_VM_ABI_DEFINE_SHIM(rriiii, v);
-IREE_VM_ABI_DEFINE_SHIM(rrIIii, v);
+IREE_VM_ABI_DEFINE_SHIM(rrIIiii, v);
IREE_VM_ABI_DEFINE_SHIM(rrirCID, v);
IREE_VM_ABI_DEFINE_SHIM(rrirI, v);
-IREE_VM_ABI_DEFINE_SHIM(rrIrII, v);
+IREE_VM_ABI_DEFINE_SHIM(rriirI, v);
+IREE_VM_ABI_DEFINE_SHIM(rrIrIIi, v);
+IREE_VM_ABI_DEFINE_SHIM(riirIrII, v);
IREE_VM_ABI_DEFINE_SHIM(rrIii, v);
IREE_VM_ABI_DEFINE_SHIM(rrrIii, v);
IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r);
diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h
index 195daf8..1d5a3aa 100644
--- a/runtime/src/iree/vm/shims.h
+++ b/runtime/src/iree/vm/shims.h
@@ -355,18 +355,20 @@
int64_t i7;
});
-IREE_VM_ABI_FIXED_STRUCT(rriirIIrIII, {
+IREE_VM_ABI_FIXED_STRUCT(rriiiirrIIIII, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int32_t i2;
int32_t i3;
- iree_vm_ref_t r4;
- int64_t i5;
- int64_t i6;
+ int32_t i4;
+ int32_t i5;
+ iree_vm_ref_t r6;
iree_vm_ref_t r7;
int64_t i8;
int64_t i9;
int64_t i10;
+ int64_t i11;
+ int64_t i12;
});
IREE_VM_ABI_FIXED_STRUCT(rriiii, {
@@ -378,13 +380,14 @@
int32_t i5;
});
-IREE_VM_ABI_FIXED_STRUCT(rrIIii, {
+IREE_VM_ABI_FIXED_STRUCT(rrIIiii, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int64_t i2;
int64_t i3;
int32_t i4;
int32_t i5;
+ int32_t i6;
});
IREE_VM_ABI_FIXED_STRUCT(rrirI, {
@@ -395,13 +398,34 @@
int64_t i4;
});
-IREE_VM_ABI_FIXED_STRUCT(rrIrII, {
+IREE_VM_ABI_FIXED_STRUCT(rriirI, {
+ iree_vm_ref_t r0;
+ iree_vm_ref_t r1;
+ int32_t i2;
+ int32_t i3;
+ iree_vm_ref_t r4;
+ int64_t i5;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(rrIrIIi, {
iree_vm_ref_t r0;
iree_vm_ref_t r1;
int64_t i2;
iree_vm_ref_t r3;
int64_t i4;
int64_t i5;
+ int32_t i6;
+});
+
+IREE_VM_ABI_FIXED_STRUCT(riirIrII, {
+ iree_vm_ref_t r0;
+ int32_t i1;
+ int32_t i2;
+ iree_vm_ref_t r3;
+ int64_t i4;
+ iree_vm_ref_t r5;
+ int64_t i6;
+ int64_t i7;
});
IREE_VM_ABI_FIXED_STRUCT(rrIii, {
@@ -670,7 +694,7 @@
IREE_VM_ABI_DECLARE_SHIM(riiIi, r);
IREE_VM_ABI_DECLARE_SHIM(rIiiI, r);
IREE_VM_ABI_DECLARE_SHIM(riIiirII, r);
-IREE_VM_ABI_DECLARE_SHIM(rriirIIrIII, v);
+IREE_VM_ABI_DECLARE_SHIM(rriiiirrIIIII, v);
IREE_VM_ABI_DECLARE_SHIM(rrrrCrD, r);
IREE_VM_ABI_DECLARE_SHIM(ririi, v);
IREE_VM_ABI_DECLARE_SHIM(rr, i);
@@ -685,10 +709,12 @@
IREE_VM_ABI_DECLARE_SHIM(rriiCID, v);
IREE_VM_ABI_DECLARE_SHIM(rriCiirIID, v);
IREE_VM_ABI_DECLARE_SHIM(rriiii, v);
-IREE_VM_ABI_DECLARE_SHIM(rrIIii, v);
+IREE_VM_ABI_DECLARE_SHIM(rrIIiii, v);
IREE_VM_ABI_DECLARE_SHIM(rrirCID, v);
IREE_VM_ABI_DECLARE_SHIM(rrirI, v);
-IREE_VM_ABI_DECLARE_SHIM(rrIrII, v);
+IREE_VM_ABI_DECLARE_SHIM(rriirI, v);
+IREE_VM_ABI_DECLARE_SHIM(rrIrIIi, v);
+IREE_VM_ABI_DECLARE_SHIM(riirIrII, v);
IREE_VM_ABI_DECLARE_SHIM(rrIii, v);
IREE_VM_ABI_DECLARE_SHIM(rrrIii, v);
IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r);