Add benchmark preset to benchmark config export tool (#11561)

Co-authored-by: Geoffrey Martin-Noble <gcmn@google.com>
diff --git a/build_tools/benchmarks/export_benchmark_config.py b/build_tools/benchmarks/export_benchmark_config.py
index 975f9f3..4fec45d 100755
--- a/build_tools/benchmarks/export_benchmark_config.py
+++ b/build_tools/benchmarks/export_benchmark_config.py
@@ -27,24 +27,64 @@
 import collections
 import dataclasses
 import json
+from typing import Callable, Dict, List
 
 from benchmark_suites.iree import benchmark_collections
+from e2e_test_framework.definitions import common_definitions, iree_definitions
 from e2e_test_framework import serialization
 from e2e_test_artifacts import iree_artifacts
 
+PresetMatcher = Callable[[iree_definitions.E2EModelRunConfig], bool]
+BENCHMARK_PRESET_MATCHERS: Dict[str, PresetMatcher] = {
+    "x86_64":
+        lambda config: config.target_device_spec.architecture.architecture ==
+        "x86_64",
+    "cuda":
+        lambda config: config.target_device_spec.architecture.architecture ==
+        "cuda",
+    "android-cpu":
+        lambda config:
+        (config.target_device_spec.architecture.type == common_definitions.
+         ArchitectureType.CPU and config.target_device_spec.host_environment.
+         platform == "android"),
+    "android-gpu":
+        lambda config:
+        (config.target_device_spec.architecture.type == common_definitions.
+         ArchitectureType.GPU and config.target_device_spec.host_environment.
+         platform == "android"),
+}
+
 
 def parse_arguments():
   """Parses command-line options."""
 
+  def parse_and_strip_list_argument(arg) -> List[str]:
+    return [part.strip() for part in arg.split(",")]
+
+  def parse_benchmark_presets(arg) -> List[PresetMatcher]:
+    matchers = []
+    for preset in parse_and_strip_list_argument(arg):
+      matcher = BENCHMARK_PRESET_MATCHERS.get(preset)
+      if matcher is None:
+        raise argparse.ArgumentTypeError(
+            f"Unrecognized benchmark preset: '{preset}'.")
+      matchers.append(matcher)
+    return matchers
+
   parser = argparse.ArgumentParser(
-      description="Exports JSON config for benchmarking.")
+      description="Exports JSON config for benchmarking. Filters can be "
+      "specified jointly to select a subset of benchmarks.")
   parser.add_argument(
-      "--target_device_name",
-      type=str,
-      action="append",
-      dest="target_device_names",
-      help=("Target device name, can be specified multiple times. "
-            "Not specified means including all devices."))
+      "--target_device_names",
+      type=parse_and_strip_list_argument,
+      help=("Target device names, separated by comma, not specified means "
+            "including all devices."))
+  parser.add_argument(
+      "--benchmark_presets",
+      type=parse_benchmark_presets,
+      help=("Presets that select a bundle of benchmarks, separated by comma, "
+            "multiple presets will be union. Available options: "
+            f"{','.join(BENCHMARK_PRESET_MATCHERS.keys())}"))
   parser.add_argument("--output",
                       type=pathlib.Path,
                       help="Path to write the JSON output.")
@@ -57,11 +97,18 @@
 
   target_device_names = (set(args.target_device_names)
                          if args.target_device_names is not None else None)
+  preset_matchers = args.benchmark_presets
+
   grouped_run_config_map = collections.defaultdict(list)
   for run_config in all_run_configs:
     device_name = run_config.target_device_spec.device_name
-    if target_device_names is None or device_name in target_device_names:
-      grouped_run_config_map[device_name].append(run_config)
+    if (target_device_names is not None and
+        device_name not in target_device_names):
+      continue
+    if (preset_matchers is not None and
+        not any(matcher(run_config) for matcher in preset_matchers)):
+      continue
+    grouped_run_config_map[device_name].append(run_config)
 
   output_map = {}
   for device_name, run_configs in grouped_run_config_map.items():