[WebGPU] Async, loop-based invoke and output. (#13962)
Redo of https://github.com/openxla/iree/pull/13820, which was closed
because GitHub is silly.
---
Now that https://github.com/openxla/iree/pull/13669 added support for
`iree_loop_wait_all()` to `iree_loop_emscripten`, this builds on the
work that https://github.com/openxla/iree/pull/11017 started to convert
the sample WebGPU application to be fully asynchronous[1].
The code _should_ support multiple return values, but I ran into
https://github.com/openxla/iree/issues/13809 while working on this.
^[1]: `iree_hal_semaphore_wait` is still technically synchronous, see
the notes in the source code
diff --git a/experimental/web/sample_webgpu/CMakeLists.txt b/experimental/web/sample_webgpu/CMakeLists.txt
index d26c00a..61c1641 100644
--- a/experimental/web/sample_webgpu/CMakeLists.txt
+++ b/experimental/web/sample_webgpu/CMakeLists.txt
@@ -24,21 +24,30 @@
# The general purpose libraries link in multiple executable loaders and HAL
# drivers/devices, which include code not compatible with Emscripten.
target_link_libraries(${_NAME}
+ # C sources
+ iree_base_loop_emscripten
iree_runtime_runtime
iree_experimental_webgpu_webgpu
iree_experimental_webgpu_platform_emscripten_emscripten
+ # JS sources
+ "--js-library ${IREE_ROOT_DIR}/runtime/src/iree/base/internal/wait_handle_emscripten.js"
+ "--js-library ${IREE_ROOT_DIR}/runtime/src/iree/base/loop_emscripten.js"
)
target_link_options(${_NAME} PRIVATE
# https://emscripten.org/docs/porting/connecting_cpp_and_javascript/Interacting-with-code.html#interacting-with-code-ccall-cwrap
"-sEXPORTED_FUNCTIONS=['_setup_sample', '_cleanup_sample', '_load_program', '_inspect_program', '_unload_program', '_call_function', '_malloc', '_free']"
- "-sEXPORTED_RUNTIME_METHODS=['ccall','cwrap','UTF8ToString']"
+ "-sEXPORTED_RUNTIME_METHODS=['ccall', 'cwrap', 'UTF8ToString', 'dynCall', 'addFunction']"
#
"-sASSERTIONS=1"
#
# Programs loaded dynamically can require additional memory, so allow growth.
"-sALLOW_MEMORY_GROWTH"
#
+ # Allow table growth to create new function pointers from JS that C can call.
+ # This is used for the callback passed in to 'call_function'.
+ "-sALLOW_TABLE_GROWTH"
+ #
# For https://emscripten.org/docs/debugging/Sanitizers.html#address-sanitizer
# "-fsanitize=address"
# "-sALLOW_MEMORY_GROWTH"
diff --git a/experimental/web/sample_webgpu/build_sample.sh b/experimental/web/sample_webgpu/build_sample.sh
index 58dafa3..4bf7af9 100755
--- a/experimental/web/sample_webgpu/build_sample.sh
+++ b/experimental/web/sample_webgpu/build_sample.sh
@@ -57,16 +57,18 @@
--o ${BINARY_DIR}/$1_webgpu.vmfb
}
-compile_sample "simple_abs" "none" "${ROOT_DIR?}/samples/models/simple_abs.mlir"
-compile_sample "fullyconnected" "mhlo" "${ROOT_DIR?}/tests/e2e/models/fullyconnected.mlir"
+compile_sample "simple_abs" "none" "${ROOT_DIR?}/samples/models/simple_abs.mlir"
+compile_sample "multiple_results" "none" "${SOURCE_DIR?}/multiple_results.mlir"
+compile_sample "fullyconnected" "stablehlo" "${ROOT_DIR?}/tests/e2e/models/fullyconnected.mlir"
# Does not run yet (uses internal readback, which needs async buffer mapping?)
-# compile_sample "collatz" "${ROOT_DIR?}/tests/e2e/models/collatz.mlir"
+# compile_sample "collatz" "stablehlo" "${ROOT_DIR?}/tests/e2e/models/collatz.mlir"
# Slow, so just run on demand
-# compile_sample "mobilebert" "tosa" "D:/dev/projects/iree-data/models/2022_10_28/mobilebertsquad.tflite.mlir"
-# compile_sample "posenet" "tosa" "D:/dev/projects/iree-data/models/2022_10_28/posenet.tflite.mlir"
-# compile_sample "mobilessd" "tosa" "D:/dev/projects/iree-data/models/2022_10_28/mobile_ssd_v2_float_coco.tflite.mlir"
+# TODO(scotttodd): iree-import-tflite (see generate_web_metrics.sh script)
+# compile_sample "mobilebert" "tosa" "MobileBertSquad_fp32.mlir"
+# compile_sample "posenet" "tosa" "Posenet_fp32.mlir"
+# compile_sample "mobilessd" "tosa" "MobileSSD_fp32.mlir"
###############################################################################
# Build the web artifacts using Emscripten #
diff --git a/experimental/web/sample_webgpu/index.html b/experimental/web/sample_webgpu/index.html
index 949d98b..62c4816 100644
--- a/experimental/web/sample_webgpu/index.html
+++ b/experimental/web/sample_webgpu/index.html
@@ -52,8 +52,8 @@
<p>
This tool works similarly to
- <a href="https://github.com/iree-org/iree/blob/main/tools/iree-run-module-main.cc"><code>iree-run-module</code></a>
- (<a href="https://github.com/iree-org/iree/blob/main/docs/developers/developing_iree/developer_overview.md#iree-run-module">docs</a>).
+ <a href="https://github.com/openxla/iree/blob/main/tools/iree-run-module-main.cc"><code>iree-run-module</code></a>
+ (<a href="https://github.com/openxla/iree/blob/main/docs/developers/developing_iree/developer_overview.md#iree-run-module">docs</a>).
<br>It loads a compiled IREE program then lets you call exported functions.
<br><b>Note:</b> Some outputs are logged to the console.</p>
</p>
@@ -84,13 +84,6 @@
style="min-width:400px; width:initial; min-height:100px; resize:both; font-family: monospace;"></textarea>
</p>
- <p>
- <label for="benchmark-iterations-input" class="form-label">
- Benchmark iterations (inner invoke call):</label>
- <input type="number" id="benchmark-iterations-input" class="form-control"
- style="width:400px; font-family: monospace;" value="1" min="1"></input>
- </p>
-
<button id="call-function" class="btn btn-primary" type="button"
onclick="callFunctionWithFormInputs()" disabled>Call function</button>
<button id="update-url" class="btn btn-secondary" type="button"
@@ -107,8 +100,10 @@
<p>Total time (including overheads):
<code id="benchmark-time-js-output" style="font-family: monospace;"></code></p>
- <p>Mean inference time (Wasm only):
- <code id="benchmark-time-wasm-output" style="font-family: monospace;"></code></p>
+ <p>Invoke time:
+ <code id="benchmark-time-invoke-output" style="font-family: monospace;"></code></p>
+ <p>Readback time:
+ <code id="benchmark-time-readback-output" style="font-family: monospace;"></code></p>
<hr>
<h2>Samples</h2>
@@ -122,7 +117,7 @@
<div class="row" style="padding:4px">
<div class="col-sm">
simple_abs
- (<a href="https://github.com/iree-org/iree/blob/main/iree/samples/models/simple_abs.mlir">source</a>)
+ (<a href="https://github.com/openxla/iree/blob/main/samples/models/simple_abs.mlir">source</a>)
</div>
<div class="col-sm-auto">
<button class="btn btn-secondary" onclick="loadSample('simple_abs')">Load sample</button>
@@ -130,8 +125,17 @@
</div>
<div class="row" style="padding:4px">
<div class="col-sm">
+ multiple_results
+ (<a href="https://github.com/openxla/iree/blob/webgpu/experimental/web/sample_webgpu/multiple_results.mlir">source</a>)
+ </div>
+ <div class="col-sm-auto">
+ <button class="btn btn-secondary" onclick="loadSample('multiple_results')">Load sample</button>
+ </div>
+ </div>
+ <div class="row" style="padding:4px">
+ <div class="col-sm">
fullyconnected
- (<a href="https://github.com/iree-org/iree/blob/main/tests/e2e/models/fullyconnected.mlir">source</a>)
+ (<a href="https://github.com/openxla/iree/blob/main/tests/e2e/models/fullyconnected.mlir">source</a>)
</div>
<div class="col-sm-auto">
<button class="btn btn-secondary" onclick="loadSample('fullyconnected')">Load sample</button>
@@ -196,10 +200,10 @@
const callFunctionButton = document.getElementById("call-function");
const functionNameInput = document.getElementById("function-name-input");
const functionArgumentsInput = document.getElementById("function-arguments-input");
- const benchmarkIterationsInput = document.getElementById("benchmark-iterations-input");
const functionOutputsElement = document.getElementById("function-outputs");
const timeJsOutputElement = document.getElementById("benchmark-time-js-output");
- const timeWasmOutputElement = document.getElementById("benchmark-time-wasm-output");
+ const timeInvokeOutputElement = document.getElementById("benchmark-time-invoke-output");
+ const timeReadbackOutputElement = document.getElementById("benchmark-time-readback-output");
async function finishLoadingProgram(newProgram, newProgramName) {
if (loadedProgram !== null) {
@@ -227,10 +231,6 @@
functionArgumentsInput.value = searchParams.get("arguments");
}
- if (searchParams.has("iterations")) {
- benchmarkIterationsInput.value = searchParams.get("iterations");
- }
-
if (searchParams.has("program")) {
const programPath = searchParams.get("program");
@@ -295,10 +295,9 @@
const functionName = functionNameInput.value;
const inputs = functionArgumentsInput.value.split("\n");
- const iterations = benchmarkIterationsInput.value;
const startJsTime = performance.now();
- ireeCallFunction(loadedProgram, functionName, inputs, iterations)
+ ireeCallFunction(loadedProgram, functionName, inputs)
.then((resultObject) => {
functionOutputsElement.value =
resultObject['outputs'].replace(";", "\n");
@@ -307,10 +306,11 @@
const totalJsTime = endJsTime - startJsTime;
timeJsOutputElement.textContent = totalJsTime.toFixed(3) + "ms";
- const totalWasmTimeMs = resultObject['total_invoke_time_ms'];
- const meanWasmTimeMs = totalWasmTimeMs / iterations;
- timeWasmOutputElement.textContent = meanWasmTimeMs.toFixed(3) +
- "ms / iteration over " + iterations + " iteration(s)";
+ const invokeTimeMs = resultObject['invoke_time_ms'];
+ timeInvokeOutputElement.textContent = invokeTimeMs + "ms";
+
+ const readbackTimeMs = resultObject['readback_time_ms'];
+ timeReadbackOutputElement.textContent = readbackTimeMs + "ms";
})
.catch((error) => {
console.error("Function call error: '" + error + "'");
@@ -329,7 +329,6 @@
const searchParams = new URLSearchParams(window.location.search);
searchParams.set("function", functionNameInput.value);
searchParams.set("arguments", functionArgumentsInput.value);
- searchParams.set("iterations", benchmarkIterationsInput.value);
replaceUrlWithSearchParams(searchParams);
}
@@ -353,6 +352,12 @@
if (sampleName === "simple_abs") {
functionNameInput.value = "abs";
functionArgumentsInput.value = "f32=-1.23";
+ } else if (sampleName === "multiple_results") {
+ functionNameInput.value = "multiple_results";
+ functionArgumentsInput.value = [
+ "f32=-1.23",
+ "f32=-4.56",
+ ].join("\n");
} else if (sampleName === "fullyconnected") {
functionNameInput.value = "main";
functionArgumentsInput.value = [
diff --git a/experimental/web/sample_webgpu/iree_api_webgpu.js b/experimental/web/sample_webgpu/iree_api_webgpu.js
index e97afef..9797286 100644
--- a/experimental/web/sample_webgpu/iree_api_webgpu.js
+++ b/experimental/web/sample_webgpu/iree_api_webgpu.js
@@ -45,7 +45,8 @@
//
// Resolves with a parsed JSON object on success:
// {
-// "total_invoke_time_ms": [number],
+// "invoke_time_ms": [number],
+// "readback_time_ms": [number],
// "outputs": [semicolon delimited list of formatted outputs]
// }
async function ireeCallFunction(
@@ -191,9 +192,7 @@
return Promise.resolve();
}
-function _ireeCallFunction(programState, functionName, inputs, iterations) {
- iterations = iterations !== undefined ? iterations : 1;
-
+function _ireeCallFunction(programState, functionName, inputs) {
let inputsJoined;
if (Array.isArray(inputs)) {
inputsJoined = inputs.join(';');
@@ -204,16 +203,22 @@
'Expected \'inputs\' to be a String or an array of Strings');
}
- // Receive as a pointer, convert, then free. This avoids a memory leak, see
- // https://github.com/emscripten-core/emscripten/issues/6484
- const returnValuePtr =
- wasmCallFunctionFn(programState, functionName, inputsJoined, iterations);
- const returnValue = Module.UTF8ToString(returnValuePtr);
+ return new Promise((resolve, reject) => {
+ const completionCallbackFunction = addFunction((resultPtr) => {
+ if (resultPtr === 0) {
+ reject('Error from callback when calling function');
+ return;
+ }
- if (returnValue === '') {
- return Promise.reject('Wasm module error calling function');
- } else {
- Module._free(returnValuePtr);
- return Promise.resolve(JSON.parse(returnValue));
- }
+ const resultStr = Module.UTF8ToString(resultPtr);
+ Module._free(resultPtr);
+ resolve(JSON.parse(resultStr));
+ }, 'vi');
+
+ const immediateResult = wasmCallFunctionFn(
+ programState, functionName, inputsJoined, completionCallbackFunction);
+ if (!immediateResult) {
+ reject('Immediate error calling function');
+ }
+ });
}
diff --git a/experimental/web/sample_webgpu/main.c b/experimental/web/sample_webgpu/main.c
index 0e2ac66..d585c9e 100644
--- a/experimental/web/sample_webgpu/main.c
+++ b/experimental/web/sample_webgpu/main.c
@@ -15,6 +15,8 @@
#include "experimental/webgpu/buffer.h"
#include "experimental/webgpu/webgpu_device.h"
#include "iree/base/api.h"
+#include "iree/base/internal/wait_handle.h"
+#include "iree/base/loop_emscripten.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/runtime/api.h"
@@ -48,21 +50,32 @@
// Unloads a program and frees its state.
void unload_program(iree_program_state_t* program_state);
-// Calls a function synchronously.
+// Callback function passed in from JS.
+// On call failure, is called with the empty string.
+// On call success, is called with a JSON object:
+// {
+// "invoke_time_ms": [number],
+// "readback_time_ms": [number],
+// "outputs": [semicolon delimited list of formatted outputs]
+// }
+// TODO(scotttodd): on error, call with some other JSON / iree_status_fprint
+// TODO(scotttodd): on success, return structured data instead of formatted text
+typedef void(IREE_API_PTR* iree_call_function_callback_fn_t)(char* output);
+
+// Calls a function asynchronously.
//
-// Returns a semicolon-delimited list of formatted outputs on success or the
-// empty string on failure. Note: This is in need of some real API bindings
-// that marshal structured data between C <-> JS.
+// Returns true if call setup was successful, or false if setup failed.
//
-// * |function_name| is the fully qualified function name, like 'module.abs'.
+// * |function_name| is the function name, like 'abs' (not fully qualified).
// * |inputs| is a semicolon delimited list of VM scalars and buffers, as
// described in iree/tooling/vm_util and used in IREE's CLI tools.
// For example, the CLI `--function_input=f32=1 --function_input=f32=2`
// should be passed here as `f32=1;f32=2`.
-// * |iterations| is the number of times to call the function, for benchmarking
-const char* call_function(iree_program_state_t* program_state,
- const char* function_name, const char* inputs,
- int iterations);
+// * |completion_callback_fn| will be called after the call completes.
+const bool call_function(
+ iree_program_state_t* program_state, const char* function_name,
+ const char* inputs,
+ iree_call_function_callback_fn_t completion_callback_fn);
//===----------------------------------------------------------------------===//
// Implementation
@@ -71,21 +84,121 @@
typedef struct iree_sample_state_t {
iree_runtime_instance_t* instance;
iree_hal_device_t* device;
+ iree_loop_emscripten_t* loop;
} iree_sample_state_t;
typedef struct iree_program_state_t {
+ iree_sample_state_t* sample_state;
iree_runtime_session_t* session;
iree_vm_module_t* module;
} iree_program_state_t;
+// Function calls run asynchronously, both for the computations themselves
+// and for mapping the outputs back from GPU/device memory to CPU/host memory.
+//
+// Here is the high level flow:
+//
+// call_function
+// parse_inputs_into_call
+// iree_vm_async_invoke
+// [async] invoke_callback
+// process_call_outputs
+// <batch transfer from output buffers to mappable device buffers>
+// map_call_output[0...n]
+// wgpuBufferMapAsync
+// [async] buffer_map_async_callback
+// <set event>
+// [async] iree_loop_wait_all
+// map_all_callback
+// print_outputs_from_call
+// issue top level callback
+//
+// State tracking uses the iree_call_function_state_t and iree_output_state_t
+// structs defined below, keeping both live until the final map_all_callback.
+// If any of the asynchronous calls fail, all other calls must be treated as
+// failed and all state must be freed.
+
+// State for a single output within an asynchronous function call.
+//
+// If the output is a buffer, optional fields are used to track asynchronous
+// mapping from device memory to host memory.
+typedef struct iree_output_state_t {
+ // Event that will be signaled when the output is ready to access on the host.
+ // * For non-buffer outputs, this will start signaled.
+ // * For buffer outputs, this will be signaled when _either_
+ // |mapped_host_buffer| points to the output data _or_ mapping failed.
+ iree_event_t ready_event;
+
+ // The original buffer_view from the call output, if the output is a buffer.
+ // Not guaranteed to reference mappable memory.
+ iree_hal_buffer_view_t* buffer_view;
+
+ // A mappable device buffer (backed by a WGPUBuffer).
+ // The original buffer will be copied into this buffer by a transfer command.
+ iree_hal_buffer_t* mappable_device_buffer;
+
+ // A mapped host buffer (backed by a iree_hal_heap_buffer_t).
+ // After asynchronous mapping of |mappable_device_buffer|, the output data
+ // will be wrapped into this buffer.
+ // Note: to recover the original shape, use iree_hal_buffer_view_create_like.
+ iree_hal_buffer_t* mapped_host_buffer;
+} iree_output_state_t;
+
+// Aggregate state for an asynchronous function call.
+typedef struct iree_call_function_state_t {
+ iree_runtime_call_t call;
+ iree_loop_emscripten_t* loop;
+ iree_call_function_callback_fn_t callback_fn;
+
+ // Opaque state used by iree_vm_async_invoke.
+ iree_vm_async_invoke_state_t* invoke_state;
+
+ // Timing/statistics metadata (~millisecond precision on the web).
+ // https://developer.mozilla.org/en-US/docs/Web/API/Performance/now#reduced_time_precision
+ iree_time_t invoke_start_time;
+ iree_time_t invoke_end_time;
+ iree_time_t readback_start_time;
+ iree_time_t readback_end_time;
+
+ // Sticky status for the first async error.
+ // If this is not ok, treat all output buffer mappings as having errored and
+ // issue the callback with a failure message.
+ iree_status_t async_status;
+
+ // Output processing state.
+ iree_host_size_t outputs_size;
+ iree_output_state_t* output_states;
+} iree_call_function_state_t;
+
+static void iree_call_function_state_destroy(
+ iree_call_function_state_t* call_state) {
+ // Output processing state.
+ for (iree_host_size_t i = 0; i < call_state->outputs_size; ++i) {
+ if (call_state->output_states[i].buffer_view) {
+ iree_hal_buffer_release(call_state->output_states[i].mapped_host_buffer);
+ iree_hal_buffer_release(
+ call_state->output_states[i].mappable_device_buffer);
+ iree_hal_buffer_view_release(call_state->output_states[i].buffer_view);
+ }
+ iree_event_deinitialize(&call_state->output_states[i].ready_event);
+ }
+ iree_allocator_free(iree_allocator_system(), call_state->output_states);
+ iree_status_free(call_state->async_status);
+
+ // Invoke state.
+ iree_allocator_free(iree_allocator_system(), call_state->invoke_state);
+ iree_runtime_call_deinitialize(&call_state->call);
+
+ iree_allocator_free(iree_allocator_system(), call_state);
+}
+
extern iree_status_t create_device(iree_allocator_t host_allocator,
iree_hal_device_t** out_device);
iree_sample_state_t* setup_sample() {
iree_sample_state_t* sample_state = NULL;
- iree_status_t status =
- iree_allocator_malloc(iree_allocator_system(),
- sizeof(iree_sample_state_t), (void**)&sample_state);
+ iree_status_t status = iree_allocator_malloc(
+ iree_allocator_system(), sizeof(*sample_state), (void**)&sample_state);
iree_runtime_instance_options_t instance_options;
iree_runtime_instance_options_initialize(&instance_options);
@@ -100,6 +213,11 @@
status = create_device(iree_allocator_system(), &sample_state->device);
}
+ if (iree_status_is_ok(status)) {
+ status = iree_loop_emscripten_allocate(iree_allocator_system(),
+ &sample_state->loop);
+ }
+
if (!iree_status_is_ok(status)) {
iree_status_fprint(stderr, status);
iree_status_free(status);
@@ -111,6 +229,7 @@
}
void cleanup_sample(iree_sample_state_t* sample_state) {
+ iree_loop_emscripten_free(sample_state->loop);
iree_hal_device_release(sample_state->device);
iree_runtime_instance_release(sample_state->instance);
free(sample_state);
@@ -119,9 +238,9 @@
iree_program_state_t* load_program(iree_sample_state_t* sample_state,
uint8_t* vmfb_data, size_t length) {
iree_program_state_t* program_state = NULL;
- iree_status_t status = iree_allocator_malloc(iree_allocator_system(),
- sizeof(iree_program_state_t),
- (void**)&program_state);
+ iree_status_t status = iree_allocator_malloc(
+ iree_allocator_system(), sizeof(*program_state), (void**)&program_state);
+ program_state->sample_state = sample_state;
iree_runtime_session_options_t session_options;
iree_runtime_session_options_initialize(&session_options);
@@ -207,6 +326,10 @@
free(program_state);
}
+//===----------------------------------------------------------------------===//
+// Input parsing
+//===----------------------------------------------------------------------===//
+
static iree_status_t parse_input_into_call(
iree_runtime_call_t* call, iree_hal_allocator_t* device_allocator,
iree_string_view_t input) {
@@ -284,9 +407,13 @@
return iree_ok_status();
}
+//===----------------------------------------------------------------------===//
+// Output readback and formatting
+//===----------------------------------------------------------------------===//
+
typedef struct iree_buffer_map_userdata_t {
- iree_hal_buffer_view_t* source_buffer_view;
- iree_hal_buffer_t* readback_buffer;
+ iree_call_function_state_t* call_state;
+ iree_host_size_t buffer_index;
} iree_buffer_map_userdata_t;
static void iree_webgpu_mapped_buffer_release(void* user_data,
@@ -295,43 +422,52 @@
wgpuBufferUnmap(buffer_handle);
}
-// TODO(scotttodd): move async mapping into webgpu/buffer.h/.c?
-static void buffer_map_sync_callback(WGPUBufferMapAsyncStatus map_status,
- void* userdata_ptr) {
+static void buffer_map_async_callback(WGPUBufferMapAsyncStatus map_status,
+ void* userdata_ptr) {
iree_buffer_map_userdata_t* userdata =
(iree_buffer_map_userdata_t*)userdata_ptr;
+ iree_host_size_t buffer_index = userdata->buffer_index;
+ iree_hal_buffer_view_t* output_buffer_view =
+ userdata->call_state->output_states[buffer_index].buffer_view;
+ iree_hal_buffer_t* mappable_device_buffer =
+ userdata->call_state->output_states[buffer_index].mappable_device_buffer;
+ iree_hal_buffer_t** mapped_host_buffer_ptr =
+ &userdata->call_state->output_states[buffer_index].mapped_host_buffer;
+
switch (map_status) {
case WGPUBufferMapAsyncStatus_Success:
break;
case WGPUBufferMapAsyncStatus_Error:
- fprintf(stderr, " buffer_map_sync_callback status: Error\n");
+ fprintf(stderr, " buffer_map_async_callback status: Error\n");
break;
case WGPUBufferMapAsyncStatus_DeviceLost:
- fprintf(stderr, " buffer_map_sync_callback status: DeviceLost\n");
+ fprintf(stderr, " buffer_map_async_callback status: DeviceLost\n");
break;
case WGPUBufferMapAsyncStatus_Unknown:
default:
- fprintf(stderr, " buffer_map_sync_callback status: Unknown\n");
+ fprintf(stderr, " buffer_map_async_callback status: Unknown\n");
break;
}
if (map_status != WGPUBufferMapAsyncStatus_Success) {
- iree_hal_buffer_view_release(userdata->source_buffer_view);
- iree_hal_buffer_release(userdata->readback_buffer);
+ // Set the sticky async error if not already set.
+ userdata->call_state->async_status = iree_status_join(
+ userdata->call_state->async_status,
+ iree_make_status(IREE_STATUS_UNKNOWN,
+ "wgpuBufferMapAsync failed for buffer %" PRIhsz,
+ buffer_index));
+ iree_event_set(
+ &userdata->call_state->output_states[buffer_index].ready_event);
iree_allocator_free(iree_allocator_system(), userdata);
return;
}
- iree_status_t status = iree_ok_status();
-
- // TODO(scotttodd): bubble result(s) up to the caller (async + callback API)
-
iree_device_size_t data_offset = iree_hal_buffer_byte_offset(
- iree_hal_buffer_view_buffer(userdata->source_buffer_view));
+ iree_hal_buffer_view_buffer(output_buffer_view));
iree_device_size_t data_length =
- iree_hal_buffer_view_byte_length(userdata->source_buffer_view);
+ iree_hal_buffer_view_byte_length(output_buffer_view);
WGPUBuffer buffer_handle =
- iree_hal_webgpu_buffer_handle(userdata->readback_buffer);
+ iree_hal_webgpu_buffer_handle(mappable_device_buffer);
// For this sample we want to print arbitrary buffers, which is easiest
// using the |iree_hal_buffer_view_format| function. Internally, that
@@ -343,13 +479,13 @@
const void* data_ptr =
wgpuBufferGetConstMappedRange(buffer_handle, data_offset, data_length);
- iree_hal_buffer_t* heap_buffer = NULL;
+ iree_status_t status = iree_ok_status();
if (iree_status_is_ok(status)) {
// The buffer we get from WebGPU may not be aligned to 64.
iree_hal_memory_access_t memory_access =
IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_UNALIGNED;
status = iree_hal_heap_buffer_wrap(
- userdata->readback_buffer->device_allocator,
+ mappable_device_buffer->device_allocator,
IREE_HAL_MEMORY_TYPE_HOST_LOCAL, memory_access,
IREE_HAL_BUFFER_USAGE_MAPPING, data_length,
iree_make_byte_span((void*)data_ptr, data_length),
@@ -357,153 +493,22 @@
.fn = iree_webgpu_mapped_buffer_release,
.user_data = buffer_handle,
},
- &heap_buffer);
+ mapped_host_buffer_ptr);
}
- // Copy the original buffer_view, backed by the mapped heap buffer instead.
- iree_hal_buffer_view_t* heap_buffer_view = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_hal_buffer_view_create_like(
- heap_buffer, userdata->source_buffer_view, iree_allocator_system(),
- &heap_buffer_view);
- }
+ // Set the sticky async error if not already set.
+ userdata->call_state->async_status =
+ iree_status_join(userdata->call_state->async_status, status);
- if (iree_status_is_ok(status)) {
- fprintf(stdout, "Call output:\n");
- status = iree_hal_buffer_view_fprint(stdout, heap_buffer_view,
- /*max_element_count=*/4096,
- iree_allocator_system());
- fprintf(stdout, "\n");
- }
- iree_hal_buffer_view_release(heap_buffer_view);
- iree_hal_buffer_release(heap_buffer);
-
- if (!iree_status_is_ok(status)) {
- fprintf(stderr, "buffer_map_sync_callback error:\n");
- iree_status_fprint(stderr, status);
- iree_status_free(status);
- }
-
- iree_hal_buffer_view_release(userdata->source_buffer_view);
- iree_hal_buffer_release(userdata->readback_buffer);
+ iree_event_set(
+ &userdata->call_state->output_states[buffer_index].ready_event);
iree_allocator_free(iree_allocator_system(), userdata);
}
-static iree_status_t print_buffer_view(iree_hal_device_t* device,
- iree_hal_buffer_view_t* buffer_view) {
- iree_status_t status = iree_ok_status();
-
- iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view);
- iree_device_size_t data_offset = iree_hal_buffer_byte_offset(buffer);
- iree_device_size_t data_length =
- iree_hal_buffer_view_byte_length(buffer_view);
-
- // ----------------------------------------------
- // Allocate mappable host memory.
- // Note: iree_hal_webgpu_simple_allocator_allocate_buffer only supports
- // CopySrc today, so we'll create the buffer directly with
- // wgpuDeviceCreateBuffer and then wrap it using iree_hal_webgpu_buffer_wrap.
- WGPUBufferDescriptor descriptor = {
- .nextInChain = NULL,
- .label = "IREE_readback",
- .usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst,
- .size = data_length,
- .mappedAtCreation = false,
- };
- WGPUBuffer readback_buffer_handle = NULL;
- if (iree_status_is_ok(status)) {
- readback_buffer_handle = wgpuDeviceCreateBuffer(
- iree_hal_webgpu_device_handle(device), &descriptor);
- if (!readback_buffer_handle) {
- status = iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
- "unable to allocate buffer of size %" PRIdsz,
- data_length);
- }
- }
- iree_device_size_t target_offset = 0;
- const iree_hal_buffer_params_t target_params = {
- .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
- .type =
- IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
- .access = IREE_HAL_MEMORY_ACCESS_ALL,
- };
- iree_hal_buffer_t* readback_buffer = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_hal_webgpu_buffer_wrap(
- device, iree_hal_device_allocator(device), target_params.type,
- target_params.access, target_params.usage, data_length,
- /*byte_offset=*/0,
- /*byte_length=*/data_length, readback_buffer_handle,
- iree_allocator_system(), &readback_buffer);
- }
- // ----------------------------------------------
-
- // ----------------------------------------------
- // Transfer from device memory to mappable host memory.
- const iree_hal_transfer_command_t transfer_command = {
- .type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY,
- .copy =
- {
- .source_buffer = buffer,
- .source_offset = data_offset,
- .target_buffer = readback_buffer,
- .target_offset = target_offset,
- .length = data_length,
- },
- };
- iree_hal_command_buffer_t* command_buffer = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_hal_create_transfer_command_buffer(
- device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
- IREE_HAL_QUEUE_AFFINITY_ANY, /*transfer_count=*/1, &transfer_command,
- &command_buffer);
- }
- iree_hal_semaphore_t* fence_semaphore = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_hal_semaphore_create(device, 0ull, &fence_semaphore);
- }
- uint64_t signal_value = 1ull;
- if (iree_status_is_ok(status)) {
- iree_hal_semaphore_list_t signal_semaphores = {
- .count = 1,
- .semaphores = &fence_semaphore,
- .payload_values = &signal_value,
- };
- status = iree_hal_device_queue_execute(
- device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
- signal_semaphores, 1, &command_buffer);
- }
- // TODO(scotttodd): Make this async - pass a wait source to iree_loop_wait_one
- if (iree_status_is_ok(status)) {
- status = iree_hal_semaphore_wait(fence_semaphore, signal_value,
- iree_infinite_timeout());
- }
- iree_hal_command_buffer_release(command_buffer);
- iree_hal_semaphore_release(fence_semaphore);
- // ----------------------------------------------
-
- iree_buffer_map_userdata_t* userdata = NULL;
- if (iree_status_is_ok(status)) {
- status = iree_allocator_malloc(iree_allocator_system(),
- sizeof(iree_buffer_map_userdata_t),
- (void**)&userdata);
- iree_hal_buffer_view_retain(buffer_view); // Released in the callback.
- userdata->source_buffer_view = buffer_view;
- userdata->readback_buffer = readback_buffer;
- }
-
- if (iree_status_is_ok(status)) {
- wgpuBufferMapAsync(readback_buffer_handle, WGPUMapMode_Read, /*offset=*/0,
- /*size=*/data_length, buffer_map_sync_callback,
- /*userdata=*/userdata);
- }
-
- return status;
-}
-
static iree_status_t print_outputs_from_call(
- iree_runtime_call_t* call, iree_string_builder_t* outputs_builder) {
- iree_vm_list_t* variants_list = iree_runtime_call_outputs(call);
+ iree_call_function_state_t* call_state,
+ iree_string_builder_t* outputs_builder) {
+ iree_vm_list_t* variants_list = iree_runtime_call_outputs(&call_state->call);
for (iree_host_size_t i = 0; i < iree_vm_list_size(variants_list); ++i) {
iree_vm_variant_t variant = iree_vm_variant_empty();
IREE_RETURN_IF_ERROR(
@@ -549,16 +554,15 @@
}
}
} else if (iree_vm_variant_is_ref(variant)) {
- if (iree_hal_buffer_view_isa(variant.ref)) {
- iree_hal_buffer_view_t* buffer_view =
- iree_hal_buffer_view_deref(variant.ref);
- // TODO(scotttodd): join async outputs together and return to caller
- iree_hal_device_t* device = iree_runtime_session_device(call->session);
- IREE_RETURN_IF_ERROR(print_buffer_view(device, buffer_view));
- } else {
- IREE_RETURN_IF_ERROR(iree_string_builder_append_cstring(
- outputs_builder, "(no printer)"));
- }
+ // Interpret the mapped buffer in the same format as the output view.
+ iree_hal_buffer_view_t* heap_buffer_view = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create_like(
+ call_state->output_states[i].mapped_host_buffer,
+ call_state->output_states[i].buffer_view, iree_allocator_system(),
+ &heap_buffer_view));
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_append_to_builder(
+ heap_buffer_view, SIZE_MAX, outputs_builder));
+ iree_hal_buffer_view_release(heap_buffer_view);
} else {
IREE_RETURN_IF_ERROR(
iree_string_builder_append_cstring(outputs_builder, "(null)"));
@@ -575,28 +579,310 @@
return iree_ok_status();
}
+static iree_status_t map_all_callback(void* user_data, iree_loop_t loop,
+ iree_status_t status) {
+ iree_call_function_state_t* call_state =
+ (iree_call_function_state_t*)user_data;
+ call_state->readback_end_time = iree_time_now();
+
+ status = iree_status_join(call_state->async_status, status);
+
+ iree_string_builder_t output_string_builder;
+ iree_string_builder_initialize(iree_allocator_system(),
+ &output_string_builder);
+
+ // Output a JSON object as a string:
+ // {
+ // "invoke_time_ms": [number],
+ // "readback_time_ms": [number],
+ // "outputs": [semicolon delimited list of formatted outputs]
+ // }
+ if (iree_status_is_ok(status)) {
+ iree_time_t invoke_time_ms =
+ (call_state->invoke_end_time - call_state->invoke_start_time) / 1000000;
+ iree_time_t readback_time_ms =
+ (call_state->readback_end_time - call_state->readback_start_time) /
+ 1000000;
+ status = iree_string_builder_append_format(
+ &output_string_builder,
+ "{ \"invoke_time_ms\": %" PRId64 ", \"readback_time_ms\": %" PRId64
+ ", \"outputs\": \"",
+ invoke_time_ms, readback_time_ms);
+ }
+ if (iree_status_is_ok(status)) {
+ status = print_outputs_from_call(call_state, &output_string_builder);
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_string_builder_append_cstring(&output_string_builder, "\"}");
+ }
+
+ if (iree_status_is_ok(status)) {
+ // Note: this leaks the buffer. It's up to the caller to free it after use.
+ char* outputs_string =
+ strdup(iree_string_builder_buffer(&output_string_builder));
+ iree_string_builder_deinitialize(&output_string_builder);
+ call_state->callback_fn(outputs_string);
+ } else {
+ fprintf(stderr, "map_all_callback error:\n");
+ // TODO(scotttodd): return a JSON object with the error message
+ // * free |status| and return a status to the loop with no storage
+ // * the returned string is always freed, so then we'd have no leaks
+ iree_status_fprint(stderr, status);
+ // Note: loop_emscripten.js must free 'status'!
+ call_state->callback_fn(NULL);
+ }
+
+ iree_string_builder_deinitialize(&output_string_builder);
+ iree_call_function_state_destroy(call_state);
+ return status;
+}
+
+static iree_status_t allocate_mappable_device_buffer(
+ iree_hal_device_t* device, iree_hal_buffer_view_t* buffer_view,
+ iree_hal_buffer_t** out_buffer) {
+ *out_buffer = NULL;
+
+ iree_device_size_t data_length =
+ iree_hal_buffer_view_byte_length(buffer_view);
+
+ // Note: iree_hal_webgpu_simple_allocator_allocate_buffer only supports
+ // CopySrc today, so we'll create the buffer directly with
+ // wgpuDeviceCreateBuffer and then wrap it using iree_hal_webgpu_buffer_wrap.
+ WGPUBufferDescriptor descriptor = {
+ .nextInChain = NULL,
+ .label = "IREE_mapping",
+ .usage = WGPUBufferUsage_MapRead | WGPUBufferUsage_CopyDst,
+ .size = data_length,
+ .mappedAtCreation = false,
+ };
+ WGPUBuffer device_buffer_handle = NULL;
+ // Note: wgpuBufferDestroy is called after iree_hal_webgpu_buffer_wrap ->
+ // iree_hal_buffer_release -> iree_hal_webgpu_buffer_destroy
+ device_buffer_handle = wgpuDeviceCreateBuffer(
+ iree_hal_webgpu_device_handle(device), &descriptor);
+ if (!device_buffer_handle) {
+ return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED,
+ "unable to allocate buffer of size %" PRIdsz,
+ data_length);
+ }
+ const iree_hal_buffer_params_t target_params = {
+ .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING,
+ .type =
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ .access = IREE_HAL_MEMORY_ACCESS_ALL,
+ };
+ return iree_hal_webgpu_buffer_wrap(
+ device, iree_hal_device_allocator(device), target_params.type,
+ target_params.access, target_params.usage, data_length,
+ /*byte_offset=*/0,
+ /*byte_length=*/data_length, device_buffer_handle,
+ iree_allocator_system(), out_buffer);
+}
+
+// Processes outputs from a completed function invocation.
+// Some output data types may require asynchronous mapping.
+static iree_status_t process_call_outputs(
+ iree_call_function_state_t* call_state) {
+ call_state->readback_start_time = iree_time_now();
+
+ iree_vm_list_t* outputs_list = iree_runtime_call_outputs(&call_state->call);
+ iree_host_size_t outputs_size = iree_vm_list_size(outputs_list);
+ iree_hal_device_t* device =
+ iree_runtime_session_device(call_state->call.session);
+
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+ iree_allocator_system(), sizeof(iree_output_state_t) * outputs_size,
+ (void**)&call_state->output_states));
+ call_state->outputs_size = outputs_size;
+
+ // TODO(scotttodd): allocate on the heap and track?
+ // * iree_loop_wait_all claims that wait_sources must live until the
+ // callback is issued
+ // * loop_emscripten uses what it needs (Promise handles/objects)
+ // immediately, before objects go out of scope
+ iree_wait_source_t* wait_sources = (iree_wait_source_t*)iree_alloca(
+ sizeof(iree_wait_source_t) * outputs_size);
+
+ // Loop through the outputs once to find buffers that need readback.
+ iree_host_size_t buffer_count = 0;
+ for (iree_host_size_t i = 0; i < outputs_size; ++i) {
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ IREE_RETURN_IF_ERROR(
+ iree_vm_list_get_variant_assign(outputs_list, i, &variant),
+ "variant %" PRIhsz " not present", i);
+
+ if (iree_vm_variant_is_ref(variant)) {
+ if (!iree_hal_buffer_view_isa(variant.ref)) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "only buffer_view variants are supported");
+ }
+
+ // Output is a buffer_view ref, add to mapping batch (async).
+ iree_event_initialize(false, &call_state->output_states[i].ready_event);
+ buffer_count++;
+ iree_hal_buffer_view_t* buffer_view =
+ iree_hal_buffer_view_deref(variant.ref);
+ call_state->output_states[i].buffer_view = buffer_view;
+ iree_hal_buffer_view_retain(buffer_view);
+
+ // TODO(scotttodd): signal event if failed
+ IREE_RETURN_IF_ERROR(allocate_mappable_device_buffer(
+ device, buffer_view,
+ &call_state->output_states[i].mappable_device_buffer));
+ } else {
+ // Not a buffer, data is available immediately - start signaled.
+ iree_event_initialize(true, &call_state->output_states[i].ready_event);
+ }
+ wait_sources[i] =
+ iree_event_await(&call_state->output_states[i].ready_event);
+ }
+
+ // Loop through the outputs again to build a batched transfer command buffer.
+ iree_hal_transfer_command_t* transfer_commands =
+ (iree_hal_transfer_command_t*)iree_alloca(
+ sizeof(iree_hal_transfer_command_t) * buffer_count);
+ for (iree_host_size_t i = 0, buffer_index = 0; i < outputs_size; ++i) {
+ iree_hal_buffer_view_t* buffer_view =
+ call_state->output_states[i].buffer_view;
+ if (!buffer_view) continue;
+
+ iree_hal_buffer_t* source_buffer = iree_hal_buffer_view_buffer(buffer_view);
+ iree_device_size_t data_offset = iree_hal_buffer_byte_offset(source_buffer);
+ iree_hal_buffer_t* target_buffer =
+ call_state->output_states[i].mappable_device_buffer;
+ iree_device_size_t target_offset = 0;
+ iree_device_size_t data_length =
+ iree_hal_buffer_view_byte_length(buffer_view);
+
+ transfer_commands[buffer_index++] = (iree_hal_transfer_command_t){
+ .type = IREE_HAL_TRANSFER_COMMAND_TYPE_COPY,
+ .copy =
+ {
+ .source_buffer = source_buffer,
+ .source_offset = data_offset,
+ .target_buffer = target_buffer,
+ .target_offset = target_offset,
+ .length = data_length,
+ },
+ };
+ }
+
+ // Construct and issue the transfer command buffer, then wait on it.
+ iree_status_t status = iree_ok_status();
+ iree_hal_command_buffer_t* transfer_command_buffer = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_create_transfer_command_buffer(
+ device, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ IREE_HAL_QUEUE_AFFINITY_ANY, buffer_count, transfer_commands,
+ &transfer_command_buffer);
+ }
+ iree_hal_semaphore_t* signal_semaphore = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_semaphore_create(device, 0ull, &signal_semaphore);
+ }
+ uint64_t signal_value = 1ull;
+ if (iree_status_is_ok(status)) {
+ iree_hal_semaphore_list_t signal_semaphores = {
+ .count = 1,
+ .semaphores = &signal_semaphore,
+ .payload_values = &signal_value,
+ };
+ status = iree_hal_device_queue_execute(
+ device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
+ signal_semaphores, 1, &transfer_command_buffer);
+ }
+ // TODO(scotttodd): Make this async - pass a wait source to iree_loop_wait_one
+ // 1. create iree_hal_fence_t, iree_hal_fence_insert(fance, semaphore)
+ // 2. iree_hal_fence_await -> iree_wait_source_t
+ // 3. iree_loop_wait_one(loop, wait_source, ...)
+ // (requires moving off of nop_semaphore and wait source import)
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_semaphore_wait(signal_semaphore, signal_value,
+ iree_infinite_timeout());
+ }
+ iree_hal_command_buffer_release(transfer_command_buffer);
+ iree_hal_semaphore_release(signal_semaphore);
+
+ // Loop through one last time to map the buffers asynchronously.
+ for (iree_host_size_t i = 0; i < outputs_size; ++i) {
+ if (!iree_status_is_ok(status)) break;
+ if (!call_state->output_states[i].mappable_device_buffer) continue;
+
+ iree_buffer_map_userdata_t* map_userdata = NULL;
+ IREE_RETURN_IF_ERROR(iree_allocator_malloc(
+ iree_allocator_system(), sizeof(*map_userdata), (void**)&map_userdata));
+
+ map_userdata->call_state = call_state;
+ map_userdata->buffer_index = i;
+ iree_device_size_t data_length = iree_hal_buffer_view_byte_length(
+ call_state->output_states[i].buffer_view);
+
+ WGPUBuffer device_buffer = iree_hal_webgpu_buffer_handle(
+ call_state->output_states[i].mappable_device_buffer);
+ wgpuBufferMapAsync(device_buffer, WGPUMapMode_Read,
+ /*offset=*/0,
+ /*size=*/data_length, buffer_map_async_callback,
+ /*userdata=*/map_userdata);
+ }
+
+ // Finally, wait on all wait sources.
+ //
+ // If there are any buffer outputs that need asynchronous mapping, those
+ // wait sources will be signaled when the mapping completes.
+ //
+ // Note: call_state (and everything within it) is kept alive until the
+ // callback resolves.
+ IREE_RETURN_IF_ERROR(iree_loop_wait_all(
+ iree_loop_emscripten(call_state->loop), outputs_size, wait_sources,
+ iree_make_timeout_ms(5000), map_all_callback,
+ /*user_data=*/call_state));
+
+ return iree_ok_status();
+}
+
+//===----------------------------------------------------------------------===//
+// Function calling / invocations
+//===----------------------------------------------------------------------===//
+
+// Handles the completion callback from `iree_vm_async_invoke()`.
iree_status_t invoke_callback(void* user_data, iree_loop_t loop,
iree_status_t status, iree_vm_list_t* outputs) {
- iree_vm_async_invoke_state_t* invoke_state =
- (iree_vm_async_invoke_state_t*)user_data;
+ iree_call_function_state_t* call_state =
+ (iree_call_function_state_t*)user_data;
+ call_state->invoke_end_time = iree_time_now();
if (!iree_status_is_ok(status)) {
fprintf(stderr, "iree_vm_async_invoke_callback_fn_t error:\n");
iree_status_fprint(stderr, status);
- iree_status_free(status);
+ iree_call_function_state_destroy(call_state);
+ return status; // Note: loop_emscripten.js must free this!
}
- iree_vm_list_release(outputs);
+ status = process_call_outputs(call_state);
+ if (!iree_status_is_ok(status)) {
+ fprintf(stderr, "process_call_outputs error:\n");
+ iree_status_fprint(stderr, status);
+ iree_call_function_state_destroy(call_state);
+ // Note: loop_emscripten.js must free 'status'!
+ }
- iree_allocator_free(iree_allocator_system(), (void*)invoke_state);
- return iree_ok_status();
+ return status;
}
-const char* call_function(iree_program_state_t* program_state,
- const char* function_name, const char* inputs,
- int iterations) {
+const bool call_function(
+ iree_program_state_t* program_state, const char* function_name,
+ const char* inputs,
+ iree_call_function_callback_fn_t completion_callback_fn) {
iree_status_t status = iree_ok_status();
+ iree_call_function_state_t* call_state = NULL;
+ if (iree_status_is_ok(status)) {
+ status = iree_allocator_malloc(iree_allocator_system(), sizeof(*call_state),
+ (void**)&call_state);
+ }
+ call_state->loop = program_state->sample_state->loop;
+ call_state->callback_fn = completion_callback_fn;
+
// Fully qualify the function name. This sample only supports loading one
// module (i.e. 'program') per session, so we can do this.
iree_string_builder_t name_builder;
@@ -607,85 +893,47 @@
(int)module_name.size,
module_name.data, function_name);
}
-
- iree_runtime_call_t call;
if (iree_status_is_ok(status)) {
status = iree_runtime_call_initialize_by_name(
- program_state->session, iree_string_builder_view(&name_builder), &call);
+ program_state->session, iree_string_builder_view(&name_builder),
+ &call_state->call);
}
iree_string_builder_deinitialize(&name_builder);
if (iree_status_is_ok(status)) {
status = parse_inputs_into_call(
- &call, iree_runtime_session_device_allocator(program_state->session),
+ &call_state->call,
+ iree_runtime_session_device_allocator(program_state->session),
iree_make_cstring_view(inputs));
}
- // Note: Timing has ~millisecond precision on the web to mitigate timing /
- // side-channel security threats.
- // https://developer.mozilla.org/en-US/docs/Web/API/Performance/now#reduced_time_precision
- iree_time_t start_time = iree_time_now();
-
- // TODO(scotttodd): benchmark iterations (somehow with async)
-
- iree_vm_async_invoke_state_t* invoke_state = NULL;
if (iree_status_is_ok(status)) {
status = iree_allocator_malloc(iree_allocator_system(),
- sizeof(iree_vm_async_invoke_state_t),
- (void**)&invoke_state);
- }
- // TODO(scotttodd): emscripten / browser loop here
- iree_status_t loop_status = iree_ok_status();
- iree_loop_t loop = iree_loop_inline(&loop_status);
- if (iree_status_is_ok(status)) {
- iree_vm_context_t* vm_context = iree_runtime_session_context(call.session);
- iree_vm_function_t vm_function = call.function;
- iree_vm_list_t* inputs = call.inputs;
- iree_vm_list_t* outputs = call.outputs;
-
- status = iree_vm_async_invoke(loop, invoke_state, vm_context, vm_function,
- IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
- inputs, outputs, iree_allocator_system(),
- invoke_callback,
- /*user_data=*/invoke_state);
+ sizeof(*(call_state->invoke_state)),
+ (void**)&(call_state->invoke_state));
}
- // TODO(scotttodd): record end time in async callback instead of here
- // TODO(scotttodd): print outputs in async callback instead of here
-
- iree_time_t end_time = iree_time_now();
- iree_time_t time_elapsed = end_time - start_time;
-
- iree_string_builder_t outputs_builder;
- iree_string_builder_initialize(iree_allocator_system(), &outputs_builder);
-
- // Output a JSON object as a string:
- // {
- // "total_invoke_time_ms": [number],
- // "outputs": [semicolon delimited list of formatted outputs]
- // }
if (iree_status_is_ok(status)) {
- status = iree_string_builder_append_format(
- &outputs_builder,
- "{ \"total_invoke_time_ms\": %" PRId64 ", \"outputs\": \"",
- time_elapsed / 1000000);
- }
- if (iree_status_is_ok(status)) {
- status = print_outputs_from_call(&call, &outputs_builder);
- }
- if (iree_status_is_ok(status)) {
- status = iree_string_builder_append_cstring(&outputs_builder, "\"}");
+ iree_loop_t loop = iree_loop_emscripten(program_state->sample_state->loop);
+ iree_vm_context_t* vm_context =
+ iree_runtime_session_context(call_state->call.session);
+ iree_vm_function_t vm_function = call_state->call.function;
+ iree_vm_list_t* inputs = call_state->call.inputs;
+ iree_vm_list_t* outputs = call_state->call.outputs;
+
+ call_state->invoke_start_time = iree_time_now();
+ status = iree_vm_async_invoke(
+ loop, call_state->invoke_state, vm_context, vm_function,
+ IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL, inputs, outputs,
+ iree_allocator_system(), invoke_callback, /*user_data=*/call_state);
}
if (!iree_status_is_ok(status)) {
- iree_string_builder_deinitialize(&outputs_builder);
iree_status_fprint(stderr, status);
iree_status_free(status);
- return "";
+ iree_call_function_state_destroy(call_state);
+ return false;
}
- // Note: this leaks the buffer. It's up to the caller to free it after use.
- char* outputs_string = strdup(iree_string_builder_buffer(&outputs_builder));
- iree_string_builder_deinitialize(&outputs_builder);
- return outputs_string;
+ return true;
}
diff --git a/experimental/web/sample_webgpu/multiple_results.mlir b/experimental/web/sample_webgpu/multiple_results.mlir
new file mode 100644
index 0000000..ab82f97
--- /dev/null
+++ b/experimental/web/sample_webgpu/multiple_results.mlir
@@ -0,0 +1,8 @@
+func.func @multiple_results(
+ %input_0 : tensor<f32>,
+ %input_1 : tensor<f32>
+) -> (tensor<f32>, tensor<f32>) {
+ %result_0 = math.absf %input_0 : tensor<f32>
+ %result_1 = math.absf %input_1 : tensor<f32>
+ return %result_0, %result_1 : tensor<f32>, tensor<f32>
+}