Add output verification to linux benchmark tool (#15281)

Support output verification with `iree-run-module` in linux benchmark
tool.

The follow-up changes will propagate the input data and expected output
from benchmark definitions to the tool
diff --git a/build_tools/benchmarks/common/benchmark_config.py b/build_tools/benchmarks/common/benchmark_config.py
index f9fb2dd..a86ef0f 100644
--- a/build_tools/benchmarks/common/benchmark_config.py
+++ b/build_tools/benchmarks/common/benchmark_config.py
@@ -56,6 +56,7 @@
       times.
     continue_from_previous: skip the benchmarks if their results are found in
       the benchmark_results_dir.
+    verify: verify the output if model's expected output is available.
     """
 
     root_benchmark_dir: pathlib.Path
@@ -73,6 +74,7 @@
     keep_going: bool = False
     benchmark_min_time: float = 0
     continue_from_previous: bool = False
+    verify: bool = False
 
     @staticmethod
     def build_from_args(args: Namespace, git_commit_hash: str):
@@ -124,4 +126,5 @@
             keep_going=args.keep_going,
             benchmark_min_time=args.benchmark_min_time,
             continue_from_previous=args.continue_from_previous,
+            verify=args.verify,
         )
diff --git a/build_tools/benchmarks/common/benchmark_config_test.py b/build_tools/benchmarks/common/benchmark_config_test.py
index 2a446ab..747982b 100644
--- a/build_tools/benchmarks/common/benchmark_config_test.py
+++ b/build_tools/benchmarks/common/benchmark_config_test.py
@@ -53,6 +53,7 @@
                 f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}",
                 f"--execution_benchmark_config={self.execution_config}",
                 "--target_device=test",
+                "--verify",
             ]
         )
 
@@ -79,6 +80,7 @@
             keep_going=True,
             benchmark_min_time=10,
             use_compatible_filter=True,
+            verify=True,
         )
         self.assertEqual(config, expected_config)
 
diff --git a/build_tools/benchmarks/common/benchmark_suite.py b/build_tools/benchmarks/common/benchmark_suite.py
index 05ee0b5..aa6ba88 100644
--- a/build_tools/benchmarks/common/benchmark_suite.py
+++ b/build_tools/benchmarks/common/benchmark_suite.py
@@ -9,10 +9,10 @@
 benchmark suite.
 """
 
-import collections
 import pathlib
 import re
 
+import dataclasses
 from dataclasses import dataclass
 from typing import Dict, List, Optional, Sequence, Tuple
 from common.benchmark_definition import IREE_DRIVERS_INFOS, DriverInfo
@@ -36,6 +36,8 @@
     benchmark_tool_name: the benchmark tool, e.g., 'iree-benchmark-module'.
     benchmark_case_dir: the path to benchmark case directory.
     run_config: the run config from e2e test framework.
+    input_uri: URI to find the input npy.
+    expected_output_uri: URI to find the expected output npy.
     """
 
     model_name: str
@@ -46,6 +48,9 @@
     benchmark_tool_name: str
     benchmark_case_dir: pathlib.Path
     run_config: iree_definitions.E2EModelRunConfig
+    input_uri: Optional[str] = None
+    expected_output_uri: Optional[str] = None
+    verify_params: List[str] = dataclasses.field(default_factory=list)
 
 
 # A map from execution config to driver info. This is temporary during migration
diff --git a/build_tools/benchmarks/common/common_arguments.py b/build_tools/benchmarks/common/common_arguments.py
index 71e8a82..8b7001b 100644
--- a/build_tools/benchmarks/common/common_arguments.py
+++ b/build_tools/benchmarks/common/common_arguments.py
@@ -171,6 +171,11 @@
             "information",
         )
         self.add_argument(
+            "--verify",
+            action="store_true",
+            help="Verify the output when the expected output is available",
+        )
+        self.add_argument(
             "--execution_benchmark_config",
             type=_check_file_path,
             required=True,
diff --git a/build_tools/benchmarks/run_benchmarks_on_linux.py b/build_tools/benchmarks/run_benchmarks_on_linux.py
index 8f7a36a..d61d9f0 100755
--- a/build_tools/benchmarks/run_benchmarks_on_linux.py
+++ b/build_tools/benchmarks/run_benchmarks_on_linux.py
@@ -15,6 +15,7 @@
 from typing import Any, List, Optional
 import atexit
 import json
+import requests
 import shutil
 import subprocess
 import tarfile
@@ -34,6 +35,7 @@
 from common.linux_device_utils import get_linux_device_info
 from e2e_test_artifacts import iree_artifacts
 from e2e_model_tests import run_module_utils
+
 import common.common_arguments
 
 
@@ -50,8 +52,35 @@
         benchmark_results_filename: Optional[pathlib.Path],
         capture_filename: Optional[pathlib.Path],
     ) -> None:
+        case_dir = benchmark_case.benchmark_case_dir
+        inputs_dir = None
+        expected_output_dir = None
+        if benchmark_case.input_uri:
+            inputs_dir = self.__fetch_and_unpack_npy(
+                uri=benchmark_case.input_uri, dest_dir=case_dir / "inputs_npy"
+            )
+        if benchmark_case.expected_output_uri:
+            expected_output_dir = self.__fetch_and_unpack_npy(
+                uri=benchmark_case.expected_output_uri,
+                dest_dir=case_dir / "expected_outputs_npy",
+            )
+
         if benchmark_results_filename:
+            if self.config.normal_benchmark_tool_dir is None:
+                raise ValueError("normal_benchmark_tool_dir can't be None.")
+
+            if self.config.verify and expected_output_dir:
+                if not inputs_dir:
+                    raise ValueError(f"Input data is missing for {benchmark_case}.")
+                self.__run_verify(
+                    tool_dir=self.config.normal_benchmark_tool_dir,
+                    benchmark_case=benchmark_case,
+                    inputs_dir=inputs_dir,
+                    expected_outputs_dir=expected_output_dir,
+                )
+
             self.__run_benchmark(
+                tool_dir=self.config.normal_benchmark_tool_dir,
                 benchmark_case=benchmark_case,
                 results_filename=benchmark_results_filename,
             )
@@ -62,7 +91,10 @@
             )
 
     def __build_tool_cmds(
-        self, benchmark_case: BenchmarkCase, tool_path: pathlib.Path
+        self,
+        benchmark_case: BenchmarkCase,
+        tool_path: pathlib.Path,
+        inputs_dir: Optional[pathlib.Path] = None,
     ) -> List[Any]:
         run_config = benchmark_case.run_config
         cmds: List[Any] = run_module_utils.build_linux_wrapper_cmds_for_device_spec(
@@ -72,18 +104,66 @@
 
         module_dir_path = benchmark_case.benchmark_case_dir
         cmds += [f"--module={module_dir_path / iree_artifacts.MODULE_FILENAME}"]
-        cmds += run_config.materialize_run_flags(gpu_id=self.gpu_id)
+        cmds += run_config.materialize_run_flags(
+            gpu_id=self.gpu_id,
+            inputs_dir=inputs_dir,
+        )
 
         return cmds
 
-    def __run_benchmark(
-        self, benchmark_case: BenchmarkCase, results_filename: pathlib.Path
-    ):
-        if self.config.normal_benchmark_tool_dir is None:
-            raise ValueError("normal_benchmark_tool_dir can't be None.")
+    def __fetch_and_unpack_npy(self, uri: str, dest_dir: pathlib.Path) -> pathlib.Path:
+        out_dir = self.__unpack_file(
+            src=self.__fetch_file(
+                uri=uri,
+                dest=dest_dir.with_suffix(".tgz"),
+            ),
+            dest=dest_dir,
+        )
+        return out_dir.absolute()
 
+    def __fetch_file(self, uri: str, dest: pathlib.Path) -> pathlib.Path:
+        """Check and fetch file if needed."""
+        if dest.exists():
+            return dest
+        req = requests.get(uri, stream=True, timeout=60)
+        with dest.open("wb") as dest_file:
+            for data in req.iter_content():
+                dest_file.write(data)
+        return dest
+
+    def __unpack_file(self, src: pathlib.Path, dest: pathlib.Path) -> pathlib.Path:
+        """Unpack tar with/without compression."""
+        if dest.exists():
+            return dest
+        with tarfile.open(src) as tar_file:
+            tar_file.extractall(dest)
+        return dest
+
+    def __run_verify(
+        self,
+        tool_dir: pathlib.Path,
+        benchmark_case: BenchmarkCase,
+        inputs_dir: pathlib.Path,
+        expected_outputs_dir: pathlib.Path,
+    ):
+        cmd = self.__build_tool_cmds(
+            benchmark_case=benchmark_case,
+            tool_path=tool_dir / "iree-run-module",
+            inputs_dir=inputs_dir,
+        )
+        # Currently only support single output.
+        cmd.append(f'--expected_output=@{expected_outputs_dir / "output_0.npy"}')
+        cmd += benchmark_case.verify_params
+        execute_cmd_and_get_output(cmd, verbose=self.verbose)
+
+    def __run_benchmark(
+        self,
+        tool_dir: pathlib.Path,
+        benchmark_case: BenchmarkCase,
+        results_filename: pathlib.Path,
+    ):
         tool_name = benchmark_case.benchmark_tool_name
-        tool_path = self.config.normal_benchmark_tool_dir / tool_name
+        tool_path = tool_dir / tool_name
         cmd = self.__build_tool_cmds(benchmark_case=benchmark_case, tool_path=tool_path)
 
         if tool_name == "iree-benchmark-module":
diff --git a/build_tools/python/e2e_model_tests/cmake_generator.py b/build_tools/python/e2e_model_tests/cmake_generator.py
index 5b0bb1a..3a5e8b0 100644
--- a/build_tools/python/e2e_model_tests/cmake_generator.py
+++ b/build_tools/python/e2e_model_tests/cmake_generator.py
@@ -54,12 +54,14 @@
         runner_args = (
             iree_definitions.generate_run_flags(
                 imported_model=imported_model,
-                input_data=test_config.input_data,
                 module_execution_config=test_config.execution_config,
                 with_driver=False,
             )
             + test_config.extra_test_flags
         )
+        runner_args += [
+            f"--input={input_type}=0" for input_type in imported_model.model.input_types
+        ]
         cmake_rule = cmake_builder.rules.build_iree_benchmark_suite_module_test(
             target_name=test_config.name,
             driver=test_config.execution_config.driver.value,
diff --git a/build_tools/python/e2e_model_tests/test_definitions.py b/build_tools/python/e2e_model_tests/test_definitions.py
index c1af11f..b7cbde8 100644
--- a/build_tools/python/e2e_model_tests/test_definitions.py
+++ b/build_tools/python/e2e_model_tests/test_definitions.py
@@ -49,9 +49,6 @@
 
     # Either a string literal or a file path.
     expected_output: str
-    input_data: common_definitions.ModelInputData = (
-        common_definitions.ZEROS_MODEL_INPUT_DATA
-    )
 
     # Platforms to ignore this test.
     unsupported_platforms: List[CMakePlatform] = dataclasses.field(default_factory=list)
diff --git a/build_tools/python/e2e_test_framework/definitions/iree_definitions.py b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py
index 8cce00f..f7278a2 100644
--- a/build_tools/python/e2e_test_framework/definitions/iree_definitions.py
+++ b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py
@@ -347,9 +347,30 @@
     def __str__(self):
         return self.name
 
-    def materialize_run_flags(self, gpu_id: str = "0"):
-        """Materialize flags with dependent values."""
-        return utils.substitute_flag_vars(flags=self.run_flags, GPU_ID=gpu_id)
+    def materialize_run_flags(
+        self, gpu_id: str = "0", inputs_dir: Optional[pathlib.PurePath] = None
+    ) -> List[str]:
+        """Materialize flags with dependent values.
+
+        Args:
+            gpu_id: gpu id to use.
+            inputs_dir: directory contains input_{0,1,...}.npy for each input.
+
+        Returns:
+            List of flags
+        """
+        flags = utils.substitute_flag_vars(flags=self.run_flags, GPU_ID=gpu_id)
+
+        model = self.module_generation_config.imported_model.model
+        if inputs_dir:
+            input_npys = [
+                inputs_dir / f"input_{idx}.npy" for idx in range(len(model.input_types))
+            ]
+            flags += [f"--input=@{npy}" for npy in input_npys]
+        else:
+            flags += [f"--input={input_type}=0" for input_type in model.input_types]
+
+        return flags
 
     @classmethod
     def build(
@@ -374,7 +395,6 @@
         name = f"{module_generation_config} {module_execution_config} with {input_data} @ {target_device_spec}"
         run_flags = generate_run_flags(
             imported_model=module_generation_config.imported_model,
-            input_data=input_data,
             module_execution_config=module_execution_config,
             gpu_id=r"${GPU_ID}",
         )
@@ -394,7 +414,6 @@
 
 def generate_run_flags(
     imported_model: ImportedModel,
-    input_data: common_definitions.ModelInputData,
     module_execution_config: ModuleExecutionConfig,
     gpu_id: str = "0",
     with_driver: bool = True,
@@ -413,9 +432,6 @@
 
     model = imported_model.model
     run_flags = [f"--function={model.entry_function}"]
-    if input_data != common_definitions.ZEROS_MODEL_INPUT_DATA:
-        raise ValueError("Currently only support all-zeros data.")
-    run_flags += [f"--input={input_type}=0" for input_type in model.input_types]
 
     exec_config = module_execution_config
     run_flags += exec_config.extra_flags.copy()
diff --git a/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py b/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py
index 585c7e5..71262e5 100644
--- a/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py
+++ b/build_tools/python/e2e_test_framework/definitions/iree_definitions_test.py
@@ -33,7 +33,6 @@
 
         flags = iree_definitions.generate_run_flags(
             imported_model=imported_model,
-            input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
             module_execution_config=execution_config,
         )
 
@@ -41,8 +40,6 @@
             flags,
             [
                 "--function=main",
-                "--input=1xf32=0",
-                "--input=2x2xf32=0",
                 "--task=10",
                 "--device=local-task",
             ],
@@ -70,14 +67,11 @@
 
         flags = iree_definitions.generate_run_flags(
             imported_model=imported_model,
-            input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
             module_execution_config=execution_config,
             gpu_id="3",
         )
 
-        self.assertEqual(
-            flags, ["--function=main", "--input=1xf32=0", "--device=cuda://3"]
-        )
+        self.assertEqual(flags, ["--function=main", "--device=cuda://3"])
 
     def test_generate_run_flags_without_driver(self):
         imported_model = iree_definitions.ImportedModel.from_model(
@@ -101,12 +95,71 @@
 
         flags = iree_definitions.generate_run_flags(
             imported_model=imported_model,
-            input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
             module_execution_config=execution_config,
             with_driver=False,
         )
 
-        self.assertEqual(flags, ["--function=main", "--input=1xf32=0", "--task=10"])
+        self.assertEqual(flags, ["--function=main", "--task=10"])
+
+    def test_materialize_run_flags(self):
+        imported_model = iree_definitions.ImportedModel.from_model(
+            common_definitions.Model(
+                id="1234",
+                name="tflite_m",
+                tags=[],
+                source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE,
+                source_url="https://example.com/xyz.tflite",
+                entry_function="main",
+                input_types=["1xf32", "2x2xf32"],
+            )
+        )
+        compile_target = iree_definitions.CompileTarget(
+            target_backend=iree_definitions.TargetBackend.CUDA,
+            target_architecture=common_definitions.DeviceArchitecture.CUDA_SM80,
+            target_abi=iree_definitions.TargetABI.LINUX_GNU,
+        )
+        compile_config = iree_definitions.CompileConfig(
+            id="compile_config_a",
+            name="compile_config_a",
+            tags=["test"],
+            compile_targets=[compile_target],
+        )
+        gen_config = iree_definitions.ModuleGenerationConfig.build(
+            imported_model=imported_model, compile_config=compile_config
+        )
+        exec_config = iree_definitions.ModuleExecutionConfig.build(
+            id="123",
+            tags=["test"],
+            loader=iree_definitions.RuntimeLoader.NONE,
+            driver=iree_definitions.RuntimeDriver.CUDA,
+        )
+        device_spec = common_definitions.DeviceSpec.build(
+            id="test_dev",
+            device_name="test_model",
+            host_environment=common_definitions.HostEnvironment.LINUX_X86_64,
+            architecture=common_definitions.DeviceArchitecture.CUDA_SM80,
+        )
+        run_config = iree_definitions.E2EModelRunConfig.build(
+            gen_config,
+            exec_config,
+            device_spec,
+            # TODO(#15282): ZEROS_MODEL_INPUT_DATA should be renamed to
+            # DEFAULT_INPUT_DATA, which means to use input npys if available;
+            # otherwise use all zeros data.
+            input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+            tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+        )
+
+        inputs_dir = pathlib.PurePath("inputs_dir")
+        flags = run_config.materialize_run_flags(gpu_id="10", inputs_dir=inputs_dir)
+
+        self.assertIn("--device=cuda://10", flags)
+        first_input = f'--input=@{inputs_dir / "input_0.npy"}'
+        self.assertIn(first_input, flags)
+        first_input_idx = flags.index(first_input)
+        self.assertEqual(
+            flags[first_input_idx + 1], f'--input=@{inputs_dir/"input_1.npy"}'
+        )
 
 
 class ModuleGenerationConfigTest(unittest.TestCase):
diff --git a/tests/e2e/models/generated_e2e_model_tests.cmake b/tests/e2e/models/generated_e2e_model_tests.cmake
index b4c1646..83523b9 100644
--- a/tests/e2e/models/generated_e2e_model_tests.cmake
+++ b/tests/e2e/models/generated_e2e_model_tests.cmake
@@ -13,8 +13,8 @@
     "x86_64-Linux=iree_module_MobileNetV1_fp32_tflite___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_/module.vmfb"
   RUNNER_ARGS
     "--function=main"
-    "--input=1x224x224x3xf32=0"
     "--device_allocator=caching"
+    "--input=1x224x224x3xf32=0"
 )
 
 iree_benchmark_suite_module_test(
@@ -25,8 +25,8 @@
     "x86_64-Linux=iree_module_EfficientNet_int8_tflite___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_/module.vmfb"
   RUNNER_ARGS
     "--function=main"
-    "--input=1x224x224x3xui8=0"
     "--device_allocator=caching"
+    "--input=1x224x224x3xui8=0"
 )
 
 iree_benchmark_suite_module_test(
@@ -38,9 +38,9 @@
     "x86_64-Linux=iree_module_DeepLabV3_fp32_tflite___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_/module.vmfb"
   RUNNER_ARGS
     "--function=main"
-    "--input=1x257x257x3xf32=0"
     "--device_allocator=caching"
     "--expected_f32_threshold=0.001"
+    "--input=1x257x257x3xf32=0"
 )
 
 iree_benchmark_suite_module_test(
@@ -53,7 +53,7 @@
     "x86_64-Linux=iree_module_PersonDetect_int8_tflite___x86_64-cascadelake-linux_gnu-llvm_cpu__default-flags_/module.vmfb"
   RUNNER_ARGS
     "--function=main"
-    "--input=1x96x96x1xi8=0"
     "--device_allocator=caching"
+    "--input=1x96x96x1xi8=0"
 )