Tightening ukernel common definitions (#10950)
* Complete rewrite of the big comment at the start of ukernel/common.h.
Please read carefully as I might be saying things that are incorrect.
* Ukernel local types prefixed to disambiguate from standard types, e.g.
`int32_t` --> `iree_ukernel_int32_t`.
* This caught a place in `vmvx/module.c` that now requires an explicit
cast from `unsigned long*` to `unsigned long long*`. I know these are
always the same 64bit stuff in practice, but it's worth an explicit
cast.
* Actually no longer including various platform headers in ukernels. In
ukernel/common.h, we were including iree/base/attributes.h, which was
including target_platform.h, which was including windows.h etc.
* Now ukernels that need a RESTRICT keyword can use
IREE_UKERNEL_RESTRICT, forked from IREE_RESTRICT.
* We can almost, but not quite, simply use `restrict`, since ukernels
are C code. The problem is that C++ code may need to include the
declarations. At least, the macro can make it easy to experiment with
what's the impact of restrict. And it leaves the door open to C++
kernels if someone really wants that in the future.
diff --git a/runtime/src/iree/builtins/ukernel/BUILD b/runtime/src/iree/builtins/ukernel/BUILD
index b322d67..6ceb403 100644
--- a/runtime/src/iree/builtins/ukernel/BUILD
+++ b/runtime/src/iree/builtins/ukernel/BUILD
@@ -35,6 +35,7 @@
],
deps = [
":exported_flag_bits",
+ ":static_assert",
"//runtime/src/iree/base:core_headers",
"//runtime/src/iree/builtins/ukernel/arch:config",
],
diff --git a/runtime/src/iree/builtins/ukernel/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
index 49f437b..c28c6f5 100644
--- a/runtime/src/iree/builtins/ukernel/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
@@ -41,6 +41,7 @@
"common.c"
DEPS
::exported_flag_bits
+ ::static_assert
iree::base::core_headers
iree::builtins::ukernel::arch::config
PUBLIC
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S
index 2b033f3..2b5a6da 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64.S
@@ -2,11 +2,11 @@
#include "iree/builtins/ukernel/exported_flag_bits.h"
// Parameters:
-// x0: int32_t* out_tile
-// x1: const int8_t* lhs_panel
-// x2: const int8_t* rhs_panel
-// w3: int32_t K. Note: K>=1, as the K==0 case was handled as an early-return.
-// w4: uint32_t flags
+// x0: iree_ukernel_int32_t* out_tile
+// x1: const iree_ukernel_int8_t* lhs_panel
+// x2: const iree_ukernel_int8_t* rhs_panel
+// w3: iree_ukernel_int32_t K. Note: K>=1, as the K==0 case was handled as an early-return.
+// w4: iree_ukernel_uint32_t flags
// x5: (UNUSED) params - relevant params K and flags already passed above.
BEGIN_FUNCTION iree_ukernel_mmt4d_f32f32f32_tile_8x8x1_arm_64
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S
index d1395a6..ad67d5b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_dotprod.S
@@ -2,11 +2,11 @@
#include "iree/builtins/ukernel/exported_flag_bits.h"
// Parameters:
-// x0: int32_t* out_tile
-// x1: const int8_t* lhs_panel
-// x2: const int8_t* rhs_panel
-// w3: int32_t K. Note: K>=1, as the K==0 case was handled as an early-return.
-// w4: uint32_t flags
+// x0: iree_ukernel_int32_t* out_tile
+// x1: const iree_ukernel_int8_t* lhs_panel
+// x2: const iree_ukernel_int8_t* rhs_panel
+// w3: iree_ukernel_int32_t K. Note: K>=1, as the K==0 case was handled as an early-return.
+// w4: iree_ukernel_uint32_t flags
// x5: (UNUSED) params - relevant params K and flags already passed above.
BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x4_arm_64_dotprod
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S
index 427770e..c412f81 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_tile_arm_64_i8mm.S
@@ -2,11 +2,11 @@
#include "iree/builtins/ukernel/exported_flag_bits.h"
// Parameters:
-// x0: int32_t* out_tile
-// x1: const int8_t* lhs_panel
-// x2: const int8_t* rhs_panel
-// w3: int32_t K. Note: K>=1, as the K==0 case was handled as an early-return.
-// w4: uint32_t flags
+// x0: iree_ukernel_int32_t* out_tile
+// x1: const iree_ukernel_int8_t* lhs_panel
+// x2: const iree_ukernel_int8_t* rhs_panel
+// w3: iree_ukernel_int32_t K. Note: K>=1, as the K==0 case was handled as an early-return.
+// w4: iree_ukernel_uint32_t flags
// x5: (UNUSED) params - relevant params K and flags already passed above.
BEGIN_FUNCTION iree_ukernel_mmt4d_i8i8i32_tile_8x8x8_arm_64_i8mm
diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h
index cabea96..56bf69d 100644
--- a/runtime/src/iree/builtins/ukernel/common.h
+++ b/runtime/src/iree/builtins/ukernel/common.h
@@ -10,29 +10,81 @@
//===----------------------------------------------------------------------===//
// Generic microkernel library
//===----------------------------------------------------------------------===//
-// This library is focused on supporting usage of tiled microkernels from both
-// runtime libraries (VMVX via the IREE VM) and compiled libraries (LLVM CPU
-// codegen). It is designed to compile standalone as well as to bitcode that
-// can be linked into generated libraries and has support for specialization
-// in the compiler. In general treat the code as portable across architectures
-// but consistently built for bare-metal systems with no stdlib.
//
-// Code here must not use any system headers - as almost all pull in bits/ and
-// various other target-dependent definitions that make the resulting IR
-// non-portable. This means there is no size_t, etc. Any definitions that may
-// come from an std* file must be redefined here with care. Target-specific
-// files may include target-specific headers if carefully managed.
+// Rules summary:
+// 1. Microkernels are bare-metal, excluding even the standard C library.
+// a. Can't #include any system header.
+// b. Can't #include any standard library header.
+// c. Can't interface with the OS in any way.
+// 2. Microkernels code may be specialized for a target CPU architecture, but
+// not for a complete target platform/OS/triple. In particular:
+// a. It's OK to have a `#ifdef __aarch64__` but not a `#ifdef __ANDROID__`.
+// 3. Microkernels are pure/reentrant/stateless.
+// a. Pure: the only effect of calling a ukernel is to write to destination
+// buffers specified by pointers passed as ukernel arguments.
+// b. Reentrant: ukernels may be called concurrently with
+// themselves, other ukernels, or any other code, on any thread.
+// c. Stateless: ukernels can't mutate any global (or static local) variable.
//
-// Code must also not use any mutable global or thread-local state ala
-// errno/rounding modes/etc. Each of the functions in the library will be called
-// concurrently from multiple threads and from multiple source modules. There
-// must be no mutable static values anywhere.
+// Explanation:
+// 1. a. Microkernels will eventually be called from IREE LLVM-CPU codegen
+// modules. So we need to be able to build microkernels for all the target
+// architectures that iree-compile supports. If microkernels included
+// system headers, we would need to compile them not merely for each
+// target architecture but for each target triple, and we would need to
+// have the system headers for each of these.
+// 1. b. Follows from a. because many standard C library headers #include
+// system headers. We can't keep track of which do. Even plausibly "pure"
+// ones such as <stdint.h> have been known to drag in surprising amounts.
+// 1. c. Since we're only targeting a CPU architecture, not a complete target
+// platform/OS, we can't use any features that rely on the OS. For example
+// we can't use TLS (thread-local-storage) or Linux's auxiliary vector, or
+// syscalls.
+// * This means in particular that any CPU feature detection needs
+// to be made ahead of calling the ukernel, and the results passed as
+// ukernel args.
+// 2. We don't want code to depend on platform `#ifdefs` beyond just target CPU
+// architecture ifdefs, in any way --- even if the code paths are not
+// interfacing with the OS (see 1.c.), it's still forbidden to have separate
+// code paths. When we will in the future call microkernels from IREE
+// LLVM-CPU codegen, this will make it legal for us to compile them only for
+// each target CPU architecture, which will be easier than having to compile
+// them separately for each supported target triple.
+// 3. Microkernels are typically called on tiles, after the workload has been
+// tiled and distributed to several threads. Keeping microkernels pure,
+// reentrant and stateless keeps them automatically compatible with any
+// tiling and distribution that we may use in the future.
//
-// Avoid #ifdef entirely where possible: they indicate a leakage of host build
-// configuration into what is supposed to be a portable module. Anything that
-// requires target-specific conditional logic must be implemented via an extern
-// that can be substituted by the IREE compiler when producing the final
-// target-specific module.
+// FAQ:
+// Q: Can a microkernel save, change, and restore the CPU float rounding mode?
+// A: Yes, as long as:
+// * It properly restores it in all its return paths.
+// * The CPU rounding mode is accessed in the microkernel's
+// own local code (as opposed to trying to use some standard library
+// header for that).
+// * The CPU architecture treats the rounding mode as a thread-local
+// setting (this tends to be the case on current CPU architectures).
+// Q: How can a microkernel depend on CPU identification information?
+// A: Microkernels that need to know CPU identification information, such as
+// bits indicating support for optional SIMD ISA features, should take
+// such information as arguments. This moves the problem of obtaining the
+// CPU identification information to the caller. This serves multiple
+// purposes:
+// * This allows writing tests that exercise all variants supported by the
+// test machine, not just whichever variant would be selected for that
+// machine.
+// * On CPU architectures where only the OS can directly access CPU
+// identification bits (that includes ARM architectures), this is
+// basically required by rule 1.c. (forbidding microkernels from
+// querying the OS directly).
+// - While other CPU architectures like x86 allow userspace processes to
+// directly query CPU identification, it's best to keep all kernels
+// on all architectures aligned on this.
+// - While some OSes may trap CPU identification instructions to make
+// them appear as succeeding in userspace programs
+// (https://www.kernel.org/doc/html/latest/arm64/cpu-feature-registers.html),
+// there are portability, reliability and performance concerns with
+// that.
// Include the build-system-generated configured header and use it as the only
// source of information about the target we're compiling against, as opposed to
@@ -44,13 +96,12 @@
// or stick to generic code.
#include "iree/builtins/ukernel/arch/config.h"
-// We require that this header compile on bare-metal targets with no stdlib.
-// These headers are clean:
-#include "iree/base/attributes.h"
-
// Include common flag values, shared with the compiler.
#include "iree/builtins/ukernel/exported_flag_bits.h"
+// Include IREE_UKERNEL_STATIC_ASSERT.
+#include "iree/builtins/ukernel/static_assert.h"
+
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
@@ -66,58 +117,91 @@
// documentation.
#define IREE_UKERNEL_EXPORT
+// Local fork of IREE_RESTRICT. We can't #include iree/base/attributes.h because
+// it drags in platform headers, via target_platform.h. TODO, consider sharing
+// this and other attributes that can be defined without any #include.
+#if defined(_MSC_VER) && _MSC_VER >= 1900
+#define IREE_UKERNEL_RESTRICT __restrict
+#elif defined(_MSC_VER)
+#define IREE_UKERNEL_RESTRICT
+#elif defined(__cplusplus)
+#define IREE_UKERNEL_RESTRICT __restrict__
+#else
+#define IREE_UKERNEL_RESTRICT restrict
+#endif // _MSC_VER
+
//===----------------------------------------------------------------------===//
-// stdint.h
+// Local replacements for stdint.h types and constants
+// Refer to the comment at the top of this file for why we can't include
+// stdint.h.
//===----------------------------------------------------------------------===//
-// https://pubs.opengroup.org/onlinepubs/009604599/basedefs/stdint.h.html
-// NOTE: prefer not using size_t/ptrdiff_t/etc (as they are target dependent).
-// We avoid including the toolchain file as it may not match with the target
-// information we have in the IREE compiler. They're also generally tire fires
-// that significantly bloat the bitcode we build and link into binaries: stdint
-// on Windows includes windows.h, for example.
-//
-// NOTE: callers must #include <stdint.h> before this header; unfortunately
-// there's not a great way to redefine these in the absence of stdint.h that
-// also operates with stdint.h.
-#if !defined(INT8_MIN)
+// These typedefs are making assumptions about the widths of standard C types.
+// These assumptions are guarded by the IREE_UKERNEL_STATIC_ASSERT's below.
+// If someday these assumptions fail, then we can always add #if's to control
+// these typedefs, perhaps similarly to what is done for iree_ukernel_ssize_t
+// below.
+typedef signed char iree_ukernel_int8_t;
+typedef short iree_ukernel_int16_t;
+typedef int iree_ukernel_int32_t;
+typedef long long iree_ukernel_int64_t;
+typedef unsigned char iree_ukernel_uint8_t;
+typedef unsigned short iree_ukernel_uint16_t;
+typedef unsigned int iree_ukernel_uint32_t;
+typedef unsigned long long iree_ukernel_uint64_t;
-typedef signed char int8_t;
-typedef short int16_t;
-typedef int int32_t;
-typedef long long int64_t;
-typedef unsigned char uint8_t;
-typedef unsigned short uint16_t;
-typedef unsigned int uint32_t;
-typedef unsigned long long uint64_t;
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_int8_t) == 1);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_int16_t) == 2);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_int32_t) == 4);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_int64_t) == 8);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_uint8_t) == 1);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_uint16_t) == 2);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_uint32_t) == 4);
+IREE_UKERNEL_STATIC_ASSERT(sizeof(iree_ukernel_uint64_t) == 8);
-#define INT8_MIN (-127i8 - 1)
-#define INT16_MIN (-32767i16 - 1)
-#define INT32_MIN (-2147483647i32 - 1)
-#define INT64_MIN (-9223372036854775807i64 - 1)
-#define INT8_MAX 127i8
-#define INT16_MAX 32767i16
-#define INT32_MAX 2147483647i32
-#define INT64_MAX 9223372036854775807i64
-#define UINT8_MAX 0xffui8
-#define UINT16_MAX 0xffffui16
-#define UINT32_MAX 0xffffffffui32
-#define UINT64_MAX 0xffffffffffffffffui64
+#define IREE_UKERNEL_INT8_MIN (-127i8 - 1)
+#define IREE_UKERNEL_INT16_MIN (-32767i16 - 1)
+#define IREE_UKERNEL_INT32_MIN (-2147483647i32 - 1)
+#define IREE_UKERNEL_INT64_MIN (-9223372036854775807i64 - 1)
+#define IREE_UKERNEL_INT8_MAX 127i8
+#define IREE_UKERNEL_INT16_MAX 32767i16
+#define IREE_UKERNEL_INT32_MAX 2147483647i32
+#define IREE_UKERNEL_INT64_MAX 9223372036854775807i64
+#define IREE_UKERNEL_UINT8_MAX 0xffui8
+#define IREE_UKERNEL_UINT16_MAX 0xffffui16
+#define IREE_UKERNEL_UINT32_MAX 0xffffffffui32
+#define IREE_UKERNEL_UINT64_MAX 0xffffffffffffffffui64
-#endif // !INT8_MIN
+//===----------------------------------------------------------------------===//
+// Local replacement for ssize_t
+//===----------------------------------------------------------------------===//
// Use iree_ukernel_ssize_t for all sizes that may need pointer width.
// For any argument that is known to fit in a specific size prefer that to
// ensure this code operates well on systems with small/weird widths (x32/ilp32,
// etc).
#if IREE_UKERNEL_POINTER_SIZE == 4
-typedef int32_t iree_ukernel_ssize_t;
+typedef iree_ukernel_int32_t iree_ukernel_ssize_t;
#elif IREE_UKERNEL_POINTER_SIZE == 8
-typedef int64_t iree_ukernel_ssize_t;
+typedef iree_ukernel_int64_t iree_ukernel_ssize_t;
#else
#error Unexpected pointer size
#endif
+//===----------------------------------------------------------------------===//
+// Local replacement for stdbool.h
+//===----------------------------------------------------------------------===//
+
+#ifndef __cplusplus
+// Exactly as in stdbool.h.
+// As stdbool.h is only macros, not typedefs, and it is standardized how these
+// macros expand, we can simply do them here. We still avoid #including it
+// in case in some toolchain it might include unexpected other headers.
+#define bool _Bool
+#define true 1
+#define false 0
+#endif
+
// Status codes returned by a mmt4d operation.
enum iree_ukernel_status_t {
iree_ukernel_status_ok = 0,
diff --git a/runtime/src/iree/builtins/ukernel/elementwise.h b/runtime/src/iree/builtins/ukernel/elementwise.h
index f4818d3..b005b07 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise.h
+++ b/runtime/src/iree/builtins/ukernel/elementwise.h
@@ -21,11 +21,11 @@
// It takes lhs, rhs, out buffers and size, returning 0 on success and !0 on
// error.
typedef int (*iree_ukernel_x32b_2d_func_t)(
- const uint32_t* lhs, iree_ukernel_ssize_t lhs_offset,
+ const iree_ukernel_uint32_t* lhs, iree_ukernel_ssize_t lhs_offset,
iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1,
- const uint32_t* rhs, iree_ukernel_ssize_t rhs_offset,
+ const iree_ukernel_uint32_t* rhs, iree_ukernel_ssize_t rhs_offset,
iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1,
- uint32_t* out, iree_ukernel_ssize_t out_offset,
+ iree_ukernel_uint32_t* out, iree_ukernel_ssize_t out_offset,
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1,
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1);
@@ -38,25 +38,25 @@
iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1, \
const dtype* rhs, iree_ukernel_ssize_t rhs_offset, \
iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1, \
- dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \
+ dtype* IREE_UKERNEL_RESTRICT out, iree_ukernel_ssize_t out_offset, \
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1)
-DECLARE_UKERNEL_BINARY_2D(addf, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(addi, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(andi, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(divf, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(divsi, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(divui, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(mulf, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(muli, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(ori, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(shli, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(shrsi, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(shrui, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(subf, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(subi, uint32_t, x32b);
-DECLARE_UKERNEL_BINARY_2D(xori, uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(addf, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(addi, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(andi, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(divf, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(divsi, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(divui, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(mulf, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(muli, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(ori, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(shli, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(shrsi, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(shrui, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(subf, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(subi, iree_ukernel_uint32_t, x32b);
+DECLARE_UKERNEL_BINARY_2D(xori, iree_ukernel_uint32_t, x32b);
//===----------------------------------------------------------------------===//
// Public API - Unary kernels.
@@ -66,9 +66,9 @@
// It takes in, out buffers and size, returning 0 on success and !0 on
// error.
typedef int (*iree_ukernel_x32u_2d_func_t)(
- const uint32_t* in, iree_ukernel_ssize_t in_offset,
+ const iree_ukernel_uint32_t* in, iree_ukernel_ssize_t in_offset,
iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1,
- uint32_t* out, iree_ukernel_ssize_t out_offset,
+ iree_ukernel_uint32_t* out, iree_ukernel_ssize_t out_offset,
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1,
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1);
@@ -80,18 +80,18 @@
IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \
const dtype* in, iree_ukernel_ssize_t in_offset, \
iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1, \
- dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \
+ dtype* IREE_UKERNEL_RESTRICT out, iree_ukernel_ssize_t out_offset, \
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1)
-DECLARE_UKERNEL_UNARY_2D(absf, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(ceilf, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(ctlz, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(expf, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(floorf, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(logf, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(negf, uint32_t, x32u);
-DECLARE_UKERNEL_UNARY_2D(rsqrtf, uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(absf, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(ceilf, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(ctlz, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(expf, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(floorf, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(logf, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(negf, iree_ukernel_uint32_t, x32u);
+DECLARE_UKERNEL_UNARY_2D(rsqrtf, iree_ukernel_uint32_t, x32u);
#ifdef __cplusplus
} // extern "C"
diff --git a/runtime/src/iree/builtins/ukernel/elementwise_generic.c b/runtime/src/iree/builtins/ukernel/elementwise_generic.c
index 8969409..9abfbc8 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise_generic.c
+++ b/runtime/src/iree/builtins/ukernel/elementwise_generic.c
@@ -9,27 +9,50 @@
// Include the generic implementation helpers.
#include "elementwise_impl.c.inc"
-DISPATCH_UKERNEL_BINARY_2D(addf, IREE_UKERNEL_X32B_ADDF, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(addi, IREE_UKERNEL_X32B_ADDI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(andi, IREE_UKERNEL_X32B_ANDI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(divf, IREE_UKERNEL_X32B_DIVF, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(divsi, IREE_UKERNEL_X32B_DIVSI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(divui, IREE_UKERNEL_X32B_DIVUI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(mulf, IREE_UKERNEL_X32B_MULF, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(muli, IREE_UKERNEL_X32B_MULI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(ori, IREE_UKERNEL_X32B_ORI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(shli, IREE_UKERNEL_X32B_SHLI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(shrsi, IREE_UKERNEL_X32B_SHRSI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(shrui, IREE_UKERNEL_X32B_SHRUI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(subf, IREE_UKERNEL_X32B_SUBF, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(subi, IREE_UKERNEL_X32B_SUBI, uint32_t, x32b);
-DISPATCH_UKERNEL_BINARY_2D(xori, IREE_UKENREL_X32B_XORI, uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(addf, IREE_UKERNEL_X32B_ADDF, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(addi, IREE_UKERNEL_X32B_ADDI, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(andi, IREE_UKERNEL_X32B_ANDI, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(divf, IREE_UKERNEL_X32B_DIVF, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(divsi, IREE_UKERNEL_X32B_DIVSI,
+ iree_ukernel_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(divui, IREE_UKERNEL_X32B_DIVUI,
+ iree_ukernel_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(mulf, IREE_UKERNEL_X32B_MULF, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(muli, IREE_UKERNEL_X32B_MULI, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(ori, IREE_UKERNEL_X32B_ORI, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(shli, IREE_UKERNEL_X32B_SHLI, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(shrsi, IREE_UKERNEL_X32B_SHRSI,
+ iree_ukernel_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(shrui, IREE_UKERNEL_X32B_SHRUI,
+ iree_ukernel_uint32_t, x32b);
+DISPATCH_UKERNEL_BINARY_2D(subf, IREE_UKERNEL_X32B_SUBF, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(subi, IREE_UKERNEL_X32B_SUBI, iree_ukernel_uint32_t,
+ x32b);
+DISPATCH_UKERNEL_BINARY_2D(xori, IREE_UKENREL_X32B_XORI, iree_ukernel_uint32_t,
+ x32b);
-DISPATCH_UKERNEL_UNARY_2D(absf, IREE_UKERNEL_X32U_ABSF, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(ceilf, IREE_UKERNEL_X32U_CEILF, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(ctlz, IREE_UKERNEL_X32U_CTLZ, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(expf, IREE_UKERNEL_X32U_EXPF, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(floorf, IREE_UKERNEL_X32U_FLOORF, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(logf, IREE_UKERNEL_X32U_LOGF, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(negf, IREE_UKERNEL_X32U_NEGF, uint32_t, x32u);
-DISPATCH_UKERNEL_UNARY_2D(rsqrtf, IREE_UKERNEL_X32U_RSQRTF, uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(absf, IREE_UKERNEL_X32U_ABSF, iree_ukernel_uint32_t,
+ x32u);
+DISPATCH_UKERNEL_UNARY_2D(ceilf, IREE_UKERNEL_X32U_CEILF, iree_ukernel_uint32_t,
+ x32u);
+DISPATCH_UKERNEL_UNARY_2D(ctlz, IREE_UKERNEL_X32U_CTLZ, iree_ukernel_uint32_t,
+ x32u);
+DISPATCH_UKERNEL_UNARY_2D(expf, IREE_UKERNEL_X32U_EXPF, iree_ukernel_uint32_t,
+ x32u);
+DISPATCH_UKERNEL_UNARY_2D(floorf, IREE_UKERNEL_X32U_FLOORF,
+ iree_ukernel_uint32_t, x32u);
+DISPATCH_UKERNEL_UNARY_2D(logf, IREE_UKERNEL_X32U_LOGF, iree_ukernel_uint32_t,
+ x32u);
+DISPATCH_UKERNEL_UNARY_2D(negf, IREE_UKERNEL_X32U_NEGF, iree_ukernel_uint32_t,
+ x32u);
+DISPATCH_UKERNEL_UNARY_2D(rsqrtf, IREE_UKERNEL_X32U_RSQRTF,
+ iree_ukernel_uint32_t, x32u);
diff --git a/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc b/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc
index 3f4ee40..01e5680 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc
+++ b/runtime/src/iree/builtins/ukernel/elementwise_impl.c.inc
@@ -57,8 +57,8 @@
// Macros to access various typed, dereferenced pointers.
#define ASF32(ptr) *((float*)ptr)
-#define ASUI32(ptr) *((uint32_t*)ptr)
-#define ASSI32(ptr) *((int32_t*)ptr)
+#define ASUI32(ptr) *((iree_ukernel_uint32_t*)ptr)
+#define ASSI32(ptr) *((iree_ukernel_int32_t*)ptr)
//===----------------------------------------------------------------------===//
// Math helper functions (extracted from base/internal/math.h and adapted
@@ -73,7 +73,7 @@
#pragma intrinsic(_BitScanForward)
#endif // IREE_COMPILER_MSVC
-static inline int iree_ukernel_count_leading_zeros_u32(const uint32_t n) {
+static inline int iree_ukernel_count_leading_zeros_u32(const iree_ukernel_uint32_t n) {
#if defined(_MSC_VER)
unsigned long result = 0; // NOLINT(runtime/int)
if (_BitScanReverse(&result, n)) {
@@ -115,7 +115,7 @@
iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1, \
const dtype* rhs, iree_ukernel_ssize_t rhs_offset, \
iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1, \
- dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \
+ dtype* IREE_UKERNEL_RESTRICT out, iree_ukernel_ssize_t out_offset, \
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) { \
return iree_ukernel_generic_##category##_2d( \
@@ -131,7 +131,7 @@
IREE_UKERNEL_EXPORT int iree_ukernel_##category##_##opcode##_2d( \
const dtype* in, iree_ukernel_ssize_t in_offset, \
iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1, \
- dtype* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset, \
+ dtype* IREE_UKERNEL_RESTRICT out, iree_ukernel_ssize_t out_offset, \
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1, \
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) { \
return iree_ukernel_generic_##category##_2d( \
@@ -146,8 +146,8 @@
// Computes a single element of an x32b opcode. On error, should set
// |*result_code| to a non-zero value (but should not touch it otherwise).
static void iree_ukernel_generic_x32b_op(iree_ukernel_x32b_opcode_t opcode,
- int* result_code, const uint32_t* lhs,
- const uint32_t* rhs, uint32_t* out) {
+ int* result_code, const iree_ukernel_uint32_t* lhs,
+ const iree_ukernel_uint32_t* rhs, iree_ukernel_uint32_t* out) {
switch (opcode) {
case IREE_UKERNEL_X32B_ADDF:
ASF32(out) = ASF32(lhs) + ASF32(rhs);
@@ -202,8 +202,8 @@
// Computes a single element of an x32u opcode. On error, should set
// |*result_code| to a non-zero value (but should not touch it otherwise).
static void iree_ukernel_generic_x32u_op(iree_ukernel_x32u_opcode_t opcode,
- int* result_code, const uint32_t* in,
- uint32_t* out) {
+ int* result_code, const iree_ukernel_uint32_t* in,
+ iree_ukernel_uint32_t* out) {
switch (opcode) {
case IREE_UKERNEL_X32U_ABSF:
ASF32(out) = fabsf(ASF32(in));
@@ -242,13 +242,13 @@
static int iree_ukernel_generic_x32b_2d(
iree_ukernel_x32b_opcode_t opcode,
// LHS.
- const uint32_t* lhs, iree_ukernel_ssize_t lhs_offset,
+ const iree_ukernel_uint32_t* lhs, iree_ukernel_ssize_t lhs_offset,
iree_ukernel_ssize_t lhs_stride0, iree_ukernel_ssize_t lhs_stride1,
// RHS
- const uint32_t* rhs, iree_ukernel_ssize_t rhs_offset,
+ const iree_ukernel_uint32_t* rhs, iree_ukernel_ssize_t rhs_offset,
iree_ukernel_ssize_t rhs_stride0, iree_ukernel_ssize_t rhs_stride1,
// OUT.
- uint32_t* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset,
+ iree_ukernel_uint32_t* IREE_UKERNEL_RESTRICT out, iree_ukernel_ssize_t out_offset,
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1,
// Sizes.
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) {
@@ -269,10 +269,10 @@
static int iree_ukernel_generic_x32u_2d(
iree_ukernel_x32u_opcode_t opcode,
// IN.
- const uint32_t* in, iree_ukernel_ssize_t in_offset,
+ const iree_ukernel_uint32_t* in, iree_ukernel_ssize_t in_offset,
iree_ukernel_ssize_t in_stride0, iree_ukernel_ssize_t in_stride1,
// OUT.
- uint32_t* IREE_RESTRICT out, iree_ukernel_ssize_t out_offset,
+ iree_ukernel_uint32_t* IREE_UKERNEL_RESTRICT out, iree_ukernel_ssize_t out_offset,
iree_ukernel_ssize_t out_stride0, iree_ukernel_ssize_t out_stride1,
// Sizes.
iree_ukernel_ssize_t size0, iree_ukernel_ssize_t size1) {
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c
index 3db7cf5..17685e1 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.c
@@ -61,29 +61,29 @@
static void iree_ukernel_mmt4d_using_tile_func(
const iree_ukernel_mmt4d_params_t* params,
iree_ukernel_mmt4d_tile_func_t tile_func) {
- const int32_t M = params->M;
- const int32_t N = params->N;
- const int32_t K = params->K;
- const int16_t M0 = params->M0;
- const int16_t N0 = params->N0;
- const int16_t lhs_elem_size_log2 =
+ const iree_ukernel_int32_t M = params->M;
+ const iree_ukernel_int32_t N = params->N;
+ const iree_ukernel_int32_t K = params->K;
+ const iree_ukernel_int16_t M0 = params->M0;
+ const iree_ukernel_int16_t N0 = params->N0;
+ const iree_ukernel_int16_t lhs_elem_size_log2 =
iree_ukernel_mmt4d_lhs_elem_size_log2(params->type);
- const int16_t rhs_elem_size_log2 =
+ const iree_ukernel_int16_t rhs_elem_size_log2 =
iree_ukernel_mmt4d_rhs_elem_size_log2(params->type);
- const int16_t out_elem_size_log2 =
+ const iree_ukernel_int16_t out_elem_size_log2 =
iree_ukernel_mmt4d_out_elem_size_log2(params->type);
char* out_tile_row = params->out_buffer;
const char* lhs_panel = params->lhs_buffer;
- int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
+ iree_ukernel_int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
iree_ukernel_ssize_t lhs_panel_stride = params->lhs_stride
<< lhs_elem_size_log2;
iree_ukernel_ssize_t rhs_panel_stride = params->rhs_stride
<< rhs_elem_size_log2;
iree_ukernel_ssize_t out_stride = params->out_stride << out_elem_size_log2;
- for (int32_t i = 0; i < M; ++i) {
+ for (iree_ukernel_int32_t i = 0; i < M; ++i) {
char* out_tile = out_tile_row;
const char* rhs_panel = params->rhs_buffer;
- for (int32_t j = 0; j < N; ++j) {
+ for (iree_ukernel_int32_t j = 0; j < N; ++j) {
tile_func(out_tile, lhs_panel, rhs_panel, K, params->flags, params);
out_tile += out_tile_size;
rhs_panel += rhs_panel_stride;
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c b/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c
index 4ca51d0..722469b 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_select_tile_generic.c
@@ -27,16 +27,17 @@
// Generic implementation of matmul tile, i8*i8->i32 case.
static void iree_ukernel_mmt4d_tile_i8i8i32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
- const void* rhs_panel_untyped, int32_t K, uint32_t flags,
- const iree_ukernel_mmt4d_params_t* params) {
- int32_t* out_tile = out_tile_untyped;
- const int8_t* lhs_panel = lhs_panel_untyped;
- const int8_t* rhs_panel = rhs_panel_untyped;
- int16_t M0 = params->M0;
- int16_t N0 = params->N0;
- int16_t K0 = params->K0;
+ const void* rhs_panel_untyped, iree_ukernel_int32_t K,
+ iree_ukernel_uint32_t flags, const iree_ukernel_mmt4d_params_t* params) {
+ iree_ukernel_int32_t* out_tile = out_tile_untyped;
+ const iree_ukernel_int8_t* lhs_panel = lhs_panel_untyped;
+ const iree_ukernel_int8_t* rhs_panel = rhs_panel_untyped;
+ iree_ukernel_int16_t M0 = params->M0;
+ iree_ukernel_int16_t N0 = params->N0;
+ iree_ukernel_int16_t K0 = params->K0;
// Initialize the local accumulator tile.
- int32_t acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
+ iree_ukernel_int32_t
+ acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
if (flags & IREE_UKERNEL_FLAG_ACCUMULATE) {
for (int i = 0; i < M0 * N0; ++i) acc[i] = out_tile[i];
} else {
@@ -47,8 +48,8 @@
for (iree_ukernel_ssize_t i0 = 0; i0 < M0; ++i0) {
for (iree_ukernel_ssize_t j0 = 0; j0 < N0; ++j0) {
for (iree_ukernel_ssize_t k0 = 0; k0 < K0; ++k0) {
- int32_t lhs_val_int32 = lhs_panel[i0 * K0 + k0];
- int32_t rhs_val_int32 = rhs_panel[j0 * K0 + k0];
+ iree_ukernel_int32_t lhs_val_int32 = lhs_panel[i0 * K0 + k0];
+ iree_ukernel_int32_t rhs_val_int32 = rhs_panel[j0 * K0 + k0];
acc[i0 * N0 + j0] += lhs_val_int32 * rhs_val_int32;
}
}
@@ -63,14 +64,14 @@
// Generic implementation of matmul tile, f32*f32->f32 case.
static void iree_ukernel_mmt4d_tile_f32f32f32_generic(
void* out_tile_untyped, const void* lhs_panel_untyped,
- const void* rhs_panel_untyped, int32_t K, uint32_t flags,
- const iree_ukernel_mmt4d_params_t* params) {
+ const void* rhs_panel_untyped, iree_ukernel_int32_t K,
+ iree_ukernel_uint32_t flags, const iree_ukernel_mmt4d_params_t* params) {
float* out_tile = out_tile_untyped;
const float* lhs_panel = lhs_panel_untyped;
const float* rhs_panel = rhs_panel_untyped;
- int16_t M0 = params->M0;
- int16_t N0 = params->N0;
- int16_t K0 = params->K0;
+ iree_ukernel_int16_t M0 = params->M0;
+ iree_ukernel_int16_t N0 = params->N0;
+ iree_ukernel_int16_t K0 = params->K0;
// Initialize the local accumulator tile.
float acc[iree_ukernel_mmt4d_tile_generic_max_bytes / sizeof(*out_tile)];
if (flags & IREE_UKERNEL_FLAG_ACCUMULATE) {
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_types.h b/runtime/src/iree/builtins/ukernel/mmt4d_types.h
index 6b083fd..de65c9d 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_types.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_types.h
@@ -21,7 +21,7 @@
// Parameters for a mmt4d operation.
struct iree_ukernel_mmt4d_params_t {
iree_ukernel_mmt4d_type_t type;
- uint32_t flags;
+ iree_ukernel_uint32_t flags;
const void* lhs_buffer;
const void* rhs_buffer;
void* out_buffer;
@@ -31,10 +31,10 @@
iree_ukernel_ssize_t M;
iree_ukernel_ssize_t N;
iree_ukernel_ssize_t K;
- int32_t M0;
- int32_t N0;
- int32_t K0;
- const uint64_t* cpu_data;
+ iree_ukernel_int32_t M0;
+ iree_ukernel_int32_t N0;
+ iree_ukernel_int32_t K0;
+ const iree_ukernel_uint64_t* cpu_data;
};
typedef struct iree_ukernel_mmt4d_params_t iree_ukernel_mmt4d_params_t;
@@ -56,13 +56,13 @@
// and keep that in sync with future struct changes.
typedef void (*iree_ukernel_mmt4d_tile_func_t)(
void* /*out_tile*/, const void* /*lhs_panel*/, const void* /*rhs_panel*/,
- int32_t /*K*/, uint32_t /*flags*/,
+ iree_ukernel_int32_t /*K*/, iree_ukernel_uint32_t /*flags*/,
const iree_ukernel_mmt4d_params_t* /*params*/);
// Tile kernel declarations. Prototype matches iree_ukernel_mmt4d_tile_func_t.
#define IREE_UKERNEL_MMT4D_TILE_FUNC_DECL(NAME) \
void NAME(void* out_tile, const void* lhs_panel, const void* rhs_panel, \
- int32_t K, uint32_t flags, \
+ iree_ukernel_int32_t K, iree_ukernel_uint32_t flags, \
const iree_ukernel_mmt4d_params_t* params);
// Log2 of size of LHS matrix element type, e.g. f32 --> size=4 --> log2=2
diff --git a/runtime/src/iree/builtins/ukernel/pack.c b/runtime/src/iree/builtins/ukernel/pack.c
index abb15f6..35ad506 100644
--- a/runtime/src/iree/builtins/ukernel/pack.c
+++ b/runtime/src/iree/builtins/ukernel/pack.c
@@ -8,8 +8,9 @@
static iree_ukernel_status_t iree_ukernel_pack_validate(
const iree_ukernel_pack_params_t* params) {
- const uint32_t allflags = IREE_UKERNEL_FLAG_PACK_TRANSPOSE_INNER |
- IREE_UKERNEL_FLAG_PACK_TRANSPOSE_OUTER;
+ const iree_ukernel_uint32_t allflags =
+ IREE_UKERNEL_FLAG_PACK_TRANSPOSE_INNER |
+ IREE_UKERNEL_FLAG_PACK_TRANSPOSE_OUTER;
if (params->flags & ~allflags) {
return iree_ukernel_status_bad_flags;
}
diff --git a/runtime/src/iree/builtins/ukernel/pack_types.h b/runtime/src/iree/builtins/ukernel/pack_types.h
index 5295b34..f25bf1d 100644
--- a/runtime/src/iree/builtins/ukernel/pack_types.h
+++ b/runtime/src/iree/builtins/ukernel/pack_types.h
@@ -34,7 +34,7 @@
iree_ukernel_ssize_t out_size2;
iree_ukernel_ssize_t out_size3;
const void* padding_value;
- uint32_t flags;
+ iree_ukernel_uint32_t flags;
};
typedef struct iree_ukernel_pack_params_t iree_ukernel_pack_params_t;
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 68b2991..dd15a88 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -38,7 +38,7 @@
int M0;
int N0;
int K0;
- const uint64_t* cpu_data;
+ const iree_ukernel_uint64_t* cpu_data;
};
typedef struct iree_mmt4d_benchmark_user_data_t
@@ -87,7 +87,7 @@
params.lhs_buffer = lhs_buffer;
params.rhs_buffer = rhs_buffer;
params.out_buffer = out_buffer;
- int64_t total_iterations = 0;
+ iree_ukernel_int64_t total_iterations = 0;
while (iree_benchmark_keep_running(benchmark_state,
/*batch_count=*/FLAG_batch_count)) {
for (int i = 0; i < FLAG_batch_count; ++i) {
@@ -136,8 +136,8 @@
#define MMT4D_BENCHMARK_REGISTER(_type, _m0, _n0, _k0, _cpu_data_field_0, \
_label) \
do { \
- static const uint64_t local_cpu_data[IREE_CPU_DATA_FIELD_COUNT] = { \
- _cpu_data_field_0}; \
+ static const iree_ukernel_uint64_t \
+ local_cpu_data[IREE_CPU_DATA_FIELD_COUNT] = {_cpu_data_field_0}; \
static const iree_mmt4d_benchmark_user_data_t user_data = { \
.type = iree_ukernel_mmt4d_type_##_type, \
.M0 = _m0, \
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
index 695fcd5..5b2fd76 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.cc
@@ -99,7 +99,8 @@
iree_mmt4d_reference<float, float, float>(params);
break;
case iree_ukernel_mmt4d_type_i8i8i32:
- iree_mmt4d_reference<int8_t, int8_t, int32_t>(params);
+ iree_mmt4d_reference<iree_ukernel_int8_t, iree_ukernel_int8_t,
+ iree_ukernel_int32_t>(params);
break;
default:
assert(false && "unknown type");
@@ -256,7 +257,7 @@
// and if the CPU supports the corresponding feature, the mmt4d tests are run a
// second time with that CPU feature enabled.
static void mmt4d_test(iree_ukernel_mmt4d_type_t type, int M0, int N0, int K0,
- uint64_t cpu_data_field_0_bit) {
+ iree_ukernel_uint64_t cpu_data_field_0_bit) {
// Letting each test create its own engine makes them independent: a testcase
// succeeds or fails the same way if we isolate it or reorder it. The
// potential downside of repeating the same pseudorandom sequence is OK
@@ -270,7 +271,8 @@
params.M0 = M0;
params.N0 = N0;
params.K0 = K0;
- const uint64_t local_cpu_data_default[IREE_CPU_DATA_FIELD_COUNT] = {0};
+ const iree_ukernel_uint64_t
+ local_cpu_data_default[IREE_CPU_DATA_FIELD_COUNT] = {0};
params.cpu_data = local_cpu_data_default;
// First try without any optional CPU feature. This matters even when the
// feature is supported by the CPU because we want to test the fallback to
@@ -278,8 +280,9 @@
test_matmuls_for_various_MNK_shapes_and_flags(params, engine);
// If this is nonzero, we are asked to test again with this CPU feature.
if (cpu_data_field_0_bit) {
- const uint64_t local_cpu_data_with_bit[IREE_CPU_DATA_FIELD_COUNT] = {
- cpu_data_field_0_bit};
+ const iree_ukernel_uint64_t
+ local_cpu_data_with_bit[IREE_CPU_DATA_FIELD_COUNT] = {
+ cpu_data_field_0_bit};
params.cpu_data = local_cpu_data_with_bit;
// Check if the CPU supports the feature (otherwise, we crash).
bool supported = iree_cpu_data_field(0) & params.cpu_data[0];
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
index 878b996..2982685 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test_utils.cc
@@ -75,7 +75,7 @@
static int iree_mmt4d_test_random_engine_get_in_uint16_range(
iree_mmt4d_test_random_engine_t* e) {
- uint32_t v = e->cpp_random_engine();
+ iree_ukernel_uint32_t v = e->cpp_random_engine();
// return the second-least-signicant out of the 4 bytes of state. It avoids
// some mild issues with the least-significant and most-significant bytes.
return (v >> 8) & 0xffff;
@@ -116,10 +116,12 @@
write_random_buffer(static_cast<float*>(buffer), size_in_bytes, engine);
return;
case iree_mmt4d_scalar_type_i32:
- write_random_buffer(static_cast<int32_t*>(buffer), size_in_bytes, engine);
+ write_random_buffer(static_cast<iree_ukernel_int32_t*>(buffer),
+ size_in_bytes, engine);
return;
case iree_mmt4d_scalar_type_i8:
- write_random_buffer(static_cast<int8_t*>(buffer), size_in_bytes, engine);
+ write_random_buffer(static_cast<iree_ukernel_int8_t*>(buffer),
+ size_in_bytes, engine);
return;
default:
assert(false && "unknown type");
diff --git a/runtime/src/iree/modules/vmvx/module.c b/runtime/src/iree/modules/vmvx/module.c
index 0576652..5fa7ba7 100644
--- a/runtime/src/iree/modules/vmvx/module.c
+++ b/runtime/src/iree/modules/vmvx/module.c
@@ -695,7 +695,7 @@
.M0 = M0,
.N0 = N0,
.K0 = K0,
- .cpu_data = iree_cpu_data_fields(),
+ .cpu_data = (const iree_ukernel_uint64_t*)iree_cpu_data_fields(),
};
iree_ukernel_status_t status = iree_ukernel_mmt4d(&ukernel_params);
IREE_TRACE_ZONE_END(z0);