| # Copyright 2023 The IREE Authors | 
 | # | 
 | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
 | # See https://llvm.org/LICENSE.txt for license information. | 
 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
 |  | 
 | from typing import Any, Callable, Collection, Dict, Union | 
 | import functools | 
 | from pathlib import Path | 
 | from tqdm import tqdm | 
 | import urllib.parse | 
 | import urllib.request | 
 | import os | 
 | from azure.storage.blob import BlobClient, BlobProperties | 
 | import hashlib | 
 | import mmap | 
 | import re | 
 | import logging | 
 |  | 
 | logger = logging.getLogger(__name__) | 
 | # Adjust logging levels. | 
 | logging.basicConfig(level=logging.INFO) | 
 | for log_name, log_obj in logging.Logger.manager.loggerDict.items(): | 
 |     if log_name.startswith("azure"): | 
 |         logging.getLogger(log_name).setLevel(logging.WARNING) | 
 |  | 
 |  | 
 | def show_progress(t): | 
 |     last_b = [0] | 
 |  | 
 |     def update_to(b=1, bsize=1, tsize=None): | 
 |         if tsize is not None: | 
 |             t.total = tsize | 
 |         t.update((b - last_b[0]) * bsize) | 
 |         last_b[0] = b | 
 |  | 
 |     return update_to | 
 |  | 
 |  | 
 | @functools.cache | 
 | def get_artifact_root_dir() -> Path: | 
 |     root_path = os.getenv("IREE_TEST_FILES", default=str(Path.cwd())) + "/artifacts" | 
 |     return Path(os.path.expanduser(root_path)).resolve() | 
 |  | 
 |  | 
 | class ArtifactGroup: | 
 |     """A group of artifacts with a persistent location on disk.""" | 
 |  | 
 |     _INSTANCES: Dict[str, "ArtifactGroup"] = {} | 
 |  | 
 |     def __init__(self, group_name: str): | 
 |         self.group_name = group_name | 
 |         if group_name: | 
 |             self.directory = get_artifact_root_dir() / group_name | 
 |         else: | 
 |             self.directory = get_artifact_root_dir() | 
 |         self.directory.mkdir(parents=True, exist_ok=True) | 
 |  | 
 |     @classmethod | 
 |     def get(cls, group: Union["ArtifactGroup", str]) -> "ArtifactGroup": | 
 |         if isinstance(group, ArtifactGroup): | 
 |             return group | 
 |         try: | 
 |             return cls._INSTANCES[group] | 
 |         except KeyError: | 
 |             instance = ArtifactGroup(group) | 
 |             cls._INSTANCES[group] = instance | 
 |             return instance | 
 |  | 
 |  | 
 | class Artifact: | 
 |     """Some form of artifact materialized to disk.""" | 
 |  | 
 |     def __init__( | 
 |         self, | 
 |         group: Union[ArtifactGroup, str], | 
 |         name: str, | 
 |         depends: Collection["Artifact"] = (), | 
 |     ): | 
 |         self.group = ArtifactGroup.get(group) | 
 |         self.name = name | 
 |         self.depends = tuple(depends) | 
 |  | 
 |     @property | 
 |     def path(self) -> Path: | 
 |         return self.group.directory / self.name | 
 |  | 
 |     def join(self): | 
 |         """Waits for the artifact to become available.""" | 
 |         pass | 
 |  | 
 |     def __str__(self): | 
 |         return str(self.path) | 
 |  | 
 |  | 
 | class ProducedArtifact(Artifact): | 
 |     def __init__( | 
 |         self, | 
 |         group: Union[ArtifactGroup, str], | 
 |         name: str, | 
 |         callback: Callable[["ProducedArtifact"], Any], | 
 |         *, | 
 |         depends: Collection["Artifact"] = (), | 
 |     ): | 
 |         self.group = ArtifactGroup.get(group) | 
 |         super().__init__(group, name, depends) | 
 |         self.name = name | 
 |         self.callback = callback | 
 |  | 
 |     def start(self) -> "ProducedArtifact": | 
 |         self.callback(self) | 
 |         if not self.path.exists(): | 
 |             raise RuntimeError( | 
 |                 f"Artifact {self} succeeded generation but was not produced" | 
 |             ) | 
 |         return self | 
 |  | 
 |  | 
 | class FetchedArtifact(ProducedArtifact): | 
 |     """Represents an artifact that is to be fetched.""" | 
 |  | 
 |     def __init__(self, group: Union[ArtifactGroup, str], url: str): | 
 |         name = Path(urllib.parse.urlparse(url).path).name | 
 |         super().__init__(group, name, FetchedArtifact._callback) | 
 |         self.url = url | 
 |  | 
 |     def human_readable_size(self, size, decimal_places=2): | 
 |         for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]: | 
 |             if size < 1024.0 or unit == "PiB": | 
 |                 break | 
 |             size /= 1024.0 | 
 |         return f"{size:.{decimal_places}f} {unit}" | 
 |  | 
 |     def get_azure_md5(self, remote_file: str, azure_blob_properties: BlobProperties): | 
 |         """Gets the content_md5 hash for a blob on Azure, if available.""" | 
 |         content_settings = azure_blob_properties.get("content_settings") | 
 |         if not content_settings: | 
 |             return None | 
 |         azure_md5 = content_settings.get("content_md5") | 
 |         if not azure_md5: | 
 |             logger.warning( | 
 |                 f"  Remote file '{remote_file}' on Azure is missing the " | 
 |                 "'content_md5' property, can't check if local matches remote" | 
 |             ) | 
 |         return azure_md5 | 
 |  | 
 |     def get_local_md5(self, local_file_path: Path): | 
 |         """Gets the content_md5 hash for a lolca file, if it exists.""" | 
 |         if not local_file_path.exists() or local_file_path.stat().st_size == 0: | 
 |             return None | 
 |  | 
 |         with open(local_file_path) as file, mmap.mmap( | 
 |             file.fileno(), 0, access=mmap.ACCESS_READ | 
 |         ) as file: | 
 |             return hashlib.md5(file).digest() | 
 |  | 
 |     def download_azure_artifact(self: "FetchedArtifact"): | 
 |         """ | 
 |         Checks the hashes between the local file and azure file. | 
 |         """ | 
 |         remote_file_name = self.url.rsplit("/", 1)[-1] | 
 |  | 
 |         # Extract path components from Azure URL to use with the Azure Storage Blobs | 
 |         # client library for Python (https://pypi.org/project/azure-storage-blob/). | 
 |         # | 
 |         # For example: | 
 |         #   https://sharkpublic.blob.core.windows.net/sharkpublic/path/to/blob.txt | 
 |         #                                            ^           ^ | 
 |         #   account_url:    https://sharkpublic.blob.core.windows.net | 
 |         #   container_name: sharkpublic | 
 |         #   blob_name:      path/to/blob.txt | 
 |         result = re.search(r"(https.+\.net)/([^/]+)/(.+)", self.url) | 
 |         account_url = result.groups()[0] | 
 |         container_name = result.groups()[1] | 
 |         blob_name = result.groups()[2] | 
 |  | 
 |         with BlobClient( | 
 |             account_url, | 
 |             container_name, | 
 |             blob_name, | 
 |             max_chunk_get_size=1024 * 1024 * 32,  # 32 MiB | 
 |             max_single_get_size=1024 * 1024 * 32,  # 32 MiB | 
 |         ) as blob_client: | 
 |             blob_properties = blob_client.get_blob_properties() | 
 |             blob_size_str = self.human_readable_size(blob_properties.size) | 
 |             azure_md5 = self.get_azure_md5(self.url, blob_properties) | 
 |  | 
 |             local_md5 = self.get_local_md5(self.path) | 
 |  | 
 |             if azure_md5 and azure_md5 == local_md5: | 
 |                 logger.info( | 
 |                     f"  Skipping '{remote_file_name}' download ({blob_size_str}) " | 
 |                     "- local MD5 hash matches" | 
 |                 ) | 
 |                 return | 
 |  | 
 |             if not local_md5: | 
 |                 logger.info( | 
 |                     f"  Downloading '{remote_file_name}' ({blob_size_str}) " | 
 |                     f"to '{self.path}'" | 
 |                 ) | 
 |                 with open(self.path, mode="wb") as local_blob: | 
 |                     download_stream = blob_client.download_blob(max_concurrency=4) | 
 |                     local_blob.write(download_stream.readall()) | 
 |             else: | 
 |                 logger.info( | 
 |                     f"  Downloading '{remote_file_name}' ({blob_size_str}) " | 
 |                     f"to '{self.path}' (local MD5 does not match)" | 
 |                 ) | 
 |                 with open(self.path, mode="wb") as local_blob: | 
 |                     download_stream = blob_client.download_blob(max_concurrency=4) | 
 |                     local_blob.write(download_stream.readall()) | 
 |  | 
 |     @staticmethod | 
 |     def _callback(self: "FetchedArtifact"): | 
 |         if "blob.core.windows.net" in self.url: | 
 |             self.download_azure_artifact() | 
 |         else: | 
 |             raise NotImplementedError( | 
 |                 f"Unsupported fetched artifact URL schema for '{self.url}'" | 
 |             ) | 
 |  | 
 |  | 
 | class StreamArtifact(Artifact): | 
 |     def __init__(self, group: Union[ArtifactGroup, str], name: str): | 
 |         super().__init__(group, name) | 
 |         self.io = open(self.path, "ab", buffering=0) | 
 |  | 
 |     def __del__(self): | 
 |         self.io.close() | 
 |  | 
 |     def write_line(self, line: Union[str, bytes]): | 
 |         contents = line if isinstance(line, bytes) else line.encode() | 
 |         self.io.write(contents + b"\n") |