Start python bindings for vm2.

Most things don't yet work as we need the hal module that is still in flight.

PiperOrigin-RevId: 286416263
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD
index d0d531d..e3d710a 100644
--- a/bindings/python/pyiree/BUILD
+++ b/bindings/python/pyiree/BUILD
@@ -127,6 +127,8 @@
         "@com_google_absl//absl/types:span",
         "@com_google_absl//absl/types:variant",
         "//iree/compiler/Utils",
+        "//iree/vm2",
+        "//iree/vm2:bytecode_module",
         "@local_config_mlir//:IR",
         "//iree/base:api",
         "//iree/base:status",
@@ -203,3 +205,14 @@
         "//bindings/python/pyiree",
     ],
 )
+
+py_test(
+    name = "vm_test",
+    srcs = ["vm_test.py"],
+    python_version = "PY3",
+    deps = NUMPY_DEPS + [
+        "//bindings/python:pathsetup",  # build_cleaner: keep
+        "@absl_py//absl/testing:absltest",
+        "//bindings/python/pyiree",
+    ],
+)
diff --git a/bindings/python/pyiree/binding.h b/bindings/python/pyiree/binding.h
index c39177c..24e4907 100644
--- a/bindings/python/pyiree/binding.h
+++ b/bindings/python/pyiree/binding.h
@@ -69,6 +69,22 @@
     return {free_fn, holder};
   }
 
+  static iree_allocator_t CreateDeallocator(std::shared_ptr<OpaqueBlob> blob) {
+    // Note that there are more efficient ways to write this which
+    // don't bounce through an extra heap alloc, but this is not
+    // intended to be a high impact code path.
+    struct Holder {
+      std::shared_ptr<OpaqueBlob> blob;
+    };
+    Holder* holder = new Holder{std::move(blob)};
+    auto free_fn = +([](void* self, void*) -> iree_status_t {
+      Holder* self_holder = static_cast<Holder*>(self);
+      delete self_holder;
+      return IREE_STATUS_OK;
+    });
+    return {holder /* self */, nullptr /* alloc */, free_fn /* free */};
+  }
+
  protected:
   void* data_;
   size_t size_;
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc
index ec9ab9b..b69082d 100644
--- a/bindings/python/pyiree/vm.cc
+++ b/bindings/python/pyiree/vm.cc
@@ -14,7 +14,10 @@
 
 #include "bindings/python/pyiree/vm.h"
 
+#include "absl/types/optional.h"
 #include "bindings/python/pyiree/status_utils.h"
+#include "iree/base/api.h"
+#include "iree/vm2/instance.h"
 
 namespace iree {
 namespace python {
@@ -29,8 +32,120 @@
   return RtModule::CreateRetained(module);
 }
 
+//------------------------------------------------------------------------------
+// VmInstance
+//------------------------------------------------------------------------------
+
+VmInstance VmInstance::Create() {
+  iree_vm_instance_t* instance;
+  auto status = iree_vm_instance_create(IREE_ALLOCATOR_SYSTEM, &instance);
+  CheckApiStatus(status, "Error creating instance");
+  return VmInstance::CreateRetained(instance);
+}
+
+//------------------------------------------------------------------------------
+// VmContext
+//------------------------------------------------------------------------------
+
+VmContext VmContext::Create(VmInstance* instance,
+                            absl::optional<std::vector<VmModule*>> modules) {
+  iree_vm_context_t* context;
+  if (!modules) {
+    // Simple create with open allowed modules.
+    auto status = iree_vm_context_create(instance->raw_ptr(),
+                                         IREE_ALLOCATOR_SYSTEM, &context);
+    CheckApiStatus(status, "Error creating vm context");
+  } else {
+    // Closed set of modules.
+    absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
+    module_handles.resize(modules->size());
+    for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+      module_handles[i] = (*modules)[i]->raw_ptr();
+    }
+    auto status = iree_vm_context_create_with_modules(
+        instance->raw_ptr(), module_handles.data(), module_handles.size(),
+        IREE_ALLOCATOR_SYSTEM, &context);
+    CheckApiStatus(status, "Error creating vm context with modules");
+  }
+
+  CHECK(context);
+  return VmContext::CreateRetained(context);
+}
+
+void VmContext::RegisterModules(std::vector<VmModule*> modules) {
+  absl::InlinedVector<iree_vm_module_t*, 8> module_handles;
+  module_handles.resize(modules.size());
+  for (size_t i = 0, e = module_handles.size(); i < e; ++i) {
+    module_handles[i] = modules[i]->raw_ptr();
+  }
+  auto status = iree_vm_context_register_modules(raw_ptr(), &module_handles[0],
+                                                 module_handles.size());
+  CheckApiStatus(status, "Error registering modules");
+}
+
+//------------------------------------------------------------------------------
+// VmModule
+//------------------------------------------------------------------------------
+
+VmModule VmModule::FromFlatbufferBlob(
+    std::shared_ptr<OpaqueBlob> flatbuffer_blob) {
+  iree_vm_module_t* module;
+  auto deallocator = OpaqueBlob::CreateDeallocator(flatbuffer_blob);
+  auto status = iree_vm_bytecode_module_create(
+      {static_cast<const uint8_t*>(flatbuffer_blob->data()),
+       flatbuffer_blob->size()},
+      deallocator, IREE_ALLOCATOR_SYSTEM, &module);
+  if (status != IREE_STATUS_OK) {
+    deallocator.free(deallocator.self, nullptr);
+  }
+
+  CheckApiStatus(status, "Error creating vm module from flatbuffer");
+  return VmModule::CreateRetained(module);
+}
+
+absl::optional<iree_vm_function_t> VmModule::LookupFunction(
+    const std::string& name, iree_vm_function_linkage_t linkage) {
+  iree_vm_function_t f;
+  auto status = iree_vm_module_lookup_function(raw_ptr(), linkage,
+                                               {name.data(), name.size()}, &f);
+  if (status == IREE_STATUS_NOT_FOUND) {
+    return absl::nullopt;
+  }
+  CheckApiStatus(status, "Error looking up function");
+  return f;
+}
+
 void SetupVmBindings(pybind11::module m) {
+  // Deprecated: VM1 module.
   m.def("create_module_from_blob", CreateModuleFromBlob);
+
+  py::enum_<iree_vm_function_linkage_t>(m, "Linkage")
+      .value("INTERNAL", IREE_VM_FUNCTION_LINKAGE_INTERNAL)
+      .value("IMPORT", IREE_VM_FUNCTION_LINKAGE_IMPORT)
+      .value("EXPORT", IREE_VM_FUNCTION_LINKAGE_EXPORT)
+      .export_values();
+
+  // Mutation and inspection of the variant list is mostly opaque to python.
+  py::class_<VmVariantList>(m, "VmVariantList")
+      .def(py::init(&VmVariantList::Create))
+      .def_property_readonly("size", &VmVariantList::size);
+
+  py::class_<iree_vm_function_t>(m, "VmFunction")
+      .def_readonly("ordinal", &iree_vm_function_t::ordinal)
+      .def_readonly("linkage", &iree_vm_function_t::linkage);
+
+  py::class_<VmInstance>(m, "VmInstance").def(py::init(&VmInstance::Create));
+
+  py::class_<VmContext>(m, "VmContext")
+      .def(py::init(&VmContext::Create), py::arg("instance"),
+           py::arg("modules") = absl::nullopt)
+      .def("register_modules", &VmContext::RegisterModules)
+      .def_property_readonly("context_id", &VmContext::context_id);
+
+  py::class_<VmModule>(m, "VmModule")
+      .def_static("from_flatbuffer", &VmModule::FromFlatbufferBlob)
+      .def("lookup_function", &VmModule::LookupFunction, py::arg("name"),
+           py::arg("linkage") = IREE_VM_FUNCTION_LINKAGE_EXPORT);
 }
 
 }  // namespace python
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/vm.h
index ffb79b1..ee094b3 100644
--- a/bindings/python/pyiree/vm.h
+++ b/bindings/python/pyiree/vm.h
@@ -15,15 +15,125 @@
 #ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_
 #define IREE_BINDINGS_PYTHON_PYIREE_VM_H_
 
+#include "absl/types/optional.h"
 #include "bindings/python/pyiree/binding.h"
 #include "bindings/python/pyiree/rt.h"
+#include "iree/base/api.h"
 #include "iree/vm/api.h"
+#include "iree/vm2/bytecode_module.h"
+#include "iree/vm2/context.h"
+#include "iree/vm2/instance.h"
+#include "iree/vm2/invocation.h"
+#include "iree/vm2/module.h"
 
 namespace iree {
 namespace python {
 
+//------------------------------------------------------------------------------
+// Retain/release bindings
+//------------------------------------------------------------------------------
+
+template <>
+struct ApiPtrAdapter<iree_vm_instance_t> {
+  static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); }
+  static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_context_t> {
+  static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); }
+  static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_module_t> {
+  static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); }
+  static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); }
+};
+
+template <>
+struct ApiPtrAdapter<iree_vm_invocation_t> {
+  static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); }
+  static void Release(iree_vm_invocation_t* b) {
+    iree_vm_invocation_release(b);
+  }
+};
+
+//------------------------------------------------------------------------------
+// ApiRefCounted types
+//------------------------------------------------------------------------------
+
+class VmInstance : public ApiRefCounted<VmInstance, iree_vm_instance_t> {
+ public:
+  static VmInstance Create();
+};
+
+class VmModule : public ApiRefCounted<VmModule, iree_vm_module_t> {
+ public:
+  static VmModule FromFlatbufferBlob(
+      std::shared_ptr<OpaqueBlob> flatbuffer_blob);
+
+  absl::optional<iree_vm_function_t> LookupFunction(
+      const std::string& name, iree_vm_function_linkage_t linkage);
+};
+
+class VmContext : public ApiRefCounted<VmContext, iree_vm_context_t> {
+ public:
+  // Creates a context, optionally with modules, which will make the context
+  // static, disallowing further module registration (and may be more
+  // efficient).
+  static VmContext Create(VmInstance* instance,
+                          absl::optional<std::vector<VmModule*>> modules);
+
+  // Registers additional modules. Only valid for non static contexts (i.e.
+  // those created without modules.
+  void RegisterModules(std::vector<VmModule*> modules);
+
+  // Unique id for this context.
+  int context_id() const { return iree_vm_context_id(raw_ptr()); }
+};
+
+class VmInvocation : public ApiRefCounted<VmInvocation, iree_vm_invocation_t> {
+};
+
 void SetupVmBindings(pybind11::module m);
 
+//------------------------------------------------------------------------------
+// VmVariantList
+//------------------------------------------------------------------------------
+
+class VmVariantList {
+ public:
+  VmVariantList() : list_(nullptr) {}
+  ~VmVariantList() {
+    if (list_) {
+      CheckApiStatus(iree_vm_variant_list_free(list_), "Error freeing list");
+    }
+  }
+
+  VmVariantList(VmVariantList&& other) {
+    list_ = other.list_;
+    other.list_ = nullptr;
+  }
+
+  VmVariantList& operator=(const VmVariantList&) = delete;
+  VmVariantList(const VmVariantList&) = delete;
+
+  static VmVariantList Create(iree_host_size_t capacity) {
+    iree_vm_variant_list_t* list;
+    CheckApiStatus(
+        iree_vm_variant_list_alloc(capacity, IREE_ALLOCATOR_SYSTEM, &list),
+        "Error allocating variant list");
+    return VmVariantList(list);
+  }
+
+  iree_host_size_t size() const { return iree_vm_variant_list_size(list_); }
+
+ private:
+  VmVariantList(iree_vm_variant_list_t* list) : list_(list) {}
+  iree_vm_variant_list_t* list_;
+};
+
 }  // namespace python
 }  // namespace iree
 
diff --git a/bindings/python/pyiree/vm_test.py b/bindings/python/pyiree/vm_test.py
new file mode 100644
index 0000000..338d64d
--- /dev/null
+++ b/bindings/python/pyiree/vm_test.py
@@ -0,0 +1,73 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=unused-variable
+# pylint: disable=g-unreachable-test-method
+
+from absl.testing import absltest
+import pyiree
+
+
+def create_simple_mul_module():
+  ctx = pyiree.CompilerContext()
+  input_module = ctx.parse_asm("""
+    func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
+          attributes { iree.module.export } {
+        %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "mul.1"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+        return %0 : tensor<4xf32>
+    }
+    """)
+  binary = input_module.compile()
+  m = pyiree.binding.vm.VmModule.from_flatbuffer(binary)
+  return m
+
+
+class VmTest(absltest.TestCase):
+
+  def test_variant_list(self):
+    l = pyiree.binding.vm.VmVariantList(5)
+    print(l)
+    self.assertEqual(l.size, 0)
+
+  def test_context_id(self):
+    instance = pyiree.binding.vm.VmInstance()
+    context1 = pyiree.binding.vm.VmContext(instance)
+    context2 = pyiree.binding.vm.VmContext(instance)
+    self.assertGreater(context2.context_id, context1.context_id)
+
+  def disabled_test_module_basics(self):
+    m = create_simple_mul_module()
+    f = m.lookup_function("simple_mul")
+    self.assertGreater(f.ordinal, 0)
+    notfound = m.lookup_function("notfound")
+    self.assertIs(notfound, None)
+
+  def disabled_test_dynamic_module_context(self):
+    instance = pyiree.binding.vm.VmInstance()
+    context = pyiree.binding.vm.VmContext(instance)
+    m = create_simple_mul_module()
+    context.register_modules([m])
+
+  def disabled_test_static_module_context(self):
+    m = create_simple_mul_module()
+    print(m)
+    instance = pyiree.binding.vm.VmInstance()
+    print(instance)
+    context = pyiree.binding.vm.VmContext(instance, modules=[m])
+    print(context)
+
+
+if __name__ == "__main__":
+  absltest.main()