Support exporting configs for compilation statistics (#11604)

diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml
index 91aac17..1aa8edd 100644
--- a/.github/workflows/benchmarks.yml
+++ b/.github/workflows/benchmarks.yml
@@ -64,6 +64,7 @@
           BENCHMARK_PRESETS: ${{ inputs.benchmark-presets }}
         run: |
           ./build_tools/benchmarks/export_benchmark_config.py \
+            execution \
             --benchmark_presets="${BENCHMARK_PRESETS}" \
             --output="${BENCHMARK_CONFIG}"
           echo "benchmark-config=${BENCHMARK_CONFIG}" >> "${GITHUB_OUTPUT}"
diff --git a/build_tools/benchmarks/CMakeLists.txt b/build_tools/benchmarks/CMakeLists.txt
index 00905e2..28192b9 100644
--- a/build_tools/benchmarks/CMakeLists.txt
+++ b/build_tools/benchmarks/CMakeLists.txt
@@ -47,3 +47,10 @@
   SRC
     "collect_compilation_statistics_test.py"
 )
+
+benchmark_tool_py_test(
+  NAME
+    export_benchmark_config_test
+  SRC
+    "export_benchmark_config_test.py"
+)
diff --git a/build_tools/benchmarks/export_benchmark_config.py b/build_tools/benchmarks/export_benchmark_config.py
index 4fec45d..39243d8 100755
--- a/build_tools/benchmarks/export_benchmark_config.py
+++ b/build_tools/benchmarks/export_benchmark_config.py
@@ -4,9 +4,9 @@
 # Licensed under the Apache License v2.0 with LLVM Exceptions.
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""Exports JSON config for benchmarking.
+"""Exports JSON config for benchmarking and compilation statistics.
 
-The exported JSON is a list of object:
+Export type: "execution" outputs:
 [
   <target device name>: {
     host_environment: HostEnvironment,
@@ -15,6 +15,11 @@
   },
   ...
 ]
+to be used in build_tools/benchmarks/run_benchmarks_on_*.py
+
+Export type: "compilation" outputs a serialized list of module generation config
+defined for compilation statistics, to be used in
+build_tools/benchmarks/collect_compilation_statistics.py
 """
 
 import sys
@@ -23,16 +28,18 @@
 # Add build_tools python dir to the search path.
 sys.path.insert(0, str(pathlib.Path(__file__).parent.with_name("python")))
 
+from typing import Callable, Dict, List, Optional, Set
 import argparse
 import collections
 import dataclasses
 import json
-from typing import Callable, Dict, List
+import textwrap
 
 from benchmark_suites.iree import benchmark_collections
-from e2e_test_framework.definitions import common_definitions, iree_definitions
-from e2e_test_framework import serialization
 from e2e_test_artifacts import iree_artifacts
+from e2e_test_framework import serialization
+from e2e_test_framework.definitions import common_definitions, iree_definitions
+from e2e_test_framework.definitions import iree_definitions
 
 PresetMatcher = Callable[[iree_definitions.E2EModelRunConfig], bool]
 BENCHMARK_PRESET_MATCHERS: Dict[str, PresetMatcher] = {
@@ -55,52 +62,24 @@
 }
 
 
-def parse_arguments():
-  """Parses command-line options."""
+def filter_and_group_run_configs(
+    run_configs: List[iree_definitions.E2EModelRunConfig],
+    target_device_names: Optional[Set[str]] = None,
+    preset_matchers: Optional[List[PresetMatcher]] = None
+) -> Dict[str, List[iree_definitions.E2EModelRunConfig]]:
+  """Filters run configs and groups by target device name.
+  
+  Args:
+    run_configs: source e2e model run configs.
+    target_device_names: list of target device names, includes all if not set.
+    preset_matchers: list of preset matcher, matches all if not set.
 
-  def parse_and_strip_list_argument(arg) -> List[str]:
-    return [part.strip() for part in arg.split(",")]
-
-  def parse_benchmark_presets(arg) -> List[PresetMatcher]:
-    matchers = []
-    for preset in parse_and_strip_list_argument(arg):
-      matcher = BENCHMARK_PRESET_MATCHERS.get(preset)
-      if matcher is None:
-        raise argparse.ArgumentTypeError(
-            f"Unrecognized benchmark preset: '{preset}'.")
-      matchers.append(matcher)
-    return matchers
-
-  parser = argparse.ArgumentParser(
-      description="Exports JSON config for benchmarking. Filters can be "
-      "specified jointly to select a subset of benchmarks.")
-  parser.add_argument(
-      "--target_device_names",
-      type=parse_and_strip_list_argument,
-      help=("Target device names, separated by comma, not specified means "
-            "including all devices."))
-  parser.add_argument(
-      "--benchmark_presets",
-      type=parse_benchmark_presets,
-      help=("Presets that select a bundle of benchmarks, separated by comma, "
-            "multiple presets will be union. Available options: "
-            f"{','.join(BENCHMARK_PRESET_MATCHERS.keys())}"))
-  parser.add_argument("--output",
-                      type=pathlib.Path,
-                      help="Path to write the JSON output.")
-
-  return parser.parse_args()
-
-
-def main(args: argparse.Namespace):
-  _, all_run_configs = benchmark_collections.generate_benchmarks()
-
-  target_device_names = (set(args.target_device_names)
-                         if args.target_device_names is not None else None)
-  preset_matchers = args.benchmark_presets
-
+  Returns:
+    A map of e2e model run configs keyed by target device name.
+  """
   grouped_run_config_map = collections.defaultdict(list)
-  for run_config in all_run_configs:
+
+  for run_config in run_configs:
     device_name = run_config.target_device_spec.device_name
     if (target_device_names is not None and
         device_name not in target_device_names):
@@ -110,6 +89,18 @@
       continue
     grouped_run_config_map[device_name].append(run_config)
 
+  return grouped_run_config_map
+
+
+def _export_execution_handler(args: argparse.Namespace):
+  _, all_run_configs = benchmark_collections.generate_benchmarks()
+  target_device_names = (set(args.target_device_names)
+                         if args.target_device_names is not None else None)
+  grouped_run_config_map = filter_and_group_run_configs(
+      all_run_configs,
+      target_device_names=target_device_names,
+      preset_matchers=args.benchmark_presets)
+
   output_map = {}
   for device_name, run_configs in grouped_run_config_map.items():
     host_environments = set(run_config.target_device_spec.host_environment
@@ -131,7 +122,94 @@
         "run_configs": serialization.serialize_and_pack(run_configs),
     }
 
-  json_data = json.dumps(output_map, indent=2)
+  return output_map
+
+
+def _export_compilation_handler(_args: argparse.Namespace):
+  all_gen_configs, _ = benchmark_collections.generate_benchmarks()
+  compile_stats_gen_configs = [
+      config for config in all_gen_configs
+      if benchmark_collections.COMPILE_STATS_TAG in config.compile_config.tags
+  ]
+
+  return serialization.serialize_and_pack(compile_stats_gen_configs)
+
+
+def _parse_and_strip_list_argument(arg) -> List[str]:
+  return [part.strip() for part in arg.split(",")]
+
+
+def _parse_benchmark_presets(arg) -> List[PresetMatcher]:
+  matchers = []
+  for preset in _parse_and_strip_list_argument(arg):
+    matcher = BENCHMARK_PRESET_MATCHERS.get(preset)
+    if matcher is None:
+      raise argparse.ArgumentTypeError(
+          f"Unrecognized benchmark preset: '{preset}'.")
+    matchers.append(matcher)
+  return matchers
+
+
+def _parse_arguments():
+  """Parses command-line options."""
+
+  # Makes global options come *after* command.
+  # See https://stackoverflow.com/q/23296695
+  subparser_base = argparse.ArgumentParser(add_help=False)
+  subparser_base.add_argument("--output",
+                              type=pathlib.Path,
+                              help="Path to write the JSON output.")
+
+  parser = argparse.ArgumentParser(
+      formatter_class=argparse.RawDescriptionHelpFormatter,
+      description=textwrap.dedent("""
+      Export type: "execution" outputs:
+      [
+        <target device name>: {
+          host_environment: HostEnvironment,
+          module_dir_paths: [<paths of dependent module directories>],
+          run_configs: [E2EModelRunConfig]
+        },
+        ...
+      ]
+      to be used in build_tools/benchmarks/run_benchmarks_on_*.py
+
+      Export type: "compilation" outputs a serialized list of module generation
+      config defined for compilation statistics, to be used in
+      build_tools/benchmarks/collect_compilation_statistics.py
+      """))
+
+  subparser = parser.add_subparsers(required=True, title="export type")
+  execution_parser = subparser.add_parser(
+      "execution",
+      parents=[subparser_base],
+      help="Export execution config to run benchmarks.")
+  execution_parser.set_defaults(handler=_export_execution_handler)
+  execution_parser.add_argument(
+      "--target_device_names",
+      type=_parse_and_strip_list_argument,
+      help=("Target device names, separated by comma, not specified means "
+            "including all devices."))
+  execution_parser.add_argument(
+      "--benchmark_presets",
+      type=_parse_benchmark_presets,
+      help=("Presets that select a bundle of benchmarks, separated by comma, "
+            "multiple presets will be union. Available options: "
+            f"{','.join(BENCHMARK_PRESET_MATCHERS.keys())}"))
+
+  compilation_parser = subparser.add_parser(
+      "compilation",
+      parents=[subparser_base],
+      help=("Export serialized list of module generation configs defined for "
+            "compilation statistics."))
+  compilation_parser.set_defaults(handler=_export_compilation_handler)
+
+  return parser.parse_args()
+
+
+def main(args: argparse.Namespace):
+  output_obj = args.handler(args)
+  json_data = json.dumps(output_obj, indent=2)
   if args.output is None:
     print(json_data)
   else:
@@ -139,4 +217,4 @@
 
 
 if __name__ == "__main__":
-  main(parse_arguments())
+  main(_parse_arguments())
diff --git a/build_tools/benchmarks/export_benchmark_config_test.py b/build_tools/benchmarks/export_benchmark_config_test.py
new file mode 100644
index 0000000..3ceddbc
--- /dev/null
+++ b/build_tools/benchmarks/export_benchmark_config_test.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+
+from e2e_test_framework.definitions import common_definitions, iree_definitions
+import export_benchmark_config
+
+COMMON_MODEL = common_definitions.Model(
+    id="tflite",
+    name="model_tflite",
+    tags=[],
+    source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE,
+    source_url="",
+    entry_function="predict",
+    input_types=["1xf32"])
+COMMON_GEN_CONFIG = iree_definitions.ModuleGenerationConfig(
+    imported_model=iree_definitions.ImportedModel.from_model(COMMON_MODEL),
+    compile_config=iree_definitions.CompileConfig(id="1",
+                                                  tags=[],
+                                                  compile_targets=[]))
+COMMON_EXEC_CONFIG = iree_definitions.ModuleExecutionConfig(
+    id="exec",
+    tags=[],
+    loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF,
+    driver=iree_definitions.RuntimeDriver.LOCAL_SYNC)
+
+
+class ExportBenchmarkConfigTest(unittest.TestCase):
+
+  def test_filter_and_group_run_configs_set_all_filters(self):
+    device_spec_a = common_definitions.DeviceSpec(
+        id="dev_a_cpu",
+        device_name="dev_a_cpu",
+        architecture=common_definitions.DeviceArchitecture.RV64_GENERIC,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    device_spec_b = common_definitions.DeviceSpec(
+        id="dev_a_gpu",
+        device_name="dev_a_gpu",
+        architecture=common_definitions.DeviceArchitecture.MALI_VALHALL,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    device_spec_c = common_definitions.DeviceSpec(
+        id="dev_c",
+        device_name="dev_c",
+        architecture=common_definitions.DeviceArchitecture.CUDA_SM80,
+        host_environment=common_definitions.HostEnvironment.LINUX_X86_64)
+    matched_run_config_a = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_a,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    unmatched_run_config_b = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_b,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    matched_run_config_c = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_c,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    matchers = [(lambda config: config.target_device_spec.architecture.
+                 architecture == "cuda"),
+                (lambda config: config.target_device_spec.host_environment.
+                 platform == "android")]
+
+    run_config_map = export_benchmark_config.filter_and_group_run_configs(
+        run_configs=[
+            matched_run_config_a, unmatched_run_config_b, matched_run_config_c
+        ],
+        target_device_names={"dev_a_cpu", "dev_c"},
+        preset_matchers=matchers)
+
+    self.assertEqual(run_config_map, {
+        "dev_a_cpu": [matched_run_config_a],
+        "dev_c": [matched_run_config_c],
+    })
+
+  def test_filter_and_group_run_configs_include_all(self):
+    device_spec_a = common_definitions.DeviceSpec(
+        id="dev_a_cpu",
+        device_name="dev_a_cpu",
+        architecture=common_definitions.DeviceArchitecture.RV64_GENERIC,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    device_spec_b = common_definitions.DeviceSpec(
+        id="dev_a_gpu",
+        device_name="dev_a_gpu",
+        architecture=common_definitions.DeviceArchitecture.MALI_VALHALL,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    device_spec_c = common_definitions.DeviceSpec(
+        id="dev_a_second_gpu",
+        device_name="dev_a_gpu",
+        architecture=common_definitions.DeviceArchitecture.ADRENO_GENERIC,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    run_config_a = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_a,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    run_config_b = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_b,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    run_config_c = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_c,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+
+    run_config_map = export_benchmark_config.filter_and_group_run_configs(
+        run_configs=[run_config_a, run_config_b, run_config_c])
+
+    self.maxDiff = 100000
+
+    self.assertEqual(run_config_map, {
+        "dev_a_cpu": [run_config_a],
+        "dev_a_gpu": [run_config_b, run_config_c],
+    })
+
+  def test_filter_and_group_run_configs_set_target_device_names(self):
+    device_spec_a = common_definitions.DeviceSpec(
+        id="dev_a",
+        device_name="dev_a",
+        architecture=common_definitions.DeviceArchitecture.RV64_GENERIC,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    device_spec_b = common_definitions.DeviceSpec(
+        id="dev_b",
+        device_name="dev_b",
+        architecture=common_definitions.DeviceArchitecture.MALI_VALHALL,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    run_config_a = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_a,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    run_config_b = iree_definitions.E2EModelRunConfig(
+        module_generation_config=COMMON_GEN_CONFIG,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_b,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+
+    run_config_map = export_benchmark_config.filter_and_group_run_configs(
+        run_configs=[run_config_a, run_config_b],
+        target_device_names={"dev_a", "dev_b"})
+
+    self.assertEqual(run_config_map, {
+        "dev_a": [run_config_a],
+        "dev_b": [run_config_b],
+    })
+
+  def test_filter_and_group_run_configs_set_preset_matchers(self):
+    small_model = common_definitions.Model(
+        id="small_model",
+        name="small_model",
+        tags=[],
+        source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE,
+        source_url="",
+        entry_function="predict",
+        input_types=["1xf32"])
+    big_model = common_definitions.Model(
+        id="big_model",
+        name="big_model",
+        tags=[],
+        source_type=common_definitions.ModelSourceType.EXPORTED_TFLITE,
+        source_url="",
+        entry_function="predict",
+        input_types=["1xf32"])
+    compile_config = iree_definitions.CompileConfig(id="1",
+                                                    tags=[],
+                                                    compile_targets=[])
+    small_gen_config = iree_definitions.ModuleGenerationConfig(
+        imported_model=iree_definitions.ImportedModel.from_model(small_model),
+        compile_config=compile_config)
+    big_gen_config = iree_definitions.ModuleGenerationConfig(
+        imported_model=iree_definitions.ImportedModel.from_model(big_model),
+        compile_config=compile_config)
+    device_spec_a = common_definitions.DeviceSpec(
+        id="dev_a",
+        device_name="dev_a",
+        architecture=common_definitions.DeviceArchitecture.RV64_GENERIC,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    device_spec_b = common_definitions.DeviceSpec(
+        id="dev_b",
+        device_name="dev_b",
+        architecture=common_definitions.DeviceArchitecture.MALI_VALHALL,
+        host_environment=common_definitions.HostEnvironment.ANDROID_ARMV8_2_A)
+    run_config_a = iree_definitions.E2EModelRunConfig(
+        module_generation_config=small_gen_config,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_a,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+    run_config_b = iree_definitions.E2EModelRunConfig(
+        module_generation_config=big_gen_config,
+        module_execution_config=COMMON_EXEC_CONFIG,
+        target_device_spec=device_spec_b,
+        input_data=common_definitions.ZEROS_MODEL_INPUT_DATA)
+
+    run_config_map = export_benchmark_config.filter_and_group_run_configs(
+        run_configs=[run_config_a, run_config_b],
+        preset_matchers=[
+            lambda config: config.module_generation_config.imported_model.model.
+            id == "small_model"
+        ])
+
+    self.assertEqual(run_config_map, {
+        "dev_a": [run_config_a],
+    })
+
+
+if __name__ == "__main__":
+  unittest.main()