blob: ce9b3f5fd39414a596635f7d9bfed6eb3033033a [file] [log] [blame]
## Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Represents the directory structure of common artifacts."""
from dataclasses import dataclass
from typing import Sequence, OrderedDict
import collections
import pathlib
import urllib.parse
from e2e_test_framework.definitions import common_definitions
# Archive extensions used to pack models.
ARCHIVE_FILE_EXTENSIONS = [".tar", ".gz"]
@dataclass(frozen=True)
class ModelArtifact(object):
model: common_definitions.Model
file_path: pathlib.PurePath
@dataclass(frozen=True)
class ArtifactsRoot(object):
# Map of model artifacts, keyed by model id.
model_artifact_map: OrderedDict[str, ModelArtifact]
def generate_artifacts_root(
parent_path: pathlib.PurePath,
models: Sequence[common_definitions.Model]) -> ArtifactsRoot:
"""Generates model directory structure."""
model_artifact_map = collections.OrderedDict()
for model in models:
if model.id in model_artifact_map:
raise ValueError(f"Duplicate model {model.id}.")
model_url = urllib.parse.urlparse(model.source_url)
# Drop the archive extensions.
file_exts = pathlib.PurePath(model_url.path).suffixes
while len(file_exts) > 0 and file_exts[-1] in ARCHIVE_FILE_EXTENSIONS:
file_exts.pop()
model_ext = "".join(file_exts)
# Model path: <model_artifacts_root>/<model_id>_<model_name><model_ext>
file_path = parent_path / f"{model.id}_{model.name}{model_ext}"
model_artifact_map[model.id] = ModelArtifact(model=model,
file_path=file_path)
return ArtifactsRoot(model_artifact_map=model_artifact_map)