blob: 6354eaf763a5f593daee3a091b882ecd12a840e5 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "iree/hal/vmla/vmla_module.h"
#include <cstdint>
#include "absl/types/span.h"
#include "iree/base/tracing.h"
#include "iree/hal/vmla/op_kernels.h"
#include "iree/vm/module_abi_packing.h"
//===----------------------------------------------------------------------===//
// Type registration
//===----------------------------------------------------------------------===//
static iree_vm_ref_type_descriptor_t Buffer_descriptor = {0};
static iree_vm_ref_type_descriptor_t Interface_descriptor = {0};
IREE_VM_DEFINE_TYPE_ADAPTERS(Buffer, iree::hal::vmla::Buffer);
IREE_VM_DEFINE_TYPE_ADAPTERS(Interface, iree::hal::vmla::Interface);
#define IREE_VMLA_REGISTER_CC_TYPE(type, name, descriptor) \
descriptor.type_name = iree_make_cstring_view(name); \
descriptor.offsetof_counter = type::offsetof_counter(); \
descriptor.destroy = type::DirectDestroy; \
IREE_RETURN_IF_ERROR(iree_vm_ref_register_type(&descriptor)) \
<< "Failed to register type " << name;
namespace iree {
namespace hal {
namespace vmla {
Status ModuleRegisterTypes() {
static bool has_registered = false;
if (has_registered) return OkStatus();
IREE_VMLA_REGISTER_CC_TYPE(Buffer, "vmla.buffer", Buffer_descriptor);
IREE_VMLA_REGISTER_CC_TYPE(Interface, "vmla.interface", Interface_descriptor);
has_registered = true;
return OkStatus();
}
//===----------------------------------------------------------------------===//
// API type implementations
//===----------------------------------------------------------------------===//
// static
StatusOr<vm::ref<Buffer>> Buffer::Allocate(size_t byte_length,
iree_allocator_t allocator) {
void* data = nullptr;
IREE_RETURN_IF_ERROR(iree_allocator_malloc(allocator, byte_length, &data))
<< "Failed to allocate buffer of size " << byte_length;
auto buffer = vm::assign_ref(new Buffer());
buffer->data_ = data;
buffer->data_length_ = byte_length;
buffer->allocator_ = allocator;
return std::move(buffer);
}
// static
StatusOr<vm::ref<Buffer>> Buffer::Wrap(const void* data, size_t data_length,
iree_allocator_t allocator) {
auto buffer = vm::assign_ref(new Buffer());
buffer->data_ = const_cast<void*>(data);
buffer->data_length_ = data_length;
buffer->allocator_ = allocator;
return std::move(buffer);
}
// static
StatusOr<vm::ref<Buffer>> Buffer::WrapMutable(void* data, size_t data_length,
iree_allocator_t allocator) {
auto buffer = vm::assign_ref(new Buffer());
buffer->data_ = data;
buffer->data_length_ = data_length;
buffer->allocator_ = allocator;
return std::move(buffer);
}
Buffer::~Buffer() {
if (!parent_) {
iree_allocator_free(allocator_, data_);
data_ = nullptr;
}
parent_.reset();
}
StatusOr<absl::Span<uint8_t>> Buffer::MakeRange(
iree_vmla_size_t byte_offset, iree_vmla_size_t byte_length) const {
if (byte_length == kVMLAWholeBuffer) {
byte_length = size() - byte_offset;
}
if (byte_offset > size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Attempted to access an address off the end of the valid "
"buffer range (offset="
<< byte_offset << ", length=" << byte_length
<< ", buffer byte_length=" << size() << ")";
}
size_t end = byte_offset + byte_length - 1;
if (end >= size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Attempted to access an address outside of the valid buffer "
"range (offset="
<< byte_offset << ", length=" << byte_length << ", end=" << end
<< ", buffer byte_length=" << size() << ")";
}
uint8_t* data = reinterpret_cast<uint8_t*>(data_) + byte_offset;
size_t data_length = byte_length;
return absl::MakeSpan(data, data_length);
}
constexpr int Interface::kMaxConstants;
constexpr int Interface::kMaxSets;
constexpr int Interface::kMaxBindings;
void Interface::Reset() {
for (int i = 0; i < bindings_.size(); ++i) {
for (int j = 0; j < bindings_[i].size(); ++j) {
bindings_[i][j] = {};
}
}
}
StatusOr<uint32_t> Interface::GetConstant(uint32_t offset) const {
if (offset >= kMaxConstants) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Invalid constant offset=" << offset;
}
return constants_[offset];
}
Status Interface::SetConstants(absl::Span<const uint32_t> values) {
if (values.size() > kMaxConstants) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Constant value overflow; have " << values.size()
<< " but max is " << kMaxConstants;
}
for (int i = 0; i < values.size(); ++i) {
constants_[i] = values[i];
}
return OkStatus();
}
StatusOr<const Interface::Binding> Interface::GetBinding(
int32_t set, int32_t binding) const {
if (set < 0 || set > kMaxSets || binding < 0 || binding > kMaxBindings) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Invalid binding set=" << set << ", binding=" << binding;
}
return bindings_[set][binding];
}
Status Interface::SetBinding(int32_t set, int32_t binding, Binding value) {
if (set < 0 || set > kMaxSets || binding < 0 || binding > kMaxBindings) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Invalid binding set=" << set << ", binding=" << binding;
}
bindings_[set][binding] = std::move(value);
return OkStatus();
}
//===----------------------------------------------------------------------===//
// Module state and method implementation
//===----------------------------------------------------------------------===//
namespace {
// Per-executable VMLA module state.
// This provides the exported kernel functions to the VM and is instantiated
// one or more times per executable used within a device. Any state here can be
// treated as workgroup-local memory.
//
// Thread-compatible.
class VMLAModuleState final {
public:
VMLAModuleState(iree_allocator_t allocator,
kernels::RuntimeState* kernel_state)
: allocator_(allocator), kernel_state_(kernel_state) {}
~VMLAModuleState() = default;
//===--------------------------------------------------------------------===//
// vmla.interface.*
//===--------------------------------------------------------------------===//
StatusOr<uint32_t> InterfaceConst(vm::ref<Interface> interface,
uint32_t offset) {
IREE_TRACE_SCOPE0("VMLAModuleState::InterfaceConst");
return interface->GetConstant(offset);
}
StatusOr<vm::ref<Buffer>> InterfaceBinding(vm::ref<Interface> interface,
int32_t set, int32_t binding) {
IREE_TRACE_SCOPE0("VMLAModuleState::InterfaceBinding");
IREE_ASSIGN_OR_RETURN(const auto& value,
interface->GetBinding(set, binding));
return vm::retain_ref(value.buffer);
}
//===--------------------------------------------------------------------===//
// vmla.buffer.*
//===--------------------------------------------------------------------===//
StatusOr<vm::ref<Buffer>> BufferConst(
vm::ref<iree_vm_ro_byte_buffer_t> value) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferConst");
iree_allocator_t external_allocator = {0};
external_allocator.self = vm::retain_ref(value).release();
external_allocator.free = +[](void* self, void* ptr) {
vm::assign_ref(reinterpret_cast<iree_vm_ro_byte_buffer_t*>(self)).reset();
};
return Buffer::Wrap(value->data.data, value->data.data_length,
external_allocator);
}
StatusOr<vm::ref<Buffer>> BufferAlloc(iree_vmla_size_t byte_length) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferAlloc");
return Buffer::Allocate(byte_length, allocator_);
}
StatusOr<vm::ref<Buffer>> BufferClone(vm::ref<Buffer> src) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferClone");
IREE_ASSIGN_OR_RETURN(auto dst, Buffer::Allocate(src->size(), allocator_));
std::memcpy(dst->data(), src->data(), dst->size());
return std::move(dst);
}
StatusOr<iree_vmla_size_t> BufferByteLength(vm::ref<Buffer> buffer) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferByteLength");
return buffer->size();
}
StatusOr<vm::ref<Buffer>> BufferView(vm::ref<Buffer> src,
iree_vmla_size_t byte_offset,
iree_vmla_size_t byte_length) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferView");
if (byte_length == kVMLAWholeBuffer) {
byte_length = src->size() - byte_offset;
}
if (byte_offset == 0 && byte_length == src->size()) {
// Asking for the same buffer.
return vm::retain_ref(src);
} else if (byte_offset > src->size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Attempted to access an address off the end of the valid "
"buffer range (offset="
<< byte_offset << ", length=" << byte_length
<< ", buffer byte_length=" << src->size() << ")";
}
size_t end = byte_offset + byte_length - 1;
if (end >= src->size()) {
return OutOfRangeErrorBuilder(IREE_LOC)
<< "Attempted to access an address outside of the valid buffer "
"range (offset="
<< byte_offset << ", length=" << byte_length << ", end=" << end
<< ", buffer byte_length=" << src->size() << ")";
}
uint8_t* data = reinterpret_cast<uint8_t*>(src->data()) + byte_offset;
size_t data_length = byte_length;
iree_allocator_t external_allocator = {0};
external_allocator.self = vm::retain_ref(src).release();
external_allocator.free = +[](void* self, void* ptr) {
vm::assign_ref(reinterpret_cast<Buffer*>(self)).reset();
};
return Buffer::Wrap(data, data_length, external_allocator);
}
Status BufferCopy(vm::ref<Buffer> src, iree_vmla_size_t src_byte_offset,
vm::ref<Buffer> dst, iree_vmla_size_t dst_byte_offset,
iree_vmla_size_t byte_length) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferCopy");
if (byte_length == kVMLAWholeBuffer) {
byte_length = src->size() - src_byte_offset;
}
IREE_ASSIGN_OR_RETURN(auto src_bytes, src->RangeAs<const uint8_t>(
src_byte_offset, byte_length));
IREE_ASSIGN_OR_RETURN(auto dst_bytes,
dst->RangeAs<uint8_t>(dst_byte_offset, byte_length));
std::memcpy(dst_bytes.data(), src_bytes.data(), dst_bytes.size());
return OkStatus();
}
Status BufferFill(vm::ref<Buffer> value, vm::ref<Buffer> dst) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferFill");
if (value->size() == 1) {
// Fast-path for single-byte memset values.
std::memset(dst->data(), value->As<uint8_t>()[0], dst->size());
return OkStatus();
} else if (dst->size() % value->size() != 0) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Fill value length (" << value->size()
<< ") must divide evenly into buffer length (" << dst->size()
<< ")";
}
auto value_bytes = value->As<uint8_t>();
auto dst_bytes = dst->As<uint8_t>();
for (size_t i = 0; i < dst_bytes.size(); i += value_bytes.size()) {
std::memcpy(dst_bytes.data() + i, value_bytes.data(), value_bytes.size());
}
return OkStatus();
}
StatusOr<int32_t> BufferLoadI32(vm::ref<Buffer> src,
iree_vmla_size_t byte_offset) {
IREE_TRACE_SCOPE0("VMLAModuleState::BufferLoadI32");
IREE_ASSIGN_OR_RETURN(auto data,
src->RangeAs<int32_t>(byte_offset, sizeof(int32_t)));
return data[0];
}
//===--------------------------------------------------------------------===//
// Common helpers for defining ops
//===--------------------------------------------------------------------===//
#define IREE_VMLA_NONARY_OP(name, kernel, type) \
Status name(vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernel::Execute<type>(dst->As<type>()); \
}
#define IREE_VMLA_UNARY_OP(name, kernel, type) \
Status name(vm::ref<Buffer> src, vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernel::Execute<type>(src->As<type>(), dst->As<type>()); \
}
#define IREE_VMLA_BINARY_OP(name, kernel, type) \
Status name(vm::ref<Buffer> lhs, vm::ref<Buffer> rhs, vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernel::Execute<type>(lhs->As<type>(), rhs->As<type>(), \
dst->As<type>()); \
}
#define IREE_VMLA_TERNARY_OP(name, kernel, type) \
Status name(vm::ref<Buffer> a, vm::ref<Buffer> b, vm::ref<Buffer> c, \
vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernel::Execute<type>(a->As<type>(), b->As<type>(), c->As<type>(), \
dst->As<type>()); \
}
//===--------------------------------------------------------------------===//
// VMLA Ops: comparison
//===--------------------------------------------------------------------===//
enum class CmpPredicate : uint32_t {
kEQ = 0,
kNE = 1,
kLT = 2,
kLE = 3,
kGT = 4,
kGE = 5,
};
#define IREE_VMLA_COMPARE_OP(name, type) \
Status name(int32_t predicate, vm::ref<Buffer> lhs, vm::ref<Buffer> rhs, \
vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
switch (static_cast<CmpPredicate>(predicate)) { \
case CmpPredicate::kEQ: \
return kernels::CompareEQ::Execute<type>( \
lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>()); \
case CmpPredicate::kNE: \
return kernels::CompareNE::Execute<type>( \
lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>()); \
case CmpPredicate::kLT: \
return kernels::CompareLT::Execute<type>( \
lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>()); \
case CmpPredicate::kLE: \
return kernels::CompareLE::Execute<type>( \
lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>()); \
case CmpPredicate::kGT: \
return kernels::CompareGT::Execute<type>( \
lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>()); \
case CmpPredicate::kGE: \
return kernels::CompareGE::Execute<type>( \
lhs->As<type>(), rhs->As<type>(), dst->As<uint8_t>()); \
default: \
return InvalidArgumentErrorBuilder(IREE_LOC) \
<< "Unsupported predicate " << predicate; \
} \
}
IREE_VMLA_COMPARE_OP(CmpI8, int8_t);
IREE_VMLA_COMPARE_OP(CmpI16, int16_t);
IREE_VMLA_COMPARE_OP(CmpI32, int32_t);
IREE_VMLA_COMPARE_OP(CmpF32, float);
#define IREE_VMLA_SELECT_OP(name, type) \
Status name(vm::ref<Buffer> cond, vm::ref<Buffer> lhs, vm::ref<Buffer> rhs, \
vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Select::Execute<type>(cond->As<uint8_t>(), \
lhs->As<type>(), rhs->As<type>(), \
dst->As<type>()); \
}
IREE_VMLA_SELECT_OP(SelectX8, uint8_t);
IREE_VMLA_SELECT_OP(SelectX16, uint16_t);
IREE_VMLA_SELECT_OP(SelectX32, uint32_t);
//===--------------------------------------------------------------------===//
// VMLA Ops: shape/structure
//===--------------------------------------------------------------------===//
#define IREE_VMLA_COPY_OP(name, size) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
absl::Span<const int32_t> src_indices, vm::ref<Buffer> dst, \
iree_vmla_shape_t dst_shape, \
absl::Span<const int32_t> dst_indices, \
absl::Span<const int32_t> lengths) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Copy::Execute<size>(src->As<uint8_t>(), src_shape, \
src_indices, dst->As<uint8_t>(), \
dst_shape, dst_indices, lengths); \
}
IREE_VMLA_COPY_OP(CopyX8, sizeof(uint8_t));
IREE_VMLA_COPY_OP(CopyX16, sizeof(uint16_t));
IREE_VMLA_COPY_OP(CopyX32, sizeof(uint32_t));
#define IREE_VMLA_TRANSPOSE_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
absl::Span<const int32_t> permutation, vm::ref<Buffer> dst, \
iree_vmla_shape_t dst_shape) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Transpose::Execute<type>(src->As<type>(), dst->As<type>(), \
src_shape, permutation); \
}
IREE_VMLA_TRANSPOSE_OP(TransposeX8, uint8_t);
IREE_VMLA_TRANSPOSE_OP(TransposeX16, uint16_t);
IREE_VMLA_TRANSPOSE_OP(TransposeX32, uint32_t);
#define IREE_VMLA_REVERSE_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
absl::Span<const int32_t> dims, vm::ref<Buffer> dst, \
iree_vmla_shape_t dst_shape) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Reverse::Execute<type>(src->As<type>(), dst->As<type>(), \
src_shape, dims); \
}
IREE_VMLA_REVERSE_OP(ReverseX8, uint8_t);
IREE_VMLA_REVERSE_OP(ReverseX16, uint16_t);
IREE_VMLA_REVERSE_OP(ReverseX32, uint32_t);
#define IREE_VMLA_PAD_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> value, iree_vmla_shape_t value_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape, \
absl::Span<const int32_t> edge_padding_low, \
absl::Span<const int32_t> edge_padding_high, \
absl::Span<const int32_t> interior_padding) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Pad::Execute<type>( \
src->As<type>(), value->As<type>(), dst->As<type>(), src_shape, \
dst_shape, edge_padding_low, edge_padding_high, interior_padding); \
}
IREE_VMLA_PAD_OP(PadX8, uint8_t);
IREE_VMLA_PAD_OP(PadX16, uint16_t);
IREE_VMLA_PAD_OP(PadX32, uint32_t);
#define IREE_VMLA_GATHER_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> indices, iree_vmla_shape_t indices_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape, \
const int32_t dim, const int32_t batch_dims) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Gather::Execute<type>( \
src->As<type>(), indices->As<int>(), dst->As<type>(), src_shape, \
indices_shape, dst_shape, dim, batch_dims); \
}
IREE_VMLA_GATHER_OP(GatherX8, uint8_t);
IREE_VMLA_GATHER_OP(GatherX16, uint16_t);
IREE_VMLA_GATHER_OP(GatherX32, uint32_t);
#define IREE_VMLA_SCATTER_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> indices, iree_vmla_shape_t indices_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Scatter::Execute<type>( \
src->As<type>(), indices->As<int>(), dst->As<type>(), src_shape, \
indices_shape, dst_shape); \
}
IREE_VMLA_SCATTER_OP(ScatterX8, uint8_t);
IREE_VMLA_SCATTER_OP(ScatterX16, uint16_t);
IREE_VMLA_SCATTER_OP(ScatterX32, uint32_t);
#define IREE_VMLA_BROADCAST_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Broadcast::Execute<type>(src->As<type>(), \
dst->As<type>()); \
}
IREE_VMLA_BROADCAST_OP(BroadcastX8, uint8_t);
IREE_VMLA_BROADCAST_OP(BroadcastX16, uint16_t);
IREE_VMLA_BROADCAST_OP(BroadcastX32, uint32_t);
IREE_VMLA_NONARY_OP(IotaI8, kernels::Iota, int8_t);
IREE_VMLA_NONARY_OP(IotaI16, kernels::Iota, int16_t);
IREE_VMLA_NONARY_OP(IotaI32, kernels::Iota, int32_t);
IREE_VMLA_NONARY_OP(IotaF32, kernels::Iota, float_t);
#define IREE_VMLA_TILE_OP(name, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Tile::Execute<type>(src->As<type>(), dst->As<type>(), \
src_shape, dst_shape); \
}
IREE_VMLA_TILE_OP(TileX8, uint8_t);
IREE_VMLA_TILE_OP(TileX16, uint16_t);
IREE_VMLA_TILE_OP(TileX32, uint32_t);
//===--------------------------------------------------------------------===//
// VMLA Ops: bit manipulation
//===--------------------------------------------------------------------===//
IREE_VMLA_UNARY_OP(NotX8, kernels::Not, uint8_t);
IREE_VMLA_UNARY_OP(NotX16, kernels::Not, uint16_t);
IREE_VMLA_UNARY_OP(NotX32, kernels::Not, uint32_t);
IREE_VMLA_BINARY_OP(AndX8, kernels::And, uint8_t);
IREE_VMLA_BINARY_OP(AndX16, kernels::And, uint16_t);
IREE_VMLA_BINARY_OP(AndX32, kernels::And, uint32_t);
IREE_VMLA_BINARY_OP(OrX8, kernels::Or, uint8_t);
IREE_VMLA_BINARY_OP(OrX16, kernels::Or, uint16_t);
IREE_VMLA_BINARY_OP(OrX32, kernels::Or, uint32_t);
IREE_VMLA_BINARY_OP(XorX8, kernels::Xor, uint8_t);
IREE_VMLA_BINARY_OP(XorX16, kernels::Xor, uint16_t);
IREE_VMLA_BINARY_OP(XorX32, kernels::Xor, uint32_t);
IREE_VMLA_BINARY_OP(ShlX8, kernels::ShiftLeft, uint8_t);
IREE_VMLA_BINARY_OP(ShlX16, kernels::ShiftLeft, uint16_t);
IREE_VMLA_BINARY_OP(ShlX32, kernels::ShiftLeft, uint32_t);
IREE_VMLA_BINARY_OP(ShrU8, kernels::ShiftRight, uint8_t);
IREE_VMLA_BINARY_OP(ShrU16, kernels::ShiftRight, uint16_t);
IREE_VMLA_BINARY_OP(ShrU32, kernels::ShiftRight, uint32_t);
IREE_VMLA_BINARY_OP(ShrI8, kernels::ShiftRight, int8_t);
IREE_VMLA_BINARY_OP(ShrI16, kernels::ShiftRight, int16_t);
IREE_VMLA_BINARY_OP(ShrI32, kernels::ShiftRight, int32_t);
//===--------------------------------------------------------------------===//
// VMLA Ops: arithmetic
//===--------------------------------------------------------------------===//
IREE_VMLA_BINARY_OP(AddI8, kernels::Add, int8_t);
IREE_VMLA_BINARY_OP(AddI16, kernels::Add, int16_t);
IREE_VMLA_BINARY_OP(AddI32, kernels::Add, int32_t);
IREE_VMLA_BINARY_OP(AddF32, kernels::Add, float);
IREE_VMLA_BINARY_OP(SubI8, kernels::Sub, int8_t);
IREE_VMLA_BINARY_OP(SubI16, kernels::Sub, int16_t);
IREE_VMLA_BINARY_OP(SubI32, kernels::Sub, int32_t);
IREE_VMLA_BINARY_OP(SubF32, kernels::Sub, float);
IREE_VMLA_UNARY_OP(AbsI8, kernels::Abs, int8_t);
IREE_VMLA_UNARY_OP(AbsI16, kernels::Abs, int16_t);
IREE_VMLA_UNARY_OP(AbsI32, kernels::Abs, int32_t);
IREE_VMLA_UNARY_OP(AbsF32, kernels::Abs, float);
IREE_VMLA_UNARY_OP(NegI8, kernels::Neg, int8_t);
IREE_VMLA_UNARY_OP(NegI16, kernels::Neg, int16_t);
IREE_VMLA_UNARY_OP(NegI32, kernels::Neg, int32_t);
IREE_VMLA_UNARY_OP(NegF32, kernels::Neg, float);
IREE_VMLA_BINARY_OP(MulI8, kernels::Mul, int8_t);
IREE_VMLA_BINARY_OP(MulI16, kernels::Mul, int16_t);
IREE_VMLA_BINARY_OP(MulI32, kernels::Mul, int32_t);
IREE_VMLA_BINARY_OP(MulF32, kernels::Mul, float);
IREE_VMLA_BINARY_OP(DivI8, kernels::Div, int8_t);
IREE_VMLA_BINARY_OP(DivI16, kernels::Div, int16_t);
IREE_VMLA_BINARY_OP(DivI32, kernels::Div, int32_t);
IREE_VMLA_BINARY_OP(DivU8, kernels::Div, uint8_t);
IREE_VMLA_BINARY_OP(DivU16, kernels::Div, uint16_t);
IREE_VMLA_BINARY_OP(DivU32, kernels::Div, uint32_t);
IREE_VMLA_BINARY_OP(DivF32, kernels::Div, float);
IREE_VMLA_BINARY_OP(RemI8, kernels::Rem, int8_t);
IREE_VMLA_BINARY_OP(RemI16, kernels::Rem, int16_t);
IREE_VMLA_BINARY_OP(RemI32, kernels::Rem, int32_t);
IREE_VMLA_BINARY_OP(RemU8, kernels::Rem, uint8_t);
IREE_VMLA_BINARY_OP(RemU16, kernels::Rem, uint16_t);
IREE_VMLA_BINARY_OP(RemU32, kernels::Rem, uint32_t);
IREE_VMLA_BINARY_OP(RemF32, kernels::Rem, float);
IREE_VMLA_BINARY_OP(PowF32, kernels::Pow, float);
IREE_VMLA_UNARY_OP(ExpF32, kernels::Exp, float);
IREE_VMLA_UNARY_OP(LogF32, kernels::Log, float);
IREE_VMLA_UNARY_OP(RsqrtF32, kernels::Rsqrt, float);
IREE_VMLA_UNARY_OP(SqrtF32, kernels::Sqrt, float);
IREE_VMLA_UNARY_OP(CosF32, kernels::Cos, float);
IREE_VMLA_UNARY_OP(SinF32, kernels::Sin, float);
IREE_VMLA_UNARY_OP(TanhF32, kernels::Tanh, float);
IREE_VMLA_BINARY_OP(Atan2F32, kernels::Atan2, float);
IREE_VMLA_BINARY_OP(MinI8, kernels::Min, int8_t);
IREE_VMLA_BINARY_OP(MinI16, kernels::Min, int16_t);
IREE_VMLA_BINARY_OP(MinI32, kernels::Min, int32_t);
IREE_VMLA_BINARY_OP(MinF32, kernels::Min, float);
IREE_VMLA_BINARY_OP(MaxI8, kernels::Max, int8_t);
IREE_VMLA_BINARY_OP(MaxI16, kernels::Max, int16_t);
IREE_VMLA_BINARY_OP(MaxI32, kernels::Max, int32_t);
IREE_VMLA_BINARY_OP(MaxF32, kernels::Max, float);
IREE_VMLA_TERNARY_OP(ClampI8, kernels::Clamp, int8_t);
IREE_VMLA_TERNARY_OP(ClampI16, kernels::Clamp, int16_t);
IREE_VMLA_TERNARY_OP(ClampI32, kernels::Clamp, int32_t);
IREE_VMLA_TERNARY_OP(ClampF32, kernels::Clamp, float);
IREE_VMLA_UNARY_OP(FloorF32, kernels::Floor, float);
IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);
//===--------------------------------------------------------------------===//
// VMLA Ops: conversion
//===--------------------------------------------------------------------===//
#define IREE_VMLA_CONVERSION_OP(name, src_type, dst_type) \
Status name(vm::ref<Buffer> src, vm::ref<Buffer> dst) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernels::Convert::Execute<src_type, dst_type>(src->As<src_type>(), \
dst->As<dst_type>()); \
}
IREE_VMLA_CONVERSION_OP(ConvertI8I16, int8_t, int16_t);
IREE_VMLA_CONVERSION_OP(ConvertI8I32, int8_t, int32_t);
IREE_VMLA_CONVERSION_OP(ConvertI8F32, int8_t, float);
IREE_VMLA_CONVERSION_OP(ConvertI16I8, int16_t, int8_t);
IREE_VMLA_CONVERSION_OP(ConvertI16I32, int16_t, int32_t);
IREE_VMLA_CONVERSION_OP(ConvertI16F32, int16_t, float);
IREE_VMLA_CONVERSION_OP(ConvertI32I8, int32_t, int8_t);
IREE_VMLA_CONVERSION_OP(ConvertI32I16, int32_t, int16_t);
IREE_VMLA_CONVERSION_OP(ConvertI32F32, int32_t, float);
IREE_VMLA_CONVERSION_OP(ConvertF32I8, float, int8_t);
IREE_VMLA_CONVERSION_OP(ConvertF32I16, float, int16_t);
IREE_VMLA_CONVERSION_OP(ConvertF32I32, float, int32_t);
//===--------------------------------------------------------------------===//
// VMLA Ops: Convolution
//===--------------------------------------------------------------------===//
Status ConvF32F32F32(vm::ref<Buffer> input, iree_vmla_shape_t input_shape,
vm::ref<Buffer> filter, iree_vmla_shape_t filter_shape,
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape,
absl::Span<const int32_t> window_strides,
absl::Span<const int32_t> padding,
absl::Span<const int32_t> lhs_dilation,
absl::Span<const int32_t> rhs_dilation,
const int32_t feature_group_count,
const int32_t batch_group_count) {
IREE_TRACE_SCOPE0("VMLAModuleState::ConvF32F32F32");
if (input_shape.size() != 4 || filter_shape.size() != 4 ||
dst_shape.size() != 4) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "Expecting 4-d tensors for Conv2D kernel";
}
const int32_t batch_size = input_shape[0];
const auto input_example_shape = input_shape.subspan(1, 3);
const auto output_example_shape = dst_shape.subspan(1, 3);
const auto filter_shape_4d = filter_shape.subspan(0, 4);
const auto dilation = lhs_dilation.subspan(0, 2);
const auto pad_h = padding.subspan(0, 2);
const auto pad_w = padding.subspan(2, 2);
const auto window_strides_2d = window_strides.subspan(0, 2);
const float* raw_inputs_data = input->As<float>().data();
const float* raw_filter_data = filter->As<float>().data();
float* raw_dst_data = dst->As<float>().data();
auto filter_buffer = absl::MakeConstSpan(
raw_filter_data, kernels::GetElementCount(filter_shape_4d));
const size_t input_stride = kernels::GetElementCount(input_example_shape);
const size_t output_stride = kernels::GetElementCount(output_example_shape);
for (int i = 0; i < batch_size; ++i) {
auto input_example =
absl::MakeConstSpan(raw_inputs_data + i * input_stride, input_stride);
auto output_example =
absl::MakeSpan(raw_dst_data + i * output_stride, output_stride);
IREE_RETURN_IF_ERROR(kernels::Conv2D::Execute(
input_example, input_example_shape, filter_buffer, filter_shape_4d,
output_example, output_example_shape, window_strides_2d, pad_h, pad_w,
dilation, feature_group_count));
}
return OkStatus();
}
//===--------------------------------------------------------------------===//
// VMLA Ops: GEMM/GEMV
//===--------------------------------------------------------------------===//
Status BatchMatMulF32F32F32(vm::ref<Buffer> lhs, iree_vmla_shape_t lhs_shape,
vm::ref<Buffer> rhs, iree_vmla_shape_t rhs_shape,
vm::ref<Buffer> dst,
iree_vmla_shape_t dst_shape) {
IREE_TRACE_SCOPE0("VMLAModuleState::BatchMatMulF32F32F32");
// Compiler guarantees. Here for documentation purposes.
assert(lhs_shape.size() == 3 && rhs_shape.size() == 3 &&
dst_shape.size() == 3);
assert(lhs_shape[0] == rhs_shape[0] && rhs_shape[0] == dst_shape[0]);
iree_vmla_shape_t lhs_batch_element_shape = lhs_shape.subspan(1);
iree_vmla_shape_t rhs_batch_element_shape = rhs_shape.subspan(1);
iree_vmla_shape_t dst_batch_element_shape = dst_shape.subspan(1);
auto lhs_batch_element_shape2 = lhs_batch_element_shape.subspan(0, 2);
auto rhs_batch_element_shape2 = rhs_batch_element_shape.subspan(0, 2);
auto dst_batch_element_shape2 = dst_batch_element_shape.subspan(0, 2);
size_t lhs_batch_stride = kernels::GetElementCount(lhs_batch_element_shape);
size_t rhs_batch_stride = kernels::GetElementCount(rhs_batch_element_shape);
size_t dst_batch_stride = kernels::GetElementCount(dst_batch_element_shape);
float* lhs_batch_base = lhs->As<float>().data();
float* rhs_batch_base = rhs->As<float>().data();
float* dst_batch_base = dst->As<float>().data();
int32_t batch_dim = lhs_shape[0];
for (int i = 0; i < batch_dim; i++) {
kernels::MatMul::Buffers<float, float> buffers;
buffers.lhs_buffer = absl::MakeSpan(lhs_batch_base + i * lhs_batch_stride,
lhs_batch_stride);
buffers.lhs_shape = lhs_batch_element_shape2;
buffers.rhs_buffer = absl::MakeSpan(rhs_batch_base + i * rhs_batch_stride,
rhs_batch_stride);
buffers.rhs_shape = rhs_batch_element_shape2;
buffers.dst_buffer = absl::MakeSpan(dst_batch_base + i * dst_batch_stride,
dst_batch_stride);
buffers.dst_shape = dst_batch_element_shape2;
IREE_RETURN_IF_ERROR(kernels::MatMul::Execute(
kernel_state_->mat_mul_state.get(), buffers));
}
return OkStatus();
}
//===--------------------------------------------------------------------===//
// VMLA Ops: reduction
//===--------------------------------------------------------------------===//
#define IREE_VMLA_REDUCTION_OP(name, kernel, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> init, iree_vmla_shape_t init_shape, \
int32_t dimension, vm::ref<Buffer> dst, \
iree_vmla_shape_t dst_shape) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernel::Execute<type>(src->As<type>(), init->As<type>(), \
dst->As<type>(), dimension, src_shape, \
dst_shape); \
}
IREE_VMLA_REDUCTION_OP(ReduceSumI8, kernels::ReduceSum, int8_t);
IREE_VMLA_REDUCTION_OP(ReduceSumI16, kernels::ReduceSum, int16_t);
IREE_VMLA_REDUCTION_OP(ReduceSumI32, kernels::ReduceSum, int32_t);
IREE_VMLA_REDUCTION_OP(ReduceSumF32, kernels::ReduceSum, float);
IREE_VMLA_REDUCTION_OP(ReduceMinI8, kernels::ReduceMin, int8_t);
IREE_VMLA_REDUCTION_OP(ReduceMinI16, kernels::ReduceMin, int16_t);
IREE_VMLA_REDUCTION_OP(ReduceMinI32, kernels::ReduceMin, int32_t);
IREE_VMLA_REDUCTION_OP(ReduceMinF32, kernels::ReduceMin, float);
IREE_VMLA_REDUCTION_OP(ReduceMaxI8, kernels::ReduceMax, int8_t);
IREE_VMLA_REDUCTION_OP(ReduceMaxI16, kernels::ReduceMax, int16_t);
IREE_VMLA_REDUCTION_OP(ReduceMaxI32, kernels::ReduceMax, int32_t);
IREE_VMLA_REDUCTION_OP(ReduceMaxF32, kernels::ReduceMax, float);
#define IREE_VMLA_POOLING_OP(name, kernel, type) \
Status name(vm::ref<Buffer> src, iree_vmla_shape_t src_shape, \
vm::ref<Buffer> init, iree_vmla_shape_t init_shape, \
vm::ref<Buffer> dst, iree_vmla_shape_t dst_shape, \
iree_vmla_shape_t window_dimensions, iree_vmla_shape_t strides, \
iree_vmla_shape_t pad_low) { \
IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
return kernel::Execute<type>(src->As<type>(), init->As<type>(), \
dst->As<type>(), src_shape, dst_shape, \
window_dimensions, strides, pad_low); \
}
IREE_VMLA_POOLING_OP(PoolingSumI8, kernels::PoolingSum, int8_t);
IREE_VMLA_POOLING_OP(PoolingSumI16, kernels::PoolingSum, int16_t);
IREE_VMLA_POOLING_OP(PoolingSumI32, kernels::PoolingSum, int32_t);
IREE_VMLA_POOLING_OP(PoolingSumF32, kernels::PoolingSum, float);
IREE_VMLA_POOLING_OP(PoolingMinI8, kernels::PoolingMin, int8_t);
IREE_VMLA_POOLING_OP(PoolingMinI16, kernels::PoolingMin, int16_t);
IREE_VMLA_POOLING_OP(PoolingMinI32, kernels::PoolingMin, int32_t);
IREE_VMLA_POOLING_OP(PoolingMinF32, kernels::PoolingMin, float);
IREE_VMLA_POOLING_OP(PoolingMaxI8, kernels::PoolingMax, int8_t);
IREE_VMLA_POOLING_OP(PoolingMaxI16, kernels::PoolingMax, int16_t);
IREE_VMLA_POOLING_OP(PoolingMaxI32, kernels::PoolingMax, int32_t);
IREE_VMLA_POOLING_OP(PoolingMaxF32, kernels::PoolingMax, float);
private:
iree_allocator_t allocator_;
// NOTE: kernel state must be externally synchronized as it is shared across
// all contexts using the VMLA module. This is fine in our current design as
// we only ever execute a single context at a time but if we start to allow
// concurrency across contexts we'll need to introduce locks.
kernels::RuntimeState* kernel_state_ = nullptr;
};
//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//
static const vm::NativeFunction<VMLAModuleState> kVMLAModuleFunctions[] = {
vm::MakeNativeFunction("interface.const", &VMLAModuleState::InterfaceConst),
vm::MakeNativeFunction("interface.binding",
&VMLAModuleState::InterfaceBinding),
vm::MakeNativeFunction("buffer.const", &VMLAModuleState::BufferConst),
vm::MakeNativeFunction("buffer.alloc", &VMLAModuleState::BufferAlloc),
vm::MakeNativeFunction("buffer.clone", &VMLAModuleState::BufferClone),
vm::MakeNativeFunction("buffer.view", &VMLAModuleState::BufferView),
vm::MakeNativeFunction("buffer.copy", &VMLAModuleState::BufferCopy),
vm::MakeNativeFunction("buffer.fill", &VMLAModuleState::BufferFill),
vm::MakeNativeFunction("buffer.load.i32", &VMLAModuleState::BufferLoadI32),
vm::MakeNativeFunction("cmp.i8", &VMLAModuleState::CmpI8),
vm::MakeNativeFunction("cmp.i16", &VMLAModuleState::CmpI16),
vm::MakeNativeFunction("cmp.i32", &VMLAModuleState::CmpI32),
vm::MakeNativeFunction("cmp.f32", &VMLAModuleState::CmpF32),
vm::MakeNativeFunction("select.x8", &VMLAModuleState::SelectX8),
vm::MakeNativeFunction("select.x16", &VMLAModuleState::SelectX16),
vm::MakeNativeFunction("select.x32", &VMLAModuleState::SelectX32),
vm::MakeNativeFunction("broadcast.x8", &VMLAModuleState::BroadcastX8),
vm::MakeNativeFunction("broadcast.x16", &VMLAModuleState::BroadcastX16),
vm::MakeNativeFunction("broadcast.x32", &VMLAModuleState::BroadcastX32),
vm::MakeNativeFunction("copy.x8", &VMLAModuleState::CopyX8),
vm::MakeNativeFunction("copy.x16", &VMLAModuleState::CopyX16),
vm::MakeNativeFunction("copy.x32", &VMLAModuleState::CopyX32),
vm::MakeNativeFunction("transpose.x8", &VMLAModuleState::TransposeX8),
vm::MakeNativeFunction("transpose.x16", &VMLAModuleState::TransposeX16),
vm::MakeNativeFunction("transpose.x32", &VMLAModuleState::TransposeX32),
vm::MakeNativeFunction("reverse.x8", &VMLAModuleState::ReverseX8),
vm::MakeNativeFunction("reverse.x16", &VMLAModuleState::ReverseX16),
vm::MakeNativeFunction("reverse.x32", &VMLAModuleState::ReverseX32),
vm::MakeNativeFunction("pad.x8", &VMLAModuleState::PadX8),
vm::MakeNativeFunction("pad.x16", &VMLAModuleState::PadX16),
vm::MakeNativeFunction("pad.x32", &VMLAModuleState::PadX32),
vm::MakeNativeFunction("gather.x8", &VMLAModuleState::GatherX8),
vm::MakeNativeFunction("gather.x16", &VMLAModuleState::GatherX16),
vm::MakeNativeFunction("gather.x32", &VMLAModuleState::GatherX32),
vm::MakeNativeFunction("scatter.x8", &VMLAModuleState::ScatterX8),
vm::MakeNativeFunction("scatter.x16", &VMLAModuleState::ScatterX16),
vm::MakeNativeFunction("scatter.x32", &VMLAModuleState::ScatterX32),
vm::MakeNativeFunction("iota.i8", &VMLAModuleState::IotaI8),
vm::MakeNativeFunction("iota.i16", &VMLAModuleState::IotaI16),
vm::MakeNativeFunction("iota.i32", &VMLAModuleState::IotaI32),
vm::MakeNativeFunction("iota.f32", &VMLAModuleState::IotaF32),
vm::MakeNativeFunction("tile.x8", &VMLAModuleState::TileX8),
vm::MakeNativeFunction("tile.x16", &VMLAModuleState::TileX16),
vm::MakeNativeFunction("tile.x32", &VMLAModuleState::TileX32),
vm::MakeNativeFunction("not.x8", &VMLAModuleState::NotX8),
vm::MakeNativeFunction("not.x16", &VMLAModuleState::NotX16),
vm::MakeNativeFunction("not.x32", &VMLAModuleState::NotX32),
vm::MakeNativeFunction("and.x8", &VMLAModuleState::AndX8),
vm::MakeNativeFunction("and.x16", &VMLAModuleState::AndX16),
vm::MakeNativeFunction("and.x32", &VMLAModuleState::AndX32),
vm::MakeNativeFunction("or.x8", &VMLAModuleState::OrX8),
vm::MakeNativeFunction("or.x16", &VMLAModuleState::OrX16),
vm::MakeNativeFunction("or.x32", &VMLAModuleState::OrX32),
vm::MakeNativeFunction("xor.x8", &VMLAModuleState::XorX8),
vm::MakeNativeFunction("xor.x16", &VMLAModuleState::XorX16),
vm::MakeNativeFunction("xor.x32", &VMLAModuleState::XorX32),
vm::MakeNativeFunction("shl.x8", &VMLAModuleState::ShlX8),
vm::MakeNativeFunction("shl.x16", &VMLAModuleState::ShlX16),
vm::MakeNativeFunction("shl.x32", &VMLAModuleState::ShlX32),
vm::MakeNativeFunction("shr.u8", &VMLAModuleState::ShrU8),
vm::MakeNativeFunction("shr.u16", &VMLAModuleState::ShrU16),
vm::MakeNativeFunction("shr.u32", &VMLAModuleState::ShrU32),
vm::MakeNativeFunction("shr.i8", &VMLAModuleState::ShrI8),
vm::MakeNativeFunction("shr.i16", &VMLAModuleState::ShrI16),
vm::MakeNativeFunction("shr.i32", &VMLAModuleState::ShrI32),
vm::MakeNativeFunction("add.i8", &VMLAModuleState::AddI8),
vm::MakeNativeFunction("add.i16", &VMLAModuleState::AddI16),
vm::MakeNativeFunction("add.i32", &VMLAModuleState::AddI32),
vm::MakeNativeFunction("add.f32", &VMLAModuleState::AddF32),
vm::MakeNativeFunction("sub.i8", &VMLAModuleState::SubI8),
vm::MakeNativeFunction("sub.i16", &VMLAModuleState::SubI16),
vm::MakeNativeFunction("sub.i32", &VMLAModuleState::SubI32),
vm::MakeNativeFunction("sub.f32", &VMLAModuleState::SubF32),
vm::MakeNativeFunction("abs.i8", &VMLAModuleState::AbsI8),
vm::MakeNativeFunction("abs.i16", &VMLAModuleState::AbsI16),
vm::MakeNativeFunction("abs.i32", &VMLAModuleState::AbsI32),
vm::MakeNativeFunction("abs.f32", &VMLAModuleState::AbsF32),
vm::MakeNativeFunction("neg.i8", &VMLAModuleState::NegI8),
vm::MakeNativeFunction("neg.i16", &VMLAModuleState::NegI16),
vm::MakeNativeFunction("neg.i32", &VMLAModuleState::NegI32),
vm::MakeNativeFunction("neg.f32", &VMLAModuleState::NegF32),
vm::MakeNativeFunction("mul.i8", &VMLAModuleState::MulI8),
vm::MakeNativeFunction("mul.i16", &VMLAModuleState::MulI16),
vm::MakeNativeFunction("mul.i32", &VMLAModuleState::MulI32),
vm::MakeNativeFunction("mul.f32", &VMLAModuleState::MulF32),
vm::MakeNativeFunction("div.i8", &VMLAModuleState::DivI8),
vm::MakeNativeFunction("div.i16", &VMLAModuleState::DivI16),
vm::MakeNativeFunction("div.i32", &VMLAModuleState::DivI32),
vm::MakeNativeFunction("div.u8", &VMLAModuleState::DivU8),
vm::MakeNativeFunction("div.u16", &VMLAModuleState::DivU16),
vm::MakeNativeFunction("div.u32", &VMLAModuleState::DivU32),
vm::MakeNativeFunction("div.f32", &VMLAModuleState::DivF32),
vm::MakeNativeFunction("rem.i8", &VMLAModuleState::RemI8),
vm::MakeNativeFunction("rem.i16", &VMLAModuleState::RemI16),
vm::MakeNativeFunction("rem.i32", &VMLAModuleState::RemI32),
vm::MakeNativeFunction("rem.u8", &VMLAModuleState::RemU8),
vm::MakeNativeFunction("rem.u16", &VMLAModuleState::RemU16),
vm::MakeNativeFunction("rem.u32", &VMLAModuleState::RemU32),
vm::MakeNativeFunction("rem.f32", &VMLAModuleState::RemF32),
vm::MakeNativeFunction("pow.f32", &VMLAModuleState::PowF32),
vm::MakeNativeFunction("exp.f32", &VMLAModuleState::ExpF32),
vm::MakeNativeFunction("log.f32", &VMLAModuleState::LogF32),
vm::MakeNativeFunction("rsqrt.f32", &VMLAModuleState::RsqrtF32),
vm::MakeNativeFunction("sqrt.f32", &VMLAModuleState::SqrtF32),
vm::MakeNativeFunction("cos.f32", &VMLAModuleState::CosF32),
vm::MakeNativeFunction("sin.f32", &VMLAModuleState::SinF32),
vm::MakeNativeFunction("tanh.f32", &VMLAModuleState::TanhF32),
vm::MakeNativeFunction("atan2.f32", &VMLAModuleState::Atan2F32),
vm::MakeNativeFunction("min.i8", &VMLAModuleState::MinI8),
vm::MakeNativeFunction("min.i16", &VMLAModuleState::MinI16),
vm::MakeNativeFunction("min.i32", &VMLAModuleState::MinI32),
vm::MakeNativeFunction("min.f32", &VMLAModuleState::MinF32),
vm::MakeNativeFunction("max.i8", &VMLAModuleState::MaxI8),
vm::MakeNativeFunction("max.i16", &VMLAModuleState::MaxI16),
vm::MakeNativeFunction("max.i32", &VMLAModuleState::MaxI32),
vm::MakeNativeFunction("max.f32", &VMLAModuleState::MaxF32),
vm::MakeNativeFunction("clamp.i8", &VMLAModuleState::ClampI8),
vm::MakeNativeFunction("clamp.i16", &VMLAModuleState::ClampI16),
vm::MakeNativeFunction("clamp.i32", &VMLAModuleState::ClampI32),
vm::MakeNativeFunction("clamp.f32", &VMLAModuleState::ClampF32),
vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
vm::MakeNativeFunction("convert.i8.i32", &VMLAModuleState::ConvertI8I32),
vm::MakeNativeFunction("convert.i8.f32", &VMLAModuleState::ConvertI8F32),
vm::MakeNativeFunction("convert.i16.i8", &VMLAModuleState::ConvertI16I8),
vm::MakeNativeFunction("convert.i16.i32", &VMLAModuleState::ConvertI16I32),
vm::MakeNativeFunction("convert.i16.f32", &VMLAModuleState::ConvertI16F32),
vm::MakeNativeFunction("convert.i32.i8", &VMLAModuleState::ConvertI32I8),
vm::MakeNativeFunction("convert.i32.i16", &VMLAModuleState::ConvertI32I16),
vm::MakeNativeFunction("convert.i32.f32", &VMLAModuleState::ConvertI32F32),
vm::MakeNativeFunction("convert.f32.i8", &VMLAModuleState::ConvertF32I8),
vm::MakeNativeFunction("convert.f32.i16", &VMLAModuleState::ConvertF32I16),
vm::MakeNativeFunction("convert.f32.i32", &VMLAModuleState::ConvertF32I32),
vm::MakeNativeFunction("reduce.sum.i8", &VMLAModuleState::ReduceSumI8),
vm::MakeNativeFunction("reduce.sum.i16", &VMLAModuleState::ReduceSumI16),
vm::MakeNativeFunction("reduce.sum.i32", &VMLAModuleState::ReduceSumI32),
vm::MakeNativeFunction("reduce.sum.f32", &VMLAModuleState::ReduceSumF32),
vm::MakeNativeFunction("reduce.min.i8", &VMLAModuleState::ReduceMinI8),
vm::MakeNativeFunction("reduce.min.i16", &VMLAModuleState::ReduceMinI16),
vm::MakeNativeFunction("reduce.min.i32", &VMLAModuleState::ReduceMinI32),
vm::MakeNativeFunction("reduce.min.f32", &VMLAModuleState::ReduceMinF32),
vm::MakeNativeFunction("reduce.max.i8", &VMLAModuleState::ReduceMaxI8),
vm::MakeNativeFunction("reduce.max.i16", &VMLAModuleState::ReduceMaxI16),
vm::MakeNativeFunction("reduce.max.i32", &VMLAModuleState::ReduceMaxI32),
vm::MakeNativeFunction("reduce.max.f32", &VMLAModuleState::ReduceMaxF32),
vm::MakeNativeFunction("pooling.sum.i8", &VMLAModuleState::PoolingSumI8),
vm::MakeNativeFunction("pooling.sum.i16", &VMLAModuleState::PoolingSumI16),
vm::MakeNativeFunction("pooling.sum.i32", &VMLAModuleState::PoolingSumI32),
vm::MakeNativeFunction("pooling.sum.f32", &VMLAModuleState::PoolingSumF32),
vm::MakeNativeFunction("pooling.min.i8", &VMLAModuleState::PoolingMinI8),
vm::MakeNativeFunction("pooling.min.i16", &VMLAModuleState::PoolingMinI16),
vm::MakeNativeFunction("pooling.min.i32", &VMLAModuleState::PoolingMinI32),
vm::MakeNativeFunction("pooling.min.f32", &VMLAModuleState::PoolingMinF32),
vm::MakeNativeFunction("pooling.max.i8", &VMLAModuleState::PoolingMaxI8),
vm::MakeNativeFunction("pooling.max.i16", &VMLAModuleState::PoolingMaxI16),
vm::MakeNativeFunction("pooling.max.i32", &VMLAModuleState::PoolingMaxI32),
vm::MakeNativeFunction("pooling.max.f32", &VMLAModuleState::PoolingMaxF32),
vm::MakeNativeFunction("batch.matmul.f32f32.f32",
&VMLAModuleState::BatchMatMulF32F32F32),
vm::MakeNativeFunction("conv.f32f32.f32", &VMLAModuleState::ConvF32F32F32)};
// Per-device VMLA module.
// One of these will be created per device and be shared across all executables
// that are created within that device. Large shared kernel state can go here
// (such as thread pools/caches/etc), though note that they must be either
// thread-safe or internally synchronized.
//
// Thread-safe.
class VMLAModule final : public vm::NativeModule<VMLAModuleState> {
public:
explicit VMLAModule(iree_allocator_t allocator)
: vm::NativeModule<VMLAModuleState>(
"vmla", allocator, absl::MakeConstSpan(kVMLAModuleFunctions)) {}
~VMLAModule() = default;
Status Initialize() {
IREE_TRACE_SCOPE0("VMLAModule::Initialize");
return OkStatus();
}
StatusOr<std::unique_ptr<VMLAModuleState>> CreateState(
iree_allocator_t allocator) override {
IREE_TRACE_SCOPE0("VMLAModule::CreateState");
auto state = std::make_unique<VMLAModuleState>(allocator, &kernel_state_);
return state;
}
private:
// NOTE: shared across all contexts with the VMLA module loaded. See
// VMLAModuleState::kernel_state_ for more information.
kernels::RuntimeState kernel_state_;
};
} // namespace
Status ModuleCreate(iree_allocator_t allocator, iree_vm_module_t** out_module) {
if (!out_module) {
return InvalidArgumentErrorBuilder(IREE_LOC)
<< "out_module must not be null";
}
*out_module = nullptr;
auto module = std::make_unique<VMLAModule>(allocator);
IREE_RETURN_IF_ERROR(module->Initialize());
*out_module = module.release()->interface();
return OkStatus();
}
} // namespace vmla
} // namespace hal
} // namespace iree