blob: a277b3fda0e05207d8aae95040d810c3a6009b63 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bindings/java/com/google/iree/native/context_wrapper.h"
#include <vector>
#include "iree/base/api.h"
#include "iree/base/logging.h"
#include "iree/vm/ref_cc.h"
namespace iree {
namespace java {
namespace {
std::vector<iree_vm_module_t*> GetModulesFromModuleWrappers(
const std::vector<ModuleWrapper*>& module_wrappers) {
std::vector<iree_vm_module_t*> modules(module_wrappers.size());
for (int i = 0; i < module_wrappers.size(); i++) {
modules[i] = module_wrappers[i]->module();
}
return modules;
}
} // namespace
Status ContextWrapper::Create(const InstanceWrapper& instance_wrapper) {
IREE_RETURN_IF_ERROR(iree_vm_context_create(
instance_wrapper.instance(), iree_allocator_system(), &context_));
IREE_RETURN_IF_ERROR(CreateDefaultModules());
std::vector<iree_vm_module_t*> default_modules = {hal_module_};
IREE_RETURN_IF_ERROR(iree_vm_context_register_modules(
context_, default_modules.data(), default_modules.size()));
return OkStatus();
}
Status ContextWrapper::CreateWithModules(
const InstanceWrapper& instance_wrapper,
const std::vector<ModuleWrapper*>& module_wrappers) {
auto modules = GetModulesFromModuleWrappers(module_wrappers);
IREE_RETURN_IF_ERROR(CreateDefaultModules());
// The ordering of modules matters, so default modules need to be at the
// beginning of the vector.
modules.insert(modules.begin(), hal_module_);
IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules(
instance_wrapper.instance(), modules.data(), modules.size(),
iree_allocator_system(), &context_));
return OkStatus();
}
Status ContextWrapper::RegisterModules(
const std::vector<ModuleWrapper*>& module_wrappers) {
auto modules = GetModulesFromModuleWrappers(module_wrappers);
IREE_RETURN_IF_ERROR(iree_vm_context_register_modules(
context_, modules.data(), modules.size()));
return OkStatus();
}
Status ContextWrapper::ResolveFunction(iree_string_view_t name,
FunctionWrapper* function_wrapper) {
return iree_vm_context_resolve_function(context_, name,
function_wrapper->function());
}
Status ContextWrapper::InvokeFunction(const FunctionWrapper& function_wrapper,
const std::vector<float*>& inputs,
int input_element_count, float* output) {
vm::ref<iree_vm_list_t> input_list;
IREE_RETURN_IF_ERROR(iree_vm_list_create(
/*element_type=*/nullptr, input_element_count, iree_allocator_system(),
&input_list));
iree_hal_allocator_t* allocator = iree_hal_device_allocator(device_);
iree_hal_memory_type_t input_memory_type =
static_cast<iree_hal_memory_type_t>(IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
iree_hal_buffer_usage_t input_buffer_usage =
static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_ALL |
IREE_HAL_BUFFER_USAGE_CONSTANT);
for (auto input : inputs) {
// Write the input into a mappable buffer.
iree_hal_buffer_t* input_buffer = nullptr;
IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer(
allocator, input_memory_type, input_buffer_usage,
sizeof(float) * input_element_count, &input_buffer));
IREE_RETURN_IF_ERROR(iree_hal_buffer_write_data(
input_buffer, 0, input, input_element_count * sizeof(float)));
// Wrap the input buffers in buffer views.
iree_hal_buffer_view_t* input_buffer_view = nullptr;
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create(
input_buffer, /*shape=*/&input_element_count,
/*shape_rank=*/1, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
iree_allocator_system(), &input_buffer_view));
iree_hal_buffer_release(input_buffer);
// Marshal the input buffer views through the input VM variant list.
auto input_buffer_view_ref =
iree_hal_buffer_view_move_ref(input_buffer_view);
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(input_list.get(), &input_buffer_view_ref));
}
// Prepare outputs list to accept results from the invocation.
vm::ref<iree_vm_list_t> outputs;
IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr,
4 * sizeof(float),
iree_allocator_system(), &outputs));
// Synchronously invoke the function.
IREE_RETURN_IF_ERROR(iree_vm_invoke(context_, *function_wrapper.function(),
/*policy=*/nullptr, input_list.get(),
outputs.get(), iree_allocator_system()));
// Read back the results into the given output buffer.
auto* output_buffer_view =
reinterpret_cast<iree_hal_buffer_view_t*>(iree_vm_list_get_ref_deref(
outputs.get(), 0, iree_hal_buffer_view_get_descriptor()));
auto* output_buffer = iree_hal_buffer_view_buffer(output_buffer_view);
iree_hal_mapped_memory_t mapped_memory;
IREE_RETURN_IF_ERROR(iree_hal_buffer_map(output_buffer,
IREE_HAL_MEMORY_ACCESS_READ, 0,
IREE_WHOLE_BUFFER, &mapped_memory));
memcpy(output, mapped_memory.contents.data,
mapped_memory.contents.data_length);
iree_hal_buffer_unmap(output_buffer, &mapped_memory);
return OkStatus();
}
int ContextWrapper::id() const { return iree_vm_context_id(context_); }
ContextWrapper::~ContextWrapper() {
iree_vm_context_release(context_);
iree_vm_module_release(hal_module_);
iree_hal_device_release(device_);
iree_hal_driver_release(driver_);
}
// TODO(jennik): Also create default string and tensorlist modules.
Status ContextWrapper::CreateDefaultModules() {
IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create_by_name(
iree_hal_driver_registry_default(),
iree_make_cstring_view("vmla"), iree_allocator_system(), &driver_));
IREE_RETURN_IF_ERROR(iree_hal_driver_create_default_device(
driver_, iree_allocator_system(), &device_));
IREE_RETURN_IF_ERROR(
iree_hal_module_create(device_, iree_allocator_system(), &hal_module_));
return OkStatus();
}
} // namespace java
} // namespace iree