Adding `hal.command_buffer.update_buffer`.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 4de1867..6071614 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp
@@ -127,6 +127,43 @@ mutable IREE::VM::ImportOp importOp; }; +class CommandBufferUpdateBufferOpConversion + : public OpConversionPattern<IREE::HAL::CommandBufferUpdateBufferOp> { +public: + CommandBufferUpdateBufferOpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(typeConverter, context) { + importOp = importSymbols.lookup<IREE::VM::ImportOp>(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::CommandBufferUpdateBufferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto importType = importOp.getFunctionType(); + SmallVector<Value, 8> callOperands = { + adaptor.getCommandBuffer(), + adaptor.getSourceBuffer(), + castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(), + rewriter), + adaptor.getTargetBuffer(), + castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(), + rewriter), + castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter), + }; + auto callOp = rewriter.replaceOpWithNewOp<IREE::VM::CallOp>( + op, SymbolRefAttr::get(importOp), importType.getResults(), + callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + class CommandBufferCollectiveOpConversion : public OpConversionPattern<IREE::HAL::CommandBufferCollectiveOp> { public: @@ -329,6 +366,9 @@ "hal.command_buffer.execution_barrier"); patterns.insert<CommandBufferFillBufferOpConversion>( context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer"); + patterns.insert<CommandBufferUpdateBufferOpConversion>( + context, importSymbols, typeConverter, + "hal.command_buffer.update_buffer"); patterns.insert<VMImportOpConversion<IREE::HAL::CommandBufferCopyBufferOp>>( context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer"); patterns.insert<CommandBufferCollectiveOpConversion>(
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 711e860..9a3e26d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir
@@ -98,6 +98,34 @@ // ----- +// CHECK-LABEL: @command_buffer_update_buffer +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref<!hal.command_buffer>, +// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[DEVICE_BUFFER:[a-z0-9]+]]: !vm.ref<!hal.buffer>, %[[DST_OFFSET:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: i32) +util.func public @command_buffer_update_buffer( + %cmd: !hal.command_buffer, + %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index, + %device_buffer: !hal.buffer, %dst_offset: index, + %length: index + ) { + // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]] + // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]] + // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]] + // CHECK: vm.call @hal.command_buffer.update_buffer + // CHECK-SAME: (%[[CMD]], + // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]], + // CHECK-SAME: %[[DEVICE_BUFFER]], %[[DST_OFFSET_I64]], + // CHECK-SAME: %[[LENGTH_I64]]) + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset] + target(%device_buffer : !hal.buffer)[%dst_offset] + length(%length) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_copy_buffer util.func public @command_buffer_copy_buffer( %arg0: !hal.command_buffer,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 784ade3..d082847 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -132,10 +132,10 @@ bool needsUpdate = false; auto newSourceBuffer = op.getSourceBuffer(); auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset()); - if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>( + if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>( op.getSourceBuffer().getDefiningOp())) { newSourceBuffer = subspanOp.getSourceBuffer(); - newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>( + newSourceOffset = rewriter.createOrFold<arith::AddIOp>( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getSourceOffset()); needsUpdate = true; @@ -220,10 +220,10 @@ bool needsUpdate = false; auto newTargetBuffer = op.getTargetBuffer(); auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset()); - if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>( + if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>( op.getTargetBuffer().getDefiningOp())) { newTargetBuffer = subspanOp.getSourceBuffer(); - newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>( + newTargetOffset = rewriter.createOrFold<arith::AddIOp>( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getTargetOffset()); needsUpdate = true; @@ -248,6 +248,46 @@ namespace { +/// Folds hal.buffer.subspans into buffer update offsets. +struct FoldCommandBufferUpdateBufferSubspans + : public OpRewritePattern<CommandBufferUpdateBufferOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CommandBufferUpdateBufferOp op, + PatternRewriter &rewriter) const override { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + bool needsUpdate = false; + auto newTargetBuffer = op.getTargetBuffer(); + auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset()); + if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>( + op.getTargetBuffer().getDefiningOp())) { + newTargetBuffer = subspanOp.getSourceBuffer(); + newTargetOffset = rewriter.createOrFold<arith::AddIOp>( + subspanOp.getLoc(), subspanOp.getSourceOffset(), + op.getTargetOffset()); + needsUpdate = true; + } + rewriter.restoreInsertionPoint(ip); + if (!needsUpdate) + return failure(); + rewriter.modifyOpInPlace(op, [&]() { + op.getTargetBufferMutable().assign(newTargetBuffer); + op.getTargetOffsetMutable().assign(newTargetOffset); + }); + return success(); + } +}; + +} // namespace + +void CommandBufferUpdateBufferOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.insert<FoldCommandBufferUpdateBufferSubspans>(context); +} + +namespace { + /// Folds hal.buffer.subspans into buffer copy offsets. struct FoldCommandBufferCopyBufferSubspans : public OpRewritePattern<CommandBufferCopyBufferOp> { @@ -260,20 +300,20 @@ bool needsUpdate = false; auto newSourceBuffer = op.getSourceBuffer(); auto newSourceOffset = llvm::cast<Value>(op.getSourceOffset()); - if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>( + if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>( op.getSourceBuffer().getDefiningOp())) { newSourceBuffer = subspanOp.getSourceBuffer(); - newSourceOffset = rewriter.createOrFold<mlir::arith::AddIOp>( + newSourceOffset = rewriter.createOrFold<arith::AddIOp>( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getSourceOffset()); needsUpdate = true; } auto newTargetBuffer = op.getTargetBuffer(); auto newTargetOffset = llvm::cast<Value>(op.getTargetOffset()); - if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>( + if (auto subspanOp = dyn_cast_or_null<IREE::HAL::BufferSubspanOp>( op.getTargetBuffer().getDefiningOp())) { newTargetBuffer = subspanOp.getSourceBuffer(); - newTargetOffset = rewriter.createOrFold<mlir::arith::AddIOp>( + newTargetOffset = rewriter.createOrFold<arith::AddIOp>( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getTargetOffset()); needsUpdate = true; @@ -317,10 +357,10 @@ auto *definingOp = bindingBuffers[i].getDefiningOp(); if (!definingOp) continue; - if (auto subspanOp = dyn_cast<BufferSubspanOp>(definingOp)) { + if (auto subspanOp = dyn_cast<IREE::HAL::BufferSubspanOp>(definingOp)) { needsUpdate = true; bindingBuffers[i] = subspanOp.getSourceBuffer(); - bindingOffsets[i] = rewriter.createOrFold<mlir::arith::AddIOp>( + bindingOffsets[i] = rewriter.createOrFold<arith::AddIOp>( subspanOp.getLoc(), subspanOp.getSourceOffset(), bindingOffsets[i]); } }
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index cc4de30..b8c0bab 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -991,6 +991,32 @@ } //===----------------------------------------------------------------------===// +// hal.command_buffer.update_buffer +//===----------------------------------------------------------------------===// + +IREE::Util::SubrangeOperand +CommandBufferUpdateBufferOp::getSubrangeOperand(unsigned operandIndex) { + if (operandIndex == 1) { + return IREE::Util::SubrangeOperand{getSourceBuffer(), getSourceSize(), + getSourceOffset(), getLength()}; + } else { + assert(false && "only source is a subrange"); + return {}; + } +} + +void CommandBufferUpdateBufferOp::setSubrangeOperand( + unsigned operandIndex, IREE::Util::SubrangeOperand operand) { + if (operandIndex == 1) { + getSourceBufferMutable().assign(operand.resource); + getSourceSizeMutable().assign(operand.resourceSize); + getSourceOffsetMutable().assign(operand.offset); + } else { + assert(false && "only source is a subrange"); + } +} + +//===----------------------------------------------------------------------===// // hal.command_buffer.push_descriptor_set //===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 1889f59..74d0b56 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1306,7 +1306,49 @@ let hasCanonicalizer = 1; } -// TODO(benvanik): update buffer op. +def HAL_CommandBufferUpdateBufferOp : HAL_Op<"command_buffer.update_buffer", [ + // TODO(benvanik): figure out the right way to model host effects - this is + // a host read but a device write; if we make it just MemRead then it gets + // DCEd because it has no result. For now we report both to keep analysis + // appeased even if incorrect. + MemoryEffects<[MemRead, MemWrite]>, + Util_SizeAwareOp, + DeclareOpInterfaceMethods<Util_SubrangeOperandOpInterface>, +]> { + let summary = [{command buffer buffer update recording operation}]; + let description = [{ + Copies a range of a host buffer into a device buffer. The host buffer + contents will be captured at the time of the call and embedded in the + command buffer. + }]; + + let arguments = (ins + HAL_CommandBuffer:$command_buffer, + Util_BufferType:$source_buffer, + Util_Size:$source_size, + Util_Size:$source_offset, + AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer, + HAL_DeviceSize:$target_offset, + HAL_DeviceSize:$length + ); + + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `source` `(` $source_buffer `:` type($source_buffer) `{` $source_size `}` `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + attr-dict-with-keyword + }]; + + let extraClassDeclaration = [{ + Value getOperandSize(unsigned idx) { return idx == 1 ? getSourceSize() : Value{}; } + Value getResultSize(unsigned idx) { return {}; } + }]; + + let hasCanonicalizer = 1; +} def HAL_CommandBufferCopyBufferOp : HAL_Op<"command_buffer.copy_buffer"> { let summary = [{command buffer buffer copy recording operation}];
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir index 5ced86a..a61ea4f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir
@@ -42,6 +42,35 @@ // ----- +// CHECK-LABEL: @fold_buffer_subspans_into_update_buffer +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !util.buffer, %[[SOURCE_BUFFER_SIZE:.+]]: index, +// CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer +util.func public @fold_buffer_subspans_into_update_buffer( + %cmd: !hal.command_buffer, + %source_buffer: !util.buffer, %source_buffer_size: index, + %target_buffer: !hal.buffer + ) { + %c0 = arith.constant 0 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + %c100000 = arith.constant 100000 : index + %c262144 = arith.constant 262144 : index + %source_subspan = util.buffer.subspan %source_buffer[%c4096] : !util.buffer{%source_buffer_size} -> !util.buffer{%c262144} + %target_subspan = hal.buffer.subspan<%target_buffer : !hal.buffer>[%c8192, %c262144] : !hal.buffer + // CHECK: hal.command_buffer.update_buffer + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + // CHECK-SAME: source(%[[SOURCE_BUFFER]] : !util.buffer{%[[SOURCE_BUFFER_SIZE]]})[%c4096] + source(%source_subspan : !util.buffer{%c262144})[%c0] + // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%c108192] + target(%target_subspan : !hal.buffer)[%c100000] + // CHECK-SAME: length(%c8192) + length(%c8192) + util.return +} + +// ----- + // CHECK-LABEL: @fold_buffer_subspan_into_copy_buffer // CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer, // CHECK-SAME: %[[BASE_BUFFER:.+]]: !hal.buffer
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index 766b39a..dc16d45 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir
@@ -73,6 +73,30 @@ // ----- +// CHECK-LABEL: @command_buffer_update_buffer +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !util.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: index, %[[SRC_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[DEVICE_BUFFER:[a-z0-9]+]]: !hal.buffer, %[[DST_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index) +util.func public @command_buffer_update_buffer( + %cmd: !hal.command_buffer, + %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index, + %device_buffer: !hal.buffer, %dst_offset: index, + %length: index + ) { + // CHECK: hal.command_buffer.update_buffer<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: source(%[[HOST_BUFFER]] : !util.buffer{%[[HOST_BUFFER_SIZE]]})[%[[SRC_OFFSET]]] + // CHECK-SAME: target(%[[DEVICE_BUFFER]] : !hal.buffer)[%[[DST_OFFSET]]] + // CHECK-SAME: length(%[[LENGTH]]) + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset] + target(%device_buffer : !hal.buffer)[%dst_offset] + length(%length) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_copy_buffer // CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, // CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 1d21fb3..6731942 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -239,6 +239,16 @@ %pattern_length: i32 ) +// Updates a device buffer with the captured contents of a host buffer. +vm.import private @command_buffer.update_buffer( + %command_buffer : !vm.ref<!hal.command_buffer>, + %source_buffer : !vm.buffer, + %source_offset : i64, + %target_buffer : !vm.ref<!hal.buffer>, + %target_offset : i64, + %length : i64 +) + // Copies a range of one buffer to another. vm.import private @command_buffer.copy_buffer( %command_buffer : !vm.ref<!hal.command_buffer>,
diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index 13f9d09..b808785 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl
@@ -58,6 +58,7 @@ EXPORT_FN("command_buffer.finalize", iree_hal_module_command_buffer_finalize, r, v) EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v) EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiirIID, v) +EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrII, v) EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r)
diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 1b9f8df..4b84671 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c
@@ -774,6 +774,30 @@ &pattern, pattern_length); } +IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_update_buffer, // + iree_hal_module_state_t, // + rrIrII, v) { + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); + iree_vm_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer)); + iree_host_size_t source_offset = iree_hal_cast_host_size(args->i2); + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer)); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4); + iree_device_size_t length = iree_hal_cast_device_size(args->i5); + + iree_const_byte_span_t source_span = iree_const_byte_span_empty(); + IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro( + source_buffer, source_offset, (iree_host_size_t)length, 1, &source_span)); + + iree_hal_buffer_ref_t target_ref = + iree_hal_make_buffer_ref(target_buffer, target_offset, length); + return iree_hal_command_buffer_update_buffer(command_buffer, source_span.data, + /*source_offset=*/0, target_ref); +} + IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, // iree_hal_module_state_t, // rrIrII, v) {