blob: 2af25fb0c4093b9a2c97e7394097c2586983d486 [file] [log] [blame]
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This tool handles mirroring tflite testing files from their source to the
# the iree-model-artifacts test bucket. This avoids taking dependency on
# external test files that may change or no longer be available.
#
# To update all files:
# python update_tflite_models.py --file all
#
# To update a specific file:
# python update_tflite_models.py --file posenet_i8_input.jpg
#
# Note you must have write permission to the iree-model-artifacts GCS bucket
# with local gcloud authentication.
from absl import app
from absl import flags
from google.cloud import storage
from google_auth_oauthlib import flow
import tempfile
import urllib
FLAGS = flags.FLAGS
flags.DEFINE_string("file", "", "file to update")
file_dict = dict(
{
"mobilenet_v1.tflite": "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite",
"posenet_i8.tflite": "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite",
"posenet_i8_input.jpg": "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg",
}
)
BUCKET_NAME = "iree-model-artifacts"
FOLDER_NAME = "tflite-integration-tests"
def upload_model(source, destination, tmpfile):
"""Uploads a file to the bucket."""
urllib.request.urlretrieve(source, tmpfile)
storage_client = storage.Client()
bucket = storage_client.get_bucket(BUCKET_NAME)
blob = bucket.blob("/".join([FOLDER_NAME, destination]))
blob.upload_from_filename(tmpfile)
def main(argv):
tf = tempfile.NamedTemporaryFile()
items = file_dict.items()
if FLAGS.file in file_dict:
items = [(FLAGS.file, file_dict[FLAGS.file])]
elif FLAGS.file != "all":
print("Unknown file to upload: ", '"' + FLAGS.file + '"')
exit()
for dst, src in items:
upload_model(src, dst, tf.name)
if __name__ == "__main__":
app.run(main)