// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "runtime/bindings/tflite/model.h"

#include <stdio.h>
#include <string.h>

#include "iree/modules/hal/module.h"
#include "iree/vm/bytecode/module.h"

static iree_status_t _TfLiteModelCalculateFunctionIOCounts(
    const iree_vm_function_signature_t* signature, int32_t* out_input_count,
    int32_t* out_output_count) {
  iree_string_view_t arguments, results;
  IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments(
      signature, &arguments, &results));
  // NOTE: today we only pass 1:1 buffer views with what tflite does.
  // That means that both these should be one `r` per buffer view and our counts
  // are just the number of chars in the cconv.
  *out_input_count = (int32_t)arguments.size;
  *out_output_count = (int32_t)results.size;
  return iree_ok_status();
}

static iree_status_t _TfLiteModelInitializeModule(const void* flatbuffer_data,
                                                  size_t flatbuffer_size,
                                                  iree_allocator_t allocator,
                                                  TfLiteModel* model) {
  IREE_TRACE_ZONE_BEGIN(z0);

  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0, iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, allocator,
                                  &model->instance));
  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0, iree_hal_module_register_all_types(model->instance));

  iree_const_byte_span_t flatbuffer_span =
      iree_make_const_byte_span(flatbuffer_data, flatbuffer_size);
  iree_allocator_t flatbuffer_allocator = iree_allocator_null();
  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0,
      iree_vm_bytecode_module_create(model->instance, flatbuffer_span,
                                     flatbuffer_allocator, allocator,
                                     &model->module),
      "error creating bytecode module");

  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0,
      iree_vm_module_lookup_function_by_name(
          model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
          iree_make_cstring_view("_tflite_main"), &model->exports._main),
      "unable to find '_tflite_main' export in module, module must be compiled "
      "with tflite bindings support");

  // Get the input and output counts of the function; this is useful for being
  // able to preallocate storage when creating interpreters.
  iree_vm_function_signature_t main_signature =
      iree_vm_function_signature(&model->exports._main);
  IREE_RETURN_IF_ERROR(_TfLiteModelCalculateFunctionIOCounts(
      &main_signature, &model->input_count, &model->output_count));

  // NOTE: the input shape query is not required as it's possible (though
  // silly) for a model to have no inputs. In testing this can happen a lot
  // but in the wild it's rare ... says someone who previously filed bugs
  // against tflite because they didn't support models with no inputs when I
  // was being silly and needed them ;)
  IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
      model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
      iree_make_cstring_view("_tflite_main_query_input_shape"),
      &model->exports._query_input_shape));

  // NOTE: the input shape resizing function is only required if the model has
  // dynamic shapes.
  IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
      model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
      iree_make_cstring_view("_tflite_main_resize_input_shape"),
      &model->exports._resize_input_shape));

  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0,
      iree_vm_module_lookup_function_by_name(
          model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
          iree_make_cstring_view("_tflite_main_query_output_shape"),
          &model->exports._query_output_shape),
      "unable to find '_tflite_main_query_output_shape' export in module");

  // It's OK for this to fail; the model may not have variables.
  IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
      model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
      iree_make_cstring_view("_tflite_main_reset_variables"),
      &model->exports._reset_variables));

  IREE_TRACE_ZONE_END(z0);
  return iree_ok_status();
}

TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreate(const void* model_data,
                                                      size_t model_size) {
  iree_allocator_t allocator = iree_allocator_system();
  IREE_TRACE_ZONE_BEGIN(z0);

  TfLiteModel* model = NULL;
  iree_status_t status =
      iree_allocator_malloc(allocator, sizeof(*model), (void**)&model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    IREE_TRACE_MESSAGE(ERROR, "failed model allocation");
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }
  memset(model, 0, sizeof(*model));
  iree_atomic_ref_count_init(&model->ref_count);
  model->allocator = allocator;

  status =
      _TfLiteModelInitializeModule(model_data, model_size, allocator, model);
  if (!iree_status_is_ok(status)) {
    iree_status_fprint(stderr, status);
    iree_status_free(status);
    TfLiteModelDelete(model);
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }

  IREE_TRACE_ZONE_END(z0);
  return model;
}

TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreateFromFile(
    const char* model_path) {
  iree_allocator_t allocator = iree_allocator_system();
  IREE_TRACE_ZONE_BEGIN(z0);

  // TODO(#3909): use file mapping C API.
  FILE* file = fopen(model_path, "r");
  if (!file) {
    IREE_TRACE_MESSAGE(ERROR, "failed to open model file");
    IREE_TRACE_MESSAGE_DYNAMIC(ERROR, model_path, strlen(model_path));
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }
  fseek(file, 0, SEEK_END);
  size_t file_size = ftell(file);
  fseek(file, 0, SEEK_SET);
  TfLiteModel* model = NULL;
  iree_status_t status = iree_allocator_malloc(
      allocator, sizeof(TfLiteModel) + file_size, (void**)&model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    IREE_TRACE_MESSAGE(ERROR, "failed model+data allocation");
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }
  memset(model, 0, sizeof(*model));
  iree_atomic_ref_count_init(&model->ref_count);
  model->allocator = allocator;
  model->owned_model_data = (uint8_t*)model + file_size;
  int ret = fread(model->owned_model_data, 1, file_size, file);
  fclose(file);
  if (ret != file_size) {
    TfLiteModelDelete(model);
    IREE_TRACE_MESSAGE(ERROR, "failed model+data read");
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }

  status = _TfLiteModelInitializeModule(model->owned_model_data, file_size,
                                        allocator, model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    TfLiteModelDelete(model);
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }

  IREE_TRACE_ZONE_END(z0);
  return model;
}

void _TfLiteModelRetain(TfLiteModel* model) {
  if (model) {
    iree_atomic_ref_count_inc(&model->ref_count);
  }
}

void _TfLiteModelRelease(TfLiteModel* model) {
  if (model && iree_atomic_ref_count_dec(&model->ref_count) == 1) {
    IREE_TRACE_ZONE_BEGIN(z0);
    iree_vm_module_release(model->module);
    iree_vm_instance_release(model->instance);
    iree_allocator_free(model->allocator, model);
    IREE_TRACE_ZONE_END(z0);
  }
}

TFL_CAPI_EXPORT extern void TfLiteModelDelete(TfLiteModel* model) {
  IREE_TRACE_ZONE_BEGIN(z0);
  _TfLiteModelRelease(model);
  IREE_TRACE_ZONE_END(z0);
}
