Merge pull request #3486 from ScottTodd:main-to-google

PiperOrigin-RevId: 337331640
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 701ba0d..472cea2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -394,8 +394,8 @@
 if(CMAKE_CROSSCOMPILING)
   # We need flatc to generate some source code. When cross-compiling, we need
   # to make sure the flatc binary is configured under host environment.
-  iree_declare_host_excutable(flatc BUILDONLY)
-  iree_declare_host_excutable(flatcc_cli BUILDONLY)
+  iree_declare_host_excutable(flatc "flatc" BUILDONLY)
+  iree_declare_host_excutable(flatcc_cli "flatcc_cli" BUILDONLY)
 
   # Set the FLATBUFFERS_FLATC_EXECUTABLE. It controls where to find the flatc
   # binary in BuildFlatBuffers().
@@ -414,7 +414,7 @@
     COMMAND
       "${CMAKE_COMMAND}" -E copy_if_different
         "${PROJECT_SOURCE_DIR}/third_party/flatcc/bin/flatcc${IREE_HOST_EXECUTABLE_SUFFIX}"
-        "${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli"
+        "${IREE_HOST_BINARY_ROOT}/bin/flatcc_cli${IREE_HOST_EXECUTABLE_SUFFIX}"
     DEPENDS iree_host_build_flatcc_cli
     COMMENT "Installing host flatcc..."
   )
diff --git a/build_tools/cmake/iree_cc_binary.cmake b/build_tools/cmake/iree_cc_binary.cmake
index 5e71014..5e1bd0a 100644
--- a/build_tools/cmake/iree_cc_binary.cmake
+++ b/build_tools/cmake/iree_cc_binary.cmake
@@ -78,7 +78,7 @@
     # The binary is marked as host only. We need to declare the rules for
     # generating them under host configuration so when cross-compiling towards
     # target we can still have this binary.
-    iree_declare_host_excutable(${_RULE_NAME})
+    iree_declare_host_excutable(${_RULE_NAME} ${_NAME})
 
     # Still define the package-prefixed target so we can have a consistent way
     # to reference this binary, whether cross-compiling or not. But this time
diff --git a/build_tools/cmake/iree_cross_compile.cmake b/build_tools/cmake/iree_cross_compile.cmake
index 4c67410..b52b4d2 100644
--- a/build_tools/cmake/iree_cross_compile.cmake
+++ b/build_tools/cmake/iree_cross_compile.cmake
@@ -131,13 +131,13 @@
 
 # iree_get_build_command
 #
-# Gets the CMake build command for the given `EXECUTABLE`.
+# Gets the CMake build command for the given `EXECUTABLE_TARGET`.
 #
 # Parameters:
-# EXECUTABLE: the executable to build.
+# EXECUTABLE_TARGET: the target for the executable to build.
 # BINDIR: root binary directory containing CMakeCache.txt.
 # CMDVAR: variable name for receiving the build command.
-function(iree_get_build_command EXECUTABLE)
+function(iree_get_build_command EXECUTABLE_TARGET)
   cmake_parse_arguments(_RULE "" "BINDIR;CMDVAR;CONFIG" "" ${ARGN})
   if(NOT _RULE_CONFIG)
     set(_RULE_CONFIG "$<CONFIG>")
@@ -145,11 +145,11 @@
   if (CMAKE_GENERATOR MATCHES "Make")
     # Use special command for Makefiles to support parallelism.
     set(${_RULE_CMDVAR}
-        "$(MAKE)" "-C" "${_RULE_BINDIR}" "${EXECUTABLE}" PARENT_SCOPE)
+        "$(MAKE)" "-C" "${_RULE_BINDIR}" "${EXECUTABLE_TARGET}" PARENT_SCOPE)
   else()
     set(${_RULE_CMDVAR}
         "${CMAKE_COMMAND}" --build ${_RULE_BINDIR}
-                           --target ${EXECUTABLE}${IREE_HOST_EXECUTABLE_SUFFIX}
+                           --target ${EXECUTABLE_TARGET}
                            --config ${_RULE_CONFIG} PARENT_SCOPE)
   endif()
 endfunction()
@@ -201,14 +201,15 @@
 #
 # Parameters:
 # EXECUTABLE: the executable to build on host.
+# EXECUTABLE_TARGET: the target name for the executable.
 # BUILDONLY: only generates commands for building the target.
 # DEPENDS: any additional dependencies for the target.
-function(iree_declare_host_excutable EXECUTABLE)
+function(iree_declare_host_excutable EXECUTABLE EXECUTABLE_TARGET)
   cmake_parse_arguments(_RULE "BUILDONLY" "" "DEPENDS" ${ARGN})
 
   iree_get_executable_path(_OUTPUT_PATH ${EXECUTABLE})
 
-  iree_get_build_command(${EXECUTABLE}
+  iree_get_build_command(${EXECUTABLE_TARGET}
     BINDIR ${IREE_HOST_BINARY_ROOT}
     CMDVAR build_cmd)
 
diff --git a/build_tools/embed_data/CMakeLists.txt b/build_tools/embed_data/CMakeLists.txt
index ed9a879..8679ccc 100644
--- a/build_tools/embed_data/CMakeLists.txt
+++ b/build_tools/embed_data/CMakeLists.txt
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 if(CMAKE_CROSSCOMPILING)
-  iree_declare_host_excutable(generate_cc_embed_data)
+  iree_declare_host_excutable(generate_cc_embed_data "generate_cc_embed_data")
 else()
   add_executable(generate_cc_embed_data)
   target_sources(generate_cc_embed_data PRIVATE generate_cc_embed_data_main.cc)
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index ed15f36..af2c510 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -105,7 +105,6 @@
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
-    "sort_test.py",
     "strings_test.py",
 ]
 
@@ -126,7 +125,6 @@
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
-    "sort_test.py",
     "strings_test.py",
 ]
 
@@ -180,6 +178,7 @@
     },
     reference_backend = "tf",
     tags = [
+        "failing",
         "manual",
         "nokokoro",
         "notap",
@@ -220,6 +219,7 @@
     },
     reference_backend = "tf",
     tags = [
+        "failing",
         "manual",
         "nokokoro",
         "notap",
@@ -261,6 +261,7 @@
     reference_backend = "tf",
     tags = [
         "external",
+        "failing",
         "guitar",
         "manual",
         "no-remote",
diff --git a/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
new file mode 100644
index 0000000..e7a4fe5
--- /dev/null
+++ b/integrations/tensorflow/e2e/iree_e2e_cartesian_product_test_suite.bzl
@@ -0,0 +1,229 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Macro for building e2e tests from a single source with multiple flags."""
+
+load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
+load(
+    "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
+    "get_driver",
+)
+
+def _normalize_dictionary(dictionary):
+    """Wraps every value of dictionary in a list if it isn't one already."""
+    for key, value in dictionary.items():
+        if type(value) != type([]):
+            dictionary[key] = [value]
+    return dictionary
+
+def _dictionary_product(dictionary):
+    """Returns a named cartesian product of dictionary's values."""
+
+    # Converts {'a': [1, 2], 'b': [3, 4]} into
+    # [{'a': 1, 'b': 3}, {'a': 1, 'b': 4}, {'a': 2, 'b': 3}, {'a': 2, 'b': 4}]
+    product = [[]]
+    for values in dictionary.values():
+        # Iteratively grow the elements of the product.
+        product = [element + [value] for element in product for value in values]
+    dicts = [{k: v for k, v in zip(dictionary, element)} for element in product]
+    return dicts
+
+def iree_e2e_cartesian_product_test_suite(
+        name,
+        srcs,
+        main,
+        flags_to_values,
+        failing_configurations = None,
+        tags = None,
+        data = None,
+        deps = None,
+        size = None,
+        python_version = "PY3",
+        **kwargs):
+    """Creates a test for each configuration and bundles a succeeding and failing test suite.
+
+    Computes the cartesian product of `flags_to_values` and then creates a test
+    for each element of that product. Tests specified in
+    `failing_configurations` are bundled into a test suite suffixed with
+    "_failing" and tagged to be excluded from CI and wildcard builds. All other
+    tests are bundled into a suite with the same name as the macro.
+
+    For example, given the following values
+
+        flags_to_values = {
+            "use_external_weights": True,
+            "model": [
+                "ResNet50",
+                "MobileBert",
+            ],
+            "target_backends": [
+                "tf",
+                "iree_vmla"
+                "iree_vulkan",
+            ]
+        }
+        failing_configurations = [
+            {
+                "model": "MobileBert",
+                "target_backends": "iree_vulkan",
+            },
+            {
+                "model": ["ResNet50"],
+            },
+        ]
+
+    the following passing and failing configurations would be generated:
+        # Passing
+        {use_exernal_weights: True, model: MobileBert, target_backends: tf}
+        {use_exernal_weights: True, model: MobileBert, target_backends: iree_vmla}
+
+        # Failing
+        {use_exernal_weights: True, model: ResNet50,   target_backends: tf}
+        {use_exernal_weights: True, model: ResNet50,   target_backends: iree_vmla}
+        {use_exernal_weights: True, model: ResNet50,   target_backends: iree_vulkan}
+        {use_exernal_weights: True, model: MobileBert, target_backends: iree_vulkan}
+
+    Args:
+      name:
+        name of the generated passing test suite. If failing_configurations
+        is not `None` then a test suite named name_failing will also be
+        generated.
+      srcs:
+        src files for iree_py_test
+      main:
+        main file for iree_py_test
+      failing_configurations:
+        an iterable of dictionaries specifying which flag values the test is
+        failing for. If a flag name is present in `flags_to_values`, but not
+        present in `failing_configurations`, then all of the values in
+        `flags_to_values[flag_name]` are included. (See `ResNet50` in the
+        example above).
+      flags_to_values:
+        a dictionary of strings (flag names) to lists (of values for the flags)
+        to take a cartesian product of. `target_backends` must be specified.
+      tags:
+        tags to apply to the test. Note that as in standard test suites, manual
+        is treated specially and will also apply to the test suite itself.
+      data:
+        external data for iree_py_test.
+      deps:
+        test dependencies for iree_py_test.
+      size:
+        size of the tests for iree_py_test.
+      python_version:
+        the python version to run the tests with. Uses python3 by default.
+      **kwargs:
+        any additional arguments that will be passed to the underlying tests and
+        test_suite.
+    """
+    if not "target_backends" in flags_to_values:
+        fail("`target_backends` must be a key in `flags_to_values`.")
+
+    # Normalize flags_to_values to always have lists as its values.
+    # e.g. {use_external_data: True} -> {use_external_data: [True]}
+    flags_to_values = _normalize_dictionary(flags_to_values)
+
+    all_flag_configurations = _dictionary_product(flags_to_values)
+
+    failing_flag_configurations = []
+    if failing_configurations != None:
+        for failing_configuration in failing_configurations:
+            failing_configuration = _normalize_dictionary(failing_configuration)
+
+            # If a flag isn't specified in the failing configuration, assume it
+            # is failing for all values of that flag.
+            for key, values in flags_to_values.items():
+                if key not in failing_configuration:
+                    failing_configuration[key] = values
+
+            failing_flag_configurations.extend(
+                _dictionary_product(failing_configuration),
+            )
+
+    tests = []
+    for flags in all_flag_configurations:
+        # Check if this is a failing configuration.
+        failing = flags in failing_flag_configurations
+
+        # Append "_failing" to name if this is a failing configuration.
+        test_name = name if not failing else name + "_failing"
+        test_name = [test_name]
+        for k, v in flags.items():
+            # Only include the flag's value in the test name if it's not always
+            # the same.
+            if len(flags_to_values[k]) > 1:
+                test_name.append(k)
+                test_name.append(str(v))
+        test_name = "__".join(test_name)
+        tests.append(test_name)
+
+        args = ["--{}={}".format(k, v) for k, v in flags.items()]
+
+        if len(flags["target_backends"].split(",")) > 1:
+            fail("Multiple target backends cannot be specified at once, but " +
+                 "got `{}`".format(flags["target_backends"]))
+
+        driver = get_driver(flags["target_backends"])
+        py_test_tags = ["driver={}".format(driver)]
+        if tags != None:  # `is` is not supported.
+            py_test_tags += tags
+
+        # Add additional tags if this is a failing configuration.
+        if failing:
+            py_test_tags += [
+                "failing",
+                "manual",
+                "nokokoro",
+                "notap",
+            ]
+
+        iree_py_test(
+            name = test_name,
+            main = main,
+            srcs = srcs,
+            args = args,
+            data = data,
+            deps = deps,
+            size = size,
+            tags = py_test_tags,
+            python_version = python_version,
+            **kwargs
+        )
+
+    if tags == None:
+        tags = []
+    native.test_suite(
+        name = name,
+        tests = tests,
+        # Add "-failing" to exclude tests in `tests` that have the "failing"
+        # tag.
+        tags = tags + ["-failing"],
+        # If there are kwargs that need to be passed here which only apply to
+        # the generated tests and not to test_suite, they should be extracted
+        # into separate named arguments.
+        **kwargs
+    )
+
+    if failing_configurations != None:
+        native.test_suite(
+            name = name + "_failing",
+            tests = tests,
+            # Add "+failing" to only include tests in `tests` that have the
+            # "failing" tag.
+            tags = tags + ["+failing"],
+            # If there are kwargs that need to be passed here which only apply
+            # to the generated tests and not to test_suite, they should be
+            # extracted into separate named arguments.
+            **kwargs
+        )
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 159633a..8fb35a4 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -16,10 +16,18 @@
 
 load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
 
+def get_driver(backend):
+    # TODO(#2175): Simplify this after backend names are standardized.
+    driver = backend.replace("iree_", "")  # "iree_<driver>" --> "<driver>"
+    if driver == "llvmjit":
+        driver = "llvm"
+    return driver
+
 def iree_e2e_test_suite(
         name,
         backends_to_srcs,
         reference_backend,
+        data = None,
         deps = None,
         size = None,
         tags = None,
@@ -34,8 +42,10 @@
         a dictionary mapping backends to a list of test files to run on them.
       reference_backend:
         the backend to use as a source of truth for the expected output results.
+      data:
+        external data for iree_py_test.
       deps:
-        test dependencies.
+        test dependencies for iree_py_test.
       tags:
         tags to apply to the test. Note that as in standard test suites, manual
         is treated specially and will also apply to the test suite itself.
@@ -49,10 +59,9 @@
 
     for backend, srcs in backends_to_srcs.items():
         for src in srcs:
-            test_name = "{}_{}__{}__{}".format(
+            test_name = "{}__{}__target_backends__{}".format(
                 name,
                 src[:-3],
-                reference_backend,
                 backend,
             )
             args = [
@@ -60,10 +69,7 @@
                 "--target_backends={}".format(backend),
             ]
 
-            # TODO(GH-2175): Simplify this after backend names are standardized.
-            driver = backend.replace("iree_", "")  # "iree_<driver>" --> "<driver>"
-            if driver == "llvmjit":
-                driver = "llvm"
+            driver = get_driver(backend)
             py_test_tags = ["driver={}".format(driver)]
             if tags != None:  # `is` is not supported.
                 py_test_tags += tags
@@ -72,8 +78,9 @@
                 name = test_name,
                 main = src,
                 srcs = [src],
-                deps = deps,
                 args = args,
+                data = data,
+                deps = deps,
                 size = size,
                 tags = py_test_tags,
                 python_version = python_version,
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 7753475..78098b1 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -25,8 +25,8 @@
     "iree_py_binary",
 )
 load(
-    "//integrations/tensorflow/e2e/keras:iree_vision_test_suite.bzl",
-    "iree_vision_test_suite",
+    "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+    "iree_e2e_cartesian_product_test_suite",
 )
 load(
     "//integrations/tensorflow/e2e:iree_e2e_test_suite.bzl",
@@ -152,6 +152,7 @@
     },
     reference_backend = "tf",
     tags = [
+        "failing",
         "manual",
         "nokokoro",
         "notap",
@@ -161,73 +162,80 @@
     ],
 )
 
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
     name = "large_cifar10_internal_tests",
     size = "large",
-    backends = [
-        "tf",
-        "tflite",
-        "iree_vmla",
-        "iree_llvmjit",
-        "iree_vulkan",
-    ],
-    datasets = ["cifar10"],
-    models = [
-        # All models with runtime shorter than ResNet50.
-        "MobileNet",  # Max: Vulkan 61.0s
-        "MobileNetV2",  # Max: LLVM 96.3s
-        "ResNet50",  # Max: LLVM 145.6s
-        "VGG16",  # Max: LLVM 89.5s
-        "VGG19",  # Max: LLVM 94.7s
-    ],
-    reference_backend = "tf",
+    srcs = ["vision_model_test.py"],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "model": [
+            # All models with runtime shorter than ResNet50.
+            "MobileNet",  # Max: Vulkan 61.0s
+            "MobileNetV2",  # Max: LLVM 96.3s
+            "ResNet50",  # Max: LLVM 145.6s
+            "VGG16",  # Max: LLVM 89.5s
+            "VGG19",  # Max: LLVM 94.7s
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "vision_model_test.py",
     tags = ["manual"],
     deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
 
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
     name = "enormous_cifar10_internal_tests",
     size = "enormous",
-    backends = [
-        "tf",
-        "tflite",
-        "iree_vmla",
-        "iree_llvmjit",
-        "iree_vulkan",
-    ],
-    datasets = ["cifar10"],
+    srcs = ["vision_model_test.py"],
     failing_configurations = [
         {
             # Failing on llvm and vulkan:
-            "models": [
+            "model": [
                 "NASNetLarge",
                 "NASNetMobile",
                 "ResNet50V2",
                 "ResNet101V2",
                 "ResNet152V2",
             ],
-            "datasets": ["cifar10"],
-            "backends": [
+            "target_backends": [
                 "iree_llvmjit",
                 "iree_vulkan",
             ],
         },
     ],
-    models = [
-        "DenseNet121",
-        "DenseNet169",
-        "DenseNet201",
-        "NASNetLarge",
-        "NASNetMobile",
-        "ResNet50V2",
-        "ResNet101",
-        "ResNet101V2",
-        "ResNet152",
-        "ResNet152V2",
-    ],
-    reference_backend = "tf",
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "model": [
+            "DenseNet121",
+            "DenseNet169",
+            "DenseNet201",
+            "NASNetLarge",
+            "NASNetMobile",
+            "ResNet50V2",
+            "ResNet101",
+            "ResNet101V2",
+            "ResNet152",
+            "ResNet152V2",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "vision_model_test.py",
     tags = [
         "guitar",
         "manual",
@@ -239,22 +247,29 @@
     ],
 )
 
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
     name = "cifar10_external_tests",
-    backends = [
-        "tf",
-        "tflite",
-        "iree_vmla",
-        "iree_llvmjit",
-        "iree_vulkan",
-    ],
-    datasets = ["cifar10"],
-    models = [
-        "MobileNet",
-        "MobileNetV2",
-        "ResNet50",
-    ],
-    reference_backend = "tf",
+    size = "large",
+    srcs = ["vision_model_test.py"],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "cifar10",
+        "url": "https://storage.googleapis.com/iree_models/",
+        "use_external_weights": True,
+        "model": [
+            "MobileNet",
+            "MobileNetV2",
+            "ResNet50",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "vision_model_test.py",
     tags = [
         "external",
         "guitar",
@@ -263,28 +278,19 @@
         "nokokoro",
         "notap",
     ],
-    url = "https://storage.googleapis.com/iree_models/",
-    use_external_weights = True,
     deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
 
-iree_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
     name = "imagenet_external_tests",
     size = "enormous",
-    backends = [
-        "tf",
-        "tflite",
-        "iree_vmla",
-        "iree_llvmjit",
-        "iree_vulkan",
-    ],
-    datasets = ["imagenet"],
+    srcs = ["vision_model_test.py"],
     failing_configurations = [
         {
             # Failing vulkan:
-            "models": [
+            "model": [
                 "InceptionResNetV2",
                 "InceptionV3",
             ],
@@ -303,34 +309,45 @@
                 "ResNet152V2",
                 "Xception",
             ],
-            "datasets": ["imagenet"],
-            "backends": [
+            "target_backends": [
                 "iree_llvmjit",
                 "iree_vulkan",
             ],
         },
     ],
-    models = [
-        "DenseNet121",
-        "DenseNet169",
-        "DenseNet201",
-        "InceptionResNetV2",
-        "InceptionV3",
-        "MobileNet",
-        "MobileNetV2",
-        "NASNetLarge",
-        "NASNetMobile",
-        "ResNet50",
-        "ResNet50V2",
-        "ResNet101",
-        "ResNet101V2",
-        "ResNet152",
-        "ResNet152V2",
-        "VGG16",
-        "VGG19",
-        "Xception",
-    ],
-    reference_backend = "tf",
+    flags_to_values = {
+        "reference_backend": "tf",
+        "data": "imagenet",
+        "use_external_weights": True,
+        "model": [
+            "DenseNet121",
+            "DenseNet169",
+            "DenseNet201",
+            "InceptionResNetV2",
+            "InceptionV3",
+            "MobileNet",
+            "MobileNetV2",
+            "NASNetLarge",
+            "NASNetMobile",
+            "ResNet50",
+            "ResNet50V2",
+            "ResNet101",
+            "ResNet101V2",
+            "ResNet152",
+            "ResNet152V2",
+            "VGG16",
+            "VGG19",
+            "Xception",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "vision_model_test.py",
     tags = [
         "external",
         "guitar",
@@ -338,7 +355,6 @@
         "nokokoro",
         "notap",
     ],
-    use_external_weights = True,
     deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
diff --git a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl b/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
deleted file mode 100644
index 303938c..0000000
--- a/integrations/tensorflow/e2e/keras/iree_vision_test_suite.bzl
+++ /dev/null
@@ -1,185 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro for building e2e keras vision model tests."""
-
-load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
-load("@bazel_skylib//lib:new_sets.bzl", "sets")
-
-def iree_vision_test_suite(
-        name,
-        models,
-        datasets,
-        backends,
-        reference_backend,
-        failing_configurations = None,
-        tags = None,
-        url = None,
-        use_external_weights = None,
-        deps = None,
-        size = "large",
-        python_version = "PY3",
-        **kwargs):
-    """Creates a test for each configuration and bundles a succeeding and failing test suite.
-
-    Creates one test per dataset, backend, and model. Tests indicated in
-    `failing_configurations` are bundled into a suite suffixed with "_failing"
-    tagged to be excluded from CI and wildcard builds. All other tests are
-    bundled into a suite with the same name as the macro.
-
-    Args:
-      name:
-        name of the generated passing test suite. If failing_configurations is
-        not `None` then a test suite named name_failing will also be generated.
-      models:
-        an iterable of model names to generate targets for.
-      datasets:
-        an iterable specifying the datasets on which the models are based. This
-        controls the shape of the input images. Also indicates which weight file
-        to use when loading weights from an external source.
-      backends:
-        an iterable of targets backends to generate targets for.
-      reference_backend:
-        the backend to use as a source of truth for the expected output results.
-      failing_configurations:
-        an iterable of dictionaries with the keys `models`, `datasets` and
-        `backends`. Each key points to a string or iterable of strings
-        specifying a set of models, datasets and backends that are failing.
-      tags:
-        tags to apply to the test. Note that as in standard test suites, manual
-        is treated specially and will also apply to the test suite itself.
-      url:
-        a base url to fetch non-keras trained model weights from.
-      use_external_weights:
-        whether or not to load model weights from the web (either uses keras or
-          the supplied url).
-      deps:
-        test dependencies.
-      size:
-        size of the tests. Default: "large".
-      python_version:
-        the python version to run the tests with. Uses python3 by default.
-      **kwargs:
-        any additional arguments that will be passed to the underlying tests and
-        test_suite.
-    """
-    failing_set = sets.make([])
-    if failing_configurations != None:
-        # Parse failing configurations.
-        for configuration in failing_configurations:
-            # Normalize configuration input.
-            # {backend: "iree_llvmjit"} -> {backend: ["iree_llvmjit"]}
-            for key, value in configuration.items():
-                if type(value) == type(""):
-                    configuration[key] = [value]
-
-            for model in configuration["models"]:
-                for dataset in configuration["datasets"]:
-                    for backend in configuration["backends"]:
-                        sets.insert(failing_set, (model, dataset, backend))
-
-    tests = []
-    for model in models:
-        for dataset in datasets:
-            for backend in backends:
-                # Check if this is a failing configuration.
-                failing = sets.contains(failing_set, (model, dataset, backend))
-
-                # Append "_failing" to name if this is a failing configuration.
-                test_name = name if not failing else name + "_failing"
-                if len(datasets) > 1:
-                    test_name = "{}_{}_{}__{}__{}".format(
-                        test_name,
-                        model,
-                        dataset,
-                        reference_backend,
-                        backend,
-                    )
-                else:
-                    test_name = "{}_{}__{}__{}".format(
-                        test_name,
-                        model,
-                        reference_backend,
-                        backend,
-                    )
-                tests.append(test_name)
-
-                args = [
-                    "--model={}".format(model),
-                    "--data={}".format(dataset),
-                    "--reference_backend={}".format(reference_backend),
-                    "--target_backends={}".format(backend),
-                ]
-                if url:
-                    args.append("--url={}".format(url))
-                if use_external_weights:
-                    args.append(
-                        "--use_external_weights={}".format(use_external_weights),
-                    )
-
-                # TODO(GH-2175): Simplify this after backend names are
-                # standardized.
-                # "iree_<driver>" --> "<driver>"
-                driver = backend.replace("iree_", "")
-                if driver == "llvmjit":
-                    driver = "llvm"
-                py_test_tags = ["driver={}".format(driver)]
-                if tags != None:  # `is` is not supported.
-                    py_test_tags += tags
-
-                # Add additional tags if this is a failing configuration.
-                if failing:
-                    py_test_tags += [
-                        "failing",  # Only used for test_suite filtering below.
-                        "manual",
-                        "nokokoro",
-                        "notap",
-                    ]
-
-                iree_py_test(
-                    name = test_name,
-                    main = "vision_model_test.py",
-                    srcs = ["vision_model_test.py"],
-                    args = args,
-                    tags = py_test_tags,
-                    deps = deps,
-                    size = size,
-                    python_version = python_version,
-                    **kwargs
-                )
-
-    native.test_suite(
-        name = name,
-        tests = tests,
-        # Add "-failing" to exclude tests in `tests` that have the "failing"
-        # tag.
-        tags = tags + ["-failing"],
-        # If there are kwargs that need to be passed here which only apply to
-        # the generated tests and not to test_suite, they should be extracted
-        # into separate named arguments.
-        **kwargs
-    )
-
-    if failing_configurations != None:
-        native.test_suite(
-            name = name + "_failing",
-            tests = tests,
-            # Add "+failing" to only include tests in `tests` that have the
-            # "failing" tag.
-            tags = tags + ["+failing"],
-            # If there are kwargs that need to be passed here which only apply
-            # to the generated tests and not to test_suite, they should be
-            # extracted into separate named arguments.
-            **kwargs
-        )
diff --git a/integrations/tensorflow/e2e/keras/train/BUILD b/integrations/tensorflow/e2e/keras/train/BUILD
index d6436b2..177bcaf 100644
--- a/integrations/tensorflow/e2e/keras/train/BUILD
+++ b/integrations/tensorflow/e2e/keras/train/BUILD
@@ -18,8 +18,8 @@
     "NUMPY_DEPS",
 )
 load(
-    "//integrations/tensorflow/e2e/keras/train:iree_train_test_suite.bzl",
-    "iree_train_test_suite",
+    "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+    "iree_e2e_cartesian_product_test_suite",
 )
 
 package(
@@ -28,13 +28,34 @@
     licenses = ["notice"],  # Apache 2.0
 )
 
-# TODO(meadowlark): Refactor this rule to match iree_vision_test_suite.bzl
-iree_train_test_suite(
+iree_e2e_cartesian_product_test_suite(
     name = "train_tests",
-    configurations = [
-        # tuples of (optimizer, backends)
-        ("sgd", "tf"),
+    srcs = ["model_train_test.py"],
+    failing_configurations = [
+        {
+            "target_backends": [
+                "tflite",
+                "iree_vmla",  # TODO(b/157581521)
+                "iree_llvmjit",
+                "iree_vulkan",
+            ],
+        },
     ],
+    flags_to_values = {
+        "reference_backend": "tf",
+        "optimizer_name": [
+            "sgd",
+            "adam",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "model_train_test.py",
     tags = [
         "guitar",
         "manual",
@@ -45,21 +66,3 @@
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
-
-iree_train_test_suite(
-    name = "train_tests_failing",
-    configurations = [
-        # tuples of (optimizer, backends)
-        # TODO: Combine this suite with keras_model_train once these tests pass.
-        ("sgd", "tf,iree_vmla"),
-        ("adam", "tf,iree_vmla"),  # TODO(b/157581521)
-    ],
-    tags = [
-        "manual",
-        "nokokoro",
-        "notap",
-    ],
-    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
-        "//integrations/tensorflow/bindings/python/pyiree/tf/support",
-    ],
-)
diff --git a/integrations/tensorflow/e2e/keras/train/README.md b/integrations/tensorflow/e2e/keras/train/README.md
new file mode 100644
index 0000000..86222b1
--- /dev/null
+++ b/integrations/tensorflow/e2e/keras/train/README.md
@@ -0,0 +1,10 @@
+# Keras Training Tests
+
+These tests require an additional python dependency on `sklearn`, which
+can be installed as follows:
+
+```shell
+python3 -m pip install sklearn
+```
+
+These tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl b/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl
deleted file mode 100644
index d6b244b..0000000
--- a/integrations/tensorflow/e2e/keras/train/iree_train_test_suite.bzl
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro for building e2e keras vision model tests."""
-
-load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
-
-def iree_train_test_suite(
-        name,
-        configurations,
-        tags = None,
-        deps = None,
-        size = None,
-        python_version = "PY3",
-        **kwargs):
-    """Creates one iree_py_test per configuration tuple and a test suite that bundles them.
-
-    Args:
-      name:
-        name of the generated test suite.
-      configurations:
-        a list of tuples of (optimizer, backends).
-      tags:
-        tags to apply to the test. Note that as in standard test suites, manual
-        is treated specially and will also apply to the test suite itself.
-      deps:
-        test dependencies.
-      size:
-        size of the tests.
-      python_version:
-        the python version to run the tests with. Uses python3 by default.
-      **kwargs:
-        Any additional arguments that will be passed to the underlying tests
-        and test_suite.
-    """
-    tests = []
-    for optimizer, backends in configurations:
-        test_name = "{}_{}_{}_test".format(name, optimizer, backends)
-        tests.append(test_name)
-
-        args = [
-            "--optimizer_name={}".format(optimizer),
-            "--target_backends={}".format(backends),
-        ]
-
-        iree_py_test(
-            name = test_name,
-            main = "model_train_test.py",
-            srcs = ["model_train_test.py"],
-            args = args,
-            tags = tags,
-            deps = deps,
-            size = size,
-            python_version = python_version,
-            **kwargs
-        )
-
-    native.test_suite(
-        name = name,
-        tests = tests,
-        # Note that only the manual tag really has any effect here. Others are
-        # used for test suite filtering, but all tests are passed the same tags.
-        tags = tags,
-        # If there are kwargs that need to be passed here which only apply to
-        # the generated tests and not to test_suite, they should be extracted
-        # into separate named arguments.
-        **kwargs
-    )
diff --git a/integrations/tensorflow/e2e/slim_vision_models/BUILD b/integrations/tensorflow/e2e/slim_vision_models/BUILD
index c7c6101..bbfe6a7 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/BUILD
+++ b/integrations/tensorflow/e2e/slim_vision_models/BUILD
@@ -26,8 +26,8 @@
     "iree_py_binary",
 )
 load(
-    "//integrations/tensorflow/e2e/slim_vision_models:iree_slim_vision_test_suite.bzl",
-    "iree_slim_vision_test_suite",
+    "//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
+    "iree_e2e_cartesian_product_test_suite",
 )
 
 package(
@@ -48,33 +48,18 @@
     ],
 )
 
-iree_slim_vision_test_suite(
+iree_e2e_cartesian_product_test_suite(
     name = "slim_vision_tests",
     size = "enormous",
-    backends = [
-        "tf",
-        "tflite",
-        "iree_vmla",
-        "iree_llvmjit",
-        "iree_vulkan",
-    ],
+    srcs = ["slim_vision_model_test.py"],
     failing_configurations = [
         {
             # SavedModelV2 (classification/4) not available.
-            "models": [
-                "amoebanet_a_n18_f448",
-            ],
-            "backends": [
-                "tf",
-                "tflite",
-                "iree_vmla",
-                "iree_llvmjit",
-                "iree_vulkan",
-            ],
+            "model": "amoebanet_a_n18_f448",
         },
         {
             # Failing llvmjit and vulkan:
-            "models": [
+            "model": [
                 "nasnet_mobile",
                 "nasnet_large",
                 "pnasnet_large",
@@ -82,84 +67,95 @@
                 "resnet_v2_101",
                 "resnet_v2_152",
             ],
-            "backends": [
+            "target_backends": [
                 "iree_llvmjit",
                 "iree_vulkan",
             ],
         },
         {
             # Failing vulkan:
-            "models": [
+            "model": [
                 # [ERROR]: cannot separate Linalg/Parallel ops into multiple kernels
                 "inception_v1",
                 "inception_v2",
                 "inception_v3",
                 "inception_resnet_v2",
             ],
-            "backends": [
+            "target_backends": [
                 "iree_vulkan",
             ],
         },
     ],
-    models = [
-        "amoebanet_a_n18_f448",
-        "inception_resnet_v2",
-        "inception_v1",
-        "inception_v2",
-        "inception_v3",
-        # MobileNetV1
-        "mobilenet_v1_025_128",
-        "mobilenet_v1_025_160",
-        "mobilenet_v1_025_192",
-        "mobilenet_v1_025_224",
-        "mobilenet_v1_050_128",
-        "mobilenet_v1_050_160",
-        "mobilenet_v1_050_192",
-        "mobilenet_v1_050_224",
-        "mobilenet_v1_075_128",
-        "mobilenet_v1_075_160",
-        "mobilenet_v1_075_192",
-        "mobilenet_v1_075_224",
-        "mobilenet_v1_100_128",
-        "mobilenet_v1_100_160",
-        "mobilenet_v1_100_192",
-        "mobilenet_v1_100_224",
-        # MobileNetV2:
-        "mobilenet_v2_035_96",
-        "mobilenet_v2_035_128",
-        "mobilenet_v2_035_160",
-        "mobilenet_v2_035_192",
-        "mobilenet_v2_035_224",
-        "mobilenet_v2_050_96",
-        "mobilenet_v2_050_128",
-        "mobilenet_v2_050_160",
-        "mobilenet_v2_050_192",
-        "mobilenet_v2_050_224",
-        "mobilenet_v2_075_96",
-        "mobilenet_v2_075_128",
-        "mobilenet_v2_075_160",
-        "mobilenet_v2_075_192",
-        "mobilenet_v2_075_224",
-        "mobilenet_v2_100_96",
-        "mobilenet_v2_100_128",
-        "mobilenet_v2_100_160",
-        "mobilenet_v2_100_192",
-        "mobilenet_v2_100_224",
-        "mobilenet_v2_130_224",
-        "mobilenet_v2_140_224",
-        "nasnet_mobile",
-        "nasnet_large",
-        "pnasnet_large",
-        # ResNetV1
-        "resnet_v1_50",
-        "resnet_v1_101",
-        "resnet_v1_152",
-        # ResNetV2
-        "resnet_v2_50",
-        "resnet_v2_101",
-        "resnet_v2_152",
-    ],
-    reference_backend = "tf",
+    flags_to_values = {
+        "reference_backend": "tf",
+        "tf_hub_url": ["https://tfhub.dev/google/imagenet/"],
+        "model": [
+            "amoebanet_a_n18_f448",
+            "inception_resnet_v2",
+            "inception_v1",
+            "inception_v2",
+            "inception_v3",
+            # MobileNetV1
+            "mobilenet_v1_025_128",
+            "mobilenet_v1_025_160",
+            "mobilenet_v1_025_192",
+            "mobilenet_v1_025_224",
+            "mobilenet_v1_050_128",
+            "mobilenet_v1_050_160",
+            "mobilenet_v1_050_192",
+            "mobilenet_v1_050_224",
+            "mobilenet_v1_075_128",
+            "mobilenet_v1_075_160",
+            "mobilenet_v1_075_192",
+            "mobilenet_v1_075_224",
+            "mobilenet_v1_100_128",
+            "mobilenet_v1_100_160",
+            "mobilenet_v1_100_192",
+            "mobilenet_v1_100_224",
+            # MobileNetV2:
+            "mobilenet_v2_035_96",
+            "mobilenet_v2_035_128",
+            "mobilenet_v2_035_160",
+            "mobilenet_v2_035_192",
+            "mobilenet_v2_035_224",
+            "mobilenet_v2_050_96",
+            "mobilenet_v2_050_128",
+            "mobilenet_v2_050_160",
+            "mobilenet_v2_050_192",
+            "mobilenet_v2_050_224",
+            "mobilenet_v2_075_96",
+            "mobilenet_v2_075_128",
+            "mobilenet_v2_075_160",
+            "mobilenet_v2_075_192",
+            "mobilenet_v2_075_224",
+            "mobilenet_v2_100_96",
+            "mobilenet_v2_100_128",
+            "mobilenet_v2_100_160",
+            "mobilenet_v2_100_192",
+            "mobilenet_v2_100_224",
+            "mobilenet_v2_130_224",
+            "mobilenet_v2_140_224",
+            "nasnet_mobile",
+            "nasnet_large",
+            "pnasnet_large",
+            # ResNetV1
+            "resnet_v1_50",
+            "resnet_v1_101",
+            "resnet_v1_152",
+            # ResNetV2
+            "resnet_v2_50",
+            "resnet_v2_101",
+            "resnet_v2_152",
+        ],
+        "target_backends": [
+            "tf",
+            "tflite",
+            "iree_vmla",
+            "iree_llvmjit",
+            "iree_vulkan",
+        ],
+    },
+    main = "slim_vision_model_test.py",
     tags = [
         "external",
         "guitar",
@@ -168,8 +164,7 @@
         "nokokoro",
         "notap",
     ],
-    tf_hub_url = "https://tfhub.dev/google/imagenet/",
-    deps = INTREE_TENSORFLOW_PY_DEPS + INTREE_TF_HUB_DEPS + NUMPY_DEPS + [
+    deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
         "//integrations/tensorflow/bindings/python/pyiree/tf/support",
     ],
 )
diff --git a/integrations/tensorflow/e2e/slim_vision_models/README.md b/integrations/tensorflow/e2e/slim_vision_models/README.md
index 0a0b5dd..8a09a14 100644
--- a/integrations/tensorflow/e2e/slim_vision_models/README.md
+++ b/integrations/tensorflow/e2e/slim_vision_models/README.md
@@ -7,4 +7,4 @@
 python3 -m pip install tensorflow_hub
 ```
 
-Like the `vision_external_tests`, these tests are not checked by the OSS CI.
+These tests are not checked by the OSS CI.
diff --git a/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl b/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
deleted file mode 100644
index 857c1b4..0000000
--- a/integrations/tensorflow/e2e/slim_vision_models/iree_slim_vision_test_suite.bzl
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright 2020 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Macro for building e2e keras vision model tests."""
-
-load("//bindings/python:build_defs.oss.bzl", "iree_py_test")
-load("@bazel_skylib//lib:new_sets.bzl", "sets")
-
-def iree_slim_vision_test_suite(
-        name,
-        models,
-        backends,
-        reference_backend,
-        failing_configurations = None,
-        tf_hub_url = None,
-        tags = None,
-        deps = None,
-        size = "large",
-        python_version = "PY3",
-        **kwargs):
-    """Creates a test for each configuration and bundles a succeeding and failing test suite.
-
-    Creates one test per model and backend. Tests indicated in
-    `failing_configurations` are bundled into a suite suffixed with "_failing"
-    tagged to be excluded from CI and wildcard builds. All other tests are
-    bundled into a suite with the same name as the macro.
-
-    Args:
-      name:
-        name of the generated passing test suite. If failing_configurations is
-        not `None` then a test suite named name_failing will also be generated.
-      models:
-        an iterable of slim vision model tags to generate targets for.
-      backends:
-        an iterable of targets backends to generate targets for.
-      reference_backend:
-        the backend to use as a source of truth for the expected output results.
-      failing_configurations:
-        an iterable of dictionaries with the keys `models` and `backends`. Each
-        key points to a string or iterable of strings specifying a set of models
-        and backends that are failing.
-      tf_hub_url:
-        a string pointing to the TF Hub base url of the models to test.
-      tags:
-        tags to apply to the test. Note that as in standard test suites, manual
-        is treated specially and will also apply to the test suite itself.
-      deps:
-        test dependencies.
-      size:
-        size of the tests. Default: "large".
-      python_version:
-        the python version to run the tests with. Uses python3 by default.
-      **kwargs:
-        any additional arguments that will be passed to the underlying tests and
-        test_suite.
-    """
-    failing_set = sets.make([])
-    if failing_configurations != None:
-        # Parse failing configurations.
-        for configuration in failing_configurations:
-            # Normalize configuration input.
-            # {backend: "iree_llvmjit"} -> {backend: ["iree_llvmjit"]}
-            for key, value in configuration.items():
-                if type(value) == type(""):
-                    configuration[key] = [value]
-
-            for model in configuration["models"]:
-                for backend in configuration["backends"]:
-                    sets.insert(failing_set, (model, backend))
-
-    tests = []
-    for model in models:
-        for backend in backends:
-            # Check if this is a failing configuration.
-            failing = sets.contains(failing_set, (model, backend))
-
-            # Append "_failing" to name if this is a failing configuration.
-            test_name = name if not failing else name + "_failing"
-            test_name = "{}_{}__{}__{}".format(
-                test_name,
-                model,
-                reference_backend,
-                backend,
-            )
-            tests.append(test_name)
-
-            args = [
-                "--model={}".format(model),
-                "--tf_hub_url={}".format(tf_hub_url),
-                "--reference_backend={}".format(reference_backend),
-                "--target_backends={}".format(backend),
-            ]
-
-            # TODO(GH-2175): Simplify this after backend names are
-            # standardized.
-            # "iree_<driver>" --> "<driver>"
-            driver = backend.replace("iree_", "")
-            if driver == "llvmjit":
-                driver = "llvm"
-            py_test_tags = ["driver={}".format(driver)]
-            if tags != None:  # `is` is not supported.
-                py_test_tags += tags
-
-            # Add additional tags if this is a failing configuration.
-            if failing:
-                py_test_tags += [
-                    "failing",  # Only used for test_suite filtering below.
-                    "manual",
-                    "nokokoro",
-                    "notap",
-                ]
-
-            iree_py_test(
-                name = test_name,
-                main = "slim_vision_model_test.py",
-                srcs = ["slim_vision_model_test.py"],
-                args = args,
-                tags = py_test_tags,
-                deps = deps,
-                size = size,
-                python_version = python_version,
-                **kwargs
-            )
-
-    native.test_suite(
-        name = name,
-        tests = tests,
-        # Add "-failing" to exclude tests in `tests` that have the "failing"
-        # tag.
-        tags = tags + ["-failing"],
-        # If there are kwargs that need to be passed here which only apply to
-        # the generated tests and not to test_suite, they should be extracted
-        # into separate named arguments.
-        **kwargs
-    )
-
-    if failing_configurations != None:
-        native.test_suite(
-            name = name + "_failing",
-            tests = tests,
-            # Add "+failing" to only include tests in `tests` that have the
-            # "failing" tag.
-            tags = tags + ["+failing"],
-            # If there are kwargs that need to be passed here which only apply
-            # to the generated tests and not to test_suite, they should be
-            # extracted into separate named arguments.
-            **kwargs
-        )
diff --git a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index 819373f..0f97aae 100644
--- a/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -202,10 +202,9 @@
 namespace {
 
 template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalLoadAddress : public OpRewritePattern<INDIRECT> {
   using OpRewritePattern<INDIRECT>::OpRewritePattern;
 
- public:
   LogicalResult matchAndRewrite(INDIRECT op,
                                 PatternRewriter &rewriter) const override {
     if (auto addressOp =
@@ -244,10 +243,9 @@
 namespace {
 
 template <typename INDIRECT, typename DIRECT>
-class PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
+struct PropagateGlobalStoreAddress : public OpRewritePattern<INDIRECT> {
   using OpRewritePattern<INDIRECT>::OpRewritePattern;
 
- public:
   LogicalResult matchAndRewrite(INDIRECT op,
                                 PatternRewriter &rewriter) const override {
     if (auto addressOp =
@@ -382,15 +380,13 @@
 // Native integer arithmetic
 //===----------------------------------------------------------------------===//
 
-namespace {
-
 /// Performs const folding `calculate` with element-wise behavior on the given
 /// attribute in `operands` and returns the result if possible.
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
-                           const CalculationT &calculate) {
+static Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
+                                  const CalculationT &calculate) {
   assert(operands.size() == 1 && "unary op takes one operand");
   if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
     return AttrElementT::get(operand.getType(), calculate(operand.getValue()));
@@ -414,8 +410,8 @@
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT =
               std::function<ElementValueT(ElementValueT, ElementValueT)>>
-Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
-                            const CalculationT &calculate) {
+static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
+                                   const CalculationT &calculate) {
   assert(operands.size() == 2 && "binary op takes two operands");
   if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
     auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
@@ -448,42 +444,54 @@
   return {};
 }
 
-}  // namespace
-
-template <typename T>
-static OpFoldResult foldAddOp(T op, ArrayRef<Attribute> operands) {
+template <typename ADD, typename SUB>
+static OpFoldResult foldAddOp(ADD op, ArrayRef<Attribute> operands) {
   if (matchPattern(op.rhs(), m_Zero())) {
     // x + 0 = x or 0 + y = y (commutative)
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a + b; });
+  if (auto subOp = dyn_cast_or_null<SUB>(op.lhs().getDefiningOp())) {
+    if (subOp.lhs() == op.rhs()) return subOp.rhs();
+    if (subOp.rhs() == op.rhs()) return subOp.lhs();
+  } else if (auto subOp = dyn_cast_or_null<SUB>(op.rhs().getDefiningOp())) {
+    if (subOp.lhs() == op.lhs()) return subOp.rhs();
+    if (subOp.rhs() == op.lhs()) return subOp.lhs();
+  }
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a + b; });
 }
 
 OpFoldResult AddI32Op::fold(ArrayRef<Attribute> operands) {
-  return foldAddOp(*this, operands);
+  return foldAddOp<AddI32Op, SubI32Op>(*this, operands);
 }
 
 OpFoldResult AddI64Op::fold(ArrayRef<Attribute> operands) {
-  return foldAddOp(*this, operands);
+  return foldAddOp<AddI64Op, SubI64Op>(*this, operands);
 }
 
-template <typename T>
-static OpFoldResult foldSubOp(T op, ArrayRef<Attribute> operands) {
+template <typename SUB, typename ADD>
+static OpFoldResult foldSubOp(SUB op, ArrayRef<Attribute> operands) {
   if (matchPattern(op.rhs(), m_Zero())) {
     // x - 0 = x
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a - b; });
+  if (auto addOp = dyn_cast_or_null<ADD>(op.lhs().getDefiningOp())) {
+    if (addOp.lhs() == op.rhs()) return addOp.rhs();
+    if (addOp.rhs() == op.rhs()) return addOp.lhs();
+  } else if (auto addOp = dyn_cast_or_null<ADD>(op.rhs().getDefiningOp())) {
+    if (addOp.lhs() == op.lhs()) return addOp.rhs();
+    if (addOp.rhs() == op.lhs()) return addOp.lhs();
+  }
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a - b; });
 }
 
 OpFoldResult SubI32Op::fold(ArrayRef<Attribute> operands) {
-  return foldSubOp(*this, operands);
+  return foldSubOp<SubI32Op, AddI32Op>(*this, operands);
 }
 
 OpFoldResult SubI64Op::fold(ArrayRef<Attribute> operands) {
-  return foldSubOp(*this, operands);
+  return foldSubOp<SubI64Op, AddI64Op>(*this, operands);
 }
 
 template <typename T>
@@ -495,18 +503,51 @@
     // x * 1 = x or 1 * y = y (commutative)
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a * b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a * b; });
 }
 
+template <typename T, typename CONST_OP>
+struct FoldConstantMulOperand : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    IntegerAttr c1, c2;
+    if (!matchPattern(op.rhs(), m_Constant(&c1))) return failure();
+    if (auto mulOp = dyn_cast_or_null<T>(op.lhs().getDefiningOp())) {
+      if (matchPattern(mulOp.rhs(), m_Constant(&c2))) {
+        auto c = rewriter.createOrFold<CONST_OP>(
+            FusedLoc::get({mulOp.getLoc(), op.getLoc()}, rewriter.getContext()),
+            constFoldBinaryOp<IntegerAttr>(
+                {c1, c2},
+                [](const APInt &a, const APInt &b) { return a * b; }));
+        rewriter.replaceOpWithNewOp<T>(op, op.getType(), mulOp.lhs(), c);
+        return success();
+      }
+    }
+    return failure();
+  }
+};
+
 OpFoldResult MulI32Op::fold(ArrayRef<Attribute> operands) {
   return foldMulOp(*this, operands);
 }
 
+void MulI32Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<FoldConstantMulOperand<MulI32Op, ConstI32Op>>(context);
+}
+
 OpFoldResult MulI64Op::fold(ArrayRef<Attribute> operands) {
   return foldMulOp(*this, operands);
 }
 
+void MulI64Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                           MLIRContext *context) {
+  results.insert<FoldConstantMulOperand<MulI64Op, ConstI64Op>>(context);
+}
+
 template <typename T>
 static OpFoldResult foldDivSOp(T op, ArrayRef<Attribute> operands) {
   if (matchPattern(op.rhs(), m_Zero())) {
@@ -521,7 +562,7 @@
     return op.lhs();
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.sdiv(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.sdiv(b); });
 }
 
 OpFoldResult DivI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -546,7 +587,7 @@
     return op.lhs();
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.udiv(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.udiv(b); });
 }
 
 OpFoldResult DivI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -570,7 +611,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.srem(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.srem(b); });
 }
 
 OpFoldResult RemI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -589,7 +630,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [](APInt a, APInt b) { return a.urem(b); });
+      operands, [](const APInt &a, const APInt &b) { return a.urem(b); });
 }
 
 OpFoldResult RemI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -625,8 +666,8 @@
     // x & x = x
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a & b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a & b; });
 }
 
 OpFoldResult AndI32Op::fold(ArrayRef<Attribute> operands) {
@@ -646,8 +687,8 @@
     // x | x = x
     return op.lhs();
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a | b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a | b; });
 }
 
 OpFoldResult OrI32Op::fold(ArrayRef<Attribute> operands) {
@@ -667,8 +708,8 @@
     // x ^ x = 0
     return zeroOfType(op.getType());
   }
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a ^ b; });
+  return constFoldBinaryOp<IntegerAttr>(
+      operands, [](const APInt &a, const APInt &b) { return a ^ b; });
 }
 
 OpFoldResult XorI32Op::fold(ArrayRef<Attribute> operands) {
@@ -693,7 +734,7 @@
     return op.operand();
   }
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return a.shl(op.amount()); });
+      operands, [&](const APInt &a) { return a.shl(op.amount()); });
 }
 
 OpFoldResult ShlI32Op::fold(ArrayRef<Attribute> operands) {
@@ -714,7 +755,7 @@
     return op.operand();
   }
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return a.ashr(op.amount()); });
+      operands, [&](const APInt &a) { return a.ashr(op.amount()); });
 }
 
 OpFoldResult ShrI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -735,7 +776,7 @@
     return op.operand();
   }
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return a.lshr(op.amount()); });
+      operands, [&](const APInt &a) { return a.lshr(op.amount()); });
 }
 
 OpFoldResult ShrI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -755,8 +796,9 @@
 template <class AttrElementT,
           class ElementValueT = typename AttrElementT::ValueType,
           class CalculationT = std::function<ElementValueT(ElementValueT)>>
-Attribute constFoldConversionOp(Type resultType, ArrayRef<Attribute> operands,
-                                const CalculationT &calculate) {
+static Attribute constFoldConversionOp(Type resultType,
+                                       ArrayRef<Attribute> operands,
+                                       const CalculationT &calculate) {
   assert(operands.size() == 1 && "unary op takes one operand");
   if (auto operand = operands[0].dyn_cast_or_null<AttrElementT>()) {
     return AttrElementT::get(resultType, calculate(operand.getValue()));
@@ -767,101 +809,100 @@
 OpFoldResult TruncI32I8Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(32); });
+      [&](const APInt &a) { return a.trunc(8).zext(32); });
 }
 
 OpFoldResult TruncI32I16Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(32); });
+      [&](const APInt &a) { return a.trunc(16).zext(32); });
 }
 
 OpFoldResult TruncI64I8Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(32); });
+      [&](const APInt &a) { return a.trunc(8).zext(32); });
 }
 
 OpFoldResult TruncI64I16Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(32); });
+      [&](const APInt &a) { return a.trunc(16).zext(32); });
 }
 
 OpFoldResult TruncI64I32Op::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(32); });
+      [&](const APInt &a) { return a.trunc(32); });
 }
 
 OpFoldResult ExtI8I32SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).sext(32); });
+      [&](const APInt &a) { return a.trunc(8).sext(32); });
 }
 
 OpFoldResult ExtI8I32UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(32); });
+      [&](const APInt &a) { return a.trunc(8).zext(32); });
 }
 
 OpFoldResult ExtI16I32SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).sext(32); });
+      [&](const APInt &a) { return a.trunc(16).sext(32); });
 }
 
 OpFoldResult ExtI16I32UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(32, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(32); });
+      [&](const APInt &a) { return a.trunc(16).zext(32); });
 }
 
 OpFoldResult ExtI8I64SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).sext(64); });
+      [&](const APInt &a) { return a.trunc(8).sext(64); });
 }
 
 OpFoldResult ExtI8I64UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(8).zext(64); });
+      [&](const APInt &a) { return a.trunc(8).zext(64); });
 }
 
 OpFoldResult ExtI16I64SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).sext(64); });
+      [&](const APInt &a) { return a.trunc(16).sext(64); });
 }
 
 OpFoldResult ExtI16I64UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.trunc(16).zext(64); });
+      [&](const APInt &a) { return a.trunc(16).zext(64); });
 }
 
 OpFoldResult ExtI32I64SOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.sext(64); });
+      [&](const APInt &a) { return a.sext(64); });
 }
 
 OpFoldResult ExtI32I64UOp::fold(ArrayRef<Attribute> operands) {
   return constFoldConversionOp<IntegerAttr>(
       IntegerType::get(64, getContext()), operands,
-      [&](APInt a) { return a.zext(64); });
+      [&](const APInt &a) { return a.zext(64); });
 }
 
 namespace {
 
 template <typename SRC_OP, typename OP_A, int SZ_T, typename OP_B>
-class PseudoIntegerConversionToSplitConversionOp
+struct PseudoIntegerConversionToSplitConversionOp
     : public OpRewritePattern<SRC_OP> {
   using OpRewritePattern<SRC_OP>::OpRewritePattern;
 
- public:
   LogicalResult matchAndRewrite(SRC_OP op,
                                 PatternRewriter &rewriter) const override {
     auto tmp = rewriter.createOrFold<OP_A>(
@@ -956,7 +997,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.eq(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.eq(b); });
 }
 
 OpFoldResult CmpEQI32Op::fold(ArrayRef<Attribute> operands) {
@@ -984,7 +1025,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ne(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ne(b); });
 }
 
 OpFoldResult CmpNEI32Op::fold(ArrayRef<Attribute> operands) {
@@ -1032,7 +1073,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.slt(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.slt(b); });
 }
 
 OpFoldResult CmpLTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1056,7 +1097,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ult(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ult(b); });
 }
 
 OpFoldResult CmpLTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1100,7 +1141,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.sle(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.sle(b); });
 }
 
 OpFoldResult CmpLTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1130,7 +1171,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ule(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ule(b); });
 }
 
 OpFoldResult CmpLTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1176,7 +1217,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.sgt(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.sgt(b); });
 }
 
 OpFoldResult CmpGTI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1206,7 +1247,7 @@
     return zeroOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.ugt(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.ugt(b); });
 }
 
 OpFoldResult CmpGTI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1256,7 +1297,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.sge(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.sge(b); });
 }
 
 OpFoldResult CmpGTEI32SOp::fold(ArrayRef<Attribute> operands) {
@@ -1286,7 +1327,7 @@
     return oneOfType(op.getType());
   }
   return constFoldBinaryOp<IntegerAttr>(
-      operands, [&](APInt a, APInt b) { return a.uge(b); });
+      operands, [&](const APInt &a, const APInt &b) { return a.uge(b); });
 }
 
 OpFoldResult CmpGTEI32UOp::fold(ArrayRef<Attribute> operands) {
@@ -1311,12 +1352,12 @@
 
 OpFoldResult CmpNZI32Op::fold(ArrayRef<Attribute> operands) {
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return APInt(32, a.getBoolValue()); });
+      operands, [&](const APInt &a) { return APInt(32, a.getBoolValue()); });
 }
 
 OpFoldResult CmpNZI64Op::fold(ArrayRef<Attribute> operands) {
   return constFoldUnaryOp<IntegerAttr>(
-      operands, [&](APInt a) { return APInt(64, a.getBoolValue()); });
+      operands, [&](const APInt &a) { return APInt(64, a.getBoolValue()); });
 }
 
 OpFoldResult CmpEQRefOp::fold(ArrayRef<Attribute> operands) {
diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td
index 727cbf0..d613c8b 100644
--- a/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1311,6 +1311,7 @@
     VM_BinaryArithmeticOp<I32, "mul.i32", VM_OPC_MulI32, [Commutative]> {
   let summary = [{integer multiplication operation}];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def VM_MulI64Op :
@@ -1318,6 +1319,7 @@
                           [VM_ExtI64, Commutative]> {
   let summary = [{integer multiplication operation}];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def VM_DivI32SOp :
diff --git a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
index 153389c..38e49b2 100644
--- a/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
+++ b/iree/compiler/Dialect/VM/IR/test/arithmetic_folding.mlir
@@ -56,6 +56,41 @@
 
 // -----
 
+// CHECK-LABEL: @add_sub_i32_folds
+vm.module @add_sub_i32_folds {
+  // CHECK-LABEL: @add_sub_x
+  vm.func @add_sub_x(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.add.i32 %arg0, %arg1 : i32
+    %1 = vm.sub.i32 %0, %arg1 : i32
+    vm.return %1 : i32
+  }
+  // CHECK-LABEL: @add_sub_x_rev
+  vm.func @add_sub_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.add.i32 %arg1, %arg0 : i32
+    %1 = vm.sub.i32 %arg1, %0 : i32
+    vm.return %1 : i32
+  }
+
+  // CHECK-LABEL: @sub_add_x
+  vm.func @sub_add_x(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.sub.i32 %arg0, %arg1 : i32
+    %1 = vm.add.i32 %0, %arg1 : i32
+    vm.return %1 : i32
+  }
+  // CHECK-LABEL: @sub_add_x_rev
+  vm.func @sub_add_x_rev(%arg0 : i32, %arg1 : i32) -> i32 {
+    // CHECK-NEXT: vm.return %arg0
+    %0 = vm.sub.i32 %arg0, %arg1 : i32
+    %1 = vm.add.i32 %arg1, %0 : i32
+    vm.return %1 : i32
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @mul_i32_folds
 vm.module @mul_i32_folds {
   // CHECK-LABEL: @mul_i32_by_0
@@ -96,6 +131,23 @@
 
 // -----
 
+// CHECK-LABEL: @mul_mul_i32_folds
+vm.module @mul_mul_i32_folds {
+  // CHECK-LABEL: @mul_mul_i32_const
+  vm.func @mul_mul_i32_const(%arg0 : i32) -> i32 {
+    // CHECK: %c40 = vm.const.i32 40 : i32
+    %c4 = vm.const.i32 4 : i32
+    %c10 = vm.const.i32 10 : i32
+    // CHECK: %0 = vm.mul.i32 %arg0, %c40 : i32
+    %0 = vm.mul.i32 %arg0, %c4 : i32
+    %1 = vm.mul.i32 %0, %c10 : i32
+    // CHECK-NEXT: vm.return %0 : i32
+    vm.return %1 : i32
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @div_i32_folds
 vm.module @div_i32_folds {
   // CHECK-LABEL: @div_i32_0_y
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index b3982b0..6bd35bb 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -43,6 +43,7 @@
   // If we end up with a lot of these, consider using an "is pseudo" trait.
   addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
   addIllegalOp<IREE::VMLA::SortPseudoOp>();
+  addIllegalOp<IREE::VMLA::FftPseudoOp>();
 
   // Allow other ops to pass through so long as their type is valid (not a
   // tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index 6f43d6e..5b842a5 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -703,6 +703,43 @@
   TypeConverter &typeConverter;
 };
 
+struct FftOpConversion : public OpConversionPattern<IREE::VMLA::FftPseudoOp> {
+  FftOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  LogicalResult matchAndRewrite(
+      IREE::VMLA::FftPseudoOp srcOp, ArrayRef<Value> rawOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto input_shape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.real_in(), typeConverter, rewriter);
+
+    auto real_input_type = srcOp.getOperand(0).getType().cast<ShapedType>();
+    auto imag_input_type = srcOp.getOperand(1).getType().cast<ShapedType>();
+
+    // The input type/shape should match for the real and imag components.
+    if (real_input_type != imag_input_type) {
+      srcOp.emitWarning() << "real and imag should have matching types";
+      return failure();
+    }
+
+    auto real_out = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(0), typeConverter, rewriter);
+    auto imag_out = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(1), typeConverter, rewriter);
+
+    rewriter.createOrFold<IREE::VMLA::FftOp>(
+        srcOp.getLoc(), rawOperands[0], input_shape, rawOperands[1],
+        input_shape, real_out, imag_out,
+        TypeAttr::get(real_input_type.getElementType()),
+        TypeAttr::get(imag_input_type.getElementType()));
+
+    rewriter.replaceOp(srcOp, {real_out, imag_out});
+    return success();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -769,6 +806,9 @@
   // vmla.sort.pseudo
   patterns.insert<SortOpConversion>(context, typeConverter);
 
+  // vmla.fft.pseudo
+  patterns.insert<FftOpConversion>(context, typeConverter);
+
   // Simple 1:1 conversion patterns using the automated trait-based converter.
   // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
   patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
new file mode 100644
index 0000000..1e3c365
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/fft.mlir
@@ -0,0 +1,11 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @fft(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->  (tensor<8xf32>, tensor<8xf32>) attributes { sym_visibility = "private" } {
+  // CHECK: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[8]>
+  // CHECK-NEXT: [[C32:%.+]] = constant 32 : index
+  // CHECK-NEXT: [[OUTBUF1:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+  // CHECK-NEXT: [[OUTBUF2:%.+]] = vmla.buffer.alloc byte_length = [[C32]] : !vmla.buffer
+  // CHECK-NEXT: vmla.fft %arg0([[RS]] : !shapex.ranked_shape<[8]>),  %arg1([[RS]] : !shapex.ranked_shape<[8]>), out [[OUTBUF1]], [[OUTBUF2]] : f32, f32
+  %real, %imag = "vmla.fft.pseudo"(%arg0, %arg1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>)
+  return %real, %imag : tensor<8xf32>, tensor<8xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f31545d..f139690 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -694,5 +694,24 @@
   }];
 }
 
+def VMLA_FftOp : VMLA_ElementTypeOp<"fft", [VMLA_IncludeShapes]> {
+  let arguments = (ins
+    VMLA_Buffer:$real_in,
+    VMLA_Shape:$real_in_shape,
+    VMLA_Buffer:$imag_in,
+    VMLA_Shape:$imag_in_shape,
+    VMLA_Buffer:$real_out,
+    VMLA_Buffer:$imag_out,
+    VMLA_AnyTypeAttr:$real_element_type,
+    VMLA_AnyTypeAttr:$imag_element_type
+  );
+
+  let assemblyFormat = [{
+    $real_in`(`$real_in_shape `:` type($real_in_shape)`)` `,`
+    $imag_in`(`$imag_in_shape `:` type($imag_in_shape)`)` `,`
+    `out` $real_out `,` $imag_out attr-dict `:` $real_element_type `,` $imag_element_type
+  }];
+}
+
 
 #endif  // IREE_DIALECT_VMLA_OPS
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index baa8d41..35732c6 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -50,6 +50,15 @@
         'End to end tests of TensorFlow slim vision models',
 }
 
+# Key to use as the name of the rows in the left column for each test in the
+# suite.
+TEST_SUITE_TO_ROW_ID_KEY = {
+    '//integrations/tensorflow/e2e/keras:imagenet_external_tests':
+        'model',
+    '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests':
+        'model',
+}
+
 # Some test suites are generated from a single source. This allows us to point
 # to the right test file when generating test URLs.
 SINGLE_SOURCE_SUITES = {
@@ -101,6 +110,34 @@
   return parsed_args
 
 
+def parse_test_name(test_name, test_suite):
+  """Splits a test name into a dictionary with its source file and backend."""
+  test_name_parts = test_name.split("__")
+  test_info = {}
+
+  # The iree_e2e_test_suite elides a 'src' key before the name of the test
+  # for brevity.
+  if len(test_name_parts) % 2 == 1:
+    test_info['src'] = test_name_parts.pop(0)
+
+  # The rest of the test name should follow 'key__value__key__value__...'.
+  for key, value in zip(test_name_parts[::2], test_name_parts[1::2]):
+    test_info[key] = value
+
+  # Default to using the test source file name as the row id for the table.
+  if 'src' in test_info:
+    test_info['row_id'] = test_info['src']
+  else:
+    test_info['src'] = SINGLE_SOURCE_SUITES[test_suite]
+    test_info['row_id'] = test_info[TEST_SUITE_TO_ROW_ID_KEY[test_suite]]
+
+  if 'target_backends' not in test_info:
+    raise ValueError('Expected `target_backends` to be in the test name but '
+                     f'got `{test_name}`.')
+
+  return test_info
+
+
 def get_name_and_backend(test_string):
   """Splits a pathless test target into its name and comparison backend."""
   name, backend = test_string.split(f'__{REFERENCE_BACKEND}__')
@@ -113,42 +150,41 @@
   failing = utils.get_test_targets(f'{test_suite}_failing')
 
   # Remove bazel path.
-  passing = [test.replace(f'{test_suite}_', '') for test in passing]
-  failing = [test.replace(f'{test_suite}_failing_', '') for test in failing]
+  passing = [test.replace(f'{test_suite}__', '') for test in passing]
+  failing = [test.replace(f'{test_suite}_failing__', '') for test in failing]
 
-  # Split into (test_name, target_backend).
-  passing = [get_name_and_backend(test) for test in passing]
-  failing = [get_name_and_backend(test) for test in failing]
-  passing_names = [test[0] for test in passing]
-  failing_names = [test[0] for test in failing]
-  all_names = list(sorted(set(passing_names + failing_names)))
-  return all_names, passing, failing
+  # Split into a dictionary mapping 'src', 'target_backend', ... to the
+  # appropriate values for each test target.
+  passing_info = [parse_test_name(test, test_suite) for test in passing]
+  failing_info = [parse_test_name(test, test_suite) for test in failing]
+  return passing_info, failing_info
 
 
-def get_name_element(test_suite, name):
+def get_row_hyperlink(test_suite, row_id, test_source):
   """Returns a Markdown hyperlink pointing to the test source on GitHub."""
   # Convert `//path/to/tests:test_suite` to `path/to/tests`
-  test_path = test_suite.split(':')[0]
-  test_path = test_path.replace('//', '')
+  test_path = test_suite.replace('//', '').split(':')[0]
 
-  if test_suite in SINGLE_SOURCE_SUITES:
-    test_name = SINGLE_SOURCE_SUITES[test_suite]
-  else:
-    test_name = name
-
-  test_url = os.path.join(MAIN_URL, test_path, f'{test_name}.py')
-  return f'[{name}]({test_url})'
+  test_url = os.path.join(MAIN_URL, test_path, f'{test_source}.py')
+  return f'[{row_id}]({test_url})'
 
 
 def generate_table(test_suite):
   """Generates an e2e backend coverage Markdown table."""
-  all_names, passing, _ = get_suite_metadata(test_suite)
+  passing_info, _ = get_suite_metadata(test_suite)
 
-  # Generate a dictionary mapping test names to their backend coverage.
+  # Create a dictionary mapping row names to source file names.
+  row_id_to_source = {}
+  for test_info in passing_info:
+    row_id_to_source[test_info['row_id']] = test_info['src']
+
+  # Create a dictionary mapping test names to a list of bools representing their
+  # backend coverage.
   table = collections.defaultdict(lambda: [False] * len(BACKENDS_TO_TITLES))
   ordered_backends = list(BACKENDS_TO_TITLES.keys())
-  for name, backend in passing:
-    table[name][ordered_backends.index(backend)] = True
+  for test_info in passing_info:
+    backend_index = ordered_backends.index(test_info['target_backends'])
+    table[test_info['row_id']][backend_index] = True
 
   # Create a header for the coverage table.
   ordered_backend_titles = list(BACKENDS_TO_TITLES.values())
@@ -157,11 +193,12 @@
 
   # Generate the coverage table as a 2D array.
   rows = [first_row, second_row]
-  for name, backends in sorted(table.items()):
-    if any(re.match(pattern, name) for pattern in TARGET_EXCLUSION_FILTERS):
+  for row_id, backends in sorted(table.items()):
+    # Skip any rows defined in the TARGET_EXCLUSION_FILTERS.
+    if any(re.match(pattern, row_id) for pattern in TARGET_EXCLUSION_FILTERS):
       continue
 
-    row = [get_name_element(test_suite, name)]
+    row = [get_row_hyperlink(test_suite, row_id, row_id_to_source[row_id])]
     row.extend([
         SUCCESS_ELEMENT if backend else FAILURE_ELEMENT for backend in backends
     ])