blob: 975f9f3e229bff3c7066f9e9968d285680f95641 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Exports JSON config for benchmarking.
The exported JSON is a list of object:
[
<target device name>: {
host_environment: HostEnvironment,
module_dir_paths: [<paths of dependent module directories>],
run_configs: [E2EModelRunConfig]
},
...
]
"""
import sys
import pathlib
# Add build_tools python dir to the search path.
sys.path.insert(0, str(pathlib.Path(__file__).parent.with_name("python")))
import argparse
import collections
import dataclasses
import json
from benchmark_suites.iree import benchmark_collections
from e2e_test_framework import serialization
from e2e_test_artifacts import iree_artifacts
def parse_arguments():
"""Parses command-line options."""
parser = argparse.ArgumentParser(
description="Exports JSON config for benchmarking.")
parser.add_argument(
"--target_device_name",
type=str,
action="append",
dest="target_device_names",
help=("Target device name, can be specified multiple times. "
"Not specified means including all devices."))
parser.add_argument("--output",
type=pathlib.Path,
help="Path to write the JSON output.")
return parser.parse_args()
def main(args: argparse.Namespace):
_, all_run_configs = benchmark_collections.generate_benchmarks()
target_device_names = (set(args.target_device_names)
if args.target_device_names is not None else None)
grouped_run_config_map = collections.defaultdict(list)
for run_config in all_run_configs:
device_name = run_config.target_device_spec.device_name
if target_device_names is None or device_name in target_device_names:
grouped_run_config_map[device_name].append(run_config)
output_map = {}
for device_name, run_configs in grouped_run_config_map.items():
host_environments = set(run_config.target_device_spec.host_environment
for run_config in run_configs)
if len(host_environments) > 1:
raise ValueError(
"Device specs of the same device should have the same host environment."
)
host_environment = host_environments.pop()
all_module_dir_paths = (str(
iree_artifacts.get_module_dir_path(config.module_generation_config))
for config in run_configs)
module_dir_paths = sorted(set(all_module_dir_paths))
output_map[device_name] = {
"host_environment": dataclasses.asdict(host_environment),
"module_dir_paths": module_dir_paths,
"run_configs": serialization.serialize_and_pack(run_configs),
}
json_data = json.dumps(output_map, indent=2)
if args.output is None:
print(json_data)
else:
args.output.write_text(json_data)
if __name__ == "__main__":
main(parse_arguments())