Making iree_hal_buffer_unmap_range return a status. This is needed for implementations that perform non-trivial writebacks when unmapping (such as webgpu/metal/any rpc system/etc).
diff --git a/experimental/rocm/rocm_buffer.c b/experimental/rocm/rocm_buffer.c index a0096ac..6b5f79b 100644 --- a/experimental/rocm/rocm_buffer.c +++ b/experimental/rocm/rocm_buffer.c
@@ -94,10 +94,11 @@ return iree_ok_status(); } -static void iree_hal_rocm_buffer_unmap_range( +static iree_status_t iree_hal_rocm_buffer_unmap_range( iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { - // nothing to do. + // Nothing to do (today). + return iree_ok_status(); } static iree_status_t iree_hal_rocm_buffer_invalidate_range(
diff --git a/iree/base/status.c b/iree/base/status.c index 4942eb6..7c6784d 100644 --- a/iree/base/status.c +++ b/iree/base/status.c
@@ -525,6 +525,18 @@ return iree_ok_status(); } +IREE_API_EXPORT iree_status_t iree_status_join(iree_status_t base_status, + iree_status_t new_status) { + // TODO(benvanik): annotate |base_status| with |new_status| so we see it? + // This is intended for failure handling and usually the first failure is the + // root cause and most important to see. + if (!iree_status_is_ok(base_status)) { + iree_status_ignore(new_status); + return base_status; + } + return new_status; +} + IREE_API_EXPORT IREE_ATTRIBUTE_NORETURN void iree_status_abort( iree_status_t status) { IREE_ASSERT(!iree_status_is_ok(status),
diff --git a/iree/base/status.h b/iree/base/status.h index 81baad3..a26352c 100644 --- a/iree/base/status.h +++ b/iree/base/status.h
@@ -379,6 +379,16 @@ // Returns an OK status that can be used when chaining. IREE_API_EXPORT iree_status_t iree_status_ignore(iree_status_t status); +// Returns a new status that is |base_status| if not OK and otherwise returns +// |new_status|. This allows for chaining failure handling code that may also +// return statuses. +// +// Example: +// iree_status_t status = do_something(); +// return iree_status_join(status, do_cleanup()); +IREE_API_EXPORT iree_status_t iree_status_join(iree_status_t base_status, + iree_status_t new_status); + // Aborts the program with a failing |status|. // This will trigger a SIGABRT. It's best not to use this at all outside of // demos or tools.
diff --git a/iree/hal/buffer.c b/iree/hal/buffer.c index 0fe0dc9..bd90062 100644 --- a/iree/hal/buffer.c +++ b/iree/hal/buffer.c
@@ -121,12 +121,12 @@ local_byte_length, mapping); } -static void iree_hal_subspan_buffer_unmap_range( +static iree_status_t iree_hal_subspan_buffer_unmap_range( iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { - if (!buffer->allocated_buffer) return; - _VTABLE_DISPATCH(buffer->allocated_buffer, unmap_range) - (buffer->allocated_buffer, local_byte_offset, local_byte_length, mapping); + if (!buffer->allocated_buffer) return iree_ok_status(); + return _VTABLE_DISPATCH(buffer->allocated_buffer, unmap_range)( + buffer->allocated_buffer, local_byte_offset, local_byte_length, mapping); } static iree_status_t iree_hal_subspan_buffer_invalidate_range( @@ -522,7 +522,7 @@ if (IREE_UNLIKELY((byte_offset % pattern_length) != 0) || IREE_UNLIKELY((byte_length % pattern_length) != 0)) { - iree_hal_buffer_unmap_range(&target_mapping); + iree_status_ignore(iree_hal_buffer_unmap_range(&target_mapping)); IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "attempting to fill a range with %zu byte values " @@ -576,7 +576,8 @@ status = iree_hal_buffer_flush_range(&target_mapping, 0, IREE_WHOLE_BUFFER); } - iree_hal_buffer_unmap_range(&target_mapping); + status = + iree_status_join(status, iree_hal_buffer_unmap_range(&target_mapping)); IREE_TRACE_ZONE_END(z0); return status; } @@ -716,7 +717,8 @@ } if (source.device_buffer) { - iree_hal_buffer_unmap_range(&source_mapping); + status = + iree_status_join(status, iree_hal_buffer_unmap_range(&source_mapping)); } if (target.device_buffer) { if (adjusted_data_length > 0 && @@ -726,7 +728,8 @@ status, iree_hal_buffer_flush_range(&target_mapping, 0, adjusted_data_length)); } - iree_hal_buffer_unmap_range(&target_mapping); + status = + iree_status_join(status, iree_hal_buffer_unmap_range(&target_mapping)); } return status; } @@ -791,16 +794,16 @@ return status; } -IREE_API_EXPORT void iree_hal_buffer_unmap_range( - iree_hal_buffer_mapping_t* buffer_mapping) { +IREE_API_EXPORT iree_status_t +iree_hal_buffer_unmap_range(iree_hal_buffer_mapping_t* buffer_mapping) { IREE_ASSERT_ARGUMENT(buffer_mapping); iree_hal_buffer_t* buffer = buffer_mapping->buffer; - if (!buffer) return; + if (!buffer) return iree_ok_status(); IREE_TRACE_ZONE_BEGIN(z0); - _VTABLE_DISPATCH(buffer, unmap_range) - (buffer, buffer_mapping->impl.byte_offset, - buffer_mapping->contents.data_length, buffer_mapping); + iree_status_t status = _VTABLE_DISPATCH(buffer, unmap_range)( + buffer, buffer_mapping->impl.byte_offset, + buffer_mapping->contents.data_length, buffer_mapping); if (!buffer_mapping->impl.is_persistent) { iree_hal_buffer_release(buffer); @@ -808,6 +811,7 @@ memset(buffer_mapping, 0, sizeof(*buffer_mapping)); IREE_TRACE_ZONE_END(z0); + return status; } IREE_API_EXPORT iree_status_t iree_hal_buffer_invalidate_range(
diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h index ed4bb77..50215b2 100644 --- a/iree/hal/buffer.h +++ b/iree/hal/buffer.h
@@ -519,8 +519,12 @@ // // If the buffer is not IREE_HAL_MEMORY_TYPE_HOST_COHERENT then the caller must // flush the byte range they want to make available to other threads/devices. -IREE_API_EXPORT void iree_hal_buffer_unmap_range( - iree_hal_buffer_mapping_t* buffer_mapping); +// +// May fail, though unlikely to do so for read-only mapping and the result can +// be safely ignored using iree_status_ignore. If writing then users must check +// the status to ensure their writes succeeded. +IREE_API_EXPORT iree_status_t +iree_hal_buffer_unmap_range(iree_hal_buffer_mapping_t* buffer_mapping); // Invalidates ranges of non-coherent memory from the host caches. // This guarantees that device writes to the memory ranges provided are @@ -608,10 +612,10 @@ iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping); - void(IREE_API_PTR* unmap_range)(iree_hal_buffer_t* buffer, - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - iree_hal_buffer_mapping_t* mapping); + iree_status_t(IREE_API_PTR* unmap_range)(iree_hal_buffer_t* buffer, + iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, + iree_hal_buffer_mapping_t* mapping); iree_status_t(IREE_API_PTR* invalidate_range)( iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset,
diff --git a/iree/hal/buffer_heap.c b/iree/hal/buffer_heap.c index 67eba93..695d77e 100644 --- a/iree/hal/buffer_heap.c +++ b/iree/hal/buffer_heap.c
@@ -192,10 +192,11 @@ return iree_ok_status(); } -static void iree_hal_heap_buffer_unmap_range( +static iree_status_t iree_hal_heap_buffer_unmap_range( iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { // No-op here as we always have the pointer. + return iree_ok_status(); } static iree_status_t iree_hal_heap_buffer_invalidate_range(
diff --git a/iree/hal/buffer_view.c b/iree/hal/buffer_view.c index 8ae43eb..b120fa5 100644 --- a/iree/hal/buffer_view.c +++ b/iree/hal/buffer_view.c
@@ -213,7 +213,8 @@ status = callback(&buffer_mapping, user_data); } - iree_hal_buffer_unmap_range(&buffer_mapping); + status = + iree_status_join(status, iree_hal_buffer_unmap_range(&buffer_mapping)); if (iree_status_is_ok(status)) { *out_buffer_view = buffer_view; } else { @@ -759,7 +760,8 @@ buffer ? buffer_capacity - buffer_length : 0, buffer ? buffer + buffer_length : NULL, &elements_length); buffer_length += elements_length; - iree_hal_buffer_unmap_range(&buffer_mapping); + status = + iree_status_join(status, iree_hal_buffer_unmap_range(&buffer_mapping)); if (iree_status_is_out_of_range(status)) { status = iree_status_ignore(status); buffer = NULL;
diff --git a/iree/hal/cuda/cuda_buffer.c b/iree/hal/cuda/cuda_buffer.c index c49959b..44fce4b 100644 --- a/iree/hal/cuda/cuda_buffer.c +++ b/iree/hal/cuda/cuda_buffer.c
@@ -94,10 +94,11 @@ return iree_ok_status(); } -static void iree_hal_cuda_buffer_unmap_range( +static iree_status_t iree_hal_cuda_buffer_unmap_range( iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { - // nothing to do. + // Nothing to do (today). + return iree_ok_status(); } static iree_status_t iree_hal_cuda_buffer_invalidate_range(
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc index 4f1ba9d..fa2dec0 100644 --- a/iree/hal/vulkan/vma_buffer.cc +++ b/iree/hal/vulkan/vma_buffer.cc
@@ -135,12 +135,13 @@ return iree_ok_status(); } -static void iree_hal_vulkan_vma_buffer_unmap_range( +static iree_status_t iree_hal_vulkan_vma_buffer_unmap_range( iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, iree_hal_buffer_mapping_t* mapping) { iree_hal_vulkan_vma_buffer_t* buffer = iree_hal_vulkan_vma_buffer_cast(base_buffer); vmaUnmapMemory(buffer->vma, buffer->allocation); + return iree_ok_status(); } static iree_status_t iree_hal_vulkan_vma_buffer_invalidate_range(
diff --git a/iree/modules/check/module.cc b/iree/modules/check/module.cc index 8e1be2f..8c0abc7 100644 --- a/iree/modules/check/module.cc +++ b/iree/modules/check/module.cc
@@ -190,7 +190,7 @@ /*byte_offset=*/0, size, &mapped_memory)); IREE_RETURN_IF_ERROR( ::iree::ExpectAllTrue(mapped_memory.contents, element_type)); - iree_hal_buffer_unmap_range(&mapped_memory); + iree_status_ignore(iree_hal_buffer_unmap_range(&mapped_memory)); return OkStatus(); } @@ -235,8 +235,8 @@ bool shape_eq = lhs_shape == rhs_shape; bool contents_eq = EqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents); - iree_hal_buffer_unmap_range(&lhs_mapped_memory); - iree_hal_buffer_unmap_range(&rhs_mapped_memory); + iree_status_ignore(iree_hal_buffer_unmap_range(&lhs_mapped_memory)); + iree_status_ignore(iree_hal_buffer_unmap_range(&rhs_mapped_memory)); if (!element_types_eq || !shape_eq || !contents_eq) { std::ostringstream os; @@ -317,8 +317,8 @@ AlmostEqByteSpan(lhs_mapped_memory.contents, rhs_mapped_memory.contents, lhs_element_type)); } - iree_hal_buffer_unmap_range(&lhs_mapped_memory); - iree_hal_buffer_unmap_range(&rhs_mapped_memory); + iree_status_ignore(iree_hal_buffer_unmap_range(&lhs_mapped_memory)); + iree_status_ignore(iree_hal_buffer_unmap_range(&rhs_mapped_memory)); if (!element_types_eq || !shape_eq || !contents_could_be_almost_eq) { std::ostringstream os;
diff --git a/iree/samples/simple_embedding/simple_embedding.c b/iree/samples/simple_embedding/simple_embedding.c index 30a4566..2f3aec2 100644 --- a/iree/samples/simple_embedding/simple_embedding.c +++ b/iree/samples/simple_embedding/simple_embedding.c
@@ -7,6 +7,10 @@ // A example of setting up the HAL module to run simple pointwise array // multiplication with the device implemented by different backends via // create_sample_driver(). +// +// NOTE: this file does not properly handle error cases and will leak on +// failure. Applications that are just going to exit()/abort() on failure can +// probably get away with the same thing but really should prefer not to. #include <stdio.h> @@ -139,7 +143,6 @@ return iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches"); } } - iree_hal_buffer_unmap_range(&mapped_memory); iree_vm_list_release(inputs); iree_vm_list_release(outputs);
diff --git a/iree/samples/static_library/static_library_demo.c b/iree/samples/static_library/static_library_demo.c index 2529281..3fc86df 100644 --- a/iree/samples/static_library/static_library_demo.c +++ b/iree/samples/static_library/static_library_demo.c
@@ -195,7 +195,6 @@ } // Cleanup call and buffers. - iree_hal_buffer_unmap_range(&mapped_memory); iree_hal_buffer_view_release(ret_buffer_view); iree_runtime_call_deinitialize(&call);
diff --git a/iree/samples/vision/iree-run-mnist-module.c b/iree/samples/vision/iree-run-mnist-module.c index bfd9d7c..1a3dc7e 100644 --- a/iree/samples/vision/iree-run-mnist-module.c +++ b/iree/samples/vision/iree-run-mnist-module.c
@@ -86,7 +86,7 @@ result_idx = i; } } - iree_hal_buffer_unmap_range(&mapped_memory); + // Get the highest index from the output. fprintf(stdout, "Detected number: %d\n", result_idx); iree_hal_buffer_view_release(ret_buffer_view);