blob: 8419f740a3b723bcb9423e2f4004c27ed7bd72c9 [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Download IREE host compiler from the snapshot release."""
import errno
import os
import sys
import shutil
import subprocess
import tarfile
import time
import argparse
import urllib
from pathlib import Path
import requests
import wget
def download_artifact(assets, keywords, out_dir):
"""Download the artifact from the asset list based on the keyword."""
# Find the linux tarball and download it.
artifact_match = False
for asset in assets:
download_url = asset["browser_download_url"]
artifact_name = asset["name"]
if all(x in artifact_name for x in keywords):
artifact_match = True
break
if not artifact_match:
print(f"{keywords[0]} is not found")
sys.exit(1)
print(f"\nDownload {artifact_name} from {download_url}\n")
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
out_file = os.path.join(out_dir, artifact_name)
num_retries = 3
for i in range(num_retries + 1):
try:
wget.download(download_url, out=out_file)
break
except (urllib.error.HTTPError, ConnectionError) as e:
if i == num_retries:
raise
print(f"{e}\nDownload failed. Retrying...")
time.sleep(5)
return out_file
def main():
""" Download IREE host compiler from the snapshot release."""
pin_toolchains = os.getenv("PIN_TOOLCHAINS", '').lower().split(' ')
if "iree" in pin_toolchains:
print()
print("****************************************************")
print("* *")
print("* PIN_TOOLCHAINS includes iree! Skipping the *")
print("* download of the latest IREE compiler binaries. *")
print("* Please DO NOT file bugs for IREE mis-behavior! *")
print("* *")
print("****************************************************")
print()
sys.exit(0)
parser = argparse.ArgumentParser(
description="Download IREE host compiler from snapshot releases")
parser.add_argument(
"--tag_name", action="store", default="",
help="snapshot tag to download. If not set, download the latest")
parser.add_argument(
"--release_url", action="store",
default="https://api.github.com/repos/google/iree/releases",
help=("URL to check the IREE release."
"(default: https://api.github.com/repos/google/iree/releases)")
)
parser.add_argument(
"--iree_compiler_dir", action="store", required=True,
default="",
help=("IREE compiler installed directory")
)
args = parser.parse_args()
# Check if the IREE runtime lib is in sync with the tag
root_dir = os.getenv("ROOTDIR")
cmd = ["git", "-C", f"{root_dir}/toolchain/iree", "rev-parse", "HEAD"]
try:
result = subprocess.run(cmd, stdout=subprocess.PIPE, check=True)
except subprocess.CalledProcessError as e:
print("Failed to check IREE runtime version")
sys.exit(e.returncode)
iree_runtime_commit = result.stdout.decode("utf-8")
iree_compiler_dir = Path(args.iree_compiler_dir)
tag_file = iree_compiler_dir / "tag"
if os.path.isfile(tag_file):
with open(tag_file, "r", encoding="utf-8") as file:
for line in file:
if iree_runtime_commit.replace("\n", "") in line:
print("Compiler version matches runtime. Skip download")
sys.exit(0)
snapshot = None
if args.tag_name:
r = requests.get((f"{args.release_url}/tags/{args.tag_name}"),
auth=('user', 'pass'),
timeout=60)
if r.status_code != 200:
print(
f"!!!!!IREE snapshot can't be found with tag {args.tag_name}, "
"please try a different tag!!!!!")
sys.exit(1)
snapshot = r.json()
else:
r = requests.get(args.release_url, auth=('user', 'pass'), timeout=60)
if r.status_code != 200:
print("Not getting the right snapshot information. Status code: %d",
r.status_code)
sys.exit(1)
snapshot = r.json()[0]
tag_name = snapshot["tag_name"]
commit_sha = snapshot["target_commitish"]
print(f"Snapshot: {tag_name}")
tag_file = iree_compiler_dir / "tag"
# Check the tag of the existing download.
tag_match = False
if os.path.isfile(tag_file):
with open(tag_file, 'r', encoding="utf-8") as f:
for line in f:
if tag_name == line.replace("\n", ""):
tag_match = True
break
if tag_match:
print("IREE compiler is up-to-date")
sys.exit(0)
tmp_dir = Path(os.getenv("OUT")) / "tmp"
whl_file = download_artifact(snapshot["assets"],
["iree_tools_tflite", ".whl"], tmp_dir)
tar_file = download_artifact(
snapshot["assets"], ["linux-x86_64.tar"], tmp_dir)
# Install IREE TFLite tool
cmd = (f"pip3 install --target={iree_compiler_dir} {whl_file} "
"--upgrade --no-cache-dir")
os.system(cmd)
# Extract the tarball to ${iree_compiler_dir}/install
install_dir = iree_compiler_dir / "install"
if not install_dir:
os.makedirs(install_dir)
with tarfile.open(tar_file) as tar:
tar.extractall(path=install_dir)
try:
shutil.copy2(f"{iree_compiler_dir}/bin/iree-import-tflite",
f"{install_dir}/bin/iree-import-tflite",
follow_symlinks=True)
except OSError as e:
if e.errno == errno.EEXIST:
os.remove(f"{install_dir}/bin/iree-import-tflite")
shutil.copy2(f"{iree_compiler_dir}/bin/iree-import-tflite",
f"{install_dir}/bin/iree-import-tflite",
follow_symlinks=True)
os.remove(tar_file)
os.remove(whl_file)
print("\nIREE compiler is installed")
# Add tag file for future checks
with open(tag_file, "w", encoding="utf-8") as f:
f.write(f"{tag_name}\ncommit_sha: {commit_sha}\n")
if __name__ == "__main__":
main()