Merge pull request #3486 from ScottTodd:main-to-google
PiperOrigin-RevId: 337331640
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 701ba0d..472cea2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -394,8 +394,8 @@
if(CMAKE_CROSSCOMPILING)
# We need flatc to generate some source code. When cross-compiling, we need
# to make sure the flatc binary is configured under host environment.
- iree_declare_host_excutable(flatc BUILDONLY)
- iree_declare_host_excutable(flatcc_cli BUILDONLY)
+ iree_declare_host_excutable(flatc "flatc" BUILDONLY)
+ iree_declare_host_excutable(flatcc_cli "flatcc_cli" BUILDONLY)
# Set the FLATBUFFERS_FLATC_EXECUTABLE. It controls where to find the flatc
# binary in BuildFlatBuffers().
@@ -414,7 +414,7 @@
COMMAND
"${CMAKE_COMMAND}" -E copy_if_different
"${PROJECT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
- "${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli"
+ "${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli${IREE_HOST_EXECUTABLE_SUFFIX}"
DEPENDS iree_host_build_flatcc_cli
COMMENT "Installing host flatcc..."
)
diff --git a/build_tools/cmake/iree_cc_binary.cmake b/build_tools/cmake/iree_cc_binary.cmake
index 5e71014..5e1bd0a 100644
--- a/build_tools/cmake/iree_cc_binary.cmake
+++ b/build_tools/cmake/iree_cc_binary.cmake
@@ -78,7 +78,7 @@
# The binary is marked as host only. We need to declare the rules for
# generating them under host configuration so when cross-compiling towards
# target we can still have this binary.
- iree_declare_host_excutable(${_RULE_NAME})
+ iree_declare_host_excutable(${_RULE_NAME} ${_NAME})
# Still define the package-prefixed target so we can have a consistent way
# to reference this binary, whether cross-compiling or not. But this time
diff --git a/build_tools/cmake/iree_cross_compile.cmake b/build_tools/cmake/iree_cross_compile.cmake
index 4c67410..b52b4d2 100644
--- a/build_tools/cmake/iree_cross_compile.cmake
+++ b/build_tools/cmake/iree_cross_compile.cmake
@@ -131,13 +131,13 @@
# iree_get_build_command
#
-# Gets the CMake build command for the given `EXECUTABLE`.
+# Gets the CMake build command for the given `EXECUTABLE_TARGET`.
#
# Parameters:
-# EXECUTABLE: the executable to build.
+# EXECUTABLE_TARGET: the target for the executable to build.
# BINDIR: root binary directory containing CMakeCache.txt.
# CMDVAR: variable name for receiving the build command.
-function(iree_get_build_command EXECUTABLE)
+function(iree_get_build_command EXECUTABLE_TARGET)
cmake_parse_arguments(_RULE "" "BINDIR;CMDVAR;CONFIG" "" ${ARGN})
if(NOT _RULE_CONFIG)
set(_RULE_CONFIG "$<CONFIG>")
@@ -145,11 +145,11 @@
if (CMAKE_GENERATOR MATCHES "Make")
# Use special command for Makefiles to support parallelism.
set(${_RULE_CMDVAR}
- "$(MAKE)" "-C" "${_RULE_BINDIR}" "${EXECUTABLE}" PARENT_SCOPE)
+ "$(MAKE)" "-C" "${_RULE_BINDIR}" "${EXECUTABLE_TARGET}" PARENT_SCOPE)
else()
set(${_RULE_CMDVAR}
"${CMAKE_COMMAND}" --build ${_RULE_BINDIR}
- --target ${EXECUTABLE}${IREE_HOST_EXECUTABLE_SUFFIX}
+ --target ${EXECUTABLE_TARGET}
--config ${_RULE_CONFIG} PARENT_SCOPE)
endif()
endfunction()
@@ -201,14 +201,15 @@
#
# Parameters:
# EXECUTABLE: the executable to build on host.
+# EXECUTABLE_TARGET: the target name for the executable.
# BUILDONLY: only generates commands for building the target.
# DEPENDS: any additional dependencies for the target.
-function(iree_declare_host_excutable EXECUTABLE)
+function(iree_declare_host_excutable EXECUTABLE EXECUTABLE_TARGET)
cmake_parse_arguments(_RULE "BUILDONLY" "" "DEPENDS" ${ARGN})
iree_get_executable_path(_OUTPUT_PATH ${EXECUTABLE})
- iree_get_build_command(${EXECUTABLE}
+ iree_get_build_command(${EXECUTABLE_TARGET}
BINDIR ${IREE_HOST_BINARY_ROOT}
CMDVAR build_cmd)
diff --git a/build_tools/embed_data/CMakeLists.txt b/build_tools/embed_data/CMakeLists.txt
index ed9a879..8679ccc 100644
--- a/build_tools/embed_data/CMakeLists.txt
+++ b/build_tools/embed_data/CMakeLists.txt
@@ -13,7 +13,7 @@
# limitations under the License.
if(CMAKE_CROSSCOMPILING)
- iree_declare_host_excutable(generate_cc_embed_data)
+ iree_declare_host_excutable(generate_cc_embed_data "generate_cc_embed_data")
else()
add_executable(generate_cc_embed_data)
target_sources(generate_cc_embed_data PRIVATE generate_cc_embed_data_main.cc)
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index ed15f36..af2c510 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -105,7 +105,6 @@
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
- "sort_test.py",
"strings_test.py",
]
@@ -126,7 +125,6 @@
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
- "sort_test.py",
"strings_test.py",
]
@@ -180,6 +178,7 @@
},
reference_backend = "tf",
tags = [
+ "failing",
"manual",
"nokokoro",
"notap",
@@ -220,6 +219,7 @@
},
reference_backend = "tf",
tags = [
+ "failing",
"manual",
"nokokoro",
"notap",
@@ -261,6 +261,7 @@
reference_backend = "tf",
tags = [
"external",
+ "failing",
"guitar",
"manual",
"no-remote",
diff --git a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
new file mode 100644
index 0000000..e7a4fe5
--- /dev/null
+++ b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
@@ -0,0 +1,229 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building e2e tests from a single source with multiple flags."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+load(
+ "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
+ "get_driver",
+)
+
+def _normalize_dictionary(dictionary):
+ """Wraps every value of dictionary in a list if it isn't one already."""
+ for key, value in dictionary.items():
+ if type(value) != type([]):
+ dictionary[key] = [value]
+ return dictionary
+
+def _dictionary_product(dictionary):
+ """Returns a named cartesian product of dictionary's values."""
+
+ # Converts {'a': [1, 2], 'b': [3, 4]} into
+ # [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}]
+ product = [[]]
+ for values in dictionary.values():
+ # Iteratively grow the elements of the product.
+ product = [element + [value] for element in product for value in values]
+ dicts = [{k: v for k, v in zip(dictionary, element)} for element in product]
+ return dicts
+
+def iree_e2e_cartesian_product_test_suite(
+ name,
+ srcs,
+ main,
+ flags_to_values,
+ failing_configurations = None,
+ tags = None,
+ data = None,
+ deps = None,
+ size = None,
+ python_version = "PY3",
+ **kwargs):
+ """Creates a test for each configuration and bundles a succeeding and failing test suite.
+
+ Computes the cartesian product of `flags_to_values` and then creates a test
+ for each element of that product. Tests specified in
+ `failing_configurations` are bundled into a test suite suffixed with
+ "_failing" and tagged to be excluded from CI and wildcard builds. All other
+ tests are bundled into a suite with the same name as the macro.
+
+ For example, given the following values
+
+ flags_to_values = {
+ "use_external_weights": True,
+ "model": [
+ "ResNet50",
+ "MobileBert",
+ ],
+ "target_backends": [
+ "tf",
+ "iree_vmla"
+ "iree_vulkan",
+ ]
+ }
+ failing_configurations = [
+ {
+ "model": "MobileBert",
+ "target_backends": "iree_vulkan",
+ },
+ {
+ "model": ["ResNet50"],
+ },
+ ]
+
+ the following passing and failing configurations would be generated:
+ # Passing
+ {use_exernal_weights: True, model: MobileBert, target_backends: tf}
+ {use_exernal_weights: True, model: MobileBert, target_backends: iree_vmla}
+
+ # Failing
+ {use_exernal_weights: True, model: ResNet50, target_backends: tf}
+ {use_exernal_weights: True, model: ResNet50, target_backends: iree_vmla}
+ {use_exernal_weights: True, model: ResNet50, target_backends: iree_vulkan}
+ {use_exernal_weights: True, model: MobileBert, target_backends: iree_vulkan}
+
+ Args:
+ name:
+ name of the generated passing test suite. If failing_configurations
+ is not `None` then a test suite named name_failing will also be
+ generated.
+ srcs:
+ src files for iree_py_test
+ main:
+ main file for iree_py_test
+ failing_configurations:
+ an iterable of dictionaries specifying which flag values the test is
+ failing for. If a flag name is present in `flags_to_values`, but not
+ present in `failing_configurations`, then all of the values in
+ `flags_to_values[flag_name]` are included. (See `ResNet50` in the
+ example above).
+ flags_to_values:
+ a dictionary of strings (flag names) to lists (of values for the flags)
+ to take a cartesian product of. `target_backends` must be specified.
+ tags:
+ tags to apply to the test. Note that as in standard test suites, manual
+ is treated specially and will also apply to the test suite itself.
+ data:
+ external data for iree_py_test.
+ deps:
+ test dependencies for iree_py_test.
+ size:
+ size of the tests for iree_py_test.
+ python_version:
+ the python version to run the tests with. Uses python3 by default.
+ **kwargs:
+ any additional arguments that will be passed to the underlying tests and
+ test_suite.
+ """
+ if not "target_backends" in flags_to_values:
+ fail("`target_backends` must be a key in `flags_to_values`.")
+
+ # Normalize flags_to_values to always have lists as its values.
+ # e.g. {use_external_data: True} -> {use_external_data: [True]}
+ flags_to_values = _normalize_dictionary(flags_to_values)
+
+ all_flag_configurations = _dictionary_product(flags_to_values)
+
+ failing_flag_configurations = []
+ if failing_configurations != None:
+ for failing_configuration in failing_configurations:
+ failing_configuration = _normalize_dictionary(failing_configuration)
+
+ # If a flag isn't specified in the failing configuration, assume it
+ # is failing for all values of that flag.
+ for key, values in flags_to_values.items():
+ if key not in failing_configuration:
+ failing_configuration[key] = values
+
+ failing_flag_configurations.extend(
+ _dictionary_product(failing_configuration),
+ )
+
+ tests = []
+ for flags in all_flag_configurations:
+ # Check if this is a failing configuration.
+ failing = flags in failing_flag_configurations
+
+ # Append "_failing" to name if this is a failing configuration.
+ test_name = name if not failing else name + "_failing"
+ test_name = [test_name]
+ for k, v in flags.items():
+ # Only include the flag's value in the test name if it's not always
+ # the same.
+ if len(flags_to_values[k]) > 1:
+ test_name.append(k)
+ test_name.append(str(v))
+ test_name = "__".join(test_name)
+ tests.append(test_name)
+
+ args = ["--{}={}".format(k, v) for k, v in flags.items()]
+
+ if len(flags["target_backends"].split(",")) > 1:
+ fail("Multiple target backends cannot be specified at once, but " +
+ "got `{}`".format(flags["target_backends"]))
+
+ driver = get_driver(flags["target_backends"])
+ py_test_tags = ["driver={}".format(driver)]
+ if tags != None: # `is` is not supported.
+ py_test_tags += tags
+
+ # Add additional tags if this is a failing configuration.
+ if failing:
+ py_test_tags += [
+ "failing",
+ "manual",
+ "nokokoro",
+ "notap",
+ ]
+
+ iree_py_test(
+ name = test_name,
+ main = main,
+ srcs = srcs,
+ args = args,
+ data = data,
+ deps = deps,
+ size = size,
+ tags = py_test_tags,
+ python_version = python_version,
+ **kwargs
+ )
+
+ if tags == None:
+ tags = []
+ native.test_suite(
+ name = name,
+ tests = tests,
+ # Add "-failing" to exclude tests in `tests` that have the "failing"
+ # tag.
+ tags = tags + ["-failing"],
+ # If there are kwargs that need to be passed here which only apply to
+ # the generated tests and not to test_suite, they should be extracted
+ # into separate named arguments.
+ **kwargs
+ )
+
+ if failing_configurations != None:
+ native.test_suite(
+ name = name + "_failing",
+ tests = tests,
+ # Add "+failing" to only include tests in `tests` that have the
+ # "failing" tag.
+ tags = tags + ["+failing"],
+ # If there are kwargs that need to be passed here which only apply
+ # to the generated tests and not to test_suite, they should be
+ # extracted into separate named arguments.
+ **kwargs
+ )
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 159633a..8fb35a4 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -16,10 +16,18 @@
load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+def get_driver(backend):
+ # TODO(#2175): Simplify this after backend names are standardized.
+ driver = backend.replace("iree_", "") # "iree_<driver>" --> "<driver>"
+ if driver == "llvmjit":
+ driver = "llvm"
+ return driver
+
def iree_e2e_test_suite(
name,
backends_to_srcs,
reference_backend,
+ data = None,
deps = None,
size = None,
tags = None,
@@ -34,8 +42,10 @@
a dictionary mapping backends to a list of test files to run on them.
reference_backend:
the backend to use as a source of truth for the expected output results.
+ data:
+ external data for iree_py_test.
deps:
- test dependencies.
+ test dependencies for iree_py_test.
tags:
tags to apply to the test. Note that as in standard test suites, manual
is treated specially and will also apply to the test suite itself.
@@ -49,10 +59,9 @@
for backend, srcs in backends_to_srcs.items():
for src in srcs:
- test_name = "{}_{}__{}__{}".format(
+ test_name = "{}__{}__target_backends__{}".format(
name,
src[:-3],
- reference_backend,
backend,
)
args = [
@@ -60,10 +69,7 @@
"--target_backends={}".format(backend),
]
- # TODO(GH-2175): Simplify this after backend names are standardized.
- driver = backend.replace("iree_", "") # "iree_<driver>" --> "<driver>"
- if driver == "llvmjit":
- driver = "llvm"
+ driver = get_driver(backend)
py_test_tags = ["driver={}".format(driver)]
if tags != None: # `is` is not supported.
py_test_tags += tags
@@ -72,8 +78,9 @@
name = test_name,
main = src,
srcs = [src],
- deps = deps,
args = args,
+ data = data,
+ deps = deps,
size = size,
tags = py_test_tags,
python_version = python_version,
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 7753475..78098b1 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -25,8 +25,8 @@
"iree_py_binary",
)
load(
- "//integrations/tensorflow/e2e/keras:iree_vision_test_suite.bzl",
- "iree_vision_test_suite",
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
)
load(
"//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
@@ -152,6 +152,7 @@
},
reference_backend = "tf",
tags = [
+ "failing",
"manual",
"nokokoro",
"notap",
@@ -161,73 +162,80 @@
],
)
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
name = "large_cifar10_internal_tests",
size = "large",
- backends = [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- datasets = ["cifar10"],
- models = [
- # All models with runtime shorter than ResNet50.
- "MobileNet", # Max: Vulkan 61.0s
- "MobileNetV2", # Max: LLVM 96.3s
- "ResNet50", # Max: LLVM 145.6s
- "VGG16", # Max: LLVM 89.5s
- "VGG19", # Max: LLVM 94.7s
- ],
- reference_backend = "tf",
+ srcs = ["vision_model_test.py"],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "model": [
+ # All models with runtime shorter than ResNet50.
+ "MobileNet", # Max: Vulkan 61.0s
+ "MobileNetV2", # Max: LLVM 96.3s
+ "ResNet50", # Max: LLVM 145.6s
+ "VGG16", # Max: LLVM 89.5s
+ "VGG19", # Max: LLVM 94.7s
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "vision_model_test.py",
tags = ["manual"],
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
)
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
name = "enormous_cifar10_internal_tests",
size = "enormous",
- backends = [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- datasets = ["cifar10"],
+ srcs = ["vision_model_test.py"],
failing_configurations = [
{
# Failing on llvm and vulkan:
- "models": [
+ "model": [
"NASNetLarge",
"NASNetMobile",
"ResNet50V2",
"ResNet101V2",
"ResNet152V2",
],
- "datasets": ["cifar10"],
- "backends": [
+ "target_backends": [
"iree_llvmjit",
"iree_vulkan",
],
},
],
- models = [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- ],
- reference_backend = "tf",
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "model": [
+ "DenseNet121",
+ "DenseNet169",
+ "DenseNet201",
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50V2",
+ "ResNet101",
+ "ResNet101V2",
+ "ResNet152",
+ "ResNet152V2",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "vision_model_test.py",
tags = [
"guitar",
"manual",
@@ -239,22 +247,29 @@
],
)
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
name = "cifar10_external_tests",
- backends = [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- datasets = ["cifar10"],
- models = [
- "MobileNet",
- "MobileNetV2",
- "ResNet50",
- ],
- reference_backend = "tf",
+ size = "large",
+ srcs = ["vision_model_test.py"],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "cifar10",
+ "url": "https://storage.googleapis.com/iree_models/",
+ "use_external_weights": True,
+ "model": [
+ "MobileNet",
+ "MobileNetV2",
+ "ResNet50",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "vision_model_test.py",
tags = [
"external",
"guitar",
@@ -263,28 +278,19 @@
"nokokoro",
"notap",
],
- url = "https://storage.googleapis.com/iree_models/",
- use_external_weights = True,
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
)
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
name = "imagenet_external_tests",
size = "enormous",
- backends = [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
- datasets = ["imagenet"],
+ srcs = ["vision_model_test.py"],
failing_configurations = [
{
# Failing vulkan:
- "models": [
+ "model": [
"InceptionResNetV2",
"InceptionV3",
],
@@ -303,34 +309,45 @@
"ResNet152V2",
"Xception",
],
- "datasets": ["imagenet"],
- "backends": [
+ "target_backends": [
"iree_llvmjit",
"iree_vulkan",
],
},
],
- models = [
- "DenseNet121",
- "DenseNet169",
- "DenseNet201",
- "InceptionResNetV2",
- "InceptionV3",
- "MobileNet",
- "MobileNetV2",
- "NASNetLarge",
- "NASNetMobile",
- "ResNet50",
- "ResNet50V2",
- "ResNet101",
- "ResNet101V2",
- "ResNet152",
- "ResNet152V2",
- "VGG16",
- "VGG19",
- "Xception",
- ],
- reference_backend = "tf",
+ flags_to_values = {
+ "reference_backend": "tf",
+ "data": "imagenet",
+ "use_external_weights": True,
+ "model": [
+ "DenseNet121",
+ "DenseNet169",
+ "DenseNet201",
+ "InceptionResNetV2",
+ "InceptionV3",
+ "MobileNet",
+ "MobileNetV2",
+ "NASNetLarge",
+ "NASNetMobile",
+ "ResNet50",
+ "ResNet50V2",
+ "ResNet101",
+ "ResNet101V2",
+ "ResNet152",
+ "ResNet152V2",
+ "VGG16",
+ "VGG19",
+ "Xception",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "vision_model_test.py",
tags = [
"external",
"guitar",
@@ -338,7 +355,6 @@
"nokokoro",
"notap",
],
- use_external_weights = True,
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
diff --git a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
deleted file mode 100644
index 303938c..0000000
--- a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro for building e2e keras vision model tests."""
-
-load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
-load("@bazel_skylib//lib:new_sets.bzl", "sets")
-
-def iree_vision_test_suite(
- name,
- models,
- datasets,
- backends,
- reference_backend,
- failing_configurations = None,
- tags = None,
- url = None,
- use_external_weights = None,
- deps = None,
- size = "large",
- python_version = "PY3",
- **kwargs):
- """Creates a test for each configuration and bundles a succeeding and failing test suite.
-
- Creates one test per dataset, backend, and model. Tests indicated in
- `failing_configurations` are bundled into a suite suffixed with "_failing"
- tagged to be excluded from CI and wildcard builds. All other tests are
- bundled into a suite with the same name as the macro.
-
- Args:
- name:
- name of the generated passing test suite. If failing_configurations is
- not `None` then a test suite named name_failing will also be generated.
- models:
- an iterable of model names to generate targets for.
- datasets:
- an iterable specifying the datasets on which the models are based. This
- controls the shape of the input images. Also indicates which weight file
- to use when loading weights from an external source.
- backends:
- an iterable of targets backends to generate targets for.
- reference_backend:
- the backend to use as a source of truth for the expected output results.
- failing_configurations:
- an iterable of dictionaries with the keys `models`, `datasets` and
- `backends`. Each key points to a string or iterable of strings
- specifying a set of models, datasets and backends that are failing.
- tags:
- tags to apply to the test. Note that as in standard test suites, manual
- is treated specially and will also apply to the test suite itself.
- url:
- a base url to fetch non-keras trained model weights from.
- use_external_weights:
- whether or not to load model weights from the web (either uses keras or
- the supplied url).
- deps:
- test dependencies.
- size:
- size of the tests. Default: "large".
- python_version:
- the python version to run the tests with. Uses python3 by default.
- **kwargs:
- any additional arguments that will be passed to the underlying tests and
- test_suite.
- """
- failing_set = sets.make([])
- if failing_configurations != None:
- # Parse failing configurations.
- for configuration in failing_configurations:
- # Normalize configuration input.
- # {backend: "iree_llvmjit"} -> {backend: ["iree_llvmjit"]}
- for key, value in configuration.items():
- if type(value) == type(""):
- configuration[key] = [value]
-
- for model in configuration["models"]:
- for dataset in configuration["datasets"]:
- for backend in configuration["backends"]:
- sets.insert(failing_set, (model, dataset, backend))
-
- tests = []
- for model in models:
- for dataset in datasets:
- for backend in backends:
- # Check if this is a failing configuration.
- failing = sets.contains(failing_set, (model, dataset, backend))
-
- # Append "_failing" to name if this is a failing configuration.
- test_name = name if not failing else name + "_failing"
- if len(datasets) > 1:
- test_name = "{}_{}_{}__{}__{}".format(
- test_name,
- model,
- dataset,
- reference_backend,
- backend,
- )
- else:
- test_name = "{}_{}__{}__{}".format(
- test_name,
- model,
- reference_backend,
- backend,
- )
- tests.append(test_name)
-
- args = [
- "--model={}".format(model),
- "--data={}".format(dataset),
- "--reference_backend={}".format(reference_backend),
- "--target_backends={}".format(backend),
- ]
- if url:
- args.append("--url={}".format(url))
- if use_external_weights:
- args.append(
- "--use_external_weights={}".format(use_external_weights),
- )
-
- # TODO(GH-2175): Simplify this after backend names are
- # standardized.
- # "iree_<driver>" --> "<driver>"
- driver = backend.replace("iree_", "")
- if driver == "llvmjit":
- driver = "llvm"
- py_test_tags = ["driver={}".format(driver)]
- if tags != None: # `is` is not supported.
- py_test_tags += tags
-
- # Add additional tags if this is a failing configuration.
- if failing:
- py_test_tags += [
- "failing", # Only used for test_suite filtering below.
- "manual",
- "nokokoro",
- "notap",
- ]
-
- iree_py_test(
- name = test_name,
- main = "vision_model_test.py",
- srcs = ["vision_model_test.py"],
- args = args,
- tags = py_test_tags,
- deps = deps,
- size = size,
- python_version = python_version,
- **kwargs
- )
-
- native.test_suite(
- name = name,
- tests = tests,
- # Add "-failing" to exclude tests in `tests` that have the "failing"
- # tag.
- tags = tags + ["-failing"],
- # If there are kwargs that need to be passed here which only apply to
- # the generated tests and not to test_suite, they should be extracted
- # into separate named arguments.
- **kwargs
- )
-
- if failing_configurations != None:
- native.test_suite(
- name = name + "_failing",
- tests = tests,
- # Add "+failing" to only include tests in `tests` that have the
- # "failing" tag.
- tags = tags + ["+failing"],
- # If there are kwargs that need to be passed here which only apply
- # to the generated tests and not to test_suite, they should be
- # extracted into separate named arguments.
- **kwargs
- )
diff --git a/integrations/tensorflow/e2e/keras/train/BUILD b/integrations/tensorflow/e2e/keras/train/BUILD
index d6436b2..177bcaf 100644
--- a/integrations/tensorflow/e2e/keras/train/BUILD
+++ b/integrations/tensorflow/e2e/keras/train/BUILD
@@ -18,8 +18,8 @@
"NUMPY_DEPS",
)
load(
- "//integrations/tensorflow/e2e/keras/train:iree_train_test_suite.bzl",
- "iree_train_test_suite",
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
)
package(
@@ -28,13 +28,34 @@
licenses = ["notice"], # Apache 2.0
)
-# TODO(meadowlark): Refactor this rule to match iree_vision_test_suite.bzl
-iree_train_test_suite(
+iree_e2e_cartesian_product_test_suite(
name = "train_tests",
- configurations = [
- # tuples of (optimizer, backends)
- ("sgd", "tf"),
+ srcs = ["model_train_test.py"],
+ failing_configurations = [
+ {
+ "target_backends": [
+ "tflite",
+ "iree_vmla", # TODO(b/157581521)
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
],
+ flags_to_values = {
+ "reference_backend": "tf",
+ "optimizer_name": [
+ "sgd",
+ "adam",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "model_train_test.py",
tags = [
"guitar",
"manual",
@@ -45,21 +66,3 @@
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
)
-
-iree_train_test_suite(
- name = "train_tests_failing",
- configurations = [
- # tuples of (optimizer, backends)
- # TODO: Combine this suite with keras_model_train once these tests pass.
- ("sgd", "tf,iree_vmla"),
- ("adam", "tf,iree_vmla"), # TODO(b/157581521)
- ],
- tags = [
- "manual",
- "nokokoro",
- "notap",
- ],
- deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
- "//integrations/tensorflow/bindings/python/pyiree/tf/support",
- ],
-)
diff --git a/integrations/tensorflow/e2e/keras/train/README.md b/integrations/tensorflow/e2e/keras/train/README.md
new file mode 100644
index 0000000..86222b1
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/README.md
@@ -0,0 +1,10 @@
+# Keras Training Tests
+
+These tests require an additional python dependency on `sklearn`, which
+can be installed as follows:
+
+```shell
+python3 -m pip install sklearn
+```
+
+These tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl b/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl
deleted file mode 100644
index d6b244b..0000000
--- a/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro for building e2e keras vision model tests."""
-
-load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
-
-def iree_train_test_suite(
- name,
- configurations,
- tags = None,
- deps = None,
- size = None,
- python_version = "PY3",
- **kwargs):
- """Creates one iree_py_test per configuration tuple and a test suite that bundles them.
-
- Args:
- name:
- name of the generated test suite.
- configurations:
- a list of tuples of (optimizer, backends).
- tags:
- tags to apply to the test. Note that as in standard test suites, manual
- is treated specially and will also apply to the test suite itself.
- deps:
- test dependencies.
- size:
- size of the tests.
- python_version:
- the python version to run the tests with. Uses python3 by default.
- **kwargs:
- Any additional arguments that will be passed to the underlying tests
- and test_suite.
- """
- tests = []
- for optimizer, backends in configurations:
- test_name = "{}_{}_{}_test".format(name, optimizer, backends)
- tests.append(test_name)
-
- args = [
- "--optimizer_name={}".format(optimizer),
- "--target_backends={}".format(backends),
- ]
-
- iree_py_test(
- name = test_name,
- main = "model_train_test.py",
- srcs = ["model_train_test.py"],
- args = args,
- tags = tags,
- deps = deps,
- size = size,
- python_version = python_version,
- **kwargs
- )
-
- native.test_suite(
- name = name,
- tests = tests,
- # Note that only the manual tag really has any effect here. Others are
- # used for test suite filtering, but all tests are passed the same tags.
- tags = tags,
- # If there are kwargs that need to be passed here which only apply to
- # the generated tests and not to test_suite, they should be extracted
- # into separate named arguments.
- **kwargs
- )
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index c7c6101..bbfe6a7 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -26,8 +26,8 @@
"iree_py_binary",
)
load(
- "//integrations/tensorflow/e2e/slim_vision_models:iree_slim_vision_test_suite.bzl",
- "iree_slim_vision_test_suite",
+ "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+ "iree_e2e_cartesian_product_test_suite",
)
package(
@@ -48,33 +48,18 @@
],
)
-iree_slim_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
name = "slim_vision_tests",
size = "enormous",
- backends = [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
+ srcs = ["slim_vision_model_test.py"],
failing_configurations = [
{
# SavedModelV2 (classification/4) not available.
- "models": [
- "amoebanet_a_n18_f448",
- ],
- "backends": [
- "tf",
- "tflite",
- "iree_vmla",
- "iree_llvmjit",
- "iree_vulkan",
- ],
+ "model": "amoebanet_a_n18_f448",
},
{
# Failing llvmjit and vulkan:
- "models": [
+ "model": [
"nasnet_mobile",
"nasnet_large",
"pnasnet_large",
@@ -82,84 +67,95 @@
"resnet_v2_101",
"resnet_v2_152",
],
- "backends": [
+ "target_backends": [
"iree_llvmjit",
"iree_vulkan",
],
},
{
# Failing vulkan:
- "models": [
+ "model": [
# [ERROR]: cannot separate Linalg/Parallel ops into multiple kernels
"inception_v1",
"inception_v2",
"inception_v3",
"inception_resnet_v2",
],
- "backends": [
+ "target_backends": [
"iree_vulkan",
],
},
],
- models = [
- "amoebanet_a_n18_f448",
- "inception_resnet_v2",
- "inception_v1",
- "inception_v2",
- "inception_v3",
- # MobileNetV1
- "mobilenet_v1_025_128",
- "mobilenet_v1_025_160",
- "mobilenet_v1_025_192",
- "mobilenet_v1_025_224",
- "mobilenet_v1_050_128",
- "mobilenet_v1_050_160",
- "mobilenet_v1_050_192",
- "mobilenet_v1_050_224",
- "mobilenet_v1_075_128",
- "mobilenet_v1_075_160",
- "mobilenet_v1_075_192",
- "mobilenet_v1_075_224",
- "mobilenet_v1_100_128",
- "mobilenet_v1_100_160",
- "mobilenet_v1_100_192",
- "mobilenet_v1_100_224",
- # MobileNetV2:
- "mobilenet_v2_035_96",
- "mobilenet_v2_035_128",
- "mobilenet_v2_035_160",
- "mobilenet_v2_035_192",
- "mobilenet_v2_035_224",
- "mobilenet_v2_050_96",
- "mobilenet_v2_050_128",
- "mobilenet_v2_050_160",
- "mobilenet_v2_050_192",
- "mobilenet_v2_050_224",
- "mobilenet_v2_075_96",
- "mobilenet_v2_075_128",
- "mobilenet_v2_075_160",
- "mobilenet_v2_075_192",
- "mobilenet_v2_075_224",
- "mobilenet_v2_100_96",
- "mobilenet_v2_100_128",
- "mobilenet_v2_100_160",
- "mobilenet_v2_100_192",
- "mobilenet_v2_100_224",
- "mobilenet_v2_130_224",
- "mobilenet_v2_140_224",
- "nasnet_mobile",
- "nasnet_large",
- "pnasnet_large",
- # ResNetV1
- "resnet_v1_50",
- "resnet_v1_101",
- "resnet_v1_152",
- # ResNetV2
- "resnet_v2_50",
- "resnet_v2_101",
- "resnet_v2_152",
- ],
- reference_backend = "tf",
+ flags_to_values = {
+ "reference_backend": "tf",
+ "tf_hub_url": ["https://tfhub.dev/google/imagenet/"],
+ "model": [
+ "amoebanet_a_n18_f448",
+ "inception_resnet_v2",
+ "inception_v1",
+ "inception_v2",
+ "inception_v3",
+ # MobileNetV1
+ "mobilenet_v1_025_128",
+ "mobilenet_v1_025_160",
+ "mobilenet_v1_025_192",
+ "mobilenet_v1_025_224",
+ "mobilenet_v1_050_128",
+ "mobilenet_v1_050_160",
+ "mobilenet_v1_050_192",
+ "mobilenet_v1_050_224",
+ "mobilenet_v1_075_128",
+ "mobilenet_v1_075_160",
+ "mobilenet_v1_075_192",
+ "mobilenet_v1_075_224",
+ "mobilenet_v1_100_128",
+ "mobilenet_v1_100_160",
+ "mobilenet_v1_100_192",
+ "mobilenet_v1_100_224",
+ # MobileNetV2:
+ "mobilenet_v2_035_96",
+ "mobilenet_v2_035_128",
+ "mobilenet_v2_035_160",
+ "mobilenet_v2_035_192",
+ "mobilenet_v2_035_224",
+ "mobilenet_v2_050_96",
+ "mobilenet_v2_050_128",
+ "mobilenet_v2_050_160",
+ "mobilenet_v2_050_192",
+ "mobilenet_v2_050_224",
+ "mobilenet_v2_075_96",
+ "mobilenet_v2_075_128",
+ "mobilenet_v2_075_160",
+ "mobilenet_v2_075_192",
+ "mobilenet_v2_075_224",
+ "mobilenet_v2_100_96",
+ "mobilenet_v2_100_128",
+ "mobilenet_v2_100_160",
+ "mobilenet_v2_100_192",
+ "mobilenet_v2_100_224",
+ "mobilenet_v2_130_224",
+ "mobilenet_v2_140_224",
+ "nasnet_mobile",
+ "nasnet_large",
+ "pnasnet_large",
+ # ResNetV1
+ "resnet_v1_50",
+ "resnet_v1_101",
+ "resnet_v1_152",
+ # ResNetV2
+ "resnet_v2_50",
+ "resnet_v2_101",
+ "resnet_v2_152",
+ ],
+ "target_backends": [
+ "tf",
+ "tflite",
+ "iree_vmla",
+ "iree_llvmjit",
+ "iree_vulkan",
+ ],
+ },
+ main = "slim_vision_model_test.py",
tags = [
"external",
"guitar",
@@ -168,8 +164,7 @@
"nokokoro",
"notap",
],
- tf_hub_url = "https://tfhub.dev/google/imagenet/",
- deps = INTREE_TENSORFLOW_PY_DEPS + INTREE_TF_HUB_DEPS + NUMPY_DEPS + [
+ deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
"//integrations/tensorflow/bindings/python/pyiree/tf/support",
],
)
diff --git a/integrations/tensorflow/e2e/slim_vision_models/README.md b/integrations/tensorflow/e2e/slim_vision_models/README.md
index 0a0b5dd..8a09a14 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/README.md
+++ b/integrations/tensorflow/e2e/slim_vision_models/README.md
@@ -7,4 +7,4 @@
python3 -m pip install tensorflow_hub
```
-Like the `vision_external_tests`, these tests are not checked by the OSS CI.
+These tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl b/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
deleted file mode 100644
index 857c1b4..0000000
--- a/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro for building e2e keras vision model tests."""
-
-load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
-load("@bazel_skylib//lib:new_sets.bzl", "sets")
-
-def iree_slim_vision_test_suite(
- name,
- models,
- backends,
- reference_backend,
- failing_configurations = None,
- tf_hub_url = None,
- tags = None,
- deps = None,
- size = "large",
- python_version = "PY3",
- **kwargs):
- """Creates a test for each configuration and bundles a succeeding and failing test suite.
-
- Creates one test per model and backend. Tests indicated in
- `failing_configurations` are bundled into a suite suffixed with "_failing"
- tagged to be excluded from CI and wildcard builds. All other tests are
- bundled into a suite with the same name as the macro.
-
- Args:
- name:
- name of the generated passing test suite. If failing_configurations is
- not `None` then a test suite named name_failing will also be generated.
- models:
- an iterable of slim vision model tags to generate targets for.
- backends:
- an iterable of targets backends to generate targets for.
- reference_backend:
- the backend to use as a source of truth for the expected output results.
- failing_configurations:
- an iterable of dictionaries with the keys `models` and `backends`. Each
- key points to a string or iterable of strings specifying a set of models
- and backends that are failing.
- tf_hub_url:
- a string pointing to the TF Hub base url of the models to test.
- tags:
- tags to apply to the test. Note that as in standard test suites, manual
- is treated specially and will also apply to the test suite itself.
- deps:
- test dependencies.
- size:
- size of the tests. Default: "large".
- python_version:
- the python version to run the tests with. Uses python3 by default.
- **kwargs:
- any additional arguments that will be passed to the underlying tests and
- test_suite.
- """
- failing_set = sets.make([])
- if failing_configurations != None:
- # Parse failing configurations.
- for configuration in failing_configurations:
- # Normalize configuration input.
- # {backend: "iree_llvmjit"} -> {backend: ["iree_llvmjit"]}
- for key, value in configuration.items():
- if type(value) == type(""):
- configuration[key] = [value]
-
- for model in configuration["models"]:
- for backend in configuration["backends"]:
- sets.insert(failing_set, (model, backend))
-
- tests = []
- for model in models:
- for backend in backends:
- # Check if this is a failing configuration.
- failing = sets.contains(failing_set, (model, backend))
-
- # Append "_failing" to name if this is a failing configuration.
- test_name = name if not failing else name + "_failing"
- test_name = "{}_{}__{}__{}".format(
- test_name,
- model,
- reference_backend,
- backend,
- )
- tests.append(test_name)
-
- args = [
- "--model={}".format(model),
- "--tf_hub_url={}".format(tf_hub_url),
- "--reference_backend={}".format(reference_backend),
- "--target_backends={}".format(backend),
- ]
-
- # TODO(GH-2175): Simplify this after backend names are
- # standardized.
- # "iree_<driver>" --> "<driver>"
- driver = backend.replace("iree_", "")
- if driver == "llvmjit":
- driver = "llvm"
- py_test_tags = ["driver={}".format(driver)]
- if tags != None: # `is` is not supported.
- py_test_tags += tags
-
- # Add additional tags if this is a failing configuration.
- if failing:
- py_test_tags += [
- "failing", # Only used for test_suite filtering below.
- "manual",
- "nokokoro",
- "notap",
- ]
-
- iree_py_test(
- name = test_name,
- main = "slim_vision_model_test.py",
- srcs = ["slim_vision_model_test.py"],
- args = args,
- tags = py_test_tags,
- deps = deps,
- size = size,
- python_version = python_version,
- **kwargs
- )
-
- native.test_suite(
- name = name,
- tests = tests,
- # Add "-failing" to exclude tests in `tests` that have the "failing"
- # tag.
- tags = tags + ["-failing"],
- # If there are kwargs that need to be passed here which only apply to
- # the generated tests and not to test_suite, they should be extracted
- # into separate named arguments.
- **kwargs
- )
-
- if failing_configurations != None:
- native.test_suite(
- name = name + "_failing",
- tests = tests,
- # Add "+failing" to only include tests in `tests` that have the
- # "failing" tag.
- tags = tags + ["+failing"],
- # If there are kwargs that need to be passed here which only apply
- # to the generated tests and not to test_suite, they should be
- # extracted into separate named arguments.
- **kwargs
- )
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 819373f..0f97aae 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -202,10 +202,9 @@
namespace {
template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
using OpRewritePattern<INDIRECT>::OpRewritePattern;
- public:
LogicalResult matchAndRewrite(INDIRECT op,
PatternRewriter &rewriter) const override {
if (auto addressOp =
@@ -244,10 +243,9 @@
namespace {
template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
using OpRewritePattern<INDIRECT>::OpRewritePattern;
- public:
LogicalResult matchAndRewrite(INDIRECT op,
PatternRewriter &rewriter) const override {
if (auto addressOp =
@@ -382,15 +380,13 @@
// Native integer arithmetic
//===----------------------------------------------------------------------===//
-namespace {
-
/// Performs const folding `calculate` with element-wise behavior on the given
/// attribute in `operands` and returns the result if possible.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 1 && "unary op takes one operand");
if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
return AttrElementT::get(operand.getType(), calculate(operand.getValue()));
@@ -414,8 +410,8 @@
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
std::function<ElementValueT(ElementValueT, ElementValueT)>>
-Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
@@ -448,42 +444,54 @@
return {};
}
-} // namespace
-
-template <typename T>
-static OpFoldResult foldAddOp(T op, ArrayRef<Attribute> operands) {
+template <typename ADD, typename SUB>
+static OpFoldResult foldAddOp(ADD op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x + 0 = x or 0 + y = y (commutative)
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a + b; });
+ if (auto subOp = dyn_cast_or_null<SUB>(op.lhs().getDefiningOp())) {
+ if (subOp.lhs() == op.rhs()) return subOp.rhs();
+ if (subOp.rhs() == op.rhs()) return subOp.lhs();
+ } else if (auto subOp = dyn_cast_or_null<SUB>(op.rhs().getDefiningOp())) {
+ if (subOp.lhs() == op.lhs()) return subOp.rhs();
+ if (subOp.rhs() == op.lhs()) return subOp.lhs();
+ }
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a + b; });
}
OpFoldResult AddI32Op::fold(ArrayRef<Attribute> operands) {
- return foldAddOp(*this, operands);
+ return foldAddOp<AddI32Op, SubI32Op>(*this, operands);
}
OpFoldResult AddI64Op::fold(ArrayRef<Attribute> operands) {
- return foldAddOp(*this, operands);
+ return foldAddOp<AddI64Op, SubI64Op>(*this, operands);
}
-template <typename T>
-static OpFoldResult foldSubOp(T op, ArrayRef<Attribute> operands) {
+template <typename SUB, typename ADD>
+static OpFoldResult foldSubOp(SUB op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
// x - 0 = x
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a - b; });
+ if (auto addOp = dyn_cast_or_null<ADD>(op.lhs().getDefiningOp())) {
+ if (addOp.lhs() == op.rhs()) return addOp.rhs();
+ if (addOp.rhs() == op.rhs()) return addOp.lhs();
+ } else if (auto addOp = dyn_cast_or_null<ADD>(op.rhs().getDefiningOp())) {
+ if (addOp.lhs() == op.lhs()) return addOp.rhs();
+ if (addOp.rhs() == op.lhs()) return addOp.lhs();
+ }
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a - b; });
}
OpFoldResult SubI32Op::fold(ArrayRef<Attribute> operands) {
- return foldSubOp(*this, operands);
+ return foldSubOp<SubI32Op, AddI32Op>(*this, operands);
}
OpFoldResult SubI64Op::fold(ArrayRef<Attribute> operands) {
- return foldSubOp(*this, operands);
+ return foldSubOp<SubI64Op, AddI64Op>(*this, operands);
}
template <typename T>
@@ -495,18 +503,51 @@
// x * 1 = x or 1 * y = y (commutative)
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a * b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a * b; });
}
+template <typename T, typename CONST_OP>
+struct FoldConstantMulOperand : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ IntegerAttr c1, c2;
+ if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure();
+ if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) {
+ if (matchPattern(mulOp.rhs(), m_Constant(&c2))) {
+ auto c = rewriter.createOrFold<CONST_OP>(
+ FusedLoc::get({mulOp.getLoc(), op.getLoc()}, rewriter.getContext()),
+ constFoldBinaryOp<IntegerAttr>(
+ {c1, c2},
+ [](const APInt &a, const APInt &b) { return a * b; }));
+ rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c);
+ return success();
+ }
+ }
+ return failure();
+ }
+};
+
OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) {
return foldMulOp(*this, operands);
}
+void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context);
+}
+
OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) {
return foldMulOp(*this, operands);
}
+void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context);
+}
+
template <typename T>
static OpFoldResult foldDivSOp(T op, ArrayRef<Attribute> operands) {
if (matchPattern(op.rhs(), m_Zero())) {
@@ -521,7 +562,7 @@
return op.lhs();
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.sdiv(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.sdiv(b); });
}
OpFoldResult DivI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -546,7 +587,7 @@
return op.lhs();
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.udiv(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.udiv(b); });
}
OpFoldResult DivI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -570,7 +611,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.srem(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.srem(b); });
}
OpFoldResult RemI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -589,7 +630,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return a.urem(b); });
+ operands, [](const APInt &a, const APInt &b) { return a.urem(b); });
}
OpFoldResult RemI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -625,8 +666,8 @@
// x & x = x
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a & b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a & b; });
}
OpFoldResult AndI32Op::fold(ArrayRef<Attribute> operands) {
@@ -646,8 +687,8 @@
// x | x = x
return op.lhs();
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a | b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a | b; });
}
OpFoldResult OrI32Op::fold(ArrayRef<Attribute> operands) {
@@ -667,8 +708,8 @@
// x ^ x = 0
return zeroOfType(op.getType());
}
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a ^ b; });
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](const APInt &a, const APInt &b) { return a ^ b; });
}
OpFoldResult XorI32Op::fold(ArrayRef<Attribute> operands) {
@@ -693,7 +734,7 @@
return op.operand();
}
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return a.shl(op.amount()); });
+ operands, [&](const APInt &a) { return a.shl(op.amount()); });
}
OpFoldResult ShlI32Op::fold(ArrayRef<Attribute> operands) {
@@ -714,7 +755,7 @@
return op.operand();
}
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return a.ashr(op.amount()); });
+ operands, [&](const APInt &a) { return a.ashr(op.amount()); });
}
OpFoldResult ShrI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -735,7 +776,7 @@
return op.operand();
}
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return a.lshr(op.amount()); });
+ operands, [&](const APInt &a) { return a.lshr(op.amount()); });
}
OpFoldResult ShrI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -755,8 +796,9 @@
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldConversionOp(Type resultType, ArrayRef<Attribute> operands,
- const CalculationT &calculate) {
+static Attribute constFoldConversionOp(Type resultType,
+ ArrayRef<Attribute> operands,
+ const CalculationT &calculate) {
assert(operands.size() == 1 && "unary op takes one operand");
if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
return AttrElementT::get(resultType, calculate(operand.getValue()));
@@ -767,101 +809,100 @@
OpFoldResult TruncI32I8Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(32); });
+ [&](const APInt &a) { return a.trunc(8).zext(32); });
}
OpFoldResult TruncI32I16Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(32); });
+ [&](const APInt &a) { return a.trunc(16).zext(32); });
}
OpFoldResult TruncI64I8Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(32); });
+ [&](const APInt &a) { return a.trunc(8).zext(32); });
}
OpFoldResult TruncI64I16Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(32); });
+ [&](const APInt &a) { return a.trunc(16).zext(32); });
}
OpFoldResult TruncI64I32Op::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(32); });
+ [&](const APInt &a) { return a.trunc(32); });
}
OpFoldResult ExtI8I32SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).sext(32); });
+ [&](const APInt &a) { return a.trunc(8).sext(32); });
}
OpFoldResult ExtI8I32UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(32); });
+ [&](const APInt &a) { return a.trunc(8).zext(32); });
}
OpFoldResult ExtI16I32SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).sext(32); });
+ [&](const APInt &a) { return a.trunc(16).sext(32); });
}
OpFoldResult ExtI16I32UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(32, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(32); });
+ [&](const APInt &a) { return a.trunc(16).zext(32); });
}
OpFoldResult ExtI8I64SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(8).sext(64); });
+ [&](const APInt &a) { return a.trunc(8).sext(64); });
}
OpFoldResult ExtI8I64UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(8).zext(64); });
+ [&](const APInt &a) { return a.trunc(8).zext(64); });
}
OpFoldResult ExtI16I64SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(16).sext(64); });
+ [&](const APInt &a) { return a.trunc(16).sext(64); });
}
OpFoldResult ExtI16I64UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.trunc(16).zext(64); });
+ [&](const APInt &a) { return a.trunc(16).zext(64); });
}
OpFoldResult ExtI32I64SOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.sext(64); });
+ [&](const APInt &a) { return a.sext(64); });
}
OpFoldResult ExtI32I64UOp::fold(ArrayRef<Attribute> operands) {
return constFoldConversionOp<IntegerAttr>(
IntegerType::get(64, getContext()), operands,
- [&](APInt a) { return a.zext(64); });
+ [&](const APInt &a) { return a.zext(64); });
}
namespace {
template <typename SRC_OP, typename OP_A, int SZ_T, typename OP_B>
-class PseudoIntegerConversionToSplitConversionOp
+struct PseudoIntegerConversionToSplitConversionOp
: public OpRewritePattern<SRC_OP> {
using OpRewritePattern<SRC_OP>::OpRewritePattern;
- public:
LogicalResult matchAndRewrite(SRC_OP op,
PatternRewriter &rewriter) const override {
auto tmp = rewriter.createOrFold<OP_A>(
@@ -956,7 +997,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.eq(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.eq(b); });
}
OpFoldResult CmpEQI32Op::fold(ArrayRef<Attribute> operands) {
@@ -984,7 +1025,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ne(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ne(b); });
}
OpFoldResult CmpNEI32Op::fold(ArrayRef<Attribute> operands) {
@@ -1032,7 +1073,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.slt(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.slt(b); });
}
OpFoldResult CmpLTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1056,7 +1097,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ult(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ult(b); });
}
OpFoldResult CmpLTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1100,7 +1141,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.sle(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.sle(b); });
}
OpFoldResult CmpLTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1130,7 +1171,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ule(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ule(b); });
}
OpFoldResult CmpLTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1176,7 +1217,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.sgt(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.sgt(b); });
}
OpFoldResult CmpGTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1206,7 +1247,7 @@
return zeroOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.ugt(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.ugt(b); });
}
OpFoldResult CmpGTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1256,7 +1297,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.sge(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.sge(b); });
}
OpFoldResult CmpGTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1286,7 +1327,7 @@
return oneOfType(op.getType());
}
return constFoldBinaryOp<IntegerAttr>(
- operands, [&](APInt a, APInt b) { return a.uge(b); });
+ operands, [&](const APInt &a, const APInt &b) { return a.uge(b); });
}
OpFoldResult CmpGTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1311,12 +1352,12 @@
OpFoldResult CmpNZI32Op::fold(ArrayRef<Attribute> operands) {
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return APInt(32, a.getBoolValue()); });
+ operands, [&](const APInt &a) { return APInt(32, a.getBoolValue()); });
}
OpFoldResult CmpNZI64Op::fold(ArrayRef<Attribute> operands) {
return constFoldUnaryOp<IntegerAttr>(
- operands, [&](APInt a) { return APInt(64, a.getBoolValue()); });
+ operands, [&](const APInt &a) { return APInt(64, a.getBoolValue()); });
}
OpFoldResult CmpEQRefOp::fold(ArrayRef<Attribute> operands) {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 727cbf0..d613c8b 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1311,6 +1311,7 @@
VM_BinaryArithmeticOp<I32, "mul.i32", VM_OPC_MulI32, [Commutative]> {
let summary = [{integer multiplication operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_MulI64Op :
@@ -1318,6 +1319,7 @@
[VM_ExtI64, Commutative]> {
let summary = [{integer multiplication operation}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def VM_DivI32SOp :
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
index 153389c..38e49b2 100644
--- a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -56,6 +56,41 @@
// -----
+// CHECK-LABEL: @add_sub_i32_folds
+vm.module @add_sub_i32_folds {
+ // CHECK-LABEL: @add_sub_x
+ vm.func @add_sub_x(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.add.i32 %arg0, %arg1 : i32
+ %1 = vm.sub.i32 %0, %arg1 : i32
+ vm.return %1 : i32
+ }
+ // CHECK-LABEL: @add_sub_x_rev
+ vm.func @add_sub_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.add.i32 %arg1, %arg0 : i32
+ %1 = vm.sub.i32 %arg1, %0 : i32
+ vm.return %1 : i32
+ }
+
+ // CHECK-LABEL: @sub_add_x
+ vm.func @sub_add_x(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.sub.i32 %arg0, %arg1 : i32
+ %1 = vm.add.i32 %0, %arg1 : i32
+ vm.return %1 : i32
+ }
+ // CHECK-LABEL: @sub_add_x_rev
+ vm.func @sub_add_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+ // CHECK-NEXT: vm.return %arg0
+ %0 = vm.sub.i32 %arg0, %arg1 : i32
+ %1 = vm.add.i32 %arg1, %0 : i32
+ vm.return %1 : i32
+ }
+}
+
+// -----
+
// CHECK-LABEL: @mul_i32_folds
vm.module @mul_i32_folds {
// CHECK-LABEL: @mul_i32_by_0
@@ -96,6 +131,23 @@
// -----
+// CHECK-LABEL: @mul_mul_i32_folds
+vm.module @mul_mul_i32_folds {
+ // CHECK-LABEL: @mul_mul_i32_const
+ vm.func @mul_mul_i32_const(%arg0 : i32) -> i32 {
+ // CHECK: %c40 = vm.const.i32 40 : i32
+ %c4 = vm.const.i32 4 : i32
+ %c10 = vm.const.i32 10 : i32
+ // CHECK: %0 = vm.mul.i32 %arg0, %c40 : i32
+ %0 = vm.mul.i32 %arg0, %c4 : i32
+ %1 = vm.mul.i32 %0, %c10 : i32
+ // CHECK-NEXT: vm.return %0 : i32
+ vm.return %1 : i32
+ }
+}
+
+// -----
+
// CHECK-LABEL: @div_i32_folds
vm.module @div_i32_folds {
// CHECK-LABEL: @div_i32_0_y
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index b3982b0..6bd35bb 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -43,6 +43,7 @@
// If we end up with a lot of these, consider using an "is pseudo" trait.
addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
addIllegalOp<IREE::VMLA::SortPseudoOp>();
+ addIllegalOp<IREE::VMLA::FftPseudoOp>();
// Allow other ops to pass through so long as their type is valid (not a
// tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 6f43d6e..5b842a5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -703,6 +703,43 @@
TypeConverter &typeConverter;
};
+struct FftOpConversion : public OpConversionPattern<IREE::VMLA::FftPseudoOp> {
+ FftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ LogicalResult matchAndRewrite(
+ IREE::VMLA::FftPseudoOp srcOp, ArrayRef<Value> rawOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto input_shape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
+
+ auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
+ auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
+
+ // The input type/shape should match for the real and imag components.
+ if (real_input_type != imag_input_type) {
+ srcOp.emitWarning() << "real and imag should have matching types";
+ return failure();
+ }
+
+ auto real_out = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
+ auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
+
+ rewriter.createOrFold<IREE::VMLA::FftOp>(
+ srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
+ input_shape, real_out, imag_out,
+ TypeAttr::get(real_input_type.getElementType()),
+ TypeAttr::get(imag_input_type.getElementType()));
+
+ rewriter.replaceOp(srcOp, {real_out, imag_out});
+ return success();
+ }
+
+ TypeConverter &typeConverter;
+};
+
struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
using OpConversionPattern::OpConversionPattern;
@@ -769,6 +806,9 @@
// vmla.sort.pseudo
patterns.insert<SortOpConversion>(context, typeConverter);
+ // vmla.fft.pseudo
+ patterns.insert<FftOpConversion>(context, typeConverter);
+
// Simple 1:1 conversion patterns using the automated trait-based converter.
// Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
new file mode 100644
index 0000000..1e3c365
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>) attributes { sym_visibility = "private" } {
+ // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
+ // CHECK-NEXT: [[C32:%.+]] = constant 32 : index
+ // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+ // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+ // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>), %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32, f32
+ %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>)
+ return %real, %imag : tensor<8xf32>, tensor<8xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f31545d..f139690 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -694,5 +694,24 @@
}];
}
+def VMLA_FftOp : VMLA_ElementTypeOp<"fft", [VMLA_IncludeShapes]> {
+ let arguments = (ins
+ VMLA_Buffer:$real_in,
+ VMLA_Shape:$real_in_shape,
+ VMLA_Buffer:$imag_in,
+ VMLA_Shape:$imag_in_shape,
+ VMLA_Buffer:$real_out,
+ VMLA_Buffer:$imag_out,
+ VMLA_AnyTypeAttr:$real_element_type,
+ VMLA_AnyTypeAttr:$imag_element_type
+ );
+
+ let assemblyFormat = [{
+ $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
+ $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
+ `out` $real_out `,` $imag_out attr-dict `:` $real_element_type `,` $imag_element_type
+ }];
+}
+
#endif // IREE_DIALECT_VMLA_OPS
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index baa8d41..35732c6 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -50,6 +50,15 @@
'End to end tests of TensorFlow slim vision models',
}
+# Key to use as the name of the rows in the left column for each test in the
+# suite.
+TEST_SUITE_TO_ROW_ID_KEY = {
+ '//integrations/tensorflow/e2e/keras:imagenet_external_tests':
+ 'model',
+ '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
+ 'model',
+}
+
# Some test suites are generated from a single source. This allows us to point
# to the right test file when generating test URLs.
SINGLE_SOURCE_SUITES = {
@@ -101,6 +110,34 @@
return parsed_args
+def parse_test_name(test_name, test_suite):
+ """Splits a test name into a dictionary with its source file and backend."""
+ test_name_parts = test_name.split("__")
+ test_info = {}
+
+ # The iree_e2e_test_suite elides a 'src' key before the name of the test
+ # for brevity.
+ if len(test_name_parts) % 2 == 1:
+ test_info['src'] = test_name_parts.pop(0)
+
+ # The rest of the test name should follow 'key__value__key__value__...'.
+ for key, value in zip(test_name_parts[::2], test_name_parts[1::2]):
+ test_info[key] = value
+
+ # Default to using the test source file name as the row id for the table.
+ if 'src' in test_info:
+ test_info['row_id'] = test_info['src']
+ else:
+ test_info['src'] = SINGLE_SOURCE_SUITES[test_suite]
+ test_info['row_id'] = test_info[TEST_SUITE_TO_ROW_ID_KEY[test_suite]]
+
+ if 'target_backends' not in test_info:
+ raise ValueError('Expected `target_backends` to be in the test name but '
+ f'got `{test_name}`.')
+
+ return test_info
+
+
def get_name_and_backend(test_string):
"""Splits a pathless test target into its name and comparison backend."""
name, backend = test_string.split(f'__{REFERENCE_BACKEND}__')
@@ -113,42 +150,41 @@
failing = utils.get_test_targets(f'{test_suite}_failing')
# Remove bazel path.
- passing = [test.replace(f'{test_suite}_', '') for test in passing]
- failing = [test.replace(f'{test_suite}_failing_', '') for test in failing]
+ passing = [test.replace(f'{test_suite}__', '') for test in passing]
+ failing = [test.replace(f'{test_suite}_failing__', '') for test in failing]
- # Split into (test_name, target_backend).
- passing = [get_name_and_backend(test) for test in passing]
- failing = [get_name_and_backend(test) for test in failing]
- passing_names = [test[0] for test in passing]
- failing_names = [test[0] for test in failing]
- all_names = list(sorted(set(passing_names + failing_names)))
- return all_names, passing, failing
+ # Split into a dictionary mapping 'src', 'target_backend', ... to the
+ # appropriate values for each test target.
+ passing_info = [parse_test_name(test, test_suite) for test in passing]
+ failing_info = [parse_test_name(test, test_suite) for test in failing]
+ return passing_info, failing_info
-def get_name_element(test_suite, name):
+def get_row_hyperlink(test_suite, row_id, test_source):
"""Returns a Markdown hyperlink pointing to the test source on GitHub."""
# Convert `//path/to/tests:test_suite` to `path/to/tests`
- test_path = test_suite.split(':')[0]
- test_path = test_path.replace('//', '')
+ test_path = test_suite.replace('//', '').split(':')[0]
- if test_suite in SINGLE_SOURCE_SUITES:
- test_name = SINGLE_SOURCE_SUITES[test_suite]
- else:
- test_name = name
-
- test_url = os.path.join(MAIN_URL, test_path, f'{test_name}.py')
- return f'[{name}]({test_url})'
+ test_url = os.path.join(MAIN_URL, test_path, f'{test_source}.py')
+ return f'[{row_id}]({test_url})'
def generate_table(test_suite):
"""Generates an e2e backend coverage Markdown table."""
- all_names, passing, _ = get_suite_metadata(test_suite)
+ passing_info, _ = get_suite_metadata(test_suite)
- # Generate a dictionary mapping test names to their backend coverage.
+ # Create a dictionary mapping row names to source file names.
+ row_id_to_source = {}
+ for test_info in passing_info:
+ row_id_to_source[test_info['row_id']] = test_info['src']
+
+ # Create a dictionary mapping test names to a list of bools representing their
+ # backend coverage.
table = collections.defaultdict(lambda: [False] * len(BACKENDS_TO_TITLES))
ordered_backends = list(BACKENDS_TO_TITLES.keys())
- for name, backend in passing:
- table[name][ordered_backends.index(backend)] = True
+ for test_info in passing_info:
+ backend_index = ordered_backends.index(test_info['target_backends'])
+ table[test_info['row_id']][backend_index] = True
# Create a header for the coverage table.
ordered_backend_titles = list(BACKENDS_TO_TITLES.values())
@@ -157,11 +193,12 @@
# Generate the coverage table as a 2D array.
rows = [first_row, second_row]
- for name, backends in sorted(table.items()):
- if any(re.match(pattern, name) for pattern in TARGET_EXCLUSION_FILTERS):
+ for row_id, backends in sorted(table.items()):
+ # Skip any rows defined in the TARGET_EXCLUSION_FILTERS.
+ if any(re.match(pattern, row_id) for pattern in TARGET_EXCLUSION_FILTERS):
continue
- row = [get_name_element(test_suite, name)]
+ row = [get_row_hyperlink(test_suite, row_id, row_id_to_source[row_id])]
row.extend([
SUCCESS_ELEMENT if backend else FAILURE_ELEMENT for backend in backends
])