Correct 32bit/64bit separation in ukernel code. (#13878)
* Rename `iree_uk_ssize_t` to `iree_uk_index_t` to make it clear that
the primary requirement on this type is to match the compiler's `index`
type.
* Build non-arch/ ukernel code twice, for 32bit and 64bit architectures.
As in libdevice, wasm_32 / wasm_64 is chosen as a sane 32bit/64bit
architecture to pick when the only thing that matters is the bitness
(which only matters for choosing the right definition of
`iree_uk_index_t`).
* Drop `strip_target_info.py`, instead adopt `Device.cpp`'s code
dropping function attributes inside of the compiler after loading the
bitcode.
diff --git a/build_tools/bazel/iree_bitcode_library.bzl b/build_tools/bazel/iree_bitcode_library.bzl
index 09282f3..0a01299 100644
--- a/build_tools/bazel/iree_bitcode_library.bzl
+++ b/build_tools/bazel/iree_bitcode_library.bzl
@@ -41,20 +41,17 @@
def iree_bitcode_library(
name,
+ arch,
srcs,
internal_hdrs = [],
copts = [],
out = None,
- arch = None,
**kwargs):
"""Builds an LLVM bitcode library from an input file via clang.
Args:
name: Name of the target.
- arch: Target architecture to compile for, in IREE_ARCH format. If left
- empty, will produce architecture-independent bitcode by stripping
- target triple and target attributes; that only makes sense if the
- sources being compiled are truly architecture-independent.
+ arch: Target architecture to compile for, in IREE_ARCH format.
srcs: source files to pass to clang.
internal_hdrs: all headers transitively included by the source files.
Unlike typical Bazel `hdrs`, these are not exposed as
@@ -73,6 +70,10 @@
builtin_headers_path = "external/llvm-project/clang/staging/include/"
base_copts = [
+ # Target architecture
+ "-target",
+ iree_arch_to_llvm_arch(arch),
+
# C17 with no system deps.
"-std=c17",
"-nostdinc",
@@ -97,20 +98,10 @@
"-DIREE_DEVICE_STANDALONE=1",
]
- llvmir_processing_tool = None
- if arch:
- # Compile to the specified target architecture.
- base_copts.extend(["-target", iree_arch_to_llvm_arch(arch)])
- else:
- # Output text rather than binary serialization of LLVM IR for processing
- base_copts.append("-S")
-
- # Strip target information from generated LLVM IR.
- llvmir_processing_tool = "//build_tools/scripts:strip_target_info"
-
bitcode_files = []
for src in srcs:
bitcode_out = "%s_%s.bc" % (name, src)
+ bitcode_files.append(bitcode_out)
native.genrule(
name = "gen_%s" % (bitcode_out),
srcs = [src, builtin_headers_dep] + internal_hdrs,
@@ -134,28 +125,6 @@
**kwargs
)
- if llvmir_processing_tool:
- processed_bitcode_out = "%s_%s.processed.bc" % (name, src)
- native.genrule(
- name = "gen_%s" % (processed_bitcode_out),
- srcs = [bitcode_out],
- outs = [processed_bitcode_out],
- cmd = " ".join([
- "$(location %s)" % (llvmir_processing_tool),
- "< $(location %s)" % bitcode_out,
- "> $(location %s)" % processed_bitcode_out,
- ]),
- tools = [
- llvmir_processing_tool,
- ],
- message = "Processing %s into %s using %s..." % (bitcode_out, processed_bitcode_out, llvmir_processing_tool),
- output_to_bindir = 1,
- **kwargs
- )
- bitcode_files.append(processed_bitcode_out)
- else:
- bitcode_files.append(bitcode_out)
-
if not out:
out = "%s.bc" % (name)
native.genrule(
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index f6b479d..023dc4d 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -484,20 +484,20 @@
def iree_bitcode_library(self,
name,
+ arch,
srcs,
internal_hdrs=None,
- copts=None,
- arch=None):
+ copts=None):
name_block = self._convert_string_arg_block("NAME", name, quote=False)
+ arch_block = self._convert_string_arg_block("ARCH", arch, quote=False)
srcs_block = self._convert_srcs_block(srcs)
copts_block = self._convert_string_list_block("COPTS", copts, sort=False)
- arch_block = self._convert_string_arg_block("ARCH", arch, quote=False)
self._converter.body += (f"iree_bitcode_library(\n"
f"{name_block}"
+ f"{arch_block}"
f"{srcs_block}"
f"{copts_block}"
- f"{arch_block}"
f")\n\n")
def iree_link_bitcode(self, name, bitcode_files):
diff --git a/build_tools/cmake/iree_bitcode_library.cmake b/build_tools/cmake/iree_bitcode_library.cmake
index 836c5ae..e6d2ba3 100644
--- a/build_tools/cmake/iree_bitcode_library.cmake
+++ b/build_tools/cmake/iree_bitcode_library.cmake
@@ -44,7 +44,12 @@
# override this but it should be harmless.
set(_BUILTIN_HEADERS_PATH "${IREE_BINARY_DIR}/llvm-project/lib/clang/${_CLANG_VERSION_MAJOR}/include/")
+ iree_arch_to_llvm_arch(_LLVM_ARCH "${_RULE_ARCH}")
+
set(_COPTS
+ # Target architecture.
+ "-target" "${_LLVM_ARCH}"
+
# C17 with no system deps.
"-std=c17"
"-nostdinc"
@@ -74,21 +79,11 @@
list(APPEND _COPTS "-I" "${IREE_BINARY_DIR}/runtime/src")
list(APPEND _COPTS "${_RULE_COPTS}")
- if(_RULE_ARCH)
- # Compile to the specified target architecture.
- iree_arch_to_llvm_arch(_LLVM_ARCH "${_RULE_ARCH}")
- list(APPEND _COPTS "-target" "${_LLVM_ARCH}")
- else()
- # Output text rather than binary serialization of LLVM IR for processing.
- list(APPEND _COPTS "-S")
- # Strip target information from generated LLVM IR.
- set(_LLVMIR_PROCESSING_TOOL "${IREE_SOURCE_DIR}/build_tools/scripts/strip_target_info.py")
- endif()
-
set(_BITCODE_FILES)
foreach(_SRC ${_RULE_SRCS})
get_filename_component(_BITCODE_SRC_PATH "${_SRC}" REALPATH)
set(_BITCODE_FILE "${_RULE_NAME}_${_SRC}.bc")
+ list(APPEND _BITCODE_FILES ${_BITCODE_FILE})
add_custom_command(
OUTPUT
"${_BITCODE_FILE}"
@@ -106,28 +101,6 @@
"Compiling ${_SRC} to ${_BITCODE_FILE}"
VERBATIM
)
-
- if(_LLVMIR_PROCESSING_TOOL)
- set(_PROCESSED_BITCODE_FILE "${_RULE_NAME}_${_SRC}.processed.bc")
- list(APPEND _BITCODE_FILES ${_PROCESSED_BITCODE_FILE})
- add_custom_command(
- OUTPUT
- "${_PROCESSED_BITCODE_FILE}"
- COMMAND
- "python3"
- "${_LLVMIR_PROCESSING_TOOL}"
- < "${_BITCODE_FILE}"
- > "${_PROCESSED_BITCODE_FILE}"
- DEPENDS
- "${_BITCODE_FILE}"
- "${_LLVMIR_PROCESSING_TOOL}"
- COMMENT
- "Processing ${_BITCODE_FILE} into ${_PROCESSED_BITCODE_FILE} using ${_LLVMIR_PROCESSING_TOOL}"
- VERBATIM
- )
- else() # _LLVMIR_PROCESSING_TOOL
- list(APPEND _BITCODE_FILES ${_BITCODE_FILE})
- endif() # _LLVMIR_PROCESSING_TOOL
endforeach()
add_custom_command(
diff --git a/build_tools/scripts/BUILD.bazel b/build_tools/scripts/BUILD.bazel
deleted file mode 100644
index 6c39341..0000000
--- a/build_tools/scripts/BUILD.bazel
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-package(
- default_visibility = ["//visibility:public"],
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-py_binary(
- name = "strip_target_info",
- srcs = ["strip_target_info.py"],
-)
diff --git a/build_tools/scripts/strip_target_info.py b/build_tools/scripts/strip_target_info.py
deleted file mode 100644
index 7c936d7..0000000
--- a/build_tools/scripts/strip_target_info.py
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""Strip LLVM IR of target triple and target-specific attributes
-"""
-
-import sys
-import re
-import os
-
-
-def main():
- sys.stdout.write(f";\n")
- sys.stdout.write(f"; Processed by {os.path.basename(__file__)}\n")
- sys.stdout.write(f";\n")
- target_triple_regex = re.compile(r'^\s*target triple\s*=\s*"[^"]*"')
- target_cpu_regex = re.compile(r'"target-cpu"="[^"]*"')
- target_features_regex = re.compile(r'"target-features"="[^"]*"')
- tune_cpu_regex = re.compile(r'"tune-cpu"="[^"]*"')
-
- for line in sys.stdin:
- if "target" in line:
- if re.match(target_triple_regex, line):
- continue
- line = re.sub(target_cpu_regex, '', line)
- line = re.sub(target_features_regex, '', line)
- line = re.sub(tune_cpu_regex, '', line)
-
- sys.stdout.write(line)
-
-
-if __name__ == "__main__":
- main()
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp
index 86df47c..e5f26bf 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp
@@ -34,10 +34,29 @@
}
std::unique_ptr<llvm::Module> loadUKernelBaseBitcode(
- llvm::LLVMContext& context) {
+ llvm::TargetMachine* targetMachine, llvm::LLVMContext& context) {
+ llvm::Triple triple = targetMachine->getTargetTriple();
+ StringRef filename;
+ if (triple.isArch64Bit()) {
+ filename = "ukernel_bitcode_64bit_base.bc";
+ } else if (triple.isArch32Bit()) {
+ filename = "ukernel_bitcode_32bit_base.bc";
+ } else {
+ return nullptr;
+ }
std::unique_ptr<llvm::Module> baseBitcode =
- loadUKernelBitcodeFile("ukernel_bitcode_base.bc", context);
+ loadUKernelBitcodeFile(filename, context);
assert(baseBitcode && "base ukernel bitcode file not found");
+
+ // Copied from Device.cpp - TODO: move this to a shared utility.
+ // Clang adds its own per-function attributes that we need to strip so that
+ // our current executable variant target is used instead.
+ for (auto& func : baseBitcode->functions()) {
+ func.removeFnAttr("target-cpu");
+ func.removeFnAttr("tune-cpu");
+ func.removeFnAttr("target-features");
+ }
+
return baseBitcode;
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h
index ca9d8f4..8ef2606 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h
@@ -16,7 +16,7 @@
namespace HAL {
std::unique_ptr<llvm::Module> loadUKernelBaseBitcode(
- llvm::LLVMContext &context);
+ llvm::TargetMachine *targetMachine, llvm::LLVMContext &context);
std::unique_ptr<llvm::Module> loadUKernelArchBitcode(
llvm::TargetMachine *targetMachine, llvm::LLVMContext &context);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
index 23f6dc5..4ef8a8b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
@@ -476,7 +476,7 @@
// The baseBitcode module contains weak symbols for fallbacks.
// So we link it after the archBitcode and with LinkOnlyNeeded.
std::unique_ptr<llvm::Module> baseBitcode =
- loadUKernelBaseBitcode(context);
+ loadUKernelBaseBitcode(targetMachine.get(), context);
// Sequence that access before we std::move(baseBitcode)!
StringRef baseBitcodeName = baseBitcode->getName();
if (failed(linkBitcodeModule(variantOp.getLoc(), moduleLinker,
diff --git a/runtime/src/iree/builtins/device/CMakeLists.txt b/runtime/src/iree/builtins/device/CMakeLists.txt
index 427f29f..25b5421 100644
--- a/runtime/src/iree/builtins/device/CMakeLists.txt
+++ b/runtime/src/iree/builtins/device/CMakeLists.txt
@@ -29,19 +29,19 @@
iree_bitcode_library(
NAME
libdevice_wasm32_generic
- SRCS
- "device_generic.c"
ARCH
wasm_32
+ SRCS
+ "device_generic.c"
)
iree_bitcode_library(
NAME
libdevice_wasm64_generic
- SRCS
- "device_generic.c"
ARCH
wasm_64
+ SRCS
+ "device_generic.c"
)
iree_c_embed_data(
diff --git a/runtime/src/iree/builtins/ukernel/BUILD.bazel b/runtime/src/iree/builtins/ukernel/BUILD.bazel
index b97a6a5..2dedda2 100644
--- a/runtime/src/iree/builtins/ukernel/BUILD.bazel
+++ b/runtime/src/iree/builtins/ukernel/BUILD.bazel
@@ -119,19 +119,23 @@
],
)
-iree_bitcode_library(
- name = "ukernel_bitcode_base",
+[iree_bitcode_library(
+ name = "ukernel_bitcode_%sbit_base" % bitness,
srcs = UKERNEL_BASE_SRCS,
+ arch = "wasm_%s" % bitness,
internal_hdrs = [
":bitcode_internal_headers",
- "//runtime/src/iree/builtins/ukernel/arch/x86_64:bitcode_internal_headers",
],
-)
+) for bitness in [
+ "32",
+ "64",
+]]
c_embed_data(
name = "embed_ukernel_bitcode",
srcs = [
- ":ukernel_bitcode_base.bc",
+ ":ukernel_bitcode_32bit_base.bc",
+ ":ukernel_bitcode_64bit_base.bc",
"//runtime/src/iree/builtins/ukernel/arch/x86_64:ukernel_bitcode_x86_64.bc",
],
c_file_output = "ukernel_bitcode.c",
diff --git a/runtime/src/iree/builtins/ukernel/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
index 3292437..73cad50 100644
--- a/runtime/src/iree/builtins/ukernel/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
@@ -108,7 +108,24 @@
iree_bitcode_library(
NAME
- ukernel_bitcode_base
+ ukernel_bitcode_32bit_base
+ ARCH
+ wasm_32
+ SRCS
+ "mmt4d.c"
+ "mmt4d_tile.c"
+ "pack.c"
+ "pack_tile.c"
+ "query_tile_sizes.c"
+ "unpack_tile.c"
+ "weak.c"
+)
+
+iree_bitcode_library(
+ NAME
+ ukernel_bitcode_64bit_base
+ ARCH
+ wasm_64
SRCS
"mmt4d.c"
"mmt4d_tile.c"
@@ -124,7 +141,8 @@
embed_ukernel_bitcode
SRCS
"runtime/src/iree/builtins/ukernel/arch/x86_64/ukernel_bitcode_x86_64.bc"
- "ukernel_bitcode_base.bc"
+ "ukernel_bitcode_32bit_base.bc"
+ "ukernel_bitcode_64bit_base.bc"
DEPS
C_FILE_OUTPUT
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h
index c05307b..1bcb60a 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h
@@ -30,7 +30,7 @@
#endif
static inline int8x16x2_t iree_uk_neon_load_8x4xi8_strided(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+ const iree_uk_int8_t* src, iree_uk_index_t stride) {
int32x4_t v0_i32 = vdupq_n_s32(0);
int32x4_t v1_i32 = vdupq_n_s32(0);
v0_i32 =
@@ -56,7 +56,7 @@
}
static inline int8x16x4_t iree_uk_neon_load_8x8xi8_strided_permute(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride, int p0, int p1, int p2,
+ const iree_uk_int8_t* src, iree_uk_index_t stride, int p0, int p1, int p2,
int p3, int p4, int p5, int p6, int p7) {
int8x8_t row0 = vld1_s8(src + p0 * stride);
int8x8_t row1 = vld1_s8(src + p1 * stride);
@@ -75,7 +75,7 @@
}
static inline int8x16x4_t iree_uk_neon_load_8x8xi8_strided(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+ const iree_uk_int8_t* src, iree_uk_index_t stride) {
return iree_uk_neon_load_8x8xi8_strided_permute(src, stride, 0, 1, 2, 3, 4, 5,
6, 7);
}
@@ -109,7 +109,7 @@
static inline void iree_uk_neon_copy_8x1xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
int8x8_t v = vdup_n_s8(0);
v = vld1_lane_s8(in_ptr + 0 * in_stride, v, 0);
v = vld1_lane_s8(in_ptr + 1 * in_stride, v, 1);
@@ -124,7 +124,7 @@
static inline void iree_uk_neon_copy_8x4xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
int8x16x2_t in = iree_uk_neon_load_8x4xi8_strided(in_ptr, in_stride);
vst1q_s8(out_ptr + 0, in.val[0]);
vst1q_s8(out_ptr + 16, in.val[1]);
@@ -132,7 +132,7 @@
static inline void iree_uk_neon_copy_8x8xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
int8x16x4_t in = iree_uk_neon_load_8x8xi8_strided(in_ptr, in_stride);
vst1q_s8(out_ptr + 0, in.val[0]);
vst1q_s8(out_ptr + 16, in.val[1]);
@@ -143,8 +143,8 @@
static inline void
iree_uk_neon_copy_8x8xi8_tiled_1x4_transpose_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
int8x16x4_t in = iree_uk_neon_load_8x8xi8_strided_permute(
in_ptr, in_stride, 0, 2, 1, 3, 4, 6, 5, 7);
int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s8(in.val[0]),
@@ -159,8 +159,8 @@
static inline void iree_uk_neon_copy_8x32xi8_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
for (int i = 0; i < 8; ++i) {
iree_uk_memcpy(out_ptr + i * out_stride, in_ptr + i * in_stride, 32);
}
@@ -168,8 +168,8 @@
static inline void iree_uk_neon_copy_8x8xi8_transpose_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
int8x16x4_t in = iree_uk_neon_load_8x8xi8_strided_permute(
in_ptr, in_stride, 0, 4, 1, 5, 2, 6, 3, 7);
int16x8x2_t zip_i16_0 = iree_uk_neon_zip_16xi8_as_8xi16(in.val[0], in.val[1]);
@@ -199,7 +199,7 @@
static inline void iree_uk_neon_copy_8x8xi8_transpose_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
// Clang (Android NDK r25) actually produces worse code when this code is
// specialized for out_stride==8 using longer contiguous stores!
iree_uk_neon_copy_8x8xi8_transpose_strided_to_strided(out_ptr, in_ptr, 8,
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c
index b280217..a01c17e 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c
@@ -9,10 +9,10 @@
static void iree_uk_pack_tile_8x1_x8_arm_64_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 1);
@@ -36,10 +36,10 @@
static void iree_uk_pack_tile_8x4_x8_arm_64_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 4);
@@ -60,10 +60,10 @@
static void iree_uk_pack_tile_8x1_x32_arm_64_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 1);
@@ -74,10 +74,10 @@
static void iree_uk_pack_tile_8x8_x8_arm_64_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -92,10 +92,10 @@
static void iree_uk_pack_tile_8x1_x32_arm_64_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 1);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -110,10 +110,10 @@
static void iree_uk_pack_tile_8x1_x8_arm_64_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 1);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -136,10 +136,10 @@
static void iree_uk_pack_tile_8x4_x8_arm_64_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 4);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -163,10 +163,10 @@
static void iree_uk_pack_tile_8x8_x8_arm_64_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -183,10 +183,10 @@
static void iree_uk_pack_tile_8x8_x32_arm_64_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 8);
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c
index 912c4a6..70654a3 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c
@@ -9,10 +9,10 @@
static void iree_uk_unpack_tile_8x8_x32_arm_64_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 8);
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt
index c08fc99..3885e3d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt
@@ -15,18 +15,20 @@
iree_bitcode_library(
NAME
ukernel_bitcode_x86_64_base
+ ARCH
+ x86_64
SRCS
"mmt4d_x86_64.c"
"pack_x86_64.c"
"query_tile_sizes_x86_64.c"
"unpack_x86_64.c"
- ARCH
- x86_64
)
iree_bitcode_library(
NAME
ukernel_bitcode_x86_64_avx2_fma
+ ARCH
+ x86_64
SRCS
"mmt4d_x86_64_avx2_fma.c"
"pack_x86_64_avx2_fma.c"
@@ -35,13 +37,13 @@
"-mavx"
"-mavx2"
"-mfma"
- ARCH
- x86_64
)
iree_bitcode_library(
NAME
ukernel_bitcode_x86_64_avx512_base
+ ARCH
+ x86_64
SRCS
"mmt4d_x86_64_avx512_base.c"
"pack_x86_64_avx512_base.c"
@@ -55,13 +57,13 @@
"-mavx512cd"
"-mavx512bw"
"-mavx512dq"
- ARCH
- x86_64
)
iree_bitcode_library(
NAME
ukernel_bitcode_x86_64_avx512_vnni
+ ARCH
+ x86_64
SRCS
"mmt4d_x86_64_avx512_vnni.c"
COPTS
@@ -74,8 +76,6 @@
"-mavx512bw"
"-mavx512dq"
"-mavx512vnni"
- ARCH
- x86_64
)
iree_link_bitcode(
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
index 08414cc..cebd52b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
@@ -71,8 +71,8 @@
static inline void iree_uk_copy_8x32xi8_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
for (int i = 0; i < 8; ++i) {
iree_uk_memcpy(out_ptr + i * out_stride, in_ptr + i * in_stride, 32);
}
@@ -80,22 +80,22 @@
static inline void iree_uk_copy_16x64xi8_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
for (int i = 0; i < 16; ++i) {
iree_uk_memcpy(out_ptr + i * out_stride, in_ptr + i * in_stride, 64);
}
}
static inline __m256i iree_uk_avx2_load_8x4xi8_strided(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+ const iree_uk_int8_t* src, iree_uk_index_t stride) {
__m256i indices = _mm256_mullo_epi32(
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(stride));
return _mm256_i32gather_epi32(src, indices, 1);
}
static inline __m128i iree_uk_avx2_load_8x2xi8_strided(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+ const iree_uk_int8_t* src, iree_uk_index_t stride) {
__m128i result = _mm_setzero_si128();
const iree_uk_int16_t* src_i16 = (const iree_uk_int16_t*)src;
result =
@@ -118,7 +118,7 @@
}
static inline __m256i iree_uk_avx2_load_16x2xi8_strided(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+ const iree_uk_int8_t* src, iree_uk_index_t stride) {
__m256i result = _mm256_setzero_si256();
const iree_uk_int16_t* src_i16 = (const iree_uk_int16_t*)src;
result = _mm256_insert_epi16(result,
@@ -158,21 +158,21 @@
static inline void iree_uk_avx2_copy_8x4xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
__m256i in = iree_uk_avx2_load_8x4xi8_strided(in_ptr, in_stride);
_mm256_storeu_si256((__m256i*)out_ptr, in);
}
static inline void iree_uk_avx2_copy_8x2xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
__m128i in = iree_uk_avx2_load_8x2xi8_strided(in_ptr, in_stride);
_mm_storeu_si128((__m128i*)out_ptr, in);
}
static inline void iree_uk_avx2_copy_16x2xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
__m256i in = iree_uk_avx2_load_16x2xi8_strided(in_ptr, in_stride);
_mm256_storeu_si256((__m256i*)out_ptr, in);
}
@@ -180,8 +180,8 @@
static inline void
iree_uk_avx2_copy_8x16xi8_tiled_1x4_transpose_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
__m256i r00004444 =
iree_uk_avx_loadu_2x128(in_ptr + 0 * in_stride, in_ptr + 4 * in_stride);
__m256i r11115555 =
@@ -211,8 +211,8 @@
static inline void
iree_uk_avx2_copy_8x16xi8_tiled_1x2_transpose_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
__m256i r0000000044444444 =
iree_uk_avx_loadu_2x128(in_ptr + 0 * in_stride, in_ptr + 4 * in_stride);
__m256i r1111111155555555 =
@@ -313,7 +313,7 @@
}
static inline __m512i iree_uk_avx512_load_16x4xi8_strided(
- const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+ const iree_uk_int8_t* src, iree_uk_index_t stride) {
__m512i indices = _mm512_mullo_epi32(
_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
_mm512_set1_epi32(stride));
@@ -322,7 +322,7 @@
static inline void iree_uk_avx512_copy_16x4xi8_strided_to_unstrided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
__m512i in = iree_uk_avx512_load_16x4xi8_strided(in_ptr, in_stride);
_mm512_storeu_si512((__m512i*)out_ptr, in);
}
@@ -330,8 +330,8 @@
static inline void
iree_uk_avx512_copy_16x16xi8_tiled_1x4_transpose_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
__m512i r000044448888CCCC = iree_uk_avx512_loadu_4x128(
in_ptr + 0 * in_stride, in_ptr + 4 * in_stride, in_ptr + 8 * in_stride,
in_ptr + 12 * in_stride);
@@ -381,8 +381,8 @@
static inline void
iree_uk_avx512_copy_16x16xi8_tiled_1x2_transpose_strided_to_strided(
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
- const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
- iree_uk_ssize_t in_stride) {
+ const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+ iree_uk_index_t in_stride) {
__m512i r000044448888CCCC = iree_uk_avx512_loadu_4x128(
in_ptr + 0 * in_stride, in_ptr + 4 * in_stride, in_ptr + 8 * in_stride,
in_ptr + 12 * in_stride);
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c
index 1c1dfe5..55e41cf 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c
@@ -11,10 +11,10 @@
void iree_uk_pack_tile_8x8_x32_x86_64_avx2_fma_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -30,10 +30,10 @@
static void iree_uk_pack_tile_8x4_x8_x86_64_avx2_fma_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 4);
@@ -54,10 +54,10 @@
void iree_uk_pack_tile_8x1_x32_x86_64_avx2_fma_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 1);
@@ -68,10 +68,10 @@
void iree_uk_pack_tile_8x1_x32_x86_64_avx2_fma_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 1);
IREE_UK_ASSERT(tile_size1 == 8);
@@ -86,10 +86,10 @@
void iree_uk_pack_tile_8x2_x8_x86_64_avx2_fma_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 2);
@@ -110,16 +110,16 @@
void iree_uk_pack_tile_8x2_x8_x86_64_avx2_fma_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 2);
IREE_UK_ASSERT(tile_size1 == 8);
const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr = in_tile_ptr;
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr = out_tile_ptr;
- iree_uk_ssize_t outer_i1 = 0;
+ iree_uk_index_t outer_i1 = 0;
for (; outer_i1 <= outer_size1 - 4; outer_i1 += 4) {
__m256i in0 = _mm256_loadu_si256((const __m256i*)in_ptr);
__m256i in1 = _mm256_loadu_si256((const __m256i*)(in_ptr + in_stride0));
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c
index 5021d6f..44a3dc3 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c
@@ -11,10 +11,10 @@
void iree_uk_pack_tile_16x16_x32_x86_64_avx512_base_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 16);
IREE_UK_ASSERT(tile_size1 == 16);
@@ -30,10 +30,10 @@
static void iree_uk_pack_tile_16x4_x8_x86_64_avx512_base_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 16);
IREE_UK_ASSERT(tile_size1 == 4);
@@ -55,10 +55,10 @@
void iree_uk_pack_tile_16x1_x32_x86_64_avx512_base_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 16);
IREE_UK_ASSERT(tile_size1 == 1);
@@ -69,10 +69,10 @@
void iree_uk_pack_tile_16x1_x32_x86_64_avx512_base_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 1);
IREE_UK_ASSERT(tile_size1 == 16);
@@ -87,10 +87,10 @@
void iree_uk_pack_tile_16x2_x8_x86_64_avx512_base_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 16);
IREE_UK_ASSERT(tile_size1 == 2);
@@ -111,16 +111,16 @@
void iree_uk_pack_tile_16x2_x8_x86_64_avx512_base_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 1);
IREE_UK_ASSERT(tile_size0 == 2);
IREE_UK_ASSERT(tile_size1 == 16);
const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr = in_tile_ptr;
iree_uk_int8_t* IREE_UK_RESTRICT out_ptr = out_tile_ptr;
- iree_uk_ssize_t outer_i1 = 0;
+ iree_uk_index_t outer_i1 = 0;
for (; outer_i1 <= outer_size1 - 4; outer_i1 += 4) {
__m512i in0 = _mm512_loadu_si512((const __m512i*)in_ptr);
__m512i in1 = _mm512_loadu_si512((const __m512i*)(in_ptr + in_stride0));
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c
index d7cda81..cc2600d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c
@@ -11,10 +11,10 @@
void iree_uk_unpack_tile_8x8_x32_x86_64_avx2_fma_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 8);
IREE_UK_ASSERT(tile_size1 == 8);
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c
index b99e16a..2c9229d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c
@@ -11,10 +11,10 @@
void iree_uk_unpack_tile_16x16_x32_x86_64_avx512_base_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
IREE_UK_ASSERT(elem_size == 4);
IREE_UK_ASSERT(tile_size0 == 16);
IREE_UK_ASSERT(tile_size1 == 16);
diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h
index 1a1db65..ad9630b 100644
--- a/runtime/src/iree/builtins/ukernel/common.h
+++ b/runtime/src/iree/builtins/ukernel/common.h
@@ -259,51 +259,94 @@
}
//===----------------------------------------------------------------------===//
-// iree_uk_ssize_t: signed size/index type.
+// Architecture detection (copied from target_platorm.h)
//===----------------------------------------------------------------------===//
-// Don't think of this as "pointer-size" -- there's a reason why this is called
-// ssize_t, not ptrdiff_t.
-//
-// The typical compilation path for this ukernel code is to LLVM bitcode, i.e.
-// LLVM IR. There, there is an opaque pointer type, `ptr`, so pointers don't
-// even have a "size". Size/index types, on the other hand, are explicitly sized
-// integers such as i32 or i64. We have to pick one here.
-//
-// As outer parts of ukernel code (all what's outside of the arch/ subdir) is
-// compiled once to LLVM bitcode that's reused on all architectures, we have to
-// pick one type that will be used across all architectures. In the future, if
-// we care more about optimizing for 32-bit architectures, we could consider
-// revisiting that by having two separate builds of the non-arch/ ukernel code,
-// one for 64-bit and one for 32-bit architectures.
-//
-// For now, it seems reasonable to just use int64 everywhere. Our current focus
-// is on 64-bit architectures anyway, the expected overhead on 32-bit
-// architectures should be marginal, and this buys us simplicity and uniformity
-// (ukernels having the exact same semantics across architectures) that we would
-// lose the moment we allowed iree_uk_ssize_t to vary across architectures.
-typedef iree_uk_int64_t iree_uk_ssize_t;
+#if defined(__arm64) || defined(__aarch64__) || defined(_M_ARM64) || \
+ defined(_M_ARM64EC)
+#define IREE_UK_ARCH_ARM_64 1
+#elif defined(__arm__) || defined(__thumb__) || defined(__TARGET_ARCH_ARM) || \
+ defined(__TARGET_ARCH_THUMB) || defined(_M_ARM)
+#define IREE_UK_ARCH_ARM_32 1
+#endif // ARM
-static inline void iree_uk_ssize_swap(iree_uk_ssize_t* a, iree_uk_ssize_t* b) {
- iree_uk_ssize_t t = *a;
+#if defined(__riscv) && (__riscv_xlen == 32)
+#define IREE_UK_ARCH_RISCV_32 1
+#elif defined(__riscv) && (__riscv_xlen == 64)
+#define IREE_UK_ARCH_RISCV_64 1
+#endif // RISCV
+
+#if defined(__wasm32__)
+#define IREE_UK_ARCH_WASM_32 1
+#elif defined(__wasm64__)
+#define IREE_UK_ARCH_WASM_64 1
+#endif // WASM
+
+#if defined(__i386__) || defined(__i486__) || defined(__i586__) || \
+ defined(__i686__) || defined(__i386) || defined(_M_IX86) || defined(_X86_)
+#define IREE_UK_ARCH_X86_32 1
+#elif defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || \
+ defined(__amd64) || defined(_M_X64)
+#define IREE_UK_ARCH_X86_64 1
+#endif // X86
+
+//===----------------------------------------------------------------------===//
+// Architecture bitness
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_UK_ARCH_ARM_64) || defined(IREE_UK_ARCH_RISCV_64) || \
+ defined(IREE_UK_ARCH_WASM_64) || defined(IREE_UK_ARCH_X86_64)
+#define IREE_UK_ARCH_IS_64_BIT
+#elif defined(IREE_UK_ARCH_ARM_32) || defined(IREE_UK_ARCH_RISCV_32) || \
+ defined(IREE_UK_ARCH_WASM_32) || defined(IREE_UK_ARCH_X86_32)
+#define IREE_UK_ARCH_IS_32_BIT
+#else
+#error Unknown architecture
+#endif
+
+//===----------------------------------------------------------------------===//
+// iree_uk_index_t: signed integer type, same size as MLIR `index`.
+//===----------------------------------------------------------------------===//
+
+// When ukernels are built as bitcode to embed in the compiler, the requirement
+// here is that the size of iree_uk_index_t equals the size of the compiler's
+// `index` type.
+//
+// So when here we define iree_uk_index_t as a pointer-sized type, there is an
+// implicit assumption about how this ukernel code is built. When building as
+// bitcode to embed in the compiler, this code must be built twice, once for
+// some 32-bit-pointers architecture and once for some 64-bit-pointers
+// architecture. Then the compiler must select the bitcode module that matches
+// the size of the `index` type.
+
+#if defined(IREE_UK_ARCH_IS_64_BIT)
+typedef iree_uk_int64_t iree_uk_index_t;
+#elif defined(IREE_UK_ARCH_IS_32_BIT)
+typedef iree_uk_int32_t iree_uk_index_t;
+#else
+#error Unknown architecture
+#endif
+
+static inline void iree_uk_index_swap(iree_uk_index_t* a, iree_uk_index_t* b) {
+ iree_uk_index_t t = *a;
*a = *b;
*b = t;
}
-static inline iree_uk_ssize_t iree_uk_ssize_min(iree_uk_ssize_t a,
- iree_uk_ssize_t b) {
+static inline iree_uk_index_t iree_uk_index_min(iree_uk_index_t a,
+ iree_uk_index_t b) {
return a <= b ? a : b;
}
-static inline iree_uk_ssize_t iree_uk_ssize_max(iree_uk_ssize_t a,
- iree_uk_ssize_t b) {
+static inline iree_uk_index_t iree_uk_index_max(iree_uk_index_t a,
+ iree_uk_index_t b) {
return a >= b ? a : b;
}
-static inline iree_uk_ssize_t iree_uk_ssize_clamp(iree_uk_ssize_t val,
- iree_uk_ssize_t min,
- iree_uk_ssize_t max) {
- return iree_uk_ssize_min(max, iree_uk_ssize_max(min, val));
+static inline iree_uk_index_t iree_uk_index_clamp(iree_uk_index_t val,
+ iree_uk_index_t min,
+ iree_uk_index_t max) {
+ return iree_uk_index_min(max, iree_uk_index_max(min, val));
}
//===----------------------------------------------------------------------===//
@@ -502,14 +545,14 @@
// as a memcpy call.
static inline void iree_uk_memcpy(void* IREE_UK_RESTRICT dst,
const void* IREE_UK_RESTRICT src,
- iree_uk_ssize_t size) {
- for (iree_uk_ssize_t i = 0; i < size; ++i)
+ iree_uk_index_t size) {
+ for (iree_uk_index_t i = 0; i < size; ++i)
((char*)dst)[i] = ((const char*)src)[i];
}
-static inline void iree_uk_memset(void* buf, int val, iree_uk_ssize_t n) {
+static inline void iree_uk_memset(void* buf, int val, iree_uk_index_t n) {
// This naive loop is lifted to a memset by both clang and gcc on ARM64.
- for (iree_uk_ssize_t i = 0; i < n; ++i) ((char*)buf)[i] = val;
+ for (iree_uk_index_t i = 0; i < n; ++i) ((char*)buf)[i] = val;
}
//===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/builtins/ukernel/elementwise.c b/runtime/src/iree/builtins/ukernel/elementwise.c
index 11570c1..5240469 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise.c
+++ b/runtime/src/iree/builtins/ukernel/elementwise.c
@@ -69,13 +69,13 @@
// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
#define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category) \
IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
- const dtype* lhs, iree_uk_ssize_t lhs_offset, \
- iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \
- const dtype* rhs, iree_uk_ssize_t rhs_offset, \
- iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \
- dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, \
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \
- iree_uk_ssize_t size0, iree_uk_ssize_t size1) { \
+ const dtype* lhs, iree_uk_index_t lhs_offset, \
+ iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1, \
+ const dtype* rhs, iree_uk_index_t rhs_offset, \
+ iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1, \
+ dtype* IREE_UK_RESTRICT out, iree_uk_index_t out_offset, \
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, \
+ iree_uk_index_t size0, iree_uk_index_t size1) { \
return iree_uk_generic_##category##_2d( \
opcode_t, lhs, lhs_offset, lhs_stride0, lhs_stride1, rhs, rhs_offset, \
rhs_stride0, rhs_stride1, out, out_offset, out_stride0, out_stride1, \
@@ -87,11 +87,11 @@
// Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
#define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category) \
IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
- const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
- iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out, \
- iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0, \
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0, \
- iree_uk_ssize_t size1) { \
+ const dtype* in, iree_uk_index_t in_offset, iree_uk_index_t in_stride0, \
+ iree_uk_index_t in_stride1, dtype* IREE_UK_RESTRICT out, \
+ iree_uk_index_t out_offset, iree_uk_index_t out_stride0, \
+ iree_uk_index_t out_stride1, iree_uk_index_t size0, \
+ iree_uk_index_t size1) { \
return iree_uk_generic_##category##_2d( \
opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset, \
out_stride0, out_stride1, size0, size1); \
@@ -203,20 +203,20 @@
IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32b_2d(
iree_uk_x32b_opcode_t opcode,
// LHS.
- const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
- iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
+ const iree_uk_uint32_t* lhs, iree_uk_index_t lhs_offset,
+ iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1,
// RHS
- const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
- iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
+ const iree_uk_uint32_t* rhs, iree_uk_index_t rhs_offset,
+ iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1,
// OUT.
- iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+ iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
// Sizes.
- iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+ iree_uk_index_t size0, iree_uk_index_t size1) {
int result_code = 0;
// TODO: Manually unroll to x4 to trigger vectorization.
- for (iree_uk_ssize_t i = 0; i < size0; ++i) {
- for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ for (iree_uk_index_t i = 0; i < size0; ++i) {
+ for (iree_uk_index_t j = 0; j < size1; ++j) {
iree_uk_generic_x32b_op(opcode, &result_code,
&lhs[i * lhs_stride0 + j * lhs_stride1],
&rhs[i * rhs_stride0 + j * rhs_stride1],
@@ -230,17 +230,17 @@
IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d(
iree_uk_x32u_opcode_t opcode,
// IN.
- const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
- iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
+ const iree_uk_uint32_t* in, iree_uk_index_t in_offset,
+ iree_uk_index_t in_stride0, iree_uk_index_t in_stride1,
// OUT.
- iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+ iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
// Sizes.
- iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+ iree_uk_index_t size0, iree_uk_index_t size1) {
int result_code = 0;
// TODO: Manually unroll to x4 to trigger vectorization.
- for (iree_uk_ssize_t i = 0; i < size0; ++i) {
- for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+ for (iree_uk_index_t i = 0; i < size0; ++i) {
+ for (iree_uk_index_t j = 0; j < size1; ++j) {
iree_uk_generic_x32u_op(opcode, &result_code,
&in[i * in_stride0 + j * in_stride1],
&out[i * out_stride0 + j * out_stride1]);
diff --git a/runtime/src/iree/builtins/ukernel/elementwise.h b/runtime/src/iree/builtins/ukernel/elementwise.h
index 34c3974..c266984 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise.h
+++ b/runtime/src/iree/builtins/ukernel/elementwise.h
@@ -21,26 +21,26 @@
// It takes lhs, rhs, out buffers and size, returning 0 on success and !0 on
// error.
typedef int (*iree_uk_x32b_2d_func_t)(
- const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
- iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
- const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
- iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
- iree_uk_uint32_t* out, iree_uk_ssize_t out_offset,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
- iree_uk_ssize_t size0, iree_uk_ssize_t size1);
+ const iree_uk_uint32_t* lhs, iree_uk_index_t lhs_offset,
+ iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1,
+ const iree_uk_uint32_t* rhs, iree_uk_index_t rhs_offset,
+ iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1,
+ iree_uk_uint32_t* out, iree_uk_index_t out_offset,
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
+ iree_uk_index_t size0, iree_uk_index_t size1);
// Declares a binary 2d microkernel with the following signature:
// int iree_uk_{category}_{opcode}_2d(...)
// of function type iree_uk_{category}_2d_func_t.
#define DECLARE_UKERNEL_BINARY_2D(opcode, dtype, category) \
IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
- const dtype* lhs, iree_uk_ssize_t lhs_offset, \
- iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \
- const dtype* rhs, iree_uk_ssize_t rhs_offset, \
- iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \
- dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset, \
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \
- iree_uk_ssize_t size0, iree_uk_ssize_t size1)
+ const dtype* lhs, iree_uk_index_t lhs_offset, \
+ iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1, \
+ const dtype* rhs, iree_uk_index_t rhs_offset, \
+ iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1, \
+ dtype* IREE_UK_RESTRICT out, iree_uk_index_t out_offset, \
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, \
+ iree_uk_index_t size0, iree_uk_index_t size1)
DECLARE_UKERNEL_BINARY_2D(addf, iree_uk_uint32_t, x32b);
DECLARE_UKERNEL_BINARY_2D(addi, iree_uk_uint32_t, x32b);
@@ -66,11 +66,11 @@
// It takes in, out buffers and size, returning 0 on success and !0 on
// error.
typedef int (*iree_uk_x32u_2d_func_t)(
- const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
- iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_uint32_t* out, iree_uk_ssize_t out_offset,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
- iree_uk_ssize_t size0, iree_uk_ssize_t size1);
+ const iree_uk_uint32_t* in, iree_uk_index_t in_offset,
+ iree_uk_index_t in_stride0, iree_uk_index_t in_stride1,
+ iree_uk_uint32_t* out, iree_uk_index_t out_offset,
+ iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
+ iree_uk_index_t size0, iree_uk_index_t size1);
// Declares a binary 2d microkernel with the following signature:
// int iree_uk_{category}_{opcode}_2d(...)
@@ -78,11 +78,11 @@
// error.
#define DECLARE_UKERNEL_UNARY_2D(opcode, dtype, category) \
IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d( \
- const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
- iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out, \
- iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0, \
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0, \
- iree_uk_ssize_t size1)
+ const dtype* in, iree_uk_index_t in_offset, iree_uk_index_t in_stride0, \
+ iree_uk_index_t in_stride1, dtype* IREE_UK_RESTRICT out, \
+ iree_uk_index_t out_offset, iree_uk_index_t out_stride0, \
+ iree_uk_index_t out_stride1, iree_uk_index_t size0, \
+ iree_uk_index_t size1)
DECLARE_UKERNEL_UNARY_2D(absf, iree_uk_uint32_t, x32u);
DECLARE_UKERNEL_UNARY_2D(ceilf, iree_uk_uint32_t, x32u);
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c
index d7ccc44..67f4856 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.c
@@ -63,9 +63,9 @@
const char* rhs_panel_start = (const char*)params->rhs_buffer +
(params->rhs_offset << rhs_elem_size_log2);
iree_uk_int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
- iree_uk_ssize_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2;
- iree_uk_ssize_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2;
- iree_uk_ssize_t out_stride = params->out_stride0 << out_elem_size_log2;
+ iree_uk_index_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2;
+ iree_uk_index_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2;
+ iree_uk_index_t out_stride = params->out_stride0 << out_elem_size_log2;
for (iree_uk_int32_t i = 0; i < M; ++i) {
char* out_tile = out_tile_row;
const char* rhs_panel = rhs_panel_start;
@@ -88,12 +88,12 @@
iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
int out_elem_size_log2 = iree_uk_type_size_log2(out_type);
- iree_uk_ssize_t contiguous_size = params->N * params->M0 * params->N0
+ iree_uk_index_t contiguous_size = params->N * params->M0 * params->N0
<< out_elem_size_log2;
- iree_uk_ssize_t stride = params->out_stride0 << out_elem_size_log2;
+ iree_uk_index_t stride = params->out_stride0 << out_elem_size_log2;
char* out_ptr =
(char*)params->out_buffer + (params->out_offset << out_elem_size_log2);
- for (iree_uk_ssize_t i = 0; i < params->M; ++i) {
+ for (iree_uk_index_t i = 0; i < params->M; ++i) {
iree_uk_memset(out_ptr, 0, contiguous_size);
out_ptr += stride;
}
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.h b/runtime/src/iree/builtins/ukernel/mmt4d.h
index bd956a6..a093e3e 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.h
@@ -15,17 +15,17 @@
typedef struct iree_uk_mmt4d_params_t {
const void* lhs_buffer;
- iree_uk_ssize_t lhs_offset;
- iree_uk_ssize_t lhs_stride0;
+ iree_uk_index_t lhs_offset;
+ iree_uk_index_t lhs_stride0;
const void* rhs_buffer;
- iree_uk_ssize_t rhs_offset;
- iree_uk_ssize_t rhs_stride0;
+ iree_uk_index_t rhs_offset;
+ iree_uk_index_t rhs_stride0;
void* out_buffer;
- iree_uk_ssize_t out_offset;
- iree_uk_ssize_t out_stride0;
- iree_uk_ssize_t M;
- iree_uk_ssize_t N;
- iree_uk_ssize_t K;
+ iree_uk_index_t out_offset;
+ iree_uk_index_t out_stride0;
+ iree_uk_index_t M;
+ iree_uk_index_t N;
+ iree_uk_index_t K;
iree_uk_int32_t M0;
iree_uk_int32_t N0;
iree_uk_int32_t K0;
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
index 36a8c25..440001c 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
@@ -25,10 +25,10 @@
for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
}
// Accumulation loop.
- for (iree_uk_ssize_t k = 0; k < K; ++k) {
- for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
- for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
- for (iree_uk_ssize_t k0 = 0; k0 < K0; ++k0) {
+ for (iree_uk_index_t k = 0; k < K; ++k) {
+ for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
iree_uk_int32_t lhs_val_int32 = lhs_panel[i0 * K0 + k0];
iree_uk_int32_t rhs_val_int32 = rhs_panel[j0 * K0 + k0];
acc[i0 * N0 + j0] += lhs_val_int32 * rhs_val_int32;
@@ -61,10 +61,10 @@
for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
}
// Accumulation loop.
- for (iree_uk_ssize_t k = 0; k < K; ++k) {
- for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
- for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
- for (iree_uk_ssize_t k0 = 0; k0 < K0; ++k0) {
+ for (iree_uk_index_t k = 0; k < K; ++k) {
+ for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+ for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+ for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
float lhs_val = lhs_panel[i0 * K0 + k0];
float rhs_val = rhs_panel[j0 * K0 + k0];
acc[i0 * N0 + j0] += lhs_val * rhs_val;
diff --git a/runtime/src/iree/builtins/ukernel/pack.c b/runtime/src/iree/builtins/ukernel/pack.c
index e7e7a25..7895ef9 100644
--- a/runtime/src/iree/builtins/ukernel/pack.c
+++ b/runtime/src/iree/builtins/ukernel/pack.c
@@ -20,8 +20,8 @@
} iree_uk_pack_tmpbuf_helper_t;
// Return x/y for x>=0 and y>0, with a fast path for when y is a power of two.
-static iree_uk_ssize_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
- iree_uk_ssize_t x, iree_uk_int32_t y) {
+static iree_uk_index_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
+ iree_uk_index_t x, iree_uk_int32_t y) {
IREE_UK_ASSERT(x >= 0);
IREE_UK_ASSERT(y > 0);
return IREE_UK_LIKELY(iree_uk_is_po2_u32(y)) ? (x >> iree_uk_po2_log2_u32(y))
@@ -46,8 +46,8 @@
// Initializes a `iree_uk_pack_tmpbuf_helper_t`. Asserts if the temporary buffer
// is smaller than one tile.
static void iree_uk_pack_tmpbuf_helper_t_init(
- iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
- iree_uk_ssize_t elem_size, iree_uk_uint64_t padding_value,
+ iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+ iree_uk_index_t elem_size, iree_uk_uint64_t padding_value,
iree_uk_pack_tmpbuf_helper_t* helper) {
helper->max_tiles_in_tmp_buf = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
iree_uk_pack_tmp_buf_size, tile_size0 * tile_size1 * elem_size);
@@ -76,15 +76,15 @@
IREE_UK_ASSERT(params->out_size3 >= 0);
// Check that the input and output shapes match, give or take padding that
// must not exceed the inner tile size.s
- iree_uk_ssize_t outer_size0 = params->out_size0;
- iree_uk_ssize_t outer_size1 = params->out_size1;
- iree_uk_ssize_t tile_size0 = params->out_size2;
- iree_uk_ssize_t tile_size1 = params->out_size3;
+ iree_uk_index_t outer_size0 = params->out_size0;
+ iree_uk_index_t outer_size1 = params->out_size1;
+ iree_uk_index_t tile_size0 = params->out_size2;
+ iree_uk_index_t tile_size1 = params->out_size3;
if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&outer_size0, &outer_size1);
+ iree_uk_index_swap(&outer_size0, &outer_size1);
}
if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
}
IREE_UK_ASSERT(outer_size0 * tile_size0 >= params->in_size0);
IREE_UK_ASSERT(outer_size1 * tile_size1 >= params->in_size1);
@@ -100,7 +100,7 @@
iree_uk_pack_tmpbuf_helper_t padding_helper;
iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
- iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
+ iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
iree_uk_pack_tmpbuf_helper_t_init(tile_size0, tile_size1, elem_size,
params->padding_value, &padding_helper);
#endif // IREE_UK_ENABLE_ASSERTS
@@ -115,8 +115,8 @@
// Fills `buf` with `num_elems` times the `pattern` of size `elem_size`.
// If this pattern's `elem_size` bytes are all equal, then it is legal to pass
// `is_single_byte_pattern=true`, which allows the impl to use memset.
-static void iree_uk_fill(char* IREE_UK_RESTRICT buf, iree_uk_ssize_t num_elems,
- iree_uk_ssize_t elem_size,
+static void iree_uk_fill(char* IREE_UK_RESTRICT buf, iree_uk_index_t num_elems,
+ iree_uk_index_t elem_size,
iree_uk_uint64_t padding_value,
bool is_padding_single_byte) {
if (is_padding_single_byte) {
@@ -124,20 +124,20 @@
} else if (elem_size == 2) {
iree_uk_uint16_t padding_value_uint16 = padding_value;
iree_uk_uint16_t* IREE_UK_RESTRICT buf_uint16 = (iree_uk_uint16_t*)buf;
- for (iree_uk_ssize_t i = 0; i < num_elems; ++i) {
+ for (iree_uk_index_t i = 0; i < num_elems; ++i) {
buf_uint16[i] = padding_value_uint16;
}
} else if (elem_size == 4) {
iree_uk_uint32_t padding_value_uint32 = padding_value;
iree_uk_uint32_t* IREE_UK_RESTRICT buf_uint32 = (iree_uk_uint32_t*)buf;
- for (iree_uk_ssize_t i = 0; i < num_elems; ++i) {
+ for (iree_uk_index_t i = 0; i < num_elems; ++i) {
buf_uint32[i] = padding_value_uint32;
}
} else { // elem_size >= 8
// While arbitrary large elem_size is allowed, padding_value remains a
// uint64, so elem_size >= 16 only support a repeating 8-byte pattern.
iree_uk_uint64_t* IREE_UK_RESTRICT buf_uint64 = (iree_uk_uint64_t*)buf;
- for (iree_uk_ssize_t i = 0; i < num_elems * elem_size / 8; ++i) {
+ for (iree_uk_index_t i = 0; i < num_elems * elem_size / 8; ++i) {
buf_uint64[i] = padding_value;
}
}
@@ -146,14 +146,14 @@
// Copy from a source 2D buffer to a destination 2D buffer, padding to the
// destination size.
static void iree_uk_copy_and_pad(
- iree_uk_ssize_t src_size0, iree_uk_ssize_t src_size1,
- iree_uk_ssize_t src_stride0, const char* src_buf, iree_uk_ssize_t dst_size0,
- iree_uk_ssize_t dst_size1, iree_uk_ssize_t dst_stride0, char* dst_buf,
- iree_uk_ssize_t elem_size, iree_uk_uint64_t padding_value,
+ iree_uk_index_t src_size0, iree_uk_index_t src_size1,
+ iree_uk_index_t src_stride0, const char* src_buf, iree_uk_index_t dst_size0,
+ iree_uk_index_t dst_size1, iree_uk_index_t dst_stride0, char* dst_buf,
+ iree_uk_index_t elem_size, iree_uk_uint64_t padding_value,
bool is_padding_single_byte) {
iree_uk_fill(dst_buf, dst_size1 + (dst_size0 - 1) * dst_stride0, elem_size,
padding_value, is_padding_single_byte);
- for (iree_uk_ssize_t in_i0 = 0; in_i0 < src_size0; in_i0++) {
+ for (iree_uk_index_t in_i0 = 0; in_i0 < src_size0; in_i0++) {
iree_uk_memcpy(dst_buf, src_buf, src_size1 * elem_size);
dst_buf += dst_stride0 * elem_size;
src_buf += src_stride0 * elem_size;
@@ -163,20 +163,20 @@
// Pads and packs an entire row. In cases that are known not to require padding,
// it is more efficient to call tile_func directly.
static void iree_uk_pad_and_pack_row_using_tile_func(
- iree_uk_pack_tile_func_t tile_func, iree_uk_ssize_t dim1_tile_start,
- iree_uk_ssize_t dim1_tile_end, iree_uk_ssize_t dim0_src_read_size,
- iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t in_size1,
- iree_uk_ssize_t in_stride0, iree_uk_ssize_t out_stride1,
+ iree_uk_pack_tile_func_t tile_func, iree_uk_index_t dim1_tile_start,
+ iree_uk_index_t dim1_tile_end, iree_uk_index_t dim0_src_read_size,
+ iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+ iree_uk_index_t elem_size, iree_uk_index_t in_size1,
+ iree_uk_index_t in_stride0, iree_uk_index_t out_stride1,
iree_uk_uint64_t padding_value, iree_uk_pack_tmpbuf_helper_t* helper,
const char* in_buf, char* out_buf) {
- iree_uk_ssize_t dim1_tile = dim1_tile_start;
+ iree_uk_index_t dim1_tile = dim1_tile_start;
while (dim1_tile < dim1_tile_end) {
- iree_uk_ssize_t dim1_chunk_tiles = iree_uk_ssize_clamp(
+ iree_uk_index_t dim1_chunk_tiles = iree_uk_index_clamp(
dim1_tile_end - dim1_tile, 0, helper->max_tiles_in_tmp_buf);
- iree_uk_ssize_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
- iree_uk_ssize_t dim1_chunk_src_pos = dim1_tile * tile_size1;
- iree_uk_ssize_t i1_read_size = iree_uk_ssize_clamp(
+ iree_uk_index_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
+ iree_uk_index_t dim1_chunk_src_pos = dim1_tile * tile_size1;
+ iree_uk_index_t i1_read_size = iree_uk_index_clamp(
in_size1 - dim1_chunk_src_pos, 0, dim1_chunk_src_width);
iree_uk_copy_and_pad(dim0_src_read_size, i1_read_size, in_stride0,
in_buf + dim1_chunk_src_pos * elem_size, tile_size0,
@@ -195,19 +195,19 @@
// For now, the input and output element types are always the same.
iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
- iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
- iree_uk_ssize_t outer_size0 = params->out_size0;
- iree_uk_ssize_t outer_size1 = params->out_size1;
- iree_uk_ssize_t tile_size0 = params->out_size2;
- iree_uk_ssize_t tile_size1 = params->out_size3;
- iree_uk_ssize_t out_stride_l0 = params->out_stride0;
- iree_uk_ssize_t out_stride1 = params->out_size3 * params->out_size2;
+ iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+ iree_uk_index_t outer_size0 = params->out_size0;
+ iree_uk_index_t outer_size1 = params->out_size1;
+ iree_uk_index_t tile_size0 = params->out_size2;
+ iree_uk_index_t tile_size1 = params->out_size3;
+ iree_uk_index_t out_stride_l0 = params->out_stride0;
+ iree_uk_index_t out_stride1 = params->out_size3 * params->out_size2;
if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&outer_size0, &outer_size1);
- iree_uk_ssize_swap(&out_stride_l0, &out_stride1);
+ iree_uk_index_swap(&outer_size0, &outer_size1);
+ iree_uk_index_swap(&out_stride_l0, &out_stride1);
}
if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
}
const char* in_buf =
(const char*)params->in_buffer + (params->in_offset * elem_size);
@@ -223,7 +223,7 @@
// source buffer's boundaries.
int dim1_full_tiles = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
params->in_size1, tile_size1);
- iree_uk_ssize_t i0 = 0;
+ iree_uk_index_t i0 = 0;
for (; i0 <= params->in_size0 - tile_size0; i0 += tile_size0) {
// Pack whole tiles that do not require padding (entirely within the source
// buffer's boundaries).
@@ -239,8 +239,8 @@
}
// Bottom-padding.
for (; i0 < outer_size0 * tile_size0; i0 += tile_size0) {
- iree_uk_ssize_t dim0_src_read_size =
- iree_uk_ssize_clamp(params->in_size0 - i0, 0, tile_size0);
+ iree_uk_index_t dim0_src_read_size =
+ iree_uk_index_clamp(params->in_size0 - i0, 0, tile_size0);
iree_uk_pad_and_pack_row_using_tile_func(
tile_func, 0, outer_size1, dim0_src_read_size, tile_size0, tile_size1,
elem_size, params->in_size1, params->in_stride0, out_stride1,
diff --git a/runtime/src/iree/builtins/ukernel/pack.h b/runtime/src/iree/builtins/ukernel/pack.h
index 8feb64b..c4d0233 100644
--- a/runtime/src/iree/builtins/ukernel/pack.h
+++ b/runtime/src/iree/builtins/ukernel/pack.h
@@ -15,17 +15,17 @@
typedef struct iree_uk_pack_params_t {
const void* in_buffer;
- iree_uk_ssize_t in_offset;
- iree_uk_ssize_t in_stride0;
+ iree_uk_index_t in_offset;
+ iree_uk_index_t in_stride0;
void* out_buffer;
- iree_uk_ssize_t out_offset;
- iree_uk_ssize_t out_stride0;
- iree_uk_ssize_t in_size0;
- iree_uk_ssize_t in_size1;
- iree_uk_ssize_t out_size0;
- iree_uk_ssize_t out_size1;
- iree_uk_ssize_t out_size2;
- iree_uk_ssize_t out_size3;
+ iree_uk_index_t out_offset;
+ iree_uk_index_t out_stride0;
+ iree_uk_index_t in_size0;
+ iree_uk_index_t in_size1;
+ iree_uk_index_t out_size0;
+ iree_uk_index_t out_size1;
+ iree_uk_index_t out_size2;
+ iree_uk_index_t out_size3;
// The least significant bits of `padding_value`, up to element size, are used
// for padding. As this is based solely on bit-significance and not on byte
// addresses, this is independent of endianness.
diff --git a/runtime/src/iree/builtins/ukernel/pack_internal.h b/runtime/src/iree/builtins/ukernel/pack_internal.h
index ce92874..cc381a0 100644
--- a/runtime/src/iree/builtins/ukernel/pack_internal.h
+++ b/runtime/src/iree/builtins/ukernel/pack_internal.h
@@ -38,18 +38,18 @@
typedef void (*iree_uk_pack_tile_func_t)(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1);
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1);
// Tile kernel declarations. Prototype matches iree_uk_unpack_tile_func_t.
#define IREE_UK_PACK_TILE_FUNC_DECL(NAME) \
void NAME(void* IREE_UK_RESTRICT out_tile_ptr, \
const void* IREE_UK_RESTRICT in_tile_ptr, \
- iree_uk_ssize_t outer_size1, iree_uk_ssize_t out_stride1, \
- iree_uk_ssize_t in_stride0, iree_uk_ssize_t elem_size, \
- iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1);
+ iree_uk_index_t outer_size1, iree_uk_index_t out_stride1, \
+ iree_uk_index_t in_stride0, iree_uk_index_t elem_size, \
+ iree_uk_index_t tile_size0, iree_uk_index_t tile_size1);
// Returns the tile function to use for the pack op with the given params.
iree_uk_pack_tile_func_t iree_uk_pack_select_tile_func(
diff --git a/runtime/src/iree/builtins/ukernel/pack_tile.c b/runtime/src/iree/builtins/ukernel/pack_tile.c
index c0c9afe..fe87939 100644
--- a/runtime/src/iree/builtins/ukernel/pack_tile.c
+++ b/runtime/src/iree/builtins/ukernel/pack_tile.c
@@ -8,16 +8,16 @@
static void iree_uk_pack_tile_generic_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
- for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+ for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
const char* IREE_UK_RESTRICT in_ptr = in_ptr_l1;
char* IREE_UK_RESTRICT out_ptr = out_ptr_l1;
- for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+ for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
iree_uk_memcpy(out_ptr, in_ptr, tile_size1 * elem_size);
out_ptr += tile_size1 * elem_size;
in_ptr += in_stride0 * elem_size;
@@ -29,19 +29,19 @@
static void iree_uk_pack_tile_generic_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
- for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+ for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
const char* IREE_UK_RESTRICT in_ptr_l2 = in_ptr_l1;
char* IREE_UK_RESTRICT out_ptr_l2 = out_ptr_l1;
- for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+ for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
const char* IREE_UK_RESTRICT in_ptr = in_ptr_l2;
char* IREE_UK_RESTRICT out_ptr = out_ptr_l2;
- for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+ for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
iree_uk_memcpy(out_ptr, in_ptr, elem_size);
out_ptr += tile_size0 * elem_size;
in_ptr += elem_size;
diff --git a/runtime/src/iree/builtins/ukernel/query_tile_sizes.h b/runtime/src/iree/builtins/ukernel/query_tile_sizes.h
index 3f819f8..a7578c7 100644
--- a/runtime/src/iree/builtins/ukernel/query_tile_sizes.h
+++ b/runtime/src/iree/builtins/ukernel/query_tile_sizes.h
@@ -16,14 +16,14 @@
// Parameters for a query_tile_sizes operation.
typedef struct iree_uk_query_tile_sizes_2d_params_t {
iree_uk_uint32_t flags;
- iree_uk_ssize_t size0;
- iree_uk_ssize_t size1;
+ iree_uk_index_t size0;
+ iree_uk_index_t size1;
const iree_uk_uint64_t* cpu_data;
} iree_uk_query_tile_sizes_2d_params_t;
typedef struct iree_uk_query_tile_sizes_2d_out_params_t {
- iree_uk_ssize_t tile_size0;
- iree_uk_ssize_t tile_size1;
+ iree_uk_index_t tile_size0;
+ iree_uk_index_t tile_size1;
} iree_uk_query_tile_sizes_2d_out_params_t;
IREE_UK_EXPORT int iree_uk_query_tile_sizes_2d(
diff --git a/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c
index 32dc799..b572d78 100644
--- a/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c
@@ -274,17 +274,17 @@
.in_stride0 = mmt4d_params.out_stride0,
};
- iree_uk_ssize_t rowmajor_lhs_buffer_size =
+ iree_uk_index_t rowmajor_lhs_buffer_size =
iree_uk_2d_buffer_length(lhs_type, params->M, params->K);
- iree_uk_ssize_t rowmajor_rhs_buffer_size =
+ iree_uk_index_t rowmajor_rhs_buffer_size =
iree_uk_2d_buffer_length(rhs_type, params->K, params->N);
- iree_uk_ssize_t rowmajor_out_buffer_size =
+ iree_uk_index_t rowmajor_out_buffer_size =
iree_uk_2d_buffer_length(out_type, params->M, params->N);
- iree_uk_ssize_t packed_lhs_buffer_size =
+ iree_uk_index_t packed_lhs_buffer_size =
iree_uk_2d_buffer_length(lhs_type, M1, mmt4d_params.lhs_stride0);
- iree_uk_ssize_t packed_rhs_buffer_size =
+ iree_uk_index_t packed_rhs_buffer_size =
iree_uk_2d_buffer_length(rhs_type, N1, mmt4d_params.rhs_stride0);
- iree_uk_ssize_t packed_out_buffer_size =
+ iree_uk_index_t packed_out_buffer_size =
iree_uk_2d_buffer_length(out_type, M1, mmt4d_params.out_stride0);
void* rowmajor_lhs_buffer = malloc(rowmajor_lhs_buffer_size);
void* rowmajor_rhs_buffer = malloc(rowmajor_rhs_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c
index 94f68ea..edf7afc 100644
--- a/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c
@@ -28,10 +28,10 @@
benchmark_def->user_data;
int64_t total_iterations = 0;
- iree_uk_ssize_t buffer_size = user_data->working_set_size / 2;
+ iree_uk_index_t buffer_size = user_data->working_set_size / 2;
uint8_t* in_buffer = malloc(buffer_size);
uint8_t* out_buffer = malloc(buffer_size);
- for (iree_uk_ssize_t i = 0; i < buffer_size; ++i) in_buffer[i] = (i & 0xFF);
+ for (iree_uk_index_t i = 0; i < buffer_size; ++i) in_buffer[i] = (i & 0xFF);
int64_t batch_count = 1;
while (iree_benchmark_keep_running(benchmark_state, batch_count)) {
for (int i = 0; i < batch_count; ++i) {
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 66cc4de..0fabf45 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -47,11 +47,11 @@
iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
- iree_uk_ssize_t lhs_buffer_size =
+ iree_uk_index_t lhs_buffer_size =
iree_uk_2d_buffer_length(lhs_type, params.M, params.lhs_stride0);
- iree_uk_ssize_t rhs_buffer_size =
+ iree_uk_index_t rhs_buffer_size =
iree_uk_2d_buffer_length(rhs_type, params.N, params.rhs_stride0);
- iree_uk_ssize_t out_buffer_size =
+ iree_uk_index_t out_buffer_size =
iree_uk_2d_buffer_length(out_type, params.M, params.out_stride0);
void* lhs_buffer = malloc(lhs_buffer_size);
void* rhs_buffer = malloc(rhs_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 7799dd2..b2a47c0 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -14,8 +14,8 @@
float* out_ptr, const float* lhs_ptr, const float* rhs_ptr,
const iree_uk_mmt4d_params_t* params) {
float acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0.f;
- for (iree_uk_ssize_t k = 0; k < params->K; ++k) {
- for (iree_uk_ssize_t k0 = 0; k0 < params->K0; ++k0) {
+ for (iree_uk_index_t k = 0; k < params->K; ++k) {
+ for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) {
float lhs_val = lhs_ptr[k * params->M0 * params->K0 + k0];
float rhs_val = rhs_ptr[k * params->N0 * params->K0 + k0];
acc += lhs_val * rhs_val;
@@ -28,8 +28,8 @@
int32_t* out_ptr, const int8_t* lhs_ptr, const int8_t* rhs_ptr,
const iree_uk_mmt4d_params_t* params) {
int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0;
- for (iree_uk_ssize_t k = 0; k < params->K; ++k) {
- for (iree_uk_ssize_t k0 = 0; k0 < params->K0; ++k0) {
+ for (iree_uk_index_t k = 0; k < params->K; ++k) {
+ for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) {
int32_t lhs_val = lhs_ptr[k * params->M0 * params->K0 + k0];
int32_t rhs_val = rhs_ptr[k * params->N0 * params->K0 + k0];
acc += lhs_val * rhs_val;
@@ -40,14 +40,14 @@
static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) {
iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
- iree_uk_ssize_t lhs_elem_size =
+ iree_uk_index_t lhs_elem_size =
iree_uk_type_size(iree_uk_mmt4d_lhs_type(mmt4d_type));
- iree_uk_ssize_t rhs_elem_size =
+ iree_uk_index_t rhs_elem_size =
iree_uk_type_size(iree_uk_mmt4d_rhs_type(mmt4d_type));
- iree_uk_ssize_t out_elem_size =
+ iree_uk_index_t out_elem_size =
iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type));
- for (iree_uk_ssize_t i = 0; i < params->M; ++i) {
- for (iree_uk_ssize_t j = 0; j < params->N; ++j) {
+ for (iree_uk_index_t i = 0; i < params->M; ++i) {
+ for (iree_uk_index_t j = 0; j < params->N; ++j) {
void* out_tile_ptr = ((char*)params->out_buffer) +
(params->out_offset + i * params->out_stride0 +
j * params->M0 * params->N0) *
@@ -58,8 +58,8 @@
const void* rhs_panel_ptr =
((const char*)params->rhs_buffer) +
(params->rhs_offset + j * params->rhs_stride0) * rhs_elem_size;
- for (iree_uk_ssize_t i0 = 0; i0 < params->M0; ++i0) {
- for (iree_uk_ssize_t j0 = 0; j0 < params->N0; ++j0) {
+ for (iree_uk_index_t i0 = 0; i0 < params->M0; ++i0) {
+ for (iree_uk_index_t j0 = 0; j0 < params->N0; ++j0) {
void* out_ptr =
((char*)out_tile_ptr) + (i0 * params->N0 + j0) * out_elem_size;
const void* lhs_ptr =
@@ -104,9 +104,9 @@
iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
- iree_uk_ssize_t lhs_buffer_size =
+ iree_uk_index_t lhs_buffer_size =
iree_uk_2d_buffer_length(lhs_type, params.M, params.lhs_stride0);
- iree_uk_ssize_t rhs_buffer_size =
+ iree_uk_index_t rhs_buffer_size =
iree_uk_2d_buffer_length(rhs_type, params.N, params.rhs_stride0);
void* lhs_buffer = malloc(lhs_buffer_size);
void* rhs_buffer = malloc(rhs_buffer_size);
@@ -122,7 +122,7 @@
iree_uk_mmt4d_params_t reference_params;
memcpy(&reference_params, ¶ms, sizeof params);
- iree_uk_ssize_t out_buffer_size =
+ iree_uk_index_t out_buffer_size =
iree_uk_2d_buffer_length(out_type, params.M, params.out_stride0);
void* reference_out_buffer = malloc(out_buffer_size);
iree_uk_write_random_buffer(reference_out_buffer, out_buffer_size, out_type,
diff --git a/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c
index 830c89b..7301169 100644
--- a/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c
@@ -33,15 +33,15 @@
iree_uk_pack_type_t pack_type = iree_uk_pack_type(params.flags);
iree_uk_type_t in_type = iree_uk_pack_in_type(pack_type);
iree_uk_type_t out_type = iree_uk_pack_out_type(pack_type);
- iree_uk_ssize_t in_type_size = iree_uk_type_size(in_type);
- iree_uk_ssize_t out_type_size = iree_uk_type_size(out_type);
+ iree_uk_index_t in_type_size = iree_uk_type_size(in_type);
+ iree_uk_index_t out_type_size = iree_uk_type_size(out_type);
// The inner dims 2, 3 are given to us as part of the benchmark user_data.
// The outer dims 0, 1 are to be determined based on FLAG_working_set_size.
- iree_uk_ssize_t out_size0 = 1;
- iree_uk_ssize_t out_size1 = 1;
- iree_uk_ssize_t out_size2 = params.out_size2;
- iree_uk_ssize_t out_size3 = params.out_size3;
+ iree_uk_index_t out_size0 = 1;
+ iree_uk_index_t out_size1 = 1;
+ iree_uk_index_t out_size2 = params.out_size2;
+ iree_uk_index_t out_size3 = params.out_size3;
int target_matrix_size_in_elems =
FLAG_working_set_size / (in_type_size + out_type_size);
int target_product_of_outer_sizes_0_1 =
@@ -55,18 +55,18 @@
params.out_size0 = out_size0;
params.out_size1 = out_size1;
if (params.flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&out_size0, &out_size1);
+ iree_uk_index_swap(&out_size0, &out_size1);
}
if (params.flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&out_size2, &out_size3);
+ iree_uk_index_swap(&out_size2, &out_size3);
}
params.in_size0 = iree_max(0, out_size0 * out_size2 - FLAG_padding_size);
params.in_size1 = iree_max(0, out_size1 * out_size3 - FLAG_padding_size);
params.in_stride0 = params.in_size1;
params.out_stride0 = params.out_size1 * params.out_size2 * params.out_size3;
- iree_uk_ssize_t in_buffer_size =
+ iree_uk_index_t in_buffer_size =
iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
- iree_uk_ssize_t out_buffer_size =
+ iree_uk_index_t out_buffer_size =
iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
void* in_buffer = malloc(in_buffer_size);
void* out_buffer = malloc(out_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/pack_test.c b/runtime/src/iree/builtins/ukernel/tools/pack_test.c
index 4163b56..295398d 100644
--- a/runtime/src/iree/builtins/ukernel/tools/pack_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/pack_test.c
@@ -14,35 +14,35 @@
// For now, the input and output element types are always the same.
iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
- iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
- iree_uk_ssize_t outer_size0 = params->out_size0;
- iree_uk_ssize_t outer_size1 = params->out_size1;
- iree_uk_ssize_t tile_size0 = params->out_size2;
- iree_uk_ssize_t tile_size1 = params->out_size3;
- iree_uk_ssize_t out_stride_l0 = params->out_stride0;
- iree_uk_ssize_t out_stride_l1 = params->out_size3 * params->out_size2;
- iree_uk_ssize_t out_stride_l2 = params->out_size3;
- iree_uk_ssize_t out_stride_l3 = 1;
+ iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+ iree_uk_index_t outer_size0 = params->out_size0;
+ iree_uk_index_t outer_size1 = params->out_size1;
+ iree_uk_index_t tile_size0 = params->out_size2;
+ iree_uk_index_t tile_size1 = params->out_size3;
+ iree_uk_index_t out_stride_l0 = params->out_stride0;
+ iree_uk_index_t out_stride_l1 = params->out_size3 * params->out_size2;
+ iree_uk_index_t out_stride_l2 = params->out_size3;
+ iree_uk_index_t out_stride_l3 = 1;
if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&outer_size0, &outer_size1);
- iree_uk_ssize_swap(&out_stride_l0, &out_stride_l1);
+ iree_uk_index_swap(&outer_size0, &outer_size1);
+ iree_uk_index_swap(&out_stride_l0, &out_stride_l1);
}
if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
- iree_uk_ssize_swap(&out_stride_l2, &out_stride_l3);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&out_stride_l2, &out_stride_l3);
}
IREE_UK_ASSERT(outer_size0 * tile_size0 >= params->in_size0);
IREE_UK_ASSERT(outer_size1 * tile_size1 >= params->in_size1);
- for (iree_uk_ssize_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
- for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
- for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
- for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
- iree_uk_ssize_t out_offset =
+ for (iree_uk_index_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
+ for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+ for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+ for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+ iree_uk_index_t out_offset =
params->out_offset + outer_i0 * out_stride_l0 +
tile_i0 * out_stride_l2 + outer_i1 * out_stride_l1 +
tile_i1 * out_stride_l3;
- iree_uk_ssize_t i0 = outer_i0 * tile_size0 + tile_i0;
- iree_uk_ssize_t i1 = outer_i1 * tile_size1 + tile_i1;
+ iree_uk_index_t i0 = outer_i0 * tile_size0 + tile_i0;
+ iree_uk_index_t i1 = outer_i1 * tile_size1 + tile_i1;
char* out_ptr = ((char*)params->out_buffer) + out_offset * elem_size;
if (i0 >= params->in_size0 || i1 >= params->in_size1) {
if (elem_size == 1) {
@@ -52,12 +52,12 @@
} else if (elem_size == 4) {
*(iree_uk_uint32_t*)out_ptr = params->padding_value;
} else {
- for (iree_uk_ssize_t k = 0; k < elem_size; k += 8) {
+ for (iree_uk_index_t k = 0; k < elem_size; k += 8) {
*(iree_uk_uint64_t*)(out_ptr + k) = params->padding_value;
}
}
} else {
- iree_uk_ssize_t in_offset =
+ iree_uk_index_t in_offset =
params->in_offset + i1 + i0 * params->in_stride0;
const char* in_ptr =
((char*)params->in_buffer) + in_offset * elem_size;
@@ -80,7 +80,7 @@
params.out_stride0 = params.out_size1 * params.out_size2 * params.out_size3;
iree_uk_pack_type_t pack_type = iree_uk_pack_type(params.flags);
iree_uk_type_t in_type = iree_uk_pack_in_type(pack_type);
- iree_uk_ssize_t in_buffer_size =
+ iree_uk_index_t in_buffer_size =
iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
void* in_buffer = malloc(in_buffer_size);
iree_uk_write_random_buffer(in_buffer, in_buffer_size, in_type, engine);
@@ -92,7 +92,7 @@
iree_uk_pack_params_t reference_params;
memcpy(&reference_params, ¶ms, sizeof reference_params);
iree_uk_type_t out_type = iree_uk_pack_out_type(pack_type);
- iree_uk_ssize_t out_buffer_size =
+ iree_uk_index_t out_buffer_size =
iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
void* reference_out_buffer = malloc(out_buffer_size);
iree_uk_write_random_buffer(reference_out_buffer, out_buffer_size, out_type,
@@ -157,21 +157,21 @@
params.out_size1 = outer_shape.size1;
if (transpose_outer) {
params.flags |= IREE_UK_FLAG_PACK_TRANSPOSE_OUTER;
- iree_uk_ssize_swap(¶ms.out_size0, ¶ms.out_size1);
+ iree_uk_index_swap(¶ms.out_size0, ¶ms.out_size1);
}
- iree_uk_ssize_t tile_size0 = params.out_size2;
- iree_uk_ssize_t tile_size1 = params.out_size3;
+ iree_uk_index_t tile_size0 = params.out_size2;
+ iree_uk_index_t tile_size1 = params.out_size3;
if (transpose_inner) {
params.flags |= IREE_UK_FLAG_PACK_TRANSPOSE_INNER;
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
}
params.in_size0 = outer_shape.size0 * tile_size0;
params.in_size1 = outer_shape.size1 * tile_size1;
iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test);
if (pad == pad_one_incomplete_tile) {
- iree_uk_ssize_t pad_size0 =
+ iree_uk_index_t pad_size0 =
iree_uk_random_engine_get_0_65535(engine) % tile_size0;
- iree_uk_ssize_t pad_size1 =
+ iree_uk_index_t pad_size1 =
iree_uk_random_engine_get_0_65535(engine) % tile_size1;
params.in_size0 = params.in_size0 - pad_size0;
if (params.in_size0 < 0) params.in_size0 = 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c
index 46bac1d..16e200f 100644
--- a/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c
@@ -34,15 +34,15 @@
iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params.flags);
iree_uk_type_t in_type = iree_uk_unpack_in_type(unpack_type);
iree_uk_type_t out_type = iree_uk_unpack_out_type(unpack_type);
- iree_uk_ssize_t in_type_size = iree_uk_type_size(in_type);
- iree_uk_ssize_t out_type_size = iree_uk_type_size(out_type);
+ iree_uk_index_t in_type_size = iree_uk_type_size(in_type);
+ iree_uk_index_t out_type_size = iree_uk_type_size(out_type);
// The inner dims 2, 3 are given to us as part of the benchmark user_data.
// The outer dims 0, 1 are to be determined based on FLAG_working_set_size.
- iree_uk_ssize_t in_size0 = 1;
- iree_uk_ssize_t in_size1 = 1;
- iree_uk_ssize_t in_size2 = params.in_size2;
- iree_uk_ssize_t in_size3 = params.in_size3;
+ iree_uk_index_t in_size0 = 1;
+ iree_uk_index_t in_size1 = 1;
+ iree_uk_index_t in_size2 = params.in_size2;
+ iree_uk_index_t in_size3 = params.in_size3;
int target_matrix_size_in_elems =
FLAG_working_set_size / (in_type_size + out_type_size);
int target_product_of_outer_sizes_0_1 =
@@ -56,18 +56,18 @@
params.in_size0 = in_size0;
params.in_size1 = in_size1;
if (params.flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&in_size0, &in_size1);
+ iree_uk_index_swap(&in_size0, &in_size1);
}
if (params.flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&in_size2, &in_size3);
+ iree_uk_index_swap(&in_size2, &in_size3);
}
params.out_size0 = iree_max(0, in_size0 * in_size2 - FLAG_padding_size);
params.out_size1 = iree_max(0, in_size1 * in_size3 - FLAG_padding_size);
params.out_stride0 = params.out_size1;
params.in_stride0 = params.in_size1 * params.in_size2 * params.in_size3;
- iree_uk_ssize_t in_buffer_size =
+ iree_uk_index_t in_buffer_size =
iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
- iree_uk_ssize_t out_buffer_size =
+ iree_uk_index_t out_buffer_size =
iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
void* in_buffer = malloc(in_buffer_size);
void* out_buffer = malloc(out_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/unpack_test.c b/runtime/src/iree/builtins/ukernel/tools/unpack_test.c
index 8b8fac3..257ff67 100644
--- a/runtime/src/iree/builtins/ukernel/tools/unpack_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/unpack_test.c
@@ -14,35 +14,35 @@
iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
// For now, the input and output element types are always the same.
iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
- iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
- iree_uk_ssize_t outer_size0 = params->in_size0;
- iree_uk_ssize_t outer_size1 = params->in_size1;
- iree_uk_ssize_t tile_size0 = params->in_size2;
- iree_uk_ssize_t tile_size1 = params->in_size3;
- iree_uk_ssize_t in_stride_outer0 = params->in_stride0;
- iree_uk_ssize_t in_stride_outer1 = params->in_size3 * params->in_size2;
- iree_uk_ssize_t in_stride_tile0 = params->in_size3;
- iree_uk_ssize_t in_stride_tile1 = 1;
+ iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+ iree_uk_index_t outer_size0 = params->in_size0;
+ iree_uk_index_t outer_size1 = params->in_size1;
+ iree_uk_index_t tile_size0 = params->in_size2;
+ iree_uk_index_t tile_size1 = params->in_size3;
+ iree_uk_index_t in_stride_outer0 = params->in_stride0;
+ iree_uk_index_t in_stride_outer1 = params->in_size3 * params->in_size2;
+ iree_uk_index_t in_stride_tile0 = params->in_size3;
+ iree_uk_index_t in_stride_tile1 = 1;
if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&outer_size0, &outer_size1);
- iree_uk_ssize_swap(&in_stride_outer0, &in_stride_outer1);
+ iree_uk_index_swap(&outer_size0, &outer_size1);
+ iree_uk_index_swap(&in_stride_outer0, &in_stride_outer1);
}
if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
- iree_uk_ssize_swap(&in_stride_tile0, &in_stride_tile1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&in_stride_tile0, &in_stride_tile1);
}
- for (iree_uk_ssize_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
- for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
- for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
- for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
- iree_uk_ssize_t in_offset =
+ for (iree_uk_index_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
+ for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+ for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+ for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+ iree_uk_index_t in_offset =
params->in_offset + outer_i0 * in_stride_outer0 +
tile_i0 * in_stride_tile0 + outer_i1 * in_stride_outer1 +
tile_i1 * in_stride_tile1;
- iree_uk_ssize_t i0 = outer_i0 * tile_size0 + tile_i0;
- iree_uk_ssize_t i1 = outer_i1 * tile_size1 + tile_i1;
+ iree_uk_index_t i0 = outer_i0 * tile_size0 + tile_i0;
+ iree_uk_index_t i1 = outer_i1 * tile_size1 + tile_i1;
if (!(i0 >= params->out_size0 || i1 >= params->out_size1)) {
- iree_uk_ssize_t out_offset =
+ iree_uk_index_t out_offset =
params->out_offset + i1 + i0 * params->out_stride0;
const char* in_ptr =
((char*)params->in_buffer) + in_offset * elem_size;
@@ -68,7 +68,7 @@
iree_uk_random_engine_get_0_1(engine);
iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params.flags);
iree_uk_type_t in_type = iree_uk_unpack_in_type(unpack_type);
- iree_uk_ssize_t in_buffer_size =
+ iree_uk_index_t in_buffer_size =
iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
void* in_buffer = malloc(in_buffer_size);
iree_uk_write_random_buffer(in_buffer, in_buffer_size, in_type, engine);
@@ -80,7 +80,7 @@
iree_uk_unpack_params_t reference_params;
memcpy(&reference_params, ¶ms, sizeof reference_params);
iree_uk_type_t out_type = iree_uk_unpack_out_type(unpack_type);
- iree_uk_ssize_t out_buffer_size =
+ iree_uk_index_t out_buffer_size =
iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
void* reference_out_buffer = malloc(out_buffer_size);
iree_uk_write_random_buffer(reference_out_buffer, out_buffer_size, out_type,
@@ -139,31 +139,31 @@
memcpy(¶ms, src_params, sizeof params);
params.cpu_data = iree_uk_test_cpu_data(test);
outer_shape_t outer_shape = outer_shapes[i];
- iree_uk_ssize_t in_size0 = outer_shape.size0;
- iree_uk_ssize_t in_size1 = outer_shape.size1;
+ iree_uk_index_t in_size0 = outer_shape.size0;
+ iree_uk_index_t in_size1 = outer_shape.size1;
params.in_size0 = in_size0;
params.in_size1 = in_size1;
if (pad == pad_a_lot) {
params.in_size0 += 16;
params.in_size1 += 16;
}
- iree_uk_ssize_t tile_size0 = params.in_size2;
- iree_uk_ssize_t tile_size1 = params.in_size3;
+ iree_uk_index_t tile_size0 = params.in_size2;
+ iree_uk_index_t tile_size1 = params.in_size3;
if (transpose_outer) {
params.flags |= IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER;
- iree_uk_ssize_swap(&in_size0, &in_size1);
+ iree_uk_index_swap(&in_size0, &in_size1);
}
if (transpose_inner) {
params.flags |= IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER;
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
}
params.out_size0 = in_size0 * tile_size0;
params.out_size1 = in_size1 * tile_size1;
if (pad == pad_one_incomplete_tile) {
iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test);
- iree_uk_ssize_t pad_size0 =
+ iree_uk_index_t pad_size0 =
iree_uk_random_engine_get_0_65535(engine) % tile_size0;
- iree_uk_ssize_t pad_size1 =
+ iree_uk_index_t pad_size1 =
iree_uk_random_engine_get_0_65535(engine) % tile_size1;
params.out_size0 = params.out_size0 - pad_size0;
if (params.out_size0 < 0) params.out_size0 = 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/util.c b/runtime/src/iree/builtins/ukernel/tools/util.c
index 8aeeff7..aafd030 100644
--- a/runtime/src/iree/builtins/ukernel/tools/util.c
+++ b/runtime/src/iree/builtins/ukernel/tools/util.c
@@ -29,20 +29,20 @@
abort();
}
-iree_uk_ssize_t iree_uk_2d_buffer_length(iree_uk_type_t type,
- iree_uk_ssize_t size0,
- iree_uk_ssize_t stride0) {
+iree_uk_index_t iree_uk_2d_buffer_length(iree_uk_type_t type,
+ iree_uk_index_t size0,
+ iree_uk_index_t stride0) {
// Just for testing purposes, so it's OK to overestimate size.
return size0 * stride0 << iree_uk_type_size_log2(type);
}
bool iree_uk_2d_buffers_equal(const void* buf1, const void* buf2,
- iree_uk_type_t type, iree_uk_ssize_t size0,
- iree_uk_ssize_t size1, iree_uk_ssize_t stride0) {
- iree_uk_ssize_t elem_size = iree_uk_type_size(type);
+ iree_uk_type_t type, iree_uk_index_t size0,
+ iree_uk_index_t size1, iree_uk_index_t stride0) {
+ iree_uk_index_t elem_size = iree_uk_type_size(type);
const char* buf1_ptr = buf1;
const char* buf2_ptr = buf2;
- for (iree_uk_ssize_t i0 = 0; i0 < size0; ++i0) {
+ for (iree_uk_index_t i0 = 0; i0 < size0; ++i0) {
if (memcmp(buf1_ptr, buf2_ptr, elem_size * size1)) return false;
buf1_ptr += elem_size * stride0;
buf2_ptr += elem_size * stride0;
@@ -82,12 +82,12 @@
return (v % 32) - 16;
}
-void iree_uk_write_random_buffer(void* buffer, iree_uk_ssize_t size_in_bytes,
+void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
iree_uk_type_t type,
iree_uk_random_engine_t* engine) {
- iree_uk_ssize_t elem_size = iree_uk_type_size(type);
- iree_uk_ssize_t size_in_elems = size_in_bytes / elem_size;
- for (iree_uk_ssize_t i = 0; i < size_in_elems; ++i) {
+ iree_uk_index_t elem_size = iree_uk_type_size(type);
+ iree_uk_index_t size_in_elems = size_in_bytes / elem_size;
+ for (iree_uk_index_t i = 0; i < size_in_elems; ++i) {
// Small integers, should work for now for all the types we currently have
// and enable exact float arithmetic, allowing to keep tests simpler for
// now. Watch out for when we'll do float16!
diff --git a/runtime/src/iree/builtins/ukernel/tools/util.h b/runtime/src/iree/builtins/ukernel/tools/util.h
index acd9380..37a6df4 100644
--- a/runtime/src/iree/builtins/ukernel/tools/util.h
+++ b/runtime/src/iree/builtins/ukernel/tools/util.h
@@ -10,13 +10,13 @@
#include "iree/builtins/ukernel/api.h"
// Helper to determine the length of test buffers to allocate.
-iree_uk_ssize_t iree_uk_2d_buffer_length(iree_uk_type_t type,
- iree_uk_ssize_t size0,
- iree_uk_ssize_t size1);
+iree_uk_index_t iree_uk_2d_buffer_length(iree_uk_type_t type,
+ iree_uk_index_t size0,
+ iree_uk_index_t size1);
bool iree_uk_2d_buffers_equal(const void* buf1, const void* buf2,
- iree_uk_type_t type, iree_uk_ssize_t size0,
- iree_uk_ssize_t size1, iree_uk_ssize_t stride0);
+ iree_uk_type_t type, iree_uk_index_t size0,
+ iree_uk_index_t size1, iree_uk_index_t stride0);
// Simple deterministic pseudorandom generator. Same as C++'s std::minstd_rand.
typedef struct iree_uk_random_engine_t {
@@ -32,7 +32,7 @@
int iree_uk_random_engine_get_0_65535(iree_uk_random_engine_t* e);
int iree_uk_random_engine_get_0_1(iree_uk_random_engine_t* e);
int iree_uk_random_engine_get_minus16_plus15(iree_uk_random_engine_t* e);
-void iree_uk_write_random_buffer(void* buffer, iree_uk_ssize_t size_in_bytes,
+void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
iree_uk_type_t type,
iree_uk_random_engine_t* engine);
diff --git a/runtime/src/iree/builtins/ukernel/unpack.c b/runtime/src/iree/builtins/ukernel/unpack.c
index 4a3696e..fdeb0bc 100644
--- a/runtime/src/iree/builtins/ukernel/unpack.c
+++ b/runtime/src/iree/builtins/ukernel/unpack.c
@@ -18,8 +18,8 @@
} iree_uk_unpack_tmpbuf_helper_t;
// Return x/y for x>=0 and y>0, with a fast path for when y is a power of two.
-static iree_uk_ssize_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
- iree_uk_ssize_t x, iree_uk_int32_t y) {
+static iree_uk_index_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
+ iree_uk_index_t x, iree_uk_int32_t y) {
IREE_UK_ASSERT(x >= 0);
IREE_UK_ASSERT(y > 0);
return IREE_UK_LIKELY(iree_uk_is_po2_u32(y)) ? (x >> iree_uk_po2_log2_u32(y))
@@ -29,8 +29,8 @@
// Initializes a `iree_uk_unpack_padding_helper`. Asserts if the temporary
// buffer is smaller than one tile.
static void iree_uk_unpack_tmpbuf_helper_init(
- iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
- iree_uk_ssize_t elem_size, iree_uk_unpack_tmpbuf_helper_t* helper) {
+ iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+ iree_uk_index_t elem_size, iree_uk_unpack_tmpbuf_helper_t* helper) {
helper->max_tiles_in_tmp_buf = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
iree_uk_unpack_tmp_buf_size, tile_size0 * tile_size1 * elem_size);
IREE_UK_ASSERT(helper->max_tiles_in_tmp_buf > 0);
@@ -55,15 +55,15 @@
IREE_UK_ASSERT(params->in_size3 >= 0);
// Check that the input and output shapes match, give or take padding that
// must not exceed the inner tile size.s
- iree_uk_ssize_t outer_size0 = params->in_size0;
- iree_uk_ssize_t outer_size1 = params->in_size1;
- iree_uk_ssize_t tile_size0 = params->in_size2;
- iree_uk_ssize_t tile_size1 = params->in_size3;
+ iree_uk_index_t outer_size0 = params->in_size0;
+ iree_uk_index_t outer_size1 = params->in_size1;
+ iree_uk_index_t tile_size0 = params->in_size2;
+ iree_uk_index_t tile_size1 = params->in_size3;
if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&outer_size0, &outer_size1);
+ iree_uk_index_swap(&outer_size0, &outer_size1);
}
if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
}
IREE_UK_ASSERT(outer_size0 * tile_size0 >= params->out_size0);
IREE_UK_ASSERT(outer_size1 * tile_size1 >= params->out_size1);
@@ -79,7 +79,7 @@
iree_uk_unpack_tmpbuf_helper_t helper;
iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
- iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
+ iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
iree_uk_unpack_tmpbuf_helper_init(tile_size0, tile_size1, elem_size, &helper);
#endif // IREE_UK_ENABLE_ASSERTS
}
@@ -89,12 +89,12 @@
return (params->out_size0 == 0 || params->out_size1 == 0);
}
-static void iree_uk_copy_slice(iree_uk_ssize_t src_stride0, const char* src_buf,
- iree_uk_ssize_t dst_size0,
- iree_uk_ssize_t dst_size1,
- iree_uk_ssize_t dst_stride0, char* dst_buf,
- iree_uk_ssize_t elem_size) {
- for (iree_uk_ssize_t in_i0 = 0; in_i0 < dst_size0; in_i0++) {
+static void iree_uk_copy_slice(iree_uk_index_t src_stride0, const char* src_buf,
+ iree_uk_index_t dst_size0,
+ iree_uk_index_t dst_size1,
+ iree_uk_index_t dst_stride0, char* dst_buf,
+ iree_uk_index_t elem_size) {
+ for (iree_uk_index_t in_i0 = 0; in_i0 < dst_size0; in_i0++) {
iree_uk_memcpy(dst_buf, src_buf, dst_size1 * elem_size);
dst_buf += dst_stride0 * elem_size;
src_buf += src_stride0 * elem_size;
@@ -105,19 +105,19 @@
// incomplete tiles. In cases involving only complete tiles, it is faster to
// call tile_func directly.
static void iree_uk_unpack_row_using_tmpbuf(
- iree_uk_unpack_tile_func_t tile_func, iree_uk_ssize_t dim1_tile_start,
- iree_uk_ssize_t dim1_tile_end, iree_uk_ssize_t dim0_write_size,
- iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t out_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
+ iree_uk_unpack_tile_func_t tile_func, iree_uk_index_t dim1_tile_start,
+ iree_uk_index_t dim1_tile_end, iree_uk_index_t dim0_write_size,
+ iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+ iree_uk_index_t elem_size, iree_uk_index_t out_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
iree_uk_unpack_tmpbuf_helper_t* helper, const char* in_buf, char* out_buf) {
- iree_uk_ssize_t dim1_tile = dim1_tile_start;
+ iree_uk_index_t dim1_tile = dim1_tile_start;
while (dim1_tile < dim1_tile_end) {
- iree_uk_ssize_t dim1_chunk_tiles = iree_uk_ssize_clamp(
+ iree_uk_index_t dim1_chunk_tiles = iree_uk_index_clamp(
dim1_tile_end - dim1_tile, 0, helper->max_tiles_in_tmp_buf);
- iree_uk_ssize_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
- iree_uk_ssize_t dim1_chunk_src_pos = dim1_tile * tile_size1;
- iree_uk_ssize_t dim1_write_size = iree_uk_ssize_clamp(
+ iree_uk_index_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
+ iree_uk_index_t dim1_chunk_src_pos = dim1_tile * tile_size1;
+ iree_uk_index_t dim1_write_size = iree_uk_index_clamp(
out_size1 - dim1_chunk_src_pos, 0, dim1_chunk_src_width);
tile_func(helper->tmp_buf, in_buf + dim1_tile * in_stride1 * elem_size,
dim1_chunk_tiles, dim1_chunk_src_width, in_stride1, elem_size,
@@ -135,19 +135,19 @@
// For now, the input and output element types are always the same.
iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
- iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
- iree_uk_ssize_t outer_size0 = params->in_size0;
- iree_uk_ssize_t outer_size1 = params->in_size1;
- iree_uk_ssize_t tile_size0 = params->in_size2;
- iree_uk_ssize_t tile_size1 = params->in_size3;
- iree_uk_ssize_t in_stride0 = params->in_stride0;
- iree_uk_ssize_t in_stride1 = params->in_size3 * params->in_size2;
+ iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+ iree_uk_index_t outer_size0 = params->in_size0;
+ iree_uk_index_t outer_size1 = params->in_size1;
+ iree_uk_index_t tile_size0 = params->in_size2;
+ iree_uk_index_t tile_size1 = params->in_size3;
+ iree_uk_index_t in_stride0 = params->in_stride0;
+ iree_uk_index_t in_stride1 = params->in_size3 * params->in_size2;
if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
- iree_uk_ssize_swap(&outer_size0, &outer_size1);
- iree_uk_ssize_swap(&in_stride0, &in_stride1);
+ iree_uk_index_swap(&outer_size0, &outer_size1);
+ iree_uk_index_swap(&in_stride0, &in_stride1);
}
if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
- iree_uk_ssize_swap(&tile_size0, &tile_size1);
+ iree_uk_index_swap(&tile_size0, &tile_size1);
}
const char* in_buf =
(const char*)params->in_buffer + (params->in_offset * elem_size);
@@ -163,7 +163,7 @@
// source buffer's boundaries.
int dim1_full_tiles = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
params->out_size1, tile_size1);
- iree_uk_ssize_t i0 = 0;
+ iree_uk_index_t i0 = 0;
for (; i0 <= params->out_size0 - tile_size0; i0 += tile_size0) {
// Pack whole tiles that do not require padding (entirely within the source
// buffer's boundaries).
@@ -179,8 +179,8 @@
}
// Bottom-padding.
for (; i0 < outer_size0 * tile_size0; i0 += tile_size0) {
- iree_uk_ssize_t dim0_write_size =
- iree_uk_ssize_clamp(params->out_size0 - i0, 0, tile_size0);
+ iree_uk_index_t dim0_write_size =
+ iree_uk_index_clamp(params->out_size0 - i0, 0, tile_size0);
iree_uk_unpack_row_using_tmpbuf(
tile_func, 0, outer_size1, dim0_write_size, tile_size0, tile_size1,
elem_size, params->out_size1, params->out_stride0, in_stride1,
diff --git a/runtime/src/iree/builtins/ukernel/unpack.h b/runtime/src/iree/builtins/ukernel/unpack.h
index ba86c9e..ff7d9cd 100644
--- a/runtime/src/iree/builtins/ukernel/unpack.h
+++ b/runtime/src/iree/builtins/ukernel/unpack.h
@@ -15,17 +15,17 @@
typedef struct iree_uk_unpack_params_t {
const void* in_buffer;
- iree_uk_ssize_t in_offset;
- iree_uk_ssize_t in_stride0;
+ iree_uk_index_t in_offset;
+ iree_uk_index_t in_stride0;
void* out_buffer;
- iree_uk_ssize_t out_offset;
- iree_uk_ssize_t out_stride0;
- iree_uk_ssize_t in_size0;
- iree_uk_ssize_t in_size1;
- iree_uk_ssize_t in_size2;
- iree_uk_ssize_t in_size3;
- iree_uk_ssize_t out_size0;
- iree_uk_ssize_t out_size1;
+ iree_uk_index_t out_offset;
+ iree_uk_index_t out_stride0;
+ iree_uk_index_t in_size0;
+ iree_uk_index_t in_size1;
+ iree_uk_index_t in_size2;
+ iree_uk_index_t in_size3;
+ iree_uk_index_t out_size0;
+ iree_uk_index_t out_size1;
iree_uk_uint32_t flags;
const iree_uk_uint64_t* cpu_data;
} iree_uk_unpack_params_t;
diff --git a/runtime/src/iree/builtins/ukernel/unpack_internal.h b/runtime/src/iree/builtins/ukernel/unpack_internal.h
index 3d48d82..d92f056 100644
--- a/runtime/src/iree/builtins/ukernel/unpack_internal.h
+++ b/runtime/src/iree/builtins/ukernel/unpack_internal.h
@@ -38,18 +38,18 @@
typedef void (*iree_uk_unpack_tile_func_t)(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1);
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1);
// Tile kernel declarations. Prototype matches iree_uk_unpack_tile_func_t.
#define IREE_UK_UNPACK_TILE_FUNC_DECL(NAME) \
void NAME(void* IREE_UK_RESTRICT out_tile_ptr, \
const void* IREE_UK_RESTRICT in_tile_ptr, \
- iree_uk_ssize_t outer_size1, iree_uk_ssize_t out_stride0, \
- iree_uk_ssize_t in_stride1, iree_uk_ssize_t elem_size, \
- iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1);
+ iree_uk_index_t outer_size1, iree_uk_index_t out_stride0, \
+ iree_uk_index_t in_stride1, iree_uk_index_t elem_size, \
+ iree_uk_index_t tile_size0, iree_uk_index_t tile_size1);
// Returns the tile function to use for the unpack op with the given params.
iree_uk_unpack_tile_func_t iree_uk_unpack_select_tile_func(
diff --git a/runtime/src/iree/builtins/ukernel/unpack_tile.c b/runtime/src/iree/builtins/ukernel/unpack_tile.c
index 3a1b464..eb7266b 100644
--- a/runtime/src/iree/builtins/ukernel/unpack_tile.c
+++ b/runtime/src/iree/builtins/ukernel/unpack_tile.c
@@ -8,16 +8,16 @@
static void iree_uk_unpack_tile_generic_direct(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
- for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+ for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
const char* IREE_UK_RESTRICT in_ptr = in_ptr_l1;
char* IREE_UK_RESTRICT out_ptr = out_ptr_l1;
- for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+ for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
iree_uk_memcpy(out_ptr, in_ptr, tile_size1 * elem_size);
in_ptr += tile_size1 * elem_size;
out_ptr += out_stride0 * elem_size;
@@ -29,19 +29,19 @@
static void iree_uk_unpack_tile_generic_transpose(
void* IREE_UK_RESTRICT out_tile_ptr,
- const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
- iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
- iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
- iree_uk_ssize_t tile_size1) {
+ const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+ iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+ iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+ iree_uk_index_t tile_size1) {
const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
- for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+ for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
const char* IREE_UK_RESTRICT in_ptr_l2 = in_ptr_l1;
char* IREE_UK_RESTRICT out_ptr_l2 = out_ptr_l1;
- for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+ for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
const char* IREE_UK_RESTRICT in_ptr = in_ptr_l2;
char* IREE_UK_RESTRICT out_ptr = out_ptr_l2;
- for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+ for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
iree_uk_memcpy(out_ptr, in_ptr, elem_size);
in_ptr += tile_size0 * elem_size;
out_ptr += elem_size;