| { | 
 |   "nbformat": 4, | 
 |   "nbformat_minor": 0, | 
 |   "metadata": { | 
 |     "colab": { | 
 |       "name": "tensorflow_hub_import.ipynb", | 
 |       "provenance": [], | 
 |       "collapsed_sections": [ | 
 |         "-V0X0E7LkEa4", | 
 |         "FH3IRpYTta2v" | 
 |       ] | 
 |     }, | 
 |     "kernelspec": { | 
 |       "name": "python3", | 
 |       "display_name": "Python 3" | 
 |     }, | 
 |     "language_info": { | 
 |       "name": "python" | 
 |     } | 
 |   }, | 
 |   "cells": [ | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "FH3IRpYTta2v" | 
 |       }, | 
 |       "source": [ | 
 |         "##### Copyright 2021 The IREE Authors" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "mWGa71_Ct2ug", | 
 |         "cellView": "form" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", | 
 |         "# See https://llvm.org/LICENSE.txt for license information.\n", | 
 |         "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" | 
 |       ], | 
 |       "execution_count": 1, | 
 |       "outputs": [] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "Qb3S0mSjpK7J" | 
 |       }, | 
 |       "source": [ | 
 |         "# IREE TensorFlow Hub Import\n", | 
 |         "\n", | 
 |         "This notebook demonstrates how to download, import, and compile models from [TensorFlow Hub](https://tfhub.dev/). It covers:\n", | 
 |         "\n", | 
 |         "* Downloading a model from TensorFlow Hub\n", | 
 |         "* Ensuring the model has serving signatures needed for import\n", | 
 |         "* Importing and compiling the model with IREE\n", | 
 |         "\n", | 
 |         "At the end of the notebook, the compilation artifacts are compressed into a .zip file for you to download and use in an application.\n", | 
 |         "\n", | 
 |         "See also https://openxla.github.io/iree/ml-frameworks/tensorflow/." | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "9rNAJKNVkKOr" | 
 |       }, | 
 |       "source": [ | 
 |         "## Setup" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "RdVc4TbOkHM2" | 
 |       }, | 
 |       "source": [ | 
 |         "%%capture\n", | 
 |         "!python -m pip install iree-compiler iree-runtime iree-tools-tf -f https://openxla.github.io/iree/pip-release-links.html" | 
 |       ], | 
 |       "execution_count": 2, | 
 |       "outputs": [] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "qRwv3qI_l5O_", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/" | 
 |         }, | 
 |         "outputId": "0e9e1cc3-c97f-4a5c-c980-a43897fc6703" | 
 |       }, | 
 |       "source": [ | 
 |         "import os\n", | 
 |         "import tensorflow as tf\n", | 
 |         "import tensorflow_hub as hub\n", | 
 |         "import tempfile\n", | 
 |         "from IPython.display import clear_output\n", | 
 |         "\n", | 
 |         "from iree.compiler import tf as tfc\n", | 
 |         "\n", | 
 |         "# Print version information for future notebook users to reference.\n", | 
 |         "print(\"TensorFlow version: \", tf.__version__)\n", | 
 |         "\n", | 
 |         "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", | 
 |         "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n", | 
 |         "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")" | 
 |       ], | 
 |       "execution_count": 3, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "TensorFlow version:  2.8.2\n", | 
 |             "Using artifacts directory '/tmp/iree/colab_artifacts'\n" | 
 |           ] | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "ZZAobcAhocFE" | 
 |       }, | 
 |       "source": [ | 
 |         "## Import pretrained [`mobilenet_v2`](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4) model\n", | 
 |         "\n", | 
 |         "IREE supports importing TensorFlow 2 models exported in the [SavedModel](https://www.tensorflow.org/guide/saved_model) format. This model we'll be importing is published in that format already, while other models may need to be converted first.\n", | 
 |         "\n", | 
 |         "MobileNet V2 is a family of neural network architectures for efficient on-device image classification and related tasks. This TensorFlow Hub module contains a trained instance of one particular network architecture packaged to perform image classification." | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/" | 
 |         }, | 
 |         "id": "7fd0vmnloZo9", | 
 |         "outputId": "f3c075d8-0422-40b2-c9f8-bdfd865fd4c2" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Download the pretrained model\n", | 
 |         "\n", | 
 |         "# Use the `hub` library to download the pretrained model to the local disk\n", | 
 |         "# https://www.tensorflow.org/hub/api_docs/python/hub\n", | 
 |         "HUB_PATH = \"https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4\"\n", | 
 |         "model_path = hub.resolve(HUB_PATH)\n", | 
 |         "print(f\"Downloaded model from tfhub to path: '{model_path}'\")" | 
 |       ], | 
 |       "execution_count": 4, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "Downloaded model from tfhub to path: '/tmp/tfhub_modules/426589ad685896ab7954855255a52db3442cb38d'\n" | 
 |           ] | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "CedNRSQTOE7C" | 
 |       }, | 
 |       "source": [ | 
 |         "### Check for serving signatures and re-export as needed\n", | 
 |         "\n", | 
 |         "IREE's compiler tools, like TensorFlow's `saved_model_cli` and other tools, require \"serving signatures\" to be defined in SavedModels.\n", | 
 |         "\n", | 
 |         "More references:\n", | 
 |         "\n", | 
 |         "* https://www.tensorflow.org/tfx/serving/signature_defs\n", | 
 |         "* https://blog.tensorflow.org/2021/03/a-tour-of-savedmodel-signatures.html" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/" | 
 |         }, | 
 |         "id": "qiO66oEYQmsd", | 
 |         "outputId": "95950642-7225-4378-f3d6-ffeb8aedbcd3" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Check for serving signatures\n", | 
 |         "\n", | 
 |         "# Load the SavedModel from the local disk and check if it has serving signatures\n", | 
 |         "# https://www.tensorflow.org/guide/saved_model#loading_and_using_a_custom_model\n", | 
 |         "loaded_model = tf.saved_model.load(model_path)\n", | 
 |         "serving_signatures = list(loaded_model.signatures.keys())\n", | 
 |         "print(f\"Loaded SavedModel from '{model_path}'\")\n", | 
 |         "print(f\"Serving signatures: {serving_signatures}\")\n", | 
 |         "\n", | 
 |         "# Also check with the saved_model_cli:\n", | 
 |         "print(\"\\n---\\n\")\n", | 
 |         "print(\"Checking for signature_defs using saved_model_cli:\\n\")\n", | 
 |         "!saved_model_cli show --dir {model_path} --tag_set serve --signature_def serving_default" | 
 |       ], | 
 |       "execution_count": 5, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "Loaded SavedModel from '/tmp/tfhub_modules/426589ad685896ab7954855255a52db3442cb38d'\n", | 
 |             "Serving signatures: []\n", | 
 |             "\n", | 
 |             "---\n", | 
 |             "\n", | 
 |             "Checking for signature_defs using saved_model_cli:\n", | 
 |             "\n", | 
 |             "Traceback (most recent call last):\n", | 
 |             "  File \"/usr/local/bin/saved_model_cli\", line 8, in <module>\n", | 
 |             "    sys.exit(main())\n", | 
 |             "  File \"/usr/local/lib/python3.7/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 1260, in main\n", | 
 |             "    args.func(args)\n", | 
 |             "  File \"/usr/local/lib/python3.7/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 745, in show\n", | 
 |             "    _show_inputs_outputs(args.dir, args.tag_set, args.signature_def)\n", | 
 |             "  File \"/usr/local/lib/python3.7/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 154, in _show_inputs_outputs\n", | 
 |             "    meta_graph_def, signature_def_key)\n", | 
 |             "  File \"/usr/local/lib/python3.7/dist-packages/tensorflow/python/tools/saved_model_cli.py\", line 115, in _get_inputs_tensor_info_from_meta_graph_def\n", | 
 |             "    f'Could not find signature \"{signature_def_key}\". Please choose from: '\n", | 
 |             "ValueError: Could not find signature \"serving_default\". Please choose from: __saved_model_init_op\n" | 
 |           ] | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "kKqqX2LsReNz" | 
 |       }, | 
 |       "source": [ | 
 |         "Since the model we downloaded did not include any serving signatures, we'll re-export it with serving signatures defined.\n", | 
 |         "\n", | 
 |         "* https://www.tensorflow.org/guide/saved_model#specifying_signatures_during_export" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/" | 
 |         }, | 
 |         "id": "OlDG2OuqOBGC", | 
 |         "outputId": "8296a409-c630-4d03-c81c-58aa95cc0f77" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Look up input signatures to use when exporting\n", | 
 |         "\n", | 
 |         "# To save serving signatures we need to specify a `ConcreteFunction` with a\n", | 
 |         "# TensorSpec signature. We can determine what this signature should be by\n", | 
 |         "# looking at any documentation for the model or running the saved_model_cli.\n", | 
 |         "\n", | 
 |         "!saved_model_cli show --dir {model_path} --all \\\n", | 
 |         "    2> /dev/null | grep \"inputs: TensorSpec\" | tail -n 1" | 
 |       ], | 
 |       "execution_count": 6, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "          inputs: TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='inputs')\n" | 
 |           ] | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/" | 
 |         }, | 
 |         "id": "gnb4HhMmkgiT", | 
 |         "outputId": "f8cf1fe0-0bc4-4c2f-9622-325c15cb923c" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Re-export the model using the known signature\n", | 
 |         "\n", | 
 |         "# Get a concrete function using the signature we found above.\n", | 
 |         "# \n", | 
 |         "# The first element of the shape is a dynamic batch size. We'll be running\n", | 
 |         "# inference on a single image at a time, so set it to `1`. The rest of the\n", | 
 |         "# shape is the fixed image dimensions [width=224, height=224, channels=3].\n", | 
 |         "call = loaded_model.__call__.get_concrete_function(tf.TensorSpec([1, 224, 224, 3], tf.float32))\n", | 
 |         "\n", | 
 |         "# Save the model, setting the concrete function as a serving signature.\n", | 
 |         "# https://www.tensorflow.org/guide/saved_model#saving_a_custom_model\n", | 
 |         "resaved_model_path = '/tmp/resaved_model'\n", | 
 |         "tf.saved_model.save(loaded_model, resaved_model_path, signatures=call)\n", | 
 |         "clear_output()  # Skip over TensorFlow's output.\n", | 
 |         "print(f\"Saved model with serving signatures to '{resaved_model_path}'\")\n", | 
 |         "\n", | 
 |         "# Load the model back into memory and check that it has serving signatures now\n", | 
 |         "reloaded_model = tf.saved_model.load(resaved_model_path)\n", | 
 |         "reloaded_serving_signatures = list(reloaded_model.signatures.keys())\n", | 
 |         "print(f\"\\nReloaded SavedModel from '{resaved_model_path}'\")\n", | 
 |         "print(f\"Serving signatures: {reloaded_serving_signatures}\")\n", | 
 |         "\n", | 
 |         "# Also check with the saved_model_cli:\n", | 
 |         "print(\"\\n---\\n\")\n", | 
 |         "print(\"Checking for signature_defs using saved_model_cli:\\n\")\n", | 
 |         "!saved_model_cli show --dir {resaved_model_path} --tag_set serve --signature_def serving_default" | 
 |       ], | 
 |       "execution_count": 7, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "Saved model with serving signatures to '/tmp/resaved_model'\n", | 
 |             "\n", | 
 |             "Reloaded SavedModel from '/tmp/resaved_model'\n", | 
 |             "Serving signatures: ['serving_default']\n", | 
 |             "\n", | 
 |             "---\n", | 
 |             "\n", | 
 |             "Checking for signature_defs using saved_model_cli:\n", | 
 |             "\n", | 
 |             "The given SavedModel SignatureDef contains the following input(s):\n", | 
 |             "  inputs['inputs'] tensor_info:\n", | 
 |             "      dtype: DT_FLOAT\n", | 
 |             "      shape: (1, 224, 224, 3)\n", | 
 |             "      name: serving_default_inputs:0\n", | 
 |             "The given SavedModel SignatureDef contains the following output(s):\n", | 
 |             "  outputs['output_0'] tensor_info:\n", | 
 |             "      dtype: DT_FLOAT\n", | 
 |             "      shape: (1, 1001)\n", | 
 |             "      name: StatefulPartitionedCall:0\n", | 
 |             "Method name is: tensorflow/serving/predict\n" | 
 |           ] | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "YdmgASzwanSz" | 
 |       }, | 
 |       "source": [ | 
 |         "### Import and compile the SavedModel with IREE" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/" | 
 |         }, | 
 |         "id": "GLkjlHE5mdmg", | 
 |         "outputId": "f6a7718a-456b-4eb1-ea0f-3465e658f3c9" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Import from SavedModel\n", | 
 |         "\n", | 
 |         "# The main output file from compilation is a .vmfb \"VM FlatBuffer\". This file\n", | 
 |         "# can used to run the compiled model with IREE's runtime.\n", | 
 |         "output_file = os.path.join(ARTIFACTS_DIR, \"mobilenet_v2.vmfb\")\n", | 
 |         "# As compilation runs, dump some intermediate .mlir files for future inspection.\n", | 
 |         "tf_input = os.path.join(ARTIFACTS_DIR, \"mobilenet_v2_tf_input.mlir\")\n", | 
 |         "iree_input = os.path.join(ARTIFACTS_DIR, \"mobilenet_v2_iree_input.mlir\")\n", | 
 |         "\n", | 
 |         "# Since our SavedModel uses signature defs, we use `saved_model_tags` with\n", | 
 |         "# `import_type=\"SIGNATURE_DEF\"`. If the SavedModel used an object graph, we\n", | 
 |         "# would use `exported_names` with `import_type=\"OBJECT_GRAPH\"` instead.\n", | 
 |         "\n", | 
 |         "# We'll set `target_backends=[\"vmvx\"]` to use IREE's reference CPU backend.\n", | 
 |         "# We could instead use different backends here, or set `import_only=True` then\n", | 
 |         "# download the imported .mlir file for compilation using native tools directly.\n", | 
 |         "\n", | 
 |         "tfc.compile_saved_model(\n", | 
 |         "    resaved_model_path,\n", | 
 |         "    output_file=output_file,\n", | 
 |         "    save_temp_tf_input=tf_input,\n", | 
 |         "    save_temp_iree_input=iree_input,\n", | 
 |         "    import_type=\"SIGNATURE_DEF\",\n", | 
 |         "    saved_model_tags=set([\"serve\"]),\n", | 
 |         "    target_backends=[\"vmvx\"])\n", | 
 |         "clear_output()  # Skip over TensorFlow's output.\n", | 
 |         "\n", | 
 |         "print(f\"Saved compiled output to '{output_file}'\")\n", | 
 |         "print(f\"Saved tf_input to        '{tf_input}'\")\n", | 
 |         "print(f\"Saved iree_input to      '{iree_input}'\")" | 
 |       ], | 
 |       "execution_count": 8, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "Saved compiled output to '/tmp/iree/colab_artifacts/mobilenet_v2.vmfb'\n", | 
 |             "Saved tf_input to        '/tmp/iree/colab_artifacts/mobilenet_v2_tf_input.mlir'\n", | 
 |             "Saved iree_input to      '/tmp/iree/colab_artifacts/mobilenet_v2_iree_input.mlir'\n" | 
 |           ] | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 104 | 
 |         }, | 
 |         "id": "IEJAzOb5qASI", | 
 |         "outputId": "d5a6ec5e-a583-47c9-b1eb-3bd81c02b50f" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Download compilation artifacts\n", | 
 |         "\n", | 
 |         "ARTIFACTS_ZIP = \"/tmp/mobilenet_colab_artifacts.zip\"\n", | 
 |         "\n", | 
 |         "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n", | 
 |         "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n", | 
 |         "\n", | 
 |         "# Note: you can also download files using the file explorer on the left\n", | 
 |         "try:\n", | 
 |         "    from google.colab import files\n", | 
 |         "    print(\"Downloading the artifacts zip file...\")\n", | 
 |         "    files.download(ARTIFACTS_ZIP)\n", | 
 |         "except ImportError:\n", | 
 |         "    print(\"Missing google_colab Python package, can't download files\")" | 
 |       ], | 
 |       "execution_count": 9, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "name": "stdout", | 
 |           "text": [ | 
 |             "Zipping '/tmp/iree/colab_artifacts' to '/tmp/mobilenet_colab_artifacts.zip' for download...\n", | 
 |             "  adding: mobilenet_v2.vmfb (deflated 8%)\n", | 
 |             "  adding: mobilenet_v2_tf_input.mlir (deflated 8%)\n", | 
 |             "  adding: mobilenet_v2_iree_input.mlir (deflated 8%)\n", | 
 |             "Downloading the artifacts zip file...\n" | 
 |           ] | 
 |         }, | 
 |         { | 
 |           "output_type": "display_data", | 
 |           "data": { | 
 |             "text/plain": [ | 
 |               "<IPython.core.display.Javascript object>" | 
 |             ], | 
 |             "application/javascript": [ | 
 |               "\n", | 
 |               "    async function download(id, filename, size) {\n", | 
 |               "      if (!google.colab.kernel.accessAllowed) {\n", | 
 |               "        return;\n", | 
 |               "      }\n", | 
 |               "      const div = document.createElement('div');\n", | 
 |               "      const label = document.createElement('label');\n", | 
 |               "      label.textContent = `Downloading \"${filename}\": `;\n", | 
 |               "      div.appendChild(label);\n", | 
 |               "      const progress = document.createElement('progress');\n", | 
 |               "      progress.max = size;\n", | 
 |               "      div.appendChild(progress);\n", | 
 |               "      document.body.appendChild(div);\n", | 
 |               "\n", | 
 |               "      const buffers = [];\n", | 
 |               "      let downloaded = 0;\n", | 
 |               "\n", | 
 |               "      const channel = await google.colab.kernel.comms.open(id);\n", | 
 |               "      // Send a message to notify the kernel that we're ready.\n", | 
 |               "      channel.send({})\n", | 
 |               "\n", | 
 |               "      for await (const message of channel.messages) {\n", | 
 |               "        // Send a message to notify the kernel that we're ready.\n", | 
 |               "        channel.send({})\n", | 
 |               "        if (message.buffers) {\n", | 
 |               "          for (const buffer of message.buffers) {\n", | 
 |               "            buffers.push(buffer);\n", | 
 |               "            downloaded += buffer.byteLength;\n", | 
 |               "            progress.value = downloaded;\n", | 
 |               "          }\n", | 
 |               "        }\n", | 
 |               "      }\n", | 
 |               "      const blob = new Blob(buffers, {type: 'application/binary'});\n", | 
 |               "      const a = document.createElement('a');\n", | 
 |               "      a.href = window.URL.createObjectURL(blob);\n", | 
 |               "      a.download = filename;\n", | 
 |               "      div.appendChild(a);\n", | 
 |               "      a.click();\n", | 
 |               "      div.remove();\n", | 
 |               "    }\n", | 
 |               "  " | 
 |             ] | 
 |           }, | 
 |           "metadata": {} | 
 |         }, | 
 |         { | 
 |           "output_type": "display_data", | 
 |           "data": { | 
 |             "text/plain": [ | 
 |               "<IPython.core.display.Javascript object>" | 
 |             ], | 
 |             "application/javascript": [ | 
 |               "download(\"download_e515e31a-0819-4066-9236-710263bad3e5\", \"mobilenet_colab_artifacts.zip\", 39551293)" | 
 |             ] | 
 |           }, | 
 |           "metadata": {} | 
 |         } | 
 |       ] | 
 |     } | 
 |   ] | 
 | } |