// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "bindings/tflite/model.h"

#include <stdio.h>
#include <string.h>

#include "iree/base/tracing.h"
#include "iree/modules/hal/hal_module.h"
#include "iree/vm/bytecode_module.h"

static iree_status_t _TfLiteModelPrepareRuntime() {
  IREE_TRACE_ZONE_BEGIN(z0);

  IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_vm_register_builtin_types());
  IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_module_register_types());

  IREE_TRACE_ZONE_END(z0);
  return iree_ok_status();
}

static iree_status_t _TfLiteModelCalculateFunctionIOCounts(
    const iree_vm_function_signature_t* signature, int32_t* out_input_count,
    int32_t* out_output_count) {
  iree_string_view_t arguments, results;
  IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments(
      signature, &arguments, &results));
  // NOTE: today we only pass 1:1 buffer views with what tflite does.
  // That means that both these should be one `r` per buffer view and our counts
  // are just the number of chars in the cconv.
  *out_input_count = (int32_t)arguments.size;
  *out_output_count = (int32_t)results.size;
  return iree_ok_status();
}

static iree_status_t _TfLiteModelInitializeModule(const void* flatbuffer_data,
                                                  size_t flatbuffer_size,
                                                  iree_allocator_t allocator,
                                                  TfLiteModel* model) {
  IREE_TRACE_ZONE_BEGIN(z0);

  IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, _TfLiteModelPrepareRuntime());

  iree_const_byte_span_t flatbuffer_span =
      iree_make_const_byte_span(flatbuffer_data, flatbuffer_size);
  iree_allocator_t flatbuffer_allocator = iree_allocator_null();
  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0,
      iree_vm_bytecode_module_create(flatbuffer_span, flatbuffer_allocator,
                                     allocator, &model->module),
      "error creating bytecode module");

  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0,
      iree_vm_module_lookup_function_by_name(
          model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
          iree_make_cstring_view("_tflite_main"), &model->exports._main),
      "unable to find '_tflite_main' export in module");

  // Get the input and output counts of the function; this is useful for being
  // able to preallocate storage when creating interpreters.
  iree_vm_function_signature_t main_signature =
      iree_vm_function_signature(&model->exports._main);
  IREE_RETURN_IF_ERROR(_TfLiteModelCalculateFunctionIOCounts(
      &main_signature, &model->input_count, &model->output_count));

  // NOTE: the input shape query is not required as it's possible (though
  // silly) for a model to have no inputs. In testing this can happen a lot
  // but in the wild it's rare ... says someone who previously filed bugs
  // against tflite because they didn't support models with no inputs when I
  // was being silly and needed them ;)
  IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
      model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
      iree_make_cstring_view("_tflite_main_query_input_shape"),
      &model->exports._query_input_shape));

  // NOTE: the input shape resizing function is only required if the model has
  // dynamic shapes.
  IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
      model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
      iree_make_cstring_view("_tflite_main_resize_input_shape"),
      &model->exports._resize_input_shape));

  IREE_RETURN_AND_END_ZONE_IF_ERROR(
      z0,
      iree_vm_module_lookup_function_by_name(
          model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
          iree_make_cstring_view("_tflite_main_query_output_shape"),
          &model->exports._query_output_shape),
      "unable to find '_tflite_main_query_output_shape' export in module");

  // It's OK for this to fail; the model may not have variables.
  IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name(
      model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
      iree_make_cstring_view("_tflite_main_reset_variables"),
      &model->exports._reset_variables));

  IREE_TRACE_ZONE_END(z0);
  return iree_ok_status();
}

TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreate(const void* model_data,
                                                      size_t model_size) {
  iree_allocator_t allocator = iree_allocator_system();
  IREE_TRACE_ZONE_BEGIN(z0);

  TfLiteModel* model = NULL;
  iree_status_t status =
      iree_allocator_malloc(allocator, sizeof(*model), (void**)&model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    IREE_TRACE_MESSAGE(ERROR, "failed model allocation");
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }
  memset(model, 0, sizeof(*model));
  iree_atomic_ref_count_init(&model->ref_count);
  model->allocator = allocator;

  status =
      _TfLiteModelInitializeModule(model_data, model_size, allocator, model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }

  IREE_TRACE_ZONE_END(z0);
  return model;
}

TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreateFromFile(
    const char* model_path) {
  iree_allocator_t allocator = iree_allocator_system();
  IREE_TRACE_ZONE_BEGIN(z0);

  // TODO(#3909): use file mapping C API.
  FILE* file = fopen(model_path, "r");
  if (!file) {
    IREE_TRACE_MESSAGE(ERROR, "failed to open model file");
    IREE_TRACE_MESSAGE_DYNAMIC(ERROR, model_path, strlen(model_path));
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }
  fseek(file, 0, SEEK_END);
  size_t file_size = ftell(file);
  fseek(file, 0, SEEK_SET);
  TfLiteModel* model = NULL;
  iree_status_t status = iree_allocator_malloc(
      allocator, sizeof(TfLiteModel) + file_size, (void**)&model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    IREE_TRACE_MESSAGE(ERROR, "failed model+data allocation");
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }
  memset(model, 0, sizeof(*model));
  iree_atomic_ref_count_init(&model->ref_count);
  model->allocator = allocator;
  model->owned_model_data = (uint8_t*)model + file_size;
  (void)fread(model->owned_model_data, 1, file_size, file);
  fclose(file);

  status = _TfLiteModelInitializeModule(model->owned_model_data, file_size,
                                        allocator, model);
  if (!iree_status_is_ok(iree_status_consume_code(status))) {
    IREE_TRACE_ZONE_END(z0);
    return NULL;
  }

  IREE_TRACE_ZONE_END(z0);
  return model;
}

void _TfLiteModelRetain(TfLiteModel* model) {
  if (model) {
    iree_atomic_ref_count_inc(&model->ref_count);
  }
}

void _TfLiteModelRelease(TfLiteModel* model) {
  if (model && iree_atomic_ref_count_dec(&model->ref_count) == 1) {
    IREE_TRACE_ZONE_BEGIN(z0);
    iree_vm_module_release(model->module);
    iree_allocator_free(model->allocator, model);
    IREE_TRACE_ZONE_END(z0);
  }
}

TFL_CAPI_EXPORT extern void TfLiteModelDelete(TfLiteModel* model) {
  IREE_TRACE_ZONE_BEGIN(z0);
  _TfLiteModelRelease(model);
  IREE_TRACE_ZONE_END(z0);
}
