blob: 930920182a88a6842f0f319b432307f91a84b59d [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <stdio.h>
#include "iree/runtime/api.h"
struct State {
float w[3];
float b[1];
float X[3];
float y[1];
float loss[1];
};
void print_state(struct State* state) {
fprintf(stdout, "Weights:");
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(state->w); ++i) {
fprintf(stdout, " %f", state->w[i]);
}
fprintf(stdout, ", Bias: %f", state->b[0]);
fprintf(stdout, ", Loss: %f\n", state->loss[0]);
}
iree_status_t train(iree_runtime_session_t* session, struct State* state) {
iree_status_t status = iree_ok_status();
// Lookup the entry point function.
iree_runtime_call_t call;
status = iree_runtime_call_initialize_by_name(
session, iree_make_cstring_view("module.forward"), &call);
// Allocate buffers in device-local memory so that if the device has an
// independent address space they live on the fast side of the fence.
iree_hal_dim_t shape_w[1] = {IREE_ARRAYSIZE(state->w)};
iree_hal_dim_t shape_b[0] = {};
iree_hal_dim_t shape_X[2] = {1, IREE_ARRAYSIZE(state->X)};
iree_hal_dim_t shape_y[1] = {IREE_ARRAYSIZE(state->y)};
iree_hal_buffer_view_t* arg0 = NULL;
iree_hal_buffer_view_t* arg1 = NULL;
iree_hal_buffer_view_t* arg2 = NULL;
iree_hal_buffer_view_t* arg3 = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_view_allocate_buffer(
iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_w),
shape_w, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
},
iree_make_const_byte_span(state->w, sizeof(state->w)), &arg0);
}
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_view_allocate_buffer(
iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_b),
shape_b, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
},
iree_make_const_byte_span(state->b, sizeof(state->b)), &arg1);
}
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_view_allocate_buffer(
iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_X),
shape_X, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
},
iree_make_const_byte_span(state->X, sizeof(state->X)), &arg2);
}
if (iree_status_is_ok(status)) {
status = iree_hal_buffer_view_allocate_buffer(
iree_runtime_session_device_allocator(session), IREE_ARRAYSIZE(shape_y),
shape_y, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.usage = IREE_HAL_BUFFER_USAGE_DEFAULT,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
},
iree_make_const_byte_span(state->y, sizeof(state->y)), &arg3);
}
// Setup call inputs with our buffers.
if (iree_status_is_ok(status)) {
status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg0);
}
iree_hal_buffer_view_release(arg0);
if (iree_status_is_ok(status)) {
status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg1);
}
iree_hal_buffer_view_release(arg1);
if (iree_status_is_ok(status)) {
status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg2);
}
iree_hal_buffer_view_release(arg2);
if (iree_status_is_ok(status)) {
status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg3);
}
iree_hal_buffer_view_release(arg3);
// Invoke the function
IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
// Update weights
iree_hal_buffer_view_t* result = NULL;
if (iree_status_is_ok(status)) {
status = iree_runtime_call_outputs_pop_front_buffer_view(&call, &result);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
iree_hal_buffer_view_buffer(result), 0, &state->w, sizeof(state->w),
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
}
// Update bias
if (iree_status_is_ok(status)) {
status = iree_runtime_call_outputs_pop_front_buffer_view(&call, &result);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
iree_hal_buffer_view_buffer(result), 0, &state->b, sizeof(state->b),
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
}
// Update loss
if (iree_status_is_ok(status)) {
status = iree_runtime_call_outputs_pop_front_buffer_view(&call, &result);
}
if (iree_status_is_ok(status)) {
status = iree_hal_device_transfer_d2h(iree_runtime_session_device(session),
iree_hal_buffer_view_buffer(result),
0, &state->loss, sizeof(state->loss),
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout());
}
iree_hal_buffer_view_release(result);
return status;
}
iree_status_t run_sample(iree_string_view_t bytecode_module_path,
iree_string_view_t driver_name, struct State* state) {
iree_status_t status = iree_ok_status();
//===-------------------------------------------------------------------===//
// Instance configuration (this should be shared across sessions).
iree_runtime_instance_options_t instance_options;
iree_runtime_instance_options_initialize(&instance_options);
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
iree_runtime_instance_t* instance = NULL;
if (iree_status_is_ok(status)) {
fprintf(stdout, "Configuring IREE runtime instance and '%s' device\n",
driver_name.data);
status = iree_runtime_instance_create(&instance_options,
iree_allocator_system(), &instance);
}
// TODO(#5724): move device selection into the compiled modules.
iree_hal_device_t* device = NULL;
if (iree_status_is_ok(status)) {
status = iree_runtime_instance_try_create_default_device(
instance, driver_name, &device);
}
//===-------------------------------------------------------------------===//
//===-------------------------------------------------------------------===//
// Session configuration (one per loaded module to hold module state).
iree_runtime_session_options_t session_options;
iree_runtime_session_options_initialize(&session_options);
iree_runtime_session_t* session = NULL;
if (iree_status_is_ok(status)) {
fprintf(stdout, "Creating IREE runtime session\n");
status = iree_runtime_session_create_with_device(
instance, &session_options, device,
iree_runtime_instance_host_allocator(instance), &session);
}
iree_hal_device_release(device);
if (iree_status_is_ok(status)) {
fprintf(stdout, "Loading bytecode module at '%s'\n",
bytecode_module_path.data);
status = iree_runtime_session_append_bytecode_module_from_file(
session, bytecode_module_path.data);
}
//===-------------------------------------------------------------------===//
//===-------------------------------------------------------------------===//
if (iree_status_is_ok(status)) {
fprintf(stdout, "Training...\n");
print_state(state);
for (int i = 0; i < 10; i++) {
status = train(session, state);
print_state(state);
if (!iree_status_is_ok(status)) {
break;
}
}
}
//===-------------------------------------------------------------------===//
//===-------------------------------------------------------------------===//
// Cleanup.
iree_runtime_session_release(session);
iree_runtime_instance_release(instance);
//===-------------------------------------------------------------------===//
return status;
}
int main(int argc, char** argv) {
// Parse args
if (argc < 2) {
fprintf(stderr,
"Usage: native-training </path/to/native_training.vmfb> "
"[<driver_name>]\n");
fprintf(stderr, " (See the README for this sample for details)\n ");
return -1;
}
iree_string_view_t bytecode_module_path = iree_make_cstring_view(argv[1]);
iree_string_view_t driver_name;
if (argc >= 3) {
driver_name = iree_make_cstring_view(argv[2]);
} else {
driver_name = iree_make_cstring_view("local-sync");
}
// Run training
struct State state = {
{4.0f, 4.0f, 5.0f}, // w
{2.0f}, // b
{1.0f, 1.0f, 1.0f}, // X
{14.0f}, // y
{1.0f}, // loss
};
iree_status_t result = run_sample(bytecode_module_path, driver_name, &state);
if (!iree_status_is_ok(result)) {
fprintf(stdout, "Failed!\n");
iree_status_fprint(stderr, result);
iree_status_ignore(result);
return -1;
}
// Validate result
if (*state.loss > 0.1f) {
fprintf(stdout, "Loss unexpectedly high\n");
return -1;
}
fprintf(stdout, "\nSuccess!\n");
return 0;
}