Adding vm::ref<T> support for non-iree_vm_ref_t ref types. (#11065)

diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index 7557c16..51d191f 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -8,7 +8,6 @@
 
 #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
 #include "iree/hal/drivers/local_task/registration/driver_module.h"
-#include "iree/modules/hal/module.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 
@@ -89,17 +88,17 @@
 CompiledBinary::~CompiledBinary() {}
 
 void CompiledBinary::deinitialize() {
-  iree_vm_module_release(hal_module);
-  iree_vm_module_release(main_module);
-  iree_vm_context_release(context);
-  iree_hal_device_release(device);
+  hal_module.reset();
+  main_module.reset();
+  context.reset();
+  device.reset();
 }
 
 LogicalResult CompiledBinary::invokeNullary(Location loc, StringRef name,
                                             ResultsCallback callback) {
   iree_vm_function_t function;
   if (auto status = iree_vm_module_lookup_function_by_name(
-          main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+          main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
           iree_string_view_t{name.data(), name.size()}, &function)) {
     iree_status_ignore(status);
     return emitError(loc) << "internal error evaling constant: func '" << name
@@ -114,7 +113,7 @@
                                     iree_allocator_system(), &outputs));
 
   if (auto status =
-          iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
+          iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
                          /*policy=*/nullptr, inputs.get(), outputs.get(),
                          iree_allocator_system())) {
     std::string message;
@@ -272,19 +271,22 @@
   iree_hal_driver_release(driver);
 
   // Create hal module.
-  IREE_CHECK_OK(iree_hal_module_create(runtime.instance, device,
+  IREE_CHECK_OK(iree_hal_module_create(runtime.instance.get(), device.get(),
                                        IREE_HAL_MODULE_FLAG_NONE,
                                        iree_allocator_system(), &hal_module));
 
   // Bytecode module.
   IREE_CHECK_OK(iree_vm_bytecode_module_create(
-      runtime.instance, iree_make_const_byte_span(data, length),
+      runtime.instance.get(), iree_make_const_byte_span(data, length),
       iree_allocator_null(), iree_allocator_system(), &main_module));
 
   // Context.
-  std::array<iree_vm_module_t*, 2> modules = {hal_module, main_module};
+  std::array<iree_vm_module_t*, 2> modules = {
+      hal_module.get(),
+      main_module.get(),
+  };
   IREE_CHECK_OK(iree_vm_context_create_with_modules(
-      runtime.instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
+      runtime.instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
       modules.data(), iree_allocator_system(), &context));
 }
 
@@ -308,11 +310,11 @@
       iree_hal_driver_registry_allocate(iree_allocator_system(), &registry));
   IREE_CHECK_OK(iree_hal_local_task_driver_module_register(registry));
   IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
-  IREE_CHECK_OK(iree_hal_module_register_all_types(instance));
+  IREE_CHECK_OK(iree_hal_module_register_all_types(instance.get()));
 }
 
 Runtime::~Runtime() {
-  iree_vm_instance_release(instance);
+  instance.reset();
   iree_hal_driver_registry_free(registry);
 }
 
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.h b/compiler/src/iree/compiler/ConstEval/Runtime.h
index 40a1564..6d5935b 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.h
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.h
@@ -9,6 +9,7 @@
 
 #include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
 #include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
 #include "iree/vm/api.h"
 #include "iree/vm/bytecode_module.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -47,10 +48,10 @@
   void deinitialize();
   Attribute convertVariantToAttribute(Location loc, iree_vm_variant_t& variant);
 
-  iree_hal_device_t* device = nullptr;
-  iree_vm_module_t* hal_module = nullptr;
-  iree_vm_module_t* main_module = nullptr;
-  iree_vm_context_t* context = nullptr;
+  iree::vm::ref<iree_hal_device_t> device;
+  iree::vm::ref<iree_vm_module_t> hal_module;
+  iree::vm::ref<iree_vm_module_t> main_module;
+  iree::vm::ref<iree_vm_context_t> context;
 };
 
 // An in-memory compiled binary and accessors for working with it.
@@ -70,7 +71,7 @@
   static Runtime& getInstance();
 
   iree_hal_driver_registry_t* registry = nullptr;
-  iree_vm_instance_t* instance = nullptr;
+  iree::vm::ref<iree_vm_instance_t> instance;
 
  private:
   Runtime();
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index fefa4a2..b2924ff 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -177,8 +177,8 @@
                           /*outputs=*/nullptr, iree_allocator_system());
   }
 
-  iree_status_t Invoke(const char* function_name,
-                       std::vector<iree_vm_value_t> args) {
+  iree_status_t InvokeValue(const char* function_name,
+                            std::vector<iree_vm_value_t> args) {
     IREE_RETURN_IF_ERROR(
         iree_vm_list_create(/*element_type=*/nullptr, args.size(),
                             iree_allocator_system(), &inputs_));
@@ -216,28 +216,28 @@
 iree_vm_module_t* CheckTest::hal_module_ = nullptr;
 
 TEST_F(CheckTest, ExpectTrueSuccess) {
-  IREE_ASSERT_OK(Invoke("expect_true", {iree_vm_value_make_i32(1)}));
+  IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(1)}));
 }
 
 TEST_F(CheckTest, ExpectTrueFailure) {
   EXPECT_NONFATAL_FAILURE(
-      IREE_ASSERT_OK(Invoke("expect_true", {iree_vm_value_make_i32(0)})),
+      IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(0)})),
       "Expected 0 to be nonzero");
 }
 
 TEST_F(CheckTest, ExpectFalseSuccess) {
-  IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(0)}));
+  IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(0)}));
 }
 
 TEST_F(CheckTest, ExpectFalseFailure) {
   EXPECT_NONFATAL_FAILURE(
-      IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(1)})),
+      IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(1)})),
       "Expected 1 to be zero");
 }
 
 TEST_F(CheckTest, ExpectFalseNotOneFailure) {
   EXPECT_NONFATAL_FAILURE(
-      IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(42)})),
+      IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(42)})),
       "Expected 42 to be zero");
 }
 
diff --git a/runtime/src/iree/vm/context.h b/runtime/src/iree/vm/context.h
index 3e5822d..4d629d2 100644
--- a/runtime/src/iree/vm/context.h
+++ b/runtime/src/iree/vm/context.h
@@ -12,6 +12,7 @@
 #include "iree/base/api.h"
 #include "iree/vm/instance.h"
 #include "iree/vm/module.h"
+#include "iree/vm/ref.h"
 #include "iree/vm/stack.h"
 
 #ifdef __cplusplus
@@ -123,4 +124,6 @@
 }  // extern "C"
 #endif  // __cplusplus
 
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_context, iree_vm_context_t);
+
 #endif  // IREE_VM_CONTEXT_H_
diff --git a/runtime/src/iree/vm/instance.h b/runtime/src/iree/vm/instance.h
index 16b6a6d..a1f4a1a 100644
--- a/runtime/src/iree/vm/instance.h
+++ b/runtime/src/iree/vm/instance.h
@@ -8,6 +8,7 @@
 #define IREE_VM_INSTANCE_H_
 
 #include "iree/base/api.h"
+#include "iree/vm/ref.h"
 
 #ifdef __cplusplus
 extern "C" {
@@ -47,4 +48,6 @@
 }  // extern "C"
 #endif  // __cplusplus
 
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_instance, iree_vm_instance_t);
+
 #endif  // IREE_VM_INSTANCE_H_
diff --git a/runtime/src/iree/vm/invocation.h b/runtime/src/iree/vm/invocation.h
index 3e4ecaa..2ebc654 100644
--- a/runtime/src/iree/vm/invocation.h
+++ b/runtime/src/iree/vm/invocation.h
@@ -13,6 +13,7 @@
 #include "iree/vm/context.h"
 #include "iree/vm/list.h"
 #include "iree/vm/module.h"
+#include "iree/vm/ref.h"
 
 #ifdef __cplusplus
 extern "C" {
@@ -344,4 +345,6 @@
 }  // extern "C"
 #endif  // __cplusplus
 
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_invocation, iree_vm_invocation_t);
+
 #endif  // IREE_VM_INVOCATION_H_
diff --git a/runtime/src/iree/vm/module.h b/runtime/src/iree/vm/module.h
index b991e50..f7c010b 100644
--- a/runtime/src/iree/vm/module.h
+++ b/runtime/src/iree/vm/module.h
@@ -15,6 +15,7 @@
 #include "iree/base/api.h"
 #include "iree/base/internal/atomics.h"
 #include "iree/base/string_builder.h"
+#include "iree/vm/ref.h"
 
 #ifdef __cplusplus
 extern "C" {
@@ -552,4 +553,6 @@
 }  // extern "C"
 #endif  // __cplusplus
 
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_module, iree_vm_module_t);
+
 #endif  // IREE_VM_MODULE_H_
diff --git a/runtime/src/iree/vm/ref.h b/runtime/src/iree/vm/ref.h
index d73f389..b4cf9be 100644
--- a/runtime/src/iree/vm/ref.h
+++ b/runtime/src/iree/vm/ref.h
@@ -236,37 +236,6 @@
 // Type adapter utilities for interfacing with the VM
 //===----------------------------------------------------------------------===//
 
-#ifdef __cplusplus
-namespace iree {
-namespace vm {
-template <typename T>
-struct ref_type_descriptor {
-  static const iree_vm_ref_type_descriptor_t* get();
-};
-}  // namespace vm
-}  // namespace iree
-#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)         \
-  namespace iree {                                      \
-  namespace vm {                                        \
-  template <>                                           \
-  struct ref_type_descriptor<T> {                       \
-    static const iree_vm_ref_type_descriptor_t* get() { \
-      return name##_get_descriptor();                   \
-    }                                                   \
-  };                                                    \
-  }                                                     \
-  }
-
-#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)  \
-  descriptor.type_name = iree_make_cstring_view(name);    \
-  descriptor.offsetof_counter = type::offsetof_counter(); \
-  descriptor.destroy = type::DirectDestroy;               \
-  IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
-#else
-#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
-#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)
-#endif  // __cplusplus
-
 // TODO(benvanik): make these macros standard/document them.
 #define IREE_VM_DECLARE_TYPE_ADAPTERS(name, T)                              \
   IREE_API_EXPORT iree_vm_ref_t name##_retain_ref(T* value);                \
@@ -330,6 +299,10 @@
 // Optional C++ iree::vm::ref<T> wrapper.
 #ifdef __cplusplus
 #include "iree/vm/ref_cc.h"
+#else
+#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
+#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)
+#define IREE_VM_DECLARE_CC_TYPE_ADAPTERS(name, T)
 #endif  // __cplusplus
 
 #endif  // IREE_VM_REF_H_
diff --git a/runtime/src/iree/vm/ref_cc.h b/runtime/src/iree/vm/ref_cc.h
index ac5fe02..c43494a 100644
--- a/runtime/src/iree/vm/ref_cc.h
+++ b/runtime/src/iree/vm/ref_cc.h
@@ -31,15 +31,20 @@
 // types. We may still need the iree_vm_ref_type_t exposed but that's relatively
 // simple compared to getting the typed retain/release functions.
 
+template <typename T>
+struct ref_type_descriptor {
+  static const iree_vm_ref_type_descriptor_t* get();
+};
+
 // Users may override this with their custom types to allow the packing code to
 // access their registered type ID at runtime.
 template <typename T>
-IREE_ATTRIBUTE_ALWAYS_INLINE void ref_type_retain(T* p) {
+static inline void ref_type_retain(T* p) {
   iree_vm_ref_object_retain(p, ref_type_descriptor<T>::get());
 }
 
 template <typename T>
-IREE_ATTRIBUTE_ALWAYS_INLINE void ref_type_release(T* p) {
+static inline void ref_type_release(T* p) {
   iree_vm_ref_object_release(p, ref_type_descriptor<T>::get());
 }
 
@@ -228,7 +233,6 @@
 class ref {
  private:
   typedef ref this_type;
-  typedef T* this_type::*unspecified_bool_type;
 
  public:
   IREE_ATTRIBUTE_ALWAYS_INLINE iree_vm_ref_type_t type() const noexcept {
@@ -329,9 +333,8 @@
 
   // Support boolean expression evaluation ala unique_ptr/shared_ptr:
   // https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool
-  constexpr operator unspecified_bool_type() const noexcept {  // NOLINT
-    return get() ? reinterpret_cast<unspecified_bool_type>(&this_type::ref_.ptr)
-                 : nullptr;
+  constexpr operator bool() const noexcept {  // NOLINT
+    return get() != nullptr;
   }
   // Supports unary expression evaluation.
   constexpr bool operator!() const noexcept { return !get(); }
@@ -463,4 +466,112 @@
 }  // namespace vm
 }  // namespace iree
 
+//===----------------------------------------------------------------------===//
+// ref-type registration and declaration for generic types
+//===----------------------------------------------------------------------===//
+// This adds vm::ref<T> support for any C type that is registered with the
+// dynamic type registration mechanism and that can be wrapped in an
+// iree_vm_ref_t.
+
+#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)         \
+  namespace iree {                                      \
+  namespace vm {                                        \
+  template <>                                           \
+  struct ref_type_descriptor<T> {                       \
+    static const iree_vm_ref_type_descriptor_t* get() { \
+      return name##_get_descriptor();                   \
+    }                                                   \
+  };                                                    \
+  }                                                     \
+  }
+
+#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)  \
+  descriptor.type_name = iree_make_cstring_view(name);    \
+  descriptor.offsetof_counter = type::offsetof_counter(); \
+  descriptor.destroy = type::DirectDestroy;               \
+  IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
+
+//===----------------------------------------------------------------------===//
+// ref-type registration and declaration for core VM types
+//===----------------------------------------------------------------------===//
+// This adds vm::ref<T> support for arbitrary C types that implement retain and
+// release methods and manage their reference count internally. These are not
+// registered with the dynamic type registration mechanism.
+
+#define IREE_VM_DECLARE_CC_TYPE_ADAPTERS(name, T)                         \
+  namespace iree {                                                        \
+  namespace vm {                                                          \
+  template <>                                                             \
+  inline void ref_type_retain(T* p) {                                     \
+    name##_retain(p);                                                     \
+  }                                                                       \
+  template <>                                                             \
+  inline void ref_type_release(T* p) {                                    \
+    name##_release(p);                                                    \
+  }                                                                       \
+  template <>                                                             \
+  class ref<T> {                                                          \
+   private:                                                               \
+    typedef ref this_type;                                                \
+                                                                          \
+   public:                                                                \
+    IREE_ATTRIBUTE_ALWAYS_INLINE ref() noexcept : ptr_(nullptr) {}        \
+    IREE_ATTRIBUTE_ALWAYS_INLINE ref(std::nullptr_t) noexcept             \
+        : ptr_(nullptr) {}                                                \
+    IREE_ATTRIBUTE_ALWAYS_INLINE ref(T* p) noexcept : ptr_(nullptr) {}    \
+    IREE_ATTRIBUTE_ALWAYS_INLINE ~ref() noexcept {                        \
+      ref_type_release<T>(get());                                         \
+    }                                                                     \
+    ref(const ref& rhs) noexcept : ptr_(rhs.ptr_) {                       \
+      ref_type_retain<T>(get());                                          \
+    }                                                                     \
+    ref& operator=(const ref&) noexcept = delete;                         \
+    ref(ref&& rhs) noexcept : ptr_(rhs.ptr_) { rhs.release(); }           \
+    ref& operator=(ref&& rhs) noexcept {                                  \
+      if (get() != rhs.get()) {                                           \
+        ref_type_release<T>(get());                                       \
+        ptr_ = rhs.ptr_;                                                  \
+        rhs.release();                                                    \
+      }                                                                   \
+      return *this;                                                       \
+    }                                                                     \
+    template <typename U>                                                 \
+    ref(ref<U>&& rhs) noexcept {                                          \
+      ptr_ = static_cast<T*>(rhs.release());                              \
+    }                                                                     \
+    template <typename U>                                                 \
+    ref& operator=(ref<U>&& rhs) noexcept {                               \
+      if (get() != rhs.get()) {                                           \
+        ref_type_release<T>(get());                                       \
+        ptr_ = static_cast<T*>(rhs.release());                            \
+      }                                                                   \
+      return *this;                                                       \
+    }                                                                     \
+    void reset() noexcept {                                               \
+      ref_type_release<T>(get());                                         \
+      ptr_ = nullptr;                                                     \
+    }                                                                     \
+    IREE_ATTRIBUTE_ALWAYS_INLINE T* release() noexcept {                  \
+      T* p = get();                                                       \
+      ptr_ = nullptr;                                                     \
+      return p;                                                           \
+    }                                                                     \
+    IREE_ATTRIBUTE_ALWAYS_INLINE void assign(T* value) noexcept {         \
+      reset();                                                            \
+      ptr_ = value;                                                       \
+    }                                                                     \
+    constexpr T* get() const noexcept { return ptr_; }                    \
+    constexpr T& operator*() const noexcept { return *get(); }            \
+    constexpr T* operator->() const noexcept { return get(); }            \
+    constexpr T** operator&() noexcept { return &ptr_; }                  \
+    constexpr operator bool() const noexcept { return get() != nullptr; } \
+    constexpr bool operator!() const noexcept { return !get(); }          \
+    void swap(ref& rhs) { std::swap(ptr_, rhs.ptr_); }                    \
+                                                                          \
+   private:                                                               \
+    mutable T* ptr_ = nullptr;                                            \
+  };                                                                      \
+  }                                                                       \
+  }
+
 #endif  // IREE_VM_REF_CC_H_
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index eea2d40..a7fc39b 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -437,20 +437,20 @@
 
     // Order matters. Tear down modules first to release resources.
     inputs_.reset();
-    iree_vm_context_release(context_);
-    iree_vm_module_release(main_module_);
-    iree_vm_instance_release(instance_);
+    context_.reset();
+    main_module_.reset();
+    instance_.reset();
 
     // Tear down device last in order to get accurate statistics.
     if (device_allocator_ && FLAG_print_statistics) {
-      IREE_IGNORE_ERROR(
-          iree_hal_allocator_statistics_fprint(stderr, device_allocator_));
+      IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
+          stderr, device_allocator_.get()));
     }
-    iree_hal_allocator_release(device_allocator_);
-    iree_hal_device_release(device_);
+    device_allocator_.reset();
+    device_.reset();
   };
 
-  iree_hal_device_t* device() const { return device_; }
+  iree_hal_device_t* device() const { return device_.get(); }
 
   iree_status_t Register() {
     IREE_TRACE_SCOPE0("IREEBenchmark::Register");
@@ -478,10 +478,11 @@
         iree_tooling_create_instance(host_allocator, &instance_));
 
     IREE_RETURN_IF_ERROR(iree_tooling_load_module_from_flags(
-        instance_, host_allocator, &main_module_));
+        instance_.get(), host_allocator, &main_module_));
 
     IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
-        instance_, /*user_module_count=*/1, /*user_modules=*/&main_module_,
+        instance_.get(), /*user_module_count=*/1,
+        /*user_modules=*/&main_module_,
         /*default_device_uri=*/iree_string_view_empty(), host_allocator,
         &context_, &device_, &device_allocator_));
 
@@ -494,25 +495,25 @@
 
     iree_vm_function_t function;
     IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
-        main_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+        main_module_.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
         iree_string_view_t{function_name.data(), function_name.size()},
         &function));
 
     IREE_CHECK_OK(ParseToVariantList(
-        device_allocator_,
+        device_allocator_.get(),
         iree::span<const std::string>{FLAG_function_inputs.data(),
                                       FLAG_function_inputs.size()},
-        iree_vm_instance_allocator(instance_), &inputs_));
+        iree_vm_instance_allocator(instance_.get()), &inputs_));
 
     iree_string_view_t invocation_model = iree_vm_function_lookup_attr_by_name(
         &function, IREE_SV("iree.abi.model"));
     if (iree_string_view_equal(invocation_model, IREE_SV("coarse-fences"))) {
       // Asynchronous invocation.
-      iree::RegisterAsyncBenchmark(function_name, device_, context_, function,
-                                   inputs_.get());
+      iree::RegisterAsyncBenchmark(function_name, device_.get(), context_.get(),
+                                   function, inputs_.get());
     } else {
       // Synchronous invocation.
-      iree::RegisterGenericBenchmark(function_name, context_, function,
+      iree::RegisterGenericBenchmark(function_name, context_.get(), function,
                                      inputs_.get());
     }
     return iree_ok_status();
@@ -521,11 +522,11 @@
   iree_status_t RegisterAllExportedFunctions() {
     IREE_TRACE_SCOPE0("IREEBenchmark::RegisterAllExportedFunctions");
     iree_vm_module_signature_t signature =
-        iree_vm_module_signature(main_module_);
+        iree_vm_module_signature(main_module_.get());
     for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) {
       iree_vm_function_t function;
       IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
-          main_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
+          main_module_.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
       iree_string_view_t function_name = iree_vm_function_name(&function);
 
       // We run anything with the 'benchmark' attribute.
@@ -534,11 +535,11 @@
           &function, IREE_SV("iree.benchmark"));
       if (iree_string_view_equal(benchmark_type, IREE_SV("dispatch"))) {
         iree::RegisterDispatchBenchmark(
-            std::string(function_name.data, function_name.size), context_,
+            std::string(function_name.data, function_name.size), context_.get(),
             function);
       } else if (iree_string_view_equal(benchmark_type, IREE_SV("entry"))) {
         iree::RegisterGenericBenchmark(
-            std::string(function_name.data, function_name.size), context_,
+            std::string(function_name.data, function_name.size), context_.get(),
             function,
             /*inputs=*/nullptr);
       } else {
@@ -567,8 +568,8 @@
           if (argument_count == 2) {
             // Only functions taking a (wait, signal) fence pair are run.
             iree::RegisterAsyncBenchmark(
-                std::string(function_name.data, function_name.size), device_,
-                context_, function,
+                std::string(function_name.data, function_name.size),
+                device_.get(), context_.get(), function,
                 /*inputs=*/nullptr);
           }
         } else {
@@ -577,8 +578,8 @@
             // Only functions with no inputs are run (because we can't pass
             // anything).
             iree::RegisterGenericBenchmark(
-                std::string(function_name.data, function_name.size), context_,
-                function,
+                std::string(function_name.data, function_name.size),
+                context_.get(), function,
                 /*inputs=*/nullptr);
           }
         }
@@ -587,11 +588,11 @@
     return iree_ok_status();
   }
 
-  iree_vm_instance_t* instance_ = nullptr;
-  iree_vm_context_t* context_ = nullptr;
-  iree_hal_device_t* device_ = nullptr;
-  iree_hal_allocator_t* device_allocator_ = nullptr;
-  iree_vm_module_t* main_module_ = nullptr;
+  iree::vm::ref<iree_vm_instance_t> instance_;
+  iree::vm::ref<iree_vm_context_t> context_;
+  iree::vm::ref<iree_hal_device_t> device_;
+  iree::vm::ref<iree_hal_allocator_t> device_allocator_;
+  iree::vm::ref<iree_vm_module_t> main_module_;
   iree::vm::ref<iree_vm_list_t> inputs_;
 };
 }  // namespace
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index de1937d..29adb39 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -346,7 +346,7 @@
       /*policy=*/nullptr, inputs.get(), outputs.get(), host_allocator));
 
   // If the function is async we need to wait for it to complete.
-  if (!!finish_fence) {
+  if (finish_fence) {
     IREE_RETURN_IF_ERROR(
         iree_hal_fence_wait(finish_fence.get(), iree_infinite_timeout()));
   }
@@ -366,7 +366,7 @@
   // Load the bytecode module from the flatbuffer data.
   // We do this first so that if we fail validation we know prior to dealing
   // with devices.
-  iree_vm_module_t* main_module = nullptr;
+  vm::ref<iree_vm_module_t> main_module;
   IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
       instance,
       iree_make_const_byte_span((void*)flatbuffer_data.data(),
@@ -375,17 +375,17 @@
 
   if (!run_flag) {
     // Just wanted verification; return without running.
-    iree_vm_module_release(main_module);
+    main_module.reset();
     return OkStatus();
   }
 
   // Evaluate all exported functions.
   auto run_function = [&](int ordinal) -> Status {
     iree_vm_function_t function;
-    IREE_RETURN_IF_ERROR(
-        iree_vm_module_lookup_function_by_ordinal(
-            main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &function),
-        "looking up function export %d", ordinal);
+    IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
+                             main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
+                             ordinal, &function),
+                         "looking up function export %d", ordinal);
     iree_string_view_t function_name = iree_vm_function_name(&function);
     if (iree_string_view_starts_with(function_name,
                                      iree_make_cstring_view("__")) ||
@@ -398,32 +398,33 @@
     // Create the context we'll use for this (ensuring that we can't interfere
     // with other running evaluations, such as when in a multithreaded test
     // runner).
-    iree_vm_context_t* context = nullptr;
-    iree_hal_device_t* device = NULL;
-    iree_hal_allocator_t* device_allocator = nullptr;
+    vm::ref<iree_vm_context_t> context;
+    vm::ref<iree_hal_device_t> device;
+    vm::ref<iree_hal_allocator_t> device_allocator;
     IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
         instance, /*user_module_count=*/1, /*user_modules=*/&main_module,
         iree_make_string_view(default_device_uri.data(),
                               default_device_uri.size()),
         iree_allocator_system(), &context, &device, &device_allocator));
 
-    IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device));
+    IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
 
     // Invoke the function and print results.
-    IREE_RETURN_IF_ERROR(EvaluateFunction(context, device, device_allocator,
-                                          function, function_name),
-                         "evaluating export function %d", ordinal);
+    IREE_RETURN_IF_ERROR(
+        EvaluateFunction(context.get(), device.get(), device_allocator.get(),
+                         function, function_name),
+        "evaluating export function %d", ordinal);
 
-    IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device));
+    IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
 
-    iree_vm_context_release(context);
-    iree_hal_allocator_release(device_allocator);
-    iree_hal_device_release(device);
+    context.reset();
+    device_allocator.reset();
+    device.reset();
     return OkStatus();
   };
 
   Status evaluate_status = OkStatus();
-  auto module_signature = iree_vm_module_signature(main_module);
+  auto module_signature = iree_vm_module_signature(main_module.get());
   for (iree_host_size_t i = 0; i < module_signature.export_function_count;
        ++i) {
     evaluate_status = run_function(i);
@@ -432,7 +433,7 @@
     }
   }
 
-  iree_vm_module_release(main_module);
+  main_module.reset();
 
   return evaluate_status;
 }
@@ -442,7 +443,7 @@
                     mlir::DialectRegistry& registry) {
   IREE_TRACE_SCOPE0("EvaluateFile");
 
-  iree_vm_instance_t* instance = nullptr;
+  vm::ref<iree_vm_instance_t> instance;
   IREE_RETURN_IF_ERROR(
       iree_tooling_create_instance(iree_allocator_system(), &instance),
       "Creating instance");
@@ -463,11 +464,11 @@
     std::string default_device_uri =
         InferDefaultDeviceFromBackend(target_backend);
     IREE_RETURN_IF_ERROR(
-        EvaluateFunctions(instance, default_device_uri, flatbuffer_data),
+        EvaluateFunctions(instance.get(), default_device_uri, flatbuffer_data),
         "Evaluating functions");
   }
 
-  iree_vm_instance_release(instance);
+  instance.reset();
   return OkStatus();
 }
 
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 05aa0ed..ccbca90 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -89,19 +89,19 @@
   IREE_TRACE_SCOPE0("iree-run-module");
 
   iree_allocator_t host_allocator = iree_allocator_system();
-  iree_vm_instance_t* instance = nullptr;
+  vm::ref<iree_vm_instance_t> instance;
   IREE_RETURN_IF_ERROR(iree_tooling_create_instance(host_allocator, &instance),
                        "creating instance");
 
-  iree_vm_module_t* main_module = nullptr;
+  vm::ref<iree_vm_module_t> main_module;
   IREE_RETURN_IF_ERROR(iree_tooling_load_module_from_flags(
-      instance, host_allocator, &main_module));
+      instance.get(), host_allocator, &main_module));
 
-  iree_vm_context_t* context = NULL;
-  iree_hal_device_t* device = NULL;
-  iree_hal_allocator_t* device_allocator = NULL;
+  vm::ref<iree_vm_context_t> context;
+  vm::ref<iree_hal_device_t> device;
+  vm::ref<iree_hal_allocator_t> device_allocator;
   IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
-      instance, /*user_module_count=*/1, /*user_modules=*/&main_module,
+      instance.get(), /*user_module_count=*/1, /*user_modules=*/&main_module,
       /*default_device_uri=*/iree_string_view_empty(), host_allocator, &context,
       &device, &device_allocator));
 
@@ -113,17 +113,17 @@
   } else {
     IREE_RETURN_IF_ERROR(
         iree_vm_module_lookup_function_by_name(
-            main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+            main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
             iree_string_view_t{function_name.data(), function_name.size()},
             &function),
         "looking up function '%s'", function_name.c_str());
   }
 
-  IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device));
+  IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
 
   vm::ref<iree_vm_list_t> inputs;
   IREE_RETURN_IF_ERROR(ParseToVariantList(
-      device_allocator,
+      device_allocator.get(),
       iree::span<const std::string>{FLAG_function_inputs.data(),
                                     FLAG_function_inputs.size()},
       host_allocator, &inputs));
@@ -131,7 +131,8 @@
   // If the function is async add fences so we can invoke it synchronously.
   vm::ref<iree_hal_fence_t> finish_fence;
   IREE_RETURN_IF_ERROR(iree_tooling_append_async_fence_inputs(
-      inputs.get(), &function, device, /*wait_fence=*/NULL, &finish_fence));
+      inputs.get(), &function, device.get(), /*wait_fence=*/NULL,
+      &finish_fence));
 
   vm::ref<iree_vm_list_t> outputs;
   IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr, 16,
@@ -139,18 +140,18 @@
 
   printf("EXEC @%s\n", function_name.c_str());
   IREE_RETURN_IF_ERROR(
-      iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
+      iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
                      /*policy=*/nullptr, inputs.get(), outputs.get(),
                      host_allocator),
       "invoking function '%s'", function_name.c_str());
 
   // If the function is async we need to wait for it to complete.
-  if (!!finish_fence) {
+  if (finish_fence) {
     IREE_RETURN_IF_ERROR(
         iree_hal_fence_wait(finish_fence.get(), iree_infinite_timeout()));
   }
 
-  IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device));
+  IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
 
   if (FLAG_expected_outputs.empty()) {
     IREE_RETURN_IF_ERROR(
@@ -179,17 +180,17 @@
   // Release resources before gathering statistics.
   inputs.reset();
   outputs.reset();
-  iree_vm_module_release(main_module);
-  iree_vm_context_release(context);
+  main_module.reset();
+  context.reset();
 
   if (device_allocator && FLAG_print_statistics) {
     IREE_IGNORE_ERROR(
-        iree_hal_allocator_statistics_fprint(stderr, device_allocator));
+        iree_hal_allocator_statistics_fprint(stderr, device_allocator.get()));
   }
 
-  iree_hal_allocator_release(device_allocator);
-  iree_hal_device_release(device);
-  iree_vm_instance_release(instance);
+  device_allocator.reset();
+  device.reset();
+  instance.reset();
   return iree_ok_status();
 }