| # Copyright 2021 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| # This tool handles mirroring tflite testing files from their source to the |
| # the iree-model-artifacts test bucket. This avoids taking dependency on |
| # external test files that may change or no longer be available. |
| # |
| # To update all files: |
| # python update_tflite_models.py --file all |
| # |
| # To update a specific file: |
| # python update_tflite_models.py --file posenet_i8_input.jpg |
| # |
| # Note you must have write permission to the iree-model-artifacts GCS bucket |
| # with local gcloud authentication. |
| |
| from absl import app |
| from absl import flags |
| from google.cloud import storage |
| from google_auth_oauthlib import flow |
| |
| import tempfile |
| import urllib |
| |
| FLAGS = flags.FLAGS |
| flags.DEFINE_string("file", "", "file to update") |
| |
| file_dict = dict( |
| { |
| "mobilenet_v1.tflite": "https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite", |
| "posenet_i8.tflite": "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite", |
| "posenet_i8_input.jpg": "https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg", |
| } |
| ) |
| |
| BUCKET_NAME = "iree-model-artifacts" |
| FOLDER_NAME = "tflite-integration-tests" |
| |
| |
| def upload_model(source, destination, tmpfile): |
| """Uploads a file to the bucket.""" |
| urllib.request.urlretrieve(source, tmpfile) |
| |
| storage_client = storage.Client() |
| bucket = storage_client.get_bucket(BUCKET_NAME) |
| blob = bucket.blob("/".join([FOLDER_NAME, destination])) |
| blob.upload_from_filename(tmpfile) |
| |
| |
| def main(argv): |
| tf = tempfile.NamedTemporaryFile() |
| |
| items = file_dict.items() |
| |
| if FLAGS.file in file_dict: |
| items = [(FLAGS.file, file_dict[FLAGS.file])] |
| elif FLAGS.file != "all": |
| print("Unknown file to upload: ", '"' + FLAGS.file + '"') |
| exit() |
| |
| for dst, src in items: |
| upload_model(src, dst, tf.name) |
| |
| |
| if __name__ == "__main__": |
| app.run(main) |