// An example based on iree/samples/simple_embedding.

#include "samples/util/util.h"

#include <springbok.h>
#include <stdio.h>

#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "samples/device/device.h"

extern const MlModel kModel;

// Prepare the input buffers and buffer_views based on the data type. They must
// be released by the caller.
static iree_status_t prepare_input_hal_buffer_views(
    const MlModel *model, iree_hal_device_t *device, void **arg_buffers,
    iree_hal_buffer_view_t **arg_buffer_views) {
  iree_status_t result = iree_ok_status();

  // Prepare the input buffer, and populate the initial value.
  // The input buffer must be released by the caller.
  result = load_input_data(model, arg_buffers);

  // Wrap buffers in shaped buffer views.
  // The buffers can be mapped on the CPU and that can also be used
  // on the device. Not all devices support this, but the ones we have now do.

  iree_hal_memory_type_t input_memory_type =
      IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
  for (int i = 0; i < model->num_input; ++i) {
    if (iree_status_is_ok(result)) {
      result = iree_hal_buffer_view_wrap_or_clone_heap_buffer(
          iree_hal_device_allocator(device), model->input_shape[i],
          model->num_input_dim[i], model->hal_element_type,
          IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, input_memory_type,
          IREE_HAL_MEMORY_ACCESS_READ, IREE_HAL_BUFFER_USAGE_ALL,
          iree_make_byte_span(arg_buffers[i], model->input_size_bytes[i] *
                                                  model->input_length[i]),
          iree_allocator_null(), &(arg_buffer_views[i]));
    }
  }
  return result;
}

iree_status_t run(const MlModel *model) {
  IREE_RETURN_IF_ERROR(iree_hal_module_register_types());

  iree_vm_instance_t *instance = NULL;
  iree_status_t result =
      iree_vm_instance_create(iree_allocator_system(), &instance);

  iree_hal_device_t *device = NULL;
  if (iree_status_is_ok(result)) {
    result = create_sample_device(iree_allocator_system(), &device);
  }
  iree_vm_module_t *hal_module = NULL;
  if (iree_status_is_ok(result)) {
    result =
        iree_hal_module_create(device, iree_allocator_system(), &hal_module);
  }
  // Load bytecode module from the embedded data.
  const iree_const_byte_span_t module_data = load_bytecode_module_data();

  iree_vm_module_t *bytecode_module = NULL;
  if (iree_status_is_ok(result)) {
    result = iree_vm_bytecode_module_create(module_data, iree_allocator_null(),
                                            iree_allocator_system(),
                                            &bytecode_module);
  }

  // Allocate a context that will hold the module state across invocations.
  iree_vm_context_t *context = NULL;
  iree_vm_module_t *modules[] = {hal_module, bytecode_module};
  if (iree_status_is_ok(result)) {
    result = iree_vm_context_create_with_modules(
        instance, IREE_VM_CONTEXT_FLAG_NONE, &modules[0],
        IREE_ARRAYSIZE(modules), iree_allocator_system(), &context);
  }
  iree_vm_module_release(hal_module);
  iree_vm_module_release(bytecode_module);

  // Lookup the entry point function.
  // Note that we use the synchronous variant which operates on pure type/shape
  // erased buffers.
  iree_vm_function_t main_function;
  if (iree_status_is_ok(result)) {
    result = (iree_vm_context_resolve_function(
        context, iree_make_cstring_view(model->entry_func), &main_function));
  }

  // Prepare the input buffers.
  void *arg_buffers[MAX_MODEL_INPUT_NUM] = {NULL};
  iree_hal_buffer_view_t *arg_buffer_views[MAX_MODEL_INPUT_NUM] = {NULL};
  if (iree_status_is_ok(result)) {
    result = prepare_input_hal_buffer_views(model, device, arg_buffers,
                                            arg_buffer_views);
  }

  // Setup call inputs with our buffers.
  iree_vm_list_t *inputs = NULL;
  if (iree_status_is_ok(result)) {
    result = iree_vm_list_create(
        /*element_type=*/NULL, /*capacity=*/model->num_input,
        iree_allocator_system(), &inputs);
  }
  iree_vm_ref_t arg_buffer_view_ref;
  for (int i = 0; i < model->num_input; ++i) {
    arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_views[i]);
    if (iree_status_is_ok(result)) {
      result = iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref);
    }
  }

  // Prepare outputs list to accept the results from the invocation.
  // The output vm list is allocated statically.
  iree_vm_list_t *outputs = NULL;
  if (iree_status_is_ok(result)) {
    result = iree_vm_list_create(
        /*element_type=*/NULL,
        /*capacity=*/1, iree_allocator_system(), &outputs);
  }

  // Invoke the function.
  if (iree_status_is_ok(result)) {
    result = iree_vm_invoke(context, main_function, IREE_VM_CONTEXT_FLAG_NONE,
                            /*policy=*/NULL, inputs, outputs,
                            iree_allocator_system());
  }

  for (int index_output = 0; index_output < model->num_output; index_output++) {
    iree_hal_buffer_view_t *ret_buffer_view = NULL;
    if (iree_status_is_ok(result)) {
      // Get the result buffers from the invocation.
      ret_buffer_view = (iree_hal_buffer_view_t *)iree_vm_list_get_ref_deref(
          outputs, index_output, iree_hal_buffer_view_get_descriptor());
      if (ret_buffer_view == NULL) {
        result = iree_make_status(IREE_STATUS_NOT_FOUND,
                                  "can't find return buffer view");
      }
    }
    // Read back the results and ensure we got the right values.
    iree_hal_buffer_mapping_t mapped_memory;
    if (iree_status_is_ok(result)) {
      result = iree_hal_buffer_map_range(
          iree_hal_buffer_view_buffer(ret_buffer_view),
          IREE_HAL_MEMORY_ACCESS_READ, 0, IREE_WHOLE_BUFFER, &mapped_memory);
    }
    if (iree_status_is_ok(result)) {
      result = check_output_data(model, &mapped_memory, index_output);
      iree_hal_buffer_unmap_range(&mapped_memory);
    }
  }

  iree_vm_list_release(inputs);
  iree_vm_list_release(outputs);
  for (int i = 0; i < model->num_input; ++i) {
    if (arg_buffers[i] != NULL) {
      iree_aligned_free(arg_buffers[i]);
    }
  }
  iree_vm_context_release(context);
  IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
      stdout, iree_hal_device_allocator(device)));
  iree_hal_device_release(device);
  iree_vm_instance_release(instance);
  return result;
}

int main() {
  const MlModel *model_ptr = &kModel;
  const iree_status_t result = run(model_ptr);
  int ret = (int)iree_status_code(result);
  if (!iree_status_is_ok(result)) {
    iree_status_fprint(stderr, result);
    iree_status_free(result);
  } else {
    LOG_INFO("%s finished successfully", model_ptr->model_name);
  }

  return ret;
}
