Correct 32bit/64bit separation in ukernel code. (#13878)

* Rename `iree_uk_ssize_t` to `iree_uk_index_t` to make it clear that
the primary requirement on this type is to match the compiler's `index`
type.
* Build non-arch/ ukernel code twice, for 32bit and 64bit architectures.
As in libdevice, wasm_32 / wasm_64 is chosen as a sane 32bit/64bit
architecture to pick when the only thing that matters is the bitness
(which only matters for choosing the right definition of
`iree_uk_index_t`).
* Drop `strip_target_info.py`, instead adopt `Device.cpp`'s code
dropping function attributes inside of the compiler after loading the
bitcode.
diff --git a/build_tools/bazel/iree_bitcode_library.bzl b/build_tools/bazel/iree_bitcode_library.bzl
index 09282f3..0a01299 100644
--- a/build_tools/bazel/iree_bitcode_library.bzl
+++ b/build_tools/bazel/iree_bitcode_library.bzl
@@ -41,20 +41,17 @@
 
 def iree_bitcode_library(
         name,
+        arch,
         srcs,
         internal_hdrs = [],
         copts = [],
         out = None,
-        arch = None,
         **kwargs):
     """Builds an LLVM bitcode library from an input file via clang.
 
     Args:
         name: Name of the target.
-        arch: Target architecture to compile for, in IREE_ARCH format. If left
-              empty, will produce architecture-independent bitcode by stripping
-              target triple and target attributes; that only makes sense if the
-              sources being compiled are truly architecture-independent.
+        arch: Target architecture to compile for, in IREE_ARCH format.
         srcs: source files to pass to clang.
         internal_hdrs: all headers transitively included by the source files.
                        Unlike typical Bazel `hdrs`, these are not exposed as
@@ -73,6 +70,10 @@
     builtin_headers_path = "external/llvm-project/clang/staging/include/"
 
     base_copts = [
+        # Target architecture
+        "-target",
+        iree_arch_to_llvm_arch(arch),
+
         # C17 with no system deps.
         "-std=c17",
         "-nostdinc",
@@ -97,20 +98,10 @@
         "-DIREE_DEVICE_STANDALONE=1",
     ]
 
-    llvmir_processing_tool = None
-    if arch:
-        # Compile to the specified target architecture.
-        base_copts.extend(["-target", iree_arch_to_llvm_arch(arch)])
-    else:
-        # Output text rather than binary serialization of LLVM IR for processing
-        base_copts.append("-S")
-
-        # Strip target information from generated LLVM IR.
-        llvmir_processing_tool = "//build_tools/scripts:strip_target_info"
-
     bitcode_files = []
     for src in srcs:
         bitcode_out = "%s_%s.bc" % (name, src)
+        bitcode_files.append(bitcode_out)
         native.genrule(
             name = "gen_%s" % (bitcode_out),
             srcs = [src, builtin_headers_dep] + internal_hdrs,
@@ -134,28 +125,6 @@
             **kwargs
         )
 
-        if llvmir_processing_tool:
-            processed_bitcode_out = "%s_%s.processed.bc" % (name, src)
-            native.genrule(
-                name = "gen_%s" % (processed_bitcode_out),
-                srcs = [bitcode_out],
-                outs = [processed_bitcode_out],
-                cmd = " ".join([
-                    "$(location %s)" % (llvmir_processing_tool),
-                    "< $(location %s)" % bitcode_out,
-                    "> $(location %s)" % processed_bitcode_out,
-                ]),
-                tools = [
-                    llvmir_processing_tool,
-                ],
-                message = "Processing %s into %s using %s..." % (bitcode_out, processed_bitcode_out, llvmir_processing_tool),
-                output_to_bindir = 1,
-                **kwargs
-            )
-            bitcode_files.append(processed_bitcode_out)
-        else:
-            bitcode_files.append(bitcode_out)
-
     if not out:
         out = "%s.bc" % (name)
     native.genrule(
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index f6b479d..023dc4d 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -484,20 +484,20 @@
 
   def iree_bitcode_library(self,
                            name,
+                           arch,
                            srcs,
                            internal_hdrs=None,
-                           copts=None,
-                           arch=None):
+                           copts=None):
     name_block = self._convert_string_arg_block("NAME", name, quote=False)
+    arch_block = self._convert_string_arg_block("ARCH", arch, quote=False)
     srcs_block = self._convert_srcs_block(srcs)
     copts_block = self._convert_string_list_block("COPTS", copts, sort=False)
-    arch_block = self._convert_string_arg_block("ARCH", arch, quote=False)
 
     self._converter.body += (f"iree_bitcode_library(\n"
                              f"{name_block}"
+                             f"{arch_block}"
                              f"{srcs_block}"
                              f"{copts_block}"
-                             f"{arch_block}"
                              f")\n\n")
 
   def iree_link_bitcode(self, name, bitcode_files):
diff --git a/build_tools/cmake/iree_bitcode_library.cmake b/build_tools/cmake/iree_bitcode_library.cmake
index 836c5ae..e6d2ba3 100644
--- a/build_tools/cmake/iree_bitcode_library.cmake
+++ b/build_tools/cmake/iree_bitcode_library.cmake
@@ -44,7 +44,12 @@
   # override this but it should be harmless.
   set(_BUILTIN_HEADERS_PATH "${IREE_BINARY_DIR}/llvm-project/lib/clang/${_CLANG_VERSION_MAJOR}/include/")
 
+  iree_arch_to_llvm_arch(_LLVM_ARCH "${_RULE_ARCH}")
+
   set(_COPTS
+    # Target architecture.
+    "-target" "${_LLVM_ARCH}"
+
     # C17 with no system deps.
     "-std=c17"
     "-nostdinc"
@@ -74,21 +79,11 @@
   list(APPEND _COPTS "-I" "${IREE_BINARY_DIR}/runtime/src")
   list(APPEND _COPTS "${_RULE_COPTS}")
 
-  if(_RULE_ARCH)
-    # Compile to the specified target architecture.
-    iree_arch_to_llvm_arch(_LLVM_ARCH "${_RULE_ARCH}")
-    list(APPEND _COPTS "-target" "${_LLVM_ARCH}")
-  else()
-    # Output text rather than binary serialization of LLVM IR for processing.
-    list(APPEND _COPTS "-S")
-    # Strip target information from generated LLVM IR.
-    set(_LLVMIR_PROCESSING_TOOL "${IREE_SOURCE_DIR}/build_tools/scripts/strip_target_info.py")
-  endif()
-
   set(_BITCODE_FILES)
   foreach(_SRC ${_RULE_SRCS})
     get_filename_component(_BITCODE_SRC_PATH "${_SRC}" REALPATH)
     set(_BITCODE_FILE "${_RULE_NAME}_${_SRC}.bc")
+    list(APPEND _BITCODE_FILES ${_BITCODE_FILE})
     add_custom_command(
       OUTPUT
         "${_BITCODE_FILE}"
@@ -106,28 +101,6 @@
         "Compiling ${_SRC} to ${_BITCODE_FILE}"
       VERBATIM
     )
-
-    if(_LLVMIR_PROCESSING_TOOL)
-      set(_PROCESSED_BITCODE_FILE "${_RULE_NAME}_${_SRC}.processed.bc")
-      list(APPEND _BITCODE_FILES ${_PROCESSED_BITCODE_FILE})
-      add_custom_command(
-        OUTPUT
-          "${_PROCESSED_BITCODE_FILE}"
-        COMMAND
-          "python3"
-          "${_LLVMIR_PROCESSING_TOOL}"
-          < "${_BITCODE_FILE}"
-          > "${_PROCESSED_BITCODE_FILE}"
-        DEPENDS
-          "${_BITCODE_FILE}"
-          "${_LLVMIR_PROCESSING_TOOL}"
-        COMMENT
-          "Processing ${_BITCODE_FILE} into ${_PROCESSED_BITCODE_FILE} using ${_LLVMIR_PROCESSING_TOOL}"
-        VERBATIM
-      )
-    else()  # _LLVMIR_PROCESSING_TOOL
-      list(APPEND _BITCODE_FILES ${_BITCODE_FILE})
-    endif()  # _LLVMIR_PROCESSING_TOOL
   endforeach()
 
   add_custom_command(
diff --git a/build_tools/scripts/BUILD.bazel b/build_tools/scripts/BUILD.bazel
deleted file mode 100644
index 6c39341..0000000
--- a/build_tools/scripts/BUILD.bazel
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-package(
-    default_visibility = ["//visibility:public"],
-    features = ["layering_check"],
-    licenses = ["notice"],  # Apache 2.0
-)
-
-py_binary(
-    name = "strip_target_info",
-    srcs = ["strip_target_info.py"],
-)
diff --git a/build_tools/scripts/strip_target_info.py b/build_tools/scripts/strip_target_info.py
deleted file mode 100644
index 7c936d7..0000000
--- a/build_tools/scripts/strip_target_info.py
+++ /dev/null
@@ -1,36 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2023 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""Strip LLVM IR of target triple and target-specific attributes
-"""
-
-import sys
-import re
-import os
-
-
-def main():
-  sys.stdout.write(f";\n")
-  sys.stdout.write(f"; Processed by {os.path.basename(__file__)}\n")
-  sys.stdout.write(f";\n")
-  target_triple_regex = re.compile(r'^\s*target triple\s*=\s*"[^"]*"')
-  target_cpu_regex = re.compile(r'"target-cpu"="[^"]*"')
-  target_features_regex = re.compile(r'"target-features"="[^"]*"')
-  tune_cpu_regex = re.compile(r'"tune-cpu"="[^"]*"')
-
-  for line in sys.stdin:
-    if "target" in line:
-      if re.match(target_triple_regex, line):
-        continue
-      line = re.sub(target_cpu_regex, '', line)
-      line = re.sub(target_features_regex, '', line)
-      line = re.sub(tune_cpu_regex, '', line)
-
-    sys.stdout.write(line)
-
-
-if __name__ == "__main__":
-  main()
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp
index 86df47c..e5f26bf 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.cpp
@@ -34,10 +34,29 @@
 }
 
 std::unique_ptr<llvm::Module> loadUKernelBaseBitcode(
-    llvm::LLVMContext& context) {
+    llvm::TargetMachine* targetMachine, llvm::LLVMContext& context) {
+  llvm::Triple triple = targetMachine->getTargetTriple();
+  StringRef filename;
+  if (triple.isArch64Bit()) {
+    filename = "ukernel_bitcode_64bit_base.bc";
+  } else if (triple.isArch32Bit()) {
+    filename = "ukernel_bitcode_32bit_base.bc";
+  } else {
+    return nullptr;
+  }
   std::unique_ptr<llvm::Module> baseBitcode =
-      loadUKernelBitcodeFile("ukernel_bitcode_base.bc", context);
+      loadUKernelBitcodeFile(filename, context);
   assert(baseBitcode && "base ukernel bitcode file not found");
+
+  // Copied from Device.cpp - TODO: move this to a shared utility.
+  // Clang adds its own per-function attributes that we need to strip so that
+  // our current executable variant target is used instead.
+  for (auto& func : baseBitcode->functions()) {
+    func.removeFnAttr("target-cpu");
+    func.removeFnAttr("tune-cpu");
+    func.removeFnAttr("target-features");
+  }
+
   return baseBitcode;
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h
index ca9d8f4..8ef2606 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/Builtins/UKernel.h
@@ -16,7 +16,7 @@
 namespace HAL {
 
 std::unique_ptr<llvm::Module> loadUKernelBaseBitcode(
-    llvm::LLVMContext &context);
+    llvm::TargetMachine *targetMachine, llvm::LLVMContext &context);
 
 std::unique_ptr<llvm::Module> loadUKernelArchBitcode(
     llvm::TargetMachine *targetMachine, llvm::LLVMContext &context);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
index 23f6dc5..4ef8a8b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/LLVMCPUTarget.cpp
@@ -476,7 +476,7 @@
       // The baseBitcode module contains weak symbols for fallbacks.
       // So we link it after the archBitcode and with LinkOnlyNeeded.
       std::unique_ptr<llvm::Module> baseBitcode =
-          loadUKernelBaseBitcode(context);
+          loadUKernelBaseBitcode(targetMachine.get(), context);
       // Sequence that access before we std::move(baseBitcode)!
       StringRef baseBitcodeName = baseBitcode->getName();
       if (failed(linkBitcodeModule(variantOp.getLoc(), moduleLinker,
diff --git a/runtime/src/iree/builtins/device/CMakeLists.txt b/runtime/src/iree/builtins/device/CMakeLists.txt
index 427f29f..25b5421 100644
--- a/runtime/src/iree/builtins/device/CMakeLists.txt
+++ b/runtime/src/iree/builtins/device/CMakeLists.txt
@@ -29,19 +29,19 @@
 iree_bitcode_library(
   NAME
     libdevice_wasm32_generic
-  SRCS
-    "device_generic.c"
   ARCH
     wasm_32
+  SRCS
+    "device_generic.c"
 )
 
 iree_bitcode_library(
   NAME
     libdevice_wasm64_generic
-  SRCS
-    "device_generic.c"
   ARCH
     wasm_64
+  SRCS
+    "device_generic.c"
 )
 
 iree_c_embed_data(
diff --git a/runtime/src/iree/builtins/ukernel/BUILD.bazel b/runtime/src/iree/builtins/ukernel/BUILD.bazel
index b97a6a5..2dedda2 100644
--- a/runtime/src/iree/builtins/ukernel/BUILD.bazel
+++ b/runtime/src/iree/builtins/ukernel/BUILD.bazel
@@ -119,19 +119,23 @@
     ],
 )
 
-iree_bitcode_library(
-    name = "ukernel_bitcode_base",
+[iree_bitcode_library(
+    name = "ukernel_bitcode_%sbit_base" % bitness,
     srcs = UKERNEL_BASE_SRCS,
+    arch = "wasm_%s" % bitness,
     internal_hdrs = [
         ":bitcode_internal_headers",
-        "//runtime/src/iree/builtins/ukernel/arch/x86_64:bitcode_internal_headers",
     ],
-)
+) for bitness in [
+    "32",
+    "64",
+]]
 
 c_embed_data(
     name = "embed_ukernel_bitcode",
     srcs = [
-        ":ukernel_bitcode_base.bc",
+        ":ukernel_bitcode_32bit_base.bc",
+        ":ukernel_bitcode_64bit_base.bc",
         "//runtime/src/iree/builtins/ukernel/arch/x86_64:ukernel_bitcode_x86_64.bc",
     ],
     c_file_output = "ukernel_bitcode.c",
diff --git a/runtime/src/iree/builtins/ukernel/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
index 3292437..73cad50 100644
--- a/runtime/src/iree/builtins/ukernel/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/CMakeLists.txt
@@ -108,7 +108,24 @@
 
 iree_bitcode_library(
   NAME
-    ukernel_bitcode_base
+    ukernel_bitcode_32bit_base
+  ARCH
+    wasm_32
+  SRCS
+    "mmt4d.c"
+    "mmt4d_tile.c"
+    "pack.c"
+    "pack_tile.c"
+    "query_tile_sizes.c"
+    "unpack_tile.c"
+    "weak.c"
+)
+
+iree_bitcode_library(
+  NAME
+    ukernel_bitcode_64bit_base
+  ARCH
+    wasm_64
   SRCS
     "mmt4d.c"
     "mmt4d_tile.c"
@@ -124,7 +141,8 @@
     embed_ukernel_bitcode
   SRCS
     "runtime/src/iree/builtins/ukernel/arch/x86_64/ukernel_bitcode_x86_64.bc"
-    "ukernel_bitcode_base.bc"
+    "ukernel_bitcode_32bit_base.bc"
+    "ukernel_bitcode_64bit_base.bc"
   DEPS
 
   C_FILE_OUTPUT
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h
index c05307b..1bcb60a 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/common_arm_64.h
@@ -30,7 +30,7 @@
 #endif
 
 static inline int8x16x2_t iree_uk_neon_load_8x4xi8_strided(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+    const iree_uk_int8_t* src, iree_uk_index_t stride) {
   int32x4_t v0_i32 = vdupq_n_s32(0);
   int32x4_t v1_i32 = vdupq_n_s32(0);
   v0_i32 =
@@ -56,7 +56,7 @@
 }
 
 static inline int8x16x4_t iree_uk_neon_load_8x8xi8_strided_permute(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride, int p0, int p1, int p2,
+    const iree_uk_int8_t* src, iree_uk_index_t stride, int p0, int p1, int p2,
     int p3, int p4, int p5, int p6, int p7) {
   int8x8_t row0 = vld1_s8(src + p0 * stride);
   int8x8_t row1 = vld1_s8(src + p1 * stride);
@@ -75,7 +75,7 @@
 }
 
 static inline int8x16x4_t iree_uk_neon_load_8x8xi8_strided(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+    const iree_uk_int8_t* src, iree_uk_index_t stride) {
   return iree_uk_neon_load_8x8xi8_strided_permute(src, stride, 0, 1, 2, 3, 4, 5,
                                                   6, 7);
 }
@@ -109,7 +109,7 @@
 
 static inline void iree_uk_neon_copy_8x1xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   int8x8_t v = vdup_n_s8(0);
   v = vld1_lane_s8(in_ptr + 0 * in_stride, v, 0);
   v = vld1_lane_s8(in_ptr + 1 * in_stride, v, 1);
@@ -124,7 +124,7 @@
 
 static inline void iree_uk_neon_copy_8x4xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   int8x16x2_t in = iree_uk_neon_load_8x4xi8_strided(in_ptr, in_stride);
   vst1q_s8(out_ptr + 0, in.val[0]);
   vst1q_s8(out_ptr + 16, in.val[1]);
@@ -132,7 +132,7 @@
 
 static inline void iree_uk_neon_copy_8x8xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   int8x16x4_t in = iree_uk_neon_load_8x8xi8_strided(in_ptr, in_stride);
   vst1q_s8(out_ptr + 0, in.val[0]);
   vst1q_s8(out_ptr + 16, in.val[1]);
@@ -143,8 +143,8 @@
 static inline void
 iree_uk_neon_copy_8x8xi8_tiled_1x4_transpose_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   int8x16x4_t in = iree_uk_neon_load_8x8xi8_strided_permute(
       in_ptr, in_stride, 0, 2, 1, 3, 4, 6, 5, 7);
   int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s8(in.val[0]),
@@ -159,8 +159,8 @@
 
 static inline void iree_uk_neon_copy_8x32xi8_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   for (int i = 0; i < 8; ++i) {
     iree_uk_memcpy(out_ptr + i * out_stride, in_ptr + i * in_stride, 32);
   }
@@ -168,8 +168,8 @@
 
 static inline void iree_uk_neon_copy_8x8xi8_transpose_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   int8x16x4_t in = iree_uk_neon_load_8x8xi8_strided_permute(
       in_ptr, in_stride, 0, 4, 1, 5, 2, 6, 3, 7);
   int16x8x2_t zip_i16_0 = iree_uk_neon_zip_16xi8_as_8xi16(in.val[0], in.val[1]);
@@ -199,7 +199,7 @@
 
 static inline void iree_uk_neon_copy_8x8xi8_transpose_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   // Clang (Android NDK r25) actually produces worse code when this code is
   // specialized for out_stride==8 using longer contiguous stores!
   iree_uk_neon_copy_8x8xi8_transpose_strided_to_strided(out_ptr, in_ptr, 8,
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c
index b280217..a01c17e 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c
@@ -9,10 +9,10 @@
 
 static void iree_uk_pack_tile_8x1_x8_arm_64_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 1);
@@ -36,10 +36,10 @@
 
 static void iree_uk_pack_tile_8x4_x8_arm_64_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 4);
@@ -60,10 +60,10 @@
 
 static void iree_uk_pack_tile_8x1_x32_arm_64_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 1);
@@ -74,10 +74,10 @@
 
 static void iree_uk_pack_tile_8x8_x8_arm_64_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -92,10 +92,10 @@
 
 static void iree_uk_pack_tile_8x1_x32_arm_64_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 1);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -110,10 +110,10 @@
 
 static void iree_uk_pack_tile_8x1_x8_arm_64_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 1);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -136,10 +136,10 @@
 
 static void iree_uk_pack_tile_8x4_x8_arm_64_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 4);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -163,10 +163,10 @@
 
 static void iree_uk_pack_tile_8x8_x8_arm_64_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -183,10 +183,10 @@
 
 static void iree_uk_pack_tile_8x8_x32_arm_64_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 8);
diff --git a/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c b/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c
index 912c4a6..70654a3 100644
--- a/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c
+++ b/runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c
@@ -9,10 +9,10 @@
 
 static void iree_uk_unpack_tile_8x8_x32_arm_64_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 8);
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt b/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt
index c08fc99..3885e3d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/CMakeLists.txt
@@ -15,18 +15,20 @@
 iree_bitcode_library(
   NAME
     ukernel_bitcode_x86_64_base
+  ARCH
+    x86_64
   SRCS
     "mmt4d_x86_64.c"
     "pack_x86_64.c"
     "query_tile_sizes_x86_64.c"
     "unpack_x86_64.c"
-  ARCH
-    x86_64
 )
 
 iree_bitcode_library(
   NAME
     ukernel_bitcode_x86_64_avx2_fma
+  ARCH
+    x86_64
   SRCS
     "mmt4d_x86_64_avx2_fma.c"
     "pack_x86_64_avx2_fma.c"
@@ -35,13 +37,13 @@
     "-mavx"
     "-mavx2"
     "-mfma"
-  ARCH
-    x86_64
 )
 
 iree_bitcode_library(
   NAME
     ukernel_bitcode_x86_64_avx512_base
+  ARCH
+    x86_64
   SRCS
     "mmt4d_x86_64_avx512_base.c"
     "pack_x86_64_avx512_base.c"
@@ -55,13 +57,13 @@
     "-mavx512cd"
     "-mavx512bw"
     "-mavx512dq"
-  ARCH
-    x86_64
 )
 
 iree_bitcode_library(
   NAME
     ukernel_bitcode_x86_64_avx512_vnni
+  ARCH
+    x86_64
   SRCS
     "mmt4d_x86_64_avx512_vnni.c"
   COPTS
@@ -74,8 +76,6 @@
     "-mavx512bw"
     "-mavx512dq"
     "-mavx512vnni"
-  ARCH
-    x86_64
 )
 
 iree_link_bitcode(
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
index 08414cc..cebd52b 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
@@ -71,8 +71,8 @@
 
 static inline void iree_uk_copy_8x32xi8_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   for (int i = 0; i < 8; ++i) {
     iree_uk_memcpy(out_ptr + i * out_stride, in_ptr + i * in_stride, 32);
   }
@@ -80,22 +80,22 @@
 
 static inline void iree_uk_copy_16x64xi8_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   for (int i = 0; i < 16; ++i) {
     iree_uk_memcpy(out_ptr + i * out_stride, in_ptr + i * in_stride, 64);
   }
 }
 
 static inline __m256i iree_uk_avx2_load_8x4xi8_strided(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+    const iree_uk_int8_t* src, iree_uk_index_t stride) {
   __m256i indices = _mm256_mullo_epi32(
       _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7), _mm256_set1_epi32(stride));
   return _mm256_i32gather_epi32(src, indices, 1);
 }
 
 static inline __m128i iree_uk_avx2_load_8x2xi8_strided(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+    const iree_uk_int8_t* src, iree_uk_index_t stride) {
   __m128i result = _mm_setzero_si128();
   const iree_uk_int16_t* src_i16 = (const iree_uk_int16_t*)src;
   result =
@@ -118,7 +118,7 @@
 }
 
 static inline __m256i iree_uk_avx2_load_16x2xi8_strided(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+    const iree_uk_int8_t* src, iree_uk_index_t stride) {
   __m256i result = _mm256_setzero_si256();
   const iree_uk_int16_t* src_i16 = (const iree_uk_int16_t*)src;
   result = _mm256_insert_epi16(result,
@@ -158,21 +158,21 @@
 
 static inline void iree_uk_avx2_copy_8x4xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   __m256i in = iree_uk_avx2_load_8x4xi8_strided(in_ptr, in_stride);
   _mm256_storeu_si256((__m256i*)out_ptr, in);
 }
 
 static inline void iree_uk_avx2_copy_8x2xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   __m128i in = iree_uk_avx2_load_8x2xi8_strided(in_ptr, in_stride);
   _mm_storeu_si128((__m128i*)out_ptr, in);
 }
 
 static inline void iree_uk_avx2_copy_16x2xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   __m256i in = iree_uk_avx2_load_16x2xi8_strided(in_ptr, in_stride);
   _mm256_storeu_si256((__m256i*)out_ptr, in);
 }
@@ -180,8 +180,8 @@
 static inline void
 iree_uk_avx2_copy_8x16xi8_tiled_1x4_transpose_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   __m256i r00004444 =
       iree_uk_avx_loadu_2x128(in_ptr + 0 * in_stride, in_ptr + 4 * in_stride);
   __m256i r11115555 =
@@ -211,8 +211,8 @@
 static inline void
 iree_uk_avx2_copy_8x16xi8_tiled_1x2_transpose_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   __m256i r0000000044444444 =
       iree_uk_avx_loadu_2x128(in_ptr + 0 * in_stride, in_ptr + 4 * in_stride);
   __m256i r1111111155555555 =
@@ -313,7 +313,7 @@
 }
 
 static inline __m512i iree_uk_avx512_load_16x4xi8_strided(
-    const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
+    const iree_uk_int8_t* src, iree_uk_index_t stride) {
   __m512i indices = _mm512_mullo_epi32(
       _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
       _mm512_set1_epi32(stride));
@@ -322,7 +322,7 @@
 
 static inline void iree_uk_avx512_copy_16x4xi8_strided_to_unstrided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t in_stride) {
   __m512i in = iree_uk_avx512_load_16x4xi8_strided(in_ptr, in_stride);
   _mm512_storeu_si512((__m512i*)out_ptr, in);
 }
@@ -330,8 +330,8 @@
 static inline void
 iree_uk_avx512_copy_16x16xi8_tiled_1x4_transpose_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   __m512i r000044448888CCCC = iree_uk_avx512_loadu_4x128(
       in_ptr + 0 * in_stride, in_ptr + 4 * in_stride, in_ptr + 8 * in_stride,
       in_ptr + 12 * in_stride);
@@ -381,8 +381,8 @@
 static inline void
 iree_uk_avx512_copy_16x16xi8_tiled_1x2_transpose_strided_to_strided(
     iree_uk_int8_t* IREE_UK_RESTRICT out_ptr,
-    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_ssize_t out_stride,
-    iree_uk_ssize_t in_stride) {
+    const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr, iree_uk_index_t out_stride,
+    iree_uk_index_t in_stride) {
   __m512i r000044448888CCCC = iree_uk_avx512_loadu_4x128(
       in_ptr + 0 * in_stride, in_ptr + 4 * in_stride, in_ptr + 8 * in_stride,
       in_ptr + 12 * in_stride);
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c
index 1c1dfe5..55e41cf 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c
@@ -11,10 +11,10 @@
 
 void iree_uk_pack_tile_8x8_x32_x86_64_avx2_fma_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -30,10 +30,10 @@
 
 static void iree_uk_pack_tile_8x4_x8_x86_64_avx2_fma_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 4);
@@ -54,10 +54,10 @@
 
 void iree_uk_pack_tile_8x1_x32_x86_64_avx2_fma_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 1);
@@ -68,10 +68,10 @@
 
 void iree_uk_pack_tile_8x1_x32_x86_64_avx2_fma_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 1);
   IREE_UK_ASSERT(tile_size1 == 8);
@@ -86,10 +86,10 @@
 
 void iree_uk_pack_tile_8x2_x8_x86_64_avx2_fma_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 2);
@@ -110,16 +110,16 @@
 
 void iree_uk_pack_tile_8x2_x8_x86_64_avx2_fma_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 2);
   IREE_UK_ASSERT(tile_size1 == 8);
   const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr = in_tile_ptr;
   iree_uk_int8_t* IREE_UK_RESTRICT out_ptr = out_tile_ptr;
-  iree_uk_ssize_t outer_i1 = 0;
+  iree_uk_index_t outer_i1 = 0;
   for (; outer_i1 <= outer_size1 - 4; outer_i1 += 4) {
     __m256i in0 = _mm256_loadu_si256((const __m256i*)in_ptr);
     __m256i in1 = _mm256_loadu_si256((const __m256i*)(in_ptr + in_stride0));
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c
index 5021d6f..44a3dc3 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c
@@ -11,10 +11,10 @@
 
 void iree_uk_pack_tile_16x16_x32_x86_64_avx512_base_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 16);
   IREE_UK_ASSERT(tile_size1 == 16);
@@ -30,10 +30,10 @@
 
 static void iree_uk_pack_tile_16x4_x8_x86_64_avx512_base_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 16);
   IREE_UK_ASSERT(tile_size1 == 4);
@@ -55,10 +55,10 @@
 
 void iree_uk_pack_tile_16x1_x32_x86_64_avx512_base_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 16);
   IREE_UK_ASSERT(tile_size1 == 1);
@@ -69,10 +69,10 @@
 
 void iree_uk_pack_tile_16x1_x32_x86_64_avx512_base_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 1);
   IREE_UK_ASSERT(tile_size1 == 16);
@@ -87,10 +87,10 @@
 
 void iree_uk_pack_tile_16x2_x8_x86_64_avx512_base_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 16);
   IREE_UK_ASSERT(tile_size1 == 2);
@@ -111,16 +111,16 @@
 
 void iree_uk_pack_tile_16x2_x8_x86_64_avx512_base_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 1);
   IREE_UK_ASSERT(tile_size0 == 2);
   IREE_UK_ASSERT(tile_size1 == 16);
   const iree_uk_int8_t* IREE_UK_RESTRICT in_ptr = in_tile_ptr;
   iree_uk_int8_t* IREE_UK_RESTRICT out_ptr = out_tile_ptr;
-  iree_uk_ssize_t outer_i1 = 0;
+  iree_uk_index_t outer_i1 = 0;
   for (; outer_i1 <= outer_size1 - 4; outer_i1 += 4) {
     __m512i in0 = _mm512_loadu_si512((const __m512i*)in_ptr);
     __m512i in1 = _mm512_loadu_si512((const __m512i*)(in_ptr + in_stride0));
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c
index d7cda81..cc2600d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c
@@ -11,10 +11,10 @@
 
 void iree_uk_unpack_tile_8x8_x32_x86_64_avx2_fma_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 8);
   IREE_UK_ASSERT(tile_size1 == 8);
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c
index b99e16a..2c9229d 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c
@@ -11,10 +11,10 @@
 
 void iree_uk_unpack_tile_16x16_x32_x86_64_avx512_base_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   IREE_UK_ASSERT(elem_size == 4);
   IREE_UK_ASSERT(tile_size0 == 16);
   IREE_UK_ASSERT(tile_size1 == 16);
diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h
index 1a1db65..ad9630b 100644
--- a/runtime/src/iree/builtins/ukernel/common.h
+++ b/runtime/src/iree/builtins/ukernel/common.h
@@ -259,51 +259,94 @@
 }
 
 //===----------------------------------------------------------------------===//
-// iree_uk_ssize_t: signed size/index type.
+// Architecture detection (copied from target_platorm.h)
 //===----------------------------------------------------------------------===//
 
-// Don't think of this as "pointer-size" -- there's a reason why this is called
-// ssize_t, not ptrdiff_t.
-//
-// The typical compilation path for this ukernel code is to LLVM bitcode, i.e.
-// LLVM IR. There, there is an opaque pointer type, `ptr`, so pointers don't
-// even have a "size". Size/index types, on the other hand, are explicitly sized
-// integers such as i32 or i64. We have to pick one here.
-//
-// As outer parts of ukernel code (all what's outside of the arch/ subdir) is
-// compiled once to LLVM bitcode that's reused on all architectures, we have to
-// pick one type that will be used across all architectures. In the future, if
-// we care more about optimizing for 32-bit architectures, we could consider
-// revisiting that by having two separate builds of the non-arch/ ukernel code,
-// one for 64-bit and one for 32-bit architectures.
-//
-// For now, it seems reasonable to just use int64 everywhere. Our current focus
-// is on 64-bit architectures anyway, the expected overhead on 32-bit
-// architectures should be marginal, and this buys us simplicity and uniformity
-// (ukernels having the exact same semantics across architectures) that we would
-// lose the moment we allowed iree_uk_ssize_t to vary across architectures.
-typedef iree_uk_int64_t iree_uk_ssize_t;
+#if defined(__arm64) || defined(__aarch64__) || defined(_M_ARM64) || \
+    defined(_M_ARM64EC)
+#define IREE_UK_ARCH_ARM_64 1
+#elif defined(__arm__) || defined(__thumb__) || defined(__TARGET_ARCH_ARM) || \
+    defined(__TARGET_ARCH_THUMB) || defined(_M_ARM)
+#define IREE_UK_ARCH_ARM_32 1
+#endif  // ARM
 
-static inline void iree_uk_ssize_swap(iree_uk_ssize_t* a, iree_uk_ssize_t* b) {
-  iree_uk_ssize_t t = *a;
+#if defined(__riscv) && (__riscv_xlen == 32)
+#define IREE_UK_ARCH_RISCV_32 1
+#elif defined(__riscv) && (__riscv_xlen == 64)
+#define IREE_UK_ARCH_RISCV_64 1
+#endif  // RISCV
+
+#if defined(__wasm32__)
+#define IREE_UK_ARCH_WASM_32 1
+#elif defined(__wasm64__)
+#define IREE_UK_ARCH_WASM_64 1
+#endif  // WASM
+
+#if defined(__i386__) || defined(__i486__) || defined(__i586__) || \
+    defined(__i686__) || defined(__i386) || defined(_M_IX86) || defined(_X86_)
+#define IREE_UK_ARCH_X86_32 1
+#elif defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || \
+    defined(__amd64) || defined(_M_X64)
+#define IREE_UK_ARCH_X86_64 1
+#endif  // X86
+
+//===----------------------------------------------------------------------===//
+// Architecture bitness
+//===----------------------------------------------------------------------===//
+
+#if defined(IREE_UK_ARCH_ARM_64) || defined(IREE_UK_ARCH_RISCV_64) || \
+    defined(IREE_UK_ARCH_WASM_64) || defined(IREE_UK_ARCH_X86_64)
+#define IREE_UK_ARCH_IS_64_BIT
+#elif defined(IREE_UK_ARCH_ARM_32) || defined(IREE_UK_ARCH_RISCV_32) || \
+    defined(IREE_UK_ARCH_WASM_32) || defined(IREE_UK_ARCH_X86_32)
+#define IREE_UK_ARCH_IS_32_BIT
+#else
+#error Unknown architecture
+#endif
+
+//===----------------------------------------------------------------------===//
+// iree_uk_index_t: signed integer type, same size as MLIR `index`.
+//===----------------------------------------------------------------------===//
+
+// When ukernels are built as bitcode to embed in the compiler, the requirement
+// here is that the size of iree_uk_index_t equals the size of the compiler's
+// `index` type.
+//
+// So when here we define iree_uk_index_t as a pointer-sized type, there is an
+// implicit assumption about how this ukernel code is built. When building as
+// bitcode to embed in the compiler, this code must be built twice, once for
+// some 32-bit-pointers architecture and once for some 64-bit-pointers
+// architecture. Then the compiler must select the bitcode module that matches
+// the size of the `index` type.
+
+#if defined(IREE_UK_ARCH_IS_64_BIT)
+typedef iree_uk_int64_t iree_uk_index_t;
+#elif defined(IREE_UK_ARCH_IS_32_BIT)
+typedef iree_uk_int32_t iree_uk_index_t;
+#else
+#error Unknown architecture
+#endif
+
+static inline void iree_uk_index_swap(iree_uk_index_t* a, iree_uk_index_t* b) {
+  iree_uk_index_t t = *a;
   *a = *b;
   *b = t;
 }
 
-static inline iree_uk_ssize_t iree_uk_ssize_min(iree_uk_ssize_t a,
-                                                iree_uk_ssize_t b) {
+static inline iree_uk_index_t iree_uk_index_min(iree_uk_index_t a,
+                                                iree_uk_index_t b) {
   return a <= b ? a : b;
 }
 
-static inline iree_uk_ssize_t iree_uk_ssize_max(iree_uk_ssize_t a,
-                                                iree_uk_ssize_t b) {
+static inline iree_uk_index_t iree_uk_index_max(iree_uk_index_t a,
+                                                iree_uk_index_t b) {
   return a >= b ? a : b;
 }
 
-static inline iree_uk_ssize_t iree_uk_ssize_clamp(iree_uk_ssize_t val,
-                                                  iree_uk_ssize_t min,
-                                                  iree_uk_ssize_t max) {
-  return iree_uk_ssize_min(max, iree_uk_ssize_max(min, val));
+static inline iree_uk_index_t iree_uk_index_clamp(iree_uk_index_t val,
+                                                  iree_uk_index_t min,
+                                                  iree_uk_index_t max) {
+  return iree_uk_index_min(max, iree_uk_index_max(min, val));
 }
 
 //===----------------------------------------------------------------------===//
@@ -502,14 +545,14 @@
 // as a memcpy call.
 static inline void iree_uk_memcpy(void* IREE_UK_RESTRICT dst,
                                   const void* IREE_UK_RESTRICT src,
-                                  iree_uk_ssize_t size) {
-  for (iree_uk_ssize_t i = 0; i < size; ++i)
+                                  iree_uk_index_t size) {
+  for (iree_uk_index_t i = 0; i < size; ++i)
     ((char*)dst)[i] = ((const char*)src)[i];
 }
 
-static inline void iree_uk_memset(void* buf, int val, iree_uk_ssize_t n) {
+static inline void iree_uk_memset(void* buf, int val, iree_uk_index_t n) {
   // This naive loop is lifted to a memset by both clang and gcc on ARM64.
-  for (iree_uk_ssize_t i = 0; i < n; ++i) ((char*)buf)[i] = val;
+  for (iree_uk_index_t i = 0; i < n; ++i) ((char*)buf)[i] = val;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/runtime/src/iree/builtins/ukernel/elementwise.c b/runtime/src/iree/builtins/ukernel/elementwise.c
index 11570c1..5240469 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise.c
+++ b/runtime/src/iree/builtins/ukernel/elementwise.c
@@ -69,13 +69,13 @@
 // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
 #define DISPATCH_UKERNEL_BINARY_2D(opcode, opcode_t, dtype, category)         \
   IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d(                      \
-      const dtype* lhs, iree_uk_ssize_t lhs_offset,                           \
-      iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,               \
-      const dtype* rhs, iree_uk_ssize_t rhs_offset,                           \
-      iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,               \
-      dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,                \
-      iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,               \
-      iree_uk_ssize_t size0, iree_uk_ssize_t size1) {                         \
+      const dtype* lhs, iree_uk_index_t lhs_offset,                           \
+      iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1,               \
+      const dtype* rhs, iree_uk_index_t rhs_offset,                           \
+      iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1,               \
+      dtype* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,                \
+      iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,               \
+      iree_uk_index_t size0, iree_uk_index_t size1) {                         \
     return iree_uk_generic_##category##_2d(                                   \
         opcode_t, lhs, lhs_offset, lhs_stride0, lhs_stride1, rhs, rhs_offset, \
         rhs_stride0, rhs_stride1, out, out_offset, out_stride0, out_stride1,  \
@@ -87,11 +87,11 @@
 // Corresponds to the header macro DECLARE_UKERNEL_BINARY_2D.
 #define DISPATCH_UKERNEL_UNARY_2D(opcode, opcode_t, dtype, category)          \
   IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d(                      \
-      const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
-      iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out,                \
-      iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0,                \
-      iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0,                     \
-      iree_uk_ssize_t size1) {                                                \
+      const dtype* in, iree_uk_index_t in_offset, iree_uk_index_t in_stride0, \
+      iree_uk_index_t in_stride1, dtype* IREE_UK_RESTRICT out,                \
+      iree_uk_index_t out_offset, iree_uk_index_t out_stride0,                \
+      iree_uk_index_t out_stride1, iree_uk_index_t size0,                     \
+      iree_uk_index_t size1) {                                                \
     return iree_uk_generic_##category##_2d(                                   \
         opcode_t, in, in_offset, in_stride0, in_stride1, out, out_offset,     \
         out_stride0, out_stride1, size0, size1);                              \
@@ -203,20 +203,20 @@
 IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32b_2d(
     iree_uk_x32b_opcode_t opcode,
     // LHS.
-    const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
-    iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
+    const iree_uk_uint32_t* lhs, iree_uk_index_t lhs_offset,
+    iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1,
     // RHS
-    const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
-    iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
+    const iree_uk_uint32_t* rhs, iree_uk_index_t rhs_offset,
+    iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1,
     // OUT.
-    iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+    iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,
+    iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
     // Sizes.
-    iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+    iree_uk_index_t size0, iree_uk_index_t size1) {
   int result_code = 0;
   // TODO: Manually unroll to x4 to trigger vectorization.
-  for (iree_uk_ssize_t i = 0; i < size0; ++i) {
-    for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+  for (iree_uk_index_t i = 0; i < size0; ++i) {
+    for (iree_uk_index_t j = 0; j < size1; ++j) {
       iree_uk_generic_x32b_op(opcode, &result_code,
                               &lhs[i * lhs_stride0 + j * lhs_stride1],
                               &rhs[i * rhs_stride0 + j * rhs_stride1],
@@ -230,17 +230,17 @@
 IREE_UK_ATTRIBUTE_NOINLINE static int iree_uk_generic_x32u_2d(
     iree_uk_x32u_opcode_t opcode,
     // IN.
-    const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
-    iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
+    const iree_uk_uint32_t* in, iree_uk_index_t in_offset,
+    iree_uk_index_t in_stride0, iree_uk_index_t in_stride1,
     // OUT.
-    iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
+    iree_uk_uint32_t* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,
+    iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
     // Sizes.
-    iree_uk_ssize_t size0, iree_uk_ssize_t size1) {
+    iree_uk_index_t size0, iree_uk_index_t size1) {
   int result_code = 0;
   // TODO: Manually unroll to x4 to trigger vectorization.
-  for (iree_uk_ssize_t i = 0; i < size0; ++i) {
-    for (iree_uk_ssize_t j = 0; j < size1; ++j) {
+  for (iree_uk_index_t i = 0; i < size0; ++i) {
+    for (iree_uk_index_t j = 0; j < size1; ++j) {
       iree_uk_generic_x32u_op(opcode, &result_code,
                               &in[i * in_stride0 + j * in_stride1],
                               &out[i * out_stride0 + j * out_stride1]);
diff --git a/runtime/src/iree/builtins/ukernel/elementwise.h b/runtime/src/iree/builtins/ukernel/elementwise.h
index 34c3974..c266984 100644
--- a/runtime/src/iree/builtins/ukernel/elementwise.h
+++ b/runtime/src/iree/builtins/ukernel/elementwise.h
@@ -21,26 +21,26 @@
 // It takes lhs, rhs, out buffers and size, returning 0 on success and !0 on
 // error.
 typedef int (*iree_uk_x32b_2d_func_t)(
-    const iree_uk_uint32_t* lhs, iree_uk_ssize_t lhs_offset,
-    iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1,
-    const iree_uk_uint32_t* rhs, iree_uk_ssize_t rhs_offset,
-    iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1,
-    iree_uk_uint32_t* out, iree_uk_ssize_t out_offset,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
-    iree_uk_ssize_t size0, iree_uk_ssize_t size1);
+    const iree_uk_uint32_t* lhs, iree_uk_index_t lhs_offset,
+    iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1,
+    const iree_uk_uint32_t* rhs, iree_uk_index_t rhs_offset,
+    iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1,
+    iree_uk_uint32_t* out, iree_uk_index_t out_offset,
+    iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
+    iree_uk_index_t size0, iree_uk_index_t size1);
 
 // Declares a binary 2d microkernel with the following signature:
 //   int iree_uk_{category}_{opcode}_2d(...)
 // of function type iree_uk_{category}_2d_func_t.
 #define DECLARE_UKERNEL_BINARY_2D(opcode, dtype, category)      \
   IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d(        \
-      const dtype* lhs, iree_uk_ssize_t lhs_offset,             \
-      iree_uk_ssize_t lhs_stride0, iree_uk_ssize_t lhs_stride1, \
-      const dtype* rhs, iree_uk_ssize_t rhs_offset,             \
-      iree_uk_ssize_t rhs_stride0, iree_uk_ssize_t rhs_stride1, \
-      dtype* IREE_UK_RESTRICT out, iree_uk_ssize_t out_offset,  \
-      iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1, \
-      iree_uk_ssize_t size0, iree_uk_ssize_t size1)
+      const dtype* lhs, iree_uk_index_t lhs_offset,             \
+      iree_uk_index_t lhs_stride0, iree_uk_index_t lhs_stride1, \
+      const dtype* rhs, iree_uk_index_t rhs_offset,             \
+      iree_uk_index_t rhs_stride0, iree_uk_index_t rhs_stride1, \
+      dtype* IREE_UK_RESTRICT out, iree_uk_index_t out_offset,  \
+      iree_uk_index_t out_stride0, iree_uk_index_t out_stride1, \
+      iree_uk_index_t size0, iree_uk_index_t size1)
 
 DECLARE_UKERNEL_BINARY_2D(addf, iree_uk_uint32_t, x32b);
 DECLARE_UKERNEL_BINARY_2D(addi, iree_uk_uint32_t, x32b);
@@ -66,11 +66,11 @@
 // It takes in, out buffers and size, returning 0 on success and !0 on
 // error.
 typedef int (*iree_uk_x32u_2d_func_t)(
-    const iree_uk_uint32_t* in, iree_uk_ssize_t in_offset,
-    iree_uk_ssize_t in_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_uint32_t* out, iree_uk_ssize_t out_offset,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t out_stride1,
-    iree_uk_ssize_t size0, iree_uk_ssize_t size1);
+    const iree_uk_uint32_t* in, iree_uk_index_t in_offset,
+    iree_uk_index_t in_stride0, iree_uk_index_t in_stride1,
+    iree_uk_uint32_t* out, iree_uk_index_t out_offset,
+    iree_uk_index_t out_stride0, iree_uk_index_t out_stride1,
+    iree_uk_index_t size0, iree_uk_index_t size1);
 
 // Declares a binary 2d microkernel with the following signature:
 //   int iree_uk_{category}_{opcode}_2d(...)
@@ -78,11 +78,11 @@
 // error.
 #define DECLARE_UKERNEL_UNARY_2D(opcode, dtype, category)                     \
   IREE_UK_EXPORT int iree_uk_##category##_##opcode##_2d(                      \
-      const dtype* in, iree_uk_ssize_t in_offset, iree_uk_ssize_t in_stride0, \
-      iree_uk_ssize_t in_stride1, dtype* IREE_UK_RESTRICT out,                \
-      iree_uk_ssize_t out_offset, iree_uk_ssize_t out_stride0,                \
-      iree_uk_ssize_t out_stride1, iree_uk_ssize_t size0,                     \
-      iree_uk_ssize_t size1)
+      const dtype* in, iree_uk_index_t in_offset, iree_uk_index_t in_stride0, \
+      iree_uk_index_t in_stride1, dtype* IREE_UK_RESTRICT out,                \
+      iree_uk_index_t out_offset, iree_uk_index_t out_stride0,                \
+      iree_uk_index_t out_stride1, iree_uk_index_t size0,                     \
+      iree_uk_index_t size1)
 
 DECLARE_UKERNEL_UNARY_2D(absf, iree_uk_uint32_t, x32u);
 DECLARE_UKERNEL_UNARY_2D(ceilf, iree_uk_uint32_t, x32u);
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.c b/runtime/src/iree/builtins/ukernel/mmt4d.c
index d7ccc44..67f4856 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.c
@@ -63,9 +63,9 @@
   const char* rhs_panel_start = (const char*)params->rhs_buffer +
                                 (params->rhs_offset << rhs_elem_size_log2);
   iree_uk_int32_t out_tile_size = (M0 * N0) << out_elem_size_log2;
-  iree_uk_ssize_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2;
-  iree_uk_ssize_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2;
-  iree_uk_ssize_t out_stride = params->out_stride0 << out_elem_size_log2;
+  iree_uk_index_t lhs_panel_stride = params->lhs_stride0 << lhs_elem_size_log2;
+  iree_uk_index_t rhs_panel_stride = params->rhs_stride0 << rhs_elem_size_log2;
+  iree_uk_index_t out_stride = params->out_stride0 << out_elem_size_log2;
   for (iree_uk_int32_t i = 0; i < M; ++i) {
     char* out_tile = out_tile_row;
     const char* rhs_panel = rhs_panel_start;
@@ -88,12 +88,12 @@
   iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
   iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
   int out_elem_size_log2 = iree_uk_type_size_log2(out_type);
-  iree_uk_ssize_t contiguous_size = params->N * params->M0 * params->N0
+  iree_uk_index_t contiguous_size = params->N * params->M0 * params->N0
                                     << out_elem_size_log2;
-  iree_uk_ssize_t stride = params->out_stride0 << out_elem_size_log2;
+  iree_uk_index_t stride = params->out_stride0 << out_elem_size_log2;
   char* out_ptr =
       (char*)params->out_buffer + (params->out_offset << out_elem_size_log2);
-  for (iree_uk_ssize_t i = 0; i < params->M; ++i) {
+  for (iree_uk_index_t i = 0; i < params->M; ++i) {
     iree_uk_memset(out_ptr, 0, contiguous_size);
     out_ptr += stride;
   }
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d.h b/runtime/src/iree/builtins/ukernel/mmt4d.h
index bd956a6..a093e3e 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d.h
+++ b/runtime/src/iree/builtins/ukernel/mmt4d.h
@@ -15,17 +15,17 @@
 
 typedef struct iree_uk_mmt4d_params_t {
   const void* lhs_buffer;
-  iree_uk_ssize_t lhs_offset;
-  iree_uk_ssize_t lhs_stride0;
+  iree_uk_index_t lhs_offset;
+  iree_uk_index_t lhs_stride0;
   const void* rhs_buffer;
-  iree_uk_ssize_t rhs_offset;
-  iree_uk_ssize_t rhs_stride0;
+  iree_uk_index_t rhs_offset;
+  iree_uk_index_t rhs_stride0;
   void* out_buffer;
-  iree_uk_ssize_t out_offset;
-  iree_uk_ssize_t out_stride0;
-  iree_uk_ssize_t M;
-  iree_uk_ssize_t N;
-  iree_uk_ssize_t K;
+  iree_uk_index_t out_offset;
+  iree_uk_index_t out_stride0;
+  iree_uk_index_t M;
+  iree_uk_index_t N;
+  iree_uk_index_t K;
   iree_uk_int32_t M0;
   iree_uk_int32_t N0;
   iree_uk_int32_t K0;
diff --git a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
index 36a8c25..440001c 100644
--- a/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
+++ b/runtime/src/iree/builtins/ukernel/mmt4d_tile.c
@@ -25,10 +25,10 @@
     for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
   }
   // Accumulation loop.
-  for (iree_uk_ssize_t k = 0; k < K; ++k) {
-    for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
-      for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
-        for (iree_uk_ssize_t k0 = 0; k0 < K0; ++k0) {
+  for (iree_uk_index_t k = 0; k < K; ++k) {
+    for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+      for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+        for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
           iree_uk_int32_t lhs_val_int32 = lhs_panel[i0 * K0 + k0];
           iree_uk_int32_t rhs_val_int32 = rhs_panel[j0 * K0 + k0];
           acc[i0 * N0 + j0] += lhs_val_int32 * rhs_val_int32;
@@ -61,10 +61,10 @@
     for (int i = 0; i < M0 * N0; ++i) acc[i] = 0;
   }
   // Accumulation loop.
-  for (iree_uk_ssize_t k = 0; k < K; ++k) {
-    for (iree_uk_ssize_t i0 = 0; i0 < M0; ++i0) {
-      for (iree_uk_ssize_t j0 = 0; j0 < N0; ++j0) {
-        for (iree_uk_ssize_t k0 = 0; k0 < K0; ++k0) {
+  for (iree_uk_index_t k = 0; k < K; ++k) {
+    for (iree_uk_index_t i0 = 0; i0 < M0; ++i0) {
+      for (iree_uk_index_t j0 = 0; j0 < N0; ++j0) {
+        for (iree_uk_index_t k0 = 0; k0 < K0; ++k0) {
           float lhs_val = lhs_panel[i0 * K0 + k0];
           float rhs_val = rhs_panel[j0 * K0 + k0];
           acc[i0 * N0 + j0] += lhs_val * rhs_val;
diff --git a/runtime/src/iree/builtins/ukernel/pack.c b/runtime/src/iree/builtins/ukernel/pack.c
index e7e7a25..7895ef9 100644
--- a/runtime/src/iree/builtins/ukernel/pack.c
+++ b/runtime/src/iree/builtins/ukernel/pack.c
@@ -20,8 +20,8 @@
 } iree_uk_pack_tmpbuf_helper_t;
 
 // Return x/y for x>=0 and y>0, with a fast path for when y is a power of two.
-static iree_uk_ssize_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
-    iree_uk_ssize_t x, iree_uk_int32_t y) {
+static iree_uk_index_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
+    iree_uk_index_t x, iree_uk_int32_t y) {
   IREE_UK_ASSERT(x >= 0);
   IREE_UK_ASSERT(y > 0);
   return IREE_UK_LIKELY(iree_uk_is_po2_u32(y)) ? (x >> iree_uk_po2_log2_u32(y))
@@ -46,8 +46,8 @@
 // Initializes a `iree_uk_pack_tmpbuf_helper_t`. Asserts if the temporary buffer
 // is smaller than one tile.
 static void iree_uk_pack_tmpbuf_helper_t_init(
-    iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
-    iree_uk_ssize_t elem_size, iree_uk_uint64_t padding_value,
+    iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+    iree_uk_index_t elem_size, iree_uk_uint64_t padding_value,
     iree_uk_pack_tmpbuf_helper_t* helper) {
   helper->max_tiles_in_tmp_buf = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
       iree_uk_pack_tmp_buf_size, tile_size0 * tile_size1 * elem_size);
@@ -76,15 +76,15 @@
   IREE_UK_ASSERT(params->out_size3 >= 0);
   // Check that the input and output shapes match, give or take padding that
   // must not exceed the inner tile size.s
-  iree_uk_ssize_t outer_size0 = params->out_size0;
-  iree_uk_ssize_t outer_size1 = params->out_size1;
-  iree_uk_ssize_t tile_size0 = params->out_size2;
-  iree_uk_ssize_t tile_size1 = params->out_size3;
+  iree_uk_index_t outer_size0 = params->out_size0;
+  iree_uk_index_t outer_size1 = params->out_size1;
+  iree_uk_index_t tile_size0 = params->out_size2;
+  iree_uk_index_t tile_size1 = params->out_size3;
   if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&outer_size0, &outer_size1);
+    iree_uk_index_swap(&outer_size0, &outer_size1);
   }
   if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&tile_size0, &tile_size1);
+    iree_uk_index_swap(&tile_size0, &tile_size1);
   }
   IREE_UK_ASSERT(outer_size0 * tile_size0 >= params->in_size0);
   IREE_UK_ASSERT(outer_size1 * tile_size1 >= params->in_size1);
@@ -100,7 +100,7 @@
   iree_uk_pack_tmpbuf_helper_t padding_helper;
   iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
   iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
-  iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
+  iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
   iree_uk_pack_tmpbuf_helper_t_init(tile_size0, tile_size1, elem_size,
                                     params->padding_value, &padding_helper);
 #endif  // IREE_UK_ENABLE_ASSERTS
@@ -115,8 +115,8 @@
 // Fills `buf` with `num_elems` times the `pattern` of size `elem_size`.
 // If this pattern's `elem_size` bytes are all equal, then it is legal to pass
 // `is_single_byte_pattern=true`, which allows the impl to use memset.
-static void iree_uk_fill(char* IREE_UK_RESTRICT buf, iree_uk_ssize_t num_elems,
-                         iree_uk_ssize_t elem_size,
+static void iree_uk_fill(char* IREE_UK_RESTRICT buf, iree_uk_index_t num_elems,
+                         iree_uk_index_t elem_size,
                          iree_uk_uint64_t padding_value,
                          bool is_padding_single_byte) {
   if (is_padding_single_byte) {
@@ -124,20 +124,20 @@
   } else if (elem_size == 2) {
     iree_uk_uint16_t padding_value_uint16 = padding_value;
     iree_uk_uint16_t* IREE_UK_RESTRICT buf_uint16 = (iree_uk_uint16_t*)buf;
-    for (iree_uk_ssize_t i = 0; i < num_elems; ++i) {
+    for (iree_uk_index_t i = 0; i < num_elems; ++i) {
       buf_uint16[i] = padding_value_uint16;
     }
   } else if (elem_size == 4) {
     iree_uk_uint32_t padding_value_uint32 = padding_value;
     iree_uk_uint32_t* IREE_UK_RESTRICT buf_uint32 = (iree_uk_uint32_t*)buf;
-    for (iree_uk_ssize_t i = 0; i < num_elems; ++i) {
+    for (iree_uk_index_t i = 0; i < num_elems; ++i) {
       buf_uint32[i] = padding_value_uint32;
     }
   } else {  // elem_size >= 8
     // While arbitrary large elem_size is allowed, padding_value remains a
     // uint64, so elem_size >= 16 only support a repeating 8-byte pattern.
     iree_uk_uint64_t* IREE_UK_RESTRICT buf_uint64 = (iree_uk_uint64_t*)buf;
-    for (iree_uk_ssize_t i = 0; i < num_elems * elem_size / 8; ++i) {
+    for (iree_uk_index_t i = 0; i < num_elems * elem_size / 8; ++i) {
       buf_uint64[i] = padding_value;
     }
   }
@@ -146,14 +146,14 @@
 // Copy from a source 2D buffer to a destination 2D buffer, padding to the
 // destination size.
 static void iree_uk_copy_and_pad(
-    iree_uk_ssize_t src_size0, iree_uk_ssize_t src_size1,
-    iree_uk_ssize_t src_stride0, const char* src_buf, iree_uk_ssize_t dst_size0,
-    iree_uk_ssize_t dst_size1, iree_uk_ssize_t dst_stride0, char* dst_buf,
-    iree_uk_ssize_t elem_size, iree_uk_uint64_t padding_value,
+    iree_uk_index_t src_size0, iree_uk_index_t src_size1,
+    iree_uk_index_t src_stride0, const char* src_buf, iree_uk_index_t dst_size0,
+    iree_uk_index_t dst_size1, iree_uk_index_t dst_stride0, char* dst_buf,
+    iree_uk_index_t elem_size, iree_uk_uint64_t padding_value,
     bool is_padding_single_byte) {
   iree_uk_fill(dst_buf, dst_size1 + (dst_size0 - 1) * dst_stride0, elem_size,
                padding_value, is_padding_single_byte);
-  for (iree_uk_ssize_t in_i0 = 0; in_i0 < src_size0; in_i0++) {
+  for (iree_uk_index_t in_i0 = 0; in_i0 < src_size0; in_i0++) {
     iree_uk_memcpy(dst_buf, src_buf, src_size1 * elem_size);
     dst_buf += dst_stride0 * elem_size;
     src_buf += src_stride0 * elem_size;
@@ -163,20 +163,20 @@
 // Pads and packs an entire row. In cases that are known not to require padding,
 // it is more efficient to call tile_func directly.
 static void iree_uk_pad_and_pack_row_using_tile_func(
-    iree_uk_pack_tile_func_t tile_func, iree_uk_ssize_t dim1_tile_start,
-    iree_uk_ssize_t dim1_tile_end, iree_uk_ssize_t dim0_src_read_size,
-    iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t in_size1,
-    iree_uk_ssize_t in_stride0, iree_uk_ssize_t out_stride1,
+    iree_uk_pack_tile_func_t tile_func, iree_uk_index_t dim1_tile_start,
+    iree_uk_index_t dim1_tile_end, iree_uk_index_t dim0_src_read_size,
+    iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+    iree_uk_index_t elem_size, iree_uk_index_t in_size1,
+    iree_uk_index_t in_stride0, iree_uk_index_t out_stride1,
     iree_uk_uint64_t padding_value, iree_uk_pack_tmpbuf_helper_t* helper,
     const char* in_buf, char* out_buf) {
-  iree_uk_ssize_t dim1_tile = dim1_tile_start;
+  iree_uk_index_t dim1_tile = dim1_tile_start;
   while (dim1_tile < dim1_tile_end) {
-    iree_uk_ssize_t dim1_chunk_tiles = iree_uk_ssize_clamp(
+    iree_uk_index_t dim1_chunk_tiles = iree_uk_index_clamp(
         dim1_tile_end - dim1_tile, 0, helper->max_tiles_in_tmp_buf);
-    iree_uk_ssize_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
-    iree_uk_ssize_t dim1_chunk_src_pos = dim1_tile * tile_size1;
-    iree_uk_ssize_t i1_read_size = iree_uk_ssize_clamp(
+    iree_uk_index_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
+    iree_uk_index_t dim1_chunk_src_pos = dim1_tile * tile_size1;
+    iree_uk_index_t i1_read_size = iree_uk_index_clamp(
         in_size1 - dim1_chunk_src_pos, 0, dim1_chunk_src_width);
     iree_uk_copy_and_pad(dim0_src_read_size, i1_read_size, in_stride0,
                          in_buf + dim1_chunk_src_pos * elem_size, tile_size0,
@@ -195,19 +195,19 @@
   // For now, the input and output element types are always the same.
   iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
   iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
-  iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
-  iree_uk_ssize_t outer_size0 = params->out_size0;
-  iree_uk_ssize_t outer_size1 = params->out_size1;
-  iree_uk_ssize_t tile_size0 = params->out_size2;
-  iree_uk_ssize_t tile_size1 = params->out_size3;
-  iree_uk_ssize_t out_stride_l0 = params->out_stride0;
-  iree_uk_ssize_t out_stride1 = params->out_size3 * params->out_size2;
+  iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+  iree_uk_index_t outer_size0 = params->out_size0;
+  iree_uk_index_t outer_size1 = params->out_size1;
+  iree_uk_index_t tile_size0 = params->out_size2;
+  iree_uk_index_t tile_size1 = params->out_size3;
+  iree_uk_index_t out_stride_l0 = params->out_stride0;
+  iree_uk_index_t out_stride1 = params->out_size3 * params->out_size2;
   if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&outer_size0, &outer_size1);
-    iree_uk_ssize_swap(&out_stride_l0, &out_stride1);
+    iree_uk_index_swap(&outer_size0, &outer_size1);
+    iree_uk_index_swap(&out_stride_l0, &out_stride1);
   }
   if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&tile_size0, &tile_size1);
+    iree_uk_index_swap(&tile_size0, &tile_size1);
   }
   const char* in_buf =
       (const char*)params->in_buffer + (params->in_offset * elem_size);
@@ -223,7 +223,7 @@
   // source buffer's boundaries.
   int dim1_full_tiles = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
       params->in_size1, tile_size1);
-  iree_uk_ssize_t i0 = 0;
+  iree_uk_index_t i0 = 0;
   for (; i0 <= params->in_size0 - tile_size0; i0 += tile_size0) {
     // Pack whole tiles that do not require padding (entirely within the source
     // buffer's boundaries).
@@ -239,8 +239,8 @@
   }
   // Bottom-padding.
   for (; i0 < outer_size0 * tile_size0; i0 += tile_size0) {
-    iree_uk_ssize_t dim0_src_read_size =
-        iree_uk_ssize_clamp(params->in_size0 - i0, 0, tile_size0);
+    iree_uk_index_t dim0_src_read_size =
+        iree_uk_index_clamp(params->in_size0 - i0, 0, tile_size0);
     iree_uk_pad_and_pack_row_using_tile_func(
         tile_func, 0, outer_size1, dim0_src_read_size, tile_size0, tile_size1,
         elem_size, params->in_size1, params->in_stride0, out_stride1,
diff --git a/runtime/src/iree/builtins/ukernel/pack.h b/runtime/src/iree/builtins/ukernel/pack.h
index 8feb64b..c4d0233 100644
--- a/runtime/src/iree/builtins/ukernel/pack.h
+++ b/runtime/src/iree/builtins/ukernel/pack.h
@@ -15,17 +15,17 @@
 
 typedef struct iree_uk_pack_params_t {
   const void* in_buffer;
-  iree_uk_ssize_t in_offset;
-  iree_uk_ssize_t in_stride0;
+  iree_uk_index_t in_offset;
+  iree_uk_index_t in_stride0;
   void* out_buffer;
-  iree_uk_ssize_t out_offset;
-  iree_uk_ssize_t out_stride0;
-  iree_uk_ssize_t in_size0;
-  iree_uk_ssize_t in_size1;
-  iree_uk_ssize_t out_size0;
-  iree_uk_ssize_t out_size1;
-  iree_uk_ssize_t out_size2;
-  iree_uk_ssize_t out_size3;
+  iree_uk_index_t out_offset;
+  iree_uk_index_t out_stride0;
+  iree_uk_index_t in_size0;
+  iree_uk_index_t in_size1;
+  iree_uk_index_t out_size0;
+  iree_uk_index_t out_size1;
+  iree_uk_index_t out_size2;
+  iree_uk_index_t out_size3;
   // The least significant bits of `padding_value`, up to element size, are used
   // for padding. As this is based solely on bit-significance and not on byte
   // addresses, this is independent of endianness.
diff --git a/runtime/src/iree/builtins/ukernel/pack_internal.h b/runtime/src/iree/builtins/ukernel/pack_internal.h
index ce92874..cc381a0 100644
--- a/runtime/src/iree/builtins/ukernel/pack_internal.h
+++ b/runtime/src/iree/builtins/ukernel/pack_internal.h
@@ -38,18 +38,18 @@
 
 typedef void (*iree_uk_pack_tile_func_t)(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1);
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1);
 
 // Tile kernel declarations. Prototype matches iree_uk_unpack_tile_func_t.
 #define IREE_UK_PACK_TILE_FUNC_DECL(NAME)                             \
   void NAME(void* IREE_UK_RESTRICT out_tile_ptr,                      \
             const void* IREE_UK_RESTRICT in_tile_ptr,                 \
-            iree_uk_ssize_t outer_size1, iree_uk_ssize_t out_stride1, \
-            iree_uk_ssize_t in_stride0, iree_uk_ssize_t elem_size,    \
-            iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1);
+            iree_uk_index_t outer_size1, iree_uk_index_t out_stride1, \
+            iree_uk_index_t in_stride0, iree_uk_index_t elem_size,    \
+            iree_uk_index_t tile_size0, iree_uk_index_t tile_size1);
 
 // Returns the tile function to use for the pack op with the given params.
 iree_uk_pack_tile_func_t iree_uk_pack_select_tile_func(
diff --git a/runtime/src/iree/builtins/ukernel/pack_tile.c b/runtime/src/iree/builtins/ukernel/pack_tile.c
index c0c9afe..fe87939 100644
--- a/runtime/src/iree/builtins/ukernel/pack_tile.c
+++ b/runtime/src/iree/builtins/ukernel/pack_tile.c
@@ -8,16 +8,16 @@
 
 static void iree_uk_pack_tile_generic_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
   char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
-  for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+  for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
     const char* IREE_UK_RESTRICT in_ptr = in_ptr_l1;
     char* IREE_UK_RESTRICT out_ptr = out_ptr_l1;
-    for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+    for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
       iree_uk_memcpy(out_ptr, in_ptr, tile_size1 * elem_size);
       out_ptr += tile_size1 * elem_size;
       in_ptr += in_stride0 * elem_size;
@@ -29,19 +29,19 @@
 
 static void iree_uk_pack_tile_generic_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride1, iree_uk_ssize_t in_stride0,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride1, iree_uk_index_t in_stride0,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
   char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
-  for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+  for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
     const char* IREE_UK_RESTRICT in_ptr_l2 = in_ptr_l1;
     char* IREE_UK_RESTRICT out_ptr_l2 = out_ptr_l1;
-    for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+    for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
       const char* IREE_UK_RESTRICT in_ptr = in_ptr_l2;
       char* IREE_UK_RESTRICT out_ptr = out_ptr_l2;
-      for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+      for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
         iree_uk_memcpy(out_ptr, in_ptr, elem_size);
         out_ptr += tile_size0 * elem_size;
         in_ptr += elem_size;
diff --git a/runtime/src/iree/builtins/ukernel/query_tile_sizes.h b/runtime/src/iree/builtins/ukernel/query_tile_sizes.h
index 3f819f8..a7578c7 100644
--- a/runtime/src/iree/builtins/ukernel/query_tile_sizes.h
+++ b/runtime/src/iree/builtins/ukernel/query_tile_sizes.h
@@ -16,14 +16,14 @@
 // Parameters for a query_tile_sizes operation.
 typedef struct iree_uk_query_tile_sizes_2d_params_t {
   iree_uk_uint32_t flags;
-  iree_uk_ssize_t size0;
-  iree_uk_ssize_t size1;
+  iree_uk_index_t size0;
+  iree_uk_index_t size1;
   const iree_uk_uint64_t* cpu_data;
 } iree_uk_query_tile_sizes_2d_params_t;
 
 typedef struct iree_uk_query_tile_sizes_2d_out_params_t {
-  iree_uk_ssize_t tile_size0;
-  iree_uk_ssize_t tile_size1;
+  iree_uk_index_t tile_size0;
+  iree_uk_index_t tile_size1;
 } iree_uk_query_tile_sizes_2d_out_params_t;
 
 IREE_UK_EXPORT int iree_uk_query_tile_sizes_2d(
diff --git a/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c
index 32dc799..b572d78 100644
--- a/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/e2e_matmul_benchmark.c
@@ -274,17 +274,17 @@
       .in_stride0 = mmt4d_params.out_stride0,
   };
 
-  iree_uk_ssize_t rowmajor_lhs_buffer_size =
+  iree_uk_index_t rowmajor_lhs_buffer_size =
       iree_uk_2d_buffer_length(lhs_type, params->M, params->K);
-  iree_uk_ssize_t rowmajor_rhs_buffer_size =
+  iree_uk_index_t rowmajor_rhs_buffer_size =
       iree_uk_2d_buffer_length(rhs_type, params->K, params->N);
-  iree_uk_ssize_t rowmajor_out_buffer_size =
+  iree_uk_index_t rowmajor_out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params->M, params->N);
-  iree_uk_ssize_t packed_lhs_buffer_size =
+  iree_uk_index_t packed_lhs_buffer_size =
       iree_uk_2d_buffer_length(lhs_type, M1, mmt4d_params.lhs_stride0);
-  iree_uk_ssize_t packed_rhs_buffer_size =
+  iree_uk_index_t packed_rhs_buffer_size =
       iree_uk_2d_buffer_length(rhs_type, N1, mmt4d_params.rhs_stride0);
-  iree_uk_ssize_t packed_out_buffer_size =
+  iree_uk_index_t packed_out_buffer_size =
       iree_uk_2d_buffer_length(out_type, M1, mmt4d_params.out_stride0);
   void* rowmajor_lhs_buffer = malloc(rowmajor_lhs_buffer_size);
   void* rowmajor_rhs_buffer = malloc(rowmajor_rhs_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c
index 94f68ea..edf7afc 100644
--- a/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/memcpy_benchmark.c
@@ -28,10 +28,10 @@
       benchmark_def->user_data;
 
   int64_t total_iterations = 0;
-  iree_uk_ssize_t buffer_size = user_data->working_set_size / 2;
+  iree_uk_index_t buffer_size = user_data->working_set_size / 2;
   uint8_t* in_buffer = malloc(buffer_size);
   uint8_t* out_buffer = malloc(buffer_size);
-  for (iree_uk_ssize_t i = 0; i < buffer_size; ++i) in_buffer[i] = (i & 0xFF);
+  for (iree_uk_index_t i = 0; i < buffer_size; ++i) in_buffer[i] = (i & 0xFF);
   int64_t batch_count = 1;
   while (iree_benchmark_keep_running(benchmark_state, batch_count)) {
     for (int i = 0; i < batch_count; ++i) {
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
index 66cc4de..0fabf45 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
@@ -47,11 +47,11 @@
   iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
   iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
   iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
-  iree_uk_ssize_t lhs_buffer_size =
+  iree_uk_index_t lhs_buffer_size =
       iree_uk_2d_buffer_length(lhs_type, params.M, params.lhs_stride0);
-  iree_uk_ssize_t rhs_buffer_size =
+  iree_uk_index_t rhs_buffer_size =
       iree_uk_2d_buffer_length(rhs_type, params.N, params.rhs_stride0);
-  iree_uk_ssize_t out_buffer_size =
+  iree_uk_index_t out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params.M, params.out_stride0);
   void* lhs_buffer = malloc(lhs_buffer_size);
   void* rhs_buffer = malloc(rhs_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
index 7799dd2..b2a47c0 100644
--- a/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
@@ -14,8 +14,8 @@
     float* out_ptr, const float* lhs_ptr, const float* rhs_ptr,
     const iree_uk_mmt4d_params_t* params) {
   float acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0.f;
-  for (iree_uk_ssize_t k = 0; k < params->K; ++k) {
-    for (iree_uk_ssize_t k0 = 0; k0 < params->K0; ++k0) {
+  for (iree_uk_index_t k = 0; k < params->K; ++k) {
+    for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) {
       float lhs_val = lhs_ptr[k * params->M0 * params->K0 + k0];
       float rhs_val = rhs_ptr[k * params->N0 * params->K0 + k0];
       acc += lhs_val * rhs_val;
@@ -28,8 +28,8 @@
     int32_t* out_ptr, const int8_t* lhs_ptr, const int8_t* rhs_ptr,
     const iree_uk_mmt4d_params_t* params) {
   int32_t acc = params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE ? *out_ptr : 0;
-  for (iree_uk_ssize_t k = 0; k < params->K; ++k) {
-    for (iree_uk_ssize_t k0 = 0; k0 < params->K0; ++k0) {
+  for (iree_uk_index_t k = 0; k < params->K; ++k) {
+    for (iree_uk_index_t k0 = 0; k0 < params->K0; ++k0) {
       int32_t lhs_val = lhs_ptr[k * params->M0 * params->K0 + k0];
       int32_t rhs_val = rhs_ptr[k * params->N0 * params->K0 + k0];
       acc += lhs_val * rhs_val;
@@ -40,14 +40,14 @@
 
 static void iree_mmt4d_reference(const iree_uk_mmt4d_params_t* params) {
   iree_uk_mmt4d_type_t mmt4d_type = iree_uk_mmt4d_type(params->flags);
-  iree_uk_ssize_t lhs_elem_size =
+  iree_uk_index_t lhs_elem_size =
       iree_uk_type_size(iree_uk_mmt4d_lhs_type(mmt4d_type));
-  iree_uk_ssize_t rhs_elem_size =
+  iree_uk_index_t rhs_elem_size =
       iree_uk_type_size(iree_uk_mmt4d_rhs_type(mmt4d_type));
-  iree_uk_ssize_t out_elem_size =
+  iree_uk_index_t out_elem_size =
       iree_uk_type_size(iree_uk_mmt4d_out_type(mmt4d_type));
-  for (iree_uk_ssize_t i = 0; i < params->M; ++i) {
-    for (iree_uk_ssize_t j = 0; j < params->N; ++j) {
+  for (iree_uk_index_t i = 0; i < params->M; ++i) {
+    for (iree_uk_index_t j = 0; j < params->N; ++j) {
       void* out_tile_ptr = ((char*)params->out_buffer) +
                            (params->out_offset + i * params->out_stride0 +
                             j * params->M0 * params->N0) *
@@ -58,8 +58,8 @@
       const void* rhs_panel_ptr =
           ((const char*)params->rhs_buffer) +
           (params->rhs_offset + j * params->rhs_stride0) * rhs_elem_size;
-      for (iree_uk_ssize_t i0 = 0; i0 < params->M0; ++i0) {
-        for (iree_uk_ssize_t j0 = 0; j0 < params->N0; ++j0) {
+      for (iree_uk_index_t i0 = 0; i0 < params->M0; ++i0) {
+        for (iree_uk_index_t j0 = 0; j0 < params->N0; ++j0) {
           void* out_ptr =
               ((char*)out_tile_ptr) + (i0 * params->N0 + j0) * out_elem_size;
           const void* lhs_ptr =
@@ -104,9 +104,9 @@
   iree_uk_type_t lhs_type = iree_uk_mmt4d_lhs_type(mmt4d_type);
   iree_uk_type_t rhs_type = iree_uk_mmt4d_rhs_type(mmt4d_type);
   iree_uk_type_t out_type = iree_uk_mmt4d_out_type(mmt4d_type);
-  iree_uk_ssize_t lhs_buffer_size =
+  iree_uk_index_t lhs_buffer_size =
       iree_uk_2d_buffer_length(lhs_type, params.M, params.lhs_stride0);
-  iree_uk_ssize_t rhs_buffer_size =
+  iree_uk_index_t rhs_buffer_size =
       iree_uk_2d_buffer_length(rhs_type, params.N, params.rhs_stride0);
   void* lhs_buffer = malloc(lhs_buffer_size);
   void* rhs_buffer = malloc(rhs_buffer_size);
@@ -122,7 +122,7 @@
 
   iree_uk_mmt4d_params_t reference_params;
   memcpy(&reference_params, &params, sizeof params);
-  iree_uk_ssize_t out_buffer_size =
+  iree_uk_index_t out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params.M, params.out_stride0);
   void* reference_out_buffer = malloc(out_buffer_size);
   iree_uk_write_random_buffer(reference_out_buffer, out_buffer_size, out_type,
diff --git a/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c
index 830c89b..7301169 100644
--- a/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/pack_benchmark.c
@@ -33,15 +33,15 @@
   iree_uk_pack_type_t pack_type = iree_uk_pack_type(params.flags);
   iree_uk_type_t in_type = iree_uk_pack_in_type(pack_type);
   iree_uk_type_t out_type = iree_uk_pack_out_type(pack_type);
-  iree_uk_ssize_t in_type_size = iree_uk_type_size(in_type);
-  iree_uk_ssize_t out_type_size = iree_uk_type_size(out_type);
+  iree_uk_index_t in_type_size = iree_uk_type_size(in_type);
+  iree_uk_index_t out_type_size = iree_uk_type_size(out_type);
 
   // The inner dims 2, 3 are given to us as part of the benchmark user_data.
   // The outer dims 0, 1 are to be determined based on FLAG_working_set_size.
-  iree_uk_ssize_t out_size0 = 1;
-  iree_uk_ssize_t out_size1 = 1;
-  iree_uk_ssize_t out_size2 = params.out_size2;
-  iree_uk_ssize_t out_size3 = params.out_size3;
+  iree_uk_index_t out_size0 = 1;
+  iree_uk_index_t out_size1 = 1;
+  iree_uk_index_t out_size2 = params.out_size2;
+  iree_uk_index_t out_size3 = params.out_size3;
   int target_matrix_size_in_elems =
       FLAG_working_set_size / (in_type_size + out_type_size);
   int target_product_of_outer_sizes_0_1 =
@@ -55,18 +55,18 @@
   params.out_size0 = out_size0;
   params.out_size1 = out_size1;
   if (params.flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&out_size0, &out_size1);
+    iree_uk_index_swap(&out_size0, &out_size1);
   }
   if (params.flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&out_size2, &out_size3);
+    iree_uk_index_swap(&out_size2, &out_size3);
   }
   params.in_size0 = iree_max(0, out_size0 * out_size2 - FLAG_padding_size);
   params.in_size1 = iree_max(0, out_size1 * out_size3 - FLAG_padding_size);
   params.in_stride0 = params.in_size1;
   params.out_stride0 = params.out_size1 * params.out_size2 * params.out_size3;
-  iree_uk_ssize_t in_buffer_size =
+  iree_uk_index_t in_buffer_size =
       iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
-  iree_uk_ssize_t out_buffer_size =
+  iree_uk_index_t out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
   void* in_buffer = malloc(in_buffer_size);
   void* out_buffer = malloc(out_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/pack_test.c b/runtime/src/iree/builtins/ukernel/tools/pack_test.c
index 4163b56..295398d 100644
--- a/runtime/src/iree/builtins/ukernel/tools/pack_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/pack_test.c
@@ -14,35 +14,35 @@
   // For now, the input and output element types are always the same.
   iree_uk_pack_type_t pack_type = iree_uk_pack_type(params->flags);
   iree_uk_type_t elem_type = iree_uk_pack_in_type(pack_type);
-  iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
-  iree_uk_ssize_t outer_size0 = params->out_size0;
-  iree_uk_ssize_t outer_size1 = params->out_size1;
-  iree_uk_ssize_t tile_size0 = params->out_size2;
-  iree_uk_ssize_t tile_size1 = params->out_size3;
-  iree_uk_ssize_t out_stride_l0 = params->out_stride0;
-  iree_uk_ssize_t out_stride_l1 = params->out_size3 * params->out_size2;
-  iree_uk_ssize_t out_stride_l2 = params->out_size3;
-  iree_uk_ssize_t out_stride_l3 = 1;
+  iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+  iree_uk_index_t outer_size0 = params->out_size0;
+  iree_uk_index_t outer_size1 = params->out_size1;
+  iree_uk_index_t tile_size0 = params->out_size2;
+  iree_uk_index_t tile_size1 = params->out_size3;
+  iree_uk_index_t out_stride_l0 = params->out_stride0;
+  iree_uk_index_t out_stride_l1 = params->out_size3 * params->out_size2;
+  iree_uk_index_t out_stride_l2 = params->out_size3;
+  iree_uk_index_t out_stride_l3 = 1;
   if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&outer_size0, &outer_size1);
-    iree_uk_ssize_swap(&out_stride_l0, &out_stride_l1);
+    iree_uk_index_swap(&outer_size0, &outer_size1);
+    iree_uk_index_swap(&out_stride_l0, &out_stride_l1);
   }
   if (params->flags & IREE_UK_FLAG_PACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&tile_size0, &tile_size1);
-    iree_uk_ssize_swap(&out_stride_l2, &out_stride_l3);
+    iree_uk_index_swap(&tile_size0, &tile_size1);
+    iree_uk_index_swap(&out_stride_l2, &out_stride_l3);
   }
   IREE_UK_ASSERT(outer_size0 * tile_size0 >= params->in_size0);
   IREE_UK_ASSERT(outer_size1 * tile_size1 >= params->in_size1);
-  for (iree_uk_ssize_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
-    for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
-      for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
-        for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
-          iree_uk_ssize_t out_offset =
+  for (iree_uk_index_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
+    for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+      for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+        for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+          iree_uk_index_t out_offset =
               params->out_offset + outer_i0 * out_stride_l0 +
               tile_i0 * out_stride_l2 + outer_i1 * out_stride_l1 +
               tile_i1 * out_stride_l3;
-          iree_uk_ssize_t i0 = outer_i0 * tile_size0 + tile_i0;
-          iree_uk_ssize_t i1 = outer_i1 * tile_size1 + tile_i1;
+          iree_uk_index_t i0 = outer_i0 * tile_size0 + tile_i0;
+          iree_uk_index_t i1 = outer_i1 * tile_size1 + tile_i1;
           char* out_ptr = ((char*)params->out_buffer) + out_offset * elem_size;
           if (i0 >= params->in_size0 || i1 >= params->in_size1) {
             if (elem_size == 1) {
@@ -52,12 +52,12 @@
             } else if (elem_size == 4) {
               *(iree_uk_uint32_t*)out_ptr = params->padding_value;
             } else {
-              for (iree_uk_ssize_t k = 0; k < elem_size; k += 8) {
+              for (iree_uk_index_t k = 0; k < elem_size; k += 8) {
                 *(iree_uk_uint64_t*)(out_ptr + k) = params->padding_value;
               }
             }
           } else {
-            iree_uk_ssize_t in_offset =
+            iree_uk_index_t in_offset =
                 params->in_offset + i1 + i0 * params->in_stride0;
             const char* in_ptr =
                 ((char*)params->in_buffer) + in_offset * elem_size;
@@ -80,7 +80,7 @@
   params.out_stride0 = params.out_size1 * params.out_size2 * params.out_size3;
   iree_uk_pack_type_t pack_type = iree_uk_pack_type(params.flags);
   iree_uk_type_t in_type = iree_uk_pack_in_type(pack_type);
-  iree_uk_ssize_t in_buffer_size =
+  iree_uk_index_t in_buffer_size =
       iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
   void* in_buffer = malloc(in_buffer_size);
   iree_uk_write_random_buffer(in_buffer, in_buffer_size, in_type, engine);
@@ -92,7 +92,7 @@
   iree_uk_pack_params_t reference_params;
   memcpy(&reference_params, &params, sizeof reference_params);
   iree_uk_type_t out_type = iree_uk_pack_out_type(pack_type);
-  iree_uk_ssize_t out_buffer_size =
+  iree_uk_index_t out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
   void* reference_out_buffer = malloc(out_buffer_size);
   iree_uk_write_random_buffer(reference_out_buffer, out_buffer_size, out_type,
@@ -157,21 +157,21 @@
           params.out_size1 = outer_shape.size1;
           if (transpose_outer) {
             params.flags |= IREE_UK_FLAG_PACK_TRANSPOSE_OUTER;
-            iree_uk_ssize_swap(&params.out_size0, &params.out_size1);
+            iree_uk_index_swap(&params.out_size0, &params.out_size1);
           }
-          iree_uk_ssize_t tile_size0 = params.out_size2;
-          iree_uk_ssize_t tile_size1 = params.out_size3;
+          iree_uk_index_t tile_size0 = params.out_size2;
+          iree_uk_index_t tile_size1 = params.out_size3;
           if (transpose_inner) {
             params.flags |= IREE_UK_FLAG_PACK_TRANSPOSE_INNER;
-            iree_uk_ssize_swap(&tile_size0, &tile_size1);
+            iree_uk_index_swap(&tile_size0, &tile_size1);
           }
           params.in_size0 = outer_shape.size0 * tile_size0;
           params.in_size1 = outer_shape.size1 * tile_size1;
           iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test);
           if (pad == pad_one_incomplete_tile) {
-            iree_uk_ssize_t pad_size0 =
+            iree_uk_index_t pad_size0 =
                 iree_uk_random_engine_get_0_65535(engine) % tile_size0;
-            iree_uk_ssize_t pad_size1 =
+            iree_uk_index_t pad_size1 =
                 iree_uk_random_engine_get_0_65535(engine) % tile_size1;
             params.in_size0 = params.in_size0 - pad_size0;
             if (params.in_size0 < 0) params.in_size0 = 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c b/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c
index 46bac1d..16e200f 100644
--- a/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c
+++ b/runtime/src/iree/builtins/ukernel/tools/unpack_benchmark.c
@@ -34,15 +34,15 @@
   iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params.flags);
   iree_uk_type_t in_type = iree_uk_unpack_in_type(unpack_type);
   iree_uk_type_t out_type = iree_uk_unpack_out_type(unpack_type);
-  iree_uk_ssize_t in_type_size = iree_uk_type_size(in_type);
-  iree_uk_ssize_t out_type_size = iree_uk_type_size(out_type);
+  iree_uk_index_t in_type_size = iree_uk_type_size(in_type);
+  iree_uk_index_t out_type_size = iree_uk_type_size(out_type);
 
   // The inner dims 2, 3 are given to us as part of the benchmark user_data.
   // The outer dims 0, 1 are to be determined based on FLAG_working_set_size.
-  iree_uk_ssize_t in_size0 = 1;
-  iree_uk_ssize_t in_size1 = 1;
-  iree_uk_ssize_t in_size2 = params.in_size2;
-  iree_uk_ssize_t in_size3 = params.in_size3;
+  iree_uk_index_t in_size0 = 1;
+  iree_uk_index_t in_size1 = 1;
+  iree_uk_index_t in_size2 = params.in_size2;
+  iree_uk_index_t in_size3 = params.in_size3;
   int target_matrix_size_in_elems =
       FLAG_working_set_size / (in_type_size + out_type_size);
   int target_product_of_outer_sizes_0_1 =
@@ -56,18 +56,18 @@
   params.in_size0 = in_size0;
   params.in_size1 = in_size1;
   if (params.flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&in_size0, &in_size1);
+    iree_uk_index_swap(&in_size0, &in_size1);
   }
   if (params.flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&in_size2, &in_size3);
+    iree_uk_index_swap(&in_size2, &in_size3);
   }
   params.out_size0 = iree_max(0, in_size0 * in_size2 - FLAG_padding_size);
   params.out_size1 = iree_max(0, in_size1 * in_size3 - FLAG_padding_size);
   params.out_stride0 = params.out_size1;
   params.in_stride0 = params.in_size1 * params.in_size2 * params.in_size3;
-  iree_uk_ssize_t in_buffer_size =
+  iree_uk_index_t in_buffer_size =
       iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
-  iree_uk_ssize_t out_buffer_size =
+  iree_uk_index_t out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
   void* in_buffer = malloc(in_buffer_size);
   void* out_buffer = malloc(out_buffer_size);
diff --git a/runtime/src/iree/builtins/ukernel/tools/unpack_test.c b/runtime/src/iree/builtins/ukernel/tools/unpack_test.c
index 8b8fac3..257ff67 100644
--- a/runtime/src/iree/builtins/ukernel/tools/unpack_test.c
+++ b/runtime/src/iree/builtins/ukernel/tools/unpack_test.c
@@ -14,35 +14,35 @@
   iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
   // For now, the input and output element types are always the same.
   iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
-  iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
-  iree_uk_ssize_t outer_size0 = params->in_size0;
-  iree_uk_ssize_t outer_size1 = params->in_size1;
-  iree_uk_ssize_t tile_size0 = params->in_size2;
-  iree_uk_ssize_t tile_size1 = params->in_size3;
-  iree_uk_ssize_t in_stride_outer0 = params->in_stride0;
-  iree_uk_ssize_t in_stride_outer1 = params->in_size3 * params->in_size2;
-  iree_uk_ssize_t in_stride_tile0 = params->in_size3;
-  iree_uk_ssize_t in_stride_tile1 = 1;
+  iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+  iree_uk_index_t outer_size0 = params->in_size0;
+  iree_uk_index_t outer_size1 = params->in_size1;
+  iree_uk_index_t tile_size0 = params->in_size2;
+  iree_uk_index_t tile_size1 = params->in_size3;
+  iree_uk_index_t in_stride_outer0 = params->in_stride0;
+  iree_uk_index_t in_stride_outer1 = params->in_size3 * params->in_size2;
+  iree_uk_index_t in_stride_tile0 = params->in_size3;
+  iree_uk_index_t in_stride_tile1 = 1;
   if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&outer_size0, &outer_size1);
-    iree_uk_ssize_swap(&in_stride_outer0, &in_stride_outer1);
+    iree_uk_index_swap(&outer_size0, &outer_size1);
+    iree_uk_index_swap(&in_stride_outer0, &in_stride_outer1);
   }
   if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&tile_size0, &tile_size1);
-    iree_uk_ssize_swap(&in_stride_tile0, &in_stride_tile1);
+    iree_uk_index_swap(&tile_size0, &tile_size1);
+    iree_uk_index_swap(&in_stride_tile0, &in_stride_tile1);
   }
-  for (iree_uk_ssize_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
-    for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
-      for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
-        for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
-          iree_uk_ssize_t in_offset =
+  for (iree_uk_index_t outer_i0 = 0; outer_i0 < outer_size0; ++outer_i0) {
+    for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+      for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+        for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+          iree_uk_index_t in_offset =
               params->in_offset + outer_i0 * in_stride_outer0 +
               tile_i0 * in_stride_tile0 + outer_i1 * in_stride_outer1 +
               tile_i1 * in_stride_tile1;
-          iree_uk_ssize_t i0 = outer_i0 * tile_size0 + tile_i0;
-          iree_uk_ssize_t i1 = outer_i1 * tile_size1 + tile_i1;
+          iree_uk_index_t i0 = outer_i0 * tile_size0 + tile_i0;
+          iree_uk_index_t i1 = outer_i1 * tile_size1 + tile_i1;
           if (!(i0 >= params->out_size0 || i1 >= params->out_size1)) {
-            iree_uk_ssize_t out_offset =
+            iree_uk_index_t out_offset =
                 params->out_offset + i1 + i0 * params->out_stride0;
             const char* in_ptr =
                 ((char*)params->in_buffer) + in_offset * elem_size;
@@ -68,7 +68,7 @@
                       iree_uk_random_engine_get_0_1(engine);
   iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params.flags);
   iree_uk_type_t in_type = iree_uk_unpack_in_type(unpack_type);
-  iree_uk_ssize_t in_buffer_size =
+  iree_uk_index_t in_buffer_size =
       iree_uk_2d_buffer_length(in_type, params.in_size0, params.in_stride0);
   void* in_buffer = malloc(in_buffer_size);
   iree_uk_write_random_buffer(in_buffer, in_buffer_size, in_type, engine);
@@ -80,7 +80,7 @@
   iree_uk_unpack_params_t reference_params;
   memcpy(&reference_params, &params, sizeof reference_params);
   iree_uk_type_t out_type = iree_uk_unpack_out_type(unpack_type);
-  iree_uk_ssize_t out_buffer_size =
+  iree_uk_index_t out_buffer_size =
       iree_uk_2d_buffer_length(out_type, params.out_size0, params.out_stride0);
   void* reference_out_buffer = malloc(out_buffer_size);
   iree_uk_write_random_buffer(reference_out_buffer, out_buffer_size, out_type,
@@ -139,31 +139,31 @@
           memcpy(&params, src_params, sizeof params);
           params.cpu_data = iree_uk_test_cpu_data(test);
           outer_shape_t outer_shape = outer_shapes[i];
-          iree_uk_ssize_t in_size0 = outer_shape.size0;
-          iree_uk_ssize_t in_size1 = outer_shape.size1;
+          iree_uk_index_t in_size0 = outer_shape.size0;
+          iree_uk_index_t in_size1 = outer_shape.size1;
           params.in_size0 = in_size0;
           params.in_size1 = in_size1;
           if (pad == pad_a_lot) {
             params.in_size0 += 16;
             params.in_size1 += 16;
           }
-          iree_uk_ssize_t tile_size0 = params.in_size2;
-          iree_uk_ssize_t tile_size1 = params.in_size3;
+          iree_uk_index_t tile_size0 = params.in_size2;
+          iree_uk_index_t tile_size1 = params.in_size3;
           if (transpose_outer) {
             params.flags |= IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER;
-            iree_uk_ssize_swap(&in_size0, &in_size1);
+            iree_uk_index_swap(&in_size0, &in_size1);
           }
           if (transpose_inner) {
             params.flags |= IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER;
-            iree_uk_ssize_swap(&tile_size0, &tile_size1);
+            iree_uk_index_swap(&tile_size0, &tile_size1);
           }
           params.out_size0 = in_size0 * tile_size0;
           params.out_size1 = in_size1 * tile_size1;
           if (pad == pad_one_incomplete_tile) {
             iree_uk_random_engine_t* engine = iree_uk_test_random_engine(test);
-            iree_uk_ssize_t pad_size0 =
+            iree_uk_index_t pad_size0 =
                 iree_uk_random_engine_get_0_65535(engine) % tile_size0;
-            iree_uk_ssize_t pad_size1 =
+            iree_uk_index_t pad_size1 =
                 iree_uk_random_engine_get_0_65535(engine) % tile_size1;
             params.out_size0 = params.out_size0 - pad_size0;
             if (params.out_size0 < 0) params.out_size0 = 0;
diff --git a/runtime/src/iree/builtins/ukernel/tools/util.c b/runtime/src/iree/builtins/ukernel/tools/util.c
index 8aeeff7..aafd030 100644
--- a/runtime/src/iree/builtins/ukernel/tools/util.c
+++ b/runtime/src/iree/builtins/ukernel/tools/util.c
@@ -29,20 +29,20 @@
   abort();
 }
 
-iree_uk_ssize_t iree_uk_2d_buffer_length(iree_uk_type_t type,
-                                         iree_uk_ssize_t size0,
-                                         iree_uk_ssize_t stride0) {
+iree_uk_index_t iree_uk_2d_buffer_length(iree_uk_type_t type,
+                                         iree_uk_index_t size0,
+                                         iree_uk_index_t stride0) {
   // Just for testing purposes, so it's OK to overestimate size.
   return size0 * stride0 << iree_uk_type_size_log2(type);
 }
 
 bool iree_uk_2d_buffers_equal(const void* buf1, const void* buf2,
-                              iree_uk_type_t type, iree_uk_ssize_t size0,
-                              iree_uk_ssize_t size1, iree_uk_ssize_t stride0) {
-  iree_uk_ssize_t elem_size = iree_uk_type_size(type);
+                              iree_uk_type_t type, iree_uk_index_t size0,
+                              iree_uk_index_t size1, iree_uk_index_t stride0) {
+  iree_uk_index_t elem_size = iree_uk_type_size(type);
   const char* buf1_ptr = buf1;
   const char* buf2_ptr = buf2;
-  for (iree_uk_ssize_t i0 = 0; i0 < size0; ++i0) {
+  for (iree_uk_index_t i0 = 0; i0 < size0; ++i0) {
     if (memcmp(buf1_ptr, buf2_ptr, elem_size * size1)) return false;
     buf1_ptr += elem_size * stride0;
     buf2_ptr += elem_size * stride0;
@@ -82,12 +82,12 @@
   return (v % 32) - 16;
 }
 
-void iree_uk_write_random_buffer(void* buffer, iree_uk_ssize_t size_in_bytes,
+void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
                                  iree_uk_type_t type,
                                  iree_uk_random_engine_t* engine) {
-  iree_uk_ssize_t elem_size = iree_uk_type_size(type);
-  iree_uk_ssize_t size_in_elems = size_in_bytes / elem_size;
-  for (iree_uk_ssize_t i = 0; i < size_in_elems; ++i) {
+  iree_uk_index_t elem_size = iree_uk_type_size(type);
+  iree_uk_index_t size_in_elems = size_in_bytes / elem_size;
+  for (iree_uk_index_t i = 0; i < size_in_elems; ++i) {
     // Small integers, should work for now for all the types we currently have
     // and enable exact float arithmetic, allowing to keep tests simpler for
     // now. Watch out for when we'll do float16!
diff --git a/runtime/src/iree/builtins/ukernel/tools/util.h b/runtime/src/iree/builtins/ukernel/tools/util.h
index acd9380..37a6df4 100644
--- a/runtime/src/iree/builtins/ukernel/tools/util.h
+++ b/runtime/src/iree/builtins/ukernel/tools/util.h
@@ -10,13 +10,13 @@
 #include "iree/builtins/ukernel/api.h"
 
 // Helper to determine the length of test buffers to allocate.
-iree_uk_ssize_t iree_uk_2d_buffer_length(iree_uk_type_t type,
-                                         iree_uk_ssize_t size0,
-                                         iree_uk_ssize_t size1);
+iree_uk_index_t iree_uk_2d_buffer_length(iree_uk_type_t type,
+                                         iree_uk_index_t size0,
+                                         iree_uk_index_t size1);
 
 bool iree_uk_2d_buffers_equal(const void* buf1, const void* buf2,
-                              iree_uk_type_t type, iree_uk_ssize_t size0,
-                              iree_uk_ssize_t size1, iree_uk_ssize_t stride0);
+                              iree_uk_type_t type, iree_uk_index_t size0,
+                              iree_uk_index_t size1, iree_uk_index_t stride0);
 
 // Simple deterministic pseudorandom generator. Same as C++'s std::minstd_rand.
 typedef struct iree_uk_random_engine_t {
@@ -32,7 +32,7 @@
 int iree_uk_random_engine_get_0_65535(iree_uk_random_engine_t* e);
 int iree_uk_random_engine_get_0_1(iree_uk_random_engine_t* e);
 int iree_uk_random_engine_get_minus16_plus15(iree_uk_random_engine_t* e);
-void iree_uk_write_random_buffer(void* buffer, iree_uk_ssize_t size_in_bytes,
+void iree_uk_write_random_buffer(void* buffer, iree_uk_index_t size_in_bytes,
                                  iree_uk_type_t type,
                                  iree_uk_random_engine_t* engine);
 
diff --git a/runtime/src/iree/builtins/ukernel/unpack.c b/runtime/src/iree/builtins/ukernel/unpack.c
index 4a3696e..fdeb0bc 100644
--- a/runtime/src/iree/builtins/ukernel/unpack.c
+++ b/runtime/src/iree/builtins/ukernel/unpack.c
@@ -18,8 +18,8 @@
 } iree_uk_unpack_tmpbuf_helper_t;
 
 // Return x/y for x>=0 and y>0, with a fast path for when y is a power of two.
-static iree_uk_ssize_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
-    iree_uk_ssize_t x, iree_uk_int32_t y) {
+static iree_uk_index_t iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
+    iree_uk_index_t x, iree_uk_int32_t y) {
   IREE_UK_ASSERT(x >= 0);
   IREE_UK_ASSERT(y > 0);
   return IREE_UK_LIKELY(iree_uk_is_po2_u32(y)) ? (x >> iree_uk_po2_log2_u32(y))
@@ -29,8 +29,8 @@
 // Initializes a `iree_uk_unpack_padding_helper`. Asserts if the temporary
 // buffer is smaller than one tile.
 static void iree_uk_unpack_tmpbuf_helper_init(
-    iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
-    iree_uk_ssize_t elem_size, iree_uk_unpack_tmpbuf_helper_t* helper) {
+    iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+    iree_uk_index_t elem_size, iree_uk_unpack_tmpbuf_helper_t* helper) {
   helper->max_tiles_in_tmp_buf = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
       iree_uk_unpack_tmp_buf_size, tile_size0 * tile_size1 * elem_size);
   IREE_UK_ASSERT(helper->max_tiles_in_tmp_buf > 0);
@@ -55,15 +55,15 @@
   IREE_UK_ASSERT(params->in_size3 >= 0);
   // Check that the input and output shapes match, give or take padding that
   // must not exceed the inner tile size.s
-  iree_uk_ssize_t outer_size0 = params->in_size0;
-  iree_uk_ssize_t outer_size1 = params->in_size1;
-  iree_uk_ssize_t tile_size0 = params->in_size2;
-  iree_uk_ssize_t tile_size1 = params->in_size3;
+  iree_uk_index_t outer_size0 = params->in_size0;
+  iree_uk_index_t outer_size1 = params->in_size1;
+  iree_uk_index_t tile_size0 = params->in_size2;
+  iree_uk_index_t tile_size1 = params->in_size3;
   if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&outer_size0, &outer_size1);
+    iree_uk_index_swap(&outer_size0, &outer_size1);
   }
   if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&tile_size0, &tile_size1);
+    iree_uk_index_swap(&tile_size0, &tile_size1);
   }
   IREE_UK_ASSERT(outer_size0 * tile_size0 >= params->out_size0);
   IREE_UK_ASSERT(outer_size1 * tile_size1 >= params->out_size1);
@@ -79,7 +79,7 @@
   iree_uk_unpack_tmpbuf_helper_t helper;
   iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
   iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
-  iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
+  iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
   iree_uk_unpack_tmpbuf_helper_init(tile_size0, tile_size1, elem_size, &helper);
 #endif  // IREE_UK_ENABLE_ASSERTS
 }
@@ -89,12 +89,12 @@
   return (params->out_size0 == 0 || params->out_size1 == 0);
 }
 
-static void iree_uk_copy_slice(iree_uk_ssize_t src_stride0, const char* src_buf,
-                               iree_uk_ssize_t dst_size0,
-                               iree_uk_ssize_t dst_size1,
-                               iree_uk_ssize_t dst_stride0, char* dst_buf,
-                               iree_uk_ssize_t elem_size) {
-  for (iree_uk_ssize_t in_i0 = 0; in_i0 < dst_size0; in_i0++) {
+static void iree_uk_copy_slice(iree_uk_index_t src_stride0, const char* src_buf,
+                               iree_uk_index_t dst_size0,
+                               iree_uk_index_t dst_size1,
+                               iree_uk_index_t dst_stride0, char* dst_buf,
+                               iree_uk_index_t elem_size) {
+  for (iree_uk_index_t in_i0 = 0; in_i0 < dst_size0; in_i0++) {
     iree_uk_memcpy(dst_buf, src_buf, dst_size1 * elem_size);
     dst_buf += dst_stride0 * elem_size;
     src_buf += src_stride0 * elem_size;
@@ -105,19 +105,19 @@
 // incomplete tiles. In cases involving only complete tiles, it is faster to
 // call tile_func directly.
 static void iree_uk_unpack_row_using_tmpbuf(
-    iree_uk_unpack_tile_func_t tile_func, iree_uk_ssize_t dim1_tile_start,
-    iree_uk_ssize_t dim1_tile_end, iree_uk_ssize_t dim0_write_size,
-    iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t out_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
+    iree_uk_unpack_tile_func_t tile_func, iree_uk_index_t dim1_tile_start,
+    iree_uk_index_t dim1_tile_end, iree_uk_index_t dim0_write_size,
+    iree_uk_index_t tile_size0, iree_uk_index_t tile_size1,
+    iree_uk_index_t elem_size, iree_uk_index_t out_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
     iree_uk_unpack_tmpbuf_helper_t* helper, const char* in_buf, char* out_buf) {
-  iree_uk_ssize_t dim1_tile = dim1_tile_start;
+  iree_uk_index_t dim1_tile = dim1_tile_start;
   while (dim1_tile < dim1_tile_end) {
-    iree_uk_ssize_t dim1_chunk_tiles = iree_uk_ssize_clamp(
+    iree_uk_index_t dim1_chunk_tiles = iree_uk_index_clamp(
         dim1_tile_end - dim1_tile, 0, helper->max_tiles_in_tmp_buf);
-    iree_uk_ssize_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
-    iree_uk_ssize_t dim1_chunk_src_pos = dim1_tile * tile_size1;
-    iree_uk_ssize_t dim1_write_size = iree_uk_ssize_clamp(
+    iree_uk_index_t dim1_chunk_src_width = dim1_chunk_tiles * tile_size1;
+    iree_uk_index_t dim1_chunk_src_pos = dim1_tile * tile_size1;
+    iree_uk_index_t dim1_write_size = iree_uk_index_clamp(
         out_size1 - dim1_chunk_src_pos, 0, dim1_chunk_src_width);
     tile_func(helper->tmp_buf, in_buf + dim1_tile * in_stride1 * elem_size,
               dim1_chunk_tiles, dim1_chunk_src_width, in_stride1, elem_size,
@@ -135,19 +135,19 @@
   // For now, the input and output element types are always the same.
   iree_uk_unpack_type_t unpack_type = iree_uk_unpack_type(params->flags);
   iree_uk_type_t elem_type = iree_uk_unpack_in_type(unpack_type);
-  iree_uk_ssize_t elem_size = iree_uk_type_size(elem_type);
-  iree_uk_ssize_t outer_size0 = params->in_size0;
-  iree_uk_ssize_t outer_size1 = params->in_size1;
-  iree_uk_ssize_t tile_size0 = params->in_size2;
-  iree_uk_ssize_t tile_size1 = params->in_size3;
-  iree_uk_ssize_t in_stride0 = params->in_stride0;
-  iree_uk_ssize_t in_stride1 = params->in_size3 * params->in_size2;
+  iree_uk_index_t elem_size = iree_uk_type_size(elem_type);
+  iree_uk_index_t outer_size0 = params->in_size0;
+  iree_uk_index_t outer_size1 = params->in_size1;
+  iree_uk_index_t tile_size0 = params->in_size2;
+  iree_uk_index_t tile_size1 = params->in_size3;
+  iree_uk_index_t in_stride0 = params->in_stride0;
+  iree_uk_index_t in_stride1 = params->in_size3 * params->in_size2;
   if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_OUTER) {
-    iree_uk_ssize_swap(&outer_size0, &outer_size1);
-    iree_uk_ssize_swap(&in_stride0, &in_stride1);
+    iree_uk_index_swap(&outer_size0, &outer_size1);
+    iree_uk_index_swap(&in_stride0, &in_stride1);
   }
   if (params->flags & IREE_UK_FLAG_UNPACK_TRANSPOSE_INNER) {
-    iree_uk_ssize_swap(&tile_size0, &tile_size1);
+    iree_uk_index_swap(&tile_size0, &tile_size1);
   }
   const char* in_buf =
       (const char*)params->in_buffer + (params->in_offset * elem_size);
@@ -163,7 +163,7 @@
   // source buffer's boundaries.
   int dim1_full_tiles = iree_uk_div_nonneg_by_pos_and_likely_po2_i32(
       params->out_size1, tile_size1);
-  iree_uk_ssize_t i0 = 0;
+  iree_uk_index_t i0 = 0;
   for (; i0 <= params->out_size0 - tile_size0; i0 += tile_size0) {
     // Pack whole tiles that do not require padding (entirely within the source
     // buffer's boundaries).
@@ -179,8 +179,8 @@
   }
   // Bottom-padding.
   for (; i0 < outer_size0 * tile_size0; i0 += tile_size0) {
-    iree_uk_ssize_t dim0_write_size =
-        iree_uk_ssize_clamp(params->out_size0 - i0, 0, tile_size0);
+    iree_uk_index_t dim0_write_size =
+        iree_uk_index_clamp(params->out_size0 - i0, 0, tile_size0);
     iree_uk_unpack_row_using_tmpbuf(
         tile_func, 0, outer_size1, dim0_write_size, tile_size0, tile_size1,
         elem_size, params->out_size1, params->out_stride0, in_stride1,
diff --git a/runtime/src/iree/builtins/ukernel/unpack.h b/runtime/src/iree/builtins/ukernel/unpack.h
index ba86c9e..ff7d9cd 100644
--- a/runtime/src/iree/builtins/ukernel/unpack.h
+++ b/runtime/src/iree/builtins/ukernel/unpack.h
@@ -15,17 +15,17 @@
 
 typedef struct iree_uk_unpack_params_t {
   const void* in_buffer;
-  iree_uk_ssize_t in_offset;
-  iree_uk_ssize_t in_stride0;
+  iree_uk_index_t in_offset;
+  iree_uk_index_t in_stride0;
   void* out_buffer;
-  iree_uk_ssize_t out_offset;
-  iree_uk_ssize_t out_stride0;
-  iree_uk_ssize_t in_size0;
-  iree_uk_ssize_t in_size1;
-  iree_uk_ssize_t in_size2;
-  iree_uk_ssize_t in_size3;
-  iree_uk_ssize_t out_size0;
-  iree_uk_ssize_t out_size1;
+  iree_uk_index_t out_offset;
+  iree_uk_index_t out_stride0;
+  iree_uk_index_t in_size0;
+  iree_uk_index_t in_size1;
+  iree_uk_index_t in_size2;
+  iree_uk_index_t in_size3;
+  iree_uk_index_t out_size0;
+  iree_uk_index_t out_size1;
   iree_uk_uint32_t flags;
   const iree_uk_uint64_t* cpu_data;
 } iree_uk_unpack_params_t;
diff --git a/runtime/src/iree/builtins/ukernel/unpack_internal.h b/runtime/src/iree/builtins/ukernel/unpack_internal.h
index 3d48d82..d92f056 100644
--- a/runtime/src/iree/builtins/ukernel/unpack_internal.h
+++ b/runtime/src/iree/builtins/ukernel/unpack_internal.h
@@ -38,18 +38,18 @@
 
 typedef void (*iree_uk_unpack_tile_func_t)(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1);
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1);
 
 // Tile kernel declarations. Prototype matches iree_uk_unpack_tile_func_t.
 #define IREE_UK_UNPACK_TILE_FUNC_DECL(NAME)                           \
   void NAME(void* IREE_UK_RESTRICT out_tile_ptr,                      \
             const void* IREE_UK_RESTRICT in_tile_ptr,                 \
-            iree_uk_ssize_t outer_size1, iree_uk_ssize_t out_stride0, \
-            iree_uk_ssize_t in_stride1, iree_uk_ssize_t elem_size,    \
-            iree_uk_ssize_t tile_size0, iree_uk_ssize_t tile_size1);
+            iree_uk_index_t outer_size1, iree_uk_index_t out_stride0, \
+            iree_uk_index_t in_stride1, iree_uk_index_t elem_size,    \
+            iree_uk_index_t tile_size0, iree_uk_index_t tile_size1);
 
 // Returns the tile function to use for the unpack op with the given params.
 iree_uk_unpack_tile_func_t iree_uk_unpack_select_tile_func(
diff --git a/runtime/src/iree/builtins/ukernel/unpack_tile.c b/runtime/src/iree/builtins/ukernel/unpack_tile.c
index 3a1b464..eb7266b 100644
--- a/runtime/src/iree/builtins/ukernel/unpack_tile.c
+++ b/runtime/src/iree/builtins/ukernel/unpack_tile.c
@@ -8,16 +8,16 @@
 
 static void iree_uk_unpack_tile_generic_direct(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
   char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
-  for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+  for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
     const char* IREE_UK_RESTRICT in_ptr = in_ptr_l1;
     char* IREE_UK_RESTRICT out_ptr = out_ptr_l1;
-    for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+    for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
       iree_uk_memcpy(out_ptr, in_ptr, tile_size1 * elem_size);
       in_ptr += tile_size1 * elem_size;
       out_ptr += out_stride0 * elem_size;
@@ -29,19 +29,19 @@
 
 static void iree_uk_unpack_tile_generic_transpose(
     void* IREE_UK_RESTRICT out_tile_ptr,
-    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_ssize_t outer_size1,
-    iree_uk_ssize_t out_stride0, iree_uk_ssize_t in_stride1,
-    iree_uk_ssize_t elem_size, iree_uk_ssize_t tile_size0,
-    iree_uk_ssize_t tile_size1) {
+    const void* IREE_UK_RESTRICT in_tile_ptr, iree_uk_index_t outer_size1,
+    iree_uk_index_t out_stride0, iree_uk_index_t in_stride1,
+    iree_uk_index_t elem_size, iree_uk_index_t tile_size0,
+    iree_uk_index_t tile_size1) {
   const char* IREE_UK_RESTRICT in_ptr_l1 = in_tile_ptr;
   char* IREE_UK_RESTRICT out_ptr_l1 = out_tile_ptr;
-  for (iree_uk_ssize_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
+  for (iree_uk_index_t outer_i1 = 0; outer_i1 < outer_size1; ++outer_i1) {
     const char* IREE_UK_RESTRICT in_ptr_l2 = in_ptr_l1;
     char* IREE_UK_RESTRICT out_ptr_l2 = out_ptr_l1;
-    for (iree_uk_ssize_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
+    for (iree_uk_index_t tile_i0 = 0; tile_i0 < tile_size0; ++tile_i0) {
       const char* IREE_UK_RESTRICT in_ptr = in_ptr_l2;
       char* IREE_UK_RESTRICT out_ptr = out_ptr_l2;
-      for (iree_uk_ssize_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
+      for (iree_uk_index_t tile_i1 = 0; tile_i1 < tile_size1; ++tile_i1) {
         iree_uk_memcpy(out_ptr, in_ptr, elem_size);
         in_ptr += tile_size0 * elem_size;
         out_ptr += elem_size;