// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "./py_module.h"

#include <string_view>
#include <unordered_map>

#include "./vm.h"

namespace iree::python {

// Low level class for constructing a native VM module from Python. This
// class is mutable while the module is being setup and will typically
// produce a module instance when ready to be used.
//
// This class has a complicated life-cycle and can be in one of several
// states:
//   UNINITIALZED: Prior to calling Create(). Mutable.
//   INITIALIZED: After calling Create() and prior to the returned reference
//     being released. Immutable.
//   DESTROYED: After the reference from Create() is released. Nothing
//     more can be done with the instance but it is still live until the
//     Python reference to it is released.
class PyModuleInterface {
 public:
  PyModuleInterface(std::string module_name, py::object ctor)
      : module_name_(std::move(module_name)), ctor_(std::move(ctor)) {
    CheckApiStatus(iree_vm_module_initialize(&interface_, this),
                   "Failed to initialize vm_module");
    interface_.destroy = &PyModuleInterface::ModuleDestroy;
    interface_.name = &PyModuleInterface::ModuleName;
    interface_.signature = &PyModuleInterface::ModuleSignature;
    interface_.enumerate_dependencies =
        &PyModuleInterface::ModuleEnumerateDependencies;
    interface_.get_function = &PyModuleInterface::ModuleGetFunction;
    interface_.lookup_function = &PyModuleInterface::ModuleLookupFunction;
    interface_.alloc_state = &PyModuleInterface::ModuleAllocState;
    interface_.free_state = &PyModuleInterface::ModuleFreeState;
    interface_.resolve_import = &PyModuleInterface::ModuleResolveImport;
    interface_.notify = &PyModuleInterface::ModuleNotify;
    interface_.begin_call = &PyModuleInterface::ModuleBeginCall;
  }
  PyModuleInterface(const PyModuleInterface&) = delete;
  ~PyModuleInterface() = default;

  static PyModuleInterface* AsSelf(void* vself) {
    return static_cast<PyModuleInterface*>(vself);
  }

  static void ModuleDestroy(void* vself) {
    auto self = AsSelf(vself);
    py::gil_scoped_acquire acquire;
    self->retained_self_ref_ = {};
  }

  static iree_string_view_t ModuleName(void* vself) {
    auto self = AsSelf(vself);
    return {self->module_name_.data(),
            static_cast<iree_host_size_t>(self->module_name_.size())};
  }

  static iree_vm_module_signature_t ModuleSignature(void* vself) {
    auto self = AsSelf(vself);
    iree_vm_module_signature_t signature = {0};
    signature.version = self->descriptor_.version;
    signature.attr_count = 0;
    signature.import_function_count = self->imports_.size();
    signature.export_function_count = self->exports_.size();
    signature.internal_function_count = 0;
    return signature;
  }

  static iree_status_t ModuleEnumerateDependencies(
      void* vself, iree_vm_module_dependency_callback_t callback,
      void* user_data) {
    // TODO(laurenzo): python support for declaring dependencies on the module.
    return iree_ok_status();
  }

  static iree_status_t ModuleGetFunction(
      void* vself, iree_vm_function_linkage_t linkage, iree_host_size_t ordinal,
      iree_vm_function_t* out_function, iree_string_view_t* out_name,
      iree_vm_function_signature_t* out_signature) {
    auto self = AsSelf(vself);
    if (IREE_LIKELY(linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT ||
                    linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT_OPTIONAL)) {
      if (IREE_LIKELY(ordinal < self->export_functions_.size())) {
        std::unique_ptr<PyFunction>& f = self->export_functions_[ordinal];
        if (IREE_LIKELY(out_function)) {
          out_function->linkage = linkage;
          out_function->module = &self->interface_;
          out_function->ordinal = ordinal;
        }
        if (IREE_LIKELY(out_name)) {
          *out_name = {f->name.data(),
                       static_cast<iree_host_size_t>(f->name.size())};
        }
        if (IREE_LIKELY(out_signature)) {
          out_signature->calling_convention = {
              f->cconv.data(), static_cast<iree_host_size_t>(f->cconv.size())};
        }
        return iree_ok_status();
      }
    }
    return iree_make_status(IREE_STATUS_NOT_FOUND);
  }

  static iree_status_t ModuleLookupFunction(
      void* vself, iree_vm_function_linkage_t linkage, iree_string_view_t name,
      const iree_vm_function_signature_t* expected_signature,
      iree_vm_function_t* out_function) {
    auto self = AsSelf(vself);
    std::string_view name_cpp(name.data, name.size);
    if (linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT ||
        linkage == IREE_VM_FUNCTION_LINKAGE_EXPORT_OPTIONAL) {
      auto found_it = self->export_name_to_ordinals_.find(name_cpp);
      if (found_it != self->export_name_to_ordinals_.end()) {
        out_function->linkage = linkage;
        out_function->module = &self->interface_;
        out_function->ordinal = found_it->second;
        return iree_ok_status();
      }
    }
    return iree_make_status(IREE_STATUS_NOT_FOUND, "function %.*s not exported",
                            (int)name.size, name.data);
  }

  static iree_status_t ModuleAllocState(
      void* vself, iree_allocator_t allocator,
      iree_vm_module_state_t** out_module_state) {
    auto self = AsSelf(vself);
    *out_module_state = nullptr;
    py::gil_scoped_acquire acquire;
    try {
      py::object py_state = self->ctor_(self->retained_self_ref_);
      // Steal the reference and use the raw PyObject* as the state.
      // This will be released in ModuleFreeState.
      *out_module_state =
          reinterpret_cast<iree_vm_module_state_t*>(py_state.release().ptr());
      return iree_ok_status();
    } catch (std::exception& e) {
      return iree_make_status(IREE_STATUS_UNKNOWN,
                              "Exception in call to PyModule constructor: %s",
                              e.what());
    }
  }

  static void ModuleFreeState(void* vself,
                              iree_vm_module_state_t* module_state) {
    py::gil_scoped_acquire acquire;
    // Release the reference stolen in ModuleAllocState.
    auto retained_handle =
        py::handle(reinterpret_cast<PyObject*>(module_state));
    retained_handle.dec_ref();
  }

  static iree_status_t ModuleResolveImport(
      void* vself, iree_vm_module_state_t* module_state,
      iree_host_size_t ordinal, const iree_vm_function_t* function,
      const iree_vm_function_signature_t* signature) {
    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                            "Python API does not support imports");
  }

  static iree_status_t ModuleNotify(void* vself,
                                    iree_vm_module_state_t* module_state,
                                    iree_vm_signal_t signal) {
    return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                            "ModuleNotify not implemented");
  }

  static iree_status_t ModuleBeginCall(void* vself, iree_vm_stack_t* stack,
                                       iree_vm_function_call_t call) {
    auto self = AsSelf(vself);
    if (IREE_UNLIKELY(call.function.ordinal >=
                      self->export_functions_.size())) {
      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                              "function ordinal out of bounds: 0 < %u < %zu",
                              call.function.ordinal,
                              self->export_functions_.size());
    }

    auto& f = self->export_functions_[call.function.ordinal];
    iree_host_size_t frame_size = 0;
    iree_vm_stack_frame_t* callee_frame = nullptr;
    IREE_RETURN_IF_ERROR(iree_vm_stack_function_enter(
        stack, &call.function, IREE_VM_STACK_FRAME_NATIVE, frame_size,
        /*frame_cleanup_fn=*/nullptr, &callee_frame));
    auto state_object =
        py::handle(reinterpret_cast<PyObject*>(callee_frame->module_state));

    try {
      IREE_RETURN_IF_ERROR(self->Invoke(*f, state_object, stack, call));
    } catch (std::exception& e) {
      return iree_make_status(IREE_STATUS_UNKNOWN,
                              "Exception raised from Python module: %s",
                              e.what());
    }

    return iree_vm_stack_function_leave(stack);
  }

  std::string ToString() {
    std::string s("<iree.runtime.PyModuleInterface '");
    s.append(module_name_);
    s.append("'");
    if (initialized_) {
      if (retained_self_ref_) {
        s.append(" initialized");
      } else {
        s.append(" destroyed");
      }
    }
    s.append(">");
    return s;
  }

  bool initialized() { return initialized_; }

  bool destroyed() { return initialized_ && !retained_self_ref_; }

  void AssertMutable() {
    if (initialized_) {
      throw std::runtime_error("Attempt to mutate a frozen PyModuleInterface");
    }
  }

  void ExportFunction(std::string name, std::string cconv,
                      py::object callable) {
    // Make sure not already defined.
    if (export_name_to_ordinals_.count(name)) {
      std::string msg("PyModule function already defined: ");
      msg.append(name);
      throw std::invalid_argument(std::move(msg));
    }

    // Heap allocate the backing PyFunction so we can reference its pointers.
    size_t ordinal = exports_.size();
    auto py_function = std::make_unique<PyFunction>(
        std::move(name), std::move(cconv), std::move(callable));
    exports_.push_back({});
    iree_vm_native_export_descriptor_t& d = exports_.back();
    d.local_name = {py_function->name.data(),
                    static_cast<iree_host_size_t>(py_function->name.size())};
    d.calling_convention = {
        py_function->cconv.data(),
        static_cast<iree_host_size_t>(py_function->cconv.size())};
    d.attr_count = 0;
    d.attrs = nullptr;
    std::string& alloced_name = py_function->name;
    CheckApiStatus(py_function->ParseCconv(), "Unparseable calling convention");

    // Transfer the PyFunction to its vector now that we are done touching it.
    export_functions_.push_back(std::move(py_function));
    export_name_to_ordinals_.insert(
        std::make_pair(std::string_view(alloced_name), ordinal));
  }

  // Initializes the internal data structures such that GetInterface() will be
  // valid. After this call, the interface is "live" and this instance will only
  // be deleted when its refcnt goes to 0, which will call ModuleDestroy and
  // release our Python side reference to this.
  void Initialize() {
    AssertMutable();
    initialized_ = true;
    memset(&descriptor_, 0, sizeof(descriptor_));
    descriptor_.name = {module_name_.data(),
                        static_cast<iree_host_size_t>(module_name_.size())};
    descriptor_.version = version_;
    descriptor_.attr_count = attrs_.size();
    descriptor_.attrs = attrs_.empty() ? nullptr : attrs_.data();
    descriptor_.import_count = imports_.size();
    descriptor_.imports = imports_.empty() ? nullptr : imports_.data();
    descriptor_.export_count = exports_.size();
    descriptor_.exports = exports_.empty() ? nullptr : exports_.data();
    descriptor_.function_count = functions_.size();
    descriptor_.functions = functions_.empty() ? nullptr : functions_.data();
    retained_self_ref_ = py::cast(this);
  }

  // Creates the live Python VmModule reference. This can only be called once.
  VmModule Create() {
    Initialize();
    return VmModule::StealFromRawPtr(&interface_);
  }

 private:
  struct PyFunction {
    std::string name;
    std::string cconv;
    py::object callable;

    // Initialized by ParseCconv.
    iree_string_view_t cconv_arguments;
    iree_string_view_t cconv_results;

    PyFunction(std::string name, std::string cconv, py::object callable)
        : name(std::move(name)),
          cconv(std::move(cconv)),
          callable(std::move(callable)) {}

    iree_status_t ParseCconv() {
      iree_vm_function_signature_t signature;
      memset(&signature, 0, sizeof(signature));
      signature.calling_convention = {
          cconv.data(), static_cast<iree_host_size_t>(cconv.size())};
      IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments(
          &signature, &cconv_arguments, &cconv_results));

      if (iree_vm_function_call_is_variadic_cconv(cconv_arguments) ||
          iree_vm_function_call_is_variadic_cconv(cconv_results)) {
        return iree_make_status(
            IREE_STATUS_INVALID_ARGUMENT,
            "PyModules do not yet support variadic arguments/results");
      }

      return iree_ok_status();
    }
  };

  iree_status_t Invoke(PyFunction& f, py::handle state_object,
                       iree_vm_stack_t* stack, iree_vm_function_call_t call) {
    py::gil_scoped_acquire acquire;
    uint8_t* packed_arguments = call.arguments.data;
    iree_host_size_t packed_arguments_required_size;
    // TODO: Is this validation needed or do we assume it from up-stack?
    IREE_RETURN_IF_ERROR(iree_vm_function_call_compute_cconv_fragment_size(
        f.cconv_arguments, /*segment_size_list=*/nullptr,
        &packed_arguments_required_size));
    if (IREE_UNLIKELY(packed_arguments_required_size !=
                      call.arguments.data_length)) {
      return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
                              "mismatched packed argument size: actual=%" PRIhsz
                              ", required=%" PRIhsz,
                              call.arguments.data_length,
                              packed_arguments_required_size);
    }

    // Unpack arguments.
    py::list arguments;
    for (iree_host_size_t i = 0; i < f.cconv_arguments.size; ++i) {
      switch (f.cconv_arguments.data[i]) {
        case IREE_VM_CCONV_TYPE_VOID:
          break;
        case IREE_VM_CCONV_TYPE_I32:
          arguments.append(
              py::cast(*reinterpret_cast<int32_t*>(packed_arguments)));
          packed_arguments += sizeof(int32_t);
          break;
        case IREE_VM_CCONV_TYPE_F32:
          arguments.append(
              py::cast(*reinterpret_cast<float*>(packed_arguments)));
          packed_arguments += sizeof(float);
          break;
        case IREE_VM_CCONV_TYPE_I64:
          arguments.append(
              py::cast(*reinterpret_cast<int64_t*>(packed_arguments)));
          packed_arguments += sizeof(int64_t);
          break;
        case IREE_VM_CCONV_TYPE_F64:
          arguments.append(
              py::cast(*reinterpret_cast<double*>(packed_arguments)));
          packed_arguments += sizeof(double);
          break;
        case IREE_VM_CCONV_TYPE_REF: {
          iree_vm_ref_t ref =
              *reinterpret_cast<iree_vm_ref_t*>(packed_arguments);
          // Since the Python level VmRef can escape, it needs its own ref
          // count.
          VmRef py_ref;
          iree_vm_ref_retain(&ref, &py_ref.ref());
          arguments.append(py::cast(py_ref, py::rv_policy::move));
          packed_arguments += sizeof(iree_vm_ref_t);
          break;
        }
        // TODO: Variadic segments.
        default:
          return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                                  "unsupported cconv type %c",
                                  f.cconv_arguments.data[i]);
      }
    }

    auto results = f.callable(state_object, *arguments);

    // Pack results.
    if (f.cconv_results.size == 0) {
      return iree_ok_status();
    }
    uint8_t* packed_results = call.results.data;
    bool unary_result = f.cconv_results.size == 1;
    auto pack_result = [&](py::object& value,
                           char cconv_type) -> iree_status_t {
      switch (cconv_type) {
        case IREE_VM_CCONV_TYPE_VOID:
          break;
        case IREE_VM_CCONV_TYPE_I32:
          *reinterpret_cast<int32_t*>(packed_results) =
              py::cast<int32_t>(value);
          packed_results += sizeof(int32_t);
          break;
        case IREE_VM_CCONV_TYPE_F32:
          *reinterpret_cast<float*>(packed_results) = py::cast<float>(value);
          packed_results += sizeof(float);
          break;
        case IREE_VM_CCONV_TYPE_I64:
          *reinterpret_cast<int64_t*>(packed_results) =
              py::cast<int64_t>(value);
          packed_results += sizeof(int64_t);
          break;
        case IREE_VM_CCONV_TYPE_F64:
          *reinterpret_cast<double*>(packed_results) = py::cast<double>(value);
          packed_results += sizeof(double);
          break;
        case IREE_VM_CCONV_TYPE_REF: {
          iree_vm_ref_t* result_ref =
              reinterpret_cast<iree_vm_ref_t*>(packed_results);
          if (value.is_none()) {
            return iree_make_status(
                IREE_STATUS_FAILED_PRECONDITION,
                "expected ref returned from Python function but got None");
          }
          VmRef* py_ref = py::cast<VmRef*>(value);
          iree_vm_ref_retain(&py_ref->ref(), result_ref);
          packed_results += sizeof(iree_vm_ref_t);
          break;
        }
        // TODO: Refs (need a generic Python ref wrapper).
        // TODO: Variadic segments.
        default:
          return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
                                  "unsupported cconv type %c", cconv_type);
      }
      return iree_ok_status();
    };

    if (unary_result) {
      return pack_result(results, f.cconv_results.data[0]);
    } else {
      py::sequence results_seq = py::cast<py::sequence>(results);
      int result_index = 0;
      for (iree_host_size_t i = 0; i < f.cconv_results.size; ++i) {
        py::object next_result = results_seq[result_index++];
        IREE_RETURN_IF_ERROR(pack_result(next_result, f.cconv_results.data[i]));
      }
      return iree_ok_status();
    }
  }

  // Descriptor state is built up when mutable and then will be populated
  // on the descriptor when frozen.
  std::string module_name_;
  uint32_t version_;
  py::object ctor_;
  std::vector<iree_string_pair_t> attrs_;
  std::vector<iree_vm_native_import_descriptor_t> imports_;
  std::vector<iree_vm_native_export_descriptor_t> exports_;
  std::vector<std::unique_ptr<PyFunction>> export_functions_;
  std::vector<iree_vm_native_function_ptr_t> functions_;

  // Map of names to ordinals.
  std::unordered_map<std::string_view, int> export_name_to_ordinals_;

  // Once the builder is frozen, the descriptor will be valid.
  iree_vm_module_t interface_;
  iree_vm_native_module_descriptor_t descriptor_;

  // Read-only and descriptor populated when frozen.
  bool initialized_ = false;
  py::object retained_self_ref_;
};

void SetupPyModuleBindings(py::module_& m) {
  py::class_<PyModuleInterface>(m, "PyModuleInterface")
      .def(py::init<std::string, py::object>(), py::arg("module_name"),
           py::arg("ctor"))
      .def("__str__", &PyModuleInterface::ToString)
      .def_prop_ro("initialized", &PyModuleInterface::initialized)
      .def_prop_ro("destroyed", &PyModuleInterface::destroyed)
      .def("create", &PyModuleInterface::Create)
      .def("export", &PyModuleInterface::ExportFunction, py::arg("name"),
           py::arg("cconv"), py::arg("callable"));
}

}  // namespace iree::python
