Fixing up samples and tests.
diff --git a/iree/hal/cts/command_buffer_test.h b/iree/hal/cts/command_buffer_test.h
index 77fe604..4bd0d7b 100644
--- a/iree/hal/cts/command_buffer_test.h
+++ b/iree/hal/cts/command_buffer_test.h
@@ -190,13 +190,13 @@
}
TEST_P(command_buffer_test, CopySubBuffer) {
- iree_hal_command_buffer_t* command_buffer;
+ iree_hal_command_buffer_t* command_buffer = NULL;
IREE_ASSERT_OK(iree_hal_command_buffer_create(
device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY,
&command_buffer));
- iree_hal_buffer_t* device_buffer;
+ iree_hal_buffer_t* device_buffer = NULL;
IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
device_allocator_,
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
@@ -209,14 +209,14 @@
// Create another host buffer with a smaller size.
std::vector<uint8_t> host_buffer_data(kBufferSize, i8_val);
- iree_hal_buffer_t* host_buffer;
+ iree_hal_buffer_t* host_buffer = NULL;
IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
device_allocator_,
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
- IREE_HAL_BUFFER_USAGE_ALL, kBufferSize / 2,
+ IREE_HAL_BUFFER_USAGE_ALL, host_buffer_data.size() / 2,
iree_make_const_byte_span(host_buffer_data.data(),
- host_buffer_data.size()),
+ host_buffer_data.size() / 2),
&host_buffer));
// Copy the host buffer to the device buffer; zero fill the untouched bytes.
diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h
index 1269646..611926d 100644
--- a/iree/hal/cts/cts_test_base.h
+++ b/iree/hal/cts/cts_test_base.h
@@ -41,7 +41,7 @@
iree_status_t status = TryGetDriver(driver_name, &driver);
if (iree_status_is_unavailable(status)) {
iree_status_free(status);
- IREE_LOG(WARNING) << "Skipping test as << '" << driver_name
+ IREE_LOG(WARNING) << "Skipping test as '" << driver_name
<< "' driver is unavailable";
GTEST_SKIP();
return;
diff --git a/iree/hal/cuda/dynamic_symbols_test.cc b/iree/hal/cuda/dynamic_symbols_test.cc
index 47862d3..ab5136c 100644
--- a/iree/hal/cuda/dynamic_symbols_test.cc
+++ b/iree/hal/cuda/dynamic_symbols_test.cc
@@ -27,6 +27,8 @@
iree_status_t status = iree_hal_cuda_dynamic_symbols_initialize(
iree_allocator_system(), &symbols);
if (!iree_status_is_ok(status)) {
+ iree_status_fprint(stderr, status);
+ iree_status_ignore(status);
std::cerr << "Symbols cannot be loaded, skipping test.";
GTEST_SKIP();
}
diff --git a/iree/modules/check/module.cc b/iree/modules/check/module.cc
index 8c0abc7..6837f4f 100644
--- a/iree/modules/check/module.cc
+++ b/iree/modules/check/module.cc
@@ -220,6 +220,9 @@
iree_hal_element_type_t rhs_element_type =
iree_hal_buffer_view_element_type(rhs);
+ // HACK: this is all broken and will leak. Let's kill this entire module
+ // please.
+
iree_hal_buffer_t* lhs_buf = iree_hal_buffer_view_buffer(lhs);
iree_hal_buffer_mapping_t lhs_mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
diff --git a/iree/samples/simple_embedding/simple_embedding.c b/iree/samples/simple_embedding/simple_embedding.c
index 2f3aec2..94f2dee 100644
--- a/iree/samples/simple_embedding/simple_embedding.c
+++ b/iree/samples/simple_embedding/simple_embedding.c
@@ -72,12 +72,11 @@
// Initial buffer contents for 4 * 2 = 8.
const float kFloat4[] = {4.0f, 4.0f, 4.0f, 4.0f};
- const float kFloat2[] = {2.0f, 2.0f, 2.0f, .0f};
- const int kElementCount = IREE_ARRAYSIZE(kFloat4);
+ const float kFloat2[] = {2.0f, 2.0f, 2.0f, 2.0f};
// Allocate buffers in device-local memory so that if the device has an
// independent address space they live on the fast side of the fence.
- iree_hal_dim_t shape[1] = {kElementCount};
+ iree_hal_dim_t shape[1] = {IREE_ARRAYSIZE(kFloat4)};
iree_hal_buffer_view_t* arg0_buffer_view = NULL;
iree_hal_buffer_view_t* arg1_buffer_view = NULL;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer(
@@ -134,12 +133,12 @@
}
// Read back the results and ensure we got the right values.
- iree_hal_buffer_mapping_t mapped_memory;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
- iree_hal_buffer_view_buffer(ret_buffer_view), IREE_HAL_MEMORY_ACCESS_READ,
- 0, IREE_WHOLE_BUFFER, &mapped_memory));
- for (int i = 0; i < mapped_memory.contents.data_length / sizeof(float); ++i) {
- if (((const float*)mapped_memory.contents.data)[i] != 8.0f) {
+ float results[] = {0.0f, 0.0f, 0.0f, 0.0f};
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view), 0,
+ results, sizeof(results)));
+ for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(results); ++i) {
+ if (results[i] != 8.0f) {
return iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
}
}
diff --git a/iree/samples/static_library/static_library_demo.c b/iree/samples/static_library/static_library_demo.c
index 3fc86df..79c802e 100644
--- a/iree/samples/static_library/static_library_demo.c
+++ b/iree/samples/static_library/static_library_demo.c
@@ -171,25 +171,17 @@
}
// Read back the results and ensure we got the right values.
- iree_hal_buffer_mapping_t mapped_memory;
- memset(&mapped_memory, 0, sizeof(mapped_memory));
+ float results[] = {0.0f, 0.0f, 0.0f, 0.0f};
if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_map_range(
- iree_hal_buffer_view_buffer(ret_buffer_view),
- IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &mapped_memory);
+ status =
+ iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view),
+ 0, results, sizeof(results));
}
if (iree_status_is_ok(status)) {
- if (mapped_memory.contents.data_length / sizeof(float) != kElementCount) {
- status = iree_make_status(IREE_STATUS_UNKNOWN,
- "result does not match element count ");
- }
- }
- if (iree_status_is_ok(status)) {
- const float* data = (const float*)mapped_memory.contents.data;
- for (iree_host_size_t i = 0;
- i < mapped_memory.contents.data_length / sizeof(float); ++i) {
- if (data[i] != 8.0f) {
+ for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(results); ++i) {
+ if (results[i] != 8.0f) {
status = iree_make_status(IREE_STATUS_UNKNOWN, "result mismatches");
+ break;
}
}
}
diff --git a/iree/samples/vision/iree-run-mnist-module.c b/iree/samples/vision/iree-run-mnist-module.c
index 1a3dc7e..d525bdd 100644
--- a/iree/samples/vision/iree-run-mnist-module.c
+++ b/iree/samples/vision/iree-run-mnist-module.c
@@ -73,23 +73,22 @@
// Read back the results. The output of the mnist model is a 1x10 prediction
// confidence values for each digit in [0, 9].
- iree_hal_buffer_mapping_t mapped_memory;
- IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
- iree_hal_buffer_view_buffer(ret_buffer_view), IREE_HAL_MEMORY_ACCESS_READ,
- 0, IREE_WHOLE_BUFFER, &mapped_memory));
+ float predictions[1 * 10] = {0.0f};
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_read_data(iree_hal_buffer_view_buffer(ret_buffer_view), 0,
+ predictions, sizeof(predictions)));
+ iree_hal_buffer_view_release(ret_buffer_view);
+
+ // Get the highest index from the output.
float result_val = FLT_MIN;
int result_idx = 0;
- const float* data_ptr = (const float*)mapped_memory.contents.data;
- for (int i = 0; i < mapped_memory.contents.data_length / sizeof(float); ++i) {
- if (data_ptr[i] > result_val) {
- result_val = data_ptr[i];
+ for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(predictions); ++i) {
+ if (predictions[i] > result_val) {
+ result_val = predictions[i];
result_idx = i;
}
}
-
- // Get the highest index from the output.
fprintf(stdout, "Detected number: %d\n", result_idx);
- iree_hal_buffer_view_release(ret_buffer_view);
iree_runtime_call_deinitialize(&call);
iree_runtime_session_release(session);
diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc
index bacf56f..9d895a1 100644
--- a/iree/samples/vulkan/vulkan_inference_gui.cc
+++ b/iree/samples/vulkan/vulkan_inference_gui.cc
@@ -71,9 +71,10 @@
}
// Setup window
- SDL_WindowFlags window_flags =
- (SDL_WindowFlags)(SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE |
- SDL_WINDOW_ALLOW_HIGHDPI);
+ // clang-format off
+ SDL_WindowFlags window_flags = (SDL_WindowFlags)(
+ SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
+ // clang-format on
SDL_Window* window = SDL_CreateWindow(
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
diff --git a/iree/tools/utils/image_util.c b/iree/tools/utils/image_util.c
index b7d99c5..94aaab0 100644
--- a/iree/tools/utils/image_util.c
+++ b/iree/tools/utils/image_util.c
@@ -159,7 +159,7 @@
iree_tools_utils_buffer_view_load_params_t* params =
(iree_tools_utils_buffer_view_load_params_t*)user_data;
return iree_tools_utils_pixel_rescaled_to_buffer(
- params->pixel_data, mapping->contents.data_length, params->input_range,
+ params->pixel_data, params->pixel_data_length, params->input_range,
params->input_range_length, (float*)mapping->contents.data);
}
diff --git a/iree/tools/utils/image_util.h b/iree/tools/utils/image_util.h
index eca9226..fa2e709 100644
--- a/iree/tools/utils/image_util.h
+++ b/iree/tools/utils/image_util.h
@@ -66,7 +66,7 @@
//
// |out_buffer| needs to be allocated before the call.
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
- const uint8_t* pixel_data, iree_host_size_t buffer_length,
+ const uint8_t* pixel_data, iree_host_size_t pixel_count,
const float* input_range, iree_host_size_t input_range_length,
float* out_buffer);
diff --git a/iree/tools/utils/trace_replay.c b/iree/tools/utils/trace_replay.c
index 8478f2f..c3cc197 100644
--- a/iree/tools/utils/trace_replay.c
+++ b/iree/tools/utils/trace_replay.c
@@ -480,7 +480,6 @@
iree_string_view_t value =
iree_string_view_trim(iree_yaml_node_as_string(contents_node));
- iree_status_t status = iree_ok_status();
if (strcmp(contents_node->tag, "tag:yaml.org,2002:binary") == 0) {
return iree_yaml_base64_decode(value, mapping->contents);
} else if (strcmp(contents_node->tag, "tag:yaml.org,2002:str") == 0) {
@@ -647,7 +646,7 @@
iree_host_size_t shape_rank;
} iree_trace_replay_generation_params_t;
-static iree_status_t iree_trace_replay_generate_hal_buffer(
+static iree_status_t iree_trace_replay_generate_hal_buffer_callback(
iree_hal_buffer_mapping_t* mapping, void* user_data) {
iree_trace_replay_generation_params_t* params =
(iree_trace_replay_generation_params_t*)user_data;
@@ -709,6 +708,7 @@
document, value_node, iree_make_cstring_view("contents_generator"),
&generator_node));
+ iree_hal_buffer_view_t* buffer_view = NULL;
if (contents_node && generator_node) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,