Merge pull request #7934 from google/benvanik-buffer-hooks
Adding iree_hal_allocator_t::deallocate_buffer.
diff --git a/bindings/python/iree/runtime/hal.h b/bindings/python/iree/runtime/hal.h
index 634c741..011e9d6 100644
--- a/bindings/python/iree/runtime/hal.h
+++ b/bindings/python/iree/runtime/hal.h
@@ -102,10 +102,10 @@
IREE_HAL_ELEMENT_TYPE_NONE, element_size * 8);
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
- CheckApiStatus(
- iree_hal_buffer_view_create(raw_ptr(), shape.s.data(), shape.s.size(),
- element_type, encoding_type, &bv),
- "Error creating buffer view");
+ CheckApiStatus(iree_hal_buffer_view_create(
+ raw_ptr(), shape.s.data(), shape.s.size(), element_type,
+ encoding_type, iree_allocator_system(), &bv),
+ "Error creating buffer view");
return HalBufferView::CreateRetained(bv);
}
};
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index 60e5f9a..d3d9dfa 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -247,10 +247,10 @@
std::vector<int> dims(py_view.ndim);
std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin());
iree_hal_buffer_view_t* buffer_view;
- CheckApiStatus(
- iree_hal_buffer_view_create(raw_buffer, dims.data(), dims.size(),
- element_type, encoding_type, &buffer_view),
- "Error allocating buffer_view");
+ CheckApiStatus(iree_hal_buffer_view_create(
+ raw_buffer, dims.data(), dims.size(), element_type,
+ encoding_type, iree_allocator_system(), &buffer_view),
+ "Error allocating buffer_view");
iree_hal_buffer_release(raw_buffer);
iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view);
CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &buffer_view_ref),
diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c
index 19b817f..093cd21 100644
--- a/experimental/rocm/rocm_allocator.c
+++ b/experimental/rocm/rocm_allocator.c
@@ -107,6 +107,19 @@
return compatibility;
}
+static void iree_hal_rocm_buffer_free(iree_hal_rocm_context_wrapper_t* context,
+ iree_hal_memory_type_t memory_type,
+ hipDeviceptr_t device_ptr,
+ void* host_ptr) {
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
+ // Device local.
+ ROCM_IGNORE_ERROR(context->syms, hipFree(device_ptr));
+ } else {
+ // Host local.
+ ROCM_IGNORE_ERROR(context->syms, hipHostFree(host_ptr));
+ }
+}
+
static iree_status_t iree_hal_rocm_allocator_allocate_buffer(
iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size,
@@ -148,37 +161,22 @@
}
if (iree_status_is_ok(status)) {
- IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc(
- &allocator->statistics, memory_type, allocation_size));
status = iree_hal_rocm_buffer_wrap(
(iree_hal_allocator_t*)allocator, memory_type,
IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size,
/*byte_offset=*/0,
/*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer);
}
- if (!iree_status_is_ok(status)) {
- iree_hal_rocm_allocator_free(base_allocator, memory_type, device_ptr,
- host_ptr, allocation_size);
+ if (iree_status_is_ok(status)) {
+ IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc(
+ &allocator->statistics, memory_type, allocation_size));
+ } else {
+ iree_hal_rocm_buffer_free(allocator->context, memory_type, device_ptr,
+ host_ptr);
}
return status;
}
-void iree_hal_rocm_allocator_free(iree_hal_allocator_t* base_allocator,
- iree_hal_memory_type_t memory_type,
- hipDeviceptr_t device_ptr, void* host_ptr,
- iree_device_size_t allocation_size) {
- iree_hal_rocm_allocator_t* allocator =
- iree_hal_rocm_allocator_cast(base_allocator);
- if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
- ROCM_IGNORE_ERROR(allocator->context->syms, hipFree(device_ptr));
- } else {
- // Host local.
- ROCM_IGNORE_ERROR(allocator->context->syms, hipHostFree(host_ptr));
- }
- IREE_STATISTICS(iree_hal_allocator_statistics_record_free(
- &allocator->statistics, memory_type, allocation_size));
-}
-
static iree_status_t iree_hal_rocm_allocator_wrap_buffer(
iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
@@ -188,6 +186,23 @@
"wrapping of external buffers not supported");
}
+static void iree_hal_rocm_allocator_deallocate_buffer(
+ iree_hal_allocator_t* base_allocator, iree_hal_buffer_t* base_buffer) {
+ iree_hal_rocm_allocator_t* allocator =
+ iree_hal_rocm_allocator_cast(base_allocator);
+
+ iree_hal_memory_type_t memory_type = iree_hal_buffer_memory_type(base_buffer);
+ iree_hal_rocm_buffer_free(allocator->context, memory_type,
+ iree_hal_rocm_buffer_device_pointer(base_buffer),
+ iree_hal_rocm_buffer_host_pointer(base_buffer));
+
+ IREE_STATISTICS(iree_hal_allocator_statistics_record_free(
+ &allocator->statistics, memory_type,
+ iree_hal_buffer_allocation_size(base_buffer)));
+
+ iree_hal_buffer_destroy(base_buffer);
+}
+
static const iree_hal_allocator_vtable_t iree_hal_rocm_allocator_vtable = {
.destroy = iree_hal_rocm_allocator_destroy,
.host_allocator = iree_hal_rocm_allocator_host_allocator,
@@ -196,4 +211,5 @@
iree_hal_rocm_allocator_query_buffer_compatibility,
.allocate_buffer = iree_hal_rocm_allocator_allocate_buffer,
.wrap_buffer = iree_hal_rocm_allocator_wrap_buffer,
+ .deallocate_buffer = iree_hal_rocm_allocator_deallocate_buffer,
};
diff --git a/experimental/rocm/rocm_allocator.h b/experimental/rocm/rocm_allocator.h
index daa8d2a..a2a89ea 100644
--- a/experimental/rocm/rocm_allocator.h
+++ b/experimental/rocm/rocm_allocator.h
@@ -21,12 +21,6 @@
iree_hal_rocm_context_wrapper_t* context,
iree_hal_allocator_t** out_allocator);
-// Free an allocation represent by the given device or host pointer.
-void iree_hal_rocm_allocator_free(iree_hal_allocator_t* allocator,
- iree_hal_memory_type_t memory_type,
- hipDeviceptr_t device_ptr, void* host_ptr,
- iree_device_size_t allocation_size);
-
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/experimental/rocm/rocm_buffer.c b/experimental/rocm/rocm_buffer.c
index 6c1f161..b39a40e 100644
--- a/experimental/rocm/rocm_buffer.c
+++ b/experimental/rocm/rocm_buffer.c
@@ -10,7 +10,6 @@
#include <stdint.h>
#include <string.h>
-#include "experimental/rocm/rocm_allocator.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
@@ -38,21 +37,16 @@
IREE_ASSERT_ARGUMENT(out_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(allocator);
iree_hal_rocm_buffer_t* buffer = NULL;
iree_status_t status =
- iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
- sizeof(*buffer), (void**)&buffer);
+ iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer);
if (iree_status_is_ok(status)) {
- iree_hal_resource_initialize(&iree_hal_rocm_buffer_vtable,
- &buffer->base.resource);
- buffer->base.allocator = allocator;
- buffer->base.allocated_buffer = &buffer->base;
- buffer->base.allocation_size = allocation_size;
- buffer->base.byte_offset = byte_offset;
- buffer->base.byte_length = byte_length;
- buffer->base.memory_type = memory_type;
- buffer->base.allowed_access = allowed_access;
- buffer->base.allowed_usage = allowed_usage;
+ iree_hal_buffer_initialize(host_allocator, allocator, &buffer->base,
+ allocation_size, byte_offset, byte_length,
+ memory_type, allowed_access, allowed_usage,
+ &iree_hal_rocm_buffer_vtable, &buffer->base);
buffer->host_ptr = host_ptr;
buffer->device_ptr = device_ptr;
*out_buffer = &buffer->base;
@@ -64,15 +58,9 @@
static void iree_hal_rocm_buffer_destroy(iree_hal_buffer_t* base_buffer) {
iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer);
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ iree_allocator_t host_allocator = base_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
-
- iree_hal_rocm_allocator_free(buffer->base.allocator, buffer->base.memory_type,
- buffer->device_ptr, buffer->host_ptr,
- buffer->base.allocation_size);
iree_allocator_free(host_allocator, buffer);
-
IREE_TRACE_ZONE_END(z0);
}
@@ -129,6 +117,11 @@
return buffer->device_ptr;
}
+void* iree_hal_rocm_buffer_host_pointer(iree_hal_buffer_t* base_buffer) {
+ iree_hal_rocm_buffer_t* buffer = iree_hal_rocm_buffer_cast(base_buffer);
+ return buffer->host_ptr;
+}
+
static const iree_hal_buffer_vtable_t iree_hal_rocm_buffer_vtable = {
.destroy = iree_hal_rocm_buffer_destroy,
.map_range = iree_hal_rocm_buffer_map_range,
diff --git a/experimental/rocm/rocm_buffer.h b/experimental/rocm/rocm_buffer.h
index 85cf294..c87be80 100644
--- a/experimental/rocm/rocm_buffer.h
+++ b/experimental/rocm/rocm_buffer.h
@@ -15,7 +15,7 @@
extern "C" {
#endif // __cplusplus
-// Wraps a rocm allocation in an iree_hal_buffer_t.
+// Wraps a ROCm allocation in an iree_hal_buffer_t.
iree_status_t iree_hal_rocm_buffer_wrap(
iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
@@ -23,11 +23,14 @@
iree_device_size_t byte_offset, iree_device_size_t byte_length,
hipDeviceptr_t device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer);
-// Returns the rocm base pointer for the given |buffer|.
+// Returns the ROCm base pointer for the given |buffer|.
// This is the entire allocated_buffer and must be offset by the buffer
// byte_offset and byte_length when used.
hipDeviceptr_t iree_hal_rocm_buffer_device_pointer(iree_hal_buffer_t* buffer);
+// Returns the ROCm host pointer for the given |buffer|, if available.
+void* iree_hal_rocm_buffer_host_pointer(iree_hal_buffer_t* buffer);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
index 59ae1ba..5498957 100644
--- a/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
+++ b/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertBufferOps.cpp
@@ -90,8 +90,6 @@
OwningRewritePatternList &patterns) {
patterns.insert<VMImportOpConversion<IREE::HAL::BufferAssertOp>>(
context, importSymbols, typeConverter, "hal.buffer.assert");
- patterns.insert<VMImportOpConversion<IREE::HAL::BufferAllocatorOp>>(
- context, importSymbols, typeConverter, "hal.buffer.allocator");
patterns.insert<VMImportOpConversion<IREE::HAL::BufferSubspanOp>>(
context, importSymbols, typeConverter, "hal.buffer.subspan");
patterns.insert<VMImportOpConversion<IREE::HAL::BufferLengthOp>>(
diff --git a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
index 63f49f3..3ff6ec6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp
@@ -49,40 +49,6 @@
}
//===----------------------------------------------------------------------===//
-// hal.buffer.*
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-/// Skips a hal.buffer.allocator accessor when the buffer view was created in
-/// the same scope and we know the origin buffer.
-struct SkipBufferAllocatorOp : public OpRewritePattern<BufferAllocatorOp> {
- using OpRewritePattern<BufferAllocatorOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(BufferAllocatorOp op,
- PatternRewriter &rewriter) const override {
- if (auto allocateOp = dyn_cast_or_null<AllocatorAllocateOp>(
- op.buffer().getDefiningOp())) {
- rewriter.replaceOp(op, allocateOp.allocator());
- return success();
- } else if (auto subspanOp = dyn_cast_or_null<BufferSubspanOp>(
- op.buffer().getDefiningOp())) {
- rewriter.replaceOpWithNewOp<BufferAllocatorOp>(op, op.result().getType(),
- subspanOp.source_buffer());
- return success();
- }
- return failure();
- }
-};
-
-} // namespace
-
-void BufferAllocatorOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<SkipBufferAllocatorOp>(context);
-}
-
-//===----------------------------------------------------------------------===//
// hal.buffer_view.*
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index 8daad1b..11eb3c6 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -331,15 +331,6 @@
Value AllocatorTryMapOp::getResultSize(unsigned idx) { return length(); }
//===----------------------------------------------------------------------===//
-// hal.buffer.allocator
-//===----------------------------------------------------------------------===//
-
-void BufferAllocatorOp::getAsmResultNames(
- function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(result(), "allocator");
-}
-
-//===----------------------------------------------------------------------===//
// hal.buffer.subspan
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/HAL/IR/HALOps.td b/iree/compiler/Dialect/HAL/IR/HALOps.td
index 66f2158..0c2bf8b 100644
--- a/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -321,30 +321,6 @@
// (such as when we create it ourselves earlier on) or we've already asserted.
}
-def HAL_BufferAllocatorOp : HAL_PureOp<"buffer.allocator", [
- DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- ]> {
- let summary = [{buffer allocator accessor operation}];
- let description = [{
- Returns the allocator this buffer was allocated from.
- }];
-
- let arguments = (ins
- HAL_BufferType:$buffer
- );
- let results = (outs
- HAL_Allocator:$result
- );
-
- let assemblyFormat = [{
- `<` $buffer `:` type($buffer) `>`
- `:` type($result)
- attr-dict-with-keyword
- }];
-
- let hasCanonicalizer = 1;
-}
-
def HAL_BufferSubspanOp : HAL_PureOp<"buffer.subspan", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<Util_SizeAwareOp>,
diff --git a/iree/compiler/Dialect/HAL/IR/test/BUILD b/iree/compiler/Dialect/HAL/IR/test/BUILD
index be92ceb..553f143 100644
--- a/iree/compiler/Dialect/HAL/IR/test/BUILD
+++ b/iree/compiler/Dialect/HAL/IR/test/BUILD
@@ -19,7 +19,6 @@
[
"allocator_ops.mlir",
"attributes.mlir",
- "buffer_folding.mlir",
"buffer_ops.mlir",
"buffer_view_folding.mlir",
"buffer_view_ops.mlir",
diff --git a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
index 51ce062..f538016 100644
--- a/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/IR/test/CMakeLists.txt
@@ -16,7 +16,6 @@
SRCS
"allocator_ops.mlir"
"attributes.mlir"
- "buffer_folding.mlir"
"buffer_ops.mlir"
"buffer_view_folding.mlir"
"buffer_view_ops.mlir"
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
deleted file mode 100644
index ecd59ee..0000000
--- a/iree/compiler/Dialect/HAL/IR/test/buffer_folding.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: iree-opt -split-input-file -canonicalize %s | iree-opt -split-input-file | IreeFileCheck %s
-
-// CHECK-LABEL: @skip_buffer_allocator
-// CHECK-SAME: (%[[ALLOCATOR:.+]]: !hal.allocator)
-func @skip_buffer_allocator(%allocator: !hal.allocator) -> !hal.allocator {
- %sz = arith.constant 4 : index
- %buffer = hal.allocator.allocate<%allocator : !hal.allocator>
- type("HostVisible|HostCoherent")
- usage(Transfer) : !hal.buffer{%sz}
- %1 = hal.buffer.allocator<%buffer : !hal.buffer> : !hal.allocator
- // CHECK: return %[[ALLOCATOR]]
- return %1 : !hal.allocator
-}
-
-// -----
-
-// CHECK-LABEL: @skip_subspan_buffer_allocator
-// CHECK-SAME: (%[[ALLOCATOR:.+]]: !hal.allocator)
-func @skip_subspan_buffer_allocator(%allocator: !hal.allocator) -> !hal.allocator {
- %c0 = arith.constant 0 : index
- %c184 = arith.constant 184 : index
- %c384 = arith.constant 384 : index
- %source_buffer = hal.allocator.allocate<%allocator : !hal.allocator>
- type("HostVisible|HostCoherent")
- usage(Transfer) : !hal.buffer{%c384}
- %span_buffer = hal.buffer.subspan<%source_buffer : !hal.buffer>[%c0, %c184] : !hal.buffer
- %1 = hal.buffer.allocator<%span_buffer : !hal.buffer> : !hal.allocator
- // CHECK: return %[[ALLOCATOR]]
- return %1 : !hal.allocator
-}
diff --git a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
index 2888d16..229097f 100644
--- a/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
+++ b/iree/compiler/Dialect/HAL/IR/test/buffer_ops.mlir
@@ -1,16 +1,5 @@
-// Tests printing and parsing of hal.buffer ops.
-
// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s
-// CHECK-LABEL: @buffer_allocator
-func @buffer_allocator(%arg0: !hal.buffer) -> !hal.allocator {
- // CHECK: %allocator = hal.buffer.allocator<%arg0 : !hal.buffer> : !hal.allocator
- %allocator = hal.buffer.allocator<%arg0 : !hal.buffer> : !hal.allocator
- return %allocator : !hal.allocator
-}
-
-// -----
-
// CHECK-LABEL: @buffer_subspan
func @buffer_subspan(%arg0: !hal.buffer) -> !hal.buffer {
// CHECK-DAG: %[[OFFSET:.+]] = arith.constant 100
diff --git a/iree/compiler/Dialect/HAL/hal.imports.mlir b/iree/compiler/Dialect/HAL/hal.imports.mlir
index 568b409..115e018 100644
--- a/iree/compiler/Dialect/HAL/hal.imports.mlir
+++ b/iree/compiler/Dialect/HAL/hal.imports.mlir
@@ -67,12 +67,6 @@
%buffer_usage : i32
)
-// Returns the allocator the buffer was allocated with.
-vm.import @buffer.allocator(
- %buffer : !vm.ref<!hal.buffer>
-) -> !vm.ref<!hal.allocator>
-attributes {nosideeffects}
-
// Returns a reference to a subspan of the buffer.
vm.import @buffer.subspan(
%source_buffer : !vm.ref<!hal.buffer>,
diff --git a/iree/hal/allocator.c b/iree/hal/allocator.c
index a8688a7..1f92b36 100644
--- a/iree/hal/allocator.c
+++ b/iree/hal/allocator.c
@@ -104,6 +104,15 @@
return status;
}
+IREE_API_EXPORT void iree_hal_allocator_deallocate_buffer(
+ iree_hal_allocator_t* allocator, iree_hal_buffer_t* buffer) {
+ IREE_ASSERT_ARGUMENT(allocator);
+ IREE_ASSERT_ARGUMENT(buffer);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ _VTABLE_DISPATCH(allocator, deallocate_buffer)(allocator, buffer);
+ IREE_TRACE_ZONE_END(z0);
+}
+
IREE_API_EXPORT iree_status_t iree_hal_allocator_statistics_fprint(
FILE* file, iree_hal_allocator_t* allocator) {
#if IREE_STATISTICS_ENABLE
@@ -111,7 +120,8 @@
iree_hal_allocator_query_statistics(allocator, &statistics);
iree_string_builder_t builder;
- iree_string_builder_initialize(iree_allocator_system(), &builder);
+ iree_string_builder_initialize(iree_hal_allocator_host_allocator(allocator),
+ &builder);
// TODO(benvanik): query identifier for the allocator so we can denote which
// device is being reported.
@@ -126,6 +136,7 @@
fprintf(file, "%.*s", (int)iree_string_builder_size(&builder),
iree_string_builder_buffer(&builder));
}
+
iree_string_builder_deinitialize(&builder);
return status;
#else
diff --git a/iree/hal/allocator.h b/iree/hal/allocator.h
index 03b87fc..9d68e31 100644
--- a/iree/hal/allocator.h
+++ b/iree/hal/allocator.h
@@ -215,11 +215,17 @@
iree_hal_memory_access_t allowed_access,
iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data,
iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer);
+
+ void(IREE_API_PTR* deallocate_buffer)(iree_hal_allocator_t* allocator,
+ iree_hal_buffer_t* buffer);
} iree_hal_allocator_vtable_t;
IREE_API_EXPORT void iree_hal_allocator_destroy(
iree_hal_allocator_t* allocator);
+IREE_API_EXPORT void iree_hal_allocator_deallocate_buffer(
+ iree_hal_allocator_t* allocator, iree_hal_buffer_t* buffer);
+
#if IREE_STATISTICS_ENABLE
// Records a buffer allocation to |statistics|.
diff --git a/iree/hal/allocator_heap.c b/iree/hal/allocator_heap.c
index dc5d4cc..7954299 100644
--- a/iree/hal/allocator_heap.c
+++ b/iree/hal/allocator_heap.c
@@ -180,6 +180,13 @@
data_allocator, out_buffer);
}
+static void iree_hal_heap_allocator_deallocate_buffer(
+ iree_hal_allocator_t* base_allocator, iree_hal_buffer_t* base_buffer) {
+ // We don't do any pooling yet.
+ // TODO(benvanik): move stats tracking here.
+ iree_hal_buffer_destroy(base_buffer);
+}
+
static const iree_hal_allocator_vtable_t iree_hal_heap_allocator_vtable = {
.destroy = iree_hal_heap_allocator_destroy,
.host_allocator = iree_hal_heap_allocator_host_allocator,
@@ -188,4 +195,5 @@
iree_hal_heap_allocator_query_buffer_compatibility,
.allocate_buffer = iree_hal_heap_allocator_allocate_buffer,
.wrap_buffer = iree_hal_heap_allocator_wrap_buffer,
+ .deallocate_buffer = iree_hal_heap_allocator_deallocate_buffer,
};
diff --git a/iree/hal/buffer.c b/iree/hal/buffer.c
index ca15374..b83d809 100644
--- a/iree/hal/buffer.c
+++ b/iree/hal/buffer.c
@@ -76,29 +76,24 @@
static const iree_hal_buffer_vtable_t iree_hal_subspan_buffer_vtable;
-static iree_status_t iree_hal_subspan_buffer_create(
+IREE_API_EXPORT iree_status_t iree_hal_subspan_buffer_create(
iree_hal_buffer_t* allocated_buffer, iree_device_size_t byte_offset,
- iree_device_size_t byte_length, iree_hal_buffer_t** out_buffer) {
+ iree_device_size_t byte_length, iree_hal_allocator_t* device_allocator,
+ iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) {
IREE_ASSERT_ARGUMENT(allocated_buffer);
IREE_ASSERT_ARGUMENT(out_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_buffer_t* buffer = NULL;
- iree_status_t status = iree_allocator_malloc(
- iree_hal_allocator_host_allocator(allocated_buffer->allocator),
- sizeof(*buffer), (void**)&buffer);
+ iree_status_t status =
+ iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer);
if (iree_status_is_ok(status)) {
- iree_hal_resource_initialize(&iree_hal_subspan_buffer_vtable,
- &buffer->resource);
- buffer->allocator = allocated_buffer->allocator;
- buffer->allocated_buffer = allocated_buffer;
- iree_hal_buffer_retain(buffer->allocated_buffer);
- buffer->allocation_size = allocated_buffer->allocation_size;
- buffer->byte_offset = byte_offset;
- buffer->byte_length = byte_length;
- buffer->memory_type = allocated_buffer->memory_type;
- buffer->allowed_access = allocated_buffer->allowed_access;
- buffer->allowed_usage = allocated_buffer->allowed_usage;
+ iree_hal_buffer_initialize(
+ host_allocator, device_allocator, allocated_buffer,
+ allocated_buffer->allocation_size, byte_offset, byte_length,
+ allocated_buffer->memory_type, allocated_buffer->allowed_access,
+ allocated_buffer->allowed_usage, &iree_hal_subspan_buffer_vtable,
+ buffer);
*out_buffer = buffer;
}
@@ -107,8 +102,7 @@
}
static void iree_hal_subspan_buffer_destroy(iree_hal_buffer_t* base_buffer) {
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ iree_allocator_t host_allocator = base_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_buffer_release(base_buffer->allocated_buffer);
@@ -161,7 +155,58 @@
// iree_hal_buffer_t
//===----------------------------------------------------------------------===//
-IREE_HAL_API_RETAIN_RELEASE(buffer);
+IREE_API_EXPORT void iree_hal_buffer_initialize(
+ iree_allocator_t host_allocator, iree_hal_allocator_t* device_allocator,
+ iree_hal_buffer_t* allocated_buffer, iree_device_size_t allocation_size,
+ iree_device_size_t byte_offset, iree_device_size_t byte_length,
+ iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage,
+ const iree_hal_buffer_vtable_t* vtable, iree_hal_buffer_t* buffer) {
+ iree_hal_resource_initialize(vtable, &buffer->resource);
+ buffer->host_allocator = host_allocator;
+ buffer->device_allocator = device_allocator;
+ buffer->allocated_buffer = allocated_buffer;
+ buffer->allocation_size = allocation_size;
+ buffer->byte_offset = byte_offset;
+ buffer->byte_length = byte_length;
+ buffer->memory_type = memory_type;
+ buffer->allowed_access = allowed_access;
+ buffer->allowed_usage = allowed_usage;
+
+ // Retain the base allocated buffer if it's unique from the buffer we are
+ // initializing.
+ if (allocated_buffer != buffer) {
+ iree_hal_buffer_retain(buffer->allocated_buffer);
+ }
+}
+
+IREE_API_EXPORT void iree_hal_buffer_destroy(iree_hal_buffer_t* buffer) {
+ if (IREE_LIKELY(buffer)) {
+ IREE_HAL_VTABLE_DISPATCH(buffer, iree_hal_buffer, destroy)
+ (buffer);
+ }
+}
+
+IREE_API_EXPORT void iree_hal_buffer_retain(iree_hal_buffer_t* buffer) {
+ if (IREE_LIKELY(buffer)) {
+ iree_atomic_ref_count_inc(&((iree_hal_resource_t*)(buffer))->ref_count);
+ }
+}
+
+IREE_API_EXPORT void iree_hal_buffer_release(iree_hal_buffer_t* buffer) {
+ if (IREE_LIKELY(buffer) &&
+ iree_atomic_ref_count_dec(&((iree_hal_resource_t*)(buffer))->ref_count) ==
+ 1) {
+ // If the buffer comes from an allocator then we route back the destruction
+ // request to that. It may decide to keep the buffer alive in a pool or
+ // do some allocator-specific cleanup.
+ if (buffer->device_allocator) {
+ iree_hal_allocator_deallocate_buffer(buffer->device_allocator, buffer);
+ } else {
+ iree_hal_buffer_destroy(buffer);
+ }
+ }
+}
IREE_API_EXPORT iree_status_t iree_hal_buffer_validate_memory_type(
iree_hal_memory_type_t actual_memory_type,
@@ -387,13 +432,8 @@
}
return iree_hal_subspan_buffer_create(buffer, byte_offset, byte_length,
- out_buffer);
-}
-
-IREE_API_EXPORT iree_hal_allocator_t* iree_hal_buffer_allocator(
- const iree_hal_buffer_t* buffer) {
- IREE_ASSERT_ARGUMENT(buffer);
- return buffer->allocator;
+ /*device_allocator=*/NULL,
+ buffer->host_allocator, out_buffer);
}
IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_allocated_buffer(
diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h
index 1c62385..326efa9 100644
--- a/iree/hal/buffer.h
+++ b/iree/hal/buffer.h
@@ -234,14 +234,14 @@
// it will never be used in a way that requires coherency may occupy address
// space reservations or memory mapping that would otherwise not be needed.
//
-// As buffers may sometimes not be accessible from the host the base Buffer type
+// As buffers may sometimes not be accessible from the host the base buffer type
// does not allow for direct void* access and instead buffers must be either
// manipulated using utility functions (such as ReadData or WriteData) or by
-// mapping them into a host-accessible address space via MapMemory. Buffer must
-// be unmapped before any command may use it.
+// mapping them into a host-accessible address space via MapMemory. Buffers must
+// be unmapped before any command may use them.
//
-// Buffers may map (roughly) 1:1 with an allocation either from the host heap or
-// a device. iree_hal_buffer_subspan can be used to reference subspans of
+// Buffers may equate (roughly) 1:1 with an allocation either from the host heap
+// or a device. iree_hal_buffer_subspan can be used to reference subspans of
// buffers like std::span - though unlike std::span the returned buffer holds
// a reference to the parent buffer.
typedef struct iree_hal_buffer_t iree_hal_buffer_t;
@@ -294,10 +294,6 @@
// Releases the given |buffer| from the caller.
IREE_API_EXPORT void iree_hal_buffer_release(iree_hal_buffer_t* buffer);
-// Returns the allocator this buffer was allocated from.
-IREE_API_EXPORT iree_hal_allocator_t* iree_hal_buffer_allocator(
- const iree_hal_buffer_t* buffer);
-
// Returns a pointer to the buffer containing the actual allocation.
// The buffer represents a span of the allocated bytes defined by byte_offset
// and byte_length. If the provided buffer *is* the allocated buffer then the
@@ -458,6 +454,18 @@
iree_device_size_t byte_length, iree_byte_span_t* out_span);
//===----------------------------------------------------------------------===//
+// iree_hal_subspan_buffer_t
+//===----------------------------------------------------------------------===//
+
+// Creates a buffer referencing a subspan of some base allocation.
+// Optionally |device_allocator| can be provided if this subspan references
+// managed buffers that need deallocation callbacks.
+IREE_API_EXPORT iree_status_t iree_hal_subspan_buffer_create(
+ iree_hal_buffer_t* allocated_buffer, iree_device_size_t byte_offset,
+ iree_device_size_t byte_length, iree_hal_allocator_t* device_allocator,
+ iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer);
+
+//===----------------------------------------------------------------------===//
// iree_hal_heap_buffer_t
//===----------------------------------------------------------------------===//
@@ -507,7 +515,8 @@
struct iree_hal_buffer_t {
iree_hal_resource_t resource;
- iree_hal_allocator_t* allocator;
+ iree_allocator_t host_allocator;
+ iree_hal_allocator_t* device_allocator;
iree_hal_buffer_t* allocated_buffer;
iree_device_size_t allocation_size;
@@ -519,6 +528,14 @@
iree_hal_buffer_usage_t allowed_usage;
};
+IREE_API_EXPORT void iree_hal_buffer_initialize(
+ iree_allocator_t host_allocator, iree_hal_allocator_t* device_allocator,
+ iree_hal_buffer_t* allocated_buffer, iree_device_size_t allocation_size,
+ iree_device_size_t byte_offset, iree_device_size_t byte_length,
+ iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access,
+ iree_hal_buffer_usage_t allowed_usage,
+ const iree_hal_buffer_vtable_t* vtable, iree_hal_buffer_t* buffer);
+
IREE_API_EXPORT void iree_hal_buffer_destroy(iree_hal_buffer_t* buffer);
#ifdef __cplusplus
diff --git a/iree/hal/buffer_heap.c b/iree/hal/buffer_heap.c
index db5a229..6649a8c 100644
--- a/iree/hal/buffer_heap.c
+++ b/iree/hal/buffer_heap.c
@@ -95,16 +95,10 @@
host_allocator, &buffer, &data);
if (iree_status_is_ok(status)) {
- iree_hal_resource_initialize(&iree_hal_heap_buffer_vtable,
- &buffer->base.resource);
- buffer->base.allocator = allocator;
- buffer->base.allocated_buffer = &buffer->base;
- buffer->base.allocation_size = allocation_size;
- buffer->base.byte_offset = 0;
- buffer->base.byte_length = allocation_size;
- buffer->base.memory_type = memory_type;
- buffer->base.allowed_access = allowed_access;
- buffer->base.allowed_usage = allowed_usage;
+ iree_hal_buffer_initialize(host_allocator, allocator, &buffer->base,
+ allocation_size, 0, allocation_size, memory_type,
+ allowed_access, allowed_usage,
+ &iree_hal_heap_buffer_vtable, &buffer->base);
buffer->data = data;
buffer->data_allocator =
same_allocator ? iree_allocator_null() : data_allocator;
@@ -136,21 +130,16 @@
IREE_ASSERT_ARGUMENT(out_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(allocator);
iree_hal_heap_buffer_t* buffer = NULL;
iree_status_t status =
- iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
- sizeof(*buffer), (void**)&buffer);
+ iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer);
if (iree_status_is_ok(status)) {
- iree_hal_resource_initialize(&iree_hal_heap_buffer_vtable,
- &buffer->base.resource);
- buffer->base.allocator = allocator;
- buffer->base.allocated_buffer = &buffer->base;
- buffer->base.allocation_size = allocation_size;
- buffer->base.byte_offset = 0;
- buffer->base.byte_length = data.data_length;
- buffer->base.memory_type = memory_type;
- buffer->base.allowed_access = allowed_access;
- buffer->base.allowed_usage = allowed_usage;
+ iree_hal_buffer_initialize(host_allocator, allocator, &buffer->base,
+ allocation_size, 0, data.data_length,
+ memory_type, allowed_access, allowed_usage,
+ &iree_hal_heap_buffer_vtable, &buffer->base);
buffer->data = data;
buffer->data_allocator = data_allocator;
*out_buffer = &buffer->base;
@@ -162,8 +151,7 @@
static void iree_hal_heap_buffer_destroy(iree_hal_buffer_t* base_buffer) {
iree_hal_heap_buffer_t* buffer = (iree_hal_heap_buffer_t*)base_buffer;
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ iree_allocator_t host_allocator = base_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
IREE_STATISTICS({
diff --git a/iree/hal/buffer_view.c b/iree/hal/buffer_view.c
index 8905d84..3c32045 100644
--- a/iree/hal/buffer_view.c
+++ b/iree/hal/buffer_view.c
@@ -17,6 +17,7 @@
struct iree_hal_buffer_view_t {
iree_atomic_ref_count_t ref_count;
+ iree_allocator_t host_allocator;
iree_hal_buffer_t* buffer;
iree_hal_element_type_t element_type;
iree_hal_encoding_type_t encoding_type;
@@ -28,7 +29,7 @@
IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create(
iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
- iree_hal_encoding_type_t encoding_type,
+ iree_hal_encoding_type_t encoding_type, iree_allocator_t host_allocator,
iree_hal_buffer_view_t** out_buffer_view) {
IREE_ASSERT_ARGUMENT(buffer);
IREE_ASSERT_ARGUMENT(out_buffer_view);
@@ -41,9 +42,6 @@
IREE_TRACE_ZONE_BEGIN(z0);
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(buffer));
-
// Allocate and initialize the iree_hal_buffer_view_t struct.
// Note that we have the dynamically-sized shape dimensions on the end.
iree_hal_buffer_view_t* buffer_view = NULL;
@@ -53,6 +51,7 @@
(void**)&buffer_view);
if (iree_status_is_ok(status)) {
iree_atomic_ref_count_init(&buffer_view->ref_count);
+ buffer_view->host_allocator = host_allocator;
buffer_view->buffer = buffer;
iree_hal_buffer_retain(buffer_view->buffer);
buffer_view->element_type = element_type;
@@ -88,10 +87,11 @@
IREE_API_EXPORT void iree_hal_buffer_view_destroy(
iree_hal_buffer_view_t* buffer_view) {
- iree_allocator_t host_allocator = iree_hal_allocator_host_allocator(
- iree_hal_buffer_allocator(buffer_view->buffer));
+ iree_allocator_t host_allocator = buffer_view->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_buffer_release(buffer_view->buffer);
iree_allocator_free(host_allocator, buffer_view);
+ IREE_TRACE_ZONE_END(z0);
}
IREE_API_EXPORT iree_status_t iree_hal_buffer_view_allocate_buffer(
@@ -115,9 +115,9 @@
}
if (iree_status_is_ok(status)) {
- status =
- iree_hal_buffer_view_create(buffer, shape, shape_rank, element_type,
- encoding_type, out_buffer_view);
+ status = iree_hal_buffer_view_create(
+ buffer, shape, shape_rank, element_type, encoding_type,
+ iree_hal_allocator_host_allocator(allocator), out_buffer_view);
}
iree_hal_buffer_release(buffer);
@@ -177,9 +177,9 @@
data_allocator, &buffer);
if (iree_status_is_ok(status)) {
- status =
- iree_hal_buffer_view_create(buffer, shape, shape_rank, element_type,
- encoding_type, out_buffer_view);
+ status = iree_hal_buffer_view_create(
+ buffer, shape, shape_rank, element_type, encoding_type,
+ iree_hal_allocator_host_allocator(allocator), out_buffer_view);
}
iree_hal_buffer_release(buffer);
@@ -592,8 +592,9 @@
}
// Wrap and pass ownership of the buffer to the buffer view.
- status = iree_hal_buffer_view_create(buffer, shape, shape_rank, element_type,
- encoding_type, out_buffer_view);
+ status = iree_hal_buffer_view_create(
+ buffer, shape, shape_rank, element_type, encoding_type,
+ iree_hal_allocator_host_allocator(buffer_allocator), out_buffer_view);
iree_hal_buffer_release(buffer);
return status;
}
@@ -737,8 +738,7 @@
// Allocate scratch space to format in to.
// We should be streaming.
- iree_allocator_t host_allocator = iree_hal_allocator_host_allocator(
- iree_hal_buffer_allocator(iree_hal_buffer_view_buffer(buffer_view)));
+ iree_allocator_t host_allocator = buffer_view->host_allocator;
iree_host_size_t buffer_capacity = buffer_length + 1; // NUL
char* buffer = NULL;
status =
diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h
index 426a7df..9a0c623 100644
--- a/iree/hal/buffer_view.h
+++ b/iree/hal/buffer_view.h
@@ -201,7 +201,7 @@
IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create(
iree_hal_buffer_t* buffer, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
- iree_hal_encoding_type_t encoding_type,
+ iree_hal_encoding_type_t encoding_type, iree_allocator_t host_allocator,
iree_hal_buffer_view_t** out_buffer_view);
// Allocates a buffer from |allocator| and wraps it in a buffer view.
diff --git a/iree/hal/cts/allocator_test.h b/iree/hal/cts/allocator_test.h
index 8377489..506d6ee 100644
--- a/iree/hal/cts/allocator_test.h
+++ b/iree/hal/cts/allocator_test.h
@@ -87,7 +87,6 @@
IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
device_allocator_, memory_type, buffer_usage, kAllocationSize, &buffer));
- EXPECT_EQ(device_allocator_, iree_hal_buffer_allocator(buffer));
// At a mimimum, the requested memory type should be respected.
// Additional bits may be optionally set depending on the allocator.
EXPECT_TRUE(
diff --git a/iree/hal/cts/buffer_mapping_test.h b/iree/hal/cts/buffer_mapping_test.h
index 87a1a41..106b24d 100644
--- a/iree/hal/cts/buffer_mapping_test.h
+++ b/iree/hal/cts/buffer_mapping_test.h
@@ -48,7 +48,6 @@
IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
device_allocator_, memory_type, buffer_usage, kAllocationSize, &buffer));
- EXPECT_EQ(device_allocator_, iree_hal_buffer_allocator(buffer));
EXPECT_TRUE(
iree_all_bits_set(iree_hal_buffer_memory_type(buffer), memory_type));
EXPECT_TRUE(
diff --git a/iree/hal/cuda/cuda_allocator.c b/iree/hal/cuda/cuda_allocator.c
index 1784798..7470282 100644
--- a/iree/hal/cuda/cuda_allocator.c
+++ b/iree/hal/cuda/cuda_allocator.c
@@ -136,6 +136,18 @@
return compatibility;
}
+static void iree_hal_cuda_buffer_free(iree_hal_cuda_context_wrapper_t* context,
+ iree_hal_memory_type_t memory_type,
+ CUdeviceptr device_ptr, void* host_ptr) {
+ if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
+ // Device local.
+ CUDA_IGNORE_ERROR(context->syms, cuMemFree(device_ptr));
+ } else {
+ // Host local.
+ CUDA_IGNORE_ERROR(context->syms, cuMemFreeHost(host_ptr));
+ }
+}
+
static iree_status_t iree_hal_cuda_allocator_allocate_buffer(
iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size,
@@ -199,37 +211,22 @@
}
}
if (iree_status_is_ok(status)) {
- IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc(
- &allocator->statistics, memory_type, allocation_size));
status = iree_hal_cuda_buffer_wrap(
(iree_hal_allocator_t*)allocator, memory_type,
IREE_HAL_MEMORY_ACCESS_ALL, allowed_usage, allocation_size,
/*byte_offset=*/0,
/*byte_length=*/allocation_size, device_ptr, host_ptr, out_buffer);
}
- if (!iree_status_is_ok(status)) {
- iree_hal_cuda_allocator_free(base_allocator, memory_type, device_ptr,
- host_ptr, allocation_size);
+ if (iree_status_is_ok(status)) {
+ IREE_STATISTICS(iree_hal_allocator_statistics_record_alloc(
+ &allocator->statistics, memory_type, allocation_size));
+ } else {
+ iree_hal_cuda_buffer_free(allocator->context, memory_type, device_ptr,
+ host_ptr);
}
return status;
}
-void iree_hal_cuda_allocator_free(iree_hal_allocator_t* base_allocator,
- iree_hal_memory_type_t memory_type,
- CUdeviceptr device_ptr, void* host_ptr,
- iree_device_size_t allocation_size) {
- iree_hal_cuda_allocator_t* allocator =
- iree_hal_cuda_allocator_cast(base_allocator);
- if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) {
- CUDA_IGNORE_ERROR(allocator->context->syms, cuMemFree(device_ptr));
- } else {
- // Host local.
- CUDA_IGNORE_ERROR(allocator->context->syms, cuMemFreeHost(host_ptr));
- }
- IREE_STATISTICS(iree_hal_allocator_statistics_record_free(
- &allocator->statistics, memory_type, allocation_size));
-}
-
static iree_status_t iree_hal_cuda_allocator_wrap_buffer(
iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
@@ -239,6 +236,22 @@
"wrapping of external buffers not supported");
}
+static void iree_hal_cuda_allocator_deallocate_buffer(
+ iree_hal_allocator_t* base_allocator, iree_hal_buffer_t* base_buffer) {
+ iree_hal_cuda_allocator_t* allocator =
+ iree_hal_cuda_allocator_cast(base_allocator);
+ iree_hal_memory_type_t memory_type = iree_hal_buffer_memory_type(base_buffer);
+ iree_hal_cuda_buffer_free(allocator->context, memory_type,
+ iree_hal_cuda_buffer_device_pointer(base_buffer),
+ iree_hal_cuda_buffer_host_pointer(base_buffer));
+
+ IREE_STATISTICS(iree_hal_allocator_statistics_record_free(
+ &allocator->statistics, memory_type,
+ iree_hal_buffer_allocation_size(base_buffer)));
+
+ iree_hal_buffer_destroy(base_buffer);
+}
+
static const iree_hal_allocator_vtable_t iree_hal_cuda_allocator_vtable = {
.destroy = iree_hal_cuda_allocator_destroy,
.host_allocator = iree_hal_cuda_allocator_host_allocator,
@@ -247,4 +260,5 @@
iree_hal_cuda_allocator_query_buffer_compatibility,
.allocate_buffer = iree_hal_cuda_allocator_allocate_buffer,
.wrap_buffer = iree_hal_cuda_allocator_wrap_buffer,
+ .deallocate_buffer = iree_hal_cuda_allocator_deallocate_buffer,
};
diff --git a/iree/hal/cuda/cuda_allocator.h b/iree/hal/cuda/cuda_allocator.h
index 20606d6..0fe2739 100644
--- a/iree/hal/cuda/cuda_allocator.h
+++ b/iree/hal/cuda/cuda_allocator.h
@@ -21,12 +21,6 @@
iree_hal_cuda_context_wrapper_t* context, CUdevice device, CUstream stream,
iree_hal_allocator_t** out_allocator);
-// Free an allocation represent by the given device or host pointer.
-void iree_hal_cuda_allocator_free(iree_hal_allocator_t* allocator,
- iree_hal_memory_type_t memory_type,
- CUdeviceptr device_ptr, void* host_ptr,
- iree_device_size_t allocation_size);
-
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/iree/hal/cuda/cuda_buffer.c b/iree/hal/cuda/cuda_buffer.c
index 1f4ee60..db2a8aa 100644
--- a/iree/hal/cuda/cuda_buffer.c
+++ b/iree/hal/cuda/cuda_buffer.c
@@ -12,7 +12,6 @@
#include "iree/base/api.h"
#include "iree/base/tracing.h"
-#include "iree/hal/cuda/cuda_allocator.h"
typedef struct iree_hal_cuda_buffer_t {
iree_hal_buffer_t base;
@@ -38,21 +37,16 @@
IREE_ASSERT_ARGUMENT(out_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(allocator);
iree_hal_cuda_buffer_t* buffer = NULL;
iree_status_t status =
- iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
- sizeof(*buffer), (void**)&buffer);
+ iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer);
if (iree_status_is_ok(status)) {
- iree_hal_resource_initialize(&iree_hal_cuda_buffer_vtable,
- &buffer->base.resource);
- buffer->base.allocator = allocator;
- buffer->base.allocated_buffer = &buffer->base;
- buffer->base.allocation_size = allocation_size;
- buffer->base.byte_offset = byte_offset;
- buffer->base.byte_length = byte_length;
- buffer->base.memory_type = memory_type;
- buffer->base.allowed_access = allowed_access;
- buffer->base.allowed_usage = allowed_usage;
+ iree_hal_buffer_initialize(host_allocator, allocator, &buffer->base,
+ allocation_size, byte_offset, byte_length,
+ memory_type, allowed_access, allowed_usage,
+ &iree_hal_cuda_buffer_vtable, &buffer->base);
buffer->host_ptr = host_ptr;
buffer->device_ptr = device_ptr;
*out_buffer = &buffer->base;
@@ -64,15 +58,9 @@
static void iree_hal_cuda_buffer_destroy(iree_hal_buffer_t* base_buffer) {
iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ iree_allocator_t host_allocator = base_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
-
- iree_hal_cuda_allocator_free(buffer->base.allocator, buffer->base.memory_type,
- buffer->device_ptr, buffer->host_ptr,
- buffer->base.allocation_size);
iree_allocator_free(host_allocator, buffer);
-
IREE_TRACE_ZONE_END(z0);
}
@@ -129,6 +117,11 @@
return buffer->device_ptr;
}
+void* iree_hal_cuda_buffer_host_pointer(iree_hal_buffer_t* base_buffer) {
+ iree_hal_cuda_buffer_t* buffer = iree_hal_cuda_buffer_cast(base_buffer);
+ return buffer->host_ptr;
+}
+
static const iree_hal_buffer_vtable_t iree_hal_cuda_buffer_vtable = {
.destroy = iree_hal_cuda_buffer_destroy,
.map_range = iree_hal_cuda_buffer_map_range,
diff --git a/iree/hal/cuda/cuda_buffer.h b/iree/hal/cuda/cuda_buffer.h
index 453571d..2aaf037 100644
--- a/iree/hal/cuda/cuda_buffer.h
+++ b/iree/hal/cuda/cuda_buffer.h
@@ -15,7 +15,7 @@
extern "C" {
#endif // __cplusplus
-// Wraps a cuda allocation in an iree_hal_buffer_t.
+// Wraps a CUDA allocation in an iree_hal_buffer_t.
iree_status_t iree_hal_cuda_buffer_wrap(
iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
@@ -23,11 +23,14 @@
iree_device_size_t byte_offset, iree_device_size_t byte_length,
CUdeviceptr device_ptr, void* host_ptr, iree_hal_buffer_t** out_buffer);
-// Returns the cuda base pointer for the given |buffer|.
+// Returns the CUDA base pointer for the given |buffer|.
// This is the entire allocated_buffer and must be offset by the buffer
// byte_offset and byte_length when used.
CUdeviceptr iree_hal_cuda_buffer_device_pointer(iree_hal_buffer_t* buffer);
+// Returns the CUDA host pointer for the given |buffer|, if available.
+void* iree_hal_cuda_buffer_host_pointer(iree_hal_buffer_t* buffer);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/iree/hal/string_util_test.cc b/iree/hal/string_util_test.cc
index 76945e4..22e9063 100644
--- a/iree/hal/string_util_test.cc
+++ b/iree/hal/string_util_test.cc
@@ -430,9 +430,9 @@
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
BufferView buffer_view;
- iree_status_t status =
- iree_hal_buffer_view_create(buffer, shape.data(), shape.size(),
- element_type, encoding_type, &buffer_view);
+ iree_status_t status = iree_hal_buffer_view_create(
+ buffer, shape.data(), shape.size(), element_type, encoding_type,
+ iree_allocator_system(), &buffer_view);
IREE_RETURN_IF_ERROR(std::move(status));
return std::move(buffer_view);
}
diff --git a/iree/hal/vulkan/vma_allocator.cc b/iree/hal/vulkan/vma_allocator.cc
index 760b80c..50c1dab 100644
--- a/iree/hal/vulkan/vma_allocator.cc
+++ b/iree/hal/vulkan/vma_allocator.cc
@@ -354,6 +354,12 @@
"wrapping of external buffers not supported");
}
+static void iree_hal_vulkan_vma_allocator_deallocate_buffer(
+ iree_hal_allocator_t* base_allocator, iree_hal_buffer_t* base_buffer) {
+ // VMA does the pooling for us so we don't need anything special.
+ iree_hal_buffer_destroy(base_buffer);
+}
+
namespace {
const iree_hal_allocator_vtable_t iree_hal_vulkan_vma_allocator_vtable = {
/*.destroy=*/iree_hal_vulkan_vma_allocator_destroy,
@@ -363,5 +369,6 @@
iree_hal_vulkan_vma_allocator_query_buffer_compatibility,
/*.allocate_buffer=*/iree_hal_vulkan_vma_allocator_allocate_buffer,
/*.wrap_buffer=*/iree_hal_vulkan_vma_allocator_wrap_buffer,
+ /*.deallocate_buffer=*/iree_hal_vulkan_vma_allocator_deallocate_buffer,
};
} // namespace
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc
index cea88b1..b69cd67 100644
--- a/iree/hal/vulkan/vma_buffer.cc
+++ b/iree/hal/vulkan/vma_buffer.cc
@@ -47,21 +47,16 @@
IREE_ASSERT_ARGUMENT(out_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_t host_allocator =
+ iree_hal_allocator_host_allocator(allocator);
iree_hal_vulkan_vma_buffer_t* buffer = NULL;
iree_status_t status =
- iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator),
- sizeof(*buffer), (void**)&buffer);
+ iree_allocator_malloc(host_allocator, sizeof(*buffer), (void**)&buffer);
if (iree_status_is_ok(status)) {
- iree_hal_resource_initialize(&iree_hal_vulkan_vma_buffer_vtable,
- &buffer->base.resource);
- buffer->base.allocator = allocator;
- buffer->base.allocated_buffer = &buffer->base;
- buffer->base.allocation_size = allocation_size;
- buffer->base.byte_offset = byte_offset;
- buffer->base.byte_length = byte_length;
- buffer->base.memory_type = memory_type;
- buffer->base.allowed_access = allowed_access;
- buffer->base.allowed_usage = allowed_usage;
+ iree_hal_buffer_initialize(
+ host_allocator, allocator, &buffer->base, allocation_size, byte_offset,
+ byte_length, memory_type, allowed_access, allowed_usage,
+ &iree_hal_vulkan_vma_buffer_vtable, &buffer->base);
buffer->vma = vma;
buffer->handle = handle;
buffer->allocation = allocation;
@@ -87,8 +82,7 @@
static void iree_hal_vulkan_vma_buffer_destroy(iree_hal_buffer_t* base_buffer) {
iree_hal_vulkan_vma_buffer_t* buffer =
iree_hal_vulkan_vma_buffer_cast(base_buffer);
- iree_allocator_t host_allocator =
- iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer));
+ iree_allocator_t host_allocator = base_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);
// IREE_TRACE_FREE_NAMED("VMA", (void*)buffer->handle);
diff --git a/iree/modules/check/check_test.cc b/iree/modules/check/check_test.cc
index e4793af..5bcd6a9 100644
--- a/iree/modules/check/check_test.cc
+++ b/iree/modules/check/check_test.cc
@@ -91,7 +91,8 @@
buffer.get(), 0, contents.data(), contents.size() * sizeof(int32_t)));
IREE_ASSERT_OK(iree_hal_buffer_view_create(
buffer.get(), shape.data(), shape.size(), IREE_HAL_ELEMENT_TYPE_INT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, &*out_buffer_view));
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, iree_allocator_system(),
+ &*out_buffer_view));
}
void CreateFloat16BufferView(iree::span<const uint16_t> contents,
@@ -115,7 +116,7 @@
IREE_ASSERT_OK(iree_hal_buffer_view_create(
buffer.get(), shape.data(), shape.size(),
IREE_HAL_ELEMENT_TYPE_FLOAT_16, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- &*out_buffer_view));
+ iree_allocator_system(), &*out_buffer_view));
}
void CreateFloat32BufferView(iree::span<const float> contents,
@@ -138,7 +139,7 @@
IREE_ASSERT_OK(iree_hal_buffer_view_create(
buffer.get(), shape.data(), shape.size(),
IREE_HAL_ELEMENT_TYPE_FLOAT_32, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- &*out_buffer_view));
+ iree_allocator_system(), &*out_buffer_view));
}
void CreateFloat64BufferView(iree::span<const double> contents,
@@ -161,7 +162,7 @@
IREE_ASSERT_OK(iree_hal_buffer_view_create(
buffer.get(), shape.data(), shape.size(),
IREE_HAL_ELEMENT_TYPE_FLOAT_64, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
- &*out_buffer_view));
+ iree_allocator_system(), &*out_buffer_view));
}
iree_status_t Invoke(const char* function_name) {
diff --git a/iree/modules/hal/exports.inl b/iree/modules/hal/exports.inl
index 5626034..3bcd242 100644
--- a/iree/modules/hal/exports.inl
+++ b/iree/modules/hal/exports.inl
@@ -28,7 +28,6 @@
EXPORT_FN("allocator.map.byte_buffer", iree_hal_module_allocator_map_byte_buffer, riiirii, r)
EXPORT_FN("allocator.wrap.byte_buffer", iree_hal_module_allocator_wrap_byte_buffer, riirii, r)
-EXPORT_FN("buffer.allocator", iree_hal_module_buffer_allocator, r, r)
EXPORT_FN("buffer.assert", iree_hal_module_buffer_assert, rrriii, v)
EXPORT_FN("buffer.length", iree_hal_module_buffer_length, r, i)
EXPORT_FN("buffer.load", iree_hal_module_buffer_load, rii, i)
diff --git a/iree/modules/hal/module.c b/iree/modules/hal/module.c
index 37de3cd..ccd5cbc 100644
--- a/iree/modules/hal/module.c
+++ b/iree/modules/hal/module.c
@@ -439,13 +439,6 @@
// target device. This needs some iree_hal_allocator_* methods for checking
// whether the external buffer can be used. To start we just compare if the
// allocators are identical.
- if (iree_hal_buffer_allocator(buffer) != allocator) {
- return iree_make_status(
- IREE_STATUS_INVALID_ARGUMENT,
- "%.*s imported buffer allocator mismatch; must be from "
- "the same allocator (today)",
- (int)message_str.size, message_str.data);
- }
// All memory type bits expected (indicating where the program intends to use
// the buffer data) must be set in the buffer while the buffer is allowed to
@@ -508,15 +501,6 @@
return iree_ok_status();
}
-IREE_VM_ABI_EXPORT(iree_hal_module_buffer_allocator, //
- iree_hal_module_state_t, //
- r, r) {
- iree_hal_buffer_t* buffer = NULL;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r0, &buffer));
- rets->r0 = iree_hal_allocator_retain_ref(iree_hal_buffer_allocator(buffer));
- return iree_ok_status();
-}
-
IREE_VM_ABI_EXPORT(iree_hal_module_buffer_subspan, //
iree_hal_module_state_t, //
rii, r) {
@@ -607,9 +591,9 @@
&shape_rank, &shape_dims);
iree_hal_buffer_view_t* buffer_view = NULL;
- IREE_RETURN_IF_ERROR(
- iree_hal_buffer_view_create(source_buffer, shape_dims, shape_rank,
- element_type, encoding_type, &buffer_view));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
+ source_buffer, shape_dims, shape_rank, element_type, encoding_type,
+ state->host_allocator, &buffer_view));
rets->r0 = iree_hal_buffer_view_move_ref(buffer_view);
return iree_ok_status();
}
diff --git a/iree/samples/simple_embedding/simple_embedding.c b/iree/samples/simple_embedding/simple_embedding.c
index 93b7a1f..db9f2b6 100644
--- a/iree/samples/simple_embedding/simple_embedding.c
+++ b/iree/samples/simple_embedding/simple_embedding.c
@@ -94,10 +94,12 @@
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
arg0_buffer, shape, IREE_ARRAYSIZE(shape), IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, &arg0_buffer_view));
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, iree_allocator_system(),
+ &arg0_buffer_view));
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
arg1_buffer, shape, IREE_ARRAYSIZE(shape), IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, &arg1_buffer_view));
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, iree_allocator_system(),
+ &arg1_buffer_view));
iree_hal_buffer_release(arg0_buffer);
iree_hal_buffer_release(arg1_buffer);
diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc
index 70124ed..4939b93 100644
--- a/iree/samples/vulkan/vulkan_inference_gui.cc
+++ b/iree/samples/vulkan/vulkan_inference_gui.cc
@@ -389,12 +389,14 @@
input0_buffer,
/*shape=*/&kElementCount, /*shape_rank=*/1,
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, &input0_buffer_view));
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, iree_allocator_system(),
+ &input0_buffer_view));
IREE_CHECK_OK(iree_hal_buffer_view_create(
input1_buffer,
/*shape=*/&kElementCount, /*shape_rank=*/1,
IREE_HAL_ELEMENT_TYPE_FLOAT_32,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, &input1_buffer_view));
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, iree_allocator_system(),
+ &input1_buffer_view));
iree_hal_buffer_release(input0_buffer);
iree_hal_buffer_release(input1_buffer);
// Marshal inputs through a VM variant list.
diff --git a/iree/tools/utils/image_util.c b/iree/tools/utils/image_util.c
index 6bc1a06..61e1486 100644
--- a/iree/tools/utils/image_util.c
+++ b/iree/tools/utils/image_util.c
@@ -191,9 +191,9 @@
if (iree_status_is_ok(result)) {
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
- result =
- iree_hal_buffer_view_create(buffer, shape, shape_rank, element_type,
- encoding_type, out_buffer_view);
+ result = iree_hal_buffer_view_create(
+ buffer, shape, shape_rank, element_type, encoding_type,
+ iree_hal_allocator_host_allocator(allocator), out_buffer_view);
}
iree_hal_buffer_release(buffer);
stbi_image_free(pixel_data);
diff --git a/iree/tools/utils/trace_replay.c b/iree/tools/utils/trace_replay.c
index 7099b38..c84c9d4 100644
--- a/iree/tools/utils/trace_replay.c
+++ b/iree/tools/utils/trace_replay.c
@@ -732,7 +732,8 @@
iree_hal_buffer_view_t* buffer_view = NULL;
status = iree_hal_buffer_view_create(buffer, shape, shape_rank, element_type,
- encoding_type, &buffer_view);
+ encoding_type, replay->host_allocator,
+ &buffer_view);
iree_hal_buffer_release(buffer);
IREE_RETURN_IF_ERROR(status);