Merge "Add tempita to python packages."
diff --git a/download_iree_compiler.py b/download_iree_compiler.py
index dba4705..ab5939f 100755
--- a/download_iree_compiler.py
+++ b/download_iree_compiler.py
@@ -8,106 +8,129 @@
 import requests
 import wget
 
-iree_compiler_dir = os.getenv("IREE_COMPILER_DIR")
-if not iree_compiler_dir:
-    print("Please run 'source build/setup.sh' first")
-    sys.exit(-1)
+from pathlib import Path
 
-parser = argparse.ArgumentParser(
-    description="Download IREE host compiler from snapshot releases")
-parser.add_argument(
-    "--tag_name", action="store", default="",
-    help="snapshot tag to download. If not set, download the latest")
-args = parser.parse_args()
-r = requests.get(
-    "https://api.github.com/repos/google/iree/releases?per_page=60", auth=(
-        'user', 'pass'))
 
-if r.status_code != 200:
-    print("Not getting the right snapshot information. Status code: %d",
-          r.status_code)
-    sys.exit(-1)
-
-TAG_FOUND = False
-if args.tag_name:
-    for x in r.json():
-        if x["tag_name"] == args.tag_name:
-            TAG_FOUND = True
-            snapshot = x
+def download_artifact(assets, keywords, out_dir):
+    """Download the artifact from the asset list based on the keyword."""
+    # Find the linux tarball and download it.
+    artifact_match = False
+    for asset in assets:
+        download_url = asset["browser_download_url"]
+        artifact_name = asset["name"]
+        if all(x in artifact_name for x in keywords):
+            artifact_match = True
             break
-else:
-    TAG_FOUND = True
-    snapshot = r.json()[0]
+    if not artifact_match:
+        print("%s is not found" % (keywords[0]))
+        sys.exit(1)
 
-if not TAG_FOUND:
-    print("!!!!!IREE snapshot can't be found with tag %s, please try a "
-          "different tag!!!!!" % args.tag_name)
-    sys.exit(-1)
+    print("\nDownload %s from %s\n" % (artifact_name, download_url))
+    if not os.path.isdir(out_dir):
+        os.mkdir(out_dir)
+    out_file = os.path.join(out_dir, artifact_name)
+    wget.download(download_url, out=out_file)
+    return out_file
 
-tag_name = snapshot["tag_name"]
-commit_sha = snapshot["target_commitish"]
 
-print("Snapshot: %s" % tag_name)
+def main():
+    """ Download IREE host compiler from the snapshot release."""
+    iree_compiler_dir = os.getenv("IREE_COMPILER_DIR")
+    if not iree_compiler_dir:
+        print("Please run 'source build/setup.sh' first")
+        sys.exit(1)
+    iree_compiler_dir = Path(iree_compiler_dir)
 
-tag_file = os.path.join(iree_compiler_dir, "tag")
+    parser = argparse.ArgumentParser(
+        description="Download IREE host compiler from snapshot releases")
+    parser.add_argument(
+        "--tag_name", action="store", default="",
+        help="snapshot tag to download. If not set, download the latest")
+    parser.add_argument(
+        "--release_url", action="store",
+        default="https://api.github.com/repos/google/iree/releases",
+        help=("URL to check the IREE release."
+              "(default: https://api.github.com/repos/google/iree/releases)")
+    )
+    parser.add_argument(
+        "--depth", action="store", default="60",
+        help=("Depth of the release history to search the snapshot for "
+              "(max 100). (default: 60)")
+    )
+    args = parser.parse_args()
+    r = requests.get(("%s?per_page=%s" % (args.release_url, args.depth)),
+                     auth=('user', 'pass'))
 
-# Check the tag of the existing download.
-TAG_MATCH = False
-if os.path.isfile(tag_file):
-    file = open(tag_file, "r")
-    for line in file:
-        if tag_name == line.replace("\n", ""):
-            TAG_MATCH = True
-            file.close()
-            break
-    file.close()
+    if r.status_code != 200:
+        print("Not getting the right snapshot information. Status code: %d",
+              r.status_code)
+        sys.exit(1)
 
-if TAG_MATCH:
-    print("IREE compiler is up-to-date")
-    sys.exit(0)
+    tag_found = False
+    snapshot = None
+    if args.tag_name:
+        for x in r.json():
+            if x["tag_name"] == args.tag_name:
+                tag_found = True
+                snapshot = x
+                break
+    else:
+        tag_found = True
+        snapshot = r.json()[0]
 
-# Install IREE TFLite tool
-# Python whl version can be found in tag_name as "snapshot-<version>"
-version=tag_name[9:]
-cmd = ("pip3 install iree-tools-tflite-snapshot==%s -f "
-       "https://github.com/google/iree/releases/ --no-cache-dir" % version)
-os.system(cmd)
+    if not tag_found:
+        print("!!!!!IREE snapshot can't be found with tag %s, please try a "
+              "different tag!!!!!" % args.tag_name)
+        sys.exit(1)
 
-# Find the linux tarball and download it.
-TAR_MATCH = False
-for asset in snapshot["assets"]:
-    download_url = asset["browser_download_url"]
-    tar_name = asset["name"]
-    if "linux-x86_64.tar" in tar_name:
-        TAR_MATCH = True
-        break
+    tag_name = snapshot["tag_name"]
+    commit_sha = snapshot["target_commitish"]
 
-if not TAR_MATCH:
-    print("linux-x86_64 tarball is not found")
-    sys.exit(-1)
+    print("Snapshot: %s" % tag_name)
 
-print("Download %s from %s" % (tar_name, download_url))
+    tag_file = iree_compiler_dir / "tag"
 
-tmp_dir = os.path.join(os.getenv("OUT"), "tmp")
+    # Check the tag of the existing download.
+    tag_match = False
+    if os.path.isfile(tag_file):
+        with open(tag_file, 'r') as f:
+            for line in f:
+                if tag_name == line.replace("\n", ""):
+                    tag_match = True
+                    break
 
-if not os.path.isdir(tmp_dir):
-    os.mkdir(tmp_dir)
+    if tag_match:
+        print("IREE compiler is up-to-date")
+        sys.exit(0)
 
-tar_file = os.path.join(tmp_dir, tar_name)
-wget.download(download_url, out=tar_file)
+    tmp_dir = Path(os.getenv("OUT")) / "tmp"
+    whl_file = download_artifact(snapshot["assets"],
+                                 ["iree_tools_tflite", "linux", "x86_64.whl"],
+                                 tmp_dir)
+    tar_file = download_artifact(
+        snapshot["assets"], ["linux-x86_64.tar"], tmp_dir)
 
-# Extract the tarball to ${iree_compiler_dir}/install
-install_dir = os.path.join(iree_compiler_dir, "install")
-if not install_dir:
-    os.mkdir(install_dir)
+    # Install IREE TFLite tool
+    cmd = ("pip3 install %s --no-cache-dir" % whl_file)
+    os.system(cmd)
 
-tar = tarfile.open(tar_file)
-tar.extractall(path=install_dir)
-tar.close()
+    # Extract the tarball to ${iree_compiler_dir}/install
+    install_dir = iree_compiler_dir / "install"
+    if not install_dir:
+        os.mkdir(install_dir)
 
-os.remove(tar_file)
-print("\nIREE compiler is installed")
+    tar = tarfile.open(tar_file)
+    tar.extractall(path=install_dir)
+    tar.close()
 
-# Add tag file for future checks
-with open(tag_file, "w") as f:
-    f.write("%s\ncommit_sha: %s\n" % (tag_name, commit_sha))
+    os.remove(tar_file)
+    os.remove(whl_file)
+    print("\nIREE compiler is installed")
+
+    # Add tag file for future checks
+    with open(tag_file, "w") as f:
+        f.write("%s\ncommit_sha: %s\n" % (tag_name, commit_sha))
+
+
+if __name__ == "__main__":
+    main()