Add validators for module generation and run configs (#14234)
diff --git a/build_tools/python/benchmark_suites/CMakeLists.txt b/build_tools/python/benchmark_suites/CMakeLists.txt
new file mode 100644
index 0000000..1e70c9d
--- /dev/null
+++ b/build_tools/python/benchmark_suites/CMakeLists.txt
@@ -0,0 +1,7 @@
+# Copyright 2022 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_add_all_subdirs()
diff --git a/build_tools/python/benchmark_suites/iree/CMakeLists.txt b/build_tools/python/benchmark_suites/iree/CMakeLists.txt
new file mode 100644
index 0000000..1930c71
--- /dev/null
+++ b/build_tools/python/benchmark_suites/iree/CMakeLists.txt
@@ -0,0 +1,12 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_build_tools_py_test(
+ NAME
+ benchmark_collections_test
+ SRC
+ "benchmark_collections_test.py"
+)
diff --git a/build_tools/python/benchmark_suites/iree/benchmark_collections.py b/build_tools/python/benchmark_suites/iree/benchmark_collections.py
index d2e1cc2..5e1838a 100644
--- a/build_tools/python/benchmark_suites/iree/benchmark_collections.py
+++ b/build_tools/python/benchmark_suites/iree/benchmark_collections.py
@@ -5,7 +5,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Generates all benchmarks."""
-from typing import List, Tuple
+import re
+from typing import List, Tuple, Sequence
from e2e_test_artifacts import iree_artifacts
from e2e_test_framework.definitions import iree_definitions
@@ -22,6 +23,69 @@
)
COMPILE_STATS_ID_SUFFIX = "-compile-stats"
+ALLOWED_NAME_FORMAT = re.compile(r"^[0-9a-zA-Z.,\-_()\[\] @]+$")
+
+
+def validate_gen_configs(
+ gen_configs: Sequence[iree_definitions.ModuleGenerationConfig],
+):
+ """Check the uniqueness and name format of module generation configs."""
+
+ ids_to_configs = {}
+ names_to_configs = {}
+ for gen_config in gen_configs:
+ if not ALLOWED_NAME_FORMAT.match(gen_config.name):
+ raise ValueError(
+ f"Module generation config name: '{gen_config.name}' doesn't"
+ f" follow the format '{ALLOWED_NAME_FORMAT.pattern}'"
+ )
+
+ if gen_config.composite_id in ids_to_configs:
+ raise ValueError(
+ "Two module generation configs have the same ID:\n\n"
+ f"{repr(gen_config)}\n\n"
+ f"{repr(ids_to_configs[gen_config.composite_id])}"
+ )
+ ids_to_configs[gen_config.composite_id] = gen_config
+
+ if gen_config.name in names_to_configs:
+ raise ValueError(
+ "Two module generation configs have the same name:\n\n"
+ f"{repr(gen_config)}\n\n"
+ f"{repr(names_to_configs[gen_config.name])}"
+ )
+ names_to_configs[gen_config.name] = gen_config
+
+
+def validate_run_configs(
+ run_configs: Sequence[iree_definitions.E2EModelRunConfig],
+):
+ """Check the uniqueness and name format of E2E model run configs."""
+
+ ids_to_configs = {}
+ names_to_configs = {}
+ for run_config in run_configs:
+ if not ALLOWED_NAME_FORMAT.match(run_config.name):
+ raise ValueError(
+ f"E2E model run config name: '{run_config.name}' doesn't"
+ f" follow the format '{ALLOWED_NAME_FORMAT.pattern}'"
+ )
+
+ if run_config.composite_id in ids_to_configs:
+ raise ValueError(
+ "Two e2e model run configs have the same ID:\n\n"
+ f"{repr(run_config)}\n\n"
+ f"{repr(ids_to_configs[run_config.composite_id])}"
+ )
+ ids_to_configs[run_config.composite_id] = run_config
+
+ if run_config.name in names_to_configs:
+ raise ValueError(
+ "Two e2e model run configs have the same name:\n\n"
+ f"{repr(run_config)}\n\n"
+ f"{repr(names_to_configs[run_config.name])}"
+ )
+ names_to_configs[run_config.name] = run_config
def generate_benchmarks() -> (
@@ -48,6 +112,9 @@
all_gen_configs += module_generation_configs
all_run_configs += run_configs
+ validate_gen_configs(all_gen_configs)
+ validate_run_configs(all_run_configs)
+
compile_stats_gen_configs = []
# For now we simply track compilation statistics of all modules.
for gen_config in all_gen_configs:
diff --git a/build_tools/python/benchmark_suites/iree/benchmark_collections_test.py b/build_tools/python/benchmark_suites/iree/benchmark_collections_test.py
new file mode 100644
index 0000000..297abba
--- /dev/null
+++ b/build_tools/python/benchmark_suites/iree/benchmark_collections_test.py
@@ -0,0 +1,252 @@
+## Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import unittest
+
+from benchmark_suites.iree import benchmark_collections
+from e2e_test_framework.definitions import common_definitions, iree_definitions
+
+MODEL = common_definitions.Model(
+ id="dummy-model-1234",
+ name="dummy-model",
+ tags=[],
+ source_type=common_definitions.ModelSourceType.EXPORTED_STABLEHLO_MLIR,
+ source_url="https://example.com/xyz.mlir",
+ entry_function="main",
+ input_types=["1xf32"],
+)
+IMPORTED_MODEL = iree_definitions.ImportedModel.from_model(MODEL)
+COMPILE_CONFIG = iree_definitions.CompileConfig.build(
+ id="dummy-config-1234",
+ compile_targets=[
+ iree_definitions.CompileTarget(
+ target_architecture=common_definitions.DeviceArchitecture.X86_64_CASCADELAKE,
+ target_backend=iree_definitions.TargetBackend.LLVM_CPU,
+ target_abi=iree_definitions.TargetABI.LINUX_GNU,
+ )
+ ],
+)
+EXEC_CONFIG = iree_definitions.ModuleExecutionConfig.build(
+ id="dummy-exec-1234",
+ loader=iree_definitions.RuntimeLoader.EMBEDDED_ELF,
+ driver=iree_definitions.RuntimeDriver.LOCAL_SYNC,
+)
+DEVICE_SPEC = common_definitions.DeviceSpec.build(
+ id="dummy-device-1234",
+ device_name="dummy-device",
+ architecture=common_definitions.DeviceArchitecture.CUDA_SM80,
+ host_environment=common_definitions.HostEnvironment.LINUX_X86_64,
+)
+
+
+class BenchmarkCollectionsTest(unittest.TestCase):
+ def test_validate_gen_configs(self):
+ config_a = iree_definitions.ModuleGenerationConfig(
+ composite_id="a",
+ name="model-name (A.RCH)[tag_0,tag_1]",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+ config_b = iree_definitions.ModuleGenerationConfig(
+ composite_id="b",
+ name="name-b",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+
+ benchmark_collections.validate_gen_configs([config_a, config_b])
+
+ def test_validate_gen_configs_duplicate_name(self):
+ config_a = iree_definitions.ModuleGenerationConfig(
+ composite_id="a",
+ name="name",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+ config_b = iree_definitions.ModuleGenerationConfig(
+ composite_id="b",
+ name="name",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+
+ self.assertRaises(
+ ValueError,
+ lambda: benchmark_collections.validate_gen_configs([config_a, config_b]),
+ )
+
+ def test_validate_gen_configs_disallowed_characters(self):
+ config_a = iree_definitions.ModuleGenerationConfig(
+ composite_id="a",
+ name="name+a",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+
+ self.assertRaises(
+ ValueError,
+ lambda: benchmark_collections.validate_gen_configs([config_a]),
+ )
+
+ def test_validate_gen_configs_duplicate_id(self):
+ config_a = iree_definitions.ModuleGenerationConfig(
+ composite_id="x",
+ name="name-a",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+ config_b = iree_definitions.ModuleGenerationConfig(
+ composite_id="x",
+ name="name-b",
+ tags=[],
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ compile_flags=[],
+ )
+
+ self.assertRaises(
+ ValueError,
+ lambda: benchmark_collections.validate_gen_configs([config_a, config_b]),
+ )
+
+ def test_validate_run_configs(self):
+ config_a = iree_definitions.E2EModelRunConfig(
+ composite_id="a",
+ name="model-name (A.RCH)[tag_0,tag_1] @ device",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+ config_b = iree_definitions.E2EModelRunConfig(
+ composite_id="b",
+ name="name-b",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+
+ benchmark_collections.validate_run_configs([config_a, config_b])
+
+ def test_validate_run_configs_duplicate_name(self):
+ config_a = iree_definitions.E2EModelRunConfig(
+ composite_id="a",
+ name="name",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+ config_b = iree_definitions.E2EModelRunConfig(
+ composite_id="b",
+ name="name",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+
+ self.assertRaises(
+ ValueError,
+ lambda: benchmark_collections.validate_run_configs([config_a, config_b]),
+ )
+
+ def test_validate_run_configs_duplicate_id(self):
+ config_a = iree_definitions.E2EModelRunConfig(
+ composite_id="x",
+ name="name-a",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+ config_b = iree_definitions.E2EModelRunConfig(
+ composite_id="x",
+ name="name-b",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+
+ self.assertRaises(
+ ValueError,
+ lambda: benchmark_collections.validate_run_configs([config_a, config_b]),
+ )
+
+ def test_validate_run_configs_disallowed_characters(self):
+ config = iree_definitions.E2EModelRunConfig(
+ composite_id="x",
+ name="name+a",
+ tags=[],
+ module_generation_config=iree_definitions.ModuleGenerationConfig.build(
+ imported_model=IMPORTED_MODEL,
+ compile_config=COMPILE_CONFIG,
+ ),
+ module_execution_config=EXEC_CONFIG,
+ target_device_spec=DEVICE_SPEC,
+ input_data=common_definitions.ZEROS_MODEL_INPUT_DATA,
+ tool=iree_definitions.E2EModelRunTool.IREE_BENCHMARK_MODULE,
+ run_flags=[],
+ )
+
+ self.assertRaises(
+ ValueError,
+ lambda: benchmark_collections.validate_run_configs([config]),
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/build_tools/python/e2e_test_framework/definitions/common_definitions.py b/build_tools/python/e2e_test_framework/definitions/common_definitions.py
index f74aef2..9ad166d 100644
--- a/build_tools/python/e2e_test_framework/definitions/common_definitions.py
+++ b/build_tools/python/e2e_test_framework/definitions/common_definitions.py
@@ -5,11 +5,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Common classes for benchmark definitions."""
+import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Sequence
+
from e2e_test_framework import serialization, unique_ids
-import dataclasses
class ArchitectureType(Enum):
@@ -152,10 +153,10 @@
cls,
id: str,
device_name: str,
- tags: Sequence[str],
host_environment: HostEnvironment,
architecture: DeviceArchitecture,
device_parameters: Optional[Sequence[str]] = None,
+ tags: Sequence[str] = (),
):
tag_part = ",".join(tags)
# Format: <device_name>[<tag>,...]
diff --git a/build_tools/python/e2e_test_framework/definitions/iree_definitions.py b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py
index 5230164..7b05c63 100644
--- a/build_tools/python/e2e_test_framework/definitions/iree_definitions.py
+++ b/build_tools/python/e2e_test_framework/definitions/iree_definitions.py
@@ -91,9 +91,9 @@
def build(
cls,
id: str,
- tags: Sequence[str],
compile_targets: Sequence[CompileTarget],
extra_flags: Optional[Sequence[str]] = None,
+ tags: Sequence[str] = (),
):
target_part = ",".join(str(target) for target in compile_targets)
tag_part = ",".join(tags)
@@ -128,10 +128,10 @@
def build(
cls,
id: str,
- tags: Sequence[str],
loader: RuntimeLoader,
driver: RuntimeDriver,
extra_flags: Optional[Sequence[str]] = None,
+ tags: Sequence[str] = (),
):
runtime_part = f"{driver.name}({loader.name})".lower()
tag_part = ",".join(tags)
@@ -356,7 +356,7 @@
target_device_spec: common_definitions.DeviceSpec,
input_data: common_definitions.ModelInputData,
tool: E2EModelRunTool,
- tags: Optional[Sequence[str]] = None,
+ tags: Sequence[str] = (),
):
composite_id = unique_ids.hash_composite_id(
[