// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "./vm.h"

#include "./status_utils.h"
#include "iree/base/api.h"

// TODO: We shouldn't need the HAL API but it is used for direct printing
// summaries of HAL objects in lists. We should have a better way of doing this
// dynamically vs hard depending on a type switch here.
#include "iree/modules/hal/module.h"
#include "iree/tooling/modules/resolver.h"
#include "iree/vm/api.h"
#include "pybind11/numpy.h"

using namespace pybind11::literals;

namespace iree {
namespace python {

namespace {

static const char kFromBufferDocstring[] =
    R"(Creates a Vmmodule from a Python buffer.

This is intended as a quick and dirty way to instantiate a VmModule from
a binary blob. It will implicitly make a copy if alignment is not sufficient.

It is recommended to use one of the other construction methods for maximum
determinism and efficiency:

* `mmap` : To memory map from a file.
* `wrap_buffer` : To directly wrap a Python buffer that is known to be
  aligned properly.
* `copy_buffer` : To always make a copy of a Python buffer such that it is
  aligned properly.

This was historically called `from_flatbuffer`. It is recommended that new
code use `flat_buffer`.

Args:
  instance: A VmInstance.
  buffer: An object implementing the Python buffer protocol. Typically a
    bytes, bytearray, memoryview, etc.
  warn_if_copy: Raises a warning if alignment is not sufficient to use the
    buffer directly, resulting in a copy. Defaults to True.
)";

static const char kCopyBufferDocstring[] =
    R"(Creates a VmModule by making a copy of a Python buffer.

Args:
  instance: A VmInstance.
  buffer: An object implementing the Python buffer protocol. Typically a
    bytes, bytearray, memoryview, etc.
)";

static const char kWrapBufferDocstring[] =
    R"(Creates a VmModule by directly using the backing memory of a Python buffer.

Args:
  instance: A VmInstance.
  buffer: An object implementing the Python buffer protocol. Typically a
    bytes, bytearray, memoryview, etc.
  destroy_callback: A no-argument callback that is invoked when the backing
    buffer is no longer in use.
  close_buffer: Whether to call the `close` method on the `buffer` (prior to
    invoking `destroy_callback`). Defaults to False.

Raises:
  ValueError if alignment is not satisfied.
)";

static const char kMMapDocstring[] =
    R"(Create a VmModule by mmap'ing a file.

When backed by a file, this is generally the most effective way to create a
VmModule. Especially for large modules, this will result in the fewest
copies and the most effective use of the system cache across invocations.

Note that mmap behavior differs between Posix and Windows. Whereas the former
will allow the backing file to be open before an mmap call and deleted
immediately after, Windows generally allows neither. For compatibility,
make sure that the backing file is not open for writing before calling this
method and that if it needs to be deleted when done, that is done in a
`destroy_callback`.

Args:
  instance: A VmInstance.
  filepath: Path to the file on the file system.
  destroy_callback: A no-argument callback that is invoked when the backing
    buffer is no longer in use.
)";

// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes
// out of scope.
class PyBufferReleaser {
 public:
  PyBufferReleaser(Py_buffer& b) : b_(b) {}
  ~PyBufferReleaser() { PyBuffer_Release(&b_); }

 private:
  Py_buffer& b_;
};

py::dict GetFunctionReflectionDict(iree_vm_function_t& f) {
  py::dict attrs;
  for (iree_host_size_t i = 0;; ++i) {
    iree_string_pair_t attr;
    auto status = iree_vm_function_get_attr(f, i, &attr);
    if (iree_status_is_out_of_range(status)) {
      iree_status_ignore(status);
      break;
    }
    CheckApiStatus(status, "Error getting reflection attr");
    py::str key_str(attr.key.data, attr.key.size);
    py::str value_str(attr.value.data, attr.value.size);
    attrs[std::move(key_str)] = std::move(value_str);
  }
  return attrs;
}

}  // namespace

//------------------------------------------------------------------------------
// VmInstance
//------------------------------------------------------------------------------

VmInstance VmInstance::Create() {
  IREE_TRACE_SCOPE_NAMED("VmInstance::Create");

  iree_vm_instance_t* instance = NULL;
  auto status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
                                        iree_allocator_system(), &instance);
  CheckApiStatus(status, "Error creating instance");

  // The python bindings assume the HAL is always available for use.
  // We register the types here so modules can be loaded using the HAL types
  // in any order.
  CheckApiStatus(iree_hal_module_register_all_types(instance),
                 "registering HAL types");

  return VmInstance::StealFromRawPtr(instance);
}

//------------------------------------------------------------------------------
// VmContext
//------------------------------------------------------------------------------

VmContext VmContext::Create(VmInstance* instance,
                            std::optional<std::vector<VmModule*>> modules) {
  IREE_TRACE_SCOPE_NAMED("VmContext::Create");
  iree_vm_context_t* context;
  if (!modules) {
    // Simple create with open allowed modules.
    auto status =
        iree_vm_context_create(instance->raw_ptr(), IREE_VM_CONTEXT_FLAG_NONE,
                               iree_allocator_system(), &context);
    CheckApiStatus(status, "Error creating vm context");
  } else {
    // Closed set of modules.
    std::vector<iree_vm_module_t*> module_handles;
    module_handles.resize(modules->size());
    for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
      module_handles[i] = (*modules)[i]->raw_ptr();
    }
    auto status = iree_vm_context_create_with_modules(
        instance->raw_ptr(), IREE_VM_CONTEXT_FLAG_NONE, module_handles.size(),
        module_handles.data(), iree_allocator_system(), &context);
    CheckApiStatus(status, "Error creating vm context with modules");
  }

  IREE_ASSERT(context);
  return VmContext::StealFromRawPtr(context);
}

void VmContext::RegisterModules(std::vector<VmModule*> modules) {
  std::vector<iree_vm_module_t*> module_handles;
  module_handles.resize(modules.size());
  for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
    module_handles[i] = modules[i]->raw_ptr();
  }
  auto status = iree_vm_context_register_modules(
      raw_ptr(), module_handles.size(), &module_handles[0]);
  CheckApiStatus(status, "Error registering modules");
}

void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs,
                       VmVariantList& outputs) {
  iree_status_t status;
  {
    py::gil_scoped_release release;
    status = iree_vm_invoke(raw_ptr(), f, IREE_VM_INVOCATION_FLAG_NONE, nullptr,
                            inputs.raw_ptr(), outputs.raw_ptr(),
                            iree_allocator_system());
  }
  CheckApiStatus(status, "Error invoking function");
}

//------------------------------------------------------------------------------
// VmModule
//------------------------------------------------------------------------------

VmModule VmModule::ResolveModuleDependency(VmInstance* instance,
                                           const std::string& name,
                                           uint32_t minimum_version) {
  IREE_TRACE_SCOPE_NAMED("VmModule::ResolveModuleDependency");
  iree_vm_module_t* module = nullptr;

  iree_vm_module_dependency_t dependency = {
      iree_make_cstring_view(name.c_str()), minimum_version,
      IREE_VM_MODULE_DEPENDENCY_FLAG_REQUIRED};

  auto status = iree_tooling_resolve_module_dependency(
      instance->raw_ptr(), &dependency, iree_allocator_system(), &module);

  assert(module != nullptr);

  CheckApiStatus(status, "Error resolving module dependency");
  auto py_module = VmModule::StealFromRawPtr(module);
  return py_module;
}

VmModule VmModule::MMap(VmInstance* instance, std::string filepath,
                        py::object destroy_callback) {
  IREE_TRACE_SCOPE_NAMED("VmModule::MMap");
  auto mmap_module = py::module::import("mmap");
  auto open_func = py::module::import("io").attr("open");
  auto file_obj = open_func(filepath, "r+b");
  // The signature of mmap is different on Windows vs others. On others,
  // we use explicit flags and protection attributes for better control,
  // triggering off of the presence of the MAP_SHARED flag constant (which
  // is not present on Windows).
  py::object mapped_file;
  if (py::hasattr(mmap_module, "MAP_SHARED")) {
    // Posix mmap signature.
    auto flags = py::cast<int64_t>(mmap_module.attr("MAP_SHARED"));
    // MAP_POPULATE isn't available on all versions/platforms.
    if (py::hasattr(mmap_module, "MAP_POPULATE")) {
      flags |= py::cast<int64_t>(mmap_module.attr("MAP_POPULATE"));
    }
    auto prot = py::cast<int64_t>(mmap_module.attr("PROT_READ"));
    mapped_file = mmap_module.attr("mmap")(file_obj.attr("fileno")(), 0,
                                           "flags"_a = flags, "prot"_a = prot);
  } else {
    // Windows mmap signature.
    mapped_file =
        mmap_module.attr("mmap")(file_obj.attr("fileno")(), 0,
                                 "access"_a = mmap_module.attr("ACCESS_READ"));
  }
  // Backing file can be closed after a successful mmap call.
  file_obj.attr("close")();

  // MADV_RANDOM is not available on Windows (and possibly others?).
  if (py::hasattr(mmap_module, "MADV_RANDOM")) {
    mapped_file.attr("madvise")(mmap_module.attr("MADV_RANDOM"));
  }
  return WrapBuffer(instance, std::move(mapped_file),
                    std::move(destroy_callback),
                    /*close_buffer=*/true);
}

VmModule VmModule::WrapBuffer(VmInstance* instance, py::object buffer_obj,
                              py::object destroy_callback, bool close_buffer) {
  IREE_TRACE_SCOPE_NAMED("VmModule::FromAlignedMemory");
  auto py_buffer = py::cast<py::buffer>(buffer_obj);
  auto buffer_info = py_buffer.request();
  if (!iree_host_size_has_alignment((uintptr_t)buffer_info.ptr,
                                    IREE_HAL_HEAP_BUFFER_ALIGNMENT)) {
    std::stringstream err;
    err << "VmModule.from_aligned_memory received an unaligned buffer. ";
    err << "Got 0x" << (void*)buffer_info.ptr << ", expected alignment ";
    err << IREE_HAL_HEAP_BUFFER_ALIGNMENT;
    throw std::invalid_argument(err.str());
  }

  iree_vm_module_t* module = nullptr;

  // Bridge to the C-based deallocator API.
  struct DeallocateState {
    DeallocateState(py::object buffer_obj, py::object destroy_callback,
                    bool close_buffer)
        : buffer_obj(std::move(buffer_obj)),
          destroy_callback(std::move(destroy_callback)),
          close_buffer(close_buffer) {}
    py::object buffer_obj;
    py::object destroy_callback;
    bool close_buffer;
  };
  DeallocateState* state =
      new DeallocateState(buffer_obj, destroy_callback, close_buffer);
  auto ctl_fn = +([](void* self, iree_allocator_command_t command,
                     const void* params, void** inout_ptr) {
    py::gil_scoped_acquire gil;
    assert(command == IREE_ALLOCATOR_COMMAND_FREE);
    try {
      DeallocateState* state = static_cast<DeallocateState*>(self);
      if (state->close_buffer) {
        state->buffer_obj.attr("close")();
      }
      if (!state->destroy_callback.is_none()) {
        state->destroy_callback();
      }
      delete state;
    } catch (std::exception& e) {
      // There are many situations where deallocation exceptions can be
      // swallowed, so carp loudly. This is almost always a critical issue
      // that needs to be visible.
      fprintf(
          stderr,
          "error: exception raised while deallocating storage for an "
          "iree.runtime.VmModule. This is unrecoverable and likely indicates a "
          "serious problem, minimally resulting in memory leaks: %s",
          e.what());
      return iree_make_status(
          IREE_STATUS_UNKNOWN,
          "exception raised while deallocating storage for an "
          "iree.runtime.VmModule. This is unrecoverable and likely indicates a "
          "serious problem, minimally resulting in memory leaks: %s",
          e.what());
    }
    return iree_ok_status();
  });
  iree_allocator_t deallocator{/*self=*/state, /*ctl=*/ctl_fn};

  auto status = iree_vm_bytecode_module_create(
      instance->raw_ptr(),
      {static_cast<const uint8_t*>(buffer_info.ptr),
       static_cast<iree_host_size_t>(buffer_info.size)},
      deallocator, iree_allocator_system(), &module);
  if (!iree_status_is_ok(status)) {
    delete state;
  }

  CheckApiStatus(status, "Error creating vm module from aligned memory");
  auto py_module = VmModule::StealFromRawPtr(module);
  // Stash a reference to the flatbuffer at the Python instance level. This
  // is exposed to the tracing API, allowing it to get at the backing contents.
  py_module.stashed_flatbuffer_blob = buffer_obj;
  return py_module;
}

VmModule VmModule::CopyBuffer(VmInstance* instance, py::object buffer_obj) {
  IREE_TRACE_SCOPE_NAMED("VmModule::CopyBuffer");
  auto alignment =
      py::cast<uintptr_t>(py::module::import("mmap").attr("PAGESIZE"));
  auto bytearray_ctor = py::module::import("builtins").attr("bytearray");
  auto src_buffer = py::cast<py::buffer>(buffer_obj);
  auto src_buffer_info = src_buffer.request();
  py::ssize_t src_buffer_size = src_buffer_info.itemsize * src_buffer_info.size;

  // Need to allocate an extra page because there is no control at the Python
  // level for the alignment it may have.
  auto dst_buffer =
      py::cast<py::buffer>(bytearray_ctor(src_buffer_size + alignment));
  auto dst_buffer_info = dst_buffer.request();
  void* dst_aligned =
      (void*)iree_host_align((uintptr_t)dst_buffer_info.ptr, alignment);
  uintptr_t dst_offset =
      (uintptr_t)dst_aligned - (uintptr_t)dst_buffer_info.ptr;

  // Now create a memoryview over the unaligned bytearray and slice into that
  // to get the aligned Python buffer.
  auto dst_slice = py::slice(dst_offset, dst_offset + src_buffer_size, 1);
  py::object dst_view = py::memoryview(dst_buffer);
  py::object dst_view_aligned = dst_view[dst_slice];

  // If any of the indexing math was wrong, Python exceptions will be raised
  // above, so this is implicitly guarding the memcpy if it is done last.
  std::memcpy(dst_aligned, src_buffer_info.ptr, src_buffer_size);
  return WrapBuffer(instance, std::move(dst_view_aligned),
                    /*destroy_callback=*/py::none(),
                    /*close_buffer=*/false);
}

VmModule VmModule::FromBuffer(VmInstance* instance, py::object buffer_obj,
                              bool warn_if_copy) {
  IREE_TRACE_SCOPE_NAMED("VmModule::FromBuffer");
  auto py_buffer = py::cast<py::buffer>(buffer_obj);
  auto buffer_info = py_buffer.request();
  if (iree_host_size_has_alignment((uintptr_t)buffer_info.ptr,
                                   IREE_HAL_HEAP_BUFFER_ALIGNMENT)) {
    return WrapBuffer(instance, std::move(buffer_obj),
                      /*destroy_callback=*/py::none(), /*close_buffer=*/false);
  } else {
    if (warn_if_copy) {
      py::module::import("warnings")
          .attr("warn")(
              "Making copy of unaligned VmModule buffer. It is recommended to "
              "make this deterministic by calling `copy_buffer` to always make "
              "a copy or `mmap` to efficiently load from a file. This warning "
              "can be silenced by adding `warn_if_copy=False` to "
              "`from_buffer`");
    }
    return CopyBuffer(instance, std::move(buffer_obj));
  }
}

std::optional<iree_vm_function_t> VmModule::LookupFunction(
    const std::string& name, iree_vm_function_linkage_t linkage) {
  iree_vm_function_t f;
  auto status = iree_vm_module_lookup_function_by_name(
      raw_ptr(), linkage,
      {name.data(), static_cast<iree_host_size_t>(name.size())}, &f);
  if (iree_status_is_not_found(status)) {
    iree_status_ignore(status);
    return std::nullopt;
  }
  CheckApiStatus(status, "Error looking up function");
  return f;
}

//------------------------------------------------------------------------------
// VmRef
//------------------------------------------------------------------------------

const char* const VmRef::kRefAttr = "__iree_vm_ref__";
const char* const VmRef::kCastAttr = "__iree_vm_cast__";
const char* const VmRef::kTypeAttr = "__iree_vm_type__";

py::object VmRef::Deref(py::object ref_object_class, bool optional) {
  py::object casted = ref_object_class.attr(kCastAttr)(*this);
  if (!optional && casted.is_none()) {
    throw py::type_error("Cannot dereference to specific type");
  }
  return casted;
}

bool VmRef::IsInstance(py::object ref_object_class) {
  auto type = py::cast<iree_vm_ref_type_t>(ref_object_class.attr(kTypeAttr)());
  return type == ref_.type;
}

std::string VmRef::ToString() {
  if (!ref_.ptr) {
    return "<VmRef NULL>";
  }
  iree_string_view_t type_name = iree_vm_ref_type_name(ref_.type);
  std::stringstream ss;
  ss << "<VmRef ";
  ss.write(type_name.data, type_name.size);
  ss << " at " << std::hex << "0x" << reinterpret_cast<uintptr_t>(ref_.ptr)
     << ">";
  return ss.str();
}

//------------------------------------------------------------------------------
// VmVariantList
//------------------------------------------------------------------------------

void VmVariantList::PushFloat(double fvalue) {
  // Note that Python floats are f64.
  iree_vm_value_t value = iree_vm_value_make_f64(fvalue);
  CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value),
                 "Could not push float");
}

void VmVariantList::PushInt(int64_t ivalue) {
  // Note that Python ints are unbounded, so just use the largest type we
  // have.
  iree_vm_value_t value = iree_vm_value_make_i64(ivalue);
  CheckApiStatus(iree_vm_list_push_value(raw_ptr(), &value),
                 "Could not push int");
}

void VmVariantList::PushList(VmVariantList& other) {
  iree_vm_ref_t retained = iree_vm_list_retain_ref(other.raw_ptr());
  iree_vm_list_push_ref_move(raw_ptr(), &retained);
}

void VmVariantList::PushRef(py::handle ref_or_object) {
  py::object py_ref = ref_or_object.attr(VmRef::kRefAttr);
  VmRef& ref = py::cast<VmRef&>(py_ref);
  CheckApiStatus(iree_vm_list_push_ref_retain(raw_ptr(), &ref.ref()),
                 "Failed to push ref");
}

py::object VmVariantList::GetAsList(int index) {
  iree_vm_ref_t ref = {0};
  CheckApiStatus(iree_vm_list_get_ref_assign(raw_ptr(), index, &ref),
                 "Could not access list element");
  iree_vm_list_t* sub_list = NULL;
  CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list),
                 "Could not deref list (wrong type?)");
  iree_vm_list_retain(sub_list);
  return py::cast(VmVariantList::StealFromRawPtr(sub_list));
}

py::object VmVariantList::GetVariant(int index) {
  iree_vm_variant_t v = iree_vm_variant_empty();
  CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
                 "Could not access list element");
  if (iree_vm_variant_is_empty(v)) {
    return py::none();
  } else if (iree_vm_variant_is_value(v)) {
    // Convert a value type.
    switch (iree_vm_type_def_as_value(v.type)) {
      case IREE_VM_VALUE_TYPE_I8:
        return py::cast(v.i8);
      case IREE_VM_VALUE_TYPE_I16:
        return py::cast(v.i16);
      case IREE_VM_VALUE_TYPE_I32:
        return py::cast(v.i32);
      case IREE_VM_VALUE_TYPE_I64:
        return py::cast(v.i64);
      case IREE_VM_VALUE_TYPE_F32:
        return py::cast(v.f32);
      case IREE_VM_VALUE_TYPE_F64:
        return py::cast(v.f64);
      default:
        throw RaiseValueError("Unsupported VM value type conversion");
    }
  } else if (iree_vm_variant_is_ref(v)) {
    VmRef ref;
    iree_vm_ref_retain(&v.ref, &ref.ref());
    return py::cast(ref, py::return_value_policy::move);
  }

  throw RaiseValueError("Unsupported VM to Python Type Conversion");
}

py::object VmVariantList::GetAsSerializedTraceValue(int index) {
  iree_vm_variant_t v = iree_vm_variant_empty();
  CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
                 "Could not access list element");
  if (iree_vm_variant_is_empty(v)) {
    py::dict record;
    record["type"] = "null";
    return std::move(record);
  } else if (iree_vm_variant_is_value(v)) {
    // Convert a value type.
    py::dict record;
    switch (iree_vm_type_def_as_value(v.type)) {
      case IREE_VM_VALUE_TYPE_I8:
        record["i8"] = py::cast(v.i8);
        break;
      case IREE_VM_VALUE_TYPE_I16:
        record["i16"] = py::cast(v.i16);
        break;
      case IREE_VM_VALUE_TYPE_I32:
        record["i32"] = py::cast(v.i32);
        break;
      case IREE_VM_VALUE_TYPE_I64:
        record["i64"] = py::cast(v.i64);
        break;
      case IREE_VM_VALUE_TYPE_F32:
        record["f32"] = py::cast(v.f32);
        break;
      case IREE_VM_VALUE_TYPE_F64:
        record["f64"] = py::cast(v.f64);
        break;
      default:
        throw RaiseValueError("Unsupported VM value type conversion");
    }
    record["type"] = py::cast("value");
    return std::move(record);
  } else if (iree_vm_variant_is_ref(v)) {
    // Convert reference type.
    if (iree_vm_list_isa(v.ref)) {
      py::dict record;
      record["type"] = "vm.list";
      py::list items;
      iree_vm_list_t* sub_list = NULL;
      CheckApiStatus(iree_vm_list_check_deref(v.ref, &sub_list),
                     "Could not deref list (wrong type?)");
      iree_vm_list_retain(sub_list);
      VmVariantList sub_list_object = VmVariantList::StealFromRawPtr(sub_list);
      for (int i = 0, e = sub_list_object.size(); i < e; ++i) {
        items.append(sub_list_object.GetAsSerializedTraceValue(i));
      }
      record["items"] = std::move(items);
      return std::move(record);
    } else if (iree_hal_buffer_view_isa(v.ref)) {
      py::dict record;
      record["type"] = "hal.buffer_view";
      iree_hal_buffer_view_t* buffer_view = iree_hal_buffer_view_deref(v.ref);
      if (!buffer_view) {
        throw RaiseValueError(
            "Could not deref result buffer view (wrong type?)");
      }
      iree_hal_buffer_t* raw_buffer = iree_hal_buffer_view_buffer(buffer_view);
      if (!raw_buffer) {
        throw RaiseValueError("Could not deref result buffer (wrong type?)");
      }

      // Extract dims from the buffer view.
      iree_host_size_t rank = 0;
      std::vector<iree_hal_dim_t> dims(6);
      iree_status_t status = iree_hal_buffer_view_shape(
          buffer_view, dims.capacity(), dims.data(), &rank);
      if (iree_status_is_out_of_range(status)) {
        dims.resize(rank);
        status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(),
                                            dims.data(), &rank);
      }
      CheckApiStatus(status, "Error extracting shape");
      dims.resize(rank);
      record["shape"] = py::cast(std::move(dims));

      // Element type.
      iree_hal_element_type_t element_type =
          iree_hal_buffer_view_element_type(buffer_view);
      char element_type_str[64] = {0};
      iree_host_size_t element_type_length = 0;
      CheckApiStatus(
          iree_hal_format_element_type(element_type, sizeof(element_type_str),
                                       element_type_str, &element_type_length),
          "Formatting element type");
      record["element_type"] =
          std::string(element_type_str, element_type_length);

      // Map memory.
      iree_device_size_t byte_length = iree_hal_buffer_byte_length(raw_buffer);
      iree_hal_buffer_mapping_t mapped_memory = {{0}};
      CheckApiStatus(iree_hal_buffer_map_range(
                         raw_buffer, IREE_HAL_MAPPING_MODE_SCOPED,
                         IREE_HAL_MEMORY_ACCESS_READ, 0 /* element_offset */,
                         byte_length, &mapped_memory),
                     "Could not map memory");
      record["contents"] =
          py::bytes(reinterpret_cast<const char*>(mapped_memory.contents.data),
                    mapped_memory.contents.data_length);
      iree_hal_buffer_unmap_range(&mapped_memory);

      return std::move(record);
    }
  }

  throw RaiseValueError("Unsupported VM to Python Type Conversion");
}

py::object VmVariantList::GetAsRef(int index) {
  iree_vm_variant_t v = iree_vm_variant_empty();
  CheckApiStatus(iree_vm_list_get_variant_assign(raw_ptr(), index, &v),
                 "Could not access list element");
  if (!iree_vm_variant_is_ref(v)) {
    throw std::invalid_argument("list element is not a ref");
  }
  VmRef ref;
  iree_vm_ref_retain(&v.ref, &ref.ref());
  return py::cast(ref, py::return_value_policy::move);
}

py::object VmVariantList::GetAsObject(int index, py::object clazz) {
  return clazz.attr(VmRef::kCastAttr)(GetAsRef(index));
}

namespace {

static std::string ToHexString(const uint8_t* data, size_t length) {
  static constexpr char kHexChars[] = {'0', '1', '2', '3', '4', '5', '6', '7',
                                       '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
  std::string s(length * 2, ' ');
  for (size_t i = 0; i < length; ++i) {
    s[2 * i + 0] = kHexChars[(data[i] & 0xF0) >> 4];
    s[2 * i + 1] = kHexChars[(data[i] & 0x0F) >> 0];
  }
  return s;
}
static std::string ToHexString(uint32_t value) {
  return ToHexString((const uint8_t*)&value, sizeof(value));
}

void AppendListContents(std::string& out, iree_vm_list_t* list,
                        std::unordered_set<iree_vm_list_t*>& visited) {
  for (iree_host_size_t i = 0, e = iree_vm_list_size(list); i < e; ++i) {
    iree_vm_variant_t variant = iree_vm_variant_empty();
    iree_status_t status = iree_vm_list_get_variant_assign(list, i, &variant);
    if (!iree_status_is_ok(status)) {
      iree_status_ignore(status);
      out.append("Error");
      continue;
    }
    if (i > 0) out.append(", ");

    if (iree_vm_variant_is_value(variant)) {
      // Convert a value type to a string.
      switch (iree_vm_type_def_as_value(variant.type)) {
        case IREE_VM_VALUE_TYPE_I8: {
          out += std::to_string(variant.i8);
          break;
        }
        case IREE_VM_VALUE_TYPE_I16: {
          out += std::to_string(variant.i16);
          break;
        }
        case IREE_VM_VALUE_TYPE_I32: {
          out += std::to_string(variant.i32);
          break;
        }
        case IREE_VM_VALUE_TYPE_I64: {
          out += std::to_string(variant.i64);
          break;
        }
        case IREE_VM_VALUE_TYPE_F32: {
          out += std::to_string(variant.f32);
          break;
        }
        case IREE_VM_VALUE_TYPE_F64: {
          out += std::to_string(variant.f64);
          break;
        }
        default:
          throw RaiseValueError("Unsupported VM value type to string");
      }
    } else if (iree_vm_variant_is_ref(variant)) {
      // Pretty print a subset of ABI impacting known types.
      if (iree_hal_buffer_isa(variant.ref)) {
        auto* hal_buffer = iree_hal_buffer_deref(variant.ref);
        assert(hal_buffer);
        out += std::string("HalBuffer(") +
               std::to_string(iree_hal_buffer_byte_length(hal_buffer)) + ")";
      } else if (iree_hal_buffer_view_isa(variant.ref)) {
        auto hal_bv = iree_hal_buffer_view_deref(variant.ref);
        out += "HalBufferView(";
        std::vector<iree_hal_dim_t> shape(
            iree_hal_buffer_view_shape_rank(hal_bv));
        iree_hal_buffer_view_shape(hal_bv, shape.size(), shape.data(), nullptr);
        for (size_t i = 0; i < shape.size(); ++i) {
          if (i > 0) out += 'x';
          out += std::to_string(shape[i]);
        }
        out += ":0x" +
               ToHexString(static_cast<uint32_t>(
                   iree_hal_buffer_view_element_type(hal_bv))) +
               ")";
      } else if (iree_vm_list_isa(variant.ref)) {
        out.append("List[");
        iree_vm_list_t* sub_list = iree_vm_list_deref(variant.ref);
        if (visited.insert(sub_list).second) {
          AppendListContents(out, sub_list, visited);
        } else {
          out.append("...circular...");
        }
        out.append("]");
      } else {
        out += "Unknown(" +
               std::to_string(iree_vm_type_def_as_ref(variant.type)) + ")";
      }
    } else {
      out.append("None");
    }
  }
}

}  // namespace

std::string VmVariantList::DebugString() const {
  // The variant list API requires mutability, so we const cast to it internally
  // so we can maintain a const DebugString() for callers.
  auto mutable_this = const_cast<VmVariantList*>(this);
  std::string s =
      std::string("<VmVariantList(") + std::to_string(size()) + "): [";
  iree_vm_list_t* list = mutable_this->raw_ptr();
  std::unordered_set<iree_vm_list_t*> visited;
  visited.insert(list);
  AppendListContents(s, list, visited);
  s.append("]>");
  return s;
}

void SetupVmBindings(pybind11::module m) {
  py::enum_<enum iree_vm_function_linkage_e>(m, "Linkage")
      .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
      .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
      .value("IMPORT_OPTIONAL", IREE_VM_FUNCTION_LINKAGE_IMPORT_OPTIONAL)
      .value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
      .export_values();

  auto vm_buffer = py::class_<VmBuffer>(m, "VmBuffer", py::buffer_protocol());
  VmRef::BindRefProtocol(vm_buffer, iree_vm_buffer_type,
                         iree_vm_buffer_retain_ref, iree_vm_buffer_deref,
                         iree_vm_buffer_isa);
  vm_buffer
      .def(py::init([](iree_host_size_t length, iree_host_size_t alignment,
                       bool is_mutable) {
             iree_vm_buffer_access_t access = 0;
             if (is_mutable) {
               access |= IREE_VM_BUFFER_ACCESS_MUTABLE;
             }
             iree_vm_buffer_t* raw_buffer;
             CheckApiStatus(
                 iree_vm_buffer_create(access, length, alignment,
                                       iree_allocator_system(), &raw_buffer),
                 "Error creating buffer");
             return VmBuffer::StealFromRawPtr(raw_buffer);
           }),
           py::arg("length"), py::arg("alignment") = 0,
           py::arg("mutable") = true)
      .def_buffer([](VmBuffer& self) -> py::buffer_info {
        return py::buffer_info(
            /*ptr=*/self.raw_ptr()->data.data,
            /*itemsize=*/sizeof(uint8_t),
            /*format=*/py::format_descriptor<uint8_t>::format(),
            /*ndim=*/1,
            /*shape=*/{self.raw_ptr()->data.data_length},
            /*strides=*/{1},
            /*readonly=*/
            !(self.raw_ptr()->access & IREE_VM_BUFFER_ACCESS_MUTABLE));
      })
      .def("__repr__", [](VmBuffer& self) {
        std::stringstream ss;
        ss << "<VmBuffer size " << self.raw_ptr()->data.data_length << " at 0x"
           << std::hex << reinterpret_cast<uintptr_t>(self.raw_ptr()->data.data)
           << ">";
        return ss.str();
      });

  // Mutation and inspection of the variant list is mostly opaque to python.
  auto vm_list = py::class_<VmVariantList>(m, "VmVariantList");
  VmRef::BindRefProtocol(vm_list, iree_vm_list_type, iree_vm_list_retain_ref,
                         iree_vm_list_deref, iree_vm_list_isa);
  vm_list
      // User Methods.
      .def(py::init(&VmVariantList::Create))
      .def_property_readonly("size", &VmVariantList::size)
      .def("__len__", &VmVariantList::size)
      .def("get_as_ref", &VmVariantList::GetAsRef)
      .def("get_as_object", &VmVariantList::GetAsObject)
      .def("get_as_list", &VmVariantList::GetAsList)
      .def("get_variant", &VmVariantList::GetVariant)
      .def("get_serialized_trace_value",
           &VmVariantList::GetAsSerializedTraceValue)
      .def("push_float", &VmVariantList::PushFloat)
      .def("push_int", &VmVariantList::PushInt)
      .def("push_list", &VmVariantList::PushList)
      .def("push_ref", &VmVariantList::PushRef)
      .def("__repr__", &VmVariantList::DebugString);

  py::class_<iree_vm_function_t>(m, "VmFunction")
      .def_readonly("linkage", &iree_vm_function_t::linkage)
      .def_readonly("ordinal", &iree_vm_function_t::ordinal)
      .def_property_readonly("name",
                             [](iree_vm_function_t& self) {
                               iree_string_view_t name =
                                   iree_vm_function_name(&self);
                               return py::str(name.data, name.size);
                             })
      .def_property_readonly("module_name",
                             [](iree_vm_function_t& self) {
                               iree_string_view_t name =
                                   iree_vm_module_name(self.module);
                               return py::str(name.data, name.size);
                             })
      .def_property_readonly("reflection",
                             [](iree_vm_function_t& self) {
                               return GetFunctionReflectionDict(self);
                             })
      .def("__repr__", [](iree_vm_function_t& self) {
        iree_string_view_t name = iree_vm_function_name(&self);
        std::string repr("<VmFunction ");
        repr.append(name.data, name.size);

        iree_vm_function_signature_t sig = iree_vm_function_signature(&self);
        repr.append("(");
        repr.append(sig.calling_convention.data, sig.calling_convention.size);
        repr.append("), reflection = ");
        py::dict reflection = GetFunctionReflectionDict(self);
        repr.append(py::cast<std::string>(py::repr(reflection)));
        repr.append(">");
        return repr;
      });

  py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));

  py::class_<VmContext>(m, "VmContext")
      .def(py::init(&VmContext::Create), py::arg("instance"),
           py::arg("modules") = std::optional<std::vector<VmModule*>>())
      .def("register_modules", &VmContext::RegisterModules)
      .def_property_readonly("context_id", &VmContext::context_id)
      .def("invoke", &VmContext::Invoke);

  py::class_<VmModule>(m, "VmModule")
      .def_static("resolve_module_dependency",
                  &VmModule::ResolveModuleDependency)
      .def_static("from_flatbuffer", &VmModule::FromBuffer, py::arg("instance"),
                  py::arg("buffer"), py::arg("warn_if_copy") = true,
                  kFromBufferDocstring)
      .def_static("from_buffer", &VmModule::FromBuffer, py::arg("instance"),
                  py::arg("buffer"), py::arg("warn_if_copy") = true,
                  kFromBufferDocstring)
      .def_static("copy_buffer", &VmModule::CopyBuffer, py::arg("instance"),
                  py::arg("buffer"), kCopyBufferDocstring)
      .def_static("wrap_buffer", &VmModule::WrapBuffer, py::arg("instance"),
                  py::arg("buffer"), py::arg("destroy_callback") = py::none(),
                  py::arg("close_buffer") = false, kWrapBufferDocstring)
      .def_static("mmap", &VmModule::MMap, py::arg("instance"),
                  py::arg("filepath"), py::arg("destroy_callback") = py::none(),
                  kMMapDocstring)
      .def_property_readonly("name", &VmModule::name)
      .def_property_readonly("version",
                             [](VmModule& self) {
                               iree_vm_module_signature_t sig =
                                   iree_vm_module_signature(self.raw_ptr());
                               return sig.version;
                             })
      .def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
           py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT)
      .def_property_readonly(
          "stashed_flatbuffer_blob",
          [](VmModule& self) { return self.get_stashed_flatbuffer_blob(); })
      .def_property_readonly(
          "function_names",
          [](VmModule& self) {
            py::list names;
            iree_vm_module_signature_t sig =
                iree_vm_module_signature(self.raw_ptr());
            for (size_t ordinal = 0; ordinal < sig.export_function_count;
                 ++ordinal) {
              iree_vm_function_t f;
              auto status = iree_vm_module_lookup_function_by_ordinal(
                  self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
              if (iree_status_is_not_found(status)) {
                iree_status_ignore(status);
                break;
              }
              CheckApiStatus(status, "Error enumerating module");
              iree_string_view_t fname = iree_vm_function_name(&f);
              py::str name(fname.data, fname.size);
              names.append(name);
            }
            return names;
          })
      .def("__repr__", [](VmModule& self) {
        std::string repr("<VmModule ");
        iree_string_view_t name = iree_vm_module_name(self.raw_ptr());
        repr.append(name.data, name.size);

        iree_vm_module_signature_t sig =
            iree_vm_module_signature(self.raw_ptr());
        repr.append(" : [");
        for (size_t ordinal = 0; ordinal < sig.export_function_count;
             ++ordinal) {
          iree_vm_function_t f;
          auto status = iree_vm_module_lookup_function_by_ordinal(
              self.raw_ptr(), IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &f);
          if (iree_status_is_not_found(status)) {
            iree_status_ignore(status);
            break;
          }
          CheckApiStatus(status, "Error enumerating module");
          iree_string_view_t fname = iree_vm_function_name(&f);
          if (ordinal > 0) {
            repr.append(", ");
          }
          repr.append(fname.data, fname.size);
        }
        repr.append("]");
        repr.append(">");
        return repr;
      });

  py::class_<VmRef>(m, "VmRef")
      .def("isinstance", &VmRef::IsInstance)
      .def("deref", &VmRef::Deref, py::arg("value"),
           py::arg("optional") = false)
      .def("__repr__", &VmRef::ToString)
      .def_property_readonly(VmRef::kRefAttr,
                             [](py::object self) { return self; })
      .def("__eq__",
           [](VmRef& self, VmRef& other) {
             return self.ref().ptr == other.ref().ptr;
           })
      .def("__eq__", [](VmRef& self, py::object& other) { return false; });
}

}  // namespace python
}  // namespace iree
