blob: a05ac2719fa7e520d67caf85d78dd55cfafa1c06 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "samples/custom_modules/module.h"
#include <atomic>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "iree/base/api.h"
#include "iree/base/status_cc.h"
#include "iree/hal/api.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/native_module_cc.h"
#include "iree/vm/ref_cc.h"
//===----------------------------------------------------------------------===//
// !custom.message type
//===----------------------------------------------------------------------===//
// Runtime type descriptor for the !custom.message describing how to manage it
// and destroy it. The type ID is allocated at runtime and does not need to
// match the compiler ID.
static iree_vm_ref_type_descriptor_t iree_custom_message_descriptor = {0};
// The "message" type we use to store string messages to print.
// This could be arbitrarily complex or simply wrap another user-defined type.
// The descriptor that is registered at startup defines how to manage the
// lifetime of the type (such as which destruction function is called, if any).
// See ref.h for more information and additional utilities.
typedef struct iree_custom_message_t {
// Ideally first; used to track the reference count of the object.
iree_vm_ref_object_t ref_object;
// Allocator the message was created from.
// Ideally pools and nested allocators would be used to avoid needing to store
// the allocator with every object.
iree_allocator_t allocator;
// String message value.
iree_string_view_t value;
} iree_custom_message_t;
IREE_VM_DEFINE_TYPE_ADAPTERS(iree_custom_message, iree_custom_message_t);
iree_status_t iree_custom_message_create(iree_string_view_t value,
iree_allocator_t allocator,
iree_custom_message_t** out_message) {
IREE_ASSERT_ARGUMENT(out_message);
// Note that we allocate the message and the string value together.
iree_custom_message_t* message = NULL;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
allocator, sizeof(*message) + value.size, (void**)&message));
message->ref_object.counter = IREE_ATOMIC_VAR_INIT(1);
message->allocator = allocator;
message->value.data = ((const char*)message) + sizeof(iree_custom_message_t);
message->value.size = value.size;
memcpy((void*)message->value.data, value.data, message->value.size);
*out_message = message;
return iree_ok_status();
}
iree_status_t iree_custom_message_wrap(iree_string_view_t value,
iree_allocator_t allocator,
iree_custom_message_t** out_message) {
IREE_ASSERT_ARGUMENT(out_message);
iree_custom_message_t* message = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(allocator, sizeof(*message), (void**)&message));
message->ref_object.counter = IREE_ATOMIC_VAR_INIT(1);
message->allocator = allocator;
message->value = value; // Unowned.
*out_message = message;
return iree_ok_status();
}
void iree_custom_message_destroy(void* ptr) {
iree_custom_message_t* message = (iree_custom_message_t*)ptr;
iree_allocator_free(message->allocator, ptr);
}
iree_status_t iree_custom_message_read_value(iree_custom_message_t* message,
char* buffer,
size_t buffer_capacity) {
IREE_ASSERT_ARGUMENT(message);
IREE_ASSERT_ARGUMENT(buffer);
if (buffer_capacity < message->value.size + 1) {
// Not an error; just a size query.
return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE);
}
memcpy(buffer, message->value.data, message->value.size);
buffer[message->value.size] = 0;
return iree_ok_status();
}
iree_status_t iree_custom_native_module_register_types() {
if (iree_custom_message_descriptor.type) {
return iree_ok_status(); // Already registered.
}
iree_custom_message_descriptor.type_name =
iree_make_cstring_view("custom.message");
iree_custom_message_descriptor.offsetof_counter =
offsetof(iree_custom_message_t, ref_object.counter);
iree_custom_message_descriptor.destroy = iree_custom_message_destroy;
return iree_vm_ref_register_type(&iree_custom_message_descriptor);
}
//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//
namespace iree {
namespace samples {
namespace {
// Per-context module state.
// This can contain "globals" and other arbitrary state.
//
// Thread-compatible; the runtime will not issue multiple calls at the same
// time using the same state. If the implementation uses external threads then
// it must synchronize itself.
class CustomModuleState final {
public:
explicit CustomModuleState(iree_allocator_t host_allocator,
vm::ref<iree_hal_allocator_t> device_allocator)
: host_allocator_(host_allocator),
device_allocator_(std::move(device_allocator)) {}
~CustomModuleState() = default;
Status Initialize(int32_t unique_id) {
// Allocate a unique ID to demonstrate per-context state.
auto str_buffer = "ctx_" + std::to_string(unique_id);
IREE_RETURN_IF_ERROR(
iree_custom_message_create(iree_make_cstring_view(str_buffer.c_str()),
host_allocator_, &unique_message_));
return OkStatus();
}
// custom.buffer_to_message(%buffer_view) -> %result
StatusOr<vm::ref<iree_custom_message_t>> BufferToMessage(
vm::ref<iree_hal_buffer_view_t> buffer_view) {
// Convert the buffer view to a [shape]x[type]=[contents] string.
std::string string_value(4096, '\0');
iree_status_t status;
do {
iree_host_size_t actual_length = 0;
status = iree_hal_buffer_view_format(
buffer_view.get(), /*max_element_count=*/1024,
string_value.size() + 1, &string_value[0], &actual_length);
string_value.resize(actual_length);
} while (iree_status_is_out_of_range(status));
IREE_RETURN_IF_ERROR(std::move(status));
// Pack the string contents into a message.
vm::ref<iree_custom_message_t> message;
IREE_RETURN_IF_ERROR(iree_custom_message_create(
iree_string_view_t{string_value.data(), string_value.size()},
host_allocator_, &message));
return std::move(message);
}
// custom.message_to_buffer(%message) -> %buffer_view
StatusOr<vm::ref<iree_hal_buffer_view_t>> MessageToBuffer(
vm::ref<iree_custom_message_t> message) {
// Convert the [shape]x[type]=[contents] string to a buffer view.
vm::ref<iree_hal_buffer_view_t> buffer_view;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_parse(message->value, device_allocator_.get(),
&buffer_view),
"parsing value '%.*s'", (int)message->value.size, message->value.data);
return std::move(buffer_view);
}
// custom.print(%message, %count)
Status Print(vm::ref<iree_custom_message_t> message, int32_t count) {
for (int i = 0; i < count; ++i) {
fwrite(message->value.data, 1, message->value.size, stdout);
fputc('\n', stdout);
}
fflush(stdout);
return OkStatus();
}
// custom.reverse(%message) -> %result
StatusOr<vm::ref<iree_custom_message_t>> Reverse(
vm::ref<iree_custom_message_t> message) {
vm::ref<iree_custom_message_t> reversed_message;
IREE_RETURN_IF_ERROR(iree_custom_message_create(
message->value, message->allocator, &reversed_message));
char* str_ptr = const_cast<char*>(reversed_message->value.data);
for (int low = 0, high = reversed_message->value.size - 1; low < high;
++low, --high) {
char temp = str_ptr[low];
str_ptr[low] = str_ptr[high];
str_ptr[high] = temp;
}
return std::move(reversed_message);
}
// custom.get_unique_message() -> %result
StatusOr<vm::ref<iree_custom_message_t>> GetUniqueMessage() {
return vm::retain_ref(unique_message_);
}
private:
// Allocator that the caller requested we use for any allocations we need to
// perform during operation.
iree_allocator_t host_allocator_ = iree_allocator_system();
// HAL buffer allocator that uses host-local memory. This is just for this
// test as we don't actually use a HAL device and don't have a real device
// allocator to use.
vm::ref<iree_hal_allocator_t> device_allocator_;
// A unique message owned by the state and returned to the VM.
// This demonstrates any arbitrary per-context state one may want to store.
vm::ref<iree_custom_message_t> unique_message_;
};
// Function table mapping imported function names to their implementation.
// The signature of the target function is expected to match that in the
// custom.imports.mlir file.
static const vm::NativeFunction<CustomModuleState> kCustomModuleFunctions[] = {
vm::MakeNativeFunction("buffer_to_message",
&CustomModuleState::BufferToMessage),
vm::MakeNativeFunction("message_to_buffer",
&CustomModuleState::MessageToBuffer),
vm::MakeNativeFunction("print", &CustomModuleState::Print),
vm::MakeNativeFunction("reverse", &CustomModuleState::Reverse),
vm::MakeNativeFunction("get_unique_message",
&CustomModuleState::GetUniqueMessage),
};
// The module instance that will be allocated and reused across contexts.
// Any context-specific state must be stored in a state structure such as
// CustomModuleState below.
//
// Assumed thread-safe (by construction here, as it's immutable), though if more
// state is stored here it will need to be synchronized by the implementation.
class CustomModule final : public vm::NativeModule<CustomModuleState> {
public:
using vm::NativeModule<CustomModuleState>::NativeModule;
// Example of global initialization (shared across all contexts), such as
// loading native libraries or creating shared pools.
Status Initialize(iree_hal_allocator_t* device_allocator) {
device_allocator_ = vm::retain_ref(device_allocator);
next_unique_id_ = 0;
return OkStatus();
}
// Creates per-context state when the module is added to a new context.
// May be called from any thread.
StatusOr<std::unique_ptr<CustomModuleState>> CreateState(
iree_allocator_t host_allocator) override {
auto state = std::make_unique<CustomModuleState>(
host_allocator, vm::retain_ref(device_allocator_));
IREE_RETURN_IF_ERROR(state->Initialize(next_unique_id_++));
return state;
}
private:
// Used for any HAL buffer allocations we need to make.
vm::ref<iree_hal_allocator_t> device_allocator_;
// The next ID to allocate for states that will be used to form the
// per-context unique message. This shows state at the module level. Note that
// this must be thread-safe.
std::atomic<int32_t> next_unique_id_;
};
} // namespace
// Note that while we are using C++ bindings internally we still expose the
// module as a C instance. This hides the details of our implementation.
extern "C" iree_status_t iree_custom_native_module_create(
iree_allocator_t host_allocator, iree_hal_allocator_t* device_allocator,
iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
auto module = std::make_unique<CustomModule>(
"custom", host_allocator,
iree::span<const vm::NativeFunction<CustomModuleState>>(
kCustomModuleFunctions));
IREE_RETURN_IF_ERROR(module->Initialize(device_allocator));
*out_module = module.release()->interface();
return iree_ok_status();
}
} // namespace samples
} // namespace iree