Add Colab notebook that trains, compiles, and runs an MNIST model.

This is an updated version of an older (private) notebook that uses
IREE's latest Python bindings, compiler, and VM. Its output is similar
to the already checked in mnist.mlir files and docs/mnist_example.md.

Some of the verbose TensorFlow training steps are in hidden cells
(expanded on GitHub but folded by default in Colab), to make the
notebook easier to scan through.

Closes https://github.com/google/iree/pull/434

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/434 from ScottTodd:colab-mnist f1031c6535a1c612d2a35e0972f5655236375c14
PiperOrigin-RevId: 288940519
diff --git a/colab/README.md b/colab/README.md
index 83bdfb5..3de3a24 100644
--- a/colab/README.md
+++ b/colab/README.md
@@ -18,6 +18,13 @@
 
 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/iree/blob/master/colab/low_level_invoke_function.ipynb)
 
+### [mnist_tensorflow\.ipynb](mnist_tensorflow.ipynb)
+
+Trains a TensorFlow 2.0 model for recognizing handwritten digits and runs it
+using IREE
+
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/iree/blob/master/colab/mnist_tensorflow.ipynb)
+
 ### [simple_tensorflow_module_import\.ipynb](simple_tensorflow_module_import.ipynb)
 
 Defines a simple TF module, saves it and loads it in IREE
diff --git a/colab/mnist_tensorflow.ipynb b/colab/mnist_tensorflow.ipynb
new file mode 100644
index 0000000..30e11b3
--- /dev/null
+++ b/colab/mnist_tensorflow.ipynb
@@ -0,0 +1,626 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "mnist_tensorflow.ipynb",
+      "provenance": [],
+      "collapsed_sections": [
+        "PZtRtMMUZHJS"
+      ]
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "PZtRtMMUZHJS",
+        "colab_type": "text"
+      },
+      "source": [
+        "##### Copyright 2020 Google LLC.\n",
+        "\n",
+        "Licensed under the Apache License, Version 2.0 (the \"License\");"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "TouZL3JZZSQe",
+        "colab_type": "code",
+        "colab": {},
+        "cellView": "both"
+      },
+      "source": [
+        "#@title License header\n",
+        "# Copyright 2020 Google LLC\n",
+        "#\n",
+        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+        "# you may not use this file except in compliance with the License.\n",
+        "# You may obtain a copy of the License at\n",
+        "#\n",
+        "#      https://www.apache.org/licenses/LICENSE-2.0\n",
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing, software\n",
+        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+        "# See the License for the specific language governing permissions and\n",
+        "# limitations under the License."
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "O6c3qfq5Zv57",
+        "colab_type": "text"
+      },
+      "source": [
+        "# MNIST Model TensorFlow Training, IREE Execution\n",
+        "\n",
+        "## Overview\n",
+        "\n",
+        "This notebook creates and trains a TensorFlow 2.0 model for recognizing handwritten digits using the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database), then compiles and executes that trained model using IREE.\n",
+        "\n",
+        "## Running Locally\n",
+        "\n",
+        "*  Refer to [iree/docs/using_colab.md](https://github.com/google/iree/blob/master/docs/using_colab.md) for general information\n",
+        "*  Ensure that you have a recent version of TensorFlow 2.0 [installed on your system](https://www.tensorflow.org/install)\n",
+        "*  Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n",
+        "*  Start colab by running `python build_tools/scripts/start_colab_kernel.py` (see that file for additional instructions)\n",
+        "*  Note: you may need to restart your runtime in order to re-run certain cells. Some of the APIs are not yet stable enough for repeated invocations"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "wBXlE69Ia2QU",
+        "colab_type": "text"
+      },
+      "source": [
+        "# Setup Steps"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "EPF7RGQDYK-M",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "source": [
+        "import os\n",
+        "import numpy as np\n",
+        "import tensorflow as tf\n",
+        "from matplotlib import pyplot as plt\n",
+        "from pyiree import compiler as ireec\n",
+        "from pyiree import rt as ireert\n",
+        "\n",
+        "tf.compat.v1.enable_eager_execution()\n",
+        "\n",
+        "SAVE_PATH = os.path.join(os.environ[\"HOME\"], \"saved_models\")\n",
+        "os.makedirs(SAVE_PATH, exist_ok=True)"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "43BH_9YcsGs8",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 34
+        },
+        "cellView": "form",
+        "outputId": "e24a17ea-297b-491b-f97c-244b411ea8d1"
+      },
+      "source": [
+        "#@title Notebook settings { run: \"auto\" }\n",
+        "\n",
+        "#@markdown -----\n",
+        "#@markdown ### Configuration\n",
+        "\n",
+        "backend_choice = \"GPU (vulkan-spirv)\" #@param [ \"GPU (vulkan-spirv)\", \"CPU (interpreter bytecode)\" ]\n",
+        "\n",
+        "if backend_choice == \"GPU (vulkan-spirv)\":\n",
+        "  backend_name = \"vulkan-spirv\"\n",
+        "  driver_name = \"vulkan\"\n",
+        "else:\n",
+        "  backend_name = \"interpreter-bytecode\"\n",
+        "  driver_name = \"interpreter\"\n",
+        "tf.print(\"Using IREE compiler backend '%s' and runtime driver '%s'\" % (backend_name, driver_name))\n",
+        "\n",
+        "#@markdown -----\n",
+        "#@markdown ### Training Parameters\n",
+        "\n",
+        "#@markdown <sup>Batch size used to subdivide the training and evaluation samples</sup>\n",
+        "batch_size = 200  #@param { type: \"slider\", min: 10, max: 400 }\n",
+        "\n",
+        "#@markdown <sup>Epochs for training/eval. Higher values take longer to run but generally produce more accurate models</sup>\n",
+        "num_epochs = 5    #@param { type: \"slider\", min:  1, max:  20 }\n",
+        "\n",
+        "#@markdown -----"
+      ],
+      "execution_count": 3,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Using IREE compiler backend 'vulkan-spirv' and runtime driver 'vulkan'\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "5vkQOMOMbXdy",
+        "colab_type": "text"
+      },
+      "source": [
+        "# Create and Train MNIST Model in TensorFlow\n",
+        "\n",
+        "The specific details of the training process here aren't critical to the model compilation and execution through IREE."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "GXZIrReTbTHN",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 486
+        },
+        "cellView": "form",
+        "outputId": "d727261b-7457-47ba-97c6-aca339a9d91f"
+      },
+      "source": [
+        "#@title Load MNIST dataset, setup training and evaluation\n",
+        "\n",
+        "NUM_CLASSES = 10  # One per digit [0, 1, 2, ..., 9]\n",
+        "IMG_ROWS, IMG_COLS = 28, 28\n",
+        "\n",
+        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
+        "tf.print(\"Loaded MNIST dataset!\")\n",
+        "\n",
+        "x_train = x_train.reshape(x_train.shape[0], IMG_ROWS, IMG_COLS, 1)\n",
+        "x_test = x_test.reshape(x_test.shape[0], IMG_ROWS, IMG_COLS, 1)\n",
+        "input_shape = (IMG_ROWS, IMG_COLS, 1)\n",
+        "\n",
+        "# Scale pixel values from [0, 255] integers to [0.0, 1.0] floats.\n",
+        "x_train = x_train.astype(\"float32\") / 255\n",
+        "x_test = x_test.astype(\"float32\") / 255\n",
+        "\n",
+        "steps_per_epoch = int(x_train.shape[0] / batch_size)\n",
+        "steps_per_eval = int(x_test.shape[0] / batch_size)\n",
+        "\n",
+        "# Convert class vectors to binary class matrices.\n",
+        "y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)\n",
+        "y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)\n",
+        "\n",
+        "# Construct batched datasets for training/evaluation.\n",
+        "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
+        "train_dataset = train_dataset.batch(batch_size, drop_remainder=True)\n",
+        "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
+        "test_dataset = test_dataset.batch(batch_size, drop_remainder=True)\n",
+        "\n",
+        "# Create a distribution strategy for the dataset (single machine).\n",
+        "strategy = tf.distribute.experimental.CentralStorageStrategy()\n",
+        "train_dist_ds = strategy.experimental_distribute_dataset(train_dataset)\n",
+        "test_dist_ds = strategy.experimental_distribute_dataset(test_dataset)\n",
+        "\n",
+        "tf.print(\"Configured data for training and evaluation!\")\n",
+        "tf.print(\"  sample shape: %s\" % str(x_train[0].shape))\n",
+        "tf.print(\"  training samples: %s\" % x_train.shape[0])\n",
+        "tf.print(\"  test     samples: %s\" % x_test.shape[0])\n",
+        "tf.print(\"  epochs: %s\" % num_epochs)\n",
+        "tf.print(\"  steps/epoch: %s\" % steps_per_epoch)\n",
+        "tf.print(\"  steps/eval : %s\" % steps_per_eval)\n",
+        "\n",
+        "tf.print(\"\")\n",
+        "tf.print(\"Sample image from the dataset:\")\n",
+        "SAMPLE_EXAMPLE_INDEX = 1\n",
+        "sample_image = x_test[SAMPLE_EXAMPLE_INDEX]\n",
+        "sample_image_batch = np.expand_dims(sample_image, axis=0)\n",
+        "sample_label = y_test[SAMPLE_EXAMPLE_INDEX]\n",
+        "plt.imshow(sample_image.reshape(IMG_ROWS, IMG_COLS))\n",
+        "plt.show()\n",
+        "tf.print(\"\\nGround truth labels: %s\" % str(sample_label))"
+      ],
+      "execution_count": 4,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Loaded MNIST dataset!\n",
+            "INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ('/device:CPU:0',), variable_device = '/device:CPU:0'\n",
+            "Configured data for training and evaluation!\n",
+            "  sample shape: (28, 28, 1)\n",
+            "  training samples: 60000\n",
+            "  test     samples: 10000\n",
+            "  epochs: 5\n",
+            "  steps/epoch: 300\n",
+            "  steps/eval : 50\n",
+            "\n",
+            "Sample image from the dataset:\n"
+          ],
+          "name": "stdout"
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 432x288 with 1 Axes>"
+            ],
+            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAANzUlEQVR4nO3df6zV9X3H8dcL5IdFVBiMMSRaLMRiF6G9oXV1m8a1s/xRbLK5ks5hY3O7rG5tQtIat6Q2/RGzVN2WNV1oJaWLP+L8UVlqOpHaOFuCXhwFhLZQhyvsChJuB24ZcK/v/XG/NFe93++5nPM9P+T9fCQ355zv+3y/33eOvvie8/2c7/k4IgTg7Dep2w0A6AzCDiRB2IEkCDuQBGEHkjinkzub6mkxXTM6uUsglf/T/+hknPB4tZbCbvs6SX8nabKkb0bEHVXPn64Zeq+vbWWXACpsjc2ltabfxtueLOlrkj4kaamk1baXNrs9AO3Vymf2FZL2RcSLEXFS0gOSVtXTFoC6tRL2BZJ+MebxgWLZ69jutz1ge+CUTrSwOwCtaPvZ+IhYFxF9EdE3RdPavTsAJVoJ+0FJC8c8vqhYBqAHtRL25yQttv1221MlfVTSxnraAlC3pofeImLY9i2S/lWjQ2/rI+KF2joDUKuWxtkj4nFJj9fUC4A24uuyQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dGfkkZz9n/pysr6yPTyyTnnXv5K5bpbrni4qZ5Ou/T7H6+sz3z23NLavL//UUv7xpnhyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3gOGvru4sr5r2T+0bd+nyofoJ+Qn13yzsn5v3/zS2oObfq9y3ZE9e5vqCePjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDO3gGNxtF/uOyBtu37H3+5qLJ+15YPVNYvubj6evgnlj5SWf/YzMHS2pdvmlO57qLPMc5ep5bCbnu/pOOSRiQNR0RfHU0BqF8dR/ZrIuJIDdsB0EZ8ZgeSaDXsIekJ29ts94/3BNv9tgdsD5zSiRZ3B6BZrb6NvyoiDtr+dUmbbP8kIp4e+4SIWCdpnSSd79ktXnYBoFktHdkj4mBxe1jSo5JW1NEUgPo1HXbbM2zPPH1f0gcl7aqrMQD1auVt/DxJj9o+vZ37IuJ7tXT1FjN87Xsq69+/4msNtjClsvq3Q0sq60/9ccWI538drlx3ydBAZX3S9OmV9a9s/a3K+m1zdpbWhmcNV66LejUd9oh4UdIVNfYCoI0YegOSIOxAEoQdSIKwA0kQdiAJLnGtwasLplbWJzX4N7XR0NoPPlw9vDXy4k8r663Y94XllfX7Zt/ZYAvTSisXfY9jTSfxagNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoyz1+DCb2+prP/hwJ9U1j10rLI+PLj/DDuqzydWPllZP29S+Tg6egtHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2DhjZ/bNut1Bq/5evrKzffOFXG2yh+qem1w6+r7Q288k9leuONNgzzgxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2s9wvb6weR//hn1aPo18wqXocfcuJyZX17V8q/935c489W7ku6tXwyG57ve3DtneNWTbb9ibbe4vbWe1tE0CrJvI2/luSrnvDslslbY6IxZI2F48B9LCGYY+IpyUdfcPiVZI2FPc3SLq+5r4A1KzZz+zzImKwuP+ypHllT7TdL6lfkqbrbU3uDkCrWj4bHxEhKSrq6yKiLyL6plRM8gegvZoN+yHb8yWpuD1cX0sA2qHZsG+UtKa4v0bSY/W0A6BdGn5mt32/pKslzbF9QNLnJd0h6UHbN0t6SdIN7WwSzTvy7tJPWJIaj6M3suYHn6isL/kOY+m9omHYI2J1SenamnsB0EZ8XRZIgrADSRB2IAnCDiRB2IEkuMT1LHBy08WltS2X3dlg7eqhtyu2rKmsv3Ptzyvr/Bx07+DIDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM7+FnDOoksq6198xz+X1mY1uIR124nqfV/8xeqR8pGhoeoNoGdwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnfwu49MGDlfXlU5v/N3v15j+rrC/58XNNbxu9hSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOHsPGFpzZWX9C/Ma/fb7tNLKmv2/X7nmOz+7r7LO776fPRoe2W2vt33Y9q4xy263fdD29uJvZXvbBNCqibyN/5ak68ZZfndELCv+Hq+3LQB1axj2iHha0tEO9AKgjVo5QXeL7R3F2/xZZU+y3W97wPbAKTX4wTMAbdNs2L8u6VJJyyQNSio9gxQR6yKiLyL6plScSALQXk2FPSIORcRIRLwm6RuSVtTbFoC6NRV22/PHPPyIpF1lzwXQGxqOs9u+X9LVkubYPiDp85Kutr1MUkjaL+mTbezxLe+cBb9ZWf+dv9xaWT9vUvMff7bsfkdlfckQ16tn0TDsEbF6nMX3tKEXAG3E12WBJAg7kARhB5Ig7EAShB1IgktcO2DPbQsr69/5jX9pafvX7Pyj0hqXsOI0juxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7B2w7cN3N3hGa7/gc8Gfv1ZaGx4aamnbOHtwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwucmndBaW3KyQUd7OTNRl45UlqLE9XTgXla9fcPJs+d01RPkjQy98LK+t61U5ve9kTEiEtrl/1Fg98gOHasqX1yZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJBhnPwt896H13W6h1G//+3iTAI86cuj8ynVnzT1eWd/6nvua6qnXLf3rWyrriz67pantNjyy215o+ynbu22/YPvTxfLZtjfZ3lvczmqqAwAdMZG38cOS1kbEUknvk/Qp20sl3Sppc0QslrS5eAygRzUMe0QMRsTzxf3jkvZIWiBplaQNxdM2SLq+XU0CaN0ZfWa3fYmk5ZK2SpoXEYNF6WVJ80rW6ZfUL0nT9bZm+wTQogmfjbd9nqSHJX0mIl73TfyICEkx3noRsS4i+iKib0qLP6wIoHkTCrvtKRoN+r0R8Uix+JDt+UV9vqTD7WkRQB0avo23bUn3SNoTEXeNKW2UtEbSHcXtY23p8CywavfHKuub3/VQhzrpvB8tv79r+/7fOFlaOxXlP789ESt33FRZ/+/tzV9+u+CZ4abXrTKRz+zvl3SjpJ22txfLbtNoyB+0fbOklyTd0JYOAdSiYdgj4hlJZVfaX1tvOwDaha/LAkkQdiAJwg4kQdiBJAg7kASXuHbAuX/wH5X1y79SfUljtPG/0szLjlbW23kZ6eX/9vHKevznjJa2v+ihV8uLz+5saduztLelejdwZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJDz6IzOdcb5nx3vNhXJAu2yNzToWR8e9SpUjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTRMOy2F9p+yvZu2y/Y/nSx/HbbB21vL/5Wtr9dAM2ayPQDw5LWRsTztmdK2mZ7U1G7OyK+2r72ANRlIvOzD0oaLO4ft71H0oJ2NwagXmf0md32JZKWS9paLLrF9g7b623PKlmn3/aA7YFTOtFSswCaN+Gw2z5P0sOSPhMRxyR9XdKlkpZp9Mh/53jrRcS6iOiLiL4pmlZDywCaMaGw256i0aDfGxGPSFJEHIqIkYh4TdI3JK1oX5sAWjWRs/GWdI+kPRFx15jl88c87SOSdtXfHoC6TORs/Psl3Shpp+3txbLbJK22vUxSSNov6ZNt6RBALSZyNv4ZSeP9DvXj9bcDoF34Bh2QBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0Tndma/IumlMYvmSDrSsQbOTK/21qt9SfTWrDp7uzgi5o5X6GjY37RzeyAi+rrWQIVe7a1X+5LorVmd6o238UAShB1IotthX9fl/Vfp1d56tS+J3prVkd66+pkdQOd0+8gOoEMIO5BEV8Ju+zrbP7W9z/at3eihjO39tncW01APdLmX9bYP2941Ztls25ts7y1ux51jr0u99cQ03hXTjHf1tev29Ocd/8xue7Kkn0n6gKQDkp6TtDoidne0kRK290vqi4iufwHD9u9KelXStyPiXcWyv5F0NCLuKP6hnBURn+uR3m6X9Gq3p/EuZiuaP3aacUnXS7pJXXztKvq6QR143bpxZF8haV9EvBgRJyU9IGlVF/roeRHxtKSjb1i8StKG4v4Gjf7P0nElvfWEiBiMiOeL+8clnZ5mvKuvXUVfHdGNsC+Q9Isxjw+ot+Z7D0lP2N5mu7/bzYxjXkQMFvdfljSvm82Mo+E03p30hmnGe+a1a2b681Zxgu7NroqId0v6kKRPFW9Xe1KMfgbrpbHTCU3j3SnjTDP+K9187Zqd/rxV3Qj7QUkLxzy+qFjWEyLiYHF7WNKj6r2pqA+dnkG3uD3c5X5+pZem8R5vmnH1wGvXzenPuxH25yQttv1221MlfVTSxi708Sa2ZxQnTmR7hqQPqvemot4oaU1xf42kx7rYy+v0yjTeZdOMq8uvXdenP4+Ijv9JWqnRM/I/l/RX3eihpK9Fkn5c/L3Q7d4k3a/Rt3WnNHpu42ZJvyZps6S9kp6UNLuHevsnSTsl7dBosOZ3qberNPoWfYek7cXfym6/dhV9deR14+uyQBKcoAOSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJP4fcKgKSEIBgPIAAAAASUVORK5CYII=\n"
+          },
+          "metadata": {
+            "tags": [],
+            "needs_background": "light"
+          }
+        },
+        {
+          "output_type": "stream",
+          "text": [
+            "\n",
+            "Ground truth labels: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "tHq96SIJcNfx",
+        "colab_type": "code",
+        "colab": {},
+        "cellView": "both"
+      },
+      "source": [
+        "#@title Define MNIST model architecture using tf.keras API\n",
+        "\n",
+        "def simple_mnist_model(input_shape):\n",
+        "  \"\"\"Creates a simple (multi-layer perceptron) MNIST model.\"\"\"\n",
+        "\n",
+        "  model = tf.keras.models.Sequential()\n",
+        "  # Flatten to a 1d array (e.g. 28x28 -> 784)\n",
+        "  model.add(tf.keras.layers.Flatten(input_shape=input_shape))\n",
+        "  # Fully-connected neural layer with 128 neurons, RELU activation\n",
+        "  model.add(tf.keras.layers.Dense(128, activation='relu'))\n",
+        "  # Fully-connected neural layer returning probability scores for each class\n",
+        "  model.add(tf.keras.layers.Dense(10, activation='softmax'))\n",
+        "  return model"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "7Gdxh7qWcPSO",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 374
+        },
+        "cellView": "form",
+        "outputId": "e38d9b12-c27c-4174-824b-8bc232a2c417"
+      },
+      "source": [
+        "#@title Train the Keras model\n",
+        "\n",
+        "with strategy.scope():\n",
+        "  model = simple_mnist_model(input_shape)\n",
+        "  tf.print(\"Constructed Keras MNIST model, training...\")\n",
+        "\n",
+        "  optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)\n",
+        "  training_loss = tf.keras.metrics.Mean(\"training_loss\", dtype=tf.float32)\n",
+        "  training_accuracy = tf.keras.metrics.CategoricalAccuracy(\n",
+        "      \"training_accuracy\", dtype=tf.float32)\n",
+        "  test_loss = tf.keras.metrics.Mean(\"test_loss\", dtype=tf.float32)\n",
+        "  test_accuracy = tf.keras.metrics.CategoricalAccuracy(\n",
+        "      \"test_accuracy\", dtype=tf.float32)\n",
+        "\n",
+        "  @tf.function\n",
+        "  def train_step(iterator):\n",
+        "    \"\"\"Training StepFn.\"\"\"\n",
+        "\n",
+        "    def step_fn(inputs):\n",
+        "      \"\"\"Per-Replica StepFn.\"\"\"\n",
+        "      images, labels = inputs\n",
+        "      with tf.GradientTape() as tape:\n",
+        "        logits = model(images, training=True)\n",
+        "        loss = tf.keras.losses.categorical_crossentropy(labels, logits)\n",
+        "        loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync\n",
+        "      grads = tape.gradient(loss, model.trainable_variables)\n",
+        "      optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
+        "      training_loss.update_state(loss)\n",
+        "      training_accuracy.update_state(labels, logits)\n",
+        "\n",
+        "    strategy.experimental_run_v2(step_fn, args=(next(iterator),))\n",
+        "\n",
+        "  @tf.function\n",
+        "  def test_step(iterator):\n",
+        "    \"\"\"Evaluation StepFn.\"\"\"\n",
+        "\n",
+        "    def step_fn(inputs):\n",
+        "      images, labels = inputs\n",
+        "      logits = model(images, training=False)\n",
+        "      loss = tf.keras.losses.categorical_crossentropy(labels, logits)\n",
+        "      loss = tf.reduce_mean(loss) / strategy.num_replicas_in_sync\n",
+        "      test_loss.update_state(loss)\n",
+        "      test_accuracy.update_state(labels, logits)\n",
+        "\n",
+        "    strategy.experimental_run_v2(step_fn, args=(next(iterator),))\n",
+        "\n",
+        "  for epoch in range(0, num_epochs):\n",
+        "    tf.print(\"Running epoch #%s\" % (epoch + 1))\n",
+        "\n",
+        "    train_iterator = iter(train_dist_ds)\n",
+        "    for step in range(steps_per_epoch):\n",
+        "      train_step(train_iterator)\n",
+        "    tf.print(\"  Training loss: %f, accuracy: %f\" % (training_loss.result(), training_accuracy.result() * 100))\n",
+        "    training_loss.reset_states()\n",
+        "    training_accuracy.reset_states()\n",
+        "\n",
+        "    test_iterator = iter(test_dist_ds)\n",
+        "    for step in range(steps_per_eval):\n",
+        "      test_step(test_iterator)\n",
+        "    tf.print(\"  Test loss    : %f, accuracy: %f\" % (test_loss.result(), test_accuracy.result() * 100))\n",
+        "    test_loss.reset_states()\n",
+        "    test_accuracy.reset_states()\n",
+        "\n",
+        "  tf.print(\"Completed training!\")\n",
+        "  tf.print(\"\")\n",
+        "\n",
+        "  # Run a single prediction on the trained model\n",
+        "  tf_prediction = model(sample_image_batch, training=False)\n",
+        "  tf.print(\"Sample prediction:\")\n",
+        "  tf.print(tf_prediction[0] * 100.0, summarize=100)\n",
+        "  tf.print(\"\")"
+      ],
+      "execution_count": 6,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Constructed Keras MNIST model, training...\n",
+            "Running epoch #1\n",
+            "  Training loss: 0.714714, accuracy: 82.386665\n",
+            "  Test loss    : 0.386100, accuracy: 89.630005\n",
+            "Running epoch #2\n",
+            "  Training loss: 0.361406, accuracy: 90.100006\n",
+            "  Test loss    : 0.312490, accuracy: 91.000000\n",
+            "Running epoch #3\n",
+            "  Training loss: 0.308521, accuracy: 91.355003\n",
+            "  Test loss    : 0.278433, accuracy: 92.000000\n",
+            "Running epoch #4\n",
+            "  Training loss: 0.277716, accuracy: 92.261665\n",
+            "  Test loss    : 0.255600, accuracy: 92.650002\n",
+            "Running epoch #5\n",
+            "  Training loss: 0.255256, accuracy: 92.849998\n",
+            "  Test loss    : 0.237870, accuracy: 93.339996\n",
+            "Completed training!\n",
+            "\n",
+            "Sample prediction:\n",
+            "[0.942700624 0.0189254563 94.8591919 1.60412228 4.52805943e-06 1.21846104 0.845801711 3.33044341e-06 0.510759115 2.41755624e-05]\n",
+            "\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "DmespEaFcSEL",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 122
+        },
+        "outputId": "07ba7fb8-bd99-495d-941c-17320014573b"
+      },
+      "source": [
+        "#@title Export the trained model as a SavedModel, with IREE-compatible settings\n",
+        "\n",
+        "# Since the model was written in sequential style, explicitly wrap in a module.\n",
+        "saved_model_dir = \"/tmp/mnist.sm\"\n",
+        "inference_module = tf.Module()\n",
+        "inference_module.model = model\n",
+        "# Hack: Convert to static shape. Won't be necessary once dynamic shapes are in.\n",
+        "dynamic_input_shape = list(model.inputs[0].shape)\n",
+        "dynamic_input_shape[0] = 1  # Make fixed (batch=1)\n",
+        "# Produce a concrete function.\n",
+        "inference_module.predict = tf.function(\n",
+        "    input_signature=[\n",
+        "        tf.TensorSpec(dynamic_input_shape, model.inputs[0].dtype)])(\n",
+        "            lambda x: model.call(x, training=False))\n",
+        "save_options = tf.saved_model.SaveOptions(save_debug_info=True)\n",
+        "tf.print(\"Exporting SavedModel to %s\" % saved_model_dir)\n",
+        "tf.saved_model.save(inference_module, saved_model_dir, options=save_options)"
+      ],
+      "execution_count": 7,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Exporting SavedModel to /tmp/mnist.sm\n",
+            "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1788: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
+            "Instructions for updating:\n",
+            "If using Keras pass *_constraint arguments to layers.\n",
+            "INFO:tensorflow:Assets written to: /tmp/mnist.sm/assets\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "nZdVUd_dgTtc",
+        "colab_type": "text"
+      },
+      "source": [
+        "# Compile and Execute MNIST Model using IREE"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "rqwIx4j4gS1a",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 717
+        },
+        "outputId": "2a9b7627-8836-427d-af51-5e001c6a9226"
+      },
+      "source": [
+        "#@title Load the SavedModel into IREE's compiler as MLIR xla_hlo\n",
+        "\n",
+        "compiler_module = ireec.tf_load_saved_model(\n",
+        "    saved_model_dir, exported_names=[\"predict\"])\n",
+        "tf.print(\"Imported MLIR:\\n\", compiler_module.to_asm(large_element_limit=100))\n",
+        "\n",
+        "# Write to a file for use outside of this notebook.\n",
+        "mnist_mlir_path = os.path.join(SAVE_PATH, \"mnist.mlir\")\n",
+        "with open(mnist_mlir_path, \"wt\") as output_file:\n",
+        "  output_file.write(compiler_module.to_asm())\n",
+        "print(\"Wrote MLIR to path '%s'\" % mnist_mlir_path)"
+      ],
+      "execution_count": 8,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Imported MLIR:\n",
+            " \n",
+            "\n",
+            "module {\n",
+            "  flow.variable @\"__iree_flow___sm_node15__model.layer-2.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<784x128xf32>\n",
+            "  flow.variable @\"__iree_flow___sm_node16__model.layer-2.bias\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128xf32>\n",
+            "  flow.variable @\"__iree_flow___sm_node21__model.layer-3.kernel\" opaque<\"\", \"0xDEADBEEF\"> : tensor<128x10xf32>\n",
+            "  flow.variable @\"__iree_flow___sm_node22__model.layer-3.bias\" dense<[-0.124819458, 0.10169369, -0.0273698159, -0.0204691291, 0.0623821244, 0.213322133, -0.036071334, 0.112993836, -0.235609874, -0.0460525416]> : tensor<10xf32>\n",
+            "  func @predict(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x10xf32> attributes {iree.module.export, iree.reflection = {abi = \"sip\", abiv = 1 : i32, sip = \"I8!S5!k0_0R3!_0\"}, tf._input_shapes = [\"tfshape$dim { size: 1 } dim { size: 28 } dim { size: 28 } dim { size: 1 }\", \"tfshape$unknown_rank: true\", \"tfshape$unknown_rank: true\", \"tfshape$unknown_rank: true\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful} {\n",
+            "    %0 = flow.variable.load @\"__iree_flow___sm_node22__model.layer-3.bias\" : tensor<10xf32>\n",
+            "    %1 = flow.variable.load @\"__iree_flow___sm_node21__model.layer-3.kernel\" : tensor<128x10xf32>\n",
+            "    %2 = flow.variable.load @\"__iree_flow___sm_node16__model.layer-2.bias\" : tensor<128xf32>\n",
+            "    %3 = flow.variable.load @\"__iree_flow___sm_node15__model.layer-2.kernel\" : tensor<784x128xf32>\n",
+            "    %4 = xla_hlo.constant opaque<\"\", \"0xDEADBEEF\"> : tensor<1x128xf32>\n",
+            "    %5 = xla_hlo.constant dense<0xFF800000> : tensor<f32>\n",
+            "    %6 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>\n",
+            "    %7 = \"xla_hlo.reshape\"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>\n",
+            "    %8 = \"xla_hlo.dot\"(%7, %3) : (tensor<1x784xf32>, tensor<784x128xf32>) -> tensor<1x128xf32>\n",
+            "    %9 = \"xla_hlo.add\"(%8, %2) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>\n",
+            "    %10 = xla_hlo.max %9, %4 : tensor<1x128xf32>\n",
+            "    %11 = \"xla_hlo.dot\"(%10, %1) : (tensor<1x128xf32>, tensor<128x10xf32>) -> tensor<1x10xf32>\n",
+            "    %12 = \"xla_hlo.add\"(%11, %0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<10xf32>) -> tensor<1x10xf32>\n",
+            "    %13 = \"xla_hlo.reduce\"(%12, %5) ( {\n",
+            "    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):\t// no predecessors\n",
+            "      %18 = xla_hlo.max %arg1, %arg2 : tensor<f32>\n",
+            "      \"xla_hlo.return\"(%18) : (tensor<f32>) -> ()\n",
+            "    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>\n",
+            "    %14 = \"xla_hlo.sub\"(%12, %13) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1xf32>) -> tensor<1x10xf32>\n",
+            "    %15 = \"xla_hlo.exp\"(%14) : (tensor<1x10xf32>) -> tensor<1x10xf32>\n",
+            "    %16 = \"xla_hlo.reduce\"(%15, %6) ( {\n",
+            "    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):\t// no predecessors\n",
+            "      %18 = xla_hlo.add %arg1, %arg2 : tensor<f32>\n",
+            "      \"xla_hlo.return\"(%18) : (tensor<f32>) -> ()\n",
+            "    }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<f32>) -> tensor<1xf32>\n",
+            "    %17 = \"xla_hlo.div\"(%15, %16) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1xf32>) -> tensor<1x10xf32>\n",
+            "    return %17 : tensor<1x10xf32>\n",
+            "  }\n",
+            "}\n",
+            "\n",
+            "Wrote MLIR to path '/usr/local/google/home/scotttodd/saved_models/mnist.mlir'\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "IDHI7h3khJr9",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 51
+        },
+        "outputId": "f0676fc3-d7e7-497e-f28d-480e33c37fc9"
+      },
+      "source": [
+        "#@title Compile the xla_hlo MLIR and prepare a context to execute it\n",
+        "\n",
+        "# Compile the MLIR module into a VM module for execution\n",
+        "flatbuffer_blob = compiler_module.compile(target_backends=[backend_name])\n",
+        "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n",
+        "\n",
+        "# Register the module with a runtime context\n",
+        "config = ireert.Config(driver_name)\n",
+        "ctx = ireert.SystemContext(config=config)\n",
+        "ctx.add_module(vm_module)"
+      ],
+      "execution_count": 9,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Created IREE driver vulkan: <pyiree.rt.binding.HalDriver object at 0x7f6e4471bc70>\n",
+            "SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x7f6e4471bc70>\n"
+          ],
+          "name": "stderr"
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "SKflpnLtkLYE",
+        "colab_type": "code",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 102
+        },
+        "outputId": "48996e9a-f70f-4f45-b53c-fd9fab80d78d"
+      },
+      "source": [
+        "#@title Execute the compiled module and compare the results with TensorFlow\n",
+        "\n",
+        "# Invoke the 'predict' function with a single image as an argument\n",
+        "iree_prediction = ctx.modules.module.predict(sample_image_batch)\n",
+        "\n",
+        "tf.print(\"IREE prediction ('%s' backend, '%s' driver):\" % (backend_name, driver_name))\n",
+        "tf.print(tf.convert_to_tensor(iree_prediction[0]) * 100.0, summarize=100)\n",
+        "tf.print(\"\")\n",
+        "tf.print(\"TensorFlow prediction:\")\n",
+        "tf.print(tf_prediction[0] * 100.0, summarize=100)"
+      ],
+      "execution_count": 10,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "IREE prediction ('vulkan-spirv' backend, 'vulkan' driver):\n",
+            "[0.942700088 0.018925447 94.8592 1.60412395 4.52806762e-06 1.2184602 0.845801711 3.33043749e-06 0.510759652 2.41755624e-05]\n",
+            "\n",
+            "TensorFlow prediction:\n",
+            "[0.942700624 0.0189254563 94.8591919 1.60412228 4.52805943e-06 1.21846104 0.845801711 3.33044341e-06 0.510759115 2.41755624e-05]\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    }
+  ]
+}
\ No newline at end of file