| { | 
 |   "nbformat": 4, | 
 |   "nbformat_minor": 0, | 
 |   "metadata": { | 
 |     "colab": { | 
 |       "name": "mnist_tensorflow.ipynb", | 
 |       "provenance": [], | 
 |       "collapsed_sections": [ | 
 |         "PZtRtMMUZHJS" | 
 |       ] | 
 |     }, | 
 |     "kernelspec": { | 
 |       "name": "python3", | 
 |       "display_name": "Python 3" | 
 |     } | 
 |   }, | 
 |   "cells": [ | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "PZtRtMMUZHJS", | 
 |         "colab_type": "text" | 
 |       }, | 
 |       "source": [ | 
 |         "##### Copyright 2020 Google LLC.\n", | 
 |         "\n", | 
 |         "Licensed under the Apache License, Version 2.0 (the \"License\");" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "TouZL3JZZSQe", | 
 |         "colab_type": "code", | 
 |         "cellView": "both", | 
 |         "colab": {} | 
 |       }, | 
 |       "source": [ | 
 |         "#@title License header\n", | 
 |         "# Copyright 2020 Google LLC\n", | 
 |         "#\n", | 
 |         "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | 
 |         "# you may not use this file except in compliance with the License.\n", | 
 |         "# You may obtain a copy of the License at\n", | 
 |         "#\n", | 
 |         "#      https://www.apache.org/licenses/LICENSE-2.0\n", | 
 |         "#\n", | 
 |         "# Unless required by applicable law or agreed to in writing, software\n", | 
 |         "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | 
 |         "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | 
 |         "# See the License for the specific language governing permissions and\n", | 
 |         "# limitations under the License." | 
 |       ], | 
 |       "execution_count": 1, | 
 |       "outputs": [] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "O6c3qfq5Zv57", | 
 |         "colab_type": "text" | 
 |       }, | 
 |       "source": [ | 
 |         "# MNIST Model TensorFlow Training, IREE Execution\n", | 
 |         "\n", | 
 |         "## Overview\n", | 
 |         "\n", | 
 |         "This notebook creates and trains a TensorFlow 2.0 model for recognizing handwritten digits using the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database), then compiles and executes that trained model using IREE.\n", | 
 |         "\n", | 
 |         "## Running Locally\n", | 
 |         "\n", | 
 |         "*  Refer to [iree/docs/using_colab.md](https://github.com/google/iree/blob/main/docs/using_colab.md) for general information\n", | 
 |         "*  Ensure that you have a recent version of TensorFlow 2.0 [installed on your system](https://www.tensorflow.org/install)\n", | 
 |         "*  Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n", | 
 |         "*  Start colab by running `python colab/start_colab_kernel.py` (see that file for additional instructions)\n", | 
 |         "*  Note: you may need to restart your runtime in order to re-run certain cells. Some of the APIs are not yet stable enough for repeated invocations" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "wBXlE69Ia2QU", | 
 |         "colab_type": "text" | 
 |       }, | 
 |       "source": [ | 
 |         "# Setup Steps" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "EPF7RGQDYK-M", | 
 |         "colab_type": "code", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 51 | 
 |         }, | 
 |         "outputId": "fe0d703a-2ad7-4d14-9aef-c69b4c342a16" | 
 |       }, | 
 |       "source": [ | 
 |         "import os\n", | 
 |         "import numpy as np\n", | 
 |         "import tensorflow as tf\n", | 
 |         "from matplotlib import pyplot as plt\n", | 
 |         "from pyiree.tf import compiler as ireec\n", | 
 |         "from pyiree import rt as ireert\n", | 
 |         "\n", | 
 |         "tf.compat.v1.enable_eager_execution()\n", | 
 |         "\n", | 
 |         "SAVE_PATH = os.path.join(os.environ[\"HOME\"], \"saved_models\")\n", | 
 |         "os.makedirs(SAVE_PATH, exist_ok=True)\n", | 
 |         "\n", | 
 |         "# Print version information for future notebook users to reference.\n", | 
 |         "print(\"TensorFlow version: \", tf.__version__)\n", | 
 |         "print(\"Numpy version: \", np.__version__)" | 
 |       ], | 
 |       "execution_count": 2, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "TensorFlow version:  2.5.0-dev20200626\n", | 
 |             "Numpy version:  1.18.4\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "43BH_9YcsGs8", | 
 |         "colab_type": "code", | 
 |         "cellView": "form", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 34 | 
 |         }, | 
 |         "outputId": "46a560cb-5947-4cfa-f71e-9065bf2c07ab" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Notebook settings { run: \"auto\" }\n", | 
 |         "\n", | 
 |         "#@markdown -----\n", | 
 |         "#@markdown ### Configuration\n", | 
 |         "\n", | 
 |         "backend_choice = \"GPU (vulkan-spirv)\" #@param [ \"GPU (vulkan-spirv)\", \"CPU (VMLA)\" ]\n", | 
 |         "\n", | 
 |         "if backend_choice == \"GPU (vulkan-spirv)\":\n", | 
 |         "  backend_name = \"vulkan-spirv\"\n", | 
 |         "  driver_name = \"vulkan\"\n", | 
 |         "else:\n", | 
 |         "  backend_name = \"vmla\"\n", | 
 |         "  driver_name = \"vmla\"\n", | 
 |         "tf.print(\"Using IREE compiler backend '%s' and runtime driver '%s'\" % (backend_name, driver_name))\n", | 
 |         "\n", | 
 |         "#@markdown -----\n", | 
 |         "#@markdown ### Training Parameters\n", | 
 |         "\n", | 
 |         "#@markdown <sup>Batch size used to subdivide the training and evaluation samples</sup>\n", | 
 |         "batch_size = 200  #@param { type: \"slider\", min: 10, max: 400 }\n", | 
 |         "\n", | 
 |         "#@markdown <sup>Epochs for training/eval. Higher values take longer to run but generally produce more accurate models</sup>\n", | 
 |         "num_epochs = 5    #@param { type: \"slider\", min:  1, max:  20 }\n", | 
 |         "\n", | 
 |         "#@markdown -----" | 
 |       ], | 
 |       "execution_count": 3, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "Using IREE compiler backend 'vulkan-spirv' and runtime driver 'vulkan'\r\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "5vkQOMOMbXdy", | 
 |         "colab_type": "text" | 
 |       }, | 
 |       "source": [ | 
 |         "# Create and Train MNIST Model in TensorFlow\n", | 
 |         "\n", | 
 |         "The specific details of the training process here aren't critical to the model compilation and execution through IREE." | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "GXZIrReTbTHN", | 
 |         "colab_type": "code", | 
 |         "cellView": "form", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 486 | 
 |         }, | 
 |         "outputId": "9c01fab1-f8cb-4a63-fff1-6b82a9b6b49d" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Load MNIST dataset, setup training and evaluation\n", | 
 |         "\n", | 
 |         "NUM_CLASSES = 10  # One per digit [0, 1, 2, ..., 9]\n", | 
 |         "IMG_ROWS, IMG_COLS = 28, 28\n", | 
 |         "\n", | 
 |         "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", | 
 |         "tf.print(\"Loaded MNIST dataset!\")\n", | 
 |         "\n", | 
 |         "x_train = x_train.reshape(x_train.shape[0], IMG_ROWS, IMG_COLS, 1)\n", | 
 |         "x_test = x_test.reshape(x_test.shape[0], IMG_ROWS, IMG_COLS, 1)\n", | 
 |         "input_shape = (IMG_ROWS, IMG_COLS, 1)\n", | 
 |         "\n", | 
 |         "# Scale pixel values from [0, 255] integers to [0.0, 1.0] floats.\n", | 
 |         "x_train = x_train.astype(\"float32\") / 255\n", | 
 |         "x_test = x_test.astype(\"float32\") / 255\n", | 
 |         "\n", | 
 |         "steps_per_epoch = int(x_train.shape[0] / batch_size)\n", | 
 |         "steps_per_eval = int(x_test.shape[0] / batch_size)\n", | 
 |         "\n", | 
 |         "# Convert class vectors to binary class matrices.\n", | 
 |         "y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)\n", | 
 |         "y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)\n", | 
 |         "\n", | 
 |         "# Construct batched datasets for training/evaluation.\n", | 
 |         "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", | 
 |         "train_dataset = train_dataset.batch(batch_size, drop_remainder=True)\n", | 
 |         "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", | 
 |         "test_dataset = test_dataset.batch(batch_size, drop_remainder=True)\n", | 
 |         "\n", | 
 |         "# Create a distribution strategy for the dataset (single machine).\n", | 
 |         "strategy = tf.distribute.experimental.CentralStorageStrategy()\n", | 
 |         "train_dist_ds = strategy.experimental_distribute_dataset(train_dataset)\n", | 
 |         "test_dist_ds = strategy.experimental_distribute_dataset(test_dataset)\n", | 
 |         "\n", | 
 |         "tf.print(\"Configured data for training and evaluation!\")\n", | 
 |         "tf.print(\"  sample shape: %s\" % str(x_train[0].shape))\n", | 
 |         "tf.print(\"  training samples: %s\" % x_train.shape[0])\n", | 
 |         "tf.print(\"  test     samples: %s\" % x_test.shape[0])\n", | 
 |         "tf.print(\"  epochs: %s\" % num_epochs)\n", | 
 |         "tf.print(\"  steps/epoch: %s\" % steps_per_epoch)\n", | 
 |         "tf.print(\"  steps/eval : %s\" % steps_per_eval)\n", | 
 |         "\n", | 
 |         "tf.print(\"\")\n", | 
 |         "tf.print(\"Sample image from the dataset:\")\n", | 
 |         "SAMPLE_EXAMPLE_INDEX = 1\n", | 
 |         "sample_image = x_test[SAMPLE_EXAMPLE_INDEX]\n", | 
 |         "sample_image_batch = np.expand_dims(sample_image, axis=0)\n", | 
 |         "sample_label = y_test[SAMPLE_EXAMPLE_INDEX]\n", | 
 |         "plt.imshow(sample_image.reshape(IMG_ROWS, IMG_COLS))\n", | 
 |         "plt.show()\n", | 
 |         "tf.print(\"\\nGround truth labels: %s\" % str(sample_label))" | 
 |       ], | 
 |       "execution_count": 4, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "Loaded MNIST dataset!\r\n", | 
 |             "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:localhost/replica:0/task:0/device:CPU:0'], variable_device = '/job:localhost/replica:0/task:0/device:CPU:0'\n", | 
 |             "Configured data for training and evaluation!\n", | 
 |             "  sample shape: (28, 28, 1)\n", | 
 |             "  training samples: 60000\n", | 
 |             "  test     samples: 10000\n", | 
 |             "  epochs: 5\n", | 
 |             "  steps/epoch: 300\n", | 
 |             "  steps/eval : 50\n", | 
 |             "\n", | 
 |             "Sample image from the dataset:\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         }, | 
 |         { | 
 |           "output_type": "display_data", | 
 |           "data": { | 
 |             "text/plain": [ | 
 |               "<Figure size 432x288 with 1 Axes>" | 
 |             ], | 
 |             "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANxUlEQVR4nO3de4xU93nG8ecBc7EwtqFgSjGygwOycSpDsiJx3YstN6nDH8GRckOJgyNHpGrcJhJSYrmV4igXWVVst1WjVCRGIZUvcn2JqWIlJsSR6wRhLy4BbJJAXOpgVmDEpuBWhd312z/2UG3wzpll5sycMe/3I41m5rxzznk18OyZmd+c+TkiBODsN6nuBgB0B2EHkiDsQBKEHUiCsANJnNPNnU31tJiuGd3cJZDK/+q/dTJOeLxaW2G3fYOkv5c0WdK3IuLOssdP1wy909e3s0sAJbbFloa1ll/G254s6euS3itpqaTVtpe2uj0AndXOe/YVkvZFxEsRcVLSg5JWVdMWgKq1E/YFkn495v6BYtlvsb3Wdr/t/iGdaGN3ANrRTtjH+xDgDd+9jYj1EdEXEX1TNK2N3QFoRzthPyBp4Zj7F0s62F47ADqlnbA/J2mx7bfYnirpI5I2VdMWgKq1PPQWEcO2b5X0A40OvW2IiBcq6wxApdoaZ4+IJyQ9UVEvADqIr8sCSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5BEV39KGq3Z/+WrS+sj0xtPzjn3yldL19161SMt9XTKZT/6RGl95rPnNqzN+4eftrVvnBmO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsPWDwe4tL67uX/WPH9j3UeIh+Qn5+3bdK6/f1zW9Ye2jzn5SuO7Jnb0s9YXwc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZu6DZOPpPlj3YsX3/028Wldbv3vru0vqll5SfD//k0kdL6x+dOdCw9pWb55Suu+jzjLNXqa2w294v6bikEUnDEdFXRVMAqlfFkf26iDhSwXYAdBDv2YEk2g17SHrS9nbba8d7gO21tvtt9w/pRJu7A9Cqdl/GXxMRB21fJGmz7Z9HxNNjHxAR6yWtl6TzPbvN0y4AtKqtI3tEHCyuD0t6TNKKKpoCUL2Ww257hu2Zp25Leo+k3VU1BqBa7byMnyfpMduntnN/RHy/kq7eZIavf0dp/UdXfb3JFqaUVv9ucElp/akPl4x4Hjxcuu6Swf7S+qTp00vrX932+6X12+fsalgbnjVcui6q1XLYI+IlSVdV2AuADmLoDUiCsANJEHYgCcIOJEHYgSQ4xbUCry2YWlqf1ORvarOhtR+/r3x4a+SlX5TW27Hvi8tL6/fPvqvJFqY1rFz8fY413cSzDSRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5egQu/s7W0/oH+j5XWPXistD48sP8MO6rOJ1f+sLR+3qTG4+joLRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtm7YOTFX9bdQkP7v3J1af2WC7/WZAvlPzW9buBdDWszf7indN2RJnvGmeHIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5+lvvNTeXj6D/5ePk4+gWTysfRt56YXFrf8eXGvzt/7rFnS9dFtZoe2W1vsH3Y9u4xy2bb3mx7b3E9q7NtAmjXRF7Gf1vSDactu03SlohYLGlLcR9AD2sa9oh4WtLR0xavkrSxuL1R0o0V9wWgYq1+QDcvIgYkqbi+qNEDba+13W+7f0gnWtwdgHZ1/NP4iFgfEX0R0TelZJI/AJ3VatgP2Z4vScX14epaAtAJrYZ9k6Q1xe01kh6vph0AndJ0nN32A5KulTTH9gFJX5B0p6SHbN8i6WVJH+xkk2jdkbdHab3ZOHoza378ydL6ku8ylt4rmoY9IlY3KF1fcS8AOoivywJJEHYgCcIOJEHYgSQIO5AEp7ieBU5uvqRhbevldzVZu3zo7aqta0rrV6z7VWmdn4PuHRzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtnfBM5ZdGlp/Utv/ZeGtVlNTmHd3uSXwi75UvlI+cjgYPkG0DM4sgNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyzvwlc9tArpfXlU1v/m716y5+X1pf87LmWt43ewpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0HDK65urT+xXnNfvt9WsPKmv1/WrrmFZ/bV1rnd9/PHk2P7LY32D5se/eYZXfYfsX2juKysrNtAmjXRF7Gf1vSDeMsvycilhWXJ6ptC0DVmoY9Ip6WdLQLvQDooHY+oLvV9s7iZf6sRg+yvdZ2v+3+ITX5wTMAHdNq2L8h6TJJyyQNSGr4CVJErI+Ivojom1LyQRKAzmop7BFxKCJGIuJ1Sd+UtKLatgBUraWw254/5u77Je1u9FgAvaHpOLvtByRdK2mO7QOSviDpWtvLJIWk/ZI+1cEe3/TOWfB7pfU/+qttpfXzJrX+9mfri28trS8Z5Hz1LJqGPSJWj7P43g70AqCD+LoskARhB5Ig7EAShB1IgrADSXCKaxfsuX1haf27v/uvbW3/ul0fbFjjFFacwpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0Ltr/vniaPaO8XfC74i9cb1oYHB9vaNs4eHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2c8CQ/MuaFibcnJBFzt5o5FXjzSsxYny6cA8rfz7B5PnzmmpJ0kamXthaX3vuqktb3siYsQNa5f/ZZPfIDh2rKV9cmQHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZz8LfO/hDXW30NAf/Pt4kwCPOnLo/NJ1Z809Xlrf9o77W+qp1y39m1tL64s+t7Wl7TY9stteaPsp23tsv2D7M8Xy2bY3295bXM9qqQMAXTGRl/HDktZFxBWS3iXp07aXSrpN0paIWCxpS3EfQI9qGvaIGIiI54vbxyXtkbRA0ipJG4uHbZR0Y6eaBNC+M/qAzvalkpZL2iZpXkQMSKN/ECRd1GCdtbb7bfcPqfy70AA6Z8Jht32epEckfTYiJvxN/IhYHxF9EdE3pc0fVgTQugmF3fYUjQb9voh4tFh8yPb8oj5f0uHOtAigCk2H3mxb0r2S9kTE3WNKmyStkXRncf14Rzo8C6x68aOl9S1ve7hLnXTfT5c/UNu+/ydONqwNReOf356IlTtvLq3/147WT79d8Mxwy+uWmcg4+zWSbpK0y/aOYtntGg35Q7ZvkfSypMaThAOoXdOwR8QzkhqdaX99te0A6BS+LgskQdiBJAg7kARhB5Ig7EASnOLaBef+2X+U1q/8avkpjdHBf6WZlx8trXfyNNIr/+0TpfV4eUZb21/08GuNi8/uamvbs7S3rXodOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKOiK7t7HzPjneaE+WATtkWW3Qsjo57lipHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiiadhtL7T9lO09tl+w/Zli+R22X7G9o7is7Hy7AFo1kekHhiWti4jnbc+UtN325qJ2T0R8rXPtAajKROZnH5A0UNw+bnuPpAWdbgxAtc7oPbvtSyUtl7StWHSr7Z22N9ie1WCdtbb7bfcP6URbzQJo3YTDbvs8SY9I+mxEHJP0DUmXSVqm0SP/XeOtFxHrI6IvIvqmaFoFLQNoxYTCbnuKRoN+X0Q8KkkRcSgiRiLidUnflLSic20CaNdEPo23pHsl7YmIu8csnz/mYe+XtLv69gBUZSKfxl8j6SZJu2zvKJbdLmm17WWSQtJ+SZ/qSIcAKjGRT+OfkTTe71A/UX07ADqFb9ABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeScER0b2f2q5L+c8yiOZKOdK2BM9OrvfVqXxK9tarK3i6JiLnjFboa9jfs3O6PiL7aGijRq731al8SvbWqW73xMh5IgrADSdQd9vU1779Mr/bWq31J9NaqrvRW63t2AN1T95EdQJcQdiCJWsJu+wbbv7C9z/ZtdfTQiO39tncV01D319zLBtuHbe8es2y27c229xbX486xV1NvPTGNd8k047U+d3VPf9719+y2J0v6paR3Szog6TlJqyPixa420oDt/ZL6IqL2L2DY/mNJr0n6TkS8rVj2t5KORsSdxR/KWRHx+R7p7Q5Jr9U9jXcxW9H8sdOMS7pR0s2q8bkr6etD6sLzVseRfYWkfRHxUkSclPSgpFU19NHzIuJpSUdPW7xK0sbi9kaN/mfpuga99YSIGIiI54vbxyWdmma81ueupK+uqCPsCyT9esz9A+qt+d5D0pO2t9teW3cz45gXEQPS6H8eSRfV3M/pmk7j3U2nTTPeM89dK9Oft6uOsI83lVQvjf9dExFvl/ReSZ8uXq5iYiY0jXe3jDPNeE9odfrzdtUR9gOSFo65f7GkgzX0Ma6IOFhcH5b0mHpvKupDp2bQLa4P19zP/+ulabzHm2ZcPfDc1Tn9eR1hf07SYttvsT1V0kckbaqhjzewPaP44ES2Z0h6j3pvKupNktYUt9dIerzGXn5Lr0zj3WiacdX83NU+/XlEdP0iaaVGP5H/laS/rqOHBn0tkvSz4vJC3b1JekCjL+uGNPqK6BZJvyNpi6S9xfXsHurtnyXtkrRTo8GaX1Nvf6jRt4Y7Je0oLivrfu5K+urK88bXZYEk+AYdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTxfy43Cn7d/BIFAAAAAElFTkSuQmCC\n" | 
 |           }, | 
 |           "metadata": { | 
 |             "tags": [], | 
 |             "needs_background": "light" | 
 |           } | 
 |         }, | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "\n", | 
 |             "Ground truth labels: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\r\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "tHq96SIJcNfx", | 
 |         "colab_type": "code", | 
 |         "cellView": "both", | 
 |         "colab": {} | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Define MNIST model architecture using tf.keras API\n", | 
 |         "\n", | 
 |         "def simple_mnist_model(input_shape):\n", | 
 |         "  \"\"\"Creates a simple (multi-layer perceptron) MNIST model.\"\"\"\n", | 
 |         "\n", | 
 |         "  model = tf.keras.models.Sequential()\n", | 
 |         "  # Flatten to a 1d array (e.g. 28x28 -> 784)\n", | 
 |         "  model.add(tf.keras.layers.Flatten(input_shape=input_shape))\n", | 
 |         "  # Fully-connected neural layer with 128 neurons, RELU activation\n", | 
 |         "  model.add(tf.keras.layers.Dense(128, activation='relu'))\n", | 
 |         "  # Fully-connected neural layer returning probability scores for each class\n", | 
 |         "  model.add(tf.keras.layers.Dense(10, activation='softmax'))\n", | 
 |         "  return model" | 
 |       ], | 
 |       "execution_count": 5, | 
 |       "outputs": [] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "7Gdxh7qWcPSO", | 
 |         "colab_type": "code", | 
 |         "cellView": "form", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 374 | 
 |         }, | 
 |         "outputId": "50b0aede-9a8f-4ce5-b340-783be5fbfc06" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Train the Keras model\n", | 
 |         "\n", | 
 |         "with strategy.scope():\n", | 
 |         "  model = simple_mnist_model(input_shape)\n", | 
 |         "  tf.print(\"Constructed Keras MNIST model, training...\")\n", | 
 |         "\n", | 
 |         "  optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)\n", | 
 |         "  training_loss = tf.keras.metrics.Mean(\"training_loss\", dtype=tf.float32)\n", | 
 |         "  training_accuracy = tf.keras.metrics.CategoricalAccuracy(\n", | 
 |         "      \"training_accuracy\", dtype=tf.float32)\n", | 
 |         "  test_loss = tf.keras.metrics.Mean(\"test_loss\", dtype=tf.float32)\n", | 
 |         "  test_accuracy = tf.keras.metrics.CategoricalAccuracy(\n", | 
 |         "      \"test_accuracy\", dtype=tf.float32)\n", | 
 |         "\n", | 
 |         "  @tf.function\n", | 
 |         "  def train_step(iterator):\n", | 
 |         "    \"\"\"Training StepFn.\"\"\"\n", | 
 |         "\n", | 
 |         "    def step_fn(inputs):\n", | 
 |         "      \"\"\"Per-Replica StepFn.\"\"\"\n", | 
 |         "      images, labels = inputs\n", | 
 |         "      with tf.GradientTape() as tape:\n", | 
 |         "        logits = model(images, training=True)\n", | 
 |         "        loss = tf.keras.losses.categorical_crossentropy(labels, logits)\n", | 
 |         "        loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync\n", | 
 |         "      grads = tape.gradient(loss, model.trainable_variables)\n", | 
 |         "      optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", | 
 |         "      training_loss.update_state(loss)\n", | 
 |         "      training_accuracy.update_state(labels, logits)\n", | 
 |         "\n", | 
 |         "    strategy.run(step_fn, args=(next(iterator),))\n", | 
 |         "\n", | 
 |         "  @tf.function\n", | 
 |         "  def test_step(iterator):\n", | 
 |         "    \"\"\"Evaluation StepFn.\"\"\"\n", | 
 |         "\n", | 
 |         "    def step_fn(inputs):\n", | 
 |         "      images, labels = inputs\n", | 
 |         "      logits = model(images, training=False)\n", | 
 |         "      loss = tf.keras.losses.categorical_crossentropy(labels, logits)\n", | 
 |         "      loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync\n", | 
 |         "      test_loss.update_state(loss)\n", | 
 |         "      test_accuracy.update_state(labels, logits)\n", | 
 |         "\n", | 
 |         "    strategy.run(step_fn, args=(next(iterator),))\n", | 
 |         "\n", | 
 |         "  for epoch in range(0, num_epochs):\n", | 
 |         "    tf.print(\"Running epoch #%s\" % (epoch + 1))\n", | 
 |         "\n", | 
 |         "    train_iterator = iter(train_dist_ds)\n", | 
 |         "    for step in range(steps_per_epoch):\n", | 
 |         "      train_step(train_iterator)\n", | 
 |         "    tf.print(\"  Training loss: %f, accuracy: %f\" % (training_loss.result(), training_accuracy.result() * 100))\n", | 
 |         "    training_loss.reset_states()\n", | 
 |         "    training_accuracy.reset_states()\n", | 
 |         "\n", | 
 |         "    test_iterator = iter(test_dist_ds)\n", | 
 |         "    for step in range(steps_per_eval):\n", | 
 |         "      test_step(test_iterator)\n", | 
 |         "    tf.print(\"  Test loss    : %f, accuracy: %f\" % (test_loss.result(), test_accuracy.result() * 100))\n", | 
 |         "    test_loss.reset_states()\n", | 
 |         "    test_accuracy.reset_states()\n", | 
 |         "\n", | 
 |         "  tf.print(\"Completed training!\")\n", | 
 |         "  tf.print(\"\")\n", | 
 |         "\n", | 
 |         "  # Run a single prediction on the trained model\n", | 
 |         "  tf_prediction = model(sample_image_batch, training=False)\n", | 
 |         "  tf.print(\"Sample prediction:\")\n", | 
 |         "  tf.print(tf_prediction[0] * 100.0, summarize=100)\n", | 
 |         "  tf.print(\"\")" | 
 |       ], | 
 |       "execution_count": 6, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "Constructed Keras MNIST model, training...\n", | 
 |             "Running epoch #1\n", | 
 |             "  Training loss: 0.732439, accuracy: 81.403336\n", | 
 |             "  Test loss    : 0.390855, accuracy: 89.490005\n", | 
 |             "Running epoch #2\n", | 
 |             "  Training loss: 0.365308, accuracy: 89.811668\n", | 
 |             "  Test loss    : 0.315630, accuracy: 91.119995\n", | 
 |             "Running epoch #3\n", | 
 |             "  Training loss: 0.312111, accuracy: 91.129997\n", | 
 |             "  Test loss    : 0.281829, accuracy: 92.040001\n", | 
 |             "Running epoch #4\n", | 
 |             "  Training loss: 0.281028, accuracy: 92.038330\n", | 
 |             "  Test loss    : 0.258432, accuracy: 92.629997\n", | 
 |             "Running epoch #5\n", | 
 |             "  Training loss: 0.257909, accuracy: 92.753334\n", | 
 |             "  Test loss    : 0.240058, accuracy: 93.229996\n", | 
 |             "Completed training!\n", | 
 |             "\n", | 
 |             "Sample prediction:\n", | 
 |             "[0.243134052 0.00337268948 95.5214081 0.925373673 2.25061958e-05 0.992091119 2.20864391 3.87712953e-06 0.105901182 4.44369543e-05]\n", | 
 |             "\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "DmespEaFcSEL", | 
 |         "colab_type": "code", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 153 | 
 |         }, | 
 |         "outputId": "3c8579db-7b3c-4164-ff91-394346595107" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Export the trained model as a SavedModel, with IREE-compatible settings\n", | 
 |         "\n", | 
 |         "# Since the model was written in sequential style, explicitly wrap in a module.\n", | 
 |         "saved_model_dir = \"/tmp/mnist.sm\"\n", | 
 |         "inference_module = tf.Module()\n", | 
 |         "inference_module.model = model\n", | 
 |         "# Hack: Convert to static shape. Won't be necessary once dynamic shapes are in.\n", | 
 |         "dynamic_input_shape = list(model.inputs[0].shape)\n", | 
 |         "dynamic_input_shape[0] = 1  # Make fixed (batch=1)\n", | 
 |         "# Produce a concrete function.\n", | 
 |         "inference_module.predict = tf.function(\n", | 
 |         "    input_signature=[\n", | 
 |         "        tf.TensorSpec(dynamic_input_shape, model.inputs[0].dtype)])(\n", | 
 |         "            lambda x: model.call(x, training=False))\n", | 
 |         "save_options = tf.saved_model.SaveOptions(save_debug_info=True)\n", | 
 |         "tf.print(\"Exporting SavedModel to %s\" % saved_model_dir)\n", | 
 |         "tf.saved_model.save(inference_module, saved_model_dir, options=save_options)" | 
 |       ], | 
 |       "execution_count": 7, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "Exporting SavedModel to /tmp/mnist.sm\r\n", | 
 |             "WARNING:tensorflow:From c:\\users\\scott\\scoop\\apps\\python\\current\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n", | 
 |             "Instructions for updating:\n", | 
 |             "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", | 
 |             "WARNING:tensorflow:From c:\\users\\scott\\scoop\\apps\\python\\current\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", | 
 |             "Instructions for updating:\n", | 
 |             "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n", | 
 |             "INFO:tensorflow:Assets written to: /tmp/mnist.sm\\assets\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "markdown", | 
 |       "metadata": { | 
 |         "id": "nZdVUd_dgTtc", | 
 |         "colab_type": "text" | 
 |       }, | 
 |       "source": [ | 
 |         "# Compile and Execute MNIST Model using IREE" | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "rqwIx4j4gS1a", | 
 |         "colab_type": "code", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 836 | 
 |         }, | 
 |         "outputId": "24cded90-c436-47ce-b4f4-7a5da46ea38a" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Load the SavedModel into IREE's compiler as MLIR mhlo\n", | 
 |         "\n", | 
 |         "compiler_module = ireec.tf_load_saved_model(\n", | 
 |         "    saved_model_dir, exported_names=[\"predict\"])\n", | 
 |         "tf.print(\"Imported MLIR:\\n\", compiler_module.to_asm(large_element_limit=100))\n", | 
 |         "\n", | 
 |         "# Write to a file for use outside of this notebook.\n", | 
 |         "mnist_mlir_path = os.path.join(SAVE_PATH, \"mnist.mlir\")\n", | 
 |         "with open(mnist_mlir_path, \"wt\") as output_file:\n", | 
 |         "  output_file.write(compiler_module.to_asm())\n", | 
 |         "print(\"Wrote MLIR to path '%s'\" % mnist_mlir_path)" | 
 |       ], | 
 |       "execution_count": 8, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "Imported MLIR:\n", | 
 |             " \n", | 
 |             "\n", | 
 |             "module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 443 : i32}} {\n", | 
 |             "  flow.variable @\"__iree_flow___sm_node14__model.layer-1.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<784x128xf32> attributes {sym_visibility = \"private\"}\n", | 
 |             "  flow.variable @\"__iree_flow___sm_node15__model.layer-1.bias\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128xf32> attributes {sym_visibility = \"private\"}\n", | 
 |             "  flow.variable @\"__iree_flow___sm_node20__model.layer-2.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128x10xf32> attributes {sym_visibility = \"private\"}\n", | 
 |             "  flow.variable @\"__iree_flow___sm_node21__model.layer-2.bias\" dense<[-0.114143081, 0.0953421518, 4.84912744E-5, -0.0384164825, 0.0063888072, 0.218958765, 0.0256200824, 0.0551806651, -0.22108613, -0.0278935507]> : tensor<10xf32> attributes {sym_visibility = \"private\"}\n", | 
 |             "  func @predict(%arg0: tensor<1x28x28x1xf32> {tf._user_specified_name = \"x\"}) -> tensor<1x10xf32> attributes {iree.module.export, iree.reflection = {abi = \"sip\", abiv = 1 : i32, sip = \"I8!S5!k0_0R3!_0\"}, tf._input_shapes = [#tf.shape<1x28x28x1>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful} {\n", | 
 |             "    %0 = flow.variable.address @\"__iree_flow___sm_node14__model.layer-1.kernel\" : !iree.ptr<tensor<784x128xf32>>\n", | 
 |             "    %1 = flow.variable.address @\"__iree_flow___sm_node15__model.layer-1.bias\" : !iree.ptr<tensor<128xf32>>\n", | 
 |             "    %2 = flow.variable.address @\"__iree_flow___sm_node20__model.layer-2.kernel\" : !iree.ptr<tensor<128x10xf32>>\n", | 
 |             "    %3 = flow.variable.address @\"__iree_flow___sm_node21__model.layer-2.bias\" : !iree.ptr<tensor<10xf32>>\n", | 
 |             "    %4 = mhlo.constant dense<0xFF800000> : tensor<f32>\n", | 
 |             "    %5 = mhlo.constant dense<0.000000e+00> : tensor<f32>\n", | 
 |             "    %6 = flow.variable.load.indirect %3 : !iree.ptr<tensor<10xf32>> -> tensor<10xf32>\n", | 
 |             "    %7 = flow.variable.load.indirect %2 : !iree.ptr<tensor<128x10xf32>> -> tensor<128x10xf32>\n", | 
 |             "    %8 = flow.variable.load.indirect %1 : !iree.ptr<tensor<128xf32>> -> tensor<128xf32>\n", | 
 |             "    %9 = flow.variable.load.indirect %0 : !iree.ptr<tensor<784x128xf32>> -> tensor<784x128xf32>\n", | 
 |             "    %10 = \"mhlo.reshape\"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>\n", | 
 |             "    %11 = \"mhlo.dot\"(%10, %9) : (tensor<1x784xf32>, tensor<784x128xf32>) -> tensor<1x128xf32>\n", | 
 |             "    %12 = \"mhlo.broadcast_in_dim\"(%8) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32>\n", | 
 |             "    %13 = mhlo.add %11, %12 : tensor<1x128xf32>\n", | 
 |             "    %14 = \"mhlo.broadcast_in_dim\"(%5) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<1x128xf32>\n", | 
 |             "    %15 = mhlo.maximum %14, %13 : tensor<1x128xf32>\n", | 
 |             "    %16 = \"mhlo.dot\"(%15, %7) : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>\n", | 
 |             "    %17 = \"mhlo.broadcast_in_dim\"(%6) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<10xf32>) -> tensor<1x10xf32>\n", | 
 |             "    %18 = mhlo.add %16, %17 : tensor<1x10xf32>\n", | 
 |             "    %19 = \"mhlo.reduce\"(%18, %4) ( {\n", | 
 |             "    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors\n", | 
 |             "      %26 = mhlo.maximum %arg1, %arg2 : tensor<f32>\n", | 
 |             "      \"mhlo.return\"(%26) : (tensor<f32>) -> ()\n", | 
 |             "    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>\n", | 
 |             "    %20 = \"mhlo.broadcast_in_dim\"(%19) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>\n", | 
 |             "    %21 = mhlo.subtract %18, %20 : tensor<1x10xf32>\n", | 
 |             "    %22 = \"mhlo.exponential\"(%21) : (tensor<1x10xf32>) -> tensor<1x10xf32>\n", | 
 |             "    %23 = \"mhlo.reduce\"(%22, %5) ( {\n", | 
 |             "    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors\n", | 
 |             "      %26 = mhlo.add %arg1, %arg2 : tensor<f32>\n", | 
 |             "      \"mhlo.return\"(%26) : (tensor<f32>) -> ()\n", | 
 |             "    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>\n", | 
 |             "    %24 = \"mhlo.broadcast_in_dim\"(%23) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<1x10xf32>\n", | 
 |             "    %25 = mhlo.divide %22, %24 : tensor<1x10xf32>\n", | 
 |             "    return %25 : tensor<1x10xf32>\n", | 
 |             "  }\n", | 
 |             "}\r\n", | 
 |             "Wrote MLIR to path 'C:\\Users\\Scott\\saved_models\\mnist.mlir'\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "IDHI7h3khJr9", | 
 |         "colab_type": "code", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 51 | 
 |         }, | 
 |         "outputId": "b8958b7f-c7bb-4fbd-b800-e58c46134086" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Compile the mhlo MLIR and prepare a context to execute it\n", | 
 |         "\n", | 
 |         "# Compile the MLIR module into a VM module for execution\n", | 
 |         "flatbuffer_blob = compiler_module.compile(target_backends=[backend_name])\n", | 
 |         "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n", | 
 |         "\n", | 
 |         "# Register the module with a runtime context\n", | 
 |         "config = ireert.Config(driver_name)\n", | 
 |         "ctx = ireert.SystemContext(config=config)\n", | 
 |         "ctx.add_module(vm_module)" | 
 |       ], | 
 |       "execution_count": 9, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "Created IREE driver vulkan: <pyiree.rt.binding.HalDriver object at 0x000001DC44C47370>\n", | 
 |             "SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x000001DC44C47370>\n" | 
 |           ], | 
 |           "name": "stderr" | 
 |         } | 
 |       ] | 
 |     }, | 
 |     { | 
 |       "cell_type": "code", | 
 |       "metadata": { | 
 |         "id": "SKflpnLtkLYE", | 
 |         "colab_type": "code", | 
 |         "colab": { | 
 |           "base_uri": "https://localhost:8080/", | 
 |           "height": 102 | 
 |         }, | 
 |         "outputId": "337ddf79-6746-4685-89f5-65787280c0af" | 
 |       }, | 
 |       "source": [ | 
 |         "#@title Execute the compiled module and compare the results with TensorFlow\n", | 
 |         "\n", | 
 |         "# Invoke the 'predict' function with a single image as an argument\n", | 
 |         "iree_prediction = ctx.modules.module.predict(sample_image_batch)\n", | 
 |         "\n", | 
 |         "tf.print(\"IREE prediction ('%s' backend, '%s' driver):\" % (backend_name, driver_name))\n", | 
 |         "tf.print(tf.convert_to_tensor(iree_prediction[0]) * 100.0, summarize=100)\n", | 
 |         "tf.print(\"\")\n", | 
 |         "tf.print(\"TensorFlow prediction:\")\n", | 
 |         "tf.print(tf_prediction[0] * 100.0, summarize=100)" | 
 |       ], | 
 |       "execution_count": 10, | 
 |       "outputs": [ | 
 |         { | 
 |           "output_type": "stream", | 
 |           "text": [ | 
 |             "IREE prediction ('vulkan-spirv' backend, 'vulkan' driver):\n", | 
 |             "[0.243133873 0.00337268622 95.5214233 0.92537272 2.25061631e-05 0.992090821 2.20864058 3.87712225e-06 0.105901062 4.44369434e-05]\n", | 
 |             "\n", | 
 |             "TensorFlow prediction:\n", | 
 |             "[0.243134052 0.00337268948 95.5214081 0.925373673 2.25061958e-05 0.992091119 2.20864391 3.87712953e-06 0.105901182 4.44369543e-05]\n" | 
 |           ], | 
 |           "name": "stdout" | 
 |         } | 
 |       ] | 
 |     } | 
 |   ] | 
 | } |