Adding vm::ref<T> support for non-iree_vm_ref_t ref types. (#11065)
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.cpp b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
index 7557c16..51d191f 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.cpp
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.cpp
@@ -8,7 +8,6 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/hal/drivers/local_task/registration/driver_module.h"
-#include "iree/modules/hal/module.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -89,17 +88,17 @@
CompiledBinary::~CompiledBinary() {}
void CompiledBinary::deinitialize() {
- iree_vm_module_release(hal_module);
- iree_vm_module_release(main_module);
- iree_vm_context_release(context);
- iree_hal_device_release(device);
+ hal_module.reset();
+ main_module.reset();
+ context.reset();
+ device.reset();
}
LogicalResult CompiledBinary::invokeNullary(Location loc, StringRef name,
ResultsCallback callback) {
iree_vm_function_t function;
if (auto status = iree_vm_module_lookup_function_by_name(
- main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{name.data(), name.size()}, &function)) {
iree_status_ignore(status);
return emitError(loc) << "internal error evaling constant: func '" << name
@@ -114,7 +113,7 @@
iree_allocator_system(), &outputs));
if (auto status =
- iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
+ iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(), outputs.get(),
iree_allocator_system())) {
std::string message;
@@ -272,19 +271,22 @@
iree_hal_driver_release(driver);
// Create hal module.
- IREE_CHECK_OK(iree_hal_module_create(runtime.instance, device,
+ IREE_CHECK_OK(iree_hal_module_create(runtime.instance.get(), device.get(),
IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module));
// Bytecode module.
IREE_CHECK_OK(iree_vm_bytecode_module_create(
- runtime.instance, iree_make_const_byte_span(data, length),
+ runtime.instance.get(), iree_make_const_byte_span(data, length),
iree_allocator_null(), iree_allocator_system(), &main_module));
// Context.
- std::array<iree_vm_module_t*, 2> modules = {hal_module, main_module};
+ std::array<iree_vm_module_t*, 2> modules = {
+ hal_module.get(),
+ main_module.get(),
+ };
IREE_CHECK_OK(iree_vm_context_create_with_modules(
- runtime.instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
+ runtime.instance.get(), IREE_VM_CONTEXT_FLAG_NONE, modules.size(),
modules.data(), iree_allocator_system(), &context));
}
@@ -308,11 +310,11 @@
iree_hal_driver_registry_allocate(iree_allocator_system(), ®istry));
IREE_CHECK_OK(iree_hal_local_task_driver_module_register(registry));
IREE_CHECK_OK(iree_vm_instance_create(iree_allocator_system(), &instance));
- IREE_CHECK_OK(iree_hal_module_register_all_types(instance));
+ IREE_CHECK_OK(iree_hal_module_register_all_types(instance.get()));
}
Runtime::~Runtime() {
- iree_vm_instance_release(instance);
+ instance.reset();
iree_hal_driver_registry_free(registry);
}
diff --git a/compiler/src/iree/compiler/ConstEval/Runtime.h b/compiler/src/iree/compiler/ConstEval/Runtime.h
index 40a1564..6d5935b 100644
--- a/compiler/src/iree/compiler/ConstEval/Runtime.h
+++ b/compiler/src/iree/compiler/ConstEval/Runtime.h
@@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.h"
#include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -47,10 +48,10 @@
void deinitialize();
Attribute convertVariantToAttribute(Location loc, iree_vm_variant_t& variant);
- iree_hal_device_t* device = nullptr;
- iree_vm_module_t* hal_module = nullptr;
- iree_vm_module_t* main_module = nullptr;
- iree_vm_context_t* context = nullptr;
+ iree::vm::ref<iree_hal_device_t> device;
+ iree::vm::ref<iree_vm_module_t> hal_module;
+ iree::vm::ref<iree_vm_module_t> main_module;
+ iree::vm::ref<iree_vm_context_t> context;
};
// An in-memory compiled binary and accessors for working with it.
@@ -70,7 +71,7 @@
static Runtime& getInstance();
iree_hal_driver_registry_t* registry = nullptr;
- iree_vm_instance_t* instance = nullptr;
+ iree::vm::ref<iree_vm_instance_t> instance;
private:
Runtime();
diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc
index fefa4a2..b2924ff 100644
--- a/runtime/src/iree/modules/check/check_test.cc
+++ b/runtime/src/iree/modules/check/check_test.cc
@@ -177,8 +177,8 @@
/*outputs=*/nullptr, iree_allocator_system());
}
- iree_status_t Invoke(const char* function_name,
- std::vector<iree_vm_value_t> args) {
+ iree_status_t InvokeValue(const char* function_name,
+ std::vector<iree_vm_value_t> args) {
IREE_RETURN_IF_ERROR(
iree_vm_list_create(/*element_type=*/nullptr, args.size(),
iree_allocator_system(), &inputs_));
@@ -216,28 +216,28 @@
iree_vm_module_t* CheckTest::hal_module_ = nullptr;
TEST_F(CheckTest, ExpectTrueSuccess) {
- IREE_ASSERT_OK(Invoke("expect_true", {iree_vm_value_make_i32(1)}));
+ IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(1)}));
}
TEST_F(CheckTest, ExpectTrueFailure) {
EXPECT_NONFATAL_FAILURE(
- IREE_ASSERT_OK(Invoke("expect_true", {iree_vm_value_make_i32(0)})),
+ IREE_ASSERT_OK(InvokeValue("expect_true", {iree_vm_value_make_i32(0)})),
"Expected 0 to be nonzero");
}
TEST_F(CheckTest, ExpectFalseSuccess) {
- IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(0)}));
+ IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(0)}));
}
TEST_F(CheckTest, ExpectFalseFailure) {
EXPECT_NONFATAL_FAILURE(
- IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(1)})),
+ IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(1)})),
"Expected 1 to be zero");
}
TEST_F(CheckTest, ExpectFalseNotOneFailure) {
EXPECT_NONFATAL_FAILURE(
- IREE_ASSERT_OK(Invoke("expect_false", {iree_vm_value_make_i32(42)})),
+ IREE_ASSERT_OK(InvokeValue("expect_false", {iree_vm_value_make_i32(42)})),
"Expected 42 to be zero");
}
diff --git a/runtime/src/iree/vm/context.h b/runtime/src/iree/vm/context.h
index 3e5822d..4d629d2 100644
--- a/runtime/src/iree/vm/context.h
+++ b/runtime/src/iree/vm/context.h
@@ -12,6 +12,7 @@
#include "iree/base/api.h"
#include "iree/vm/instance.h"
#include "iree/vm/module.h"
+#include "iree/vm/ref.h"
#include "iree/vm/stack.h"
#ifdef __cplusplus
@@ -123,4 +124,6 @@
} // extern "C"
#endif // __cplusplus
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_context, iree_vm_context_t);
+
#endif // IREE_VM_CONTEXT_H_
diff --git a/runtime/src/iree/vm/instance.h b/runtime/src/iree/vm/instance.h
index 16b6a6d..a1f4a1a 100644
--- a/runtime/src/iree/vm/instance.h
+++ b/runtime/src/iree/vm/instance.h
@@ -8,6 +8,7 @@
#define IREE_VM_INSTANCE_H_
#include "iree/base/api.h"
+#include "iree/vm/ref.h"
#ifdef __cplusplus
extern "C" {
@@ -47,4 +48,6 @@
} // extern "C"
#endif // __cplusplus
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_instance, iree_vm_instance_t);
+
#endif // IREE_VM_INSTANCE_H_
diff --git a/runtime/src/iree/vm/invocation.h b/runtime/src/iree/vm/invocation.h
index 3e4ecaa..2ebc654 100644
--- a/runtime/src/iree/vm/invocation.h
+++ b/runtime/src/iree/vm/invocation.h
@@ -13,6 +13,7 @@
#include "iree/vm/context.h"
#include "iree/vm/list.h"
#include "iree/vm/module.h"
+#include "iree/vm/ref.h"
#ifdef __cplusplus
extern "C" {
@@ -344,4 +345,6 @@
} // extern "C"
#endif // __cplusplus
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_invocation, iree_vm_invocation_t);
+
#endif // IREE_VM_INVOCATION_H_
diff --git a/runtime/src/iree/vm/module.h b/runtime/src/iree/vm/module.h
index b991e50..f7c010b 100644
--- a/runtime/src/iree/vm/module.h
+++ b/runtime/src/iree/vm/module.h
@@ -15,6 +15,7 @@
#include "iree/base/api.h"
#include "iree/base/internal/atomics.h"
#include "iree/base/string_builder.h"
+#include "iree/vm/ref.h"
#ifdef __cplusplus
extern "C" {
@@ -552,4 +553,6 @@
} // extern "C"
#endif // __cplusplus
+IREE_VM_DECLARE_CC_TYPE_ADAPTERS(iree_vm_module, iree_vm_module_t);
+
#endif // IREE_VM_MODULE_H_
diff --git a/runtime/src/iree/vm/ref.h b/runtime/src/iree/vm/ref.h
index d73f389..b4cf9be 100644
--- a/runtime/src/iree/vm/ref.h
+++ b/runtime/src/iree/vm/ref.h
@@ -236,37 +236,6 @@
// Type adapter utilities for interfacing with the VM
//===----------------------------------------------------------------------===//
-#ifdef __cplusplus
-namespace iree {
-namespace vm {
-template <typename T>
-struct ref_type_descriptor {
- static const iree_vm_ref_type_descriptor_t* get();
-};
-} // namespace vm
-} // namespace iree
-#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T) \
- namespace iree { \
- namespace vm { \
- template <> \
- struct ref_type_descriptor<T> { \
- static const iree_vm_ref_type_descriptor_t* get() { \
- return name##_get_descriptor(); \
- } \
- }; \
- } \
- }
-
-#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor) \
- descriptor.type_name = iree_make_cstring_view(name); \
- descriptor.offsetof_counter = type::offsetof_counter(); \
- descriptor.destroy = type::DirectDestroy; \
- IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
-#else
-#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
-#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)
-#endif // __cplusplus
-
// TODO(benvanik): make these macros standard/document them.
#define IREE_VM_DECLARE_TYPE_ADAPTERS(name, T) \
IREE_API_EXPORT iree_vm_ref_t name##_retain_ref(T* value); \
@@ -330,6 +299,10 @@
// Optional C++ iree::vm::ref<T> wrapper.
#ifdef __cplusplus
#include "iree/vm/ref_cc.h"
+#else
+#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T)
+#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor)
+#define IREE_VM_DECLARE_CC_TYPE_ADAPTERS(name, T)
#endif // __cplusplus
#endif // IREE_VM_REF_H_
diff --git a/runtime/src/iree/vm/ref_cc.h b/runtime/src/iree/vm/ref_cc.h
index ac5fe02..c43494a 100644
--- a/runtime/src/iree/vm/ref_cc.h
+++ b/runtime/src/iree/vm/ref_cc.h
@@ -31,15 +31,20 @@
// types. We may still need the iree_vm_ref_type_t exposed but that's relatively
// simple compared to getting the typed retain/release functions.
+template <typename T>
+struct ref_type_descriptor {
+ static const iree_vm_ref_type_descriptor_t* get();
+};
+
// Users may override this with their custom types to allow the packing code to
// access their registered type ID at runtime.
template <typename T>
-IREE_ATTRIBUTE_ALWAYS_INLINE void ref_type_retain(T* p) {
+static inline void ref_type_retain(T* p) {
iree_vm_ref_object_retain(p, ref_type_descriptor<T>::get());
}
template <typename T>
-IREE_ATTRIBUTE_ALWAYS_INLINE void ref_type_release(T* p) {
+static inline void ref_type_release(T* p) {
iree_vm_ref_object_release(p, ref_type_descriptor<T>::get());
}
@@ -228,7 +233,6 @@
class ref {
private:
typedef ref this_type;
- typedef T* this_type::*unspecified_bool_type;
public:
IREE_ATTRIBUTE_ALWAYS_INLINE iree_vm_ref_type_t type() const noexcept {
@@ -329,9 +333,8 @@
// Support boolean expression evaluation ala unique_ptr/shared_ptr:
// https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool
- constexpr operator unspecified_bool_type() const noexcept { // NOLINT
- return get() ? reinterpret_cast<unspecified_bool_type>(&this_type::ref_.ptr)
- : nullptr;
+ constexpr operator bool() const noexcept { // NOLINT
+ return get() != nullptr;
}
// Supports unary expression evaluation.
constexpr bool operator!() const noexcept { return !get(); }
@@ -463,4 +466,112 @@
} // namespace vm
} // namespace iree
+//===----------------------------------------------------------------------===//
+// ref-type registration and declaration for generic types
+//===----------------------------------------------------------------------===//
+// This adds vm::ref<T> support for any C type that is registered with the
+// dynamic type registration mechanism and that can be wrapped in an
+// iree_vm_ref_t.
+
+#define IREE_VM_DECLARE_CC_TYPE_LOOKUP(name, T) \
+ namespace iree { \
+ namespace vm { \
+ template <> \
+ struct ref_type_descriptor<T> { \
+ static const iree_vm_ref_type_descriptor_t* get() { \
+ return name##_get_descriptor(); \
+ } \
+ }; \
+ } \
+ }
+
+#define IREE_VM_REGISTER_CC_TYPE(type, name, descriptor) \
+ descriptor.type_name = iree_make_cstring_view(name); \
+ descriptor.offsetof_counter = type::offsetof_counter(); \
+ descriptor.destroy = type::DirectDestroy; \
+ IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor));
+
+//===----------------------------------------------------------------------===//
+// ref-type registration and declaration for core VM types
+//===----------------------------------------------------------------------===//
+// This adds vm::ref<T> support for arbitrary C types that implement retain and
+// release methods and manage their reference count internally. These are not
+// registered with the dynamic type registration mechanism.
+
+#define IREE_VM_DECLARE_CC_TYPE_ADAPTERS(name, T) \
+ namespace iree { \
+ namespace vm { \
+ template <> \
+ inline void ref_type_retain(T* p) { \
+ name##_retain(p); \
+ } \
+ template <> \
+ inline void ref_type_release(T* p) { \
+ name##_release(p); \
+ } \
+ template <> \
+ class ref<T> { \
+ private: \
+ typedef ref this_type; \
+ \
+ public: \
+ IREE_ATTRIBUTE_ALWAYS_INLINE ref() noexcept : ptr_(nullptr) {} \
+ IREE_ATTRIBUTE_ALWAYS_INLINE ref(std::nullptr_t) noexcept \
+ : ptr_(nullptr) {} \
+ IREE_ATTRIBUTE_ALWAYS_INLINE ref(T* p) noexcept : ptr_(nullptr) {} \
+ IREE_ATTRIBUTE_ALWAYS_INLINE ~ref() noexcept { \
+ ref_type_release<T>(get()); \
+ } \
+ ref(const ref& rhs) noexcept : ptr_(rhs.ptr_) { \
+ ref_type_retain<T>(get()); \
+ } \
+ ref& operator=(const ref&) noexcept = delete; \
+ ref(ref&& rhs) noexcept : ptr_(rhs.ptr_) { rhs.release(); } \
+ ref& operator=(ref&& rhs) noexcept { \
+ if (get() != rhs.get()) { \
+ ref_type_release<T>(get()); \
+ ptr_ = rhs.ptr_; \
+ rhs.release(); \
+ } \
+ return *this; \
+ } \
+ template <typename U> \
+ ref(ref<U>&& rhs) noexcept { \
+ ptr_ = static_cast<T*>(rhs.release()); \
+ } \
+ template <typename U> \
+ ref& operator=(ref<U>&& rhs) noexcept { \
+ if (get() != rhs.get()) { \
+ ref_type_release<T>(get()); \
+ ptr_ = static_cast<T*>(rhs.release()); \
+ } \
+ return *this; \
+ } \
+ void reset() noexcept { \
+ ref_type_release<T>(get()); \
+ ptr_ = nullptr; \
+ } \
+ IREE_ATTRIBUTE_ALWAYS_INLINE T* release() noexcept { \
+ T* p = get(); \
+ ptr_ = nullptr; \
+ return p; \
+ } \
+ IREE_ATTRIBUTE_ALWAYS_INLINE void assign(T* value) noexcept { \
+ reset(); \
+ ptr_ = value; \
+ } \
+ constexpr T* get() const noexcept { return ptr_; } \
+ constexpr T& operator*() const noexcept { return *get(); } \
+ constexpr T* operator->() const noexcept { return get(); } \
+ constexpr T** operator&() noexcept { return &ptr_; } \
+ constexpr operator bool() const noexcept { return get() != nullptr; } \
+ constexpr bool operator!() const noexcept { return !get(); } \
+ void swap(ref& rhs) { std::swap(ptr_, rhs.ptr_); } \
+ \
+ private: \
+ mutable T* ptr_ = nullptr; \
+ }; \
+ } \
+ }
+
#endif // IREE_VM_REF_CC_H_
diff --git a/tools/iree-benchmark-module-main.cc b/tools/iree-benchmark-module-main.cc
index eea2d40..a7fc39b 100644
--- a/tools/iree-benchmark-module-main.cc
+++ b/tools/iree-benchmark-module-main.cc
@@ -437,20 +437,20 @@
// Order matters. Tear down modules first to release resources.
inputs_.reset();
- iree_vm_context_release(context_);
- iree_vm_module_release(main_module_);
- iree_vm_instance_release(instance_);
+ context_.reset();
+ main_module_.reset();
+ instance_.reset();
// Tear down device last in order to get accurate statistics.
if (device_allocator_ && FLAG_print_statistics) {
- IREE_IGNORE_ERROR(
- iree_hal_allocator_statistics_fprint(stderr, device_allocator_));
+ IREE_IGNORE_ERROR(iree_hal_allocator_statistics_fprint(
+ stderr, device_allocator_.get()));
}
- iree_hal_allocator_release(device_allocator_);
- iree_hal_device_release(device_);
+ device_allocator_.reset();
+ device_.reset();
};
- iree_hal_device_t* device() const { return device_; }
+ iree_hal_device_t* device() const { return device_.get(); }
iree_status_t Register() {
IREE_TRACE_SCOPE0("IREEBenchmark::Register");
@@ -478,10 +478,11 @@
iree_tooling_create_instance(host_allocator, &instance_));
IREE_RETURN_IF_ERROR(iree_tooling_load_module_from_flags(
- instance_, host_allocator, &main_module_));
+ instance_.get(), host_allocator, &main_module_));
IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
- instance_, /*user_module_count=*/1, /*user_modules=*/&main_module_,
+ instance_.get(), /*user_module_count=*/1,
+ /*user_modules=*/&main_module_,
/*default_device_uri=*/iree_string_view_empty(), host_allocator,
&context_, &device_, &device_allocator_));
@@ -494,25 +495,25 @@
iree_vm_function_t function;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_name(
- main_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ main_module_.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{function_name.data(), function_name.size()},
&function));
IREE_CHECK_OK(ParseToVariantList(
- device_allocator_,
+ device_allocator_.get(),
iree::span<const std::string>{FLAG_function_inputs.data(),
FLAG_function_inputs.size()},
- iree_vm_instance_allocator(instance_), &inputs_));
+ iree_vm_instance_allocator(instance_.get()), &inputs_));
iree_string_view_t invocation_model = iree_vm_function_lookup_attr_by_name(
&function, IREE_SV("iree.abi.model"));
if (iree_string_view_equal(invocation_model, IREE_SV("coarse-fences"))) {
// Asynchronous invocation.
- iree::RegisterAsyncBenchmark(function_name, device_, context_, function,
- inputs_.get());
+ iree::RegisterAsyncBenchmark(function_name, device_.get(), context_.get(),
+ function, inputs_.get());
} else {
// Synchronous invocation.
- iree::RegisterGenericBenchmark(function_name, context_, function,
+ iree::RegisterGenericBenchmark(function_name, context_.get(), function,
inputs_.get());
}
return iree_ok_status();
@@ -521,11 +522,11 @@
iree_status_t RegisterAllExportedFunctions() {
IREE_TRACE_SCOPE0("IREEBenchmark::RegisterAllExportedFunctions");
iree_vm_module_signature_t signature =
- iree_vm_module_signature(main_module_);
+ iree_vm_module_signature(main_module_.get());
for (iree_host_size_t i = 0; i < signature.export_function_count; ++i) {
iree_vm_function_t function;
IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
- main_module_, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
+ main_module_.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
iree_string_view_t function_name = iree_vm_function_name(&function);
// We run anything with the 'benchmark' attribute.
@@ -534,11 +535,11 @@
&function, IREE_SV("iree.benchmark"));
if (iree_string_view_equal(benchmark_type, IREE_SV("dispatch"))) {
iree::RegisterDispatchBenchmark(
- std::string(function_name.data, function_name.size), context_,
+ std::string(function_name.data, function_name.size), context_.get(),
function);
} else if (iree_string_view_equal(benchmark_type, IREE_SV("entry"))) {
iree::RegisterGenericBenchmark(
- std::string(function_name.data, function_name.size), context_,
+ std::string(function_name.data, function_name.size), context_.get(),
function,
/*inputs=*/nullptr);
} else {
@@ -567,8 +568,8 @@
if (argument_count == 2) {
// Only functions taking a (wait, signal) fence pair are run.
iree::RegisterAsyncBenchmark(
- std::string(function_name.data, function_name.size), device_,
- context_, function,
+ std::string(function_name.data, function_name.size),
+ device_.get(), context_.get(), function,
/*inputs=*/nullptr);
}
} else {
@@ -577,8 +578,8 @@
// Only functions with no inputs are run (because we can't pass
// anything).
iree::RegisterGenericBenchmark(
- std::string(function_name.data, function_name.size), context_,
- function,
+ std::string(function_name.data, function_name.size),
+ context_.get(), function,
/*inputs=*/nullptr);
}
}
@@ -587,11 +588,11 @@
return iree_ok_status();
}
- iree_vm_instance_t* instance_ = nullptr;
- iree_vm_context_t* context_ = nullptr;
- iree_hal_device_t* device_ = nullptr;
- iree_hal_allocator_t* device_allocator_ = nullptr;
- iree_vm_module_t* main_module_ = nullptr;
+ iree::vm::ref<iree_vm_instance_t> instance_;
+ iree::vm::ref<iree_vm_context_t> context_;
+ iree::vm::ref<iree_hal_device_t> device_;
+ iree::vm::ref<iree_hal_allocator_t> device_allocator_;
+ iree::vm::ref<iree_vm_module_t> main_module_;
iree::vm::ref<iree_vm_list_t> inputs_;
};
} // namespace
diff --git a/tools/iree-run-mlir-main.cc b/tools/iree-run-mlir-main.cc
index de1937d..29adb39 100644
--- a/tools/iree-run-mlir-main.cc
+++ b/tools/iree-run-mlir-main.cc
@@ -346,7 +346,7 @@
/*policy=*/nullptr, inputs.get(), outputs.get(), host_allocator));
// If the function is async we need to wait for it to complete.
- if (!!finish_fence) {
+ if (finish_fence) {
IREE_RETURN_IF_ERROR(
iree_hal_fence_wait(finish_fence.get(), iree_infinite_timeout()));
}
@@ -366,7 +366,7 @@
// Load the bytecode module from the flatbuffer data.
// We do this first so that if we fail validation we know prior to dealing
// with devices.
- iree_vm_module_t* main_module = nullptr;
+ vm::ref<iree_vm_module_t> main_module;
IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create(
instance,
iree_make_const_byte_span((void*)flatbuffer_data.data(),
@@ -375,17 +375,17 @@
if (!run_flag) {
// Just wanted verification; return without running.
- iree_vm_module_release(main_module);
+ main_module.reset();
return OkStatus();
}
// Evaluate all exported functions.
auto run_function = [&](int ordinal) -> Status {
iree_vm_function_t function;
- IREE_RETURN_IF_ERROR(
- iree_vm_module_lookup_function_by_ordinal(
- main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, ordinal, &function),
- "looking up function export %d", ordinal);
+ IREE_RETURN_IF_ERROR(iree_vm_module_lookup_function_by_ordinal(
+ main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ ordinal, &function),
+ "looking up function export %d", ordinal);
iree_string_view_t function_name = iree_vm_function_name(&function);
if (iree_string_view_starts_with(function_name,
iree_make_cstring_view("__")) ||
@@ -398,32 +398,33 @@
// Create the context we'll use for this (ensuring that we can't interfere
// with other running evaluations, such as when in a multithreaded test
// runner).
- iree_vm_context_t* context = nullptr;
- iree_hal_device_t* device = NULL;
- iree_hal_allocator_t* device_allocator = nullptr;
+ vm::ref<iree_vm_context_t> context;
+ vm::ref<iree_hal_device_t> device;
+ vm::ref<iree_hal_allocator_t> device_allocator;
IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
instance, /*user_module_count=*/1, /*user_modules=*/&main_module,
iree_make_string_view(default_device_uri.data(),
default_device_uri.size()),
iree_allocator_system(), &context, &device, &device_allocator));
- IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device));
+ IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
// Invoke the function and print results.
- IREE_RETURN_IF_ERROR(EvaluateFunction(context, device, device_allocator,
- function, function_name),
- "evaluating export function %d", ordinal);
+ IREE_RETURN_IF_ERROR(
+ EvaluateFunction(context.get(), device.get(), device_allocator.get(),
+ function, function_name),
+ "evaluating export function %d", ordinal);
- IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device));
+ IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
- iree_vm_context_release(context);
- iree_hal_allocator_release(device_allocator);
- iree_hal_device_release(device);
+ context.reset();
+ device_allocator.reset();
+ device.reset();
return OkStatus();
};
Status evaluate_status = OkStatus();
- auto module_signature = iree_vm_module_signature(main_module);
+ auto module_signature = iree_vm_module_signature(main_module.get());
for (iree_host_size_t i = 0; i < module_signature.export_function_count;
++i) {
evaluate_status = run_function(i);
@@ -432,7 +433,7 @@
}
}
- iree_vm_module_release(main_module);
+ main_module.reset();
return evaluate_status;
}
@@ -442,7 +443,7 @@
mlir::DialectRegistry& registry) {
IREE_TRACE_SCOPE0("EvaluateFile");
- iree_vm_instance_t* instance = nullptr;
+ vm::ref<iree_vm_instance_t> instance;
IREE_RETURN_IF_ERROR(
iree_tooling_create_instance(iree_allocator_system(), &instance),
"Creating instance");
@@ -463,11 +464,11 @@
std::string default_device_uri =
InferDefaultDeviceFromBackend(target_backend);
IREE_RETURN_IF_ERROR(
- EvaluateFunctions(instance, default_device_uri, flatbuffer_data),
+ EvaluateFunctions(instance.get(), default_device_uri, flatbuffer_data),
"Evaluating functions");
}
- iree_vm_instance_release(instance);
+ instance.reset();
return OkStatus();
}
diff --git a/tools/iree-run-module-main.cc b/tools/iree-run-module-main.cc
index 05aa0ed..ccbca90 100644
--- a/tools/iree-run-module-main.cc
+++ b/tools/iree-run-module-main.cc
@@ -89,19 +89,19 @@
IREE_TRACE_SCOPE0("iree-run-module");
iree_allocator_t host_allocator = iree_allocator_system();
- iree_vm_instance_t* instance = nullptr;
+ vm::ref<iree_vm_instance_t> instance;
IREE_RETURN_IF_ERROR(iree_tooling_create_instance(host_allocator, &instance),
"creating instance");
- iree_vm_module_t* main_module = nullptr;
+ vm::ref<iree_vm_module_t> main_module;
IREE_RETURN_IF_ERROR(iree_tooling_load_module_from_flags(
- instance, host_allocator, &main_module));
+ instance.get(), host_allocator, &main_module));
- iree_vm_context_t* context = NULL;
- iree_hal_device_t* device = NULL;
- iree_hal_allocator_t* device_allocator = NULL;
+ vm::ref<iree_vm_context_t> context;
+ vm::ref<iree_hal_device_t> device;
+ vm::ref<iree_hal_allocator_t> device_allocator;
IREE_RETURN_IF_ERROR(iree_tooling_create_context_from_flags(
- instance, /*user_module_count=*/1, /*user_modules=*/&main_module,
+ instance.get(), /*user_module_count=*/1, /*user_modules=*/&main_module,
/*default_device_uri=*/iree_string_view_empty(), host_allocator, &context,
&device, &device_allocator));
@@ -113,17 +113,17 @@
} else {
IREE_RETURN_IF_ERROR(
iree_vm_module_lookup_function_by_name(
- main_module, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ main_module.get(), IREE_VM_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{function_name.data(), function_name.size()},
&function),
"looking up function '%s'", function_name.c_str());
}
- IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device));
+ IREE_RETURN_IF_ERROR(iree_hal_begin_profiling_from_flags(device.get()));
vm::ref<iree_vm_list_t> inputs;
IREE_RETURN_IF_ERROR(ParseToVariantList(
- device_allocator,
+ device_allocator.get(),
iree::span<const std::string>{FLAG_function_inputs.data(),
FLAG_function_inputs.size()},
host_allocator, &inputs));
@@ -131,7 +131,8 @@
// If the function is async add fences so we can invoke it synchronously.
vm::ref<iree_hal_fence_t> finish_fence;
IREE_RETURN_IF_ERROR(iree_tooling_append_async_fence_inputs(
- inputs.get(), &function, device, /*wait_fence=*/NULL, &finish_fence));
+ inputs.get(), &function, device.get(), /*wait_fence=*/NULL,
+ &finish_fence));
vm::ref<iree_vm_list_t> outputs;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr, 16,
@@ -139,18 +140,18 @@
printf("EXEC @%s\n", function_name.c_str());
IREE_RETURN_IF_ERROR(
- iree_vm_invoke(context, function, IREE_VM_INVOCATION_FLAG_NONE,
+ iree_vm_invoke(context.get(), function, IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(), outputs.get(),
host_allocator),
"invoking function '%s'", function_name.c_str());
// If the function is async we need to wait for it to complete.
- if (!!finish_fence) {
+ if (finish_fence) {
IREE_RETURN_IF_ERROR(
iree_hal_fence_wait(finish_fence.get(), iree_infinite_timeout()));
}
- IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device));
+ IREE_RETURN_IF_ERROR(iree_hal_end_profiling_from_flags(device.get()));
if (FLAG_expected_outputs.empty()) {
IREE_RETURN_IF_ERROR(
@@ -179,17 +180,17 @@
// Release resources before gathering statistics.
inputs.reset();
outputs.reset();
- iree_vm_module_release(main_module);
- iree_vm_context_release(context);
+ main_module.reset();
+ context.reset();
if (device_allocator && FLAG_print_statistics) {
IREE_IGNORE_ERROR(
- iree_hal_allocator_statistics_fprint(stderr, device_allocator));
+ iree_hal_allocator_statistics_fprint(stderr, device_allocator.get()));
}
- iree_hal_allocator_release(device_allocator);
- iree_hal_device_release(device);
- iree_vm_instance_release(instance);
+ device_allocator.reset();
+ device.reset();
+ instance.reset();
return iree_ok_status();
}