Builtin ukernels as system/standalone plugins (#13433)
Making builtin ukernels available to llvm-cpu modules compiled with
`--iree-llvmcpu-enable-microkernels`.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index fc82c38..038d0a3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -175,7 +175,7 @@
# Experimental project flags
#-------------------------------------------------------------------------------
-option(IREE_BUILD_EXPERIMENTAL_VMVX_MMT4D "Enables MMT4D methods in the VMVX module." OFF)
+option(IREE_BUILD_EXPERIMENTAL_CPU_UKERNEL_PLUGINS "Build experimental plugins making builtin ukernels available to llvm-cpu modules compiled with --iree-llvmcpu-enable-microkernels" OFF)
option(IREE_BUILD_EXPERIMENTAL_WEB_SAMPLES "Builds experimental web samples." OFF)
#-------------------------------------------------------------------------------
@@ -967,6 +967,10 @@
add_subdirectory(experimental/web)
endif()
+if(IREE_BUILD_EXPERIMENTAL_CPU_UKERNEL_PLUGINS)
+ add_subdirectory(experimental/cpu_ukernel)
+endif()
+
set(IREE_PUBLIC_INCLUDE_DIRS "${IREE_COMMON_INCLUDE_DIRS}"
CACHE INTERNAL "IREE: Include Directories" FORCE)
diff --git a/build_tools/cmake/iree_macros.cmake b/build_tools/cmake/iree_macros.cmake
index fd2f8cb..cf2102f 100644
--- a/build_tools/cmake/iree_macros.cmake
+++ b/build_tools/cmake/iree_macros.cmake
@@ -52,33 +52,33 @@
string(TOLOWER "${_IREE_UNNORMALIZED_ARCH}" _IREE_UNNORMALIZED_ARCH_LOWERCASE)
# Normalize _IREE_UNNORMALIZED_ARCH into IREE_ARCH.
-if (EMSCRIPTEN)
+if(EMSCRIPTEN)
# TODO: figure what to do about the wasm target, which masquerades as x86.
# This is the one case where the IREE_ARCH CMake variable is currently
# inconsistent with the IREE_ARCH C preprocessor token.
- set (IREE_ARCH "")
-elseif ((_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "aarch64") OR
+ set(IREE_ARCH "")
+elseif((_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "aarch64") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "arm64") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "arm64e") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "arm64ec"))
- set (IREE_ARCH "arm_64")
-elseif ((_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "arm") OR
+ set(IREE_ARCH "arm_64")
+elseif((_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "arm") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE MATCHES "^armv[5-8]"))
- set (IREE_ARCH "arm_32")
-elseif ((_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "x86_64") OR
+ set(IREE_ARCH "arm_32")
+elseif((_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "x86_64") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "amd64") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "x64"))
- set (IREE_ARCH "x86_64")
-elseif ((_IREE_UNNORMALIZED_ARCH_LOWERCASE MATCHES "^i[3-7]86$") OR
+ set(IREE_ARCH "x86_64")
+elseif((_IREE_UNNORMALIZED_ARCH_LOWERCASE MATCHES "^i[3-7]86$") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "x86") OR
(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "win32"))
- set (IREE_ARCH "x86_32")
-elseif (_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "riscv64")
- set (IREE_ARCH "riscv_64")
-elseif (_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "riscv32")
- set (IREE_ARCH "riscv_32")
-elseif (_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "")
- set (IREE_ARCH "")
+ set(IREE_ARCH "x86_32")
+elseif(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "riscv64")
+ set(IREE_ARCH "riscv_64")
+elseif(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "riscv32")
+ set(IREE_ARCH "riscv_32")
+elseif(_IREE_UNNORMALIZED_ARCH_LOWERCASE STREQUAL "")
+ set(IREE_ARCH "")
message(WARNING "Performance advisory: architecture-specific code paths "
"disabled because no target architecture was specified or we didn't know "
"which CMake variable to read. Some relevant CMake variables:\n"
@@ -88,10 +88,33 @@
"CMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES}\n"
)
else()
- set (IREE_ARCH "")
+ set(IREE_ARCH "")
message(SEND_ERROR "Unrecognized target architecture ${_IREE_UNNORMALIZED_ARCH_LOWERCASE}")
endif()
+# iree_arch_to_llvm_arch()
+#
+# Helper mapping an architecture in IREE's naming scheme (as in IREE_ARCH)
+# to an architecture in LLVM's naming scheme (as in LLVM target triples).
+function(iree_arch_to_llvm_arch DST_LLVM_ARCH_VARIABLE SRC_ARCH)
+ if("${SRC_ARCH}" STREQUAL "arm_64")
+ set(${DST_LLVM_ARCH_VARIABLE} "aarch64" PARENT_SCOPE)
+ elseif("${SRC_ARCH}" STREQUAL "arm_32")
+ set(${DST_LLVM_ARCH_VARIABLE} "arm" PARENT_SCOPE)
+ elseif("${SRC_ARCH}" STREQUAL "x86_64")
+ set(${DST_LLVM_ARCH_VARIABLE} "x86_64" PARENT_SCOPE)
+ elseif("${SRC_ARCH}" STREQUAL "x86_32")
+ set(${DST_LLVM_ARCH_VARIABLE} "i386" PARENT_SCOPE)
+ elseif("${SRC_ARCH}" STREQUAL "riscv_64")
+ set(${DST_LLVM_ARCH_VARIABLE} "riscv64" PARENT_SCOPE)
+ elseif("${SRC_ARCH}" STREQUAL "riscv_32")
+ set(${DST_LLVM_ARCH_VARIABLE} "riscv32" PARENT_SCOPE)
+ else()
+ message(SEND_ERROR "What is the LLVM name of the architecture that we call ${SRC_ARCH} ?")
+ set(${DST_LLVM_ARCH_VARIABLE} "unknown" PARENT_SCOPE)
+ endif()
+endfunction()
+
#-------------------------------------------------------------------------------
# General utilities
#-------------------------------------------------------------------------------
@@ -141,7 +164,7 @@
set(_PACKAGE "")
endif()
if(IREE_PACKAGE_ROOT_PREFIX)
- if ("${_PACKAGE}" STREQUAL "")
+ if("${_PACKAGE}" STREQUAL "")
set(_PACKAGE "${IREE_PACKAGE_ROOT_PREFIX}")
else()
set(_PACKAGE "${IREE_PACKAGE_ROOT_PREFIX}/${_PACKAGE}")
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
index 610bd67..54b26b8 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
+++ b/compiler/src/iree/compiler/Codegen/Dialect/UKernelOps.cpp
@@ -9,6 +9,7 @@
#include "iree/builtins/ukernel/exported_bits.h"
#include "iree/compiler/Codegen/Dialect/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Utils/EncodingInfo.h"
+#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -55,12 +56,6 @@
for (auto attr : fnDefAttrs) {
fnDecl->setAttr(attr.getName(), attr.getValue());
}
- // TODO(#12327): Based on description in the issue, add an attribute
- // `vm.import.module` and set it to `vmvx`. This only works on `vmvx`
- // backend (obviously), but is enough to unblock while the proper fix lands.
- // For now there are a bunch of attributes set on the function, but this
- // should be made more controllable based on the backend.
- fnDecl->setAttr("vm.import.module", rewriter.getStringAttr("vmvx"));
fnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true));
} else if (fnDecl.getFunctionType() != functionType) {
return rewriter.notifyMatchFailure(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
index 7524f49..4de389a 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULowerToUKernels.cpp
@@ -37,6 +37,36 @@
matchPattern(fillVal, m_AnyZeroFloat());
}
+/// Holds a function name and attributes.
+struct FnNameAndDefAttrs {
+ std::string name;
+ SmallVector<NamedAttribute> defAttrs;
+};
+
+/// Returns the function name and attributes to use for a ukernel with given
+/// `ukernelName` on the target described by `targetAttr`.
+static FnNameAndDefAttrs getFnNameAndDefAttrs(
+ const char *ukernelName, RewriterBase &rewriter,
+ IREE::HAL::ExecutableTargetAttr targetAttr) {
+ FnNameAndDefAttrs result;
+ if (isVMVXBackend(targetAttr)) {
+ result.name = std::string("vmvx.") + ukernelName;
+ // TODO(#12327): Based on description in the issue, add an attribute
+ // `vm.import.module` and set it to `vmvx`. This only works on `vmvx`
+ // backend (obviously), but is enough to unblock while the proper fix
+ // lands. For now there are a bunch of attributes set on the function, but
+ // this should be made more controllable based on the backend.
+ result.defAttrs.emplace_back(rewriter.getStringAttr("vm.import.module"),
+ rewriter.getStringAttr("vmvx"));
+ } else {
+ result.name = std::string("ukernel.") + ukernelName;
+ result.defAttrs.emplace_back(
+ rewriter.getStringAttr("hal.import.fields"),
+ rewriter.getArrayAttr({rewriter.getStringAttr("processor_data")}));
+ }
+ return result;
+}
+
/// Matches an (linalg.fill -> )? linalg.mmt4d operation sequence and converts
/// it into a iree_codegen.ukernel.mmt4d operation, that is later lowered
/// into a call to the microkernel.
@@ -78,15 +108,24 @@
Value m = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value n = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value k = rewriter.create<tensor::DimOp>(loc, rhs, 1);
- Value m0 = rewriter.create<tensor::DimOp>(loc, lhs, 2);
- Value n0 = rewriter.create<tensor::DimOp>(loc, rhs, 2);
- Value k0 = rewriter.create<tensor::DimOp>(loc, rhs, 3);
+
+ auto getDimAsI32 = [](RewriterBase &rewriter, Location loc, Value value,
+ int dim) -> Value {
+ return rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(),
+ rewriter.create<tensor::DimOp>(loc, value, dim));
+ };
+ Value m0 = getDimAsI32(rewriter, loc, lhs, 2);
+ Value n0 = getDimAsI32(rewriter, loc, rhs, 2);
+ Value k0 = getDimAsI32(rewriter, loc, rhs, 3);
Value flagsVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(flags));
+ auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
+ auto fn = getFnNameAndDefAttrs("mmt4d", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
- loc, outType, "vmvx.mmt4d", ValueRange{lhs, rhs}, out,
+ loc, outType, fn.name, ValueRange{lhs, rhs}, out,
ValueRange{m, n, k, m0, n0, k0, flagsVal},
- /*fn_def_attrs=*/nullptr,
+ /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
/*strided_outer_dims=*/rewriter.getIndexAttr(1));
return cast<IREE::Codegen::UKernelOpInterface>(
genericMicroKernelOp.getOperation());
@@ -192,11 +231,13 @@
Value out_size3 = rewriter.create<tensor::DimOp>(loc, out, 3);
Value flagsVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(flags));
+ auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
+ auto fn = getFnNameAndDefAttrs("pack", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
- loc, outType, "vmvx.pack", in, out,
+ loc, outType, fn.name, in, out,
ValueRange{in_size0, in_size1, out_size0, out_size1, out_size2, out_size3,
paddingVal, flagsVal},
- /*fn_def_attrs=*/nullptr,
+ /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
/*strided_outer_dims=*/rewriter.getIndexAttr(1));
return cast<IREE::Codegen::UKernelOpInterface>(
genericMicroKernelOp.getOperation());
@@ -267,11 +308,13 @@
Value out_size1 = rewriter.create<tensor::DimOp>(loc, out, 1);
Value flagsVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(flags));
+ auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
+ auto fn = getFnNameAndDefAttrs("unpack", rewriter, targetAttr);
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
- loc, outType, "vmvx.unpack", in, out,
+ loc, outType, fn.name, in, out,
ValueRange{in_size0, in_size1, in_size2, in_size3, out_size0, out_size1,
flagsVal},
- /*fn_def_attrs=*/nullptr,
+ /*fn_def_attrs=*/rewriter.getDictionaryAttr(fn.defAttrs),
/*strided_outer_dims=*/rewriter.getIndexAttr(1));
return cast<IREE::Codegen::UKernelOpInterface>(
genericMicroKernelOp.getOperation());
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
index 2017e83..758168d 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/lower_to_ukernel_ops.mlir
@@ -18,10 +18,13 @@
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[M0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG1]], %[[C2]]
-// CHECK-DAG: %[[K0:.+]] = tensor.dim %[[ARG1]], %[[C3]]
-// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.mmt4d"
+// CHECK-DAG: %[[M0_index:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[M0:.+]] = arith.index_cast %[[M0_index]] : index to i32
+// CHECK-DAG: %[[N0_index:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[N0:.+]] = arith.index_cast %[[N0_index]] : index to i32
+// CHECK-DAG: %[[K0_index:.+]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-DAG: %[[K0:.+]] = arith.index_cast %[[K0_index]] : index to i32
+// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "ukernel.mmt4d"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[M]], %[[N]], %[[K]], %[[M0]], %[[N0]], %[[K0]], %[[FLAGS]] :
@@ -44,10 +47,13 @@
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[M0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG1]], %[[C2]]
-// CHECK-DAG: %[[K0:.+]] = tensor.dim %[[ARG1]], %[[C3]]
-// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.mmt4d"
+// CHECK-DAG: %[[M0_index:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[M0:.+]] = arith.index_cast %[[M0_index]] : index to i32
+// CHECK-DAG: %[[N0_index:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[N0:.+]] = arith.index_cast %[[N0_index]] : index to i32
+// CHECK-DAG: %[[K0_index:.+]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-DAG: %[[K0:.+]] = arith.index_cast %[[K0_index]] : index to i32
+// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "ukernel.mmt4d"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[M]], %[[N]], %[[K]], %[[M0]], %[[N0]], %[[K0]], %[[FLAGS]] :
@@ -73,10 +79,13 @@
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[M0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG1]], %[[C2]]
-// CHECK-DAG: %[[K0:.+]] = tensor.dim %[[ARG1]], %[[C3]]
-// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "vmvx.mmt4d"
+// CHECK-DAG: %[[M0_index:.+]] = tensor.dim %[[ARG0]], %[[C2]]
+// CHECK-DAG: %[[M0:.+]] = arith.index_cast %[[M0_index]] : index to i32
+// CHECK-DAG: %[[N0_index:.+]] = tensor.dim %[[ARG1]], %[[C2]]
+// CHECK-DAG: %[[N0:.+]] = arith.index_cast %[[N0_index]] : index to i32
+// CHECK-DAG: %[[K0_index:.+]] = tensor.dim %[[ARG1]], %[[C3]]
+// CHECK-DAG: %[[K0:.+]] = arith.index_cast %[[K0_index]] : index to i32
+// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "ukernel.mmt4d"
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK-SAME: (%[[M]], %[[N]], %[[K]], %[[M0]], %[[N0]], %[[K0]], %[[FLAGS]] :
@@ -98,7 +107,7 @@
// CHECK-DAG: %[[OUT_SIZE1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[OUT_SIZE2:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[OUT_SIZE3:.+]] = arith.constant 8 : index
-// CHECK: ukernel.generic "vmvx.pack"
+// CHECK: ukernel.generic "ukernel.pack"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[ARG1]] :
// CHECK-SAME: (%[[IN_SIZE0]], %[[IN_SIZE1]], %[[OUT_SIZE0]], %[[OUT_SIZE1]], %[[OUT_SIZE2]], %[[OUT_SIZE3]], %[[PAD]], %[[FLAGS]] :
@@ -124,7 +133,7 @@
// CHECK-DAG: %[[OUT_SIZE1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[OUT_SIZE2:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[OUT_SIZE3:.+]] = arith.constant 8 : index
-// CHECK: ukernel.generic "vmvx.pack"
+// CHECK: ukernel.generic "ukernel.pack"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[ARG1]] :
// CHECK-SAME: (%[[IN_SIZE0]], %[[IN_SIZE1]], %[[OUT_SIZE0]], %[[OUT_SIZE1]], %[[OUT_SIZE2]], %[[OUT_SIZE3]], %[[PAD]], %[[FLAGS]] :
@@ -151,7 +160,7 @@
// CHECK-DAG: %[[OUT_SIZE1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[OUT_SIZE2:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[OUT_SIZE3:.+]] = arith.constant 8 : index
-// CHECK: ukernel.generic "vmvx.pack"
+// CHECK: ukernel.generic "ukernel.pack"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[ARG1]] :
// CHECK-SAME: (%[[IN_SIZE0]], %[[IN_SIZE1]], %[[OUT_SIZE0]], %[[OUT_SIZE1]], %[[OUT_SIZE2]], %[[OUT_SIZE3]], %[[PAD]], %[[FLAGS]] :
@@ -175,7 +184,7 @@
// CHECK-DAG: %[[OUT_SIZE1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[IN_SIZE2:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[IN_SIZE3:.+]] = arith.constant 8 : index
-// CHECK: ukernel.generic "vmvx.unpack"
+// CHECK: ukernel.generic "ukernel.unpack"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[ARG1]] :
// CHECK-SAME: (%[[IN_SIZE0]], %[[IN_SIZE1]], %[[IN_SIZE2]], %[[IN_SIZE3]], %[[OUT_SIZE0]], %[[OUT_SIZE1]], %[[FLAGS]] :
@@ -199,7 +208,7 @@
// CHECK-DAG: %[[OUT_SIZE1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[IN_SIZE2:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[IN_SIZE3:.+]] = arith.constant 8 : index
-// CHECK: ukernel.generic "vmvx.unpack"
+// CHECK: ukernel.generic "ukernel.unpack"
// CHECK-SAME: ins(%[[ARG0]] :
// CHECK-SAME: outs(%[[ARG1]] :
// CHECK-SAME: (%[[IN_SIZE0]], %[[IN_SIZE1]], %[[IN_SIZE2]], %[[IN_SIZE3]], %[[OUT_SIZE0]], %[[OUT_SIZE1]], %[[FLAGS]] :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
index fe1adbf..dd543e6 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/pipeline_tests.mlir
@@ -472,7 +472,7 @@
}
}
// CHECK-LABEL: func @ukernel_dispatch()
-// CHECK: iree_codegen.ukernel.generic "vmvx.mmt4d"
+// CHECK: iree_codegen.ukernel.generic "ukernel.mmt4d"
// -----
diff --git a/experimental/cpu_ukernel/CMakeLists.txt b/experimental/cpu_ukernel/CMakeLists.txt
new file mode 100644
index 0000000..583dfb6
--- /dev/null
+++ b/experimental/cpu_ukernel/CMakeLists.txt
@@ -0,0 +1,105 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if(NOT IREE_BUILD_EXPERIMENTAL_CPU_UKERNEL_PLUGINS)
+ return()
+endif()
+
+add_subdirectory(test)
+
+include(iree_experimental_system_plugin.cmake)
+include(iree_experimental_standalone_plugin.cmake)
+
+iree_experimental_system_plugin(
+ NAME
+ builtin_ukernel_system_plugin
+ SRCS
+ plugin.c
+ DEPS
+ iree::builtins::ukernel
+)
+
+set(IREE_UK_COMMON_COPTS
+ "-DIREE_UK_ENABLE_INLINE_ASM"
+)
+
+set(IREE_UK_X86_64_COPTS
+ ${IREE_UK_COMMON_COPTS}
+ "-DIREE_UK_ARCH_X86_64"
+ "-DIREE_UK_POINTER_SIZE=8"
+ "-DIREE_UK_BUILD_X86_64_AVX2_FMA"
+ "-DIREE_UK_BUILD_X86_64_AVX512_BASE"
+ "-DIREE_UK_BUILD_X86_64_AVX512_VNNI"
+)
+
+set(IREE_UK_ARM_64_COPTS
+ ${IREE_UK_COMMON_COPTS}
+ "-DIREE_UK_ARCH_ARM_64"
+ "-DIREE_UK_POINTER_SIZE=8"
+ "-DIREE_UK_BUILD_ARM_64_DOTPROD"
+ "-DIREE_UK_BUILD_ARM_64_I8MM"
+)
+
+set(IREE_UK_X86_64_AVX2_FMA_COPTS
+ "-mavx2"
+ "-mfma"
+)
+
+set(IREE_UK_X86_64_AVX512_BASE_COPTS
+ ${IREE_UK_X86_64_AVX2_FMA_COPTS}
+ "-mavx512f"
+ "-mavx512vl"
+ "-mavx512cd"
+ "-mavx512bw"
+ "-mavx512dq"
+)
+
+set(IREE_UK_X86_64_AVX512_VNNI_COPTS
+ ${IREE_UK_X86_64_AVX512_BASE_COPTS}
+ "-mavx512vnni"
+)
+
+set(IREE_UK_ARM_64_DOTPROD_COPTS
+ "-march=armv8.2-a+dotprod"
+)
+
+set(IREE_UK_ARM_64_I8MM_COPTS
+ "-march=armv8.2-a+i8mm"
+)
+
+iree_experimental_standalone_plugin(
+ NAME
+ builtin_ukernel_standalone_plugin
+ ARCHS
+ "x86_64:IREE_UK_X86_64_COPTS"
+ "arm_64:IREE_UK_ARM_64_COPTS"
+ SRCS
+ plugin.c
+ runtime/src/iree/builtins/ukernel/mmt4d.c
+ runtime/src/iree/builtins/ukernel/mmt4d_tile.c
+ runtime/src/iree/builtins/ukernel/unpack_tile.c
+ runtime/src/iree/builtins/ukernel/pack.c
+ runtime/src/iree/builtins/ukernel/query_tile_sizes.c
+ runtime/src/iree/builtins/ukernel/unpack.c
+ runtime/src/iree/builtins/ukernel/pack_tile.c
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/query_tile_sizes_x86_64.c"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64.c"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64.c"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64.c"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c:IREE_UK_X86_64_AVX2_FMA_COPTS"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx2_fma.c:IREE_UK_X86_64_AVX2_FMA_COPTS"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx2_fma.c:IREE_UK_X86_64_AVX2_FMA_COPTS"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/unpack_x86_64_avx512_base.c:IREE_UK_X86_64_AVX512_BASE_COPTS"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_base.c:IREE_UK_X86_64_AVX512_BASE_COPTS"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/pack_x86_64_avx512_base.c:IREE_UK_X86_64_AVX512_BASE_COPTS"
+ "x86_64:runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c:IREE_UK_X86_64_AVX512_VNNI_COPTS"
+ "arm_64:runtime/src/iree/builtins/ukernel/arch/arm_64/query_tile_sizes_arm_64.c"
+ "arm_64:runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64.c"
+ "arm_64:runtime/src/iree/builtins/ukernel/arch/arm_64/pack_arm_64.c"
+ "arm_64:runtime/src/iree/builtins/ukernel/arch/arm_64/unpack_arm_64.c"
+ "arm_64:runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_dotprod.c:IREE_UK_ARM_64_DOTPROD_COPTS"
+ "arm_64:runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c:IREE_UK_ARM_64_I8MM_COPTS"
+)
diff --git a/experimental/cpu_ukernel/iree_experimental_standalone_plugin.cmake b/experimental/cpu_ukernel/iree_experimental_standalone_plugin.cmake
new file mode 100644
index 0000000..079ffc4
--- /dev/null
+++ b/experimental/cpu_ukernel/iree_experimental_standalone_plugin.cmake
@@ -0,0 +1,223 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# iree_experimental_standalone_plugin_arch()
+#
+# Helper for iree_experimental_standalone_plugin, building one
+# architecture.
+#
+# Parameters:
+# NAME: Name of the system plugin to create.
+# ARCH: Name of architecture (as in IREE_ARCH) to build for.
+# Example: "arm_64".
+# COPTS: List of compiler options to be applied to all source files.
+# SRCS: List of source files. Each list entry may be of one of two forms:
+# * Each entry that does not contain a colon is interpreted as a source
+# file path, to be built unconditionally, with the compiler options
+# specified in `COPTS`.
+# * Each entry that contains a colon is interpreted as a colon-separated
+# list of length either 2 or 3. Format:
+# `ARCH:FILE[:FILE_COPTS_VAR_NAME]`.
+# Any entry whose `ARCH` does not match this rules's `ARCH` parameter
+# is filtered out. Remaining files are compiled with the
+# architecture-wide compiler options (see `COPTS`) and, if provided,
+# with the file-specific compiler options from expanding the variable
+# specified in `FILE_COPTS_VAR_NAME`.
+# Example: "x86_64:some_file_for_x86_64_using_avx512_instructions.c:NAME_OF_VARIABLE_CONTAINING_COPTS_FOR_X86_64_AVX512".
+function(iree_experimental_standalone_plugin_arch)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME;ARCH"
+ "SRCS;COPTS"
+ ${ARGN}
+ )
+
+ iree_package_name(_PACKAGE_NAME)
+ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}_${_RULE_ARCH}")
+ iree_arch_to_llvm_arch(LLVM_ARCH "${_RULE_ARCH}")
+
+ foreach(_SRC_ENTRY_COLON_SEPARATED IN LISTS _RULE_SRCS)
+ string(REPLACE ":" ";" _SRC_ENTRY_LIST "${_SRC_ENTRY_COLON_SEPARATED}")
+ list(LENGTH _SRC_ENTRY_LIST _SRC_ENTRY_LIST_LENGTH)
+ set(_SRC_COPTS_VAR_NAME "")
+ set(_SRC_FILE "")
+ if(_SRC_ENTRY_LIST_LENGTH EQUAL 1)
+ set(_SRC_FILE "${_SRC_ENTRY_LIST}")
+ else() # NOT _SRC_ENTRY_LIST_LENGTH EQUAL 1
+ list(GET _SRC_ENTRY_LIST 0 _SRC_ARCH)
+ if(NOT _SRC_ARCH STREQUAL _RULE_ARCH)
+ continue()
+ endif()
+ list(GET _SRC_ENTRY_LIST 1 _SRC_FILE)
+ if(_SRC_ENTRY_LIST_LENGTH EQUAL 3)
+ list(GET _SRC_ENTRY_LIST 2 _SRC_COPTS_VAR_NAME)
+ endif()
+ endif() # NOT _SRC_ENTRY_LIST_LENGTH EQUAL 1
+
+ set(_SRC_COPTS "${${_SRC_COPTS_VAR_NAME}}")
+
+ get_filename_component(_SRC_FILE_BASENAME "${_SRC_FILE}" NAME)
+
+ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/${_SRC_FILE}")
+ set(_SRC_FILE "${CMAKE_CURRENT_SOURCE_DIR}/${_SRC_FILE}")
+ endif()
+
+ if(EXISTS "${PROJECT_SOURCE_DIR}/${_SRC_FILE}")
+ set(_SRC_FILE "${PROJECT_SOURCE_DIR}/${_SRC_FILE}")
+ endif()
+
+ set(_OBJECT_FILE "${_SRC_FILE_BASENAME}.${_RULE_ARCH}.o")
+ list(APPEND _OBJECT_FILES "${CMAKE_CURRENT_BINARY_DIR}/${_OBJECT_FILE}")
+ add_custom_command(
+ OUTPUT
+ "${_OBJECT_FILE}"
+ DEPENDS
+ "${_SRC_FILE}"
+ "${IREE_CLANG_TARGET}"
+ COMMAND "${IREE_CLANG_TARGET}"
+ # Flags copied from
+ # compiler/src/iree/compiler/Dialect/HAL/Target/LLVMCPU/internal/EmbeddedLinkerTool.cpp
+ -target "${LLVM_ARCH}-unknown-unknown-eabi-elf"
+ -isystem "${IREE_BINARY_DIR}/third_party/llvm-project/llvm/lib/clang/17/include"
+ -std=c17
+ -fasm # Added for inline-asm support.
+ -fPIC
+ -ffreestanding
+ -fvisibility=hidden
+ -fno-plt
+ -fno-rtti
+ -fno-exceptions
+ -fdata-sections
+ -ffunction-sections
+ -funique-section-names
+ -DIREE_UK_STANDALONE
+ -I "${IREE_SOURCE_DIR}/runtime/src/"
+ -c "${_SRC_FILE}"
+ -o "${CMAKE_CURRENT_BINARY_DIR}/${_OBJECT_FILE}"
+ ${_RULE_COPTS}
+ ${_SRC_COPTS}
+ VERBATIM
+ )
+ endforeach()
+ set(_OUTPUT_SO_FILE "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.${_RULE_ARCH}.so")
+ add_custom_command(
+ OUTPUT
+ ${_OUTPUT_SO_FILE}
+ DEPENDS
+ ${_OBJECT_FILES}
+ ${IREE_LLD_TARGET}
+ COMMAND ${IREE_LLD_TARGET}
+ -flavor gnu
+ --build-id=none
+ -nostdlib
+ -static
+ -shared
+ --no-undefined
+ --no-allow-shlib-undefined
+ --allow-multiple-definition
+ --gc-sections
+ -z now
+ -z relro
+ --discard-all
+ --icf=all
+ --ignore-data-address-equality
+ --ignore-function-address-equality
+ --hash-style=sysv
+ --strip-debug
+ ${_OBJECT_FILES}
+ -o "${_OUTPUT_SO_FILE}"
+ VERBATIM
+ )
+ add_custom_target(${_NAME} DEPENDS
+ "${_OUTPUT_SO_FILE}"
+ )
+endfunction()
+
+# iree_experimental_standalone_plugin()
+#
+# Creates a standalone plugin library, that is built using our in-tree Clang
+# toolchain for multiple target architectures, generating a fat embedded-elf,
+# and may be loaded with the embedded dynamic library loaded.
+#
+# Contrast with: iree_experimental_system_plugin.
+#
+# Parameters:
+# NAME: Name of the system plugin to create.
+# ARCHS: List of architectures (as in IREE_ARCH) to build. Format:
+# `ARCH[:ARCH_COPTS_VAR_NAME]`. If provided, `ARCH_COPTS_VAR_NAME` is
+# interpreted as the name of a variable to be expanded into all compiler
+# command lines used for architecture `ARCH`.
+# Example: "arm_64:NAME_OF_VARIABLE_CONTAINING_COPTS_FOR_ARM_64".
+# SRCS: List of source files. Each list entry may be of one of two forms:
+# * Each entry that does not contain a colon is interpreted as a source
+# file path, to be built for all architectures with the
+# architecture-wide compiler options provided for each architecture
+# (see `ARCHS`).
+# * Each entry that contains a colon is interpreted as a colon-separated
+# list of length either 2 or 3. Format:
+# `ARCH:FILE[:FILE_COPTS_VAR_NAME]`.
+# The specified source `FILE` is compiled only for the specified
+# architecture `ARCH` and is skipped on other architectures. It is
+# compiled with the architecture-wide compiler options
+# (see `ARCHS`) and, if provided, with the file-specific compiler
+# options from expanding the variable specified in
+# `FILE_COPTS_VAR_NAME`.
+# Example: "x86_64:some_file_for_x86_64_using_avx512_instructions.c:NAME_OF_VARIABLE_CONTAINING_COPTS_FOR_X86_64_AVX512".
+function(iree_experimental_standalone_plugin)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME"
+ "SRCS;ARCHS"
+ ${ARGN}
+ )
+
+ # Iterate over architectures. For each of them, build the architecture-specific
+ # shared library (iree_experimental_standalone_plugin_arch).
+ foreach(_ARCH_ENTRY_COLON_SEPARATED IN LISTS _RULE_ARCHS)
+ # Turn the colon-separated ARCH entry into a CMake list (semicolon-separated)
+ string(REPLACE ":" ";" _ARCH_ENTRY_LIST "${_ARCH_ENTRY_COLON_SEPARATED}")
+ list(GET _ARCH_ENTRY_LIST 0 _ARCH)
+ list(LENGTH _ARCH_ENTRY_LIST _ARCH_ENTRY_LIST_LENGTH)
+ # Get optional architecture-wide copts into _COPTS.
+ set(_COPTS_VAR_NAME "")
+ if(_ARCH_ENTRY_LIST_LENGTH EQUAL 2)
+ list(GET _ARCH_ENTRY_LIST 1 _COPTS_VAR_NAME)
+ endif()
+ set(_COPTS "${${_COPTS_VAR_NAME}}")
+ # Build the architecture-specific shared library.
+ iree_experimental_standalone_plugin_arch(
+ NAME
+ "${_RULE_NAME}"
+ ARCH
+ "${_ARCH}"
+ SRCS
+ ${_RULE_SRCS}
+ COPTS
+ ${_COPTS}
+ )
+ list(APPEND _ARCH_SO_FILES "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.${_ARCH}.so")
+ endforeach()
+ # Generate the multi-architecture ELF file.
+ add_custom_command(
+ OUTPUT
+ "${_RULE_NAME}.sos"
+ DEPENDS
+ ${_ARCH_SO_FILES}
+ iree-fatelf
+ COMMAND iree-fatelf join
+ ${_ARCH_SO_FILES}
+ > ${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.sos
+ VERBATIM
+ )
+ iree_package_name(_PACKAGE_NAME)
+ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
+ add_custom_target("${_NAME}" DEPENDS
+ "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.sos"
+ )
+ add_dependencies(iree-test-deps "${_NAME}")
+endfunction()
diff --git a/experimental/cpu_ukernel/iree_experimental_system_plugin.cmake b/experimental/cpu_ukernel/iree_experimental_system_plugin.cmake
new file mode 100644
index 0000000..4b8acbc
--- /dev/null
+++ b/experimental/cpu_ukernel/iree_experimental_system_plugin.cmake
@@ -0,0 +1,53 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# # iree_experimental_system_plugin
+#
+# Creates a system plugin library, that is built using the host toolchain for
+# the host architecture and may be loaded with the system dynamic library
+# loader.
+#
+# Contrast with: iree_experimental_standalone_plugin.
+#
+# Parameters:
+# NAME: Name of the system plugin to create.
+# SRCS: List of source files.
+# DEPS: List of dependencies.
+function(iree_experimental_system_plugin)
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME"
+ "SRCS;DEPS"
+ ${ARGN}
+ )
+
+ iree_cc_library(
+ NAME
+ ${_RULE_NAME}
+ SRCS
+ ${_RULE_SRCS}
+ DEPS
+ ${_RULE_DEPS}
+ iree::hal::local::executable_plugin
+ INCLUDES
+ "${IREE_SOURCE_DIR}/runtime/src/"
+ SHARED
+ )
+
+ # NOTE: this is only required because we want this sample to run on all
+ # platforms without needing to change the library name (libfoo.so/foo.dll).
+ iree_package_name(_PACKAGE_NAME)
+ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
+ set_target_properties("${_NAME}"
+ PROPERTIES
+ WINDOWS_EXPORT_ALL_SYMBOLS ON
+ PREFIX ""
+ OUTPUT_NAME "${_RULE_NAME}"
+ )
+
+ add_dependencies(iree-test-deps "${_NAME}")
+endfunction()
diff --git a/experimental/cpu_ukernel/plugin.c b/experimental/cpu_ukernel/plugin.c
new file mode 100644
index 0000000..b93619f
--- /dev/null
+++ b/experimental/cpu_ukernel/plugin.c
@@ -0,0 +1,152 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/builtins/ukernel/api.h"
+#include "iree/hal/local/executable_plugin.h"
+
+// Implementation of iree_uk_assert_fail failure is deferred to users code, i.e.
+// to us here, as core ukernel/ code can't use the standard library.
+#if defined(IREE_UK_STANDALONE) // Building a standalone plugin.
+void iree_uk_assert_fail(const char* file, int line, const char* function,
+ const char* condition) {
+ // Doing nothing at the moment.
+}
+#else // Building a system plugin.
+#include <stdio.h>
+#include <stdlib.h>
+void iree_uk_assert_fail(const char* file, int line, const char* function,
+ const char* condition) {
+ fflush(stdout);
+ // Must be a single fprintf call (which must make a single write) - typically
+ // called from multiple worker threads concurrently.
+ fprintf(stderr, "%s:%d: %s: assertion failed: %s\n", file, line, function,
+ condition);
+ fflush(stderr);
+ abort();
+}
+#endif // defined(IREE_UK_STANDALONE)
+
+// Plugin entry points wrapping the actual ukernels.
+static int iree_uk_plugin_mmt4d(void* context, void* params_ptr,
+ void* reserved) {
+ iree_uk_mmt4d((const iree_uk_mmt4d_params_t*)params_ptr);
+ return 0;
+}
+
+static int iree_uk_plugin_pack(void* context, void* params_ptr,
+ void* reserved) {
+ iree_uk_pack((const iree_uk_pack_params_t*)params_ptr);
+ return 0;
+}
+
+static int iree_uk_plugin_unpack(void* context, void* params_ptr,
+ void* reserved) {
+ iree_uk_unpack((const iree_uk_unpack_params_t*)params_ptr);
+ return 0;
+}
+
+static iree_hal_executable_plugin_status_t iree_uk_plugin_load(
+ const iree_hal_executable_plugin_environment_v0_t* environment,
+ size_t param_count, const iree_hal_executable_plugin_string_pair_t* params,
+ void** out_self) {
+ *out_self = NULL; // no state in this plugin
+ return iree_hal_executable_plugin_ok_status();
+}
+
+// Called to free any plugin state allocated in load.
+static void iree_uk_plugin_unload(void* self) {}
+
+#define ARRAYSIZE(arr) (sizeof(arr) / sizeof(arr[0]))
+
+// Called to resolve one or more imports by symbol name.
+// See the plugin API header for more information. Note that some of the
+// functions may already be resolved and some may be optional.
+static iree_hal_executable_plugin_status_t iree_uk_plugin_resolve(
+ void* self, const iree_hal_executable_plugin_resolve_params_v0_t* params,
+ iree_hal_executable_plugin_resolution_t* out_resolution) {
+ typedef struct {
+ const char* symbol_name;
+ const void* fn_ptr;
+ } plugin_entry_point_t;
+ static const plugin_entry_point_t entry_points[] = {
+ {"ukernel.mmt4d", iree_uk_plugin_mmt4d},
+ {"ukernel.pack", iree_uk_plugin_pack},
+ {"ukernel.unpack", iree_uk_plugin_unpack},
+ };
+ *out_resolution = 0;
+ bool any_required_not_found = false;
+ for (size_t i = 0; i < params->count; ++i) {
+ if (params->out_fn_ptrs[i]) continue;
+ const char* symbol_name = params->symbol_names[i];
+ bool is_optional =
+ iree_hal_executable_plugin_import_is_optional(symbol_name);
+ if (is_optional) ++symbol_name;
+ bool found = false;
+ for (int ep_idx = 0; ep_idx < ARRAYSIZE(entry_points); ++ep_idx) {
+ const plugin_entry_point_t* entry_point = &entry_points[ep_idx];
+ if (iree_hal_executable_plugin_strcmp(symbol_name,
+ entry_point->symbol_name) == 0) {
+ params->out_fn_ptrs[i] = (void*)(entry_point->fn_ptr);
+ params->out_fn_contexts[i] = NULL;
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ if (is_optional) {
+ *out_resolution |=
+ IREE_HAL_EXECUTABLE_PLUGIN_RESOLUTION_MISSING_OPTIONAL;
+ } else {
+ any_required_not_found = true;
+ }
+ }
+ }
+ return any_required_not_found
+ ? iree_hal_executable_plugin_status_from_code(
+ IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND)
+ : iree_hal_executable_plugin_ok_status();
+}
+
+// Exported on the shared library and used by the runtime to query the plugin
+// interface. When statically linking the plugin this is just a function that
+// can be called and can have any name to allow for multiple plugins. When
+// dynamically linking the exported symbol must be exactly this with no C++
+// name mangling.
+IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t**
+iree_hal_executable_plugin_query(
+ iree_hal_executable_plugin_version_t max_version, void* reserved) {
+ static const iree_hal_executable_plugin_header_t header = {
+ // Declares what library version is present: newer runtimes may support
+ // loading older plugins but newer plugins cannot load on older runtimes.
+ .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST,
+#if defined(IREE_UK_STANDALONE) // Building a standalone plugin.
+ // Name and description are used for tracing/logging/diagnostics.
+ .name = "builtin_ukernel_standalone_plugin",
+ .description = "builtin ukernels as standalone plugin (" __FILE__ ")",
+ // Standalone plugins must declare that they are standalone so that the
+ // runtime can verify support.
+ .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE,
+ // Standalone plugins don't support sanitizers.
+ .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_NONE,
+#else // Building a system plugin.
+ // Name and description are used for tracing/logging/diagnostics.
+ .name = "builtin_ukernel_system_plugin",
+ .description = "builtin ukernels as system plugin (" __FILE__ ")",
+ .features = 0,
+ // Let the runtime know what sanitizer this plugin was compiled with.
+ .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND,
+#endif // defined(IREE_UK_STANDALONE)
+ };
+ static const iree_hal_executable_plugin_v0_t plugin = {
+ .header = &header,
+ .load = iree_uk_plugin_load,
+ .unload = iree_uk_plugin_unload,
+ .resolve = iree_uk_plugin_resolve,
+ };
+ return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST
+ ? (const iree_hal_executable_plugin_header_t**)&plugin
+ : NULL;
+}
diff --git a/experimental/cpu_ukernel/test/CMakeLists.txt b/experimental/cpu_ukernel/test/CMakeLists.txt
new file mode 100644
index 0000000..53a81b4
--- /dev/null
+++ b/experimental/cpu_ukernel/test/CMakeLists.txt
@@ -0,0 +1,39 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_check_single_backend_test_suite(
+ NAME
+ builtin_ukernel_system_plugin_test
+ SRCS
+ "mmt4d.mlir"
+ TARGET_BACKEND
+ "llvm-cpu"
+ DRIVER
+ "local-sync"
+ COMPILER_FLAGS
+ "--iree-llvmcpu-enable-microkernels"
+ RUNNER_ARGS
+ "--executable_plugin=${PROJECT_BINARY_DIR}/experimental/cpu_ukernel/builtin_ukernel_system_plugin.so"
+ LABELS
+ "hostonly"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ builtin_ukernel_standalone_plugin_test
+ SRCS
+ "mmt4d.mlir"
+ TARGET_BACKEND
+ "llvm-cpu"
+ DRIVER
+ "local-sync"
+ COMPILER_FLAGS
+ "--iree-llvmcpu-enable-microkernels"
+ RUNNER_ARGS
+ "--executable_plugin=${PROJECT_BINARY_DIR}/experimental/cpu_ukernel/builtin_ukernel_standalone_plugin.sos"
+ LABELS
+ "hostonly"
+)
diff --git a/experimental/cpu_ukernel/test/mmt4d.mlir b/experimental/cpu_ukernel/test/mmt4d.mlir
new file mode 100644
index 0000000..06e0272
--- /dev/null
+++ b/experimental/cpu_ukernel/test/mmt4d.mlir
@@ -0,0 +1,71 @@
+func.func @test_mmt4d() {
+ %lhs = util.unfoldable_constant
+ dense<
+ [
+ [
+ [
+ [1.0],
+ [2.0],
+ [3.0],
+ [4.0],
+ [5.0],
+ [6.0],
+ [7.0],
+ [8.0]
+ ]
+ ]
+ ]> : tensor<1x1x8x1xf32>
+ %rhs = util.unfoldable_constant
+ dense<
+ [
+ [
+ [
+ [1.0e-1],
+ [1.0e-2],
+ [1.0e-3],
+ [1.0e-4],
+ [1.0e-5],
+ [1.0e-6],
+ [1.0e-7],
+ [1.0e-8]
+ ]
+ ]
+ ]> : tensor<1x1x8x1xf32>
+ %init_acc = util.unfoldable_constant
+ dense<
+ [
+ [
+ [
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
+ ]
+ ]
+ ]> : tensor<1x1x8x8xf32>
+
+ %result = linalg.mmt4d ins(%lhs, %rhs : tensor<1x1x8x1xf32>, tensor<1x1x8x1xf32>)
+ outs(%init_acc : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32>
+
+ check.expect_almost_eq_const(%result, dense<
+ [
+ [
+ [
+ [1.0e-1, 1.0e-2, 1.0e-3, 1.0e-4, 1.0e-5, 1.0e-6, 1.0e-7, 1.0e-8],
+ [2.0e-1, 2.0e-2, 2.0e-3, 2.0e-4, 2.0e-5, 2.0e-6, 2.0e-7, 2.0e-8],
+ [3.0e-1, 3.0e-2, 3.0e-3, 3.0e-4, 3.0e-5, 3.0e-6, 3.0e-7, 3.0e-8],
+ [4.0e-1, 4.0e-2, 4.0e-3, 4.0e-4, 4.0e-5, 4.0e-6, 4.0e-7, 3.0e-8],
+ [5.0e-1, 5.0e-2, 5.0e-3, 5.0e-4, 5.0e-5, 5.0e-6, 5.0e-7, 4.0e-8],
+ [6.0e-1, 6.0e-2, 6.0e-3, 6.0e-4, 6.0e-5, 6.0e-6, 6.0e-7, 5.0e-8],
+ [7.0e-1, 7.0e-2, 7.0e-3, 7.0e-4, 7.0e-5, 7.0e-6, 7.0e-7, 6.0e-8],
+ [8.0e-1, 8.0e-2, 8.0e-3, 8.0e-4, 8.0e-5, 8.0e-6, 8.0e-7, 1.0]
+ ]
+ ]
+ ]> : tensor<1x1x8x8xf32>) : tensor<1x1x8x8xf32>
+
+ return
+}
diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
index b66391b..79c6e87 100644
--- a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
+++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64.h
@@ -129,7 +129,7 @@
static inline __m128i iree_uk_avx2_load_8x2xi8_strided(
const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
__m128i result = _mm_setzero_si128();
- const iree_uk_uint16_t* src_i16 = (const iree_uk_int16_t*)src;
+ const iree_uk_int16_t* src_i16 = (const iree_uk_int16_t*)src;
result =
_mm_insert_epi16(result, *(const iree_uk_int16_t*)(src + 0 * stride), 0);
result =
@@ -152,7 +152,7 @@
static inline __m256i iree_uk_avx2_load_16x2xi8_strided(
const iree_uk_int8_t* src, iree_uk_ssize_t stride) {
__m256i result = _mm256_setzero_si256();
- const iree_uk_uint16_t* src_i16 = (const iree_uk_int16_t*)src;
+ const iree_uk_int16_t* src_i16 = (const iree_uk_int16_t*)src;
result = _mm256_insert_epi16(result,
*(const iree_uk_int16_t*)(src + 0 * stride), 0);
result = _mm256_insert_epi16(result,
diff --git a/runtime/src/iree/builtins/ukernel/common.h b/runtime/src/iree/builtins/ukernel/common.h
index adfa79f..8e38d61 100644
--- a/runtime/src/iree/builtins/ukernel/common.h
+++ b/runtime/src/iree/builtins/ukernel/common.h
@@ -86,6 +86,10 @@
// there are portability, reliability and performance concerns with
// that.
+// Configured headers generated by CMake based on the host toolchain.
+// The condition `!defined(IREE_UK_STANDALONE)` means "host toolchain".
+#if !defined(IREE_UK_STANDALONE)
+
// Include the build-system-generated configured header and use it as the only
// source of information about the target we're compiling against, as opposed to
// including iree/base/target_platform.h.
@@ -107,6 +111,8 @@
#include "iree/builtins/ukernel/arch/x86_64/config.h"
#endif
+#endif // !defined(IREE_UK_STANDALONE)
+
// Include common flag values, shared with the compiler.
#include "iree/builtins/ukernel/exported_bits.h"