blob: 0746c4e0073d5ca512c985c834b92d5730a2cd18 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// NOTE: it is not safe to use the standard library functions in here.
// Attempting to allocate memory (outside of small stack allocations), make
// syscalls, use thread-local state, or pull in externally defined standard
// library functions will result in a bad time.
//
// It is safe to include definitions/macros/etc but be veeery careful!
// Embedded ELF shared libraries are intended to be portable across operating
// systems and environments (including to bare-metal systems) and though the
// IREE compiler can ensure it does not pull in things that may run afoul of
// that it's the user's responsibility when injecting code like this.
#include <stddef.h>
#include <stdint.h>
// NOTE: kernels must be exported with C naming (no C++ mangling) in order to
// match the names used in the IR declarations.
// NOTE: IREE ensures all bindings don't alias their active subranges and
// it is safe to mark them as restrict. This is critical as the C compiler can't
// analyze the codegen when compiling and has to play it safe by assuming any
// write to any binding could be visible through other bindings.
// NOTE: memref lowering in MLIR -> LLVM currently expands to two pointers and
// three ints - do not rely on this behavior and only use the first pointer.
// At some point someone will fix upstream to allow for passing raw base
// pointers and the function signatures here will become much less verbose.
// NOTE: MLIR's index type will map to either i32 or i64 based on the target
// pointer width. size_t (or ssize_t) can be used in source to match that type.
// `ret = lhs * rhs`
//
// Conforms to ABI:
// #hal.pipeline.layout<push_constants = 1, sets = [
// <0, bindings = [
// <0, storage_buffer, ReadOnly>,
// <1, storage_buffer, ReadOnly>,
// <2, storage_buffer>
// ]>
// ]>
// With a workgroup size of 64x1x1.
void simple_mul_workgroup(
// vvvv simplification pending (buffer + offset)
const float* restrict binding0, const float* restrict binding0_aligned,
size_t binding0_offset, size_t binding0_size, size_t binding0_stride,
const float* restrict binding1, const float* restrict binding1_aligned,
size_t binding1_offset, size_t binding1_size, size_t binding1_stride,
float* restrict binding2, float* restrict binding2_aligned,
size_t binding2_offset, size_t binding2_size, size_t binding2_stride,
// ^^^^ simplification pending (buffer + offset)
size_t dim, size_t tid) {
size_t end = tid + 64;
if (end > dim) end = dim;
for (size_t i = tid; i < end; ++i) {
binding2[i] = binding0[i] * binding1[i];
}
}
// `rhs *= lhs`
//
// Conforms to ABI:
// #hal.pipeline.layout<push_constants = 1, sets = [
// <0, bindings = [
// <0, storage_buffer, ReadOnly>,
// <1, storage_buffer>
// ]>
// ]>
// With a workgroup size of 64x1x1.
void simple_mul_inplace_workgroup(
// vvvv simplification pending (buffer + offset)
const float* restrict binding0, const float* restrict binding0_aligned,
size_t binding0_offset, size_t binding0_size, size_t binding0_stride,
float* restrict binding1, float* restrict binding1_aligned,
size_t binding1_offset, size_t binding1_size, size_t binding1_stride,
// ^^^^ simplification pending (buffer + offset)
size_t dim, size_t tid) {
size_t end = tid + 64;
if (end > dim) end = dim;
for (size_t i = tid; i < end; ++i) {
binding1[i] *= binding0[i];
}
}