end-to-end matmul tests (#7347)
end-to-end matmul tests, meaning: we generate input matrices, perform linalg.matmul's, check that the results are correct.
In an earlier attempt (PR #7154) the test was entirely generated code. That approach was abandoned because it potentially amounted to comparing the compiler's output with itself.
The new approach here is to have a custom C test runner that has its own C reference implementation of matmul and uses it to compare the output of the generated code, which it drives by means of replay-trace. So now each test is not just a .mlir source, but a (.mlir, .yaml) pair where the yaml file is a trace, determining the runtime values of the inputs of the code being tested. Both .mlir and .yaml files are co-generated by a Python script.
A new CMake function iree_trace_runner_test is added. It is a variant of iree_check_test, taking a custom trace runner and trace file, instead of iree_check_test's behavior of just running the module in iree-check-module.
A new CMake function iree_generated_trace_runner_test is added to call the Python script to co-generate the .yaml and the .mlirsource and call iree_trace_runner_test on the resulting files.
One aspect of PR #7154 is retained here: all the above-mentioned CMake functions are generalized also to allow running iree-opt (or an alternate tool) as a pre-processing step on the source .mlir file, before running iree-translate. That is because by default, new MLIR transformations are not necessarily known to iree-translate. In our case, that's the mmt4d-related transformations.
The tests added in this PR share the prefix e2e_matmul (that can be used with ctest -R).
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 231833c..36b05b6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -266,6 +266,7 @@
include(iree_lit_test)
include(iree_add_all_subdirs)
include(iree_check_test)
+include(iree_trace_runner_test)
include(iree_run_binary_test)
include(iree_mlir_benchmark_suite)
diff --git a/build_tools/bazel/iree_check_test.bzl b/build_tools/bazel/iree_check_test.bzl
index 04d3afe..92d382c 100644
--- a/build_tools/bazel/iree_check_test.bzl
+++ b/build_tools/bazel/iree_check_test.bzl
@@ -22,6 +22,8 @@
driver,
compiler_flags = [],
runner_args = [],
+ opt_tool = "//iree/tools:iree-opt",
+ opt_flags = [],
tags = [],
timeout = None,
**kwargs):
@@ -36,6 +38,10 @@
flags are passed automatically.
runner_args: additional runner_args to pass to iree-check-module. The driver and input file
are passed automatically.
+ opt_tool: Defaulting to iree-opt. Tool used to preprocess the source files
+ if opt_flags is specified.
+ opt_flags: If specified, source files are preprocessed with OPT_TOOL with
+ these flags.
tags: additional tags to apply to the generated test. A tag "driver=DRIVER" is added
automatically.
timeout: timeout for the generated tests.
@@ -50,6 +56,8 @@
"-mlir-print-op-on-diagnostic=false",
"-iree-hal-target-backends=%s" % target_backend,
] + compiler_flags,
+ opt_tool = opt_tool,
+ opt_flags = opt_flags,
visibility = ["//visibility:private"],
)
@@ -73,6 +81,8 @@
driver,
compiler_flags = [],
runner_args = [],
+ opt_tool = "//iree/tools:iree-opt",
+ opt_flags = [],
tags = [],
timeout = None,
**kwargs):
@@ -90,6 +100,10 @@
runner_args: additional runner_args to pass to the underlying iree-check-module tests. The
driver and input file are passed automatically. To use different runner_args per test,
create a separate suite or iree_check_test.
+ opt_tool: Defaulting to iree-opt. Tool used to preprocess the source files
+ if opt_flags is specified.
+ opt_flags: If specified, source files are preprocessed with OPT_TOOL with
+ these flags.
tags: tags to apply to the generated tests. Note that as in standard test suites, manual
is treated specially and will also apply to the test suite itself.
timeout: timeout for the generated tests.
@@ -105,6 +119,8 @@
driver = driver,
compiler_flags = compiler_flags,
runner_args = runner_args,
+ opt_tool = opt_tool,
+ opt_flags = opt_flags,
tags = tags,
timeout = timeout,
**kwargs
@@ -128,6 +144,8 @@
target_backends_and_drivers = ALL_TARGET_BACKENDS_AND_DRIVERS,
compiler_flags = [],
runner_args = [],
+ opt_tool = "//iree/tools:iree-opt",
+ opt_flags = [],
tags = [],
**kwargs):
"""Creates a test suite of iree-check-module tests.
@@ -143,6 +161,10 @@
runner_args: additional runner_args to pass to the underlying iree-check-module tests. The
driver and input file are passed automatically. To use different runner_args per test,
create a separate suite or iree_check_test.
+ opt_tool: Defaulting to iree-opt. Tool used to preprocess the source files
+ if opt_flags is specified.
+ opt_flags: If specified, source files are preprocessed with OPT_TOOL with
+ these flags.
tags: tags to apply to the generated tests. Note that as in standard test suites, manual
is treated specially and will also apply to the test suite itself.
**kwargs: any additional attributes to pass to the underlying tests and test suite.
@@ -160,6 +182,8 @@
target_backend = backend,
compiler_flags = compiler_flags,
runner_args = runner_args,
+ opt_tool = opt_tool,
+ opt_flags = opt_flags,
tags = tags,
**kwargs
)
diff --git a/build_tools/bazel/iree_trace_runner_test.bzl b/build_tools/bazel/iree_trace_runner_test.bzl
new file mode 100644
index 0000000..4c2b50a
--- /dev/null
+++ b/build_tools/bazel/iree_trace_runner_test.bzl
@@ -0,0 +1,12 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""Macros for defining tests that use a trace-runner."""
+
+def iree_generated_trace_runner_test(**kwargs):
+ # TODO: implement this. For now, it's only parsed by bazel_to_cmake.py, so
+ # the iree_generated_check_test's are just omitted in bazel builds.
+ pass
diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
index 7c1611d..9abd641 100644
--- a/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
+++ b/build_tools/bazel_to_cmake/bazel_to_cmake_converter.py
@@ -66,20 +66,15 @@
return ""
-def _convert_translate_tool_block(translate_tool):
- if translate_tool is None:
+def _convert_target_block(name, target):
+ if target is None:
return ""
# Bazel target name to cmake binary name
# Bazel `//iree/custom:custom-translate` -> CMake `iree_custom_custom-translate`
- translate_tool = translate_tool.replace(
- "//iree", "iree") # iree/custom:custom-translate
- translate_tool = translate_tool.replace(":",
- "_") # iree/custom_custom-translate
- translate_tool = translate_tool.replace("/",
- "_") # iree_custom_custom-translate
- return _convert_string_arg_block("TRANSLATE_TOOL",
- translate_tool,
- quote=False)
+ target = target.replace("//iree", "iree") # iree/custom:custom-translate
+ target = target.replace(":", "_") # iree/custom_custom-translate
+ target = target.replace("/", "_") # iree_custom_custom-translate
+ return _convert_string_arg_block(name, target, quote=False)
def _convert_srcs_block(srcs):
@@ -439,12 +434,15 @@
flags=None,
translate_tool=None,
c_identifier=None,
+ opt_flags=None,
testonly=None):
name_block = _convert_string_arg_block("NAME", name, quote=False)
src_block = _convert_string_arg_block("SRC", src)
c_identifier_block = _convert_string_arg_block("C_IDENTIFIER", c_identifier)
- translate_tool_block = _convert_translate_tool_block(translate_tool)
+ translate_tool_block = _convert_target_block("TRANSLATE_TOOL",
+ translate_tool)
flags_block = _convert_string_list_block("FLAGS", flags)
+ opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
testonly_block = _convert_option_block("TESTONLY", testonly)
self.converter.body += (f"iree_bytecode_module(\n"
@@ -453,6 +451,7 @@
f"{c_identifier_block}"
f"{translate_tool_block}"
f"{flags_block}"
+ f"{opt_flags_block}"
f"{testonly_block}"
f" PUBLIC\n)\n\n")
@@ -532,6 +531,7 @@
target_backends_and_drivers=None,
runner_args=None,
tags=None,
+ opt_flags=None,
**kwargs):
name_block = _convert_string_arg_block("NAME", name, quote=False)
srcs_block = _convert_srcs_block(srcs)
@@ -542,6 +542,7 @@
compiler_flags)
runner_args_block = _convert_string_list_block("RUNNER_ARGS", runner_args)
labels_block = _convert_string_list_block("LABELS", tags)
+ opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
self.converter.body += (f"iree_check_single_backend_test_suite(\n"
f"{name_block}"
@@ -551,6 +552,7 @@
f"{compiler_flags_block}"
f"{runner_args_block}"
f"{labels_block}"
+ f"{opt_flags_block}"
f")\n\n")
def iree_check_test_suite(self,
@@ -560,6 +562,7 @@
compiler_flags=None,
runner_args=None,
tags=None,
+ opt_flags=None,
**kwargs):
target_backends = None
drivers = None
@@ -576,6 +579,7 @@
compiler_flags)
runner_args_block = _convert_string_list_block("RUNNER_ARGS", runner_args)
labels_block = _convert_string_list_block("LABELS", tags)
+ opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
self.converter.body += (f"iree_check_test_suite(\n"
f"{name_block}"
@@ -585,6 +589,54 @@
f"{compiler_flags_block}"
f"{runner_args_block}"
f"{labels_block}"
+ f"{opt_flags_block}"
+ f")\n\n")
+
+ def iree_generated_trace_runner_test(self,
+ name,
+ generator,
+ generator_args=None,
+ trace_runner=None,
+ target_backends_and_drivers=None,
+ compiler_flags=None,
+ runner_args=None,
+ tags=None,
+ opt_tool=None,
+ opt_flags=None,
+ **kwargs):
+ target_backends = None
+ drivers = None
+ if target_backends_and_drivers is not None:
+ target_backends = [it[0] for it in target_backends_and_drivers]
+ drivers = [it[1] for it in target_backends_and_drivers]
+
+ name_block = _convert_string_arg_block("NAME", name, quote=False)
+ generator_block = _convert_string_arg_block("GENERATOR",
+ generator,
+ quote=True)
+ generator_args_block = _convert_string_list_block("GENERATOR_ARGS",
+ generator_args)
+ trace_runner_block = _convert_target_block("TRACE_RUNNER", trace_runner)
+ target_backends_block = _convert_string_list_block("TARGET_BACKENDS",
+ target_backends)
+ drivers_block = _convert_string_list_block("DRIVERS", drivers)
+ compiler_flags_block = _convert_string_list_block("COMPILER_FLAGS",
+ compiler_flags)
+ runner_args_block = _convert_string_list_block("RUNNER_ARGS", runner_args)
+ labels_block = _convert_string_list_block("LABELS", tags)
+ opt_flags_block = _convert_string_list_block("OPT_FLAGS", opt_flags)
+
+ self.converter.body += (f"iree_generated_trace_runner_test(\n"
+ f"{name_block}"
+ f"{generator_block}"
+ f"{generator_args_block}"
+ f"{trace_runner_block}"
+ f"{target_backends_block}"
+ f"{drivers_block}"
+ f"{compiler_flags_block}"
+ f"{runner_args_block}"
+ f"{labels_block}"
+ f"{opt_flags_block}"
f")\n\n")
def iree_e2e_cartesian_product_test_suite(self,
diff --git a/build_tools/cmake/iree_bytecode_module.cmake b/build_tools/cmake/iree_bytecode_module.cmake
index a63495b..6d05eb7 100644
--- a/build_tools/cmake/iree_bytecode_module.cmake
+++ b/build_tools/cmake/iree_bytecode_module.cmake
@@ -32,8 +32,8 @@
cmake_parse_arguments(
_RULE
"PUBLIC;TESTONLY"
- "NAME;SRC;TRANSLATE_TOOL;C_IDENTIFIER"
- "FLAGS"
+ "NAME;SRC;TRANSLATE_TOOL;C_IDENTIFIER;OPT_TOOL;MODULE_FILE_NAME"
+ "FLAGS;OPT_FLAGS"
${ARGN}
)
@@ -48,21 +48,72 @@
set(_TRANSLATE_TOOL "iree-translate")
endif()
+ if(DEFINED _RULE_MODULE_FILE_NAME)
+ set(_MODULE_FILE_NAME "${_RULE_MODULE_FILE_NAME}")
+ else()
+ set(_MODULE_FILE_NAME "${_RULE_NAME}.vmfb")
+ endif()
+
+ # If OPT_FLAGS was specified, preprocess the source file with the OPT_TOOL
+ if(_RULE_OPT_FLAGS)
+ # Create the filename for the output of OPT_TOOL, which
+ # will relace _RULE_SRC as the input to iree_bytecode_module.
+ set(_TRANSLATE_SRC_BASENAME "${_RULE_NAME}.opt.mlir")
+ set(_TRANSLATE_SRC "${CMAKE_CURRENT_BINARY_DIR}/${_TRANSLATE_SRC_BASENAME}")
+
+ # Set default for OPT_TOOL.
+ if(_RULE_OPT_TOOL)
+ set(_OPT_TOOL ${_RULE_OPT_TOOL})
+ else()
+ set(_OPT_TOOL "iree-opt")
+ endif()
+
+ # Prepare the OPT_TOOL command line.
+ iree_get_executable_path(_OPT_TOOL_EXECUTABLE ${_OPT_TOOL})
+
+ set(_ARGS "${_RULE_OPT_FLAGS}")
+ get_filename_component(_SRC_PATH "${_RULE_SRC}" REALPATH)
+ list(APPEND _ARGS "${_SRC_PATH}")
+ list(APPEND _ARGS "-o")
+ list(APPEND _ARGS "${_TRANSLATE_SRC}")
+
+ add_custom_command(
+ OUTPUT
+ "${_TRANSLATE_SRC_BASENAME}"
+ COMMAND
+ ${_OPT_TOOL_EXECUTABLE}
+ ${_ARGS}
+ # Changes to the opt tool should trigger rebuilding.
+ # Using {_OPT_TOOL} as the dependency would only work when the tools
+ # are built in the same cmake build directory as the tests, that is,
+ # when NOT cross-compiling. Using {_OPT_TOOL_EXECUTABLE} works
+ # uniformly regardless of that.
+ DEPENDS
+ ${_OPT_TOOL_EXECUTABLE}
+ ${_RULE_SRC}
+ )
+ else()
+ # OPT_FLAGS was not specified, so are not using the OPT_TOOL.
+ # Just pass the source file directly as the source for the bytecode module.
+ set(_TRANSLATE_SRC "${_RULE_SRC}")
+ endif()
+
iree_get_executable_path(_TRANSLATE_TOOL_EXECUTABLE ${_TRANSLATE_TOOL})
iree_get_executable_path(_EMBEDDED_LINKER_TOOL_EXECUTABLE "lld")
set(_ARGS "${_RULE_FLAGS}")
- get_filename_component(_SRC_PATH "${_RULE_SRC}" REALPATH)
- list(APPEND _ARGS "${_SRC_PATH}")
+
+ get_filename_component(_TRANSLATE_SRC_PATH "${_TRANSLATE_SRC}" REALPATH)
+ list(APPEND _ARGS "${_TRANSLATE_SRC_PATH}")
list(APPEND _ARGS "-o")
- list(APPEND _ARGS "${_RULE_NAME}.vmfb")
+ list(APPEND _ARGS "${_MODULE_FILE_NAME}")
list(APPEND _ARGS "-iree-llvm-embedded-linker-path=${_EMBEDDED_LINKER_TOOL_EXECUTABLE}")
# Depending on the binary instead of the target here given we might not have
# a target in this CMake invocation when cross-compiling.
add_custom_command(
OUTPUT
- "${_RULE_NAME}.vmfb"
+ "${_MODULE_FILE_NAME}"
COMMAND
${_TRANSLATE_TOOL_EXECUTABLE}
${_ARGS}
@@ -71,7 +122,7 @@
DEPENDS
${_TRANSLATE_TOOL_EXECUTABLE}
${_EMBEDDED_LINKER_TOOL_EXECUTABLE}
- ${_RULE_SRC}
+ ${_TRANSLATE_SRC}
)
if(_RULE_TESTONLY)
diff --git a/build_tools/cmake/iree_check_test.cmake b/build_tools/cmake/iree_check_test.cmake
index 32cf435..9b3a79e 100644
--- a/build_tools/cmake/iree_check_test.cmake
+++ b/build_tools/cmake/iree_check_test.cmake
@@ -6,6 +6,57 @@
include(CMakeParseArguments)
+# Helper for iree_check_test and iree_trace_runner_test.
+# Just a thin wrapper around iree_bytecode_module, passing it some
+# common flags, including the appropriate --iree-llvm-target-triple in the
+# Android case.
+function(iree_bytecode_module_for_iree_check_test_and_friends)
+ if(NOT IREE_BUILD_TESTS)
+ return()
+ endif()
+
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "MODULE_NAME;SRC;TARGET_BACKEND;OPT_TOOL;MODULE_FILE_NAME"
+ "FLAGS;OPT_FLAGS"
+ ${ARGN}
+ )
+
+ if(ANDROID)
+ # Android's CMake toolchain defines some variables that we can use to infer
+ # the appropriate target triple from the configured settings:
+ # https://developer.android.com/ndk/guides/cmake#android_platform
+ #
+ # In typical CMake fashion, the various strings are pretty fuzzy and can
+ # have multiple values like "latest", "android-25"/"25"/"android-N-MR1".
+ #
+ # From looking at the toolchain file, ANDROID_PLATFORM_LEVEL seems like it
+ # should pretty consistently be just a number we can use for target triple.
+ set(_TARGET_TRIPLE "aarch64-none-linux-android${ANDROID_PLATFORM_LEVEL}")
+ list(APPEND _RULE_FLAGS "--iree-llvm-target-triple=${_TARGET_TRIPLE}")
+ endif()
+
+ iree_bytecode_module(
+ NAME
+ "${_RULE_MODULE_NAME}"
+ MODULE_FILE_NAME
+ "${_RULE_MODULE_FILE_NAME}"
+ SRC
+ "${_RULE_SRC}"
+ FLAGS
+ "-iree-mlir-to-vm-bytecode-module"
+ "-mlir-print-op-on-diagnostic=false"
+ "--iree-hal-target-backends=${_RULE_TARGET_BACKEND}"
+ ${_RULE_FLAGS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ TESTONLY
+ )
+endfunction()
+
# iree_check_test()
#
# Creates a test using iree-check-module for the specified source file.
@@ -23,6 +74,12 @@
# and input file are passed automatically.
# LABELS: Additional labels to apply to the test. The package path and
# "driver=${DRIVER}" are added automatically.
+# OPT_TOOL: Defaulting to iree-opt. Tool used to preprocess the source files
+# if OPT_FLAGS is specified.
+# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
+# these flags.
+# MODULE_FILE_NAME: Optional, specifies the absolute path to the filename
+# to use for the generated IREE module (.vmfb).
function(iree_check_test)
if(NOT IREE_BUILD_TESTS)
return()
@@ -51,8 +108,8 @@
cmake_parse_arguments(
_RULE
""
- "NAME;SRC;TARGET_BACKEND;DRIVER"
- "COMPILER_FLAGS;RUNNER_ARGS;LABELS"
+ "NAME;SRC;TARGET_BACKEND;DRIVER;OPT_TOOL;MODULE_FILE_NAME"
+ "COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
${ARGN}
)
@@ -61,49 +118,28 @@
set(_MODULE_NAME "${_RULE_NAME}_module")
- if(ANDROID)
- # Android's CMake toolchain defines some variables that we can use to infer
- # the appropriate target triple from the configured settings:
- # https://developer.android.com/ndk/guides/cmake#android_platform
- #
- # In typical CMake fashion, the various strings are pretty fuzzy and can
- # have multiple values like "latest", "android-25"/"25"/"android-N-MR1".
- #
- # From looking at the toolchain file, ANDROID_PLATFORM_LEVEL seems like it
- # should pretty consistently be just a number we can use for target triple.
- set(_TARGET_TRIPLE "aarch64-none-linux-android${ANDROID_PLATFORM_LEVEL}")
+ if(DEFINED _RULE_MODULE_FILE_NAME)
+ set(_MODULE_FILE_NAME "${_RULE_MODULE_FILE_NAME}")
+ else(DEFINED _RULE_MODULE_FILE_NAME)
+ set(_MODULE_FILE_NAME "${_MODULE_NAME}.vmfb")
+ endif(DEFINED _RULE_MODULE_FILE_NAME)
- iree_bytecode_module(
- NAME
- "${_MODULE_NAME}"
- SRC
- "${_RULE_SRC}"
- FLAGS
- "-iree-mlir-to-vm-bytecode-module"
- "-mlir-print-op-on-diagnostic=false"
- "--iree-hal-target-backends=${_RULE_TARGET_BACKEND}"
- "--iree-llvm-target-triple=${_TARGET_TRIPLE}"
- ${_RULE_COMPILER_FLAGS}
- TESTONLY
- )
- else(ANDROID)
- iree_bytecode_module(
- NAME
- "${_MODULE_NAME}"
- SRC
- "${_RULE_SRC}"
- FLAGS
- "-iree-mlir-to-vm-bytecode-module"
- "-mlir-print-op-on-diagnostic=false"
- "--iree-hal-target-backends=${_RULE_TARGET_BACKEND}"
- ${_RULE_COMPILER_FLAGS}
- TESTONLY
- )
- endif(ANDROID)
-
- # TODO(b/146898896): It would be nice if this were something we could query
- # rather than having to know the conventions used by iree_bytecode_module.
- set(_MODULE_FILE_NAME "${_MODULE_NAME}.vmfb")
+ iree_bytecode_module_for_iree_check_test_and_friends(
+ MODULE_NAME
+ "${_MODULE_NAME}"
+ MODULE_FILE_NAME
+ "${_MODULE_FILE_NAME}"
+ SRC
+ "${_RULE_SRC}"
+ TARGET_BACKEND
+ "${_RULE_TARGET_BACKEND}"
+ FLAGS
+ ${_RULE_COMPILER_FLAGS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ )
# iree_bytecode_module does not define a target, only a custom command.
# We need to create a target that depends on the command to ensure the
@@ -117,69 +153,33 @@
"${_MODULE_FILE_NAME}"
)
+ set(_RUNNER_TARGET "iree_tools_iree-check-module")
+
# A target specifically for the test. We could combine this with the above,
# but we want that one to get pulled into iree_bytecode_module.
add_custom_target("${_NAME}" ALL)
add_dependencies(
"${_NAME}"
"${_MODULE_TARGET_NAME}"
- iree_tools_iree-check-module
+ "${_RUNNER_TARGET}"
)
- iree_package_ns(_PACKAGE_NS)
- string(REPLACE "::" "/" _PACKAGE_PATH ${_PACKAGE_NS})
- set(_TEST_NAME "${_PACKAGE_PATH}/${_RULE_NAME}")
-
- # Case for cross-compiling towards Android.
- if(ANDROID)
- set(_ANDROID_EXE_REL_DIR "iree/modules/check")
- set(_ANDROID_REL_DIR "${_PACKAGE_PATH}/${_RULE_NAME}")
- set(_ANDROID_ABS_DIR "/data/local/tmp/${_ANDROID_REL_DIR}")
-
- # Define a custom target for pushing and running the test on Android device.
- set(_TEST_NAME ${_TEST_NAME}_on_android_device)
- add_test(
- NAME
- ${_TEST_NAME}
- COMMAND
- "${CMAKE_SOURCE_DIR}/build_tools/cmake/run_android_test.${IREE_HOST_SCRIPT_EXT}"
- "${_ANDROID_REL_DIR}/$<TARGET_FILE_NAME:iree_tools_iree-check-module>"
- "--driver=${_RULE_DRIVER}"
- "${_ANDROID_REL_DIR}/${_MODULE_FILE_NAME}"
- ${_RULE_RUNNER_ARGS}
- )
- # Use environment variables to instruct the script to push artifacts
- # onto the Android device before running the test. This needs to match
- # with the expectation of the run_android_test.{sh|bat|ps1} script.
- set(
- _ENVIRONMENT_VARS
- TEST_ANDROID_ABS_DIR=${_ANDROID_ABS_DIR}
- TEST_DATA=${CMAKE_CURRENT_BINARY_DIR}/${_MODULE_FILE_NAME}
- TEST_EXECUTABLE=$<TARGET_FILE:iree_tools_iree-check-module>
- )
- set_property(TEST ${_TEST_NAME} PROPERTY ENVIRONMENT ${_ENVIRONMENT_VARS})
- iree_add_test_environment_properties(${_TEST_NAME})
- else(ANDROID)
- add_test(
- NAME
- "${_TEST_NAME}"
- COMMAND
- "${CMAKE_SOURCE_DIR}/build_tools/cmake/run_test.${IREE_HOST_SCRIPT_EXT}"
- "$<TARGET_FILE:iree_tools_iree-check-module>"
- "--driver=${_RULE_DRIVER}"
- "${CMAKE_CURRENT_BINARY_DIR}/${_MODULE_FILE_NAME}"
- ${_RULE_RUNNER_ARGS}
- )
- set_property(TEST "${_TEST_NAME}" PROPERTY ENVIRONMENT "TEST_TMPDIR=${_NAME}_test_tmpdir")
- iree_add_test_environment_properties(${_TEST_NAME})
- endif(ANDROID)
-
- list(APPEND _RULE_LABELS "${_PACKAGE_PATH}" "driver=${_RULE_DRIVER}")
- set_property(TEST "${_TEST_NAME}" PROPERTY REQUIRED_FILES "${_MODULE_FILE_NAME}")
- set_property(TEST "${_TEST_NAME}" PROPERTY LABELS "${_RULE_LABELS}")
+ iree_run_binary_test(
+ NAME
+ "${_RULE_NAME}"
+ DRIVER
+ "${_RULE_DRIVER}"
+ TEST_BINARY
+ "${_RUNNER_TARGET}"
+ TEST_INPUT_FILE_ARG
+ "${_MODULE_FILE_NAME}"
+ ARGS
+ ${_RULE_RUNNER_ARGS}
+ LABELS
+ ${_RULE_LABELS}
+ )
endfunction()
-
# iree_check_single_backend_test_suite()
#
# Creates a test suite of iree-check-module tests for a single backend/driver pair.
@@ -199,6 +199,10 @@
# different args per test, create a separate suite or iree_check_test.
# LABELS: Additional labels to apply to the generated tests. The package path is
# added automatically.
+# OPT_TOOL: Defaulting to iree-opt. Tool used to preprocess the source files
+# if OPT_FLAGS is specified.
+# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
+# these flags.
function(iree_check_single_backend_test_suite)
if(NOT IREE_BUILD_TESTS)
return()
@@ -211,8 +215,8 @@
cmake_parse_arguments(
_RULE
""
- "NAME;TARGET_BACKEND;DRIVER"
- "SRCS;COMPILER_FLAGS;RUNNER_ARGS;LABELS"
+ "NAME;TARGET_BACKEND;DRIVER;OPT_TOOL"
+ "SRCS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
${ARGN}
)
@@ -258,6 +262,10 @@
${_RULE_RUNNER_ARGS}
LABELS
${_RULE_LABELS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
)
endforeach()
endfunction()
@@ -286,6 +294,10 @@
# test, create a separate suite or iree_check_test.
# LABELS: Additional labels to apply to the generated tests. The package path is
# added automatically.
+# OPT_TOOL: Defaulting to iree-opt. Tool used to preprocess the source files
+# if OPT_FLAGS is specified.
+# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
+# these flags.
function(iree_check_test_suite)
if(NOT IREE_BUILD_TESTS)
return()
@@ -332,6 +344,10 @@
${_RULE_RUNNER_ARGS}
LABELS
${_RULE_LABELS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
)
endforeach()
endfunction()
diff --git a/build_tools/cmake/iree_run_binary_test.cmake b/build_tools/cmake/iree_run_binary_test.cmake
index 8db938c..5129c39 100644
--- a/build_tools/cmake/iree_run_binary_test.cmake
+++ b/build_tools/cmake/iree_run_binary_test.cmake
@@ -14,14 +14,23 @@
#
# Parameters:
# NAME: name of target
-# ARGS: arguments passed to the test binary.
+# DRIVER: If specified, will pass --driver=DRIVER to the test binary and adds
+# a driver label to the test.
+# TEST_INPUT_FILE_ARG: If specified, the input file will be added to DATA and
+# its device path appended to ARGS. Note that the device path may be different
+# from the host path, so this parameter should be used to portably pass file arguments
+# to tests.
+# DATA: Additional input files needed by the test binary. When running tests on a
+# separate device (e.g. Android), these files will be pushed to the device.
+# TEST_INPUT_FILE_ARG is automatically added if specified.
+# ARGS: additional arguments passed to the test binary. TEST_INPUT_FILE_ARG and
+# --driver=DRIVER are automatically added if specified.
# TEST_BINARY: binary target to run as the test.
# LABELS: Additional labels to apply to the test. The package path is added
# automatically.
#
-# Note: the DATA argument is not supported because CMake doesn't have a good way
-# to specify a data dependency for a test.
-#
+# Note: the DATA argument is not actually adding dependencies because CMake
+# doesn't have a good way to specify a data dependency for a test.
#
# Usage:
# iree_cc_binary(
@@ -46,8 +55,8 @@
cmake_parse_arguments(
_RULE
""
- "NAME;TEST_BINARY"
- "ARGS;LABELS"
+ "NAME;TEST_BINARY;DRIVER;TEST_INPUT_FILE_ARG"
+ "ARGS;LABELS;DATA"
${ARGN}
)
@@ -58,6 +67,27 @@
iree_package_path(_PACKAGE_PATH)
set(_TEST_NAME "${_PACKAGE_PATH}/${_RULE_NAME}")
+ # If driver was specified, add the corresponding test arg and label.
+ if (DEFINED _RULE_DRIVER)
+ list(APPEND _RULE_ARGS "--driver=${_RULE_DRIVER}")
+ list(APPEND _RULE_LABELS "driver=${_RULE_DRIVER}")
+ endif()
+
+ if(ANDROID)
+ set(_ANDROID_REL_DIR "${_PACKAGE_PATH}/${_RULE_NAME}")
+ set(_ANDROID_ABS_DIR "/data/local/tmp/${_ANDROID_REL_DIR}")
+ endif()
+
+ if (DEFINED _RULE_TEST_INPUT_FILE_ARG)
+ if (ANDROID)
+ get_filename_component(_TEST_INPUT_FILE_BASENAME "${_RULE_TEST_INPUT_FILE_ARG}" NAME)
+ list(APPEND _RULE_ARGS "${_ANDROID_REL_DIR}/${_TEST_INPUT_FILE_BASENAME}")
+ else()
+ list(APPEND _RULE_ARGS "${_RULE_TEST_INPUT_FILE_ARG}")
+ endif()
+ list(APPEND _RULE_DATA "${_RULE_TEST_INPUT_FILE_ARG}")
+ endif()
+
# Replace binary passed by relative ::name with iree::package::name
string(REGEX REPLACE "^::" "${_PACKAGE_NS}::" _TEST_BINARY_TARGET ${_RULE_TEST_BINARY})
@@ -78,11 +108,13 @@
# Use environment variables to instruct the script to push artifacts
# onto the Android device before running the test. This needs to match
# with the expectation of the run_android_test.{sh|bat|ps1} script.
+ string (REPLACE ";" " " _DATA_SPACE_SEPARATED "${_RULE_DATA}")
set(
_ENVIRONMENT_VARS
- TEST_ANDROID_ABS_DIR=${_ANDROID_ABS_DIR}
- TEST_EXECUTABLE=$<TARGET_FILE:${_TEST_BINARY_TARGET}>
- TEST_TMPDIR=${_ANDROID_ABS_DIR}/test_tmpdir
+ "TEST_ANDROID_ABS_DIR=${_ANDROID_ABS_DIR}"
+ "TEST_EXECUTABLE=$<TARGET_FILE:${_TEST_BINARY_TARGET}>"
+ "DATA=${_DATA_SPACE_SEPARATED}"
+ "TEST_TMPDIR=${_ANDROID_ABS_DIR}/test_tmpdir"
)
set_property(TEST ${_TEST_NAME} PROPERTY ENVIRONMENT ${_ENVIRONMENT_VARS})
else()
@@ -100,4 +132,5 @@
list(APPEND _RULE_LABELS "${_PACKAGE_PATH}")
set_property(TEST ${_TEST_NAME} PROPERTY LABELS "${_RULE_LABELS}")
+ set_property(TEST "${_TEST_NAME}" PROPERTY REQUIRED_FILES "${_RULE_DATA}")
endfunction()
diff --git a/build_tools/cmake/iree_trace_runner_test.cmake b/build_tools/cmake/iree_trace_runner_test.cmake
new file mode 100644
index 0000000..2e34645
--- /dev/null
+++ b/build_tools/cmake/iree_trace_runner_test.cmake
@@ -0,0 +1,341 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+include(CMakeParseArguments)
+
+# iree_trace_runner_test()
+#
+# Creates a test using a specified trace-runner program for the specified
+# replay trace.
+#
+# Parameters:
+# NAME: Name of the target
+# SRC: mlir source file to be compiled to an IREE module.
+# TARGET_BACKEND: target backend to compile for.
+# DRIVER: driver to run the module with.
+# COMPILER_FLAGS: additional flags to pass to the compiler. Bytecode
+# translation and backend flags are passed automatically.
+# RUNNER_ARGS: additional args to pass to the trace-runner program. The driver
+# and input file flags are passed automatically.
+# LABELS: Additional labels to apply to the test. The package path and
+# "driver=${DRIVER}" are added automatically.
+# OPT_TOOL: Defaulting to iree-opt. Tool used to preprocess the source files
+# if OPT_FLAGS is specified.
+# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
+# these flags.
+# TRACE_RUNNER: trace-runner program to run.
+# TRACE: trace file input to the trace-runner program.
+# MODULE_FILE_NAME: specifies the absolute path to the filename to use for the
+# generated IREE module (.vmfb). Mandatory, unlike in iree_check_test,
+# because trace files (.yaml) reference a specific module file path.
+function(iree_trace_runner_test)
+ if(NOT IREE_BUILD_TESTS)
+ return()
+ endif()
+
+ # See comment in iree_check_test about this condition.
+ if(NOT IREE_BUILD_COMPILER AND NOT CMAKE_CROSSCOMPILING)
+ return()
+ endif()
+
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME;SRC;TRACE;TARGET_BACKEND;DRIVER;OPT_TOOL;TRACE_RUNNER;MODULE_FILE_NAME"
+ "COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ ${ARGN}
+ )
+
+ iree_package_name(_PACKAGE_NAME)
+ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
+
+ set(_MODULE_NAME "${_RULE_NAME}_module")
+
+ iree_bytecode_module_for_iree_check_test_and_friends(
+ MODULE_NAME
+ "${_MODULE_NAME}"
+ MODULE_FILE_NAME
+ "${_RULE_MODULE_FILE_NAME}"
+ SRC
+ "${_RULE_SRC}"
+ TARGET_BACKEND
+ "${_RULE_TARGET_BACKEND}"
+ FLAGS
+ ${_RULE_COMPILER_FLAGS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ )
+
+ # iree_bytecode_module does not define a target, only a custom command.
+ # We need to create a target that depends on the command to ensure the
+ # module gets built.
+ # TODO(b/146898896): Do this in iree_bytecode_module and avoid having to
+ # reach into the internals.
+ set(_MODULE_TARGET_NAME "${_NAME}_module")
+ add_custom_target(
+ "${_MODULE_TARGET_NAME}"
+ DEPENDS
+ "${_RULE_MODULE_FILE_NAME}"
+ )
+
+ # A target specifically for the test. We could combine this with the above,
+ # but we want that one to get pulled into iree_bytecode_module.
+ add_custom_target("${_NAME}" ALL)
+ add_dependencies(
+ "${_NAME}"
+ "${_MODULE_TARGET_NAME}"
+ "${_RULE_TRACE_RUNNER}"
+ )
+
+ iree_run_binary_test(
+ NAME
+ "${_RULE_NAME}"
+ DRIVER
+ "${_RULE_DRIVER}"
+ TEST_BINARY
+ "${_RULE_TRACE_RUNNER}"
+ TEST_INPUT_FILE_ARG
+ ${_RULE_TRACE}
+ DATA
+ ${_MODULE_FILE_NAME}
+ ARGS
+ ${_RULE_RUNNER_ARGS}
+ LABELS
+ ${_RULE_LABELS}
+ )
+endfunction()
+
+# iree_single_backend_generated_trace_runner_test()
+#
+# Variant of iree_trace_runner_test where instead of specifying
+# a source file (and possibly a trace file and module path), one passes a
+# generator program.
+#
+# Parameters:
+# NAME: Name of the target
+# GENERATOR: Program (at the moment, must be Python3) to run to generate the
+# source file (and possibly a trace file and module path). It will be
+# invoked with the following standard flags, in addition to GENERATOR_ARGS:
+# --output_code=${CMAKE_CURRENT_BINARY_DIR}/name.mlir
+# --output_trace=${CMAKE_CURRENT_BINARY_DIR}/name.yaml
+# --module_path=${CMAKE_CURRENT_BINARY_DIR}/name.vmfb
+# GENERATOR_ARGS: additional args to pass to the generator program.
+# TARGET_BACKEND: target backend to compile for.
+# DRIVER: driver to run the module with.
+# COMPILER_FLAGS: additional flags to pass to the compiler. Bytecode
+# translation and backend flags are passed automatically.
+# RUNNER_ARGS: additional args to pass to the trace-runner program. The driver
+# and input file flags are passed automatically.
+# LABELS: Additional labels to apply to the test. The package path and
+# "driver=${DRIVER}" are added automatically.
+# OPT_TOOL: Defaulting to iree-opt. Tool used to preprocess the source files
+# if OPT_FLAGS is specified.
+# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
+# these flags.
+# TRACE_RUNNER: trace-runner program to run.
+function(iree_single_backend_generated_trace_runner_test)
+ if(NOT IREE_BUILD_TESTS)
+ return()
+ endif()
+
+ # Copied from iree_check_test. Refer to the comment there.
+ if(NOT IREE_BUILD_COMPILER AND NOT CMAKE_CROSSCOMPILING)
+ return()
+ endif()
+
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME;GENERATOR;TARGET_BACKEND;DRIVER;OPT_TOOL;TRACE_RUNNER"
+ "GENERATOR_ARGS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ ${ARGN}
+ )
+
+ # Omit tests for which the specified driver or target backend is not enabled.
+ # This overlaps with directory exclusions and other filtering mechanisms.
+ string(TOUPPER ${_RULE_DRIVER} _UPPERCASE_DRIVER)
+ if(NOT DEFINED IREE_HAL_DRIVER_${_UPPERCASE_DRIVER})
+ message(SEND_ERROR "Unknown driver '${_RULE_DRIVER}'. Check IREE_ALL_HAL_DRIVERS.")
+ endif()
+ if(NOT IREE_HAL_DRIVER_${_UPPERCASE_DRIVER})
+ return()
+ endif()
+ string(TOUPPER ${_RULE_TARGET_BACKEND} _UPPERCASE_TARGET_BACKEND)
+ if(NOT DEFINED IREE_TARGET_BACKEND_${_UPPERCASE_TARGET_BACKEND})
+ message(SEND_ERROR "Unknown backend '${_RULE_TARGET_BACKEND}'. Check IREE_ALL_TARGET_BACKENDS.")
+ endif()
+ if(DEFINED IREE_HOST_BINARY_ROOT)
+ # If we're not building the host tools from source under this configuration,
+ # such as when cross compiling, then we can't easily check for which
+ # compiler target backends are enabled. Just assume all are enabled and only
+ # rely on the runtime HAL driver check above for filtering.
+ else()
+ # We are building the host tools, so check enabled compiler target backends.
+ if(NOT IREE_TARGET_BACKEND_${_UPPERCASE_TARGET_BACKEND})
+ return()
+ endif()
+ endif()
+
+ iree_package_name(_PACKAGE_NAME)
+ set(_NAME "${_PACKAGE_NAME}_${_RULE_NAME}")
+
+ set(_SRC "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.mlir")
+
+ set(_GENERATOR_OUTPUT "${_SRC}")
+ set(_TRACE "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.yaml")
+ set(_MODULE_FILE_NAME "${CMAKE_CURRENT_BINARY_DIR}/${_RULE_NAME}.vmfb")
+ list(APPEND _GENERATOR_STANDARD_FLAGS "--output_code=${_SRC}")
+ list(APPEND _GENERATOR_STANDARD_FLAGS "--output_trace=${_TRACE}")
+ list(APPEND _GENERATOR_STANDARD_FLAGS "--module_path=${_MODULE_FILE_NAME}")
+ list(APPEND _GENERATOR_OUTPUT "${_TRACE}")
+
+ add_custom_command(
+ COMMAND
+ "${Python3_EXECUTABLE}"
+ "${CMAKE_CURRENT_SOURCE_DIR}/${_RULE_GENERATOR}"
+ ${_GENERATOR_STANDARD_FLAGS}
+ ${_RULE_GENERATOR_ARGS}
+ OUTPUT
+ ${_GENERATOR_OUTPUT}
+ DEPENDS
+ ${_RULE_GENERATOR}
+ )
+
+ add_custom_target(
+ "${_NAME}_generated_files"
+ DEPENDS
+ ${_GENERATOR_OUTPUT}
+ )
+
+ iree_trace_runner_test(
+ NAME
+ "${_RULE_NAME}"
+ SRC
+ "${_SRC}"
+ TRACE
+ "${_TRACE}"
+ TRACE_RUNNER
+ "${_RULE_TRACE_RUNNER}"
+ MODULE_FILE_NAME
+ "${_MODULE_FILE_NAME}"
+ TARGET_BACKEND
+ ${_RULE_TARGET_BACKEND}
+ DRIVER
+ ${_RULE_DRIVER}
+ COMPILER_FLAGS
+ ${_RULE_COMPILER_FLAGS}
+ RUNNER_ARGS
+ ${_RULE_RUNNER_ARGS}
+ LABELS
+ ${_RULE_LABELS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ )
+
+ # Note we are relying on the fact that the target created by
+ # iree_trace_runner_test is _NAME, even though we passed _RULE_NAME to it,
+ # i.e. we are relying on the prefixing to be identical.
+ add_dependencies("${_NAME}" "${_NAME}_generated_files")
+endfunction()
+
+
+# iree_generated_trace_runner_test()
+#
+# Creates a set of iree_single_backend_generated_trace_runner_test's differing
+# by target backend and driver.
+#
+# Mirrors the bzl rule of the same name.
+#
+# One test is generated per source and backend/driver pair.
+# Parameters:
+# NAME: Name of the target
+# GENERATOR: Program (at the moment, must be Python3) to run to generate the
+# source file (and possibly a trace file and module path). It will be
+# invoked with the following standard flags, in addition to GENERATOR_ARGS:
+# --output_code=${CMAKE_CURRENT_BINARY_DIR}/name.mlir
+# --output_trace=${CMAKE_CURRENT_BINARY_DIR}/name.yaml
+# --module_path=${CMAKE_CURRENT_BINARY_DIR}/name.vmfb
+# GENERATOR_ARGS: additional args to pass to the generator program.
+# TARGET_BACKENDS: backends to compile the module for. These form pairs with
+# the DRIVERS argument (due to cmake limitations they are separate list
+# arguments). The lengths must exactly match. If no backends or drivers are
+# specified, a test will be generated for every supported pair.
+# DRIVERS: drivers to run the module with. These form pairs with the
+# TARGET_BACKENDS argument (due to cmake limitations they are separate list
+# arguments). The lengths must exactly match. If no backends or drivers are
+# specified, a test will be generated for every supported pair.
+# COMPILER_FLAGS: additional flags to pass to the compiler. Bytecode
+# translation and backend flags are passed automatically.
+# RUNNER_ARGS: additional args to pass to the trace-runner program. The driver
+# and input file flags are passed automatically.
+# LABELS: Additional labels to apply to the test. The package path and
+# "driver=${DRIVER}" are added automatically.
+# OPT_TOOL: Defaulting to iree-opt. Tool used to preprocess the source files
+# if OPT_FLAGS is specified.
+# OPT_FLAGS: If specified, source files are preprocessed with OPT_TOOL with
+# these flags.
+# TRACE_RUNNER: trace-runner program to run.
+function(iree_generated_trace_runner_test)
+ if(NOT IREE_BUILD_TESTS)
+ return()
+ endif()
+
+ cmake_parse_arguments(
+ _RULE
+ ""
+ "NAME;GENERATOR;OPT_TOOL;TRACE_RUNNER"
+ "TARGET_BACKENDS;DRIVERS;GENERATOR_ARGS;COMPILER_FLAGS;RUNNER_ARGS;LABELS;OPT_FLAGS"
+ ${ARGN}
+ )
+
+ if(NOT DEFINED _RULE_TARGET_BACKENDS AND NOT DEFINED _RULE_DRIVERS)
+ set(_RULE_TARGET_BACKENDS "vmvx" "vulkan-spirv" "dylib-llvm-aot")
+ set(_RULE_DRIVERS "vmvx" "vulkan" "dylib")
+ endif()
+
+ list(LENGTH _RULE_TARGET_BACKENDS _TARGET_BACKEND_COUNT)
+ list(LENGTH _RULE_DRIVERS _DRIVER_COUNT)
+
+ if(NOT _TARGET_BACKEND_COUNT EQUAL _DRIVER_COUNT)
+ message(SEND_ERROR
+ "TARGET_BACKENDS count ${_TARGET_BACKEND_COUNT} does not match DRIVERS count ${_DRIVER_COUNT}")
+ endif()
+
+ math(EXPR _MAX_INDEX "${_TARGET_BACKEND_COUNT} - 1")
+ foreach(_INDEX RANGE "${_MAX_INDEX}")
+ list(GET _RULE_TARGET_BACKENDS ${_INDEX} _TARGET_BACKEND)
+ list(GET _RULE_DRIVERS ${_INDEX} _DRIVER)
+ set(_SINGLE_BACKEND_TEST_NAME "${_RULE_NAME}_${_TARGET_BACKEND}_${_DRIVER}")
+ iree_single_backend_generated_trace_runner_test(
+ NAME
+ ${_SINGLE_BACKEND_TEST_NAME}
+ GENERATOR
+ ${_RULE_GENERATOR}
+ GENERATOR_ARGS
+ ${_RULE_GENERATOR_ARGS}
+ TRACE_RUNNER
+ ${_RULE_TRACE_RUNNER}
+ TARGET_BACKEND
+ ${_TARGET_BACKEND}
+ DRIVER
+ ${_DRIVER}
+ COMPILER_FLAGS
+ ${_RULE_COMPILER_FLAGS}
+ RUNNER_ARGS
+ ${_RULE_RUNNER_ARGS}
+ LABELS
+ ${_RULE_LABELS}
+ OPT_TOOL
+ ${_RULE_OPT_TOOL}
+ OPT_FLAGS
+ ${_RULE_OPT_FLAGS}
+ )
+ endforeach()
+endfunction()
diff --git a/build_tools/cmake/run_android_test.sh b/build_tools/cmake/run_android_test.sh
index 201a95f..caf0b34 100755
--- a/build_tools/cmake/run_android_test.sh
+++ b/build_tools/cmake/run_android_test.sh
@@ -15,7 +15,7 @@
# This script reads the following environment variables:
# - TEST_ANDROID_ABS_DIR: the absolute path on Android device for the build
# artifacts.
-# - TEST_DATA: optional; the data file to push to the Android device.
+# - TEST_DATA: optional; the files to push to the Android device. Space-separated.
# - TEST_EXECUTABLE: the executable file to push to the Android device.
# - TEST_TMPDIR: optional; temporary directory on the Android device for
# running tests.
@@ -30,7 +30,10 @@
adb push $TEST_EXECUTABLE $TEST_ANDROID_ABS_DIR/$(basename $TEST_EXECUTABLE)
if [ -n "$TEST_DATA" ]; then
- adb push $TEST_DATA $TEST_ANDROID_ABS_DIR/$(basename $TEST_DATA)
+ for datafile in $TEST_DATA
+ do
+ adb push "$datafile" "$TEST_ANDROID_ABS_DIR/$(basename "$datafile")"
+ done
fi
if [ -n "$TEST_TMPDIR" ]; then
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index 872d0b4..77e48c3 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -11,6 +11,7 @@
load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
load("//iree:lit_test.bzl", "iree_lit_test_suite")
load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")
+load("//build_tools/bazel:iree_trace_runner_test.bzl", "iree_generated_trace_runner_test")
package(
default_visibility = ["//visibility:public"],
@@ -103,3 +104,62 @@
],
target_backend = "cuda",
)
+
+[iree_generated_trace_runner_test(
+ name = "e2e_matmul_direct_%s_small" % lhs_rhs_type,
+ generator = "generate_e2e_matmul_tests.py",
+ generator_args = [
+ "--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--shapes=small",
+ ],
+ target_backends_and_drivers = [
+ ("dylib-llvm-aot", "dylib"),
+ ("vmvx", "vmvx"),
+ ],
+ trace_runner = "//iree/tools:iree-e2e-matmul-test",
+) for lhs_rhs_type in [
+ "i8",
+ "f32",
+]]
+
+[iree_generated_trace_runner_test(
+ name = "e2e_matmul_mmt4d_%s_small" % lhs_rhs_type,
+ generator = "generate_e2e_matmul_tests.py",
+ generator_args = [
+ "--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--shapes=small",
+ ],
+ opt_flags = [
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
+ "--iree-codegen-vectorize-linalg-mmt4d",
+ ],
+ target_backends_and_drivers = [
+ ("dylib-llvm-aot", "dylib"),
+ ("vmvx", "vmvx"),
+ ],
+ trace_runner = "//iree/tools:iree-e2e-matmul-test",
+) for lhs_rhs_type in [
+ "i8",
+ "f32",
+]]
+
+[iree_generated_trace_runner_test(
+ name = "e2e_matmul_mmt4d_%s_large" % lhs_rhs_type,
+ generator = "generate_e2e_matmul_tests.py",
+ generator_args = [
+ "--lhs_rhs_type=%s" % lhs_rhs_type,
+ "--shapes=large",
+ ],
+ opt_flags = [
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=%d N0=8" % (4 if lhs_rhs_type == "i8" else 1),
+ "--iree-codegen-vectorize-linalg-mmt4d",
+ ],
+ target_backends_and_drivers = [
+ ("dylib-llvm-aot", "dylib"),
+ # TODO: enable VMVX. Skipped for now: it's very slow for these large matmul tests.
+ ],
+ trace_runner = "//iree/tools:iree-e2e-matmul-test",
+) for lhs_rhs_type in [
+ "i8",
+ "f32",
+]]
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index cb1ea0b..795a874 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -122,4 +122,120 @@
"requires-gpu-nvidia"
)
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_direct_i8_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=i8"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree_tools_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "dylib-llvm-aot"
+ "vmvx"
+ DRIVERS
+ "dylib"
+ "vmvx"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_direct_f32_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree_tools_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "dylib-llvm-aot"
+ "vmvx"
+ DRIVERS
+ "dylib"
+ "vmvx"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_mmt4d_i8_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=i8"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree_tools_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "dylib-llvm-aot"
+ "vmvx"
+ DRIVERS
+ "dylib"
+ "vmvx"
+ OPT_FLAGS
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ "--iree-codegen-vectorize-linalg-mmt4d"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_mmt4d_f32_small
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--shapes=small"
+ TRACE_RUNNER
+ iree_tools_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "dylib-llvm-aot"
+ "vmvx"
+ DRIVERS
+ "dylib"
+ "vmvx"
+ OPT_FLAGS
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ "--iree-codegen-vectorize-linalg-mmt4d"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_mmt4d_i8_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=i8"
+ "--shapes=large"
+ TRACE_RUNNER
+ iree_tools_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "dylib-llvm-aot"
+ DRIVERS
+ "dylib"
+ OPT_FLAGS
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=4 N0=8"
+ "--iree-codegen-vectorize-linalg-mmt4d"
+)
+
+iree_generated_trace_runner_test(
+ NAME
+ e2e_matmul_mmt4d_f32_large
+ GENERATOR
+ "generate_e2e_matmul_tests.py"
+ GENERATOR_ARGS
+ "--lhs_rhs_type=f32"
+ "--shapes=large"
+ TRACE_RUNNER
+ iree_tools_iree-e2e-matmul-test
+ TARGET_BACKENDS
+ "dylib-llvm-aot"
+ DRIVERS
+ "dylib"
+ OPT_FLAGS
+ "--iree-flow-convert-linalg-matmul-to-mmt4d=M0=8 K0=1 N0=8"
+ "--iree-codegen-vectorize-linalg-mmt4d"
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/test/e2e/regression/generate_e2e_matmul_tests.py b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
new file mode 100644
index 0000000..26aa509
--- /dev/null
+++ b/iree/test/e2e/regression/generate_e2e_matmul_tests.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""iree_generated_check_test generator for end-to-end matrix multiplication.
+"""
+
+import argparse
+import random
+import os
+import yaml
+import re
+
+
+# Returns lists of shapes as (M, K, N) tuples.
+# For example (M, K, 1) is a matrix*vector product, and (M, 1, N) is an outer
+# product.
+def get_test_shapes():
+ return {
+ "small": [ # Small sizes, square matrices
+ (x, x, x) for x in range(1, 40)
+ ] + [
+ # Small sizes, slightly rectangular matrices
+ (2, 3, 4),
+ (8, 7, 6),
+ (15, 16, 17),
+ (14, 19, 23),
+ (31, 33, 32),
+ (25, 41, 35),
+ # Small sizes, involving vectors (i.e. most rectangular cases)
+ (10, 1, 1),
+ (1, 10, 1),
+ (1, 1, 10),
+ (1, 10, 10),
+ (10, 1, 10),
+ (10, 10, 1),
+ # Small sizes, involving other very small dimensions just above 1
+ (13, 14, 2),
+ (3, 17, 12),
+ (21, 4, 18),
+ # Medium sizes, square matrices
+ (100, 100, 100),
+ # Medium sizes, slightly rectangular matrices
+ (101, 102, 103),
+ # Medium sizes, involving vectors (i.e. most rectangular cases)
+ (10000, 1, 1),
+ (1, 10000, 1),
+ (1, 1, 10000),
+ (1, 1000, 1000),
+ (1000, 1, 1000),
+ (1000, 1000, 1),
+ # Medium sizes, involving other very small dimensions just above 1
+ (1300, 1300, 2),
+ (1300, 1300, 3),
+ (1300, 1300, 4),
+ ],
+ "large": [
+ # Large sizes, powers of two
+ (256, 256, 512),
+ (512, 512, 128),
+ (1024, 512, 512),
+ (512, 1024, 512),
+ # Large sizes, powers of two minus one
+ (127, 63, 511),
+ # Large sizes, powers of two plus one
+ (129, 65, 513),
+ # Large sizes, misc.
+ (200, 300, 400),
+ (123, 456, 789),
+ (500, 500, 50),
+ # Be conservative in adding larger shapes. They can result in
+ # high latency tests. If you have to, consider splitting them
+ # out in a way that constrains the latency impact, e.g. by
+ # running on fewer backends/drivers or with fewer generators
+ # (see get_test_generators).
+ ]
+ }
+
+
+# Returns lists of 'generators', which are tuples of the form
+# (lhs_generator, rhs_generator, acc_generator, dynamicity)
+# The first 3 entries specify how to generate test input data.
+# The dynamicity entry chooses between static, dynamic or mixed shapes.
+#
+# TODO (Issue #7431): turn into enum and dataclass.
+def get_test_generators():
+ return {
+ "small": [
+ # Generators using simple matrices for ease of numerical debugging.
+ # They don't add significant test coverage (all bugs are hit by
+ # tests using random matrices anyway). They are only here to make
+ # the bulk of our debugging easier.
+ ("identity", "identity", "zero", "dynamic"),
+ ("random", "identity", "zero", "dynamic"),
+ ("identity", "random", "zero", "dynamic"),
+ ("identity", "identity", "random", "dynamic"),
+ # Generators using general random matrices
+ ("random", "random", "random", "dynamic"),
+ ("random", "random", "random", "static"),
+ # TODO: enable 'mixed' testcases. For now they cause iree-opt
+ # errors.
+ #("random", "random", "random", "mixed"),
+ ],
+ "large": [
+ # Fewer generators are used for large shapes, to limit the
+ # latency impact. Most bugs are going to be caught on small
+ # shapes anyway.
+ ("random", "random", "random", "dynamic"),
+ ("random", "random", "random", "static"),
+ ]
+ }
+
+
+# Generates a name for a test function in the generated MLIR code.
+def function_name(lhs_rhs_type, accum_type, shape, gen):
+ return f"{lhs_rhs_type}_{gen[3]}_{gen[0]}_{shape[0]}x{shape[1]}_times_{gen[1]}_{shape[1]}x{shape[2]}_plus_{gen[2]}_{accum_type}"
+
+
+# Intentionally fixed seed! We want full reproducibility here, both across runs
+# and across machines.
+# Intentionally not shared with pseudorandom_generator_seed to limit the ways
+# in which shuffling testcases changes which random values are generated.
+local_pseudorandom_state = 1
+
+
+# Generates a compile-time MLIR size value, i.e. either a fixed positive integer
+# or a '?' depending on dynamicity.
+def static_size(x, dynamicity):
+ if dynamicity == "dynamic":
+ return "?"
+ elif dynamicity == "static":
+ return x
+ elif dynamicity == "mixed":
+ global local_pseudorandom_state
+ # Same as C++ std::minstd_rand.
+ # Using a local pseudorandom generator implementation ensures that it's
+ # completely reproducible, across runs and across machines.
+ local_pseudorandom_state = (local_pseudorandom_state * 48271) % 2147483647
+ return x if local_pseudorandom_state > 1073741824 else "?"
+ else:
+ raise ValueError(dynamicity)
+
+
+# Generates a test function in the generated MLIR code.
+# The generated function will take the same arguments as linalg.matmul and
+# will just call linalg.matmul with them, returning its result.
+def generate_function(func_name, lhs_rhs_type, accum_type, shape, gen):
+ (m, k, n) = shape
+ lhs_m = static_size(m, gen[3])
+ lhs_k = static_size(k, gen[3])
+ rhs_k = static_size(k, gen[3])
+ rhs_n = static_size(n, gen[3])
+ acc_m = static_size(m, gen[3])
+ acc_n = static_size(n, gen[3])
+ lhs_tensor_type = f"tensor<{lhs_m}x{lhs_k}x{lhs_rhs_type}>"
+ rhs_tensor_type = f"tensor<{rhs_k}x{rhs_n}x{lhs_rhs_type}>"
+ acc_tensor_type = f"tensor<{acc_m}x{acc_n}x{accum_type}>"
+ return (
+ f"func @{func_name}(%lhs: {lhs_tensor_type}, %rhs: {rhs_tensor_type}, %acc: {acc_tensor_type}) -> {acc_tensor_type} {{\n"
+ f" %result = linalg.matmul ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
+ f" return %result: {acc_tensor_type}\n"
+ f"}}\n")
+
+
+# Intentionally fixed seed! We want full reproducibility here, both across runs
+# and across machines.
+# Intentionally not shared with local_pseudorandom_state to limit the ways
+# in which shuffling testcases changes which random values are generated.
+pseudorandom_generator_seed = 1
+
+
+# Generates a contents_generator tag to use in the output trace.
+def contents_generator_tag(generator):
+ if generator == "zero":
+ return ""
+ elif generator == "identity":
+ return "!tag:iree:identity_matrix"
+ elif generator == "random":
+ global pseudorandom_generator_seed
+ pseudorandom_generator_seed = pseudorandom_generator_seed + 1
+ return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}"
+ else:
+ raise ValueError(generator)
+
+
+# Generate a matrix function argument in the output trace, as a dictionary
+# to be passed to yaml.dump.
+def generate_trace_matrix_arg(matrix_shape, element_type, generator):
+ result = {
+ "type": "hal.buffer_view",
+ "shape": matrix_shape,
+ "element_type": element_type,
+ }
+ generator_tag = contents_generator_tag(generator)
+ if generator_tag:
+ result["contents_generator"] = generator_tag
+ return result
+
+
+# Generates the output trace for a testcase i.e. a single test function call,
+# as a dictionary to be passed to yaml.dump.
+def generate_trace(func_name, lhs_rhs_type, acc_type, shape, gen):
+ (m, k, n) = shape
+ lhs_arg = generate_trace_matrix_arg([m, k], lhs_rhs_type, gen[0])
+ rhs_arg = generate_trace_matrix_arg([k, n], lhs_rhs_type, gen[1])
+ acc_arg = generate_trace_matrix_arg([m, n], acc_type, gen[2])
+ result_arg = generate_trace_matrix_arg([m, n], acc_type, "zero")
+ return {
+ "type": "call",
+ "function": "module." + func_name,
+ "args": [
+ lhs_arg,
+ rhs_arg,
+ acc_arg,
+ ],
+ "results": [result_arg,],
+ }
+
+
+# Generates all output files' contents as strings.
+def generate(args):
+ functions = {}
+ traces = []
+ lhs_rhs_type = args.lhs_rhs_type
+ accum_type = 'i32' if lhs_rhs_type == 'i8' else lhs_rhs_type
+ for shape in get_test_shapes()[args.shapes]:
+ for gen in get_test_generators()[args.shapes]:
+ func_name = function_name(lhs_rhs_type, accum_type, shape, gen)
+ # Different testcases may differ only by runtime parameters but
+ # share the same code. For example, dynamic-shapes testcases
+ # share the same code involing tensor<?x?xf32> even though the runtime
+ # value in the trace are different. That's why we call
+ # generate_function conditionally, and generate_trace unconditionally.
+ if func_name not in functions:
+ functions[func_name] = generate_function(func_name, lhs_rhs_type,
+ accum_type, shape, gen)
+ traces.append(
+ generate_trace(func_name, lhs_rhs_type, accum_type, shape, gen))
+ return (functions, traces)
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="Generator of e2e matmul tests")
+ parser.add_argument("--output_code",
+ type=str,
+ help="Path of output .mlir file",
+ required=True)
+ parser.add_argument("--output_trace",
+ type=str,
+ help="Path of output .yaml trace file",
+ required=True)
+ parser.add_argument("--lhs_rhs_type",
+ type=str,
+ choices=["i8", "f32"],
+ help="Numeric type of input matrices",
+ required=True)
+ parser.add_argument("--shapes",
+ type=str,
+ choices=["small", "large"],
+ help="Collection of matrix shapes to test",
+ required=True)
+ parser.add_argument(
+ "--module_path",
+ type=str,
+ help=
+ "Module path (typically .vmfb) to be referenced in the output trace. Should match the output path of the iree-translate command generating the module.",
+ required=True)
+
+ return parser.parse_args()
+
+
+def write_code_file(functions, filename):
+ with open(filename, "w") as file:
+ for funcname in functions:
+ file.write(functions[funcname] + "\n")
+
+
+def write_trace_file(traces, filename, module_path):
+ yaml_documents = [
+ {
+ "type": "context_load",
+ },
+ {
+ "type": "module_load",
+ "module": {
+ "name": "hal",
+ "type": "builtin",
+ }
+ },
+ {
+ "type": "module_load",
+ "module": {
+ "name": "module",
+ "type": "bytecode",
+ "path": os.path.relpath(module_path, os.path.dirname(filename))
+ }
+ },
+ ]
+
+ for trace in traces:
+ yaml_documents.append(trace)
+
+ dumped_yaml = yaml.dump_all(yaml_documents, sort_keys=False)
+
+ processed_yaml = re.sub(r"'(![^']*)'", "\\1", dumped_yaml)
+
+ with open(filename, "w") as file:
+ file.write(processed_yaml)
+
+
+def main(args):
+ (functions, traces) = generate(args)
+ write_code_file(functions, args.output_code)
+ write_trace_file(traces, args.output_trace, args.module_path)
+
+
+if __name__ == "__main__":
+ main(parse_arguments())
diff --git a/iree/tools/CMakeLists.txt b/iree/tools/CMakeLists.txt
index fc29931..6050277 100644
--- a/iree/tools/CMakeLists.txt
+++ b/iree/tools/CMakeLists.txt
@@ -168,6 +168,25 @@
yaml
)
+iree_cc_binary(
+ NAME
+ iree-e2e-matmul-test
+ SRCS
+ "iree-e2e-matmul-test.c"
+ DEPS
+ iree::base
+ iree::base::internal::file_path
+ iree::base::internal::flags
+ iree::base::tracing
+ iree::hal
+ iree::hal::drivers
+ iree::modules::hal
+ iree::tools::utils::trace_replay
+ iree::tools::utils::yaml_util
+ iree::vm
+ yaml
+)
+
if(${IREE_BUILD_COMPILER})
iree_cc_binary(
NAME
diff --git a/iree/tools/compilation.bzl b/iree/tools/compilation.bzl
index a75a5a6..ae4fd03 100644
--- a/iree/tools/compilation.bzl
+++ b/iree/tools/compilation.bzl
@@ -15,11 +15,31 @@
flags = ["-iree-mlir-to-vm-bytecode-module"],
translate_tool = "//iree/tools:iree-translate",
embedded_linker_tool = "@llvm-project//lld:lld",
+ opt_tool = "//iree/tools:iree-opt",
+ opt_flags = [],
c_identifier = "",
**kwargs):
+ translate_src = src
+ if opt_flags:
+ translate_src = "%s.opt.mlir" % (name)
+ native.genrule(
+ name = "%s_opt" % (name),
+ srcs = [src],
+ outs = [translate_src],
+ cmd = " ".join([
+ "$(location %s)" % (opt_tool),
+ " ".join([('"%s"' % flag) for flag in opt_flags]),
+ "$(location %s)" % (src),
+ "-o $(location %s)" % (translate_src),
+ ]),
+ tools = [opt_tool],
+ message = "Transforming MLIR source for IREE module %s..." % (name),
+ output_to_bindir = 1,
+ )
+
native.genrule(
name = name,
- srcs = [src],
+ srcs = [translate_src],
outs = [
"%s.vmfb" % (name),
],
@@ -29,7 +49,7 @@
" ".join(flags),
"-iree-llvm-embedded-linker-path=$(location %s)" % (embedded_linker_tool),
"-o $(location %s.vmfb)" % (name),
- "$(location %s)" % (src),
+ "$(location %s)" % (translate_src),
]),
]),
tools = [translate_tool, embedded_linker_tool],
diff --git a/iree/tools/iree-e2e-matmul-test.c b/iree/tools/iree-e2e-matmul-test.c
new file mode 100644
index 0000000..a48fb8d
--- /dev/null
+++ b/iree/tools/iree-e2e-matmul-test.c
@@ -0,0 +1,713 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+
+#include "iree/base/api.h"
+#include "iree/base/internal/file_path.h"
+#include "iree/base/internal/flags.h"
+#include "iree/base/target_platform.h"
+#include "iree/hal/api.h"
+#include "iree/hal/buffer_view.h"
+#include "iree/hal/drivers/init.h"
+#include "iree/modules/hal/module.h"
+#include "iree/tools/utils/trace_replay.h"
+#include "iree/tools/utils/yaml_util.h"
+#include "iree/vm/api.h"
+
+IREE_FLAG(bool, trace_execution, false, "Traces VM execution to stderr.");
+
+IREE_FLAG(string, driver, "vmvx", "Backend driver to use.");
+
+// We rely on environment variables for some internal knobs because they are
+// easier to propagate through ctest to this program than command-line
+// arguments.
+const char* portable_getenv(const char* env_var) {
+#ifdef IREE_PLATFORM_WINDOWS
+ return _getenv(env_var);
+#else
+ return getenv(env_var);
+#endif
+}
+
+// Helper to get a list item as a buffer_view.
+static iree_status_t iree_get_buffer_view_list_item(
+ iree_vm_list_t* list, iree_host_size_t i,
+ iree_hal_buffer_view_t** out_value) {
+ iree_vm_variant_t variant = iree_vm_variant_empty();
+ IREE_RETURN_IF_ERROR(iree_vm_list_get_variant(list, i, &variant));
+ if (!iree_vm_variant_is_ref(variant)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "expected list item %zu to be a ref", i);
+ }
+ return iree_hal_buffer_view_check_deref(variant.ref, out_value);
+}
+
+// Helper to get the shape of a buffer_view that is a matrix, meaning
+// has a 2D shape with positive dimensions.
+static iree_status_t get_matrix_buffer_view_shape(
+ iree_hal_buffer_view_t* buffer_view, iree_hal_dim_t* dims) {
+ iree_host_size_t shape_rank = iree_hal_buffer_view_shape_rank(buffer_view);
+ if (shape_rank != 2) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "expected a matrix (2D tensor) shape, got a %zu-dimensional shape",
+ shape_rank);
+ }
+ dims[0] = iree_hal_buffer_view_shape_dim(buffer_view, 0);
+ dims[1] = iree_hal_buffer_view_shape_dim(buffer_view, 1);
+ if (!(dims[0] > 0 && dims[1] > 0)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "expected matrix dims to be positive, got %dx%d",
+ dims[0], dims[1]);
+ }
+ return iree_ok_status();
+}
+
+// Helper to get a pointer to dense row-major data in a buffer_view.
+static iree_status_t get_buffer_view_dense_row_major_data(
+ iree_hal_buffer_view_t* buffer_view, void** data) {
+ if (iree_hal_buffer_view_encoding_type(buffer_view) !=
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "buffer_view is not dense row major");
+ }
+ iree_hal_buffer_mapping_t mapping;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ iree_hal_buffer_view_buffer(buffer_view), IREE_HAL_MEMORY_ACCESS_READ, 0,
+ IREE_WHOLE_BUFFER, &mapping));
+ *data = mapping.contents.data;
+ return iree_ok_status();
+}
+
+// Helper for iree_check_matmul and reference_matmul:
+// obtain and validate the {m,k,n}_size values.
+static iree_status_t get_matmul_sizes(
+ iree_hal_buffer_view_t* lhs, iree_hal_buffer_view_t* rhs,
+ iree_hal_buffer_view_t* acc, iree_hal_buffer_view_t* result,
+ iree_hal_dim_t* m_size, iree_hal_dim_t* k_size, iree_hal_dim_t* n_size) {
+ iree_hal_dim_t lhs_dims[2];
+ iree_hal_dim_t rhs_dims[2];
+ iree_hal_dim_t acc_dims[2];
+ iree_hal_dim_t result_dims[2];
+ IREE_RETURN_IF_ERROR(get_matrix_buffer_view_shape(lhs, lhs_dims));
+ IREE_RETURN_IF_ERROR(get_matrix_buffer_view_shape(rhs, rhs_dims));
+ IREE_RETURN_IF_ERROR(get_matrix_buffer_view_shape(acc, acc_dims));
+ IREE_RETURN_IF_ERROR(get_matrix_buffer_view_shape(result, result_dims));
+ *m_size = lhs_dims[0];
+ *k_size = lhs_dims[1];
+ *n_size = rhs_dims[1];
+ if (!(lhs_dims[0] == *m_size && lhs_dims[1] == *k_size &&
+ rhs_dims[0] == *k_size && rhs_dims[1] == *n_size &&
+ acc_dims[0] == *m_size && acc_dims[1] == *n_size &&
+ result_dims[0] == *m_size && result_dims[1] == *n_size)) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "mismatched matrix shapes in matmul: %dx%d * %dx%d + %dx%d -> %dx%d",
+ lhs_dims[0], lhs_dims[1], rhs_dims[0], rhs_dims[1], acc_dims[0],
+ acc_dims[1], result_dims[0], result_dims[1]);
+ }
+ return iree_ok_status();
+}
+
+// Helper for reference_matmul_element. f32 case.
+static void reference_matmul_element_f32(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+ float* lhs_data, float* rhs_data, float* acc_data, float* result_data,
+ iree_hal_dim_t m, iree_hal_dim_t n) {
+ float acc = acc_data[n + m * n_size];
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ float lhs_value = lhs_data[k + m * k_size];
+ float rhs_value = rhs_data[n + k * n_size];
+ acc += lhs_value * rhs_value;
+ }
+ result_data[n + m * n_size] = acc;
+}
+
+// Helper for reference_matmul_element. i8*i8->i32 case.
+static void reference_matmul_element_i8_i8_i32(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+ int8_t* lhs_data, int8_t* rhs_data, int32_t* acc_data, int32_t* result_data,
+ iree_hal_dim_t m, iree_hal_dim_t n) {
+ int32_t acc = acc_data[n + m * n_size];
+ for (iree_hal_dim_t k = 0; k < k_size; ++k) {
+ int8_t lhs_value = lhs_data[k + m * k_size];
+ int8_t rhs_value = rhs_data[n + k * n_size];
+ acc += ((int32_t)lhs_value) * ((int32_t)rhs_value);
+ }
+ result_data[n + m * n_size] = acc;
+}
+
+// Helper for reference_matmul.
+// Computes one element in the result matrix.
+static void reference_matmul_element(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type,
+ iree_hal_element_type_t acc_type, void* lhs_data, void* rhs_data,
+ void* acc_data, void* result_data, iree_hal_dim_t m, iree_hal_dim_t n) {
+ if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_matmul_element_f32(m_size, k_size, n_size, lhs_type, rhs_type,
+ (float*)lhs_data, (float*)rhs_data,
+ (float*)acc_data, (float*)result_data, m, n);
+ } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_SINT_8 &&
+ rhs_type == IREE_HAL_ELEMENT_TYPE_SINT_8 &&
+ acc_type == IREE_HAL_ELEMENT_TYPE_SINT_32) {
+ reference_matmul_element_i8_i8_i32(
+ m_size, k_size, n_size, lhs_type, rhs_type, (int8_t*)lhs_data,
+ (int8_t*)rhs_data, (int32_t*)acc_data, (int32_t*)result_data, m, n);
+ } else {
+ iree_status_abort(
+ iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unhandled combination of element types in matmul"));
+ }
+}
+
+// Reference matmul implementation, used to compare matmul results against.
+static iree_status_t reference_matmul(iree_vm_list_t* input_list,
+ iree_hal_buffer_view_t* result) {
+ iree_hal_buffer_view_t* lhs;
+ iree_hal_buffer_view_t* rhs;
+ iree_hal_buffer_view_t* acc;
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(input_list, 0, &lhs));
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(input_list, 1, &rhs));
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(input_list, 2, &acc));
+
+ iree_hal_dim_t m_size, k_size, n_size;
+ IREE_RETURN_IF_ERROR(
+ get_matmul_sizes(lhs, rhs, acc, result, &m_size, &k_size, &n_size));
+ void* lhs_data;
+ void* rhs_data;
+ void* acc_data;
+ void* result_data;
+ IREE_RETURN_IF_ERROR(get_buffer_view_dense_row_major_data(lhs, &lhs_data));
+ IREE_RETURN_IF_ERROR(get_buffer_view_dense_row_major_data(rhs, &rhs_data));
+ IREE_RETURN_IF_ERROR(get_buffer_view_dense_row_major_data(acc, &acc_data));
+ IREE_RETURN_IF_ERROR(
+ get_buffer_view_dense_row_major_data(result, &result_data));
+ iree_hal_element_type_t lhs_type = iree_hal_buffer_view_element_type(lhs);
+ iree_hal_element_type_t rhs_type = iree_hal_buffer_view_element_type(rhs);
+ iree_hal_element_type_t acc_type = iree_hal_buffer_view_element_type(acc);
+ for (iree_hal_dim_t m = 0; m < m_size; ++m) {
+ for (iree_hal_dim_t n = 0; n < n_size; ++n) {
+ reference_matmul_element(m_size, k_size, n_size, lhs_type, rhs_type,
+ acc_type, lhs_data, rhs_data, acc_data,
+ result_data, m, n);
+ }
+ }
+ return iree_ok_status();
+}
+
+// Reads an element from a (row-major) matrix.
+static iree_vm_value_t read_matrix_element(iree_hal_dim_t m_size,
+ iree_hal_dim_t n_size,
+ iree_hal_element_type_t result_type,
+ void* data, iree_hal_dim_t m,
+ iree_hal_dim_t n) {
+ iree_host_size_t index = n + m * n_size;
+ (void)m_size;
+ switch (result_type) {
+ case IREE_HAL_ELEMENT_TYPE_FLOAT_32:
+ return iree_vm_value_make_f32(((float*)data)[index]);
+ case IREE_HAL_ELEMENT_TYPE_SINT_32:
+ return iree_vm_value_make_i32(((int32_t*)data)[index]);
+ case IREE_HAL_ELEMENT_TYPE_SINT_8:
+ return iree_vm_value_make_i8(((int8_t*)data)[index]);
+ default:
+ iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unhandled matmul result type"));
+ return iree_vm_value_make_none();
+ }
+}
+
+typedef enum precision_e {
+ PRECISION_LOW,
+ PRECISION_HIGH,
+} precision_t;
+
+// Prints a iree_vm_value_t to a string buffer. Returns the number of
+// characters written. Like snprintf.
+static int snprintf_value(char* buf, size_t bufsize, iree_vm_value_t value,
+ precision_t precision) {
+ switch (value.type) {
+ case IREE_VM_VALUE_TYPE_I8:
+ return snprintf(buf, bufsize, "%" PRIi8, value.i8);
+ case IREE_VM_VALUE_TYPE_I16:
+ return snprintf(buf, bufsize, "%" PRIi16, value.i16);
+ case IREE_VM_VALUE_TYPE_I32:
+ return snprintf(buf, bufsize, "%" PRIi32, value.i32);
+ case IREE_VM_VALUE_TYPE_I64:
+ return snprintf(buf, bufsize, "%" PRIi64, value.i64);
+ case IREE_VM_VALUE_TYPE_F32:
+ return snprintf(buf, bufsize,
+ precision == PRECISION_HIGH ? "%.8g" : "%.4g", value.f32);
+ case IREE_VM_VALUE_TYPE_F64:
+ return snprintf(buf, bufsize,
+ precision == PRECISION_HIGH ? "%.16g" : "%.4g",
+ value.f64);
+ default:
+ iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unhandled value type"));
+ return 0;
+ }
+}
+
+// Returns true if |expected| and |actual| agree to tolerable accuracy.
+static bool matmul_result_elements_agree(iree_vm_value_t expected,
+ iree_vm_value_t actual) {
+ if (expected.type != actual.type) {
+ iree_status_abort(
+ iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "mismatched types"));
+ return false;
+ }
+ switch (expected.type) {
+ case IREE_VM_VALUE_TYPE_I32:
+ return actual.i32 == expected.i32;
+ case IREE_VM_VALUE_TYPE_F32:
+ // The absolute value difference comparison here is naive, bad.
+ //
+ // Why it's almost good enough: we are only testing matmuls here, not even
+ // fused with any other op. Because of how matmul is defined (as a
+ // polynomial expression with coefficients either 0 or 1), it's going to
+ // be either correct or completely wrong. That wouldn't be
+ // true if we were pursuing non-trivial accumulation strategies limiting
+ // accumulation depth, but we are not doing that. Also, we are not testing
+ // huge sizes, and all our test data is in the same order of magnitude.
+ //
+ // What would be the better thing to do here: adjust the tolerated
+ // absolute value difference based on the magnitude of the matrix
+ // elements, the accumulation depth (k_size) and the accumulator type's
+ // epsilon. Floating-point calculations should be scale-invariant: matmul
+ // tests should succeed or fail in the same way if we rescale all input
+ // data by a constant factor (as long as we don't run out of exponents).
+ return fabsf(actual.f32 - expected.f32) < 1e-3f;
+ default:
+ iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unhandled value type"));
+ return false;
+ }
+}
+
+// Prints |matrix| to |file|, with |label| as caption.
+// |precision| controls how many decimals are printed for float values.
+//
+// If |other_matrix| is not NULL, then any matrix entries that disagree
+// between |matrix| and |other_matrix| (according to
+// matmul_result_elements_agree) are highlighted.
+static void print_matrix(FILE* file, const char* label, precision_t precision,
+ int row_start, int row_end, int col_start, int col_end,
+ iree_hal_buffer_view_t* matrix,
+ iree_hal_buffer_view_t* other_matrix) {
+ iree_hal_dim_t dims[2];
+ get_matrix_buffer_view_shape(matrix, dims);
+ int rows = dims[0];
+ int cols = dims[1];
+ iree_hal_element_type_t elem_type = iree_hal_buffer_view_element_type(matrix);
+ void* data = 0;
+ get_buffer_view_dense_row_major_data(matrix, &data);
+ void* other_data = 0;
+ if (other_matrix) {
+ get_buffer_view_dense_row_major_data(other_matrix, &other_data);
+ }
+ int max_elem_width = 0;
+ for (int row = row_start; row < row_end; row++) {
+ for (int col = col_start; col < col_end; col++) {
+ iree_vm_value_t elem =
+ read_matrix_element(rows, cols, elem_type, data, row, col);
+ char buf[64];
+ max_elem_width = iree_max(
+ max_elem_width, snprintf_value(buf, sizeof buf, elem, precision));
+ }
+ }
+ fprintf(file,
+ "%s (rows %d..%d out of %d..%d, columns %d..%d out of %d..%d)\n",
+ label, row_start, row_end - 1, 0, rows - 1, col_start, col_end - 1, 0,
+ cols - 1);
+ for (int row = row_start; row < row_end; row++) {
+ for (int col = col_start; col < col_end; col++) {
+ iree_vm_value_t elem =
+ read_matrix_element(rows, cols, elem_type, (void*)data, row, col);
+ bool bad_elem = false;
+ if (other_matrix) {
+ iree_vm_value_t other_elem = read_matrix_element(
+ rows, cols, elem_type, (void*)other_data, row, col);
+ bad_elem = !matmul_result_elements_agree(elem, other_elem);
+ }
+ char buf[64];
+ snprintf_value(buf, sizeof buf, elem, precision);
+ fprintf(file, "%*s", max_elem_width, buf);
+ if (bad_elem) {
+ fprintf(file, "💩");
+ } else if (col < col_end - 1) {
+ // two spaces per https://www.unicode.org/reports/tr11/#Recommendations
+ fprintf(file, " ");
+ }
+ }
+ fprintf(file, "\n");
+ }
+}
+
+// Helper for iree_check_matmul: handler for the failure case.
+static iree_status_t iree_check_matmul_failure(
+ FILE* file, iree_vm_value_t actual_value, iree_vm_value_t expected_value,
+ iree_hal_dim_t row, iree_hal_dim_t col, iree_hal_buffer_view_t* lhs,
+ iree_hal_buffer_view_t* rhs, iree_hal_buffer_view_t* acc,
+ iree_hal_buffer_view_t* actual_result,
+ iree_hal_buffer_view_t* expected_result) {
+ fprintf(file,
+ "\n\nerror: the actual and expected result matrices disagree "
+ "at row %d, column %d.\n\n",
+ row, col);
+ char actual_value_buf[32];
+ char expected_value_buf[32];
+ snprintf_value(actual_value_buf, sizeof actual_value_buf, actual_value,
+ PRECISION_HIGH);
+ snprintf_value(expected_value_buf, sizeof expected_value_buf, expected_value,
+ PRECISION_HIGH);
+ fprintf(file, "actual value: %s\n", actual_value_buf);
+ fprintf(file, "expected value: %s\n", expected_value_buf);
+
+ iree_hal_dim_t m_size, k_size, n_size;
+ IREE_RETURN_IF_ERROR(get_matmul_sizes(lhs, rhs, acc, actual_result, &m_size,
+ &k_size, &n_size));
+ iree_hal_dim_t context = 8;
+ const char* context_env = portable_getenv("IREE_MATMUL_TEST_SHOW_CONTEXT");
+ if (getenv("IREE_MATMUL_TEST_SHOW_CONTEXT")) {
+ if (1 != sscanf(context_env, "%d", &context)) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "Failed to parse IREE_MATMUL_TEST_SHOW_CONTEXT "
+ "as \"%%d\". Got \"%s\"",
+ context_env);
+ }
+ }
+ int m_start = iree_max(0, row - context);
+ int m_end = iree_min(m_size, row + context);
+ int n_start = iree_max(0, col - context);
+ int n_end = iree_min(n_size, col + context);
+ // We have a lot more freedom to pick k_start, k_end, since these parameters
+ // only affect which regions of the input lhs and rhs matrices are printed.
+ // If we were only testing random lhs and rhs, we would just pick
+ // k_start = 0 and any reasonable k_end value. Since we are often using
+ // identity matrices for lhs and rhs, and we expect the majority of
+ // test failures to occur with such identity matrices, we try to pick
+ // k_start and k_end so that nontrivial regions of identity matrices will be
+ // printed. That means that we try to have [k_start, k_end) intervals
+ // overlap [m_start, m_end) and [n_start, n_end).
+ int k_start = iree_max(0, iree_min(m_start, n_start));
+ int k_end = iree_min(k_size, iree_max(m_end, n_end));
+ // [k_start, k_end) could be arbitrarily long at this point. Constrain it a
+ // bit to avoid huge output.
+ k_end = iree_min(k_end, k_start + 4 * context);
+
+ fprintf(file, "\n");
+ print_matrix(file, "left-hand side", PRECISION_LOW, m_start, m_end, k_start,
+ k_end, lhs, NULL);
+ fprintf(file, "\n");
+ print_matrix(file, "right-hand side", PRECISION_LOW, k_start, k_end, n_start,
+ n_end, rhs, NULL);
+ fprintf(file, "\n");
+ print_matrix(file, "input accumulator", PRECISION_LOW, m_start, m_end,
+ n_start, n_end, acc, NULL);
+ fprintf(file, "\n");
+ print_matrix(file, "expected result", PRECISION_LOW, m_start, m_end, n_start,
+ n_end, expected_result, actual_result);
+ fprintf(file, "\n");
+ print_matrix(file, "actual result", PRECISION_LOW, m_start, m_end, n_start,
+ n_end, actual_result, expected_result);
+ fprintf(file, "\n");
+ return iree_make_status(IREE_STATUS_ABORTED,
+ "matmul test failure, details logged above");
+}
+
+// Helper for iree_check_matmul: the actual interesting part once we've obtained
+// and validated the {m,k,n}_size values.
+static iree_status_t check_matmul_impl(
+ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size,
+ iree_hal_buffer_view_t* lhs, iree_hal_buffer_view_t* rhs,
+ iree_hal_buffer_view_t* acc, iree_hal_buffer_view_t* actual_result,
+ iree_hal_buffer_view_t* expected_result) {
+ void* actual_result_data;
+ void* expected_result_data;
+ IREE_RETURN_IF_ERROR(
+ get_buffer_view_dense_row_major_data(actual_result, &actual_result_data));
+ IREE_RETURN_IF_ERROR(get_buffer_view_dense_row_major_data(
+ expected_result, &expected_result_data));
+ iree_hal_element_type_t result_type =
+ iree_hal_buffer_view_element_type(actual_result);
+ for (iree_hal_dim_t m = 0; m < m_size; ++m) {
+ for (iree_hal_dim_t n = 0; n < n_size; ++n) {
+ iree_vm_value_t actual_value = read_matrix_element(
+ m_size, n_size, result_type, actual_result_data, m, n);
+ iree_vm_value_t expected_value = read_matrix_element(
+ m_size, n_size, result_type, expected_result_data, m, n);
+ if (!matmul_result_elements_agree(actual_value, expected_value)) {
+ return iree_check_matmul_failure(stderr, actual_value, expected_value,
+ m, n, lhs, rhs, acc, actual_result,
+ expected_result);
+ }
+ }
+ }
+ return iree_ok_status();
+}
+
+// Given an actual matmul's inputs and output, uses a reference
+// matmul implementation on the same inputs to check if the output
+// is correct.
+static iree_status_t iree_check_matmul(
+ iree_vm_list_t* input_list, iree_hal_buffer_view_t* actual_result,
+ iree_hal_buffer_view_t* expected_result) {
+ iree_hal_buffer_view_t* lhs;
+ iree_hal_buffer_view_t* rhs;
+ iree_hal_buffer_view_t* acc;
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(input_list, 0, &lhs));
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(input_list, 1, &rhs));
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(input_list, 2, &acc));
+
+ iree_hal_dim_t m_size, k_size, n_size;
+ IREE_RETURN_IF_ERROR(get_matmul_sizes(lhs, rhs, acc, actual_result, &m_size,
+ &k_size, &n_size));
+
+ return check_matmul_impl(m_size, k_size, n_size, lhs, rhs, acc, actual_result,
+ expected_result);
+}
+
+// Allocates |dst| to have the same shape as |src|, without copying contents.
+static iree_status_t allocate_buffer_like(iree_hal_allocator_t* hal_allocator,
+ iree_hal_buffer_view_t* src,
+ iree_hal_buffer_view_t** dst) {
+ return iree_hal_buffer_view_allocate_buffer(
+ hal_allocator, iree_hal_buffer_view_shape_dims(src),
+ iree_hal_buffer_view_shape_rank(src),
+ iree_hal_buffer_view_element_type(src),
+ iree_hal_buffer_view_encoding_type(src),
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ IREE_HAL_BUFFER_USAGE_ALL, dst);
+}
+
+// Performs a deep copy of |src| into |dst|. Takes care of allocating |dst|.
+static iree_status_t copy_buffer(iree_hal_allocator_t* hal_allocator,
+ iree_hal_buffer_view_t* src,
+ iree_hal_buffer_view_t** dst) {
+ iree_hal_buffer_mapping_t src_mapping;
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range(
+ iree_hal_buffer_view_buffer(src), IREE_HAL_MEMORY_ACCESS_READ, 0,
+ IREE_WHOLE_BUFFER, &src_mapping));
+ iree_const_byte_span_t src_span;
+ src_span.data = src_mapping.contents.data;
+ src_span.data_length = src_mapping.contents.data_length;
+ return iree_hal_buffer_view_clone_heap_buffer(
+ hal_allocator, iree_hal_buffer_view_shape_dims(src),
+ iree_hal_buffer_view_shape_rank(src),
+ iree_hal_buffer_view_element_type(src),
+ iree_hal_buffer_view_encoding_type(src),
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+ IREE_HAL_BUFFER_USAGE_ALL, src_span, dst);
+}
+
+static iree_status_t copy_list_of_buffer_views(
+ iree_hal_allocator_t* hal_allocator, iree_vm_list_t* src,
+ iree_vm_list_t** dst) {
+ iree_vm_type_def_t elem_type;
+ IREE_RETURN_IF_ERROR(iree_vm_list_element_type(src, &elem_type));
+ iree_host_size_t size = iree_vm_list_size(src);
+ iree_allocator_t allocator = iree_hal_allocator_host_allocator(hal_allocator);
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(&elem_type, size, allocator, dst));
+ IREE_RETURN_IF_ERROR(iree_vm_list_resize(*dst, size));
+ for (iree_host_size_t i = 0; i < size; ++i) {
+ iree_hal_buffer_view_t* src_elem;
+ IREE_RETURN_IF_ERROR(iree_get_buffer_view_list_item(src, i, &src_elem));
+ iree_hal_buffer_view_t* dst_elem;
+ IREE_RETURN_IF_ERROR(copy_buffer(hal_allocator, src_elem, &dst_elem));
+ iree_vm_ref_t dst_elem_ref = {0};
+ IREE_RETURN_IF_ERROR(iree_vm_ref_wrap_assign(
+ dst_elem, iree_hal_buffer_view_type_id(), &dst_elem_ref));
+ IREE_RETURN_IF_ERROR(iree_vm_list_set_ref_move(*dst, i, &dst_elem_ref));
+ }
+ return iree_ok_status();
+}
+
+// Special handler for function calls in a e2e matmul test trace.
+// Assumes that all calls are to functions that take 3 inputs (lhs, rhs, acc)
+// and return the result of a matmul (lhs*rhs+acc).
+static iree_status_t replay_event_call(iree_trace_replay_t* replay,
+ yaml_document_t* document,
+ yaml_node_t* event_node) {
+ yaml_node_t* function_node = NULL;
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
+ document, event_node, iree_make_cstring_view("function"),
+ &function_node));
+ iree_string_view_t function_name = iree_yaml_node_as_string(function_node);
+ fprintf(stderr, "--- CALL[%.*s] ---\n", (int)function_name.size,
+ function_name.data);
+
+ iree_hal_allocator_t* heap_hal_allocator;
+ IREE_RETURN_IF_ERROR(iree_hal_allocator_create_heap(
+ iree_make_cstring_view("e2e-matmul-test-heap-allocator"),
+ replay->host_allocator, &heap_hal_allocator));
+
+ iree_vm_function_t function;
+ iree_vm_list_t* input_list = NULL;
+ IREE_RETURN_IF_ERROR(iree_trace_replay_event_call_prepare(
+ replay, document, event_node, &function, &input_list));
+
+ // Perform a deep copy of the input list to pass to the test function.
+ // Rationale: the test function may mutate some of the input list elements,
+ // e.g. input-output parameters. For instance, the accumulator input of a
+ // linalg.matmul. We need to preserve the original test inputs to run the
+ // reference matmul on and to use in test failure logs.
+ iree_vm_list_t* copy_of_input_list = NULL;
+ copy_list_of_buffer_views(heap_hal_allocator, input_list,
+ ©_of_input_list);
+
+ // Invoke the function to produce the actual result.
+ iree_vm_list_t* output_list = NULL;
+ IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/NULL,
+ /*initial_capacity=*/8,
+ replay->host_allocator, &output_list));
+ IREE_CHECK_OK(iree_vm_invoke(
+ replay->context, function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL,
+ copy_of_input_list, output_list, replay->host_allocator));
+
+ // Get the actual_result buffer from the output_list.
+ iree_hal_buffer_view_t* actual_result;
+ IREE_RETURN_IF_ERROR(
+ iree_get_buffer_view_list_item(output_list, 0, &actual_result));
+
+ // Allocate an expected_result buffer, with same shape as actual_result.
+ iree_hal_buffer_view_t* expected_result;
+ IREE_RETURN_IF_ERROR(allocate_buffer_like(heap_hal_allocator, actual_result,
+ &expected_result));
+
+ // Use the reference matmul implementation to fill expected_result
+ IREE_RETURN_IF_ERROR(reference_matmul(input_list, expected_result));
+
+ // Check that actual_result and expected_result agree.
+ IREE_CHECK_OK(iree_check_matmul(input_list, actual_result, expected_result));
+
+ // Clean up.
+ iree_vm_list_release(input_list);
+ iree_vm_list_release(copy_of_input_list);
+ iree_vm_list_release(output_list); // releases actual_result
+ iree_hal_buffer_view_release(expected_result);
+
+ iree_hal_allocator_release(heap_hal_allocator);
+
+ return iree_ok_status();
+}
+
+static iree_status_t iree_e2e_matmul_test_trace_replay_event(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* event_node) {
+ if (event_node->type != YAML_MAPPING_NODE) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "(%zu): expected mapping node",
+ event_node->start_mark.line);
+ }
+ yaml_node_t* type_node = NULL;
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_find(
+ document, event_node, iree_make_cstring_view("type"), &type_node));
+ if (iree_yaml_string_equal(type_node, iree_make_cstring_view("call"))) {
+ return replay_event_call(replay, document, event_node);
+ } else {
+ return iree_trace_replay_event(replay, document, event_node);
+ }
+}
+
+// Runs the trace in |file| using |root_path| as the base for any path lookups
+// required for external files referenced in |file|.
+static iree_status_t run_trace_file(iree_string_view_t root_path, FILE* file,
+ iree_vm_instance_t* instance) {
+ iree_trace_replay_t replay;
+ IREE_RETURN_IF_ERROR(iree_trace_replay_initialize(
+ root_path, instance,
+ FLAG_trace_execution ? IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION
+ : IREE_VM_CONTEXT_FLAG_NONE,
+ iree_allocator_system(), &replay));
+ iree_trace_replay_set_hal_driver_override(
+ &replay, iree_make_cstring_view(FLAG_driver));
+
+ yaml_parser_t parser;
+ if (!yaml_parser_initialize(&parser)) {
+ iree_trace_replay_deinitialize(&replay);
+ return iree_make_status(IREE_STATUS_INTERNAL,
+ "yaml_parser_initialize failed");
+ }
+ yaml_parser_set_input_file(&parser, file);
+
+ iree_status_t status = iree_ok_status();
+ for (bool document_eof = false; !document_eof;) {
+ yaml_document_t document;
+ if (!yaml_parser_load(&parser, &document)) {
+ status = iree_status_from_yaml_parser_error(&parser);
+ break;
+ }
+ yaml_node_t* event_node = yaml_document_get_root_node(&document);
+ if (event_node) {
+ status = iree_e2e_matmul_test_trace_replay_event(&replay, &document,
+ event_node);
+ } else {
+ document_eof = true;
+ }
+ yaml_document_delete(&document);
+ if (!iree_status_is_ok(status)) break;
+ }
+
+ yaml_parser_delete(&parser);
+ iree_trace_replay_deinitialize(&replay);
+ return status;
+}
+
+// Runs each of the given traces files sequentially in isolated contexts.
+static iree_status_t run_trace_files(int file_count, char** file_paths,
+ iree_vm_instance_t* instance) {
+ for (int i = 0; i < file_count; ++i) {
+ iree_string_view_t file_path = iree_make_cstring_view(file_paths[i]);
+ iree_string_view_t root_path = iree_file_path_dirname(file_path);
+ FILE* file = fopen(file_paths[i], "rb");
+ if (!file) {
+ return iree_make_status(iree_status_code_from_errno(errno),
+ "failed to open trace file '%.*s'",
+ (int)file_path.size, file_path.data);
+ }
+ iree_status_t status = run_trace_file(root_path, file, instance);
+ fclose(file);
+ IREE_RETURN_IF_ERROR(status, "replaying trace file '%.*s'",
+ (int)file_path.size, file_path.data);
+ }
+ return iree_ok_status();
+}
+
+int main(int argc, char** argv) {
+ iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
+ if (argc <= 1) {
+ fprintf(stderr,
+ "no trace files provided; pass one or more yaml file paths");
+ return 1;
+ }
+
+ iree_vm_instance_t* instance = NULL;
+ iree_status_t status =
+ iree_vm_instance_create(iree_allocator_system(), &instance);
+ if (iree_status_is_ok(status)) {
+ IREE_CHECK_OK(iree_hal_register_all_available_drivers(
+ iree_hal_driver_registry_default()));
+ status = run_trace_files(argc - 1, argv + 1, instance);
+ }
+ iree_vm_instance_release(instance);
+ if (!iree_status_is_ok(status)) {
+ iree_status_fprint(stderr, status);
+ iree_status_free(status);
+ return 1;
+ }
+ return 0;
+}
diff --git a/iree/tools/utils/trace_replay.c b/iree/tools/utils/trace_replay.c
index 9e91d62..4f01524 100644
--- a/iree/tools/utils/trace_replay.c
+++ b/iree/tools/utils/trace_replay.c
@@ -499,6 +499,155 @@
return status;
}
+// Writes an element of the given |element_type| with the given integral |value|
+// to |dst|.
+static void iree_trace_replay_write_element(
+ iree_hal_element_type_t element_type, int value, void* dst) {
+#define IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(ETYPE, CTYPE) \
+ case IREE_HAL_ELEMENT_TYPE_##ETYPE: \
+ *(CTYPE*)dst = (CTYPE)value; \
+ break;
+
+ switch (element_type) {
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_8, int8_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_16, int16_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_32, int32_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(SINT_64, int64_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_8, uint8_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_16, uint16_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_32, uint32_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(UINT_64, uint64_t)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_32, float)
+ IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE(FLOAT_64, double)
+ default:
+ IREE_ASSERT(false, "unhandled element type");
+ break;
+ }
+
+#undef IREE_TRACE_REPLAY_WRITE_ELEMENT_CASE
+}
+
+// Writes an identity matrix, with matrix elements of the given |element_type|,
+// to the destination |span|. The matrix shape is inferred from |inner_size|
+// and the span's length.
+//
+// Here by 'identity matrix' we mean any two-dimensional array of integers
+// of the form
+//
+// array[i, j] = ((i == j) ? 1 : 0)
+//
+// Technically they are only called 'identity matrix' for square shapes.
+//
+// These identity matrices are useful in matrix multiplication tests to
+// generate testcases that are easy to debug numerically, as the identity
+// matrix is the neutral element for matrix multiplication.
+static void iree_trace_replay_generate_identity_matrix(
+ iree_hal_element_type_t element_type, iree_byte_span_t span,
+ iree_hal_dim_t inner_size) {
+ iree_host_size_t element_byte_count =
+ iree_hal_element_byte_count(element_type);
+ uint8_t* data_end = span.data + span.data_length;
+ iree_host_size_t inner_index = 0;
+ iree_host_size_t outer_index = 0;
+ for (uint8_t* data = span.data; data < data_end; data += element_byte_count) {
+ int value = inner_index == outer_index ? 1 : 0;
+ iree_trace_replay_write_element(element_type, value, data);
+ ++inner_index;
+ if (inner_index == inner_size) {
+ inner_index = 0;
+ ++outer_index;
+ }
+ }
+}
+
+// Simple deterministic pseudorandom generator.
+// Typically in tests we want reproducible results both across runs and across
+// machines.
+static uint8_t iree_trace_replay_pseudorandom_uint8(uint32_t* state) {
+ // Same as C++'s std::minstd_rand.
+ *state = (*state * 48271) % 2147483647;
+ // return the second-least-signicant out of the 4 bytes of state. it avoids
+ // some mild issues with the least-significant and most-significant bytes.
+ return *state >> 8;
+}
+
+// Fills the destination span with pseudorandom values of the given
+// |element_type|. The given |seed| is passed to the pseudorandom generator.
+// The pseudorandom values are reproducible both across runs and across
+// machines.
+static void iree_trace_replay_generate_fully_specified_pseudorandom_buffer(
+ iree_hal_element_type_t element_type, iree_byte_span_t span,
+ uint32_t seed) {
+ const bool is_unsigned = iree_hal_element_numerical_type(element_type) ==
+ IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED;
+ iree_host_size_t element_byte_count =
+ iree_hal_element_byte_count(element_type);
+ uint8_t* data_end = span.data + span.data_length;
+ uint32_t state = seed;
+ for (uint8_t* data = span.data; data < data_end; data += element_byte_count) {
+ int value_in_uint8_range = iree_trace_replay_pseudorandom_uint8(&state);
+ int value = value_in_uint8_range + (is_unsigned ? 0 : -128);
+ iree_trace_replay_write_element(element_type, value, data);
+ }
+}
+
+// Generates the destination |buffer| using the generator specified by
+// |contents_generator_node|.
+static iree_status_t iree_trace_replay_generate_hal_buffer(
+ iree_trace_replay_t* replay, yaml_document_t* document,
+ yaml_node_t* contents_generator_node, iree_hal_element_type_t element_type,
+ iree_hal_buffer_t* buffer, iree_hal_dim_t* shape,
+ iree_host_size_t shape_size) {
+ if (!contents_generator_node) {
+ return iree_ok_status();
+ } else if (contents_generator_node->type != YAML_SCALAR_NODE) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "(%zu): expected scalar node for buffer contents_generator",
+ contents_generator_node->start_mark.line);
+ }
+
+ iree_hal_buffer_mapping_t mapping;
+ IREE_RETURN_IF_ERROR(
+ iree_hal_buffer_map_range(buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0,
+ IREE_WHOLE_BUFFER, &mapping));
+ iree_status_t status = iree_ok_status();
+ if (strcmp(contents_generator_node->tag, "!tag:iree:identity_matrix") == 0) {
+ if (shape_size == 2) {
+ iree_hal_dim_t inner_size = shape[shape_size - 1];
+ iree_trace_replay_generate_identity_matrix(element_type, mapping.contents,
+ inner_size);
+ } else {
+ status = iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "the identity_matrix generator is only for 2D shapes (matrices)");
+ }
+ } else if (strcmp(contents_generator_node->tag,
+ "!tag:iree:fully_specified_pseudorandom") == 0) {
+ // To enable pseudorandom tests that are both reproducible and invariant
+ // under reordering and filtering testcases, the seed is explicitly
+ // passed as argument in the contents_generator tag.
+ iree_string_view_t seed_str = iree_string_view_trim(
+ iree_yaml_node_as_string(contents_generator_node));
+ uint32_t seed;
+ if (iree_string_view_atoi_uint32(seed_str, &seed)) {
+ iree_trace_replay_generate_fully_specified_pseudorandom_buffer(
+ element_type, mapping.contents, seed);
+ } else {
+ status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "could not parse the seed argument ('%s') of "
+ "the fully_specified_pseudorandom tag",
+ seed_str.data);
+ }
+ } else {
+ status = iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED, "(%zu): unimplemented buffer generator '%s'",
+ contents_generator_node->start_mark.line, contents_generator_node->tag);
+ }
+ iree_hal_buffer_unmap_range(&mapping);
+ return status;
+}
+
// Parses a !hal.buffer_view and appends it to |target_list|.
//
// ```yaml
@@ -541,6 +690,18 @@
document, value_node, iree_make_cstring_view("contents"),
&contents_node));
+ yaml_node_t* contents_generator_node = NULL;
+ IREE_RETURN_IF_ERROR(iree_yaml_mapping_try_find(
+ document, value_node, iree_make_cstring_view("contents_generator"),
+ &contents_generator_node));
+
+ if (contents_node && contents_generator_node) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "(%zu): cannot have both contents and contents_generator",
+ contents_generator_node->start_mark.line);
+ }
+
iree_device_size_t allocation_size = 0;
IREE_RETURN_IF_ERROR(iree_hal_buffer_compute_view_size(
shape, shape_rank, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
@@ -551,8 +712,15 @@
iree_hal_device_allocator(replay->device),
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
IREE_HAL_BUFFER_USAGE_ALL, allocation_size, &buffer));
- iree_status_t status = iree_trace_replay_parse_hal_buffer(
- replay, document, contents_node, element_type, buffer);
+ iree_status_t status = iree_trace_replay_generate_hal_buffer(
+ replay, document, contents_generator_node, element_type, buffer, shape,
+ shape_rank);
+ if (!iree_status_is_ok(status)) {
+ iree_hal_buffer_release(buffer);
+ return status;
+ }
+ status = iree_trace_replay_parse_hal_buffer(replay, document, contents_node,
+ element_type, buffer);
if (!iree_status_is_ok(status)) {
iree_hal_buffer_release(buffer);
return status;