blob: 6e1e5c5ff032dbc8666ced9e74e1bf8f069cc6f8 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/rt/api.h"
#include "absl/time/time.h"
#include "iree/base/api.h"
#include "iree/base/api_util.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
#include "iree/hal/buffer_view.h"
#include "iree/rt/context.h"
#include "iree/rt/debug/debug_server.h"
#include "iree/rt/function.h"
#include "iree/rt/instance.h"
#include "iree/rt/invocation.h"
#include "iree/rt/module.h"
#include "iree/rt/policy.h"
namespace iree {
namespace rt {
//===----------------------------------------------------------------------===//
// iree::rt::Instance
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create(
iree_allocator_t allocator, iree_rt_instance_t** out_instance) {
IREE_TRACE_SCOPE0("iree_rt_instance_create");
if (!out_instance) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_instance = nullptr;
auto instance = make_ref<Instance>();
*out_instance = reinterpret_cast<iree_rt_instance_t*>(instance.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_instance_retain(iree_rt_instance_t* instance) {
IREE_TRACE_SCOPE0("iree_rt_instance_retain");
auto* handle = reinterpret_cast<Instance*>(instance);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->AddReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_instance_release(iree_rt_instance_t* instance) {
IREE_TRACE_SCOPE0("iree_rt_instance_release");
auto* handle = reinterpret_cast<Instance*>(instance);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->ReleaseReference();
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::rt::Module
//===----------------------------------------------------------------------===//
namespace {
class ExternalModule final : public Module {
public:
ExternalModule(iree_rt_external_module_t impl, iree_allocator_t allocator)
: impl_(impl), allocator_(allocator) {
IREE_TRACE_SCOPE0("ExternalModule::ctor");
}
~ExternalModule() override {
IREE_TRACE_SCOPE0("ExternalModule::dtor");
impl_.destroy(impl_.self);
std::memset(&impl_, 0, sizeof(impl_));
}
absl::string_view name() const override {
auto result = impl_.name(impl_.self);
return absl::string_view{result.data, result.size};
}
const ModuleSignature signature() const override {
auto signature = impl_.signature(impl_.self);
return ModuleSignature{
signature.import_function_count,
signature.export_function_count,
signature.internal_function_count,
signature.state_slot_count,
};
}
SourceResolver* source_resolver() const override { return nullptr; }
std::string DebugStringShort() const override { return std::string(name()); }
StatusOr<const Function> LookupFunctionByOrdinal(
Function::Linkage linkage, int32_t ordinal) const override {
IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByOrdinal");
iree_rt_function_t function;
auto status = impl_.lookup_function_by_ordinal(
impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
&function);
if (status != IREE_STATUS_OK) {
return FromApiStatus(status, IREE_LOC);
}
return Function{reinterpret_cast<Module*>(function.module),
static_cast<Function::Linkage>(function.linkage),
function.ordinal};
}
StatusOr<const Function> LookupFunctionByName(
Function::Linkage linkage, absl::string_view name) const override {
IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByName");
iree_rt_function_t function;
auto status = impl_.lookup_function_by_name(
impl_.self, static_cast<iree_rt_function_linkage_t>(linkage),
iree_string_view_t{name.data(), name.size()}, &function);
if (status != IREE_STATUS_OK) {
return FromApiStatus(status, IREE_LOC);
}
return Function{reinterpret_cast<Module*>(function.module),
static_cast<Function::Linkage>(function.linkage),
function.ordinal};
}
StatusOr<absl::string_view> GetFunctionName(Function::Linkage linkage,
int32_t ordinal) const {
IREE_TRACE_SCOPE0("ExternalModule::GetFunctionName");
iree_string_view_t name;
auto status = impl_.get_function_name(
impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
&name);
RETURN_IF_ERROR(FromApiStatus(status, IREE_LOC));
return absl::string_view{name.data, name.size};
}
StatusOr<const FunctionSignature> GetFunctionSignature(
Function::Linkage linkage, int32_t ordinal) const override {
IREE_TRACE_SCOPE0("ExternalModule::GetFunctionSignature");
iree_rt_function_signature_t signature;
auto status = impl_.get_function_signature(
impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal,
&signature);
if (status != IREE_STATUS_OK) {
return FromApiStatus(status, IREE_LOC);
}
return FunctionSignature{signature.argument_count, signature.result_count};
}
Status Execute(
const Function function,
absl::InlinedVector<hal::BufferView, 8> arguments,
absl::InlinedVector<hal::BufferView, 8>* results) const override {
// TODO(benvanik): fn ptr callback to external code. Waiting on fibers.
return UnimplementedErrorBuilder(IREE_LOC)
<< "External calls not yet implemented";
}
private:
iree_rt_external_module_t impl_;
iree_allocator_t allocator_;
};
} // namespace
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external(
iree_rt_external_module_t impl, iree_allocator_t allocator,
iree_rt_module_t** out_module) {
IREE_TRACE_SCOPE0("iree_rt_module_create_external");
if (!out_module) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_module = nullptr;
auto module = make_ref<ExternalModule>(impl, allocator);
*out_module = reinterpret_cast<iree_rt_module_t*>(module.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_module_retain(iree_rt_module_t* module) {
IREE_TRACE_SCOPE0("iree_rt_module_retain");
auto* handle = reinterpret_cast<Module*>(module);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->AddReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_module_release(iree_rt_module_t* module) {
IREE_TRACE_SCOPE0("iree_rt_module_release");
auto* handle = reinterpret_cast<Module*>(module);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->ReleaseReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_string_view_t IREE_API_CALL
iree_rt_module_name(const iree_rt_module_t* module) {
IREE_TRACE_SCOPE0("iree_rt_module_name");
const auto* handle = reinterpret_cast<const Module*>(module);
CHECK(handle) << "NULL module handle";
return iree_string_view_t{handle->name().data(), handle->name().size()};
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module,
iree_rt_function_linkage_t linkage,
int32_t ordinal,
iree_rt_function_t* out_function) {
IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_ordinal");
if (!out_function) {
return IREE_STATUS_INVALID_ARGUMENT;
}
std::memset(out_function, 0, sizeof(*out_function));
auto* handle = reinterpret_cast<Module*>(module);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto function_or = handle->LookupFunctionByOrdinal(
static_cast<Function::Linkage>(linkage), ordinal);
IREE_API_RETURN_IF_ERROR(function_or.status());
auto function = std::move(function_or).ValueOrDie();
out_function->module = module;
out_function->linkage =
static_cast<iree_rt_function_linkage_t>(function.linkage());
out_function->ordinal = function.ordinal();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_module_lookup_function_by_name(iree_rt_module_t* module,
iree_rt_function_linkage_t linkage,
iree_string_view_t name,
iree_rt_function_t* out_function) {
IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_name");
if (!out_function) {
return IREE_STATUS_INVALID_ARGUMENT;
}
std::memset(out_function, 0, sizeof(*out_function));
auto* handle = reinterpret_cast<Module*>(module);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto function_or =
handle->LookupFunctionByName(static_cast<Function::Linkage>(linkage),
absl::string_view{name.data, name.size});
IREE_API_RETURN_IF_ERROR(function_or.status());
auto function = std::move(function_or).ValueOrDie();
out_function->linkage =
static_cast<iree_rt_function_linkage_t>(function.linkage());
out_function->module = module;
out_function->linkage = linkage;
out_function->ordinal = function.ordinal();
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::rt::Function
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_string_view_t IREE_API_CALL
iree_rt_function_name(const iree_rt_function_t* function) {
IREE_TRACE_SCOPE0("iree_rt_function_name");
CHECK(function && function->module) << "NULL function handle";
auto* module = reinterpret_cast<Module*>(function->module);
auto name_or = module->GetFunctionName(
static_cast<Function::Linkage>(function->linkage), function->ordinal);
if (!name_or.ok()) return {};
auto name = name_or.ValueOrDie();
return iree_string_view_t{name.data(), name.size()};
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_function_signature(const iree_rt_function_t* function,
iree_rt_function_signature_t* out_signature) {
IREE_TRACE_SCOPE0("iree_rt_function_signature");
if (!out_signature) {
return IREE_STATUS_INVALID_ARGUMENT;
}
std::memset(out_signature, 0, sizeof(*out_signature));
if (!function || !function->module) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto* module = reinterpret_cast<Module*>(function->module);
auto signature_or = module->GetFunctionSignature(
static_cast<Function::Linkage>(function->linkage), function->ordinal);
IREE_API_RETURN_IF_ERROR(signature_or.status());
auto signature = std::move(signature_or).ValueOrDie();
out_signature->argument_count = signature.argument_count();
out_signature->result_count = signature.result_count();
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::rt::Policy
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t iree_rt_policy_create(
iree_allocator_t allocator, iree_rt_policy_t** out_policy) {
IREE_TRACE_SCOPE0("iree_rt_policy_create");
if (!out_policy) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_policy = nullptr;
auto policy = make_ref<Policy>();
*out_policy = reinterpret_cast<iree_rt_policy_t*>(policy.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_policy_retain(iree_rt_policy_t* policy) {
IREE_TRACE_SCOPE0("iree_rt_policy_retain");
auto* handle = reinterpret_cast<Policy*>(policy);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->AddReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_policy_release(iree_rt_policy_t* policy) {
IREE_TRACE_SCOPE0("iree_rt_policy_release");
auto* handle = reinterpret_cast<Policy*>(policy);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->ReleaseReference();
return IREE_STATUS_OK;
}
//===----------------------------------------------------------------------===//
// iree::rt::Context
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create(
iree_rt_instance_t* instance, iree_rt_policy_t* policy,
iree_allocator_t allocator, iree_rt_context_t** out_context) {
IREE_TRACE_SCOPE0("iree_rt_context_create");
if (!out_context) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_context = nullptr;
if (!instance || !policy) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto context =
make_ref<Context>(add_ref(reinterpret_cast<Instance*>(instance)),
add_ref(reinterpret_cast<Policy*>(policy)));
*out_context = reinterpret_cast<iree_rt_context_t*>(context.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_context_retain(iree_rt_context_t* context) {
IREE_TRACE_SCOPE0("iree_rt_context_retain");
auto* handle = reinterpret_cast<Context*>(context);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->AddReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_context_release(iree_rt_context_t* context) {
IREE_TRACE_SCOPE0("iree_rt_context_release");
auto* handle = reinterpret_cast<Context*>(context);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->ReleaseReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT int32_t IREE_API_CALL
iree_rt_context_id(const iree_rt_context_t* context) {
IREE_TRACE_SCOPE0("iree_rt_context_id");
const auto* handle = reinterpret_cast<const Context*>(context);
CHECK(handle) << "NULL context handle";
return handle->id();
}
IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL
iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context,
iree_string_view_t module_name) {
IREE_TRACE_SCOPE0("iree_rt_context_lookup_module_by_name");
const auto* handle = reinterpret_cast<const Context*>(context);
CHECK(handle) << "NULL context handle";
auto module_or = handle->LookupModuleByName(
absl::string_view{module_name.data, module_name.size});
if (!module_or.ok()) {
return nullptr;
}
return reinterpret_cast<iree_rt_module_t*>(module_or.ValueOrDie());
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function(
const iree_rt_context_t* context, iree_string_view_t full_name,
iree_rt_function_t* out_function) {
IREE_TRACE_SCOPE0("iree_rt_context_resolve_function");
if (!out_function) {
return IREE_STATUS_INVALID_ARGUMENT;
}
std::memset(out_function, 0, sizeof(*out_function));
const auto* handle = reinterpret_cast<const Context*>(context);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto full_name_view = absl::string_view{full_name.data, full_name.size};
size_t last_dot = full_name_view.rfind('.');
if (last_dot == absl::string_view::npos) {
return IREE_STATUS_INVALID_ARGUMENT;
}
auto module_name = full_name_view.substr(0, last_dot);
auto function_name = full_name_view.substr(last_dot + 1);
iree_rt_module_t* module = iree_rt_context_lookup_module_by_name(
context, iree_string_view_t{module_name.data(), module_name.size()});
if (!module) {
return IREE_STATUS_NOT_FOUND;
}
return iree_rt_module_lookup_function_by_name(
module, IREE_RT_FUNCTION_LINKAGE_EXPORT,
iree_string_view_t{function_name.data(), function_name.size()},
out_function);
}
//===----------------------------------------------------------------------===//
// iree::rt::Invocation
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create(
iree_rt_context_t* context, iree_rt_function_t* function,
iree_rt_policy_t* policy,
const iree_rt_invocation_dependencies_t* dependencies,
const iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count,
const iree_hal_buffer_view_t** results, iree_host_size_t result_count,
iree_allocator_t allocator, iree_rt_invocation_t** out_invocation) {
IREE_TRACE_SCOPE0("iree_rt_invocation_create");
if (!out_invocation) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_invocation = nullptr;
if (!context || !function || !function->module) {
return IREE_STATUS_INVALID_ARGUMENT;
} else if (dependencies &&
(dependencies->invocation_count && !dependencies->invocations)) {
return IREE_STATUS_INVALID_ARGUMENT;
} else if ((argument_count && !arguments) || (result_count && !results)) {
return IREE_STATUS_INVALID_ARGUMENT;
}
// TODO(benvanik): unwrap without needing to retain here.
absl::InlinedVector<ref_ptr<Invocation>, 4> dependent_invocations;
if (dependencies) {
dependent_invocations.resize(dependencies->invocation_count);
for (int i = 0; i < dependencies->invocation_count; ++i) {
dependent_invocations[i] =
add_ref(reinterpret_cast<Invocation*>(dependencies->invocations[i]));
}
}
// TODO(benvanik): unwrap without needing to retain here.
absl::InlinedVector<hal::BufferView, 8> argument_views(argument_count);
for (int i = 0; i < argument_count; ++i) {
const auto* buffer_view =
reinterpret_cast<const hal::BufferView*>(arguments[i]);
if (!buffer_view) {
return IREE_STATUS_INVALID_ARGUMENT;
}
argument_views[i] =
hal::BufferView{add_ref(buffer_view->buffer), buffer_view->shape,
buffer_view->element_size};
}
// TODO(benvanik): unwrap without needing to retain here.
absl::InlinedVector<hal::BufferView, 8> result_views(result_count);
for (int i = 0; i < result_count; ++i) {
const auto* buffer_view =
reinterpret_cast<const hal::BufferView*>(results[i]);
if (buffer_view) {
result_views[i] =
hal::BufferView{add_ref(buffer_view->buffer), buffer_view->shape,
buffer_view->element_size};
}
}
auto invocation_or = Invocation::Create(
add_ref(reinterpret_cast<Context*>(context)),
Function{reinterpret_cast<Module*>(function->module),
static_cast<Function::Linkage>(function->linkage),
function->ordinal},
add_ref(reinterpret_cast<Policy*>(policy)),
std::move(dependent_invocations), std::move(argument_views),
std::move(result_views));
IREE_API_RETURN_IF_ERROR(invocation_or.status());
auto invocation = std::move(invocation_or).ValueOrDie();
*out_invocation =
reinterpret_cast<iree_rt_invocation_t*>(invocation.release());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_invocation_retain(iree_rt_invocation_t* invocation) {
IREE_TRACE_SCOPE0("iree_rt_invocation_retain");
auto* handle = reinterpret_cast<Invocation*>(invocation);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->AddReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_invocation_release(iree_rt_invocation_t* invocation) {
IREE_TRACE_SCOPE0("iree_rt_invocation_release");
auto* handle = reinterpret_cast<Invocation*>(invocation);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
handle->ReleaseReference();
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_invocation_query_status(iree_rt_invocation_t* invocation) {
IREE_TRACE_SCOPE0("iree_rt_invocation_query_status");
auto* handle = reinterpret_cast<Invocation*>(invocation);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(handle->QueryStatus());
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results(
iree_rt_invocation_t* invocation, iree_host_size_t result_capacity,
iree_allocator_t allocator, iree_hal_buffer_view_t** out_results,
iree_host_size_t* out_result_count) {
IREE_TRACE_SCOPE0("iree_rt_invocation_consume_results");
if (!out_result_count) {
return IREE_STATUS_INVALID_ARGUMENT;
}
*out_result_count = 0;
if (!out_results) {
std::memset(out_results, 0,
sizeof(iree_hal_buffer_view_t*) * result_capacity);
}
auto* handle = reinterpret_cast<Invocation*>(invocation);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
const auto& function = handle->function();
int32_t result_count = function.signature().result_count();
*out_result_count = result_count;
if (!out_results) {
return IREE_STATUS_OK;
} else if (result_capacity < result_count) {
return IREE_STATUS_OUT_OF_RANGE;
}
auto results_or = handle->ConsumeResults();
IREE_API_RETURN_IF_ERROR(results_or.status());
auto results = std::move(results_or).ValueOrDie();
iree_status_t status = IREE_STATUS_OK;
int i = 0;
for (i = 0; i < results.size(); ++i) {
iree_shape_t shape;
status = ToApiShape(results[i].shape, &shape);
if (status != IREE_STATUS_OK) break;
status = iree_hal_buffer_view_create(
reinterpret_cast<iree_hal_buffer_t*>(results[i].buffer.get()), shape,
results[i].element_size, allocator, &out_results[i]);
if (status != IREE_STATUS_OK) break;
}
if (status != IREE_STATUS_OK) {
// Release already-retained buffer views on failure.
for (int j = 0; j < i; ++j) {
iree_hal_buffer_view_release(out_results[j]);
}
std::memset(out_results, 0,
sizeof(iree_hal_buffer_view_t*) * result_capacity);
}
return status;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await(
iree_rt_invocation_t* invocation, iree_time_t deadline) {
IREE_TRACE_SCOPE0("iree_rt_invocation_await");
auto* handle = reinterpret_cast<Invocation*>(invocation);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(handle->Await(ToAbslTime(deadline)));
return IREE_STATUS_OK;
}
IREE_API_EXPORT iree_status_t IREE_API_CALL
iree_rt_invocation_abort(iree_rt_invocation_t* invocation) {
IREE_TRACE_SCOPE0("iree_rt_invocation_abort");
auto* handle = reinterpret_cast<Invocation*>(invocation);
if (!handle) {
return IREE_STATUS_INVALID_ARGUMENT;
}
IREE_API_RETURN_IF_ERROR(handle->Abort());
return IREE_STATUS_OK;
}
} // namespace rt
} // namespace iree