introduce a "vector.contract custom kernels" pass. (#7778)
This is refurbishing Ahmed Taei's existing
VectorContractToAArch64InlineAsmOp pass, repurposing it to implementing
matrix-times-matrix-transposed kernels, and generalizing it to create a
good insertion point for any other vector.contract custom "handwritten"
kernel, using either inline assembly or intrinsics.
diff --git a/build_tools/bazel/iree_check_test.bzl b/build_tools/bazel/iree_check_test.bzl
index ddae530..5359e42 100644
--- a/build_tools/bazel/iree_check_test.bzl
+++ b/build_tools/bazel/iree_check_test.bzl
@@ -25,6 +25,7 @@
opt_tool = "//iree/tools:iree-opt",
opt_flags = [],
tags = [],
+ target_cpu_features = None,
timeout = None,
**kwargs):
"""Creates an iree-check-module test for the specified source file.
@@ -44,9 +45,14 @@
these flags.
tags: additional tags to apply to the generated test. A tag "driver=DRIVER" is added
automatically.
+ target_cpu_features: currently unimplemented (must be empty), will eventually allow specifying target CPU features.
timeout: timeout for the generated tests.
**kwargs: any additional attributes to pass to the underlying run_binary_test.
"""
+
+ if target_cpu_features:
+ fail("target_cpu_features must currently be empty")
+
bytecode_module_name = name + "_bytecode_module"
iree_bytecode_module(
name = bytecode_module_name,
@@ -84,6 +90,7 @@
opt_tool = "//iree/tools:iree-opt",
opt_flags = [],
tags = [],
+ target_cpu_features = None,
timeout = None,
**kwargs):
"""Creates a test suite of iree-check-module tests for a single backend/driver pair.
@@ -104,11 +111,20 @@
if opt_flags is specified.
opt_flags: If specified, source files are preprocessed with OPT_TOOL with
these flags.
+ target_cpu_features: currently unimplemented (must be empty), will eventually allow specifying target CPU features.
tags: tags to apply to the generated tests. Note that as in standard test suites, manual
is treated specially and will also apply to the test suite itself.
timeout: timeout for the generated tests.
**kwargs: any additional attributes to pass to the underlying tests and test suite.
"""
+
+ # We haven't implemented this so far because we have been using target_cpu_features so far only
+ # for aarch64 targets, for which we use the CMake build. To future people implementing this:
+ # target_cpu_features should be a list, and here it should be joined into a comma-separated
+ # string to be passed to --iree-llvm-target-cpu-features
+ if target_cpu_features:
+ fail("target_cpu_features must currently be empty")
+
tests = []
for src in srcs:
test_name = "_".join([name, src])
@@ -147,6 +163,7 @@
opt_tool = "//iree/tools:iree-opt",
opt_flags = [],
tags = [],
+ target_cpu_features_variants = [],
**kwargs):
"""Creates a test suite of iree-check-module tests.
@@ -167,9 +184,16 @@
these flags.
tags: tags to apply to the generated tests. Note that as in standard test suites, manual
is treated specially and will also apply to the test suite itself.
+ target_cpu_features_variants: list of target cpu features variants. Currently unimplemented, so each
+ entry must be either "default" or start with "aarch64:" so as Bazel builds are currently x86-only,
+ we know that it is correct to ignore this.
**kwargs: any additional attributes to pass to the underlying tests and test suite.
"""
+ for target_cpu_features in target_cpu_features_variants:
+ if not (target_cpu_features == "default" or target_cpu_features.startswith("aarch64:")):
+ fail("Entry %s in target_cpu_features_variants: unimplemented" % target_cpu_features)
+
# We could have complicated argument override logic for runner_args and such, or... the client
# could just create a test suite. The latter seems simpler and more readable.
tests = []
diff --git a/build_tools/bazel/iree_trace_runner_test.bzl b/build_tools/bazel/iree_trace_runner_test.bzl
index 17387ae..f5e3127 100644
--- a/build_tools/bazel/iree_trace_runner_test.bzl
+++ b/build_tools/bazel/iree_trace_runner_test.bzl
@@ -22,6 +22,7 @@
opt_tool = "//iree/tools:iree-opt",
opt_flags = [],
tags = [],
+ target_cpu_features = None,
timeout = None,
**kwargs):
"""Creates a test running a custom trace-runner on a trace file (yaml).
@@ -46,9 +47,13 @@
module: specifies the path to use for the enerated IREE module (.vmfb). Mandatory,
unlike in iree_check_test, because trace files (.yaml) reference a specific module file path.
timeout: timeout for the generated tests.
+ target_cpu_features: currently unimplemented (must be empty), will eventually allow specifying target CPU features.
**kwargs: any additional attributes to pass to the underlying tests and test suite.
"""
+ if target_cpu_features:
+ fail("target_cpu_features must currently be empty")
+
bytecode_module_name = name + "_bytecode_module"
iree_bytecode_module(
name = bytecode_module_name,
@@ -93,6 +98,7 @@
opt_tool = "//iree/tools:iree-opt",
opt_flags = [],
tags = [],
+ target_cpu_features = None,
timeout = None,
**kwargs):
"""Generates an iree_trace_runner_test using a custom python generator script.
@@ -122,9 +128,13 @@
these flags.
trace_runner: trace-runner program to run.
timeout: timeout for the generated tests.
+ target_cpu_features: currently unimplemented (must be empty), will eventually allow specifying target CPU features.
**kwargs: any additional attributes to pass to the underlying tests and test suite.
"""
+ if target_cpu_features:
+ fail("target_cpu_features must currently be empty")
+
src = "%s.mlir" % (name)
trace = "%s.yaml" % (name)
module = "%s.vmfb" % (name)
@@ -175,6 +185,7 @@
opt_flags = [],
tags = [],
timeout = None,
+ target_cpu_features_variants = [],
**kwargs):
"""Generates a suite of iree_trace_runner_test on multiple backends/drivers.
@@ -199,9 +210,16 @@
these flags.
trace_runner: trace-runner program to run.
timeout: timeout for the generated tests.
+ target_cpu_features_variants: list of target cpu features variants. Currently unimplemented, so each
+ entry must be either "default" or start with "aarch64:" so as Bazel builds are currently x86-only,
+ we know that it is correct to ignore this.
**kwargs: any additional attributes to pass to the underlying tests and test suite.
"""
+ for target_cpu_features in target_cpu_features_variants:
+ if not (target_cpu_features == "default" or target_cpu_features.startswith("aarch64:")):
+ fail("Entry %s in target_cpu_features_variants: unimplemented" % target_cpu_features)
+
tests = []
for backend, driver in target_backends_and_drivers:
suite_entry_name = "_".join([name, backend, driver])
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index c6fe264..7c66d0c 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -238,7 +238,6 @@
self._convert_unimplemented_function("filegroup", name)
-
def sh_binary(self, name, **kwargs):
self._convert_unimplemented_function("sh_binary", name)
@@ -532,6 +531,7 @@
runner_args=None,
tags=None,
opt_flags=None,
+ target_cpu_features=None,
**kwargs):
name_block = _convert_string_arg_block("NAME", name, quote=False)
srcs_block = _convert_srcs_block(srcs)
@@ -543,6 +543,8 @@
runner_args_block = _convert_string_list_block("RUNNER_ARGS", runner_args)
labels_block = _convert_string_list_block("LABELS", tags)
opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
+ target_cpu_features_block = _convert_string_arg_block(
+ "TARGET_CPU_FEATURES", target_cpu_features)
self.converter.body += (f"iree_check_single_backend_test_suite(\n"
f"{name_block}"
@@ -553,6 +555,7 @@
f"{runner_args_block}"
f"{labels_block}"
f"{opt_flags_block}"
+ f"{target_cpu_features_block}"
f")\n\n")
def iree_check_test_suite(self,
@@ -563,6 +566,7 @@
runner_args=None,
tags=None,
opt_flags=None,
+ target_cpu_features_variants=None,
**kwargs):
target_backends = None
drivers = None
@@ -580,6 +584,8 @@
runner_args_block = _convert_string_list_block("RUNNER_ARGS", runner_args)
labels_block = _convert_string_list_block("LABELS", tags)
opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
+ target_cpu_features_variants_block = _convert_string_list_block(
+ "TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants)
self.converter.body += (f"iree_check_test_suite(\n"
f"{name_block}"
@@ -590,6 +596,7 @@
f"{runner_args_block}"
f"{labels_block}"
f"{opt_flags_block}"
+ f"{target_cpu_features_variants_block}"
f")\n\n")
def iree_generated_trace_runner_test(self,
@@ -603,6 +610,7 @@
tags=None,
opt_tool=None,
opt_flags=None,
+ target_cpu_features_variants=None,
**kwargs):
target_backends = None
drivers = None
@@ -628,6 +636,8 @@
runner_args_block = _convert_string_list_block("RUNNER_ARGS", runner_args)
labels_block = _convert_string_list_block("LABELS", tags)
opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
+ target_cpu_features_variants_block = _convert_string_list_block(
+ "TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants)
self.converter.body += (f"iree_generated_trace_runner_test(\n"
f"{name_block}"
@@ -640,9 +650,9 @@
f"{runner_args_block}"
f"{labels_block}"
f"{opt_flags_block}"
+ f"{target_cpu_features_variants_block}"
f")\n\n")
-
def iree_e2e_cartesian_product_test_suite(self,
name,
matrix,
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index 4a75a94..0340ac6 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -19,7 +19,7 @@
_RULE
""
"MODULE_NAME;SRC;TARGET_BACKEND;OPT_TOOL;MODULE_FILE_NAME"
- "FLAGS;OPT_FLAGS"
+ "FLAGS;OPT_FLAGS;TARGET_CPU_FEATURES"
${ARGN}
)
@@ -37,6 +37,16 @@
list(APPEND _RULE_FLAGS "--iree-llvm-target-triple=${_TARGET_TRIPLE}")
endif()
+ if(_RULE_TARGET_CPU_FEATURES)
+ if(NOT _RULE_TARGET_BACKEND STREQUAL "dylib-llvm-aot")
+ message(SEND_ERROR "TARGET_CPU_FEATURES should be empty when \
+TARGET_BACKEND is not dylib-llvm-aot. Actual values: \
+TARGET_CPU_FEATURES=${_RULE_TARGET_CPU_FEATURES}, \
+TARGET_BACKEND=${_RULE_TARGET_BACKEND}.")
+ endif()
+ list(APPEND _RULE_FLAGS "--iree-llvm-target-cpu-features=${_RULE_TARGET_CPU_FEATURES}")
+ endif()
+
iree_bytecode_module(
NAME
"${_RULE_MODULE_NAME}"
@@ -80,6 +90,8 @@
# these flags.
# MODULE_FILE_NAME: Optional, specifies the absolute path to the filename
# to use for the generated IREE module (.vmfb).
+# TARGET_CPU_FEATURES: If specified, a string passed as argument to
+# --iree-llvm-target-cpu-features.
function(iree_check_test)
if(NOT IREE_BUILD_TESTS)
return()
@@ -109,7 +121,7 @@
_RULE
""
"NAME;SRC;TARGET_BACKEND;DRIVER;OPT_TOOL;MODULE_FILE_NAME"
- "COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ "COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS;TARGET_CPU_FEATURES"
${ARGN}
)
@@ -139,6 +151,8 @@
${_RULE_OPT_TOOL}
OPT_FLAGS
${_RULE_OPT_FLAGS}
+ TARGET_CPU_FEATURES
+ ${_RULE_TARGET_CPU_FEATURES}
)
# iree_bytecode_module does not define a target, only a custom command.
@@ -177,6 +191,7 @@
${_RULE_RUNNER_ARGS}
LABELS
${_RULE_LABELS}
+ ${_RULE_TARGET_CPU_FEATURES}
)
endfunction()
@@ -203,6 +218,8 @@
# if OPT_FLAGS is specified.
# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
# these flags.
+# TARGET_CPU_FEATURES: If specified, a string passed as argument to
+# --iree-llvm-target-cpu-features.
function(iree_check_single_backend_test_suite)
if(NOT IREE_BUILD_TESTS)
return()
@@ -216,7 +233,7 @@
_RULE
""
"NAME;TARGET_BACKEND;DRIVER;OPT_TOOL"
- "SRCS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ "SRCS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS;TARGET_CPU_FEATURES"
${ARGN}
)
@@ -268,10 +285,76 @@
${_RULE_OPT_TOOL}
OPT_FLAGS
${_RULE_OPT_FLAGS}
+ TARGET_CPU_FEATURES
+ ${_RULE_TARGET_CPU_FEATURES}
)
endforeach()
endfunction()
+# Helper function parsing a string occurring as an entry in TARGET_CPU_FEATURES_VARIANTS.
+#
+# This function has 3 output-params: variables that it sets with PARENT_SCOPE:
+# _ENABLED, _TARGET_CPU_FEATURES, _TARGET_CPU_FEATURES_SUFFIX.
+#
+# "default" is handled specially. _ENABLED is always set to "TRUE" and
+# _TARGET_CPU_FEATURES and _TARGET_CPU_FEATURES_SUFFIX are both set to the
+# empty string.
+#
+# Other values are parsed as "arch:features", the parsed arch is matched with
+# `CMAKE_SYSTEM_PROCESSOR`, `_ENABLED` is set to "TRUE" if and only if they
+# match, and `_TARGET_CPU_FEATURES_SUFFIX` is set to a string based on the
+# features that is appropriate to include in a CMake target or test name. More
+# than one target cpu feature is currently unsupported.
+# aarch64:+dotprod -> _TARGET_CPU_FEATURES="+dotprod", _TARGET_CPU_FEATURES_SUFFIX="_dotprod"
+# default -> _TARGET_CPU_FEATURES="", _TARGET_CPU_FEATURES_SUFFIX="", ENABLED="TRUE"
+function(process_target_cpu_features _INPUT_TARGET_CPU_FEATURES _ENABLED
+ _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX)
+ if ("${_INPUT_TARGET_CPU_FEATURES}" STREQUAL "default")
+ set(_ENABLED "TRUE" PARENT_SCOPE)
+ set(_TARGET_CPU_FEATURES "" PARENT_SCOPE)
+ set(_TARGET_CPU_FEATURES_SUFFIX "" PARENT_SCOPE)
+ return()
+ endif()
+ string(REGEX MATCHALL "[^:]+" _COMPONENTS "${_INPUT_TARGET_CPU_FEATURES}")
+ list(LENGTH _COMPONENTS _NUM_COMPONENTS)
+ if (NOT _NUM_COMPONENTS EQUAL 2)
+ message (SEND_ERROR "TARGET_CPU_FEATURES should be of the form \
+_FILTER_ARCH:_TARGET_CPU_FEATURES. Got: ${_INPUT_TARGET_CPU_FEATURES}")
+ return()
+ endif()
+ # TARGET_CPU_FEATURES_VARIANT is of the form _FILTER_ARCH:_TARGET_CPU_FEATURE.
+ list(GET _COMPONENTS 0 _FILTER_ARCH)
+ list(GET _COMPONENTS 1 _TARGET_CPU_FEATURES)
+ if (_FILTER_ARCH STREQUAL CMAKE_SYSTEM_PROCESSOR)
+ set(_ENABLED "TRUE" PARENT_SCOPE)
+ set(_TARGET_CPU_FEATURES "${_TARGET_CPU_FEATURES}" PARENT_SCOPE)
+ # TODO: the logic to generate the suffix from the list of target CPU features
+ # will need to be generalized when the lists have more than 1 element, when
+ # some features are being disabled by a "-" sign, and if some features involve
+ # any character that's not wanted in a cmake rule name.
+ # For now, let's just generate errors in those cases:
+ list(LENGTH _TARGET_CPU_FEATURES _NUM_TARGET_CPU_FEATURES)
+ if (NOT _NUM_TARGET_CPU_FEATURES EQUAL 1)
+ message(SEND_ERROR "Current limitation: \
+TARGET_CPU_FEATURES should have length 1")
+ endif()
+ string(SUBSTRING "${_TARGET_CPU_FEATURES}" 0 1 _TARGET_CPU_FEATURES_FIRST_CHAR)
+ string(SUBSTRING "${_TARGET_CPU_FEATURES}" 1 -1 _TARGET_CPU_FEATURES_AFTER_FIRST_CHAR)
+ if (NOT _TARGET_CPU_FEATURES_FIRST_CHAR STREQUAL "+")
+ message(SEND_ERROR "Current limitation: \
+TARGET_CPU_FEATURES should start with a +. Got: ${_TARGET_CPU_FEATURES}.")
+ endif()
+ if (NOT _TARGET_CPU_FEATURES_AFTER_FIRST_CHAR MATCHES "[a-zA-Z0-9_]+")
+ message(SEND_ERROR "Current limitation: \
+TARGET_CPU_FEATURES should match [a-zA-Z0-9]+ after the initial +. \
+Got: ${_TARGET_CPU_FEATURES}.")
+ endif()
+ string(REPLACE "+" "_" _TARGET_CPU_FEATURES_SUFFIX_LOCAL "${_TARGET_CPU_FEATURES}")
+ set(_TARGET_CPU_FEATURES_SUFFIX "${_TARGET_CPU_FEATURES_SUFFIX_LOCAL}" PARENT_SCOPE)
+ else()
+ set(_ENABLED "FALSE" PARENT_SCOPE)
+ endif()
+endfunction()
# iree_check_test_suite()
#
@@ -300,6 +383,12 @@
# if OPT_FLAGS is specified.
# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
# these flags.
+# TARGET_CPU_FEATURES_VARIANTS: list of target cpu features variants. Only used
+# for drivers that vary based on the target CPU features. For each list
+# element, a separate test is created, with the list element passed as
+# argument to --iree-llvm-target-cpu-features. The special value "default"
+# is interpreted as no --iree-llvm-target-cpu-features flag to work around
+# corner cases with empty entries in CMake lists.
function(iree_check_test_suite)
if(NOT IREE_BUILD_TESTS)
return()
@@ -309,7 +398,7 @@
_RULE
""
"NAME"
- "SRCS;TARGET_BACKENDS;DRIVERS;RUNNER_ARGS;LABELS"
+ "SRCS;TARGET_BACKENDS;DRIVERS;RUNNER_ARGS;LABELS;TARGET_CPU_FEATURES_VARIANTS"
${ARGN}
)
@@ -330,26 +419,39 @@
foreach(_INDEX RANGE "${_MAX_INDEX}")
list(GET _RULE_TARGET_BACKENDS ${_INDEX} _TARGET_BACKEND)
list(GET _RULE_DRIVERS ${_INDEX} _DRIVER)
- set(_SUITE_NAME "${_RULE_NAME}_${_TARGET_BACKEND}_${_DRIVER}")
- iree_check_single_backend_test_suite(
- NAME
- ${_SUITE_NAME}
- SRCS
- ${_RULE_SRCS}
- TARGET_BACKEND
- ${_TARGET_BACKEND}
- DRIVER
- ${_DRIVER}
- COMPILER_FLAGS
- ${_RULE_COMPILER_FLAGS}
- RUNNER_ARGS
- ${_RULE_RUNNER_ARGS}
- LABELS
- ${_RULE_LABELS}
- OPT_TOOL
- ${_RULE_OPT_TOOL}
- OPT_FLAGS
- ${_RULE_OPT_FLAGS}
- )
+ if (_TARGET_BACKEND STREQUAL "dylib-llvm-aot" AND _RULE_TARGET_CPU_FEATURES_VARIANTS)
+ set(_TARGET_CPU_FEATURES_VARIANTS "${_RULE_TARGET_CPU_FEATURES_VARIANTS}")
+ else()
+ set(_TARGET_CPU_FEATURES_VARIANTS "default")
+ endif()
+ foreach(_TARGET_CPU_FEATURES_LIST_ELEM IN LISTS _TARGET_CPU_FEATURES_VARIANTS)
+ process_target_cpu_features("${_TARGET_CPU_FEATURES_LIST_ELEM}" _ENABLED _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX)
+ if (NOT _ENABLED)
+ # The current entry is disabled on the target CPU architecture.
+ continue()
+ endif()
+ iree_check_single_backend_test_suite(
+ NAME
+ "${_RULE_NAME}_${_TARGET_BACKEND}_${_DRIVER}${_TARGET_CPU_FEATURES_SUFFIX}"
+ SRCS
+ ${_RULE_SRCS}
+ TARGET_BACKEND
+ ${_TARGET_BACKEND}
+ DRIVER
+ ${_DRIVER}
+ COMPILER_FLAGS
+ ${_RULE_COMPILER_FLAGS}
+ RUNNER_ARGS
+ ${_RULE_RUNNER_ARGS}
+ LABELS
+ ${_RULE_LABELS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ TARGET_CPU_FEATURES
+ ${_TARGET_CPU_FEATURES}
+ )
+ endforeach()
endforeach()
endfunction()
diff --git a/build_tools/cmake/iree_trace_runner_test.cmake b/build_tools/cmake/iree_trace_runner_test.cmake
index 13e97fc..a2457a4 100644
--- a/build_tools/cmake/iree_trace_runner_test.cmake
+++ b/build_tools/cmake/iree_trace_runner_test.cmake
@@ -31,6 +31,8 @@
# MODULE_FILE_NAME: specifies the absolute path to the filename to use for the
# generated IREE module (.vmfb). Mandatory, unlike in iree_check_test,
# because trace files (.yaml) reference a specific module file path.
+# TARGET_CPU_FEATURES: If specified, a string passed as argument to
+# --iree-llvm-target-cpu-features.
function(iree_trace_runner_test)
if(NOT IREE_BUILD_TESTS)
return()
@@ -45,7 +47,7 @@
_RULE
""
"NAME;SRC;TRACE;TARGET_BACKEND;DRIVER;OPT_TOOL;TRACE_RUNNER;MODULE_FILE_NAME"
- "COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ "COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS;TARGET_CPU_FEATURES"
${ARGN}
)
@@ -69,6 +71,8 @@
${_RULE_OPT_TOOL}
OPT_FLAGS
${_RULE_OPT_FLAGS}
+ TARGET_CPU_FEATURES
+ ${_RULE_TARGET_CPU_FEATURES}
)
# iree_bytecode_module does not define a target, only a custom command.
@@ -107,6 +111,7 @@
${_RULE_RUNNER_ARGS}
LABELS
${_RULE_LABELS}
+ ${_RULE_TARGET_CPU_FEATURES}
)
endfunction()
@@ -138,6 +143,8 @@
# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
# these flags.
# TRACE_RUNNER: trace-runner program to run.
+# TARGET_CPU_FEATURES: If specified, a string passed as argument to
+# --iree-llvm-target-cpu-features.
function(iree_single_backend_generated_trace_runner_test)
if(NOT IREE_BUILD_TESTS)
return()
@@ -159,7 +166,7 @@
_RULE
""
"NAME;GENERATOR;TARGET_BACKEND;DRIVER;OPT_TOOL;TRACE_RUNNER"
- "GENERATOR_ARGS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ "GENERATOR_ARGS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS;TARGET_CPU_FEATURES"
${ARGN}
)
@@ -246,6 +253,8 @@
${_RULE_OPT_TOOL}
OPT_FLAGS
${_RULE_OPT_FLAGS}
+ TARGET_CPU_FEATURES
+ ${_RULE_TARGET_CPU_FEATURES}
)
# Note we are relying on the fact that the target created by
@@ -291,6 +300,12 @@
# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
# these flags.
# TRACE_RUNNER: trace-runner program to run.
+# TARGET_CPU_FEATURES_VARIANTS: list of target cpu features variants. Only used
+# for drivers that vary based on the target CPU features. For each list
+# element, a separate test is created, with the list element passed as
+# argument to --iree-llvm-target-cpu-features. The special value "default"
+# is interpreted as no --iree-llvm-target-cpu-features flag to work around
+# corner cases with empty entries in CMake lists.
function(iree_generated_trace_runner_test)
if(NOT IREE_BUILD_TESTS)
return()
@@ -300,7 +315,7 @@
_RULE
""
"NAME;GENERATOR;OPT_TOOL;TRACE_RUNNER"
- "TARGET_BACKENDS;DRIVERS;GENERATOR_ARGS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ "TARGET_BACKENDS;DRIVERS;GENERATOR_ARGS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS;TARGET_CPU_FEATURES_VARIANTS"
${ARGN}
)
@@ -321,30 +336,43 @@
foreach(_INDEX RANGE "${_MAX_INDEX}")
list(GET _RULE_TARGET_BACKENDS ${_INDEX} _TARGET_BACKEND)
list(GET _RULE_DRIVERS ${_INDEX} _DRIVER)
- set(_SINGLE_BACKEND_TEST_NAME "${_RULE_NAME}_${_TARGET_BACKEND}_${_DRIVER}")
- iree_single_backend_generated_trace_runner_test(
- NAME
- ${_SINGLE_BACKEND_TEST_NAME}
- GENERATOR
- ${_RULE_GENERATOR}
- GENERATOR_ARGS
- ${_RULE_GENERATOR_ARGS}
- TRACE_RUNNER
- ${_RULE_TRACE_RUNNER}
- TARGET_BACKEND
- ${_TARGET_BACKEND}
- DRIVER
- ${_DRIVER}
- COMPILER_FLAGS
- ${_RULE_COMPILER_FLAGS}
- RUNNER_ARGS
- ${_RULE_RUNNER_ARGS}
- LABELS
- ${_RULE_LABELS}
- OPT_TOOL
- ${_RULE_OPT_TOOL}
- OPT_FLAGS
- ${_RULE_OPT_FLAGS}
- )
+ if (_TARGET_BACKEND STREQUAL "dylib-llvm-aot" AND _RULE_TARGET_CPU_FEATURES_VARIANTS)
+ set(_TARGET_CPU_FEATURES_VARIANTS "${_RULE_TARGET_CPU_FEATURES_VARIANTS}")
+ else()
+ set(_TARGET_CPU_FEATURES_VARIANTS "default")
+ endif()
+ foreach(_TARGET_CPU_FEATURES_LIST_ELEM IN LISTS _TARGET_CPU_FEATURES_VARIANTS)
+ process_target_cpu_features("${_TARGET_CPU_FEATURES_LIST_ELEM}" _ENABLED _TARGET_CPU_FEATURES _TARGET_CPU_FEATURES_SUFFIX)
+ if (NOT _ENABLED)
+ # The current entry is disabled on the target CPU architecture.
+ continue()
+ endif()
+ iree_single_backend_generated_trace_runner_test(
+ NAME
+ "${_RULE_NAME}_${_TARGET_BACKEND}_${_DRIVER}${_TARGET_CPU_FEATURES_SUFFIX}"
+ GENERATOR
+ ${_RULE_GENERATOR}
+ GENERATOR_ARGS
+ ${_RULE_GENERATOR_ARGS}
+ TRACE_RUNNER
+ ${_RULE_TRACE_RUNNER}
+ TARGET_BACKEND
+ ${_TARGET_BACKEND}
+ DRIVER
+ ${_DRIVER}
+ COMPILER_FLAGS
+ ${_RULE_COMPILER_FLAGS}
+ RUNNER_ARGS
+ ${_RULE_RUNNER_ARGS}
+ LABELS
+ ${_RULE_LABELS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ TARGET_CPU_FEATURES
+ ${_TARGET_CPU_FEATURES}
+ )
+ endforeach()
endforeach()
endfunction()
diff --git a/iree/compiler/Codegen/LLVMCPU/BUILD b/iree/compiler/Codegen/LLVMCPU/BUILD
index a24a4e2..ea350c1 100644
--- a/iree/compiler/Codegen/LLVMCPU/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/BUILD
@@ -22,7 +22,7 @@
"LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp",
"LLVMCPUUnfuseFMAOps.cpp",
"Passes.cpp",
- "VectorContractToAArch64InlineAsmOp.cpp",
+ "VectorContractCustomKernels.cpp",
],
hdrs = [
"KernelDispatch.h",
diff --git a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
index e29a47c..63674ee 100644
--- a/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
@@ -25,7 +25,7 @@
"LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp"
"LLVMCPUUnfuseFMAOps.cpp"
"Passes.cpp"
- "VectorContractToAArch64InlineAsmOp.cpp"
+ "VectorContractCustomKernels.cpp"
DEPS
IREELinalgExtDialect
IREELinalgExtPasses
diff --git a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
index ec4e284..cd3b6bf 100644
--- a/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
@@ -94,12 +94,9 @@
static DispatchLoweringPassPipeline getDispatchLoweringPassPipeline(
FuncOp entryPointFn, Operation *op) {
return TypeSwitch<Operation *, DispatchLoweringPassPipeline>(op)
- .Case<linalg::ContractionOpInterface>([&](auto op) {
+ .Case<linalg::ContractionOpInterface, linalg::Mmt4DOp>([&](auto op) {
return DispatchLoweringPassPipeline::CPUTileFuseAndVectorize;
})
- .Case<linalg::Mmt4DOp>([&](auto op) {
- return DispatchLoweringPassPipeline::CPUTensorToVectors;
- })
.Default([&](Operation *op) {
return DispatchLoweringPassPipeline::CPUDefault;
});
diff --git a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
index 35c989d..093030b 100644
--- a/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
+++ b/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileFuseAndVectorizeLinalgTensorOps.cpp
@@ -254,6 +254,15 @@
[&](linalg::LinalgOp op) { setMarker(op, getVectorizeMarker()); });
}
+ // Op specific conversion.
+ {
+ RewritePatternSet patterns(context);
+ populateLinalgToVectorVectorizeMMT4dPatterns(context, patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+
// Apply vectorization patterns.
{
OwningRewritePatternList vectorizationPatterns(&getContext());
@@ -328,6 +337,19 @@
llvm::dbgs() << "\n\n";
});
+ {
+ // Special-case vector.contract codegen paths. This needs to happen
+ // just before the generic vector ops lowerings.
+ CustomKernelsTargetInfo info;
+ if (succeeded(InferCustomKernelsTargetInfoFromParent(funcOp, info))) {
+ RewritePatternSet patterns(context);
+ populateVectorContractCustomKernelsPatterns(info, patterns);
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+ }
+
// Apply vector specific operation lowering.
{
vector::VectorTransformsOptions vectorTransformsOptions =
diff --git a/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp b/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
new file mode 100644
index 0000000..11e3bc8
--- /dev/null
+++ b/iree/compiler/Codegen/LLVMCPU/VectorContractCustomKernels.cpp
@@ -0,0 +1,368 @@
+// Copyright 2020 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/PassDetail.h"
+#include "iree/compiler/Codegen/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Triple.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+LogicalResult InferCustomKernelsTargetInfoFromParent(
+ FuncOp entryPointFn, CustomKernelsTargetInfo &target_info) {
+ // Set the out-value to defaults early so that early returns produce
+ // consistent results and so that we can write simpler code below
+ // (for loop OR-ing booleans, assuming initial 'false' value).
+ target_info = CustomKernelsTargetInfo();
+
+ // Try to find the parent ExecutableVariantOp and its relevant attributes.
+ auto variantOp =
+ entryPointFn->getParentOfType<IREE::HAL::ExecutableVariantOp>();
+ if (!variantOp) {
+ return failure();
+ }
+ IREE::HAL::ExecutableTargetAttr targetAttr = variantOp.target();
+ if (!targetAttr) {
+ return failure();
+ }
+ auto config = targetAttr.getConfiguration();
+ if (!config) {
+ return failure();
+ }
+ auto tripleAttr = config.getAs<StringAttr>("target_triple");
+ if (!tripleAttr) {
+ return failure();
+ }
+ auto cpuFeaturesAttr = config.getAs<StringAttr>("cpu_features");
+ if (!cpuFeaturesAttr) {
+ return failure();
+ }
+
+ // Set the out-value target_info fields.
+ llvm::Triple triple(tripleAttr.getValue());
+ llvm::SmallVector<llvm::StringRef> cpuFeatures;
+ cpuFeaturesAttr.getValue().split(cpuFeatures, ',');
+ switch (triple.getArch()) {
+ case llvm::Triple::ArchType::aarch64:
+ target_info.aarch64 = true;
+ for (auto f : cpuFeatures) {
+ target_info.dotprod |= (f == "+dotprod");
+ }
+ break;
+ default:
+ break;
+ }
+ return success();
+}
+
+namespace {
+
+// Returns true if `contractionOp` is of the form
+// matrix * transposed_matrix.
+// That is, if there are 2 parallel iterators, say M and N, 1 additive reduction
+// iterator, say K, and the indexing maps are {{M, K}, {N, K}, {M, N}}.
+static bool isMatrixTimesMatrixTransposed(vector::ContractionOp contractionOp) {
+ // Check that the reduction is additive.
+ if (contractionOp.kind() != vector::CombiningKind::ADD) {
+ return false;
+ }
+ // Check that there are 2 parallel and 1 reduction iterators.
+ auto iteratorTypes = contractionOp.iterator_types().getValue();
+ if (iteratorTypes.size() != 3) {
+ return false;
+ }
+ SmallVector<int, 3> parallel_iterators;
+ SmallVector<int, 3> reduction_iterators;
+ for (int i = 0; i < 3; i++) {
+ if (isParallelIterator(iteratorTypes[i])) {
+ parallel_iterators.push_back(i);
+ } else if (isReductionIterator(iteratorTypes[i])) {
+ reduction_iterators.push_back(i);
+ } else {
+ return false;
+ }
+ }
+ if (parallel_iterators.size() != 2 || reduction_iterators.size() != 1) {
+ return false;
+ }
+ // Give the found iterators some idiomatic names.
+ const int MIter = parallel_iterators[0];
+ const int NIter = parallel_iterators[1];
+ const int KIter = reduction_iterators[0];
+ // Check that there are 3 indexing maps.
+ auto indexingMaps = contractionOp.indexing_maps().getValue();
+ if (indexingMaps.size() != 3) {
+ return false;
+ }
+ // Check that the indexing maps have the expected form.
+ const int expectedMapResults[3][2] = {
+ {MIter, KIter}, {NIter, KIter}, {MIter, NIter}};
+ for (int m = 0; m < 3; ++m) {
+ auto map = indexingMaps[m].cast<AffineMapAttr>().getValue();
+ if (map.getNumDims() != 3 || map.getNumResults() != 2) {
+ return false;
+ }
+ for (int r = 0; r < 2; ++r) {
+ int actualMapResult =
+ map.getResults()[r].cast<AffineDimExpr>().getPosition();
+ if (actualMapResult != expectedMapResults[m][r]) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+// Returns true if `contractionOp` is of the form
+// matrix * transposed_matrix
+// where matrix is a vector<{mSize}x{kSize}xType>, and
+// transposed_matrix is a vector<{nSize}x{kSize}xType>
+static bool isMatrixTimesMatrixTransposedOfGivenShape(
+ vector::ContractionOp contractionOp, int64_t mSize, int64_t kSize,
+ int64_t nSize) {
+ if (!isMatrixTimesMatrixTransposed(contractionOp)) {
+ return false;
+ }
+ VectorType lhsType = contractionOp.lhs().getType().cast<VectorType>();
+ VectorType rhsType = contractionOp.rhs().getType().cast<VectorType>();
+ auto lhsShape = lhsType.getShape();
+ auto rhsShape = rhsType.getShape();
+ if (lhsShape[0] != mSize || lhsShape[1] != kSize || rhsShape[0] != nSize ||
+ rhsShape[1] != kSize) {
+ return false;
+ }
+ return true;
+}
+
+// Checks that the Value `extResult` is defined by an arith::ExtSIOp promoting
+// from `extSrcType` to `extDstType`, and returns the input of the ExtSIOp.
+static Value getExtSIInput(Type extSrcType, Type extDstType, Value extResult) {
+ auto extSIOp = extResult.getDefiningOp<arith::ExtSIOp>();
+ if (!extSIOp) {
+ return nullptr;
+ }
+ Value extInput = extSIOp.getIn();
+ if (extInput.getType().cast<VectorType>().getElementType() != extSrcType) {
+ return nullptr;
+ }
+ return extInput;
+}
+
+// Helper to create a 1D, contiguous slice of a 1D vector.
+static Value extract1DSlice(PatternRewriter &rewriter, Location loc,
+ VectorType dstVecType, Value input, int position) {
+ assert(input.getType().cast<VectorType>().getRank() == 1);
+ assert(dstVecType.getRank() == 1);
+ std::array<int64_t, 1> offsets{position};
+ std::array<int64_t, 1> strides{1};
+ return rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, input, offsets, dstVecType.getShape(), strides);
+}
+
+// Helper to flatten a N-dimensional vector to a 1D vector.
+static Value flatten(PatternRewriter &rewriter, Location loc, Value vector) {
+ VectorType inputVecType = vector.getType().cast<VectorType>();
+ VectorType dstType = VectorType::get(inputVecType.getNumElements(),
+ inputVecType.getElementType());
+ return rewriter.create<vector::ShapeCastOp>(loc, dstType, vector);
+}
+
+/// Converts matrix-times-matrix-transposed vector.contracts with
+/// lhs and rhs inputs defined by arith.extsi promoting from i8 to i32,
+///
+/// %lhs_i32 = arith.extsi %lhs_i8 : i8 to i32
+/// %rhs_i32 = arith.extsi %rhs_i8 : i8 to i32
+/// %result = vector.contract [...]
+/// %lhs_i32 : vector<8x4xi32>,
+/// %rhs_i32 : vector<8x4xi32>,
+/// %acc_i32 : vector<8x8xi32>,
+/// [...]
+///
+/// To vector ops reading directly from the %lhs_i8 and %rhs_i8 values
+/// (bypassing the existing arith.extsi) and passing that to a llvm.inline_asm
+/// block implementing the matrix multiplication arithmetic using Aarch64
+/// dot-product instructions (sdot).
+struct MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm
+ : OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractionOp,
+ PatternRewriter &rewriter) const override {
+ // Check if `contractionOp` matches, and obtain the un-promoted i8 input
+ // LHS and RHS vectors, `lhsI8` and `rhsI8`.
+ if (!isMatrixTimesMatrixTransposedOfGivenShape(contractionOp, 8, 4, 8)) {
+ return failure();
+ }
+ Type I8Type = rewriter.getIntegerType(8);
+ Type I32Type = rewriter.getIntegerType(32);
+ VectorType accType = contractionOp.acc().getType().cast<VectorType>();
+ if (accType.getElementType() != I32Type) {
+ return failure();
+ }
+ Value lhsI8 = getExtSIInput(I8Type, I32Type, contractionOp.lhs());
+ Value rhsI8 = getExtSIInput(I8Type, I32Type, contractionOp.rhs());
+ if (!lhsI8 || !rhsI8) {
+ return failure();
+ }
+
+ // `contractionOp` matches, start rewriting it. We only reference
+ // the `lhsI8` and `rhsI8` values obtained above as the inputs of the
+ // arith.extsi, so this rewrite will leave the existing arith.extsi without
+ // any user (unless something else was using them), so they may be
+ // removed by another transformation.
+ Location loc = contractionOp.getLoc();
+ // Flatten the inputs to 1D vectors.
+ Value flatLhsI8 = flatten(rewriter, loc, lhsI8);
+ Value flatRhsI8 = flatten(rewriter, loc, rhsI8);
+ Value flatAcc = flatten(rewriter, loc, contractionOp.acc());
+
+ // Create the 1D input vectors of 16 bytes each that are directly what
+ // the target SIMD instructions will want.
+ SmallVector<Value> lhsVec;
+ SmallVector<Value> rhsVec;
+ VectorType vector16xi8Type = VectorType::get({16}, I8Type);
+ for (int position = 0; position < 8 * 4; position += 16) {
+ lhsVec.push_back(
+ extract1DSlice(rewriter, loc, vector16xi8Type, flatLhsI8, position));
+ rhsVec.push_back(
+ extract1DSlice(rewriter, loc, vector16xi8Type, flatRhsI8, position));
+ }
+ SmallVector<Value> accVec;
+ VectorType int32x4Type = VectorType::get({4}, I32Type);
+ for (int position = 0; position < 8 * 8; position += 4) {
+ accVec.push_back(
+ extract1DSlice(rewriter, loc, int32x4Type, flatAcc, position));
+ }
+
+ // Start of the code that's specific to inline assembly. An intrinsics
+ // code path would diverge here.
+
+ // Create the inline asm op's operands list.
+ SmallVector<Value> asmOperands;
+ // First the inputs operands.
+ asmOperands.append(lhsVec);
+ asmOperands.append(rhsVec);
+ // Then the input-output operands.
+ asmOperands.append(accVec);
+ SmallVector<Type> asmOutputOperandTypes(
+ llvm::map_range(accVec, [](Value v) { return v.getType(); }));
+
+ // Create the inline asm op.
+ auto returnType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
+ asmOutputOperandTypes);
+ auto dialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
+ LLVM::AsmDialect::AD_ATT);
+ // The LLVM inline asm syntax is documented here:
+ // https://llvm.org/docs/LangRef.html#inline-assembler-expressions
+ LLVM::InlineAsmOp asmOp = rewriter.create<LLVM::InlineAsmOp>(
+ loc, returnType, asmOperands,
+ R"ASM(
+ sdot $0.4s, $18.16b, $16.4b[0]
+ sdot $1.4s, $19.16b, $16.4b[0]
+ sdot $2.4s, $18.16b, $16.4b[1]
+ sdot $3.4s, $19.16b, $16.4b[1]
+ sdot $4.4s, $18.16b, $16.4b[2]
+ sdot $5.4s, $19.16b, $16.4b[2]
+ sdot $6.4s, $18.16b, $16.4b[3]
+ sdot $7.4s, $19.16b, $16.4b[3]
+ sdot $8.4s, $18.16b, $17.4b[0]
+ sdot $9.4s, $19.16b, $17.4b[0]
+ sdot $10.4s, $18.16b, $17.4b[1]
+ sdot $11.4s, $19.16b, $17.4b[1]
+ sdot $12.4s, $18.16b, $17.4b[2]
+ sdot $13.4s, $19.16b, $17.4b[2]
+ sdot $14.4s, $18.16b, $17.4b[3]
+ sdot $15.4s, $19.16b, $17.4b[3]
+ )ASM",
+ "=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,=w,w,w,w,w,0,1,2,3,4,5,6,"
+ "7,8,9,10,11,12,13,14,15",
+ /*has_side_effects=*/false, /*is_align_stack=*/false, dialectAttr);
+
+ // Extract result vectors from the asm op.
+ SmallVector<Value, 16> resVec;
+ for (int i = 0; i < 16; ++i) {
+ resVec.push_back(rewriter.create<LLVM::ExtractValueOp>(
+ loc, int32x4Type, asmOp.getRes(), rewriter.getI64ArrayAttr({i})));
+ }
+
+ // End of the code that's specific to inline assembly. An intrinsics code
+ // path would merge here.
+
+ // Insert the result vectors of size 4 into the overall result vector of
+ // size 64, still 1D.
+ VectorType int32x64xType = VectorType::get({64}, I32Type);
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, int32x64xType, DenseIntElementsAttr::get(int32x64xType, 0));
+ for (int i = 0; i < 16; ++i) {
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, resVec[i], result, std::array<int64_t, 1>{4 * i},
+ std::array<int64_t, 1>{1});
+ }
+
+ // Cast the result from 1D to 2D and replace the original vector.contract.
+ VectorType int32x8x8xType = VectorType::get({8, 8}, I32Type);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(contractionOp,
+ int32x8x8xType, result);
+ return success();
+ }
+};
+
+class VectorContractCustomKernelsPass
+ : public VectorContractCustomKernelsBase<VectorContractCustomKernelsPass> {
+ public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect, LLVM::LLVMDialect>();
+ }
+ LogicalResult initializeOptions(StringRef options) override {
+ if (failed(Pass::initializeOptions(options))) {
+ return failure();
+ }
+ target_info.aarch64 = aarch64;
+ target_info.dotprod = dotprod;
+ return success();
+ }
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ OwningRewritePatternList patterns(context);
+ populateVectorContractCustomKernelsPatterns(target_info, patterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+
+ private:
+ CustomKernelsTargetInfo target_info;
+};
+
+} // namespace
+
+void populateVectorContractCustomKernelsPatterns(
+ const CustomKernelsTargetInfo &target_info,
+ OwningRewritePatternList &patterns) {
+ MLIRContext *context = patterns.getContext();
+ if (target_info.aarch64 && target_info.dotprod) {
+ patterns.insert<MMT_8x4x8_i8i8i32_Aarch64Dotprod_InlineAsm>(context);
+ }
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createVectorContractCustomKernelsPass() {
+ return std::make_unique<VectorContractCustomKernelsPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Codegen/LLVMCPU/VectorContractToAArch64InlineAsmOp.cpp b/iree/compiler/Codegen/LLVMCPU/VectorContractToAArch64InlineAsmOp.cpp
deleted file mode 100644
index 3e06019..0000000
--- a/iree/compiler/Codegen/LLVMCPU/VectorContractToAArch64InlineAsmOp.cpp
+++ /dev/null
@@ -1,173 +0,0 @@
-// Copyright 2020 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/PassDetail.h"
-#include "iree/compiler/Codegen/Passes.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-namespace iree_compiler {
-
-namespace {
-
-/// Converts 4x4x4 vector contraction with matmul(A_, B) semantics to AArch64
-/// inline assembly using aarch64 4-sdot instructions. Each sdot instruction
-/// performas a single matrix-vector product and to compute matmul(A, B) with
-/// matrix-vector products B is transposed.
-struct ConvertVectorContract4x4x4_i8i8i32_ToAArch64InlineAsmPattern
- : public OpRewritePattern<vector::ContractionOp> {
- public:
- using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ContractionOp contractionOp,
- PatternRewriter &rewriter) const override {
- auto lhsType = contractionOp.lhs().getType().cast<VectorType>();
- auto rhsType = contractionOp.rhs().getType().cast<VectorType>();
- auto accType = contractionOp.acc().getType().cast<VectorType>();
- auto lhsShape = lhsType.getShape();
- auto rhsShape = rhsType.getShape();
- if (lhsShape[0] != 4 || lhsShape[1] != 4 || rhsShape[0] != 4 ||
- rhsShape[1] != 4) {
- return failure();
- }
-
- Value inLhs = contractionOp.lhs();
- Value inRhs = contractionOp.rhs();
-
- auto I8Type = rewriter.getIntegerType(8);
- auto I32Type = rewriter.getIntegerType(32);
-
- if (accType.getElementType() != I32Type) {
- return failure();
- }
-
- auto getI8Value = [&](Value v) -> Value {
- if (auto parentOp = v.getDefiningOp<arith::ExtSIOp>()) {
- if (parentOp.getIn().getType().cast<VectorType>().getElementType() !=
- I8Type) {
- return nullptr;
- } else {
- return parentOp.getIn();
- }
- }
- return nullptr;
- };
- if (lhsType.getElementType() != I8Type) {
- inLhs = getI8Value(inLhs);
- }
- if (rhsType.getElementType() != I8Type) {
- inRhs = getI8Value(inRhs);
- }
-
- if (!inLhs || !inRhs) return failure();
-
- auto loc = contractionOp.getLoc();
-
- SmallVector<Value> dstVec;
- for (int i = 0; i < 4; ++i) {
- dstVec.push_back(
- rewriter.create<vector::ExtractOp>(loc, contractionOp.acc(), i));
- }
-
- auto flattnedVectorType = VectorType::get({16}, I8Type);
-
- auto lhs =
- rewriter.create<vector::ShapeCastOp>(loc, flattnedVectorType, inLhs);
-
- auto inRhsTransposed = rewriter.create<vector::TransposeOp>(
- loc, inRhs, ArrayRef<int64_t>({1, 0}));
-
- auto rhs = rewriter.create<vector::ShapeCastOp>(loc, flattnedVectorType,
- inRhsTransposed);
-
- auto int32x4VType = VectorType::get({4}, I32Type);
-
- auto returnType = LLVM::LLVMStructType::getLiteral(
- rewriter.getContext(),
- {int32x4VType, int32x4VType, int32x4VType, int32x4VType});
-
- /// TODO(ataei): We an have a formatter like the c++ inline asm to
- /// ssa-values to string names which will make the inline-assembly
- /// statements more redable e.g :
- /// sdot ${dstVec_0}.4s, ${lhs}.16b,${rhs}.4b[0]
- auto packedResult = rewriter.create<LLVM::InlineAsmOp>(
- loc, returnType,
- ArrayRef<Value>({lhs, rhs, dstVec[0], dstVec[1], dstVec[2], dstVec[3]}),
- R"ASM(
- sdot $0.4s, $4.16b, $5.4b[0]
- sdot $1.4s, $4.16b, $5.4b[1]
- sdot $2.4s, $4.16b, $5.4b[2]
- sdot $3.4s, $4.16b, $5.4b[3]
- )ASM",
- "=w,=w,=w,=w,w,w,0,1,2,3", false, false,
- LLVM::AsmDialectAttr::get(rewriter.getContext(),
- LLVM::AsmDialect::AD_ATT));
-
- auto resVec =
- llvm::to_vector<4>(llvm::map_range(llvm::seq<int>(0, 4), [&](int i) {
- return rewriter.create<LLVM::ExtractValueOp>(
- loc, int32x4VType, packedResult.getRes(),
- rewriter.getI64ArrayAttr({i}));
- }));
-
- auto int32x4x4xVType = VectorType::get({4, 4}, I32Type);
-
- Value result;
- result = rewriter.create<arith::ConstantOp>(
- loc, int32x4x4xVType, DenseIntElementsAttr::get(int32x4x4xVType, 0));
- for (int i = 0; i < 4; ++i) {
- result = rewriter.create<vector::InsertOp>(loc, resVec[i], result,
- ArrayRef<int64_t>({i}));
- }
- rewriter.replaceOp(contractionOp, {result});
- return success();
- }
-};
-
-} // namespace
-
-namespace {
-struct VectorToAArch64InlineAsmPass
- : public VectorToAArch64InlineAsmBase<VectorToAArch64InlineAsmPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect, LLVM::LLVMDialect>();
- }
- void runOnOperation() override;
-};
-} // namespace
-
-void populateVectorContractToAArch64InlineAsm(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<ConvertVectorContract4x4x4_i8i8i32_ToAArch64InlineAsmPattern>(
- context);
-}
-
-void VectorToAArch64InlineAsmPass::runOnOperation() {
- MLIRContext *context = &getContext();
- OwningRewritePatternList patterns(context);
- populateVectorContractToAArch64InlineAsm(patterns, context);
-
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
- signalPassFailure();
- }
-}
-
-std::unique_ptr<OperationPass<FuncOp>>
-createVectorToAArch64InlineAssemblyPass() {
- return std::make_unique<VectorToAArch64InlineAsmPass>();
-}
-
-} // namespace iree_compiler
-} // namespace mlir
diff --git a/iree/compiler/Codegen/LLVMCPU/test/BUILD b/iree/compiler/Codegen/LLVMCPU/test/BUILD
index e267572..511c649 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/BUILD
+++ b/iree/compiler/Codegen/LLVMCPU/test/BUILD
@@ -31,7 +31,7 @@
"tile_and_vectorize.mlir",
"tile_fuse_and_vectorize.mlir",
"unfused_fma.mlir",
- "vector_contract_to_aarch64_asm.mlir",
+ "vector_contract_custom_kernels.mlir",
],
include = ["*.mlir"],
),
diff --git a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
index badf0ec..742023d 100644
--- a/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
+++ b/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt
@@ -26,7 +26,7 @@
"tile_and_vectorize.mlir"
"tile_fuse_and_vectorize.mlir"
"unfused_fma.mlir"
- "vector_contract_to_aarch64_asm.mlir"
+ "vector_contract_custom_kernels.mlir"
DATA
FileCheck
iree::tools::iree-opt
diff --git a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_custom_kernels.mlir b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_custom_kernels.mlir
new file mode 100644
index 0000000..db29930
--- /dev/null
+++ b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_custom_kernels.mlir
@@ -0,0 +1,85 @@
+// RUN: iree-opt -iree-llvmcpu-vector-contract-custom-kernels='aarch64 dotprod' %s | FileCheck %s
+
+func @mmt_8x4x8_i8i8i32_aarch64_dotprod_inline_asm(
+ %lhs: vector<8x4xi8>,
+ %rhs: vector<8x4xi8>,
+ %acc: vector<8x8xi32>) -> vector<8x8xi32> {
+ %lhs_wide = arith.extsi %lhs : vector<8x4xi8> to vector<8x4xi32>
+ %rhs_wide = arith.extsi %rhs : vector<8x4xi8> to vector<8x4xi32>
+ %res = vector.contract {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>
+ ], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
+ } %lhs_wide, %rhs_wide, %acc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
+ return %res : vector<8x8xi32>
+}
+// CHECK-LABEL: func @mmt_8x4x8_i8i8i32_aarch64_dotprod_inline_asm(
+// CHECK-SAME: %[[LHS:[^:[:space:]]+]]
+// CHECK-SAME: %[[RHS:[^:[:space:]]+]]
+// CHECK-SAME: %[[ACC:[^:[:space:]]+]]
+// CHECK-SAME: -> vector<8x8xi32> {
+// CHECK-DAG: %[[INITRES:.+]] = arith.constant dense<0> : vector<64xi32>
+// CHECK-DAG: %[[LHS1D:.+]] = vector.shape_cast %[[LHS]] : vector<8x4xi8> to vector<32xi8>
+// CHECK-DAG: %[[RHS1D:.+]] = vector.shape_cast %[[RHS]] : vector<8x4xi8> to vector<32xi8>
+// CHECK-DAG: %[[ACC1D:.+]] = vector.shape_cast %[[ACC]] : vector<8x8xi32> to vector<64xi32>
+// CHECK-DAG: %[[LHS1D_0:.+]] = vector.extract_strided_slice %[[LHS1D]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xi8> to vector<16xi8>
+// CHECK-DAG: %[[RHS1D_0:.+]] = vector.extract_strided_slice %[[RHS1D]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xi8> to vector<16xi8>
+// CHECK-DAG: %[[LHS1D_1:.+]] = vector.extract_strided_slice %[[LHS1D]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xi8> to vector<16xi8>
+// CHECK-DAG: %[[RHS1D_1:.+]] = vector.extract_strided_slice %[[RHS1D]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xi8> to vector<16xi8>
+// CHECK-DAG: %[[ACC1D_0:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [0], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_1:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [4], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_2:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [8], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_3:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [12], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_4:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [16], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_5:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [20], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_6:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [24], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_7:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [28], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_8:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [32], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_9:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [36], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_10:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [40], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_11:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [44], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_12:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [48], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_13:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [52], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_14:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [56], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ACC1D_15:.+]] = vector.extract_strided_slice %[[ACC1D]] {offsets = [60], sizes = [4], strides = [1]} : vector<64xi32> to vector<4xi32>
+// CHECK-DAG: %[[ASM:.+]] = llvm.inline_asm asm_dialect = att
+// CHECK-SAME: "{{([\0A ]+sdot \$[0-9]+\.4s, \$[0-9]+\.16b, \$[0-9]+\.4b[[0-9]+]){16}[\0A ]+}}"
+// CHECK-SAME: "{{(\=w,){16}(w,){4}0,1,.*,15}}"
+// CHECK-SAME: {{\((vector<16xi8>, ){4}(vector<4xi32>(, )?){16}\)}}
+// CHECK-SAME: -> !llvm.struct<({{((vector<4xi32>(, )?){16})}})>
+// CHECK-DAG: %[[RES0:.+]] = llvm.extractvalue %[[ASM]][0]
+// CHECK-DAG: %[[RES1:.+]] = llvm.extractvalue %[[ASM]][1]
+// CHECK-DAG: %[[RES2:.+]] = llvm.extractvalue %[[ASM]][2]
+// CHECK-DAG: %[[RES3:.+]] = llvm.extractvalue %[[ASM]][3]
+// CHECK-DAG: %[[RES4:.+]] = llvm.extractvalue %[[ASM]][4]
+// CHECK-DAG: %[[RES5:.+]] = llvm.extractvalue %[[ASM]][5]
+// CHECK-DAG: %[[RES6:.+]] = llvm.extractvalue %[[ASM]][6]
+// CHECK-DAG: %[[RES7:.+]] = llvm.extractvalue %[[ASM]][7]
+// CHECK-DAG: %[[RES8:.+]] = llvm.extractvalue %[[ASM]][8]
+// CHECK-DAG: %[[RES9:.+]] = llvm.extractvalue %[[ASM]][9]
+// CHECK-DAG: %[[RES10:.+]] = llvm.extractvalue %[[ASM]][10]
+// CHECK-DAG: %[[RES11:.+]] = llvm.extractvalue %[[ASM]][11]
+// CHECK-DAG: %[[RES12:.+]] = llvm.extractvalue %[[ASM]][12]
+// CHECK-DAG: %[[RES13:.+]] = llvm.extractvalue %[[ASM]][13]
+// CHECK-DAG: %[[RES14:.+]] = llvm.extractvalue %[[ASM]][14]
+// CHECK-DAG: %[[RES15:.+]] = llvm.extractvalue %[[ASM]][15]
+// CHECK: %[[INS0:.+]] = vector.insert_strided_slice %[[RES0]], %[[INITRES]] {offsets = [0], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS1:.+]] = vector.insert_strided_slice %[[RES1]], %[[INS0]] {offsets = [4], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS2:.+]] = vector.insert_strided_slice %[[RES2]], %[[INS1]] {offsets = [8], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS3:.+]] = vector.insert_strided_slice %[[RES3]], %[[INS2]] {offsets = [12], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS4:.+]] = vector.insert_strided_slice %[[RES4]], %[[INS3]] {offsets = [16], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS5:.+]] = vector.insert_strided_slice %[[RES5]], %[[INS4]] {offsets = [20], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS6:.+]] = vector.insert_strided_slice %[[RES6]], %[[INS5]] {offsets = [24], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS7:.+]] = vector.insert_strided_slice %[[RES7]], %[[INS6]] {offsets = [28], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS8:.+]] = vector.insert_strided_slice %[[RES8]], %[[INS7]] {offsets = [32], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS9:.+]] = vector.insert_strided_slice %[[RES9]], %[[INS8]] {offsets = [36], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS10:.+]] = vector.insert_strided_slice %[[RES10]], %[[INS9]] {offsets = [40], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS11:.+]] = vector.insert_strided_slice %[[RES11]], %[[INS10]] {offsets = [44], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS12:.+]] = vector.insert_strided_slice %[[RES12]], %[[INS11]] {offsets = [48], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS13:.+]] = vector.insert_strided_slice %[[RES13]], %[[INS12]] {offsets = [52], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS14:.+]] = vector.insert_strided_slice %[[RES14]], %[[INS13]] {offsets = [56], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[INS15:.+]] = vector.insert_strided_slice %[[RES15]], %{{.+}} {offsets = [60], strides = [1]} : vector<4xi32> into vector<64xi32>
+// CHECK: %[[RESULT1D:.+]] = vector.shape_cast %[[INS15]] : vector<64xi32> to vector<8x8xi32>
+// CHECK: return %[[RESULT1D]] : vector<8x8xi32>
diff --git a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_aarch64_asm.mlir b/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_aarch64_asm.mlir
deleted file mode 100644
index 3277e84..0000000
--- a/iree/compiler/Codegen/LLVMCPU/test/vector_contract_to_aarch64_asm.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: iree-opt -iree-llvmcpu-vector-to-aarch64-inline-asm %s | FileCheck %s
-
-// CHECK-LABEL: @vector_i8i8i32matmul_to_aarch64_asm_vec_dot(
-func @vector_i8i8i32matmul_to_aarch64_asm_vec_dot(
- // CHECK-SAME: %[[LHS:[a-zA-Z0-9_]+]]
- %lhs: vector<4x4xi8>,
- // CHECK-SAME: %[[RHS:[a-zA-Z0-9_]+]]
- %rhs: vector<4x4xi8>,
- // CHECK-SAME: %[[ACC:[a-zA-Z0-9_]+]]
- %acc: vector<4x4xi32>) -> vector<4x4xi32> {
- %lhs_wide = arith.extsi %lhs : vector<4x4xi8> to vector<4x4xi32>
- %rhs_wide = arith.extsi %rhs : vector<4x4xi8> to vector<4x4xi32>
- // CHECK-DAG: %[[RES_2D:.+]] = arith.constant dense<0> : vector<4x4xi32>
- // CHECK-DAG: %[[DST0:.+]] = vector.extract %[[ACC]][0] : vector<4x4xi32>
- // CHECK-DAG: %[[DST1:.+]] = vector.extract %[[ACC]][1] : vector<4x4xi32>
- // CHECK-DAG: %[[DST2:.+]] = vector.extract %[[ACC]][2] : vector<4x4xi32>
- // CHECK-DAG: %[[DST3:.+]] = vector.extract %[[ACC]][3] : vector<4x4xi32>
- // CHECK-DAG: %[[LHS_1D:.+]] = vector.shape_cast %[[LHS]] : vector<4x4xi8> to vector<16xi8>
- // CHECK-DAG: %[[RHS_T_2d:.+]] = vector.transpose %[[RHS]], [1, 0]
- // CHECK-DAG: %[[RHS_T:.+]] = vector.shape_cast %[[RHS_T_2d]] : vector<4x4xi8> to vector<16xi8>
- // CHECK: %[[ASM_RESULT:.+]] = llvm.inline_asm {{.*}} "=w,=w,=w,=w,w,w,0,1,2,3" %[[LHS_1D]], %[[RHS_T]], %[[DST0]], %[[DST1]], %[[DST2]], %[[DST3]]
- // CHECK-DAG: %[[RES_0:.+]] = llvm.extractvalue %[[ASM_RESULT]][0] : !llvm.struct<(vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>)>
- // CHECK-DAG: %[[RES_1:.+]] = llvm.extractvalue %[[ASM_RESULT]][1] : !llvm.struct<(vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>)>
- // CHECK-DAG: %[[RES_2:.+]] = llvm.extractvalue %[[ASM_RESULT]][2] : !llvm.struct<(vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>)>
- // CHECK-DAG: %[[RES_3:.+]] = llvm.extractvalue %[[ASM_RESULT]][3] : !llvm.struct<(vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>)>
- // CHECK-DAG: %[[RES_2D_0:.+]] = vector.insert %[[RES_0]], %[[RES_2D]] [0] : vector<4xi32> into vector<4x4xi32>
- // CHECK-DAG: %[[RES_2D_1:.+]] = vector.insert %[[RES_1]], %[[RES_2D_0]] [1] : vector<4xi32> into vector<4x4xi32>
- // CHECK-DAG: %[[RES_2D_2:.+]] = vector.insert %[[RES_2]], %[[RES_2D_1]] [2] : vector<4xi32> into vector<4x4xi32>
- // CHECK-DAG: %[[RES_2D_3:.+]] = vector.insert %[[RES_3]], %[[RES_2D_2]] [3] : vector<4xi32> into vector<4x4xi32>
- %res = vector.contract {
- indexing_maps = [
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0, d1)>
- ], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
- } %lhs_wide, %rhs_wide, %acc : vector<4x4xi32>, vector<4x4xi32> into vector<4x4xi32>
- // CHECK: return %[[RES_2D_3]]
- return %res : vector<4x4xi32>
-}
diff --git a/iree/compiler/Codegen/Passes.h b/iree/compiler/Codegen/Passes.h
index 93a80f7..49fc0c8 100644
--- a/iree/compiler/Codegen/Passes.h
+++ b/iree/compiler/Codegen/Passes.h
@@ -171,18 +171,42 @@
/// Replaces llvm.intr.fma with its unfused mul and add ops.
std::unique_ptr<OperationPass<FuncOp>> createLLVMCPUUnfuseFMAOpsPass();
-/// A pass that converts vector dialect operations to inline assembly
-std::unique_ptr<OperationPass<FuncOp>>
-createVectorToAArch64InlineAssemblyPass();
+/// A pass that converts certain vector.contract ops to custom kernels.
+std::unique_ptr<OperationPass<FuncOp>> createVectorContractCustomKernelsPass();
//------------------------------------------------------------------------------
// LLVMCPU Codegen specific patterns.
//------------------------------------------------------------------------------
-/// Populates `patterns` to convert vector.contract op to a sequence
-/// of AArch64 inline assembly operations.
-void populateVectorContractToAArch64InlineAsm(
- OwningRewritePatternList &patterns, MLIRContext *context);
+// Some codegen patterns need to know target CPU information. They can receive
+// such information by means of this struct, which can be populated from either
+// pass options (e.g. in lit tests,
+// -iree-llvmcpu-vector-contract-custom-kernels='aarch64 dotprod')
+// or from global state (see InferCustomKernelsTargetInfoFromGlobals below).
+//
+// It would be interesting to find an opportunity to de-duplicate this with
+// other data structures containing similar information, but a difficulty here
+// is that in the case of lit tests, where we need to populate this from
+// a minimal set of custom boolean options passed to a pass such as
+// -iree-llvmcpu-vector-contract-custom-kernels, we do not have enough
+// information to populate all the other fields of existing, larger data
+// structures. That's the motivation for this custom, minimal struct.
+struct CustomKernelsTargetInfo {
+ // Indicates that the target ISA is Aarch64
+ bool aarch64 = false;
+ // Under aarch64: indicates dot-product extension (SDOT, UDOT)
+ bool dotprod = false;
+};
+
+// Populate target_info fields from the parent HAL::ExecutableVariantOp.
+LogicalResult InferCustomKernelsTargetInfoFromParent(
+ FuncOp entryPointFn, CustomKernelsTargetInfo &target_info);
+
+/// Populates `patterns` to convert certain vector.contract ops to special
+/// "kernels" written either in SIMD intrinsics or inline assembly.
+void populateVectorContractCustomKernelsPatterns(
+ const CustomKernelsTargetInfo &target_info,
+ OwningRewritePatternList &patterns);
void populateUnfusedFMAOpsPassPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
diff --git a/iree/compiler/Codegen/Passes.td b/iree/compiler/Codegen/Passes.td
index 6eb5586..9d67ecd 100644
--- a/iree/compiler/Codegen/Passes.td
+++ b/iree/compiler/Codegen/Passes.td
@@ -159,10 +159,18 @@
let constructor = "mlir::iree_compiler::createLLVMCPUUnfuseFMAOpsPass()";
}
-def VectorToAArch64InlineAsm :
- Pass<"iree-llvmcpu-vector-to-aarch64-inline-asm", "FuncOp"> {
- let summary = "Convert vector operations to aarch64 inline asm LLVMIR dialect";
- let constructor = "mlir::iree_compiler::createVectorToAArch64InlineAssemblyPass()";
+def VectorContractCustomKernels :
+ Pass<"iree-llvmcpu-vector-contract-custom-kernels", "FuncOp"> {
+ let summary = "Enable custom kernels (inline assembly or intrinsics) for some vector.contract ops";
+ let constructor = "mlir::iree_compiler::createVectorContractCustomKernelsPass()";
+ let options = [
+ Option<"aarch64", "aarch64", "bool",
+ /*default=*/"false",
+ "Enable aarch64 kernels">,
+ Option<"dotprod", "dotprod", "bool",
+ /*default=*/"false",
+ "Under aarch64, enable kernels that use dotprod instructions">,
+ ];
}
//------------------------------------------------------------------------------
diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
index 6257157..b0cb0b0 100644
--- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
@@ -660,6 +660,10 @@
addConfig("native_vector_size",
IntegerAttr::get(IndexType::get(context), config_.vectorSize));
+ // Set target CPU features.
+ addConfig("cpu_features",
+ StringAttr::get(context, options_.targetCPUFeatures));
+
return IREE::HAL::ExecutableTargetAttr::get(
context, StringAttr::get(context, "llvm"),
StringAttr::get(context, format), DictionaryAttr::get(context, config));
@@ -700,9 +704,9 @@
LLVMTargetOptions options_;
- // Configuration to be set on each `hal.executable.variant` that only depend
- // on the `options_`.
- struct ConfigurationValues {
+ // Additional target information besides that is contained in
+ // LLVMTargetOptions options_.
+ struct AdditionalConfigurationValues {
std::string dataLayoutStr;
int64_t vectorSize;
} config_;
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index c184213..d2b0e4a 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -143,6 +143,10 @@
("dylib-llvm-aot", "dylib"),
("vmvx", "vmvx"),
],
+ target_cpu_features_variants = [
+ "default",
+ "aarch64:+dotprod",
+ ],
trace_runner = "//iree/tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"i8",
@@ -163,6 +167,10 @@
("dylib-llvm-aot", "dylib"),
# TODO: enable VMVX. Skipped for now: it's very slow for these large matmul tests.
],
+ target_cpu_features_variants = [
+ "default",
+ "aarch64:+dotprod",
+ ],
trace_runner = "//iree/tools:iree-e2e-matmul-test",
) for lhs_rhs_type in [
"i8",
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index 5c20a61..914c9ce 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -178,6 +178,9 @@
"vmvx"
OPT_FLAGS
"--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "aarch64:+dotprod"
)
iree_generated_trace_runner_test(
@@ -198,6 +201,9 @@
"vmvx"
OPT_FLAGS
"--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "aarch64:+dotprod"
)
iree_generated_trace_runner_test(
@@ -216,6 +222,9 @@
"dylib"
OPT_FLAGS
"--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "aarch64:+dotprod"
)
iree_generated_trace_runner_test(
@@ -234,6 +243,9 @@
"dylib"
OPT_FLAGS
"--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+ "aarch64:+dotprod"
)
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###