Exposing the iree::rt C++ API as a C API.
PiperOrigin-RevId: 274627570
diff --git a/rt/BUILD b/rt/BUILD
index 9ba5990..8fc7242 100644
--- a/rt/BUILD
+++ b/rt/BUILD
@@ -6,6 +6,33 @@
)
cc_library(
+ name = "api",
+ srcs = ["api.cc"],
+ hdrs = ["api.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":api_hdrs",
+ ":rt",
+ "//iree/base:api",
+ "//iree/base:api_util",
+ "//iree/base:tracing",
+ "//iree/hal:api",
+ "//iree/hal:buffer_view",
+ "//iree/rt/debug:debug_server_interface",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "api_hdrs",
+ hdrs = ["api.h"],
+ deps = [
+ "//iree/base:api_hdrs",
+ "//iree/hal:api_hdrs",
+ ],
+)
+
+cc_library(
name = "rt",
srcs = [
"context.cc",
diff --git a/rt/CMakeLists.txt b/rt/CMakeLists.txt
index f297759..8cf7786 100644
--- a/rt/CMakeLists.txt
+++ b/rt/CMakeLists.txt
@@ -16,6 +16,24 @@
iree_cc_library(
NAME
+ api
+ HDRS
+ "api.h"
+ SRCS
+ "api.cc"
+ DEPS
+ absl::time
+ iree::base::api
+ iree::base::api_util
+ iree::base::tracing
+ iree::hal::api
+ iree::hal::buffer_view
+ iree::rt
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
rt
SRCS
"context.cc"
diff --git a/rt/api.cc b/rt/api.cc
new file mode 100644
index 0000000..6e1e5c5
--- /dev/null
+++ b/rt/api.cc
@@ -0,0 +1,667 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/rt/api.h"
+
+#include "absl/time/time.h"
+#include "iree/base/api.h"
+#include "iree/base/api_util.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/api.h"
+#include "iree/hal/buffer_view.h"
+#include "iree/rt/context.h"
+#include "iree/rt/debug/debug_server.h"
+#include "iree/rt/function.h"
+#include "iree/rt/instance.h"
+#include "iree/rt/invocation.h"
+#include "iree/rt/module.h"
+#include "iree/rt/policy.h"
+
+namespace iree {
+namespace rt {
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Instance
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
+ iree_allocator_t allocator, iree_rt_instance_t** out_instance) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_create");
+
+ if (!out_instance) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_instance = nullptr;
+
+ auto instance = make_ref<Instance>();
+ *out_instance = reinterpret_cast<iree_rt_instance_t*>(instance.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_retain(iree_rt_instance_t* instance) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_retain");
+ auto* handle = reinterpret_cast<Instance*>(instance);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_release(iree_rt_instance_t* instance) {
+ IREE_TRACE_SCOPE0("iree_rt_instance_release");
+ auto* handle = reinterpret_cast<Instance*>(instance);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Module
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class ExternalModule final : public Module {
+ public:
+ ExternalModule(iree_rt_external_module_t impl, iree_allocator_t allocator)
+ : impl_(impl), allocator_(allocator) {
+ IREE_TRACE_SCOPE0("ExternalModule::ctor");
+ }
+
+ ~ExternalModule() override {
+ IREE_TRACE_SCOPE0("ExternalModule::dtor");
+ impl_.destroy(impl_.self);
+ std::memset(&impl_, 0, sizeof(impl_));
+ }
+
+ absl::string_view name() const override {
+ auto result = impl_.name(impl_.self);
+ return absl::string_view{result.data, result.size};
+ }
+
+ const ModuleSignature signature() const override {
+ auto signature = impl_.signature(impl_.self);
+ return ModuleSignature{
+ signature.import_function_count,
+ signature.export_function_count,
+ signature.internal_function_count,
+ signature.state_slot_count,
+ };
+ }
+
+ SourceResolver* source_resolver() const override { return nullptr; }
+
+ std::string DebugStringShort() const override { return std::string(name()); }
+
+ StatusOr<const Function> LookupFunctionByOrdinal(
+ Function::Linkage linkage, int32_t ordinal) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByOrdinal");
+ iree_rt_function_t function;
+ auto status = impl_.lookup_function_by_ordinal(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
+ &function);
+ if (status != IREE_STATUS_OK) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return Function{reinterpret_cast<Module*>(function.module),
+ static_cast<Function::Linkage>(function.linkage),
+ function.ordinal};
+ }
+
+ StatusOr<const Function> LookupFunctionByName(
+ Function::Linkage linkage, absl::string_view name) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByName");
+ iree_rt_function_t function;
+ auto status = impl_.lookup_function_by_name(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage),
+ iree_string_view_t{name.data(), name.size()}, &function);
+ if (status != IREE_STATUS_OK) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return Function{reinterpret_cast<Module*>(function.module),
+ static_cast<Function::Linkage>(function.linkage),
+ function.ordinal};
+ }
+
+ StatusOr<absl::string_view> GetFunctionName(Function::Linkage linkage,
+ int32_t ordinal) const {
+ IREE_TRACE_SCOPE0("ExternalModule::GetFunctionName");
+ iree_string_view_t name;
+ auto status = impl_.get_function_name(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
+ &name);
+ RETURN_IF_ERROR(FromApiStatus(status, IREE_LOC));
+ return absl::string_view{name.data, name.size};
+ }
+
+ StatusOr<const FunctionSignature> GetFunctionSignature(
+ Function::Linkage linkage, int32_t ordinal) const override {
+ IREE_TRACE_SCOPE0("ExternalModule::GetFunctionSignature");
+ iree_rt_function_signature_t signature;
+ auto status = impl_.get_function_signature(
+ impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
+ &signature);
+ if (status != IREE_STATUS_OK) {
+ return FromApiStatus(status, IREE_LOC);
+ }
+ return FunctionSignature{signature.argument_count, signature.result_count};
+ }
+
+ Status Execute(
+ const Function function,
+ absl::InlinedVector<hal::BufferView, 8> arguments,
+ absl::InlinedVector<hal::BufferView, 8>* results) const override {
+ // TODO(benvanik): fn ptr callback to external code. Waiting on fibers.
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "External calls not yet implemented";
+ }
+
+ private:
+ iree_rt_external_module_t impl_;
+ iree_allocator_t allocator_;
+};
+
+} // namespace
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
+ iree_rt_external_module_t impl, iree_allocator_t allocator,
+ iree_rt_module_t** out_module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_create_external");
+
+ if (!out_module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_module = nullptr;
+
+ auto module = make_ref<ExternalModule>(impl, allocator);
+ *out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_retain(iree_rt_module_t* module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_retain");
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_release(iree_rt_module_t* module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_release");
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_module_name(const iree_rt_module_t* module) {
+ IREE_TRACE_SCOPE0("iree_rt_module_name");
+ const auto* handle = reinterpret_cast<const Module*>(module);
+ CHECK(handle) << "NULL module handle";
+ return iree_string_view_t{handle->name().data(), handle->name().size()};
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ int32_t ordinal,
+ iree_rt_function_t* out_function) {
+ IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_ordinal");
+
+ if (!out_function) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_function, 0, sizeof(*out_function));
+
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto function_or = handle->LookupFunctionByOrdinal(
+ static_cast<Function::Linkage>(linkage), ordinal);
+ IREE_API_RETURN_IF_ERROR(function_or.status());
+ auto function = std::move(function_or).ValueOrDie();
+
+ out_function->module = module;
+ out_function->linkage =
+ static_cast<iree_rt_function_linkage_t>(function.linkage());
+ out_function->ordinal = function.ordinal();
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ iree_string_view_t name,
+ iree_rt_function_t* out_function) {
+ IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_name");
+
+ if (!out_function) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_function, 0, sizeof(*out_function));
+
+ auto* handle = reinterpret_cast<Module*>(module);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto function_or =
+ handle->LookupFunctionByName(static_cast<Function::Linkage>(linkage),
+ absl::string_view{name.data, name.size});
+ IREE_API_RETURN_IF_ERROR(function_or.status());
+ auto function = std::move(function_or).ValueOrDie();
+
+ out_function->linkage =
+ static_cast<iree_rt_function_linkage_t>(function.linkage());
+ out_function->module = module;
+ out_function->linkage = linkage;
+ out_function->ordinal = function.ordinal();
+
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Function
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_function_name(const iree_rt_function_t* function) {
+ IREE_TRACE_SCOPE0("iree_rt_function_name");
+ CHECK(function && function->module) << "NULL function handle";
+ auto* module = reinterpret_cast<Module*>(function->module);
+ auto name_or = module->GetFunctionName(
+ static_cast<Function::Linkage>(function->linkage), function->ordinal);
+ if (!name_or.ok()) return {};
+ auto name = name_or.ValueOrDie();
+ return iree_string_view_t{name.data(), name.size()};
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_function_signature(const iree_rt_function_t* function,
+ iree_rt_function_signature_t* out_signature) {
+ IREE_TRACE_SCOPE0("iree_rt_function_signature");
+
+ if (!out_signature) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_signature, 0, sizeof(*out_signature));
+
+ if (!function || !function->module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto* module = reinterpret_cast<Module*>(function->module);
+ auto signature_or = module->GetFunctionSignature(
+ static_cast<Function::Linkage>(function->linkage), function->ordinal);
+ IREE_API_RETURN_IF_ERROR(signature_or.status());
+ auto signature = std::move(signature_or).ValueOrDie();
+ out_signature->argument_count = signature.argument_count();
+ out_signature->result_count = signature.result_count();
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Policy
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t iree_rt_policy_create(
+ iree_allocator_t allocator, iree_rt_policy_t** out_policy) {
+ IREE_TRACE_SCOPE0("iree_rt_policy_create");
+
+ if (!out_policy) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_policy = nullptr;
+
+ auto policy = make_ref<Policy>();
+
+ *out_policy = reinterpret_cast<iree_rt_policy_t*>(policy.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_retain(iree_rt_policy_t* policy) {
+ IREE_TRACE_SCOPE0("iree_rt_policy_retain");
+ auto* handle = reinterpret_cast<Policy*>(policy);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_release(iree_rt_policy_t* policy) {
+ IREE_TRACE_SCOPE0("iree_rt_policy_release");
+ auto* handle = reinterpret_cast<Policy*>(policy);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Context
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
+ iree_rt_instance_t* instance, iree_rt_policy_t* policy,
+ iree_allocator_t allocator, iree_rt_context_t** out_context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_create");
+
+ if (!out_context) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_context = nullptr;
+
+ if (!instance || !policy) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto context =
+ make_ref<Context>(add_ref(reinterpret_cast<Instance*>(instance)),
+ add_ref(reinterpret_cast<Policy*>(policy)));
+
+ *out_context = reinterpret_cast<iree_rt_context_t*>(context.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_retain(iree_rt_context_t* context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_retain");
+ auto* handle = reinterpret_cast<Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_release(iree_rt_context_t* context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_release");
+ auto* handle = reinterpret_cast<Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT int32_t IREE_API_CALL
+iree_rt_context_id(const iree_rt_context_t* context) {
+ IREE_TRACE_SCOPE0("iree_rt_context_id");
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ CHECK(handle) << "NULL context handle";
+ return handle->id();
+}
+
+IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
+iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
+ iree_string_view_t module_name) {
+ IREE_TRACE_SCOPE0("iree_rt_context_lookup_module_by_name");
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ CHECK(handle) << "NULL context handle";
+ auto module_or = handle->LookupModuleByName(
+ absl::string_view{module_name.data, module_name.size});
+ if (!module_or.ok()) {
+ return nullptr;
+ }
+ return reinterpret_cast<iree_rt_module_t*>(module_or.ValueOrDie());
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
+ const iree_rt_context_t* context, iree_string_view_t full_name,
+ iree_rt_function_t* out_function) {
+ IREE_TRACE_SCOPE0("iree_rt_context_resolve_function");
+
+ if (!out_function) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ std::memset(out_function, 0, sizeof(*out_function));
+
+ const auto* handle = reinterpret_cast<const Context*>(context);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ auto full_name_view = absl::string_view{full_name.data, full_name.size};
+ size_t last_dot = full_name_view.rfind('.');
+ if (last_dot == absl::string_view::npos) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ auto module_name = full_name_view.substr(0, last_dot);
+ auto function_name = full_name_view.substr(last_dot + 1);
+
+ iree_rt_module_t* module = iree_rt_context_lookup_module_by_name(
+ context, iree_string_view_t{module_name.data(), module_name.size()});
+ if (!module) {
+ return IREE_STATUS_NOT_FOUND;
+ }
+
+ return iree_rt_module_lookup_function_by_name(
+ module, IREE_RT_FUNCTION_LINKAGE_EXPORT,
+ iree_string_view_t{function_name.data(), function_name.size()},
+ out_function);
+}
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Invocation
+//===----------------------------------------------------------------------===//
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
+ iree_rt_context_t* context, iree_rt_function_t* function,
+ iree_rt_policy_t* policy,
+ const iree_rt_invocation_dependencies_t* dependencies,
+ const iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
+ const iree_hal_buffer_view_t** results, iree_host_size_t result_count,
+ iree_allocator_t allocator, iree_rt_invocation_t** out_invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_create");
+
+ if (!out_invocation) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_invocation = nullptr;
+
+ if (!context || !function || !function->module) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if (dependencies &&
+ (dependencies->invocation_count && !dependencies->invocations)) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ } else if ((argument_count && !arguments) || (result_count && !results)) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ // TODO(benvanik): unwrap without needing to retain here.
+ absl::InlinedVector<ref_ptr<Invocation>, 4> dependent_invocations;
+ if (dependencies) {
+ dependent_invocations.resize(dependencies->invocation_count);
+ for (int i = 0; i < dependencies->invocation_count; ++i) {
+ dependent_invocations[i] =
+ add_ref(reinterpret_cast<Invocation*>(dependencies->invocations[i]));
+ }
+ }
+
+ // TODO(benvanik): unwrap without needing to retain here.
+ absl::InlinedVector<hal::BufferView, 8> argument_views(argument_count);
+ for (int i = 0; i < argument_count; ++i) {
+ const auto* buffer_view =
+ reinterpret_cast<const hal::BufferView*>(arguments[i]);
+ if (!buffer_view) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ argument_views[i] =
+ hal::BufferView{add_ref(buffer_view->buffer), buffer_view->shape,
+ buffer_view->element_size};
+ }
+
+ // TODO(benvanik): unwrap without needing to retain here.
+ absl::InlinedVector<hal::BufferView, 8> result_views(result_count);
+ for (int i = 0; i < result_count; ++i) {
+ const auto* buffer_view =
+ reinterpret_cast<const hal::BufferView*>(results[i]);
+ if (buffer_view) {
+ result_views[i] =
+ hal::BufferView{add_ref(buffer_view->buffer), buffer_view->shape,
+ buffer_view->element_size};
+ }
+ }
+
+ auto invocation_or = Invocation::Create(
+ add_ref(reinterpret_cast<Context*>(context)),
+ Function{reinterpret_cast<Module*>(function->module),
+ static_cast<Function::Linkage>(function->linkage),
+ function->ordinal},
+ add_ref(reinterpret_cast<Policy*>(policy)),
+ std::move(dependent_invocations), std::move(argument_views),
+ std::move(result_views));
+ IREE_API_RETURN_IF_ERROR(invocation_or.status());
+ auto invocation = std::move(invocation_or).ValueOrDie();
+
+ *out_invocation =
+ reinterpret_cast<iree_rt_invocation_t*>(invocation.release());
+
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_retain(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_retain");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->AddReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_release(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_release");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ handle->ReleaseReference();
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_query_status(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_query_status");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->QueryStatus());
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
+ iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
+ iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
+ iree_host_size_t* out_result_count) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_consume_results");
+
+ if (!out_result_count) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ *out_result_count = 0;
+ if (!out_results) {
+ std::memset(out_results, 0,
+ sizeof(iree_hal_buffer_view_t*) * result_capacity);
+ }
+
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+
+ const auto& function = handle->function();
+ int32_t result_count = function.signature().result_count();
+ *out_result_count = result_count;
+ if (!out_results) {
+ return IREE_STATUS_OK;
+ } else if (result_capacity < result_count) {
+ return IREE_STATUS_OUT_OF_RANGE;
+ }
+
+ auto results_or = handle->ConsumeResults();
+ IREE_API_RETURN_IF_ERROR(results_or.status());
+ auto results = std::move(results_or).ValueOrDie();
+ iree_status_t status = IREE_STATUS_OK;
+ int i = 0;
+ for (i = 0; i < results.size(); ++i) {
+ iree_shape_t shape;
+ status = ToApiShape(results[i].shape, &shape);
+ if (status != IREE_STATUS_OK) break;
+ status = iree_hal_buffer_view_create(
+ reinterpret_cast<iree_hal_buffer_t*>(results[i].buffer.get()), shape,
+ results[i].element_size, allocator, &out_results[i]);
+ if (status != IREE_STATUS_OK) break;
+ }
+ if (status != IREE_STATUS_OK) {
+ // Release already-retained buffer views on failure.
+ for (int j = 0; j < i; ++j) {
+ iree_hal_buffer_view_release(out_results[j]);
+ }
+ std::memset(out_results, 0,
+ sizeof(iree_hal_buffer_view_t*) * result_capacity);
+ }
+ return status;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
+ iree_rt_invocation_t* invocation, iree_time_t deadline) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_await");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->Await(ToAbslTime(deadline)));
+ return IREE_STATUS_OK;
+}
+
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_abort(iree_rt_invocation_t* invocation) {
+ IREE_TRACE_SCOPE0("iree_rt_invocation_abort");
+ auto* handle = reinterpret_cast<Invocation*>(invocation);
+ if (!handle) {
+ return IREE_STATUS_INVALID_ARGUMENT;
+ }
+ IREE_API_RETURN_IF_ERROR(handle->Abort());
+ return IREE_STATUS_OK;
+}
+
+} // namespace rt
+} // namespace iree
diff --git a/rt/api.h b/rt/api.h
new file mode 100644
index 0000000..f5ed7bd
--- /dev/null
+++ b/rt/api.h
@@ -0,0 +1,366 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// See iree/base/api.h for documentation on the API conventions used.
+
+#ifndef IREE_RT_API_H_
+#define IREE_RT_API_H_
+
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Types and Enums
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_rt_instance iree_rt_instance_t;
+typedef struct iree_rt_context iree_rt_context_t;
+typedef struct iree_rt_policy iree_rt_policy_t;
+typedef struct iree_rt_module iree_rt_module_t;
+typedef struct iree_rt_invocation iree_rt_invocation_t;
+
+// Describes the type of a function reference.
+typedef enum {
+ // Function is internal to the module and may not be reflectable.
+ IREE_RT_FUNCTION_LINKAGE_INTERNAL = 0,
+ // Function is an import from another module.
+ IREE_RT_FUNCTION_LINKAGE_IMPORT = 1,
+ // Function is an export from the module.
+ IREE_RT_FUNCTION_LINKAGE_EXPORT = 2,
+} iree_rt_function_linkage_t;
+
+// A function reference that can be used with the iree_rt_function_* methods.
+// These should be treated as opaque and the accessor functions should be used
+// instead.
+typedef struct {
+ // Module the function is contained within.
+ iree_rt_module_t* module;
+ // Linkage of the function. Note that IREE_RT_FUNCTION_LINKAGE_INTERNAL
+ // functions may be missing reflection information.
+ iree_rt_function_linkage_t linkage;
+ // Ordinal within the module in the linkage scope.
+ int32_t ordinal;
+} iree_rt_function_t;
+
+// Describes the expected calling convention and arguments/results of a
+// function.
+typedef struct {
+ // Total number of arguments to the function.
+ int32_t argument_count;
+ // Total number of results from the function.
+ int32_t result_count;
+} iree_rt_function_signature_t;
+
+// Describes the imports, exports, and capabilities of a module.
+typedef struct {
+ // Total number of imported functions.
+ int32_t import_function_count;
+ // Total number of exported functions.
+ int32_t export_function_count;
+ // Total number of internal functions, if debugging info is present and they
+ // can be queried.
+ int32_t internal_function_count;
+ // Total number of state block resource slots consumed.
+ int32_t state_slot_count;
+} iree_rt_module_signature_t;
+
+// Dependency information used to order invocations.
+typedef struct {
+ // Prior invocations that must complete before the new invocation begins.
+ iree_rt_invocation_t** invocations;
+ iree_host_size_t invocation_count;
+
+ // TODO(benvanik): wait semaphores/importing.
+} iree_rt_invocation_dependencies_t;
+
+// Defines an external module that can be used to reflect and execute functions.
+// Modules must be thread-safe as lookups and executions may occur in any order
+// from any thread.
+//
+// Modules will have their resolve_imports function called upon registration
+// with a context and may use the provided resolver to find imported functions.
+typedef struct {
+ // User-defined pointer passed to all functions.
+ void* self;
+ // Destroys |self| when all references to the module have been released.
+ iree_status_t(IREE_API_PTR* destroy)(void* self);
+ // Returns the name of the module (used during resolution).
+ iree_string_view_t(IREE_API_PTR* name)(void* self);
+ // Sets |out_module_signature| to the reflected signature of the module.
+ iree_rt_module_signature_t(IREE_API_PTR* signature)(void* self);
+ // Sets |out_function| to a resolved function by ordinal, if found.
+ iree_status_t(IREE_API_PTR* lookup_function_by_ordinal)(
+ void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
+ iree_rt_function_t* out_function);
+ // Sets |out_function| to a resolved function by name, if found.
+ iree_status_t(IREE_API_PTR* lookup_function_by_name)(
+ void* self, iree_rt_function_linkage_t linkage, iree_string_view_t name,
+ iree_rt_function_t* out_function);
+ // Sets |out_name| to the name of the function with the given ordinal, if
+ // found.
+ iree_status_t(IREE_API_PTR* get_function_name)(
+ void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
+ iree_string_view_t* out_name);
+ // Sets |out_signature| to the reflected signature of the given
+ // function, if found.
+ iree_status_t(IREE_API_PTR* get_function_signature)(
+ void* self, iree_rt_function_linkage_t linkage, int32_t ordinal,
+ iree_rt_function_signature_t* out_signature);
+} iree_rt_external_module_t;
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Instance
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a new instance. This should be shared with all contexts in an
+// application to ensure that resources are tracked properly and threads are
+// managed correctly.
+// |out_instance| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
+ iree_allocator_t allocator, iree_rt_instance_t** out_instance);
+
+// Retains the given |instance| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_retain(iree_rt_instance_t* instance);
+
+// Releases the given |instance| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_instance_release(iree_rt_instance_t* instance);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Module
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a module with an external backing implementation.
+// The provided |external_module| definition will be used to query the module
+// state as needed. No caching occurs within the implementation to allow calls
+// to return different values per-invocation.
+//
+// |out_module| must be released by the caller.
+// iree_rt_external_module_t::destroy is called when the last reference to the
+// iree_rt_module_t is released.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
+ iree_rt_external_module_t impl, iree_allocator_t allocator,
+ iree_rt_module_t** out_module);
+
+// Retains the given |module| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_retain(iree_rt_module_t* module);
+
+// Releases the given |module| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_release(iree_rt_module_t* module);
+
+// Returns the name of the module.
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_module_name(const iree_rt_module_t* module);
+
+// Sets |out_function| to a function with |ordinal| in the given linkage or
+// returns IREE_STATUS_NOT_FOUND. The function reference is valid for the
+// lifetime of |module|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ int32_t ordinal,
+ iree_rt_function_t* out_function);
+
+// Sets |out_function| to a function with |name| in the given linkage or returns
+// IREE_STATUS_NOT_FOUND. The function reference is valid for the lifetime of
+// |module|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
+ iree_rt_function_linkage_t linkage,
+ iree_string_view_t name,
+ iree_rt_function_t* out_function);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Function
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Returns the name of the function as exported from the module.
+IREE_API_EXPORT iree_string_view_t IREE_API_CALL
+iree_rt_function_name(const iree_rt_function_t* function);
+
+// Sets |out_function_signature| to the reflected signature of the function.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_function_signature(const iree_rt_function_t* function,
+ iree_rt_function_signature_t* out_signature);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Policy
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// TODO(benvanik): define policies. For now they are no-ops.
+IREE_API_EXPORT iree_status_t iree_rt_policy_create(
+ iree_allocator_t allocator, iree_rt_policy_t** out_policy);
+
+// Retains the given |policy| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_retain(iree_rt_policy_t* policy);
+
+// Releases the given |policy| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_policy_release(iree_rt_policy_t* policy);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Context
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a new context that uses the given |instance| for device management.
+// |out_context| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
+ iree_rt_instance_t* instance, iree_rt_policy_t* policy,
+ iree_allocator_t allocator, iree_rt_context_t** out_context);
+
+// Retains the given |context| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_retain(iree_rt_context_t* context);
+
+// Releases the given |context| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_context_release(iree_rt_context_t* context);
+
+// Returns a process-unique ID for the |context|.
+IREE_API_EXPORT int32_t IREE_API_CALL
+iree_rt_context_id(const iree_rt_context_t* context);
+
+// Registers a list of modules with the context and resolves imports.
+// The modules will be retained by the context until destruction.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules(
+ iree_rt_context_t* context, const iree_rt_module_t** modules,
+ iree_host_size_t module_count);
+
+// Returns a reference to the module registered with the given name or nullptr
+// if not found. The caller must retain the returned module if they want to
+// continue using it.
+IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
+iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
+ iree_string_view_t module_name);
+
+// Sets |out_function| to to an exported function with the fully-qualified name
+// of |full_name| or returns IREE_STATUS_NOT_FOUND. The function reference is
+// valid for the lifetime of |context|.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
+ const iree_rt_context_t* context, iree_string_view_t full_name,
+ iree_rt_function_t* out_function);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+//===----------------------------------------------------------------------===//
+// iree::rt::Invocation
+//===----------------------------------------------------------------------===//
+
+#ifndef IREE_API_NO_PROTOTYPES
+
+// Creates a new invocation tracking object for invoking the given |function|
+// from |context|. |arguments| will be retained until the invocation is made.
+// If |dependencies| are provided then the invocation will wait until they are
+// resolved before executing. If a |policy| is provided it will override the
+// context-level policy.
+//
+// Optionally |results| may be provided with preallocated buffers that will
+// receive the outputs of the invocation. Invocation will fail if they do not
+// match expected sizes.
+//
+// Note that it's possible for the invocation to complete prior to the return of
+// this function. Any errors that occur will be set on the invocation and
+// callers should query its state prior to assuming it is in-flight.
+//
+// |out_invocation| must be released by the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
+ iree_rt_context_t* context, iree_rt_function_t* function,
+ iree_rt_policy_t* policy,
+ const iree_rt_invocation_dependencies_t* dependencies,
+ const iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
+ const iree_hal_buffer_view_t** results, iree_host_size_t result_count,
+ iree_allocator_t allocator, iree_rt_invocation_t** out_invocation);
+
+// Retains the given |invocation| for the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_retain(iree_rt_invocation_t* invocation);
+
+// Releases the given |invocation| from the caller.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_release(iree_rt_invocation_t* invocation);
+
+// Queries the completion status of the invocation.
+// Returns one of the following:
+// IREE_STATUS_OK: the invocation completed successfully.
+// IREE_STATUS_UNAVAILABLE: the invocation has not yet completed.
+// IREE_STATUS_CANCELLED: the invocation was cancelled internally.
+// IREE_STATUS_ABORTED: the invocation was aborted.
+// IREE_STATUS_*: an error occurred during invocation.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_query_status(iree_rt_invocation_t* invocation);
+
+// Populates |out_results| to the values of the results.
+// |result_capacity| defines the number of elements available in |out_results|
+// and |out_result_count| will be set with the actual number of results
+// available. If |result_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be
+// returned wtih the required capacity in |out_result_count|. To only query the
+// required capacity |out_results| may be passed as nullptr.
+//
+// Ownership of returned results will be transferred to the caller and they must
+// be released if no longer needed.
+//
+// Returns errors as with iree_rt_invocation_query_status, for example in the
+// case of not-yet-completed or aborted invocations.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
+ iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
+ iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
+ iree_host_size_t* out_result_count);
+
+// Blocks the caller until the invocation completes (successfully or otherwise).
+//
+// Returns IREE_STATUS_DEADLINE_EXCEEDED if |deadline| elapses before the
+// invocation completes and otherwise returns iree_rt_invocation_query_status.
+IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
+ iree_rt_invocation_t* invocation, iree_time_t deadline);
+
+// Attempts to abort the invocation if it is in-flight.
+// A no-op if the invocation has already completed.
+IREE_API_EXPORT iree_status_t IREE_API_CALL
+iree_rt_invocation_abort(iree_rt_invocation_t* invocation);
+
+#endif // IREE_API_NO_PROTOTYPES
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_RT_API_H_