[iree.build] Make the fetch_http action more robust. (#19330)
* Downloads to a staging file and then atomically renames into place,
avoiding potential for partial downloads.
* Reports completion percent as part of the console updates.
* Persists metadata for the source URL and will refetch if changed.
* Fixes an error handling test for the onnx mnist_builder that missed
the prior update.
More sophistication is possible but this brings it up to min-viable from
a usability perspective.
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/bindings/python/iree/build/executor.py b/compiler/bindings/python/iree/build/executor.py
index c0fefe8..c0463b1 100644
--- a/compiler/bindings/python/iree/build/executor.py
+++ b/compiler/bindings/python/iree/build/executor.py
@@ -8,6 +8,7 @@
import concurrent.futures
import enum
+import json
import math
import multiprocessing
import os
@@ -128,6 +129,7 @@
self.failed_deps: set["BuildDependency"] = set()
self.stderr = stderr
self.reporter = reporter
+ self.metadata_lock = threading.RLock()
BuildContext("", self)
def check_path_not_exists(self, path: str, for_entity):
@@ -160,6 +162,7 @@
return existing
def write_status(self, message: str):
+ self.reporter.reset_display()
print(message, file=self.stderr)
def get_root(self, namespace: FileNamespace) -> Path:
@@ -294,6 +297,9 @@
self.future.set_result(self)
+BuildFileMetadata = dict[str, str | int | bool | float]
+
+
class BuildFile(BuildDependency):
"""Generated file in the build tree."""
@@ -322,6 +328,35 @@
path.parent.mkdir(parents=True, exist_ok=True)
return path
+ def access_metadata(
+ self,
+ mutation_callback: Callable[[BuildFileMetadata], bool] | None = None,
+ ) -> BuildFileMetadata:
+ """Accesses persistent metadata about the build file.
+
+ This is intended for the storage of small amounts of metadata relevant to the
+ build system for performing up-to-date checks and the like.
+
+ If a `mutation_callback=` is provided, then any modifications it makes will be
+ persisted prior to returning. Using a callback in this fashion holds a lock
+ and avoids data races. If the callback returns True, it is persisted.
+ """
+ with self.executor.metadata_lock:
+ metadata = _load_metadata(self.executor)
+ path_metadata = metadata.get("paths")
+ if path_metadata is None:
+ path_metadata = {}
+ metadata["paths"] = path_metadata
+ file_key = f"{self.namespace}/{self.path}"
+ file_metadata = path_metadata.get(file_key)
+ if file_metadata is None:
+ file_metadata = {}
+ path_metadata[file_key] = file_metadata
+ if mutation_callback:
+ if mutation_callback(file_metadata):
+ _save_metadata(self.executor, metadata)
+ return file_metadata
+
def __repr__(self):
return f"BuildFile[{self.namespace}]({self.path})"
@@ -658,3 +693,20 @@
# Type aliases.
BuildFileLike = BuildFile | str
+
+# Private utilities.
+_METADATA_FILENAME = ".metadata.json"
+
+
+def _load_metadata(executor: Executor) -> dict:
+ path = executor.output_dir / _METADATA_FILENAME
+ if not path.exists():
+ return {}
+ with open(path, "rb") as f:
+ return json.load(f)
+
+
+def _save_metadata(executor: Executor, metadata: dict):
+ path = executor.output_dir / _METADATA_FILENAME
+ with open(path, "wt") as f:
+ json.dump(metadata, f, sort_keys=True, indent=2)
diff --git a/compiler/bindings/python/iree/build/net_actions.py b/compiler/bindings/python/iree/build/net_actions.py
index da74d9a..7a262bf 100644
--- a/compiler/bindings/python/iree/build/net_actions.py
+++ b/compiler/bindings/python/iree/build/net_actions.py
@@ -7,7 +7,7 @@
import urllib.error
import urllib.request
-from iree.build.executor import BuildAction, BuildContext, BuildFile
+from iree.build.executor import BuildAction, BuildContext, BuildFile, BuildFileMetadata
__all__ = [
"fetch_http",
@@ -29,11 +29,49 @@
super().__init__(**kwargs)
self.url = url
self.output_file = output_file
+ self.original_desc = self.desc
def _invoke(self):
+ # Determine whether metadata indicates that fetch is needed.
path = self.output_file.get_fs_path()
+ needs_fetch = False
+ existing_metadata = self.output_file.access_metadata()
+ existing_url = existing_metadata.get("fetch_http.url")
+ if existing_url != self.url:
+ needs_fetch = True
+
+ # Always fetch if empty or absent.
+ if not path.exists() or path.stat().st_size == 0:
+ needs_fetch = True
+
+ # Bail if already obtained.
+ if not needs_fetch:
+ return
+
+ # Download to a staging file.
+ stage_path = path.with_name(f".{path.name}.download")
self.executor.write_status(f"Fetching URL: {self.url} -> {path}")
+
+ def reporthook(received_blocks: int, block_size: int, total_size: int):
+ received_size = received_blocks * block_size
+ if total_size == 0:
+ self.desc = f"{self.original_desc} ({received_size} bytes received)"
+ else:
+ complete_percent = round(100 * received_size / total_size)
+ self.desc = f"{self.original_desc} ({complete_percent}% complete)"
+
try:
- urllib.request.urlretrieve(self.url, str(path))
+ urllib.request.urlretrieve(self.url, str(stage_path), reporthook=reporthook)
except urllib.error.HTTPError as e:
raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None
+ finally:
+ self.desc = self.original_desc
+
+ # Commit the download.
+ def commit(metadata: BuildFileMetadata) -> bool:
+ metadata["fetch_http.url"] = self.url
+ path.unlink(missing_ok=True)
+ stage_path.rename(path)
+ return True
+
+ self.output_file.access_metadata(commit)
diff --git a/compiler/bindings/python/test/build_api/CMakeLists.txt b/compiler/bindings/python/test/build_api/CMakeLists.txt
index 6dfcc38..6b9916c 100644
--- a/compiler/bindings/python/test/build_api/CMakeLists.txt
+++ b/compiler/bindings/python/test/build_api/CMakeLists.txt
@@ -20,3 +20,10 @@
SRCS
"basic_test.py"
)
+
+iree_py_test(
+ NAME
+ net_test
+ SRCS
+ "net_test.py"
+)
diff --git a/compiler/bindings/python/test/build_api/mnist_builder_test.py b/compiler/bindings/python/test/build_api/mnist_builder_test.py
index 7b1f641..60f750c 100644
--- a/compiler/bindings/python/test/build_api/mnist_builder_test.py
+++ b/compiler/bindings/python/test/build_api/mnist_builder_test.py
@@ -90,10 +90,7 @@
mod = load_build_module(THIS_DIR / "mnist_builder.py")
out_file = io.StringIO()
err_file = io.StringIO()
- with self.assertRaisesRegex(
- IOError,
- re.escape("Failed to fetch URL 'https://github.com/iree-org/doesnotexist'"),
- ):
+ with self.assertRaises(SystemExit):
iree_build_main(
mod,
args=[
@@ -104,6 +101,7 @@
stdout=out_file,
stderr=err_file,
)
+ self.assertIn("ERROR:", err_file.getvalue())
def testBuildNonDefaultSubTarget(self):
mod = load_build_module(THIS_DIR / "mnist_builder.py")
diff --git a/compiler/bindings/python/test/build_api/net_test.py b/compiler/bindings/python/test/build_api/net_test.py
new file mode 100644
index 0000000..6e10c7b
--- /dev/null
+++ b/compiler/bindings/python/test/build_api/net_test.py
@@ -0,0 +1,100 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import io
+import os
+from pathlib import Path
+import tempfile
+import unittest
+
+from iree.build import *
+from iree.build.executor import BuildContext
+from iree.build.test_actions import ExecuteOutOfProcessThunkAction
+
+
+TEST_URL = None
+TEST_URL_1 = "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json"
+TEST_URL_2 = "https://huggingface.co/google-bert/bert-base-cased/resolve/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer_config.json"
+
+
+@entrypoint
+def tokenizer_via_http():
+ return fetch_http(
+ name="tokenizer.json",
+ url=TEST_URL,
+ )
+
+
+class BasicTest(unittest.TestCase):
+ def setUp(self):
+ self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
+ self._temp_dir.__enter__()
+ self.output_path = Path(self._temp_dir.name)
+
+ def tearDown(self) -> None:
+ self._temp_dir.__exit__(None, None, None)
+
+ def test_fetch_http(self):
+ # This just does a sanity check that rich console mode does not crash. Actual
+ # behavior can really only be completely verified visually.
+ out = None
+ err = None
+ global TEST_URL
+ path = self.output_path / "genfiles" / "tokenizer_via_http" / "tokenizer.json"
+
+ def run():
+ nonlocal out
+ nonlocal err
+ try:
+ out_io = io.StringIO()
+ err_io = io.StringIO()
+ iree_build_main(
+ args=[
+ "tokenizer_via_http",
+ "--output-dir",
+ str(self.output_path),
+ "--test-force-console",
+ ],
+ stderr=err_io,
+ stdout=out_io,
+ )
+ finally:
+ out = out_io.getvalue()
+ err = err_io.getvalue()
+ print(f"::test_fetch_http err: {err!r}")
+ print(f"::test_fetch_http out: {out!r}")
+
+ def assertExists():
+ self.assertTrue(path.exists(), msg=f"Path {path} exists")
+
+ # First run should fetch.
+ TEST_URL = TEST_URL_1
+ run()
+ self.assertIn("Fetching URL: https://", err)
+ assertExists()
+
+ # Second run should not fetch.
+ TEST_URL = TEST_URL_1
+ run()
+ self.assertNotIn("Fetching URL: https://", err)
+ assertExists()
+
+ # Fetching a different URL should download again.
+ TEST_URL = TEST_URL_2
+ run()
+ self.assertIn("Fetching URL: https://", err)
+ assertExists()
+
+ # Removing the file should fetch again.
+ TEST_URL = TEST_URL_2
+ path.unlink()
+ run()
+ self.assertIn("Fetching URL: https://", err)
+ assertExists()
+
+
+if __name__ == "__main__":
+ unittest.main()