blob: f795185bf1ef305f0fb1eef59504ee08a4478566 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#define _GNU_SOURCE
#include "iree/modules/vmvx/module.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/vm/api.h"
// Include the ukernel support library so that we can use its implementations
// as fixed-function components of the runtime.
#include "iree/base/internal/cpu.h"
#include "iree/builtins/ukernel/api.h"
#define IREE_VMVX_MODULE_VERSION_0_0 0x00000000u
#define IREE_VMVX_MODULE_VERSION_LATEST IREE_VMVX_MODULE_VERSION_0_0
// Implementation of iree_uk_assert_fail failure is deferred to users code, i.e.
// to us here, as core ukernel/ code can't do I/O or depend on anything.
void iree_uk_assert_fail(const char* file, int line, const char* function,
const char* condition) {
#if (!defined(NDEBUG)) && (IREE_FILE_IO_ENABLE == 1)
fflush(stdout);
// Must be a single fprintf call (which must make a single write) - typically
// called from multiple worker threads concurrently.
fprintf(stderr, "%s:%d: %s: assertion failed: %s\n", file, line, function,
condition);
fflush(stderr);
#endif // (!defined(NDEBUG)) && (IREE_FILE_IO_ENABLE == 1)
IREE_ASSERT(false);
}
//===----------------------------------------------------------------------===//
// Module type definitions
//===----------------------------------------------------------------------===//
typedef struct iree_vmvx_module_t {
iree_allocator_t host_allocator;
// TODO(benvanik): types when we are not registering them globally.
} iree_vmvx_module_t;
#define IREE_VMVX_MODULE_CAST(module) \
(iree_vmvx_module_t*)((uint8_t*)(module) + iree_vm_native_module_size());
typedef struct iree_vmvx_module_state_t {
iree_allocator_t host_allocator;
// Logical processor identifier used to index into processor info fields.
// Depending on the implementation this may be an ordinal, a bitfield, or an
// opaque unique identifier.
uint32_t processor_id;
// If we have any external libraries we want to interact with that are
// stateful we could store their state here. Note that VMVX invocations may
// happen from any thread and concurrently and if the state is not thread-safe
// we'll have to perform the synchronization ourselves here. That'd be bad,
// of course, and an indication that whatever is being called is not suited
// for this use.
} iree_vmvx_module_state_t;
static void IREE_API_PTR iree_vmvx_module_destroy(void* base_module) {
// No state to clean up (yet).
}
static iree_status_t IREE_API_PTR
iree_vmvx_module_alloc_state(void* self, iree_allocator_t host_allocator,
iree_vm_module_state_t** out_module_state) {
iree_vmvx_module_state_t* state = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, sizeof(*state), (void**)&state));
memset(state, 0, sizeof(*state));
state->host_allocator = host_allocator;
*out_module_state = (iree_vm_module_state_t*)state;
return iree_ok_status();
}
static void IREE_API_PTR
iree_vmvx_module_free_state(void* self, iree_vm_module_state_t* module_state) {
iree_vmvx_module_state_t* state = (iree_vmvx_module_state_t*)module_state;
iree_allocator_free(state->host_allocator, state);
}
//===----------------------------------------------------------------------===//
// Argument validation and marshalling
//===----------------------------------------------------------------------===//
static iree_host_size_t iree_vmvx_2d_length_bound(
iree_host_size_t element_size, uint64_t size0, uint64_t size1,
uint64_t stride0, uint64_t stride1, uint64_t* overflow) {
// Check for 2d size/stride overflow conditions for the equation:
// (size0 - 1) * stride0 + (size1 - 1) * stride1
// This limits each (multiplicand + 1) to the 32bit range. We can get
// smarter about this later or when scaling to >2D, but while limited, this
// is easy and correct.
*overflow |= (size0 & 0xffffffff00000000) | (size1 & 0xffffffff00000000) |
((stride0 + 1) & 0xffffffff00000000) |
((stride1 + 1) & 0xffffffff00000000);
uint64_t last_index = (size0 - 1) * stride0 + (size1 - 1) * stride1;
uint64_t max_size = (last_index + 1) * element_size;
iree_host_size_t max_size_size_t = (iree_host_size_t)max_size;
*overflow |= (max_size_size_t != max_size); // No-op for 64bit size_t.
return max_size_size_t;
}
static iree_host_size_t iree_vmvx_cast_host_size(int64_t value,
uint64_t* overflow) {
if (sizeof(iree_host_size_t) == 4) {
*overflow |= (uint64_t)value & 0xffffffff00000000ul;
}
return (iree_host_size_t)value;
}
#define BUFFER_2D_DECLS(name, dtype_size, offset, stride0, stride1, size0, \
size1) \
uint64_t name##_overflow = 0; \
iree_host_size_t name##_size0 = \
iree_vmvx_cast_host_size(size0, &name##_overflow); \
iree_host_size_t name##_size1 = \
iree_vmvx_cast_host_size(size1, &name##_overflow); \
iree_host_size_t name##_stride0 = \
iree_vmvx_cast_host_size(stride0, &name##_overflow); \
iree_host_size_t name##_stride1 = \
iree_vmvx_cast_host_size(stride1, &name##_overflow); \
iree_host_size_t name##_length_bound = iree_vmvx_2d_length_bound( \
dtype_size, name##_size0, name##_size1, name##_stride0, name##_stride1, \
&name##_overflow); \
iree_host_size_t name##_offset = \
dtype_size * iree_vmvx_cast_host_size(offset, &name##_overflow); \
if (name##_overflow) { \
IREE_TRACE_ZONE_END(z0); \
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, \
"buffer overflow for " #name); \
}
#define MAP_BUFFER_2D_IMPL(mode, ptr_type, span_type, name, dtype_size, \
buffer_ref, offset, stride0, stride1, size0, size1) \
iree_vm_buffer_t* name##_buffer; \
span_type name##_span; \
BUFFER_2D_DECLS(name, dtype_size, offset, stride0, stride1, size0, size1); \
IREE_RETURN_AND_END_ZONE_IF_ERROR( \
z0, iree_vm_buffer_check_deref(buffer_ref, &name##_buffer)) \
IREE_RETURN_AND_END_ZONE_IF_ERROR( \
z0, iree_vm_buffer_map_##mode(name##_buffer, /*offset=*/ \
name##_offset, /*length=*/ \
name##_length_bound, /*alignment=*/ \
dtype_size, &name##_span)); \
ptr_type name = (ptr_type)name##_span.data
#define MAP_BUFFER_2D_UNTYPED_RO(name, dtype_size, ...) \
MAP_BUFFER_2D_IMPL(ro, const void*, iree_const_byte_span_t, name, \
dtype_size, __VA_ARGS__)
#define MAP_BUFFER_2D_UNTYPED_RW(name, dtype_size, ...) \
MAP_BUFFER_2D_IMPL(rw, void*, iree_byte_span_t, name, dtype_size, __VA_ARGS__)
#define MAP_BUFFER_2D_RO(name, dtype, ...) \
MAP_BUFFER_2D_IMPL(ro, const dtype*, iree_const_byte_span_t, name, \
sizeof(dtype), __VA_ARGS__)
#define MAP_BUFFER_2D_RW(name, dtype, ...) \
MAP_BUFFER_2D_IMPL(rw, dtype*, iree_byte_span_t, name, sizeof(dtype), \
__VA_ARGS__)
//===----------------------------------------------------------------------===//
// Shared argument shims
//===----------------------------------------------------------------------===//
#define IREE_VMVX_ABI_EXPORT(function_name, arg_types, ret_types) \
IREE_VM_ABI_EXPORT(function_name, iree_vmvx_module_state_t, arg_types, \
ret_types)
#define IREE_VMVX_ABI_FIXED_STRUCT(name, types, body) \
IREE_VM_ABI_FIXED_STRUCT(name, body)
#define IREE_VMVX_ABI_DEFINE_SHIM(arg_types, ret_types) \
static IREE_VM_ABI_DEFINE_SHIM(arg_types, ret_types)
IREE_VMVX_ABI_FIXED_STRUCT(unary2d, rIIIrIIIII, {
iree_vm_ref_t in_ref;
int64_t in_offset;
int64_t in_stride0;
int64_t in_stride1;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t out_stride1;
int64_t size0;
int64_t size1;
});
IREE_VMVX_ABI_DEFINE_SHIM(unary2d, v);
IREE_VMVX_ABI_FIXED_STRUCT(binary2d, rIIIrIIIrIIIII, {
iree_vm_ref_t lhs_ref;
int64_t lhs_offset;
int64_t lhs_stride0;
int64_t lhs_stride1;
iree_vm_ref_t rhs_ref;
int64_t rhs_offset;
int64_t rhs_stride0;
int64_t rhs_stride1;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t out_stride1;
int64_t size0;
int64_t size1;
});
IREE_VMVX_ABI_DEFINE_SHIM(binary2d, v);
//===----------------------------------------------------------------------===//
// Ukernel shims. These shims are a bit different in that they directly marshal
// to a low level ukernel target function.
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(ukernel_x32b_2d, rIIIrIIIrIIIII, {
iree_vm_ref_t lhs_ref;
int64_t lhs_offset;
int64_t lhs_stride0;
int64_t lhs_stride1;
iree_vm_ref_t rhs_ref;
int64_t rhs_offset;
int64_t rhs_stride0;
int64_t rhs_stride1;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t out_stride1;
int64_t size0;
int64_t size1;
});
static iree_status_t iree_vm_shim_ukernel_x32b_2d_v(
iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags,
iree_byte_span_t args_storage, iree_byte_span_t rets_storage,
iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module,
void* IREE_RESTRICT module_state) {
// TODO: Figure out how to identify this with the actual target fn.
IREE_TRACE_ZONE_BEGIN(z0);
const iree_vm_abi_ukernel_x32b_2d_t* args =
iree_vm_abi_ukernel_x32b_2d_checked_deref(args_storage);
if (IREE_UNLIKELY(!((flags & IREE_VM_NATIVE_FUNCTION_CALL_RESUME) || args))) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"argument/result signature mismatch");
}
MAP_BUFFER_2D_RO(lhs, uint32_t,
/*buffer_ref=*/args->lhs_ref,
/*offset=*/args->lhs_offset,
/*stride0=*/args->lhs_stride0,
/*stride1=*/args->lhs_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RO(rhs, uint32_t,
/*buffer_ref=*/args->rhs_ref,
/*offset=*/args->rhs_offset,
/*stride0=*/args->rhs_stride0,
/*stride1=*/args->rhs_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RW(out, uint32_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/args->out_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
iree_uk_x32b_2d_func_t ukernel_func = (iree_uk_x32b_2d_func_t)target_fn;
int ret = ukernel_func(
// LHS
lhs, lhs_offset, lhs_stride0, lhs_stride1,
// RHS
rhs, rhs_offset, rhs_stride0, rhs_stride1,
// OUT
out, out_offset, out_stride0, out_stride1,
// SIZE
out_size0, out_size1);
IREE_TRACE_ZONE_END(z0);
return ret == 0
? iree_ok_status()
: iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"illegal x32b ukernel return code (%d)", ret);
}
IREE_VMVX_ABI_FIXED_STRUCT(ukernel_x32u_2d, rIIIrIIIII, {
iree_vm_ref_t in_ref;
int64_t in_offset;
int64_t in_stride0;
int64_t in_stride1;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t out_stride1;
int64_t size0;
int64_t size1;
});
static iree_status_t iree_vm_shim_ukernel_x32u_2d_v(
iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags,
iree_byte_span_t args_storage, iree_byte_span_t rets_storage,
iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module,
void* IREE_RESTRICT module_state) {
// TODO: Figure out how to identify this with the actual target fn.
IREE_TRACE_ZONE_BEGIN(z0);
const iree_vm_abi_ukernel_x32u_2d_t* args =
iree_vm_abi_ukernel_x32u_2d_checked_deref(args_storage);
if (IREE_UNLIKELY(!((flags & IREE_VM_NATIVE_FUNCTION_CALL_RESUME) || args))) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"argument/result signature mismatch");
}
MAP_BUFFER_2D_RO(in, uint32_t,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/args->in_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RW(out, uint32_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/args->out_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
iree_uk_x32u_2d_func_t ukernel_func = (iree_uk_x32u_2d_func_t)target_fn;
int ret = ukernel_func(
// IN
in, in_offset, in_stride0, in_stride1,
// OUT
out, out_offset, out_stride0, out_stride1,
// SIZE
out_size0, out_size1);
IREE_TRACE_ZONE_END(z0);
return ret == 0
? iree_ok_status()
: iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"illegal x32u ukernel return code (%d)", ret);
}
//===----------------------------------------------------------------------===//
// Exported copy function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_EXPORT(iree_vmvx_copy2d_x8, unary2d, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RO(in, int8_t,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/args->in_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RW(out, int8_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/args->out_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
for (iree_host_size_t j = 0; j < out_size0; ++j) {
for (iree_host_size_t i = 0; i < out_size1; ++i) {
out[j * out_stride0 + i * out_stride1] =
in[j * in_stride0 + i * in_stride1];
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_copy2d_x16, unary2d, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RO(in, int16_t,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/args->in_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RW(out, int16_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/args->out_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
for (iree_host_size_t j = 0; j < out_size0; ++j) {
for (iree_host_size_t i = 0; i < out_size1; ++i) {
out[j * out_stride0 + i * out_stride1] =
in[j * in_stride0 + i * in_stride1];
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_copy2d_x32, unary2d, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RO(in, int32_t,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/args->in_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RW(out, int32_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/args->out_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
for (iree_host_size_t j = 0; j < out_size0; ++j) {
for (iree_host_size_t i = 0; i < out_size1; ++i) {
out[j * out_stride0 + i * out_stride1] =
in[j * in_stride0 + i * in_stride1];
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_copy2d_x64, unary2d, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RO(in, int64_t,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/args->in_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
MAP_BUFFER_2D_RW(out, int64_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/args->out_stride1,
/*size0=*/args->size0,
/*size1=*/args->size1);
for (iree_host_size_t j = 0; j < out_size0; ++j) {
for (iree_host_size_t i = 0; i < out_size1; ++i) {
out[j * out_stride0 + i * out_stride1] =
in[j * in_stride0 + i * in_stride1];
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// Exported fill function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(fill2d_x32, irIIII, {
int32_t fill_value;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_row_stride;
int64_t size0;
int64_t size1;
});
IREE_VMVX_ABI_DEFINE_SHIM(fill2d_x32, v);
IREE_VMVX_ABI_EXPORT(iree_vmvx_fill2d_x32, fill2d_x32, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RW(out, int32_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_row_stride,
/*stride1=*/1,
/*size0=*/args->size0,
/*size1=*/args->size1);
for (iree_host_size_t i = 0; i < out_size0; ++i) {
for (iree_host_size_t j = 0; j < out_size1; ++j) {
out[i * out_stride0 + j] = args->fill_value;
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// Exported matmul function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(matmul, rIIrIIrIIIIIi, {
iree_vm_ref_t lhs_ref;
int64_t lhs_offset;
int64_t lhs_row_stride;
iree_vm_ref_t rhs_ref;
int64_t rhs_offset;
int64_t rhs_row_stride;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_row_stride;
int64_t m;
int64_t n;
int64_t k;
int32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(matmul, v);
IREE_VMVX_ABI_EXPORT(iree_vmvx_matmul_f32f32f32, matmul, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RO(lhs, float,
/*buffer_ref=*/args->lhs_ref,
/*offset=*/args->lhs_offset,
/*stride0=*/args->lhs_row_stride,
/*stride1=*/1,
/*size0=*/args->m,
/*size1=*/args->k);
MAP_BUFFER_2D_RO(rhs, float,
/*buffer_ref=*/args->rhs_ref,
/*offset=*/args->rhs_offset,
/*stride0=*/args->rhs_row_stride,
/*stride1=*/1,
/*size0=*/args->k,
/*size1=*/args->n);
MAP_BUFFER_2D_RW(out, float,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_row_stride,
/*stride1=*/1,
/*size0=*/args->m,
/*size1=*/args->n);
iree_host_size_t M = (iree_host_size_t)args->m;
iree_host_size_t N = (iree_host_size_t)args->n;
iree_host_size_t K = (iree_host_size_t)args->k;
// TODO: define flags more robustly
unsigned accumulate_flag = args->flags & IREE_UK_FLAG_ACCUMULATE;
unsigned unhandled_flags = args->flags ^ accumulate_flag;
if (unhandled_flags) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unsupported matmul flags: 0x%x", unhandled_flags);
}
for (iree_host_size_t i = 0; i < M; ++i) {
for (iree_host_size_t j = 0; j < N; ++j) {
float* out_ptr = out + i * out_stride0 + j;
float acc = accumulate_flag ? *out_ptr : 0.f;
for (iree_host_size_t k = 0; k < K; ++k) {
acc += lhs[i * lhs_stride0 + k] * rhs[k * rhs_stride0 + j];
}
*out_ptr = acc;
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_matmul_i8i8i32, matmul, v) {
IREE_TRACE_ZONE_BEGIN(z0);
MAP_BUFFER_2D_RO(lhs, int8_t,
/*buffer_ref=*/args->lhs_ref,
/*offset=*/args->lhs_offset,
/*stride0=*/args->lhs_row_stride,
/*stride1=*/1,
/*size0=*/args->m,
/*size1=*/args->k);
MAP_BUFFER_2D_RO(rhs, int8_t,
/*buffer_ref=*/args->rhs_ref,
/*offset=*/args->rhs_offset,
/*stride0=*/args->rhs_row_stride,
/*stride1=*/1,
/*size0=*/args->k,
/*size1=*/args->n);
MAP_BUFFER_2D_RW(out, int32_t,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_row_stride,
/*stride1=*/1,
/*size0=*/args->m,
/*size1=*/args->n);
iree_host_size_t M = (iree_host_size_t)args->m;
iree_host_size_t N = (iree_host_size_t)args->n;
iree_host_size_t K = (iree_host_size_t)args->k;
// TODO: define flags more robustly
unsigned accumulate_flag = args->flags & IREE_UK_FLAG_ACCUMULATE;
unsigned unhandled_flags = args->flags ^ accumulate_flag;
if (unhandled_flags) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unsupported matmul flags: 0x%x", unhandled_flags);
}
for (iree_host_size_t i = 0; i < M; ++i) {
for (iree_host_size_t j = 0; j < N; ++j) {
int32_t* out_ptr = out + i * out_stride0 + j;
int32_t acc = accumulate_flag ? *out_ptr : 0.f;
for (iree_host_size_t k = 0; k < K; ++k) {
// C's implicit promotion to int saves skin, but let's be explicit.
int32_t lhs_val_int32 = lhs[i * lhs_stride0 + k];
int32_t rhs_val_int32 = rhs[k * rhs_stride0 + j];
acc += lhs_val_int32 * rhs_val_int32;
}
*out_ptr = acc;
}
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// Exported mmt4d function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(mmt4d, rIIrIIrIIIIIiiii, {
iree_vm_ref_t lhs_ref;
int64_t lhs_offset;
int64_t lhs_row_stride;
iree_vm_ref_t rhs_ref;
int64_t rhs_offset;
int64_t rhs_row_stride;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_row_stride;
int64_t m;
int64_t n;
int64_t k;
int32_t m0;
int32_t n0;
int32_t k0;
uint32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(mmt4d, v);
static iree_status_t iree_vmvx_mmt4d(iree_uk_mmt4d_type_t type,
const iree_vm_abi_mmt4d_t* args) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_host_size_t M = (iree_host_size_t)args->m;
iree_host_size_t N = (iree_host_size_t)args->n;
iree_host_size_t K = (iree_host_size_t)args->k;
iree_host_size_t M0 = (iree_host_size_t)args->m0;
iree_host_size_t N0 = (iree_host_size_t)args->n0;
iree_host_size_t K0 = (iree_host_size_t)args->k0;
iree_host_size_t lhs_tile_size = M0 * K0;
iree_host_size_t rhs_tile_size = N0 * K0;
iree_host_size_t out_tile_size = M0 * N0;
int lhs_elem_size = iree_uk_type_size(iree_uk_mmt4d_lhs_type(type));
int rhs_elem_size = iree_uk_type_size(iree_uk_mmt4d_rhs_type(type));
int out_elem_size = iree_uk_type_size(iree_uk_mmt4d_out_type(type));
// Here are abusing the 2D-specific macros MAP_BUFFER_2D_* to query 4D arrays.
// Thanks to the requirement that all dimensions but the outer-most one are
// contiguous row-major, the outer-most stride is the only nontrivial stride,
// we can correctly coalesce the inner 3 dimensions without changing the
// mapped span.
MAP_BUFFER_2D_UNTYPED_RO(lhs,
/*dtype_size=*/lhs_elem_size,
/*buffer_ref=*/args->lhs_ref,
/*offset=*/args->lhs_offset,
/*stride0=*/args->lhs_row_stride,
/*stride1=*/1,
/*size0=*/M,
/*size1=*/K * lhs_tile_size);
MAP_BUFFER_2D_UNTYPED_RO(rhs, /*dtype_size=*/rhs_elem_size,
/*buffer_ref=*/args->rhs_ref,
/*offset=*/args->rhs_offset,
/*stride0=*/args->rhs_row_stride,
/*stride1=*/1,
/*size0=*/N,
/*size1=*/K * rhs_tile_size);
MAP_BUFFER_2D_UNTYPED_RW(out, /*dtype_size=*/out_elem_size,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_row_stride,
/*stride1=*/1,
/*size0=*/M,
/*size1=*/N * out_tile_size);
iree_uk_mmt4d_params_t ukernel_params = {
.type = type,
.flags = args->flags,
.lhs_buffer = lhs,
.rhs_buffer = rhs,
.out_buffer = out,
.lhs_stride = lhs_stride0,
.rhs_stride = rhs_stride0,
.out_stride = out_stride0,
.M = M,
.N = N,
.K = K,
.M0 = M0,
.N0 = N0,
.K0 = K0,
.cpu_data = (const iree_uk_uint64_t*)iree_cpu_data_fields(),
};
iree_uk_mmt4d(&ukernel_params);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_f32f32f32, mmt4d, v) {
return iree_vmvx_mmt4d(iree_uk_mmt4d_type_f32f32f32, args);
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_mmt4d_i8i8i32, mmt4d, v) {
return iree_vmvx_mmt4d(iree_uk_mmt4d_type_i8i8i32, args);
}
//===----------------------------------------------------------------------===//
// Exported pack function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(pack_f, rIIIrIIIIIIIIfi, {
iree_vm_ref_t in_ref;
int64_t in_offset;
int64_t in_stride0;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t in_size0;
int64_t in_size1;
int64_t out_size0;
int64_t out_size1;
int64_t out_size2;
int64_t out_size3;
float padding_value;
uint32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(pack_f, v);
IREE_VMVX_ABI_FIXED_STRUCT(pack_i, rIIIrIIIIIIIIii, {
iree_vm_ref_t in_ref;
int64_t in_offset;
int64_t in_stride0;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t in_size0;
int64_t in_size1;
int64_t out_size0;
int64_t out_size1;
int64_t out_size2;
int64_t out_size3;
int32_t padding_value;
uint32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(pack_i, v);
static iree_status_t iree_vmvx_pack_f(iree_uk_pack_type_t type,
const iree_vm_abi_pack_f_t* args) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_host_size_t out_tile_size = args->out_size2 * args->out_size3;
int in_elem_size = iree_uk_type_size(iree_uk_pack_in_type(type));
int out_elem_size = iree_uk_type_size(iree_uk_pack_out_type(type));
MAP_BUFFER_2D_UNTYPED_RO(in,
/*dtype_size=*/in_elem_size,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/1,
/*size0=*/args->in_size0,
/*size1=*/args->in_size1);
MAP_BUFFER_2D_UNTYPED_RW(out, /*dtype_size=*/out_elem_size,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/1,
/*size0=*/args->out_size0,
/*size1=*/args->out_size1 * out_tile_size);
iree_uk_pack_params_t ukernel_params = {
.type = type,
.in_buffer = in,
.out_buffer = out,
.in_stride0 = args->in_stride0,
.out_stride0 = args->out_stride0,
.in_size0 = args->in_size0,
.in_size1 = args->in_size1,
.out_size0 = args->out_size0,
.out_size1 = args->out_size1,
.out_size2 = args->out_size2,
.out_size3 = args->out_size3,
.padding_value = &args->padding_value,
.flags = args->flags,
};
iree_uk_pack(&ukernel_params);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
static iree_status_t iree_vmvx_pack_i(iree_uk_pack_type_t type,
const iree_vm_abi_pack_i_t* args) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_host_size_t out_tile_size = args->out_size2 * args->out_size3;
int in_elem_size = iree_uk_type_size(iree_uk_pack_in_type(type));
int out_elem_size = iree_uk_type_size(iree_uk_pack_out_type(type));
MAP_BUFFER_2D_UNTYPED_RO(in,
/*dtype_size=*/in_elem_size,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/1,
/*size0=*/args->in_size0,
/*size1=*/args->in_size1);
MAP_BUFFER_2D_UNTYPED_RW(out, /*dtype_size=*/out_elem_size,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/1,
/*size0=*/args->out_size0,
/*size1=*/args->out_size1 * out_tile_size);
iree_uk_pack_params_t ukernel_params = {
.type = type,
.in_buffer = in,
.out_buffer = out,
.in_stride0 = args->in_stride0,
.out_stride0 = args->out_stride0,
.in_size0 = args->in_size0,
.in_size1 = args->in_size1,
.out_size0 = args->out_size0,
.out_size1 = args->out_size1,
.out_size2 = args->out_size2,
.out_size3 = args->out_size3,
.padding_value = &args->padding_value,
.flags = args->flags,
};
iree_uk_pack(&ukernel_params);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_pack_f32f32, pack_f, v) {
return iree_vmvx_pack_f(iree_uk_pack_type_f32f32, args);
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_pack_i8i8, pack_i, v) {
return iree_vmvx_pack_i(iree_uk_pack_type_i8i8, args);
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_pack_i32i32, pack_i, v) {
return iree_vmvx_pack_i(iree_uk_pack_type_i32i32, args);
}
//===----------------------------------------------------------------------===//
// Exported unpack function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(unpack, rIIIrIIIIIIIIi, {
iree_vm_ref_t in_ref;
int64_t in_offset;
int64_t in_stride0;
iree_vm_ref_t out_ref;
int64_t out_offset;
int64_t out_stride0;
int64_t in_size0;
int64_t in_size1;
int64_t in_size2;
int64_t in_size3;
int64_t out_size0;
int64_t out_size1;
uint32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(unpack, v);
static iree_status_t iree_vmvx_unpack(iree_uk_unpack_type_t type,
const iree_vm_abi_unpack_t* args) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_host_size_t out_tile_size = args->in_size2 * args->in_size3;
int in_elem_size = iree_uk_type_size(iree_uk_unpack_in_type(type));
int out_elem_size = iree_uk_type_size(iree_uk_unpack_out_type(type));
MAP_BUFFER_2D_UNTYPED_RO(in,
/*dtype_size=*/in_elem_size,
/*buffer_ref=*/args->in_ref,
/*offset=*/args->in_offset,
/*stride0=*/args->in_stride0,
/*stride1=*/1,
/*size0=*/args->in_size0,
/*size1=*/args->in_size1 * out_tile_size);
MAP_BUFFER_2D_UNTYPED_RW(out, /*dtype_size=*/out_elem_size,
/*buffer_ref=*/args->out_ref,
/*offset=*/args->out_offset,
/*stride0=*/args->out_stride0,
/*stride1=*/1,
/*size0=*/args->out_size0,
/*size1=*/args->out_size1);
iree_uk_unpack_params_t ukernel_params = {
.type = type,
.in_buffer = in,
.out_buffer = out,
.in_stride0 = args->in_stride0,
.out_stride0 = args->out_stride0,
.in_size0 = args->in_size0,
.in_size1 = args->in_size1,
.in_size2 = args->in_size2,
.in_size3 = args->in_size3,
.out_size0 = args->out_size0,
.out_size1 = args->out_size1,
.flags = args->flags,
};
iree_uk_unpack(&ukernel_params);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_unpack_f32f32, unpack, v) {
return iree_vmvx_unpack(iree_uk_unpack_type_f32f32, args);
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_unpack_i8i8, unpack, v) {
return iree_vmvx_unpack(iree_uk_unpack_type_i8i8, args);
}
IREE_VMVX_ABI_EXPORT(iree_vmvx_unpack_i32i32, unpack, v) {
return iree_vmvx_unpack(iree_uk_unpack_type_i32i32, args);
}
//===----------------------------------------------------------------------===//
// Exported query_tile_sizes function definitions
//===----------------------------------------------------------------------===//
IREE_VMVX_ABI_FIXED_STRUCT(query_tile_sizes_2d, IIi, {
int64_t size0;
int64_t size1;
uint32_t flags;
});
IREE_VMVX_ABI_DEFINE_SHIM(query_tile_sizes_2d, II);
IREE_VMVX_ABI_EXPORT(iree_vmvx_query_tile_sizes_2d, query_tile_sizes_2d, II) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_uk_query_tile_sizes_2d_params_t ukernel_params = {
.size0 = args->size0,
.size1 = args->size1,
.flags = args->flags,
.cpu_data = (const iree_uk_uint64_t*)iree_cpu_data_fields(),
};
iree_uk_query_tile_sizes_2d_out_params_t ukernel_out_params;
iree_uk_query_tile_sizes_2d(&ukernel_params, &ukernel_out_params);
rets->i0 = ukernel_out_params.tile_size0;
rets->i1 = ukernel_out_params.tile_size1;
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// VM module interface implementation
//===----------------------------------------------------------------------===//
// NOTE: this must match the ordering of the iree_vmvx_module_exports_table.
static const iree_vm_native_function_ptr_t iree_vmvx_module_funcs_[] = {
#define EXPORT_FN(name, target_fn, arg_struct, arg_types, ret_types) \
{ \
.shim = (iree_vm_native_function_shim_t) \
iree_vm_shim_##arg_struct##_##ret_types, \
.target = (iree_vm_native_function_target_t)(target_fn), \
},
#include "iree/modules/vmvx/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
};
// NOTE: 0 length, but can't express that in C.
static const iree_vm_native_import_descriptor_t iree_vmvx_module_imports_[1];
static const iree_vm_native_export_descriptor_t iree_vmvx_module_exports_[] = {
#define EXPORT_FN(name, target_fn, arg_struct, arg_types, ret_types) \
{ \
.local_name = iree_string_view_literal(name), \
.calling_convention = \
iree_string_view_literal("0" #arg_types "_" #ret_types), \
.attr_count = 0, \
.attrs = NULL, \
},
#include "iree/modules/vmvx/exports.inl" // IWYU pragma: keep
#undef EXPORT_FN
};
static_assert(IREE_ARRAYSIZE(iree_vmvx_module_funcs_) ==
IREE_ARRAYSIZE(iree_vmvx_module_exports_),
"function pointer table must be 1:1 with exports");
static const iree_vm_native_module_descriptor_t iree_vmvx_module_descriptor_ = {
.name = iree_string_view_literal("vmvx"),
.version = IREE_VMVX_MODULE_VERSION_LATEST,
.attr_count = 0,
.attrs = NULL,
.dependency_count = 0,
.dependencies = NULL,
.import_count = 0, // workaround for 0-length C struct
.imports = iree_vmvx_module_imports_,
.export_count = IREE_ARRAYSIZE(iree_vmvx_module_exports_),
.exports = iree_vmvx_module_exports_,
.function_count = IREE_ARRAYSIZE(iree_vmvx_module_funcs_),
.functions = iree_vmvx_module_funcs_,
};
IREE_API_EXPORT iree_status_t iree_vmvx_module_create(
iree_vm_instance_t* instance, iree_allocator_t host_allocator,
iree_vm_module_t** out_module) {
IREE_ASSERT_ARGUMENT(instance);
IREE_ASSERT_ARGUMENT(out_module);
*out_module = NULL;
// Setup the interface with the functions we implement ourselves. Any function
// we omit will be handled by the base native module.
static const iree_vm_module_t interface = {
.destroy = iree_vmvx_module_destroy,
.alloc_state = iree_vmvx_module_alloc_state,
.free_state = iree_vmvx_module_free_state,
};
// Allocate shared module state.
iree_host_size_t total_size =
iree_vm_native_module_size() + sizeof(iree_vmvx_module_t);
iree_vm_module_t* base_module = NULL;
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, total_size, (void**)&base_module));
memset(base_module, 0, total_size);
iree_status_t status = iree_vm_native_module_initialize(
&interface, &iree_vmvx_module_descriptor_, instance, host_allocator,
base_module);
if (!iree_status_is_ok(status)) {
iree_allocator_free(host_allocator, base_module);
return status;
}
iree_vmvx_module_t* module = IREE_VMVX_MODULE_CAST(base_module);
module->host_allocator = host_allocator;
*out_module = base_module;
return iree_ok_status();
}
IREE_API_EXPORT void iree_vmvx_module_state_update_workgroup_state(
iree_vm_module_state_t* module_state, uint32_t processor_id) {
iree_vmvx_module_state_t* state = (iree_vmvx_module_state_t*)module_state;
state->processor_id = processor_id;
}