blob: e2ea88727bc41cd43613cae2d8d574352483b3d8 [file] [log] [blame]
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This tool handles mirroring tflite testing files from their source to the
# the iree-model-artifacts test bucket. This avoids taking dependency on
# external test files that may change or no longer be available.
#
# To update all files:
# python update_tflite_models.py --file all
#
# To update a specific file:
# python update_tflite_models.py --file posenet_i8_input.jpg
#
# Note you must have write permission to the iree-model-artifacts GCS bucket
# with local gcloud authentication.
from absl import app
from absl import flags
from google.cloud import storage
from google_auth_oauthlib import flow
import tempfile
import urllib
FLAGS = flags.FLAGS
flags.DEFINE_string('file', '', 'file to update')
file_dict = dict({
"mobilenet_v1.tflite":
"https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_160/1/default/1?lite-format=tflite",
"posenet_i8.tflite":
"https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/int8/4?lite-format=tflite",
"posenet_i8_input.jpg":
"https://github.com/tensorflow/examples/raw/master/lite/examples/pose_estimation/raspberry_pi/test_data/image3.jpeg"
})
BUCKET_NAME = "iree-model-artifacts"
FOLDER_NAME = "tflite-integration-tests"
def upload_model(source, destination, tmpfile):
"""Uploads a file to the bucket."""
urllib.request.urlretrieve(source, tmpfile)
storage_client = storage.Client()
bucket = storage_client.get_bucket(BUCKET_NAME)
blob = bucket.blob("/".join([FOLDER_NAME, destination]))
blob.upload_from_filename(tmpfile)
def main(argv):
tf = tempfile.NamedTemporaryFile()
items = file_dict.items()
if FLAGS.file in file_dict:
items = [(FLAGS.file, file_dict[FLAGS.file])]
elif FLAGS.file != "all":
print('Unknown file to upload: ', "\"" + FLAGS.file + "\"")
exit()
for dst, src in items:
upload_model(src, dst, tf.name)
if __name__ == '__main__':
app.run(main)