blob: 4a1f1aad923d36b4e0026f655114d9e7e92ad904 [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "mnist_tensorflow.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "PZtRtMMUZHJS"
},
"source": [
"##### Copyright 2020 Google LLC.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"id": "TouZL3JZZSQe"
},
"source": [
"#@title License header\n",
"# Copyright 2020 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "O6c3qfq5Zv57"
},
"source": [
"# MNIST Model TensorFlow Training, IREE Execution\n",
"\n",
"## Overview\n",
"\n",
"This notebook creates and trains a TensorFlow 2.0 model for recognizing handwritten digits using the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database), then compiles and executes that trained model using IREE.\n",
"\n",
"## Running Locally\n",
"\n",
"* Refer to [using_colab.md](https://google.github.io/iree/using-iree/using-colab) for general information\n",
"* Ensure that you have a recent version of TensorFlow 2.0 [installed on your system](https://www.tensorflow.org/install)\n",
"* Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n",
"* Start colab by running `python colab/start_colab_kernel.py` (see that file for additional instructions)\n",
"* Note: you may need to restart your runtime in order to re-run certain cells. Some of the APIs are not yet stable enough for repeated invocations"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"id": "EPF7RGQDYK-M",
"outputId": "524db42f-b1d0-4f26-c911-3dc32bb19ba0",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"#@title Imports and Setup\n",
"\n",
"from pyiree import rt as ireert\n",
"from pyiree.tf import compiler as ireec\n",
"from pyiree.tf.support import module_utils\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"import os\n",
"import tempfile\n",
"import tensorflow as tf\n",
"\n",
"ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), 'iree', 'modules')\n",
"print(\"Artifacts directory is: \", ARTIFACTS_DIR)\n",
"\n",
"plt.style.use(\"seaborn-whitegrid\")\n",
"plt.rcParams[\"font.family\"] = \"monospace\"\n",
"\n",
"# Print version information for future notebook users to reference.\n",
"print(\"TensorFlow version: \", tf.__version__)\n",
"print(\"Numpy version: \", np.__version__)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Artifacts directory is: C:\\Users\\Scott\\AppData\\Local\\Temp\\iree\\modules\n",
"TensorFlow version: 2.5.0-dev20200626\n",
"Numpy version: 1.18.4\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5vkQOMOMbXdy"
},
"source": [
"# Create and Train MNIST Model in TensorFlow\n",
"\n",
"The specific details of the training process here aren't critical to the model compilation and execution through IREE."
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"id": "XPo8ATGqqZbW",
"outputId": "acfd2492-a7c5-484a-e9af-7d1a58579dc2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 314
}
},
"source": [
"#@title Load MNIST dataset, setup training and evaluation\n",
"\n",
"# Keras datasets don't provide metadata.\n",
"NUM_CLASSES = 10\n",
"NUM_ROWS, NUM_COLS = 28, 28\n",
"\n",
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
"\n",
"# Reshape into grayscale images:\n",
"x_train = np.reshape(x_train, (-1, NUM_ROWS, NUM_COLS, 1))\n",
"x_test = np.reshape(x_test, (-1, NUM_ROWS, NUM_COLS, 1))\n",
"\n",
"# Rescale uint8 pixel values into floats:\n",
"x_train = x_train / 255\n",
"x_test = x_test / 255\n",
"\n",
"# Explicitly cast to float32 because numpy defaults to double precision and\n",
"# IREE uses single precision:\n",
"x_train = x_train.astype(np.float32)\n",
"x_test = x_test.astype(np.float32)\n",
"\n",
"print(\"Sample image from the dataset:\")\n",
"sample_index = np.random.randint(x_train.shape[0])\n",
"plt.imshow(x_train[sample_index].reshape(NUM_ROWS, NUM_COLS), cmap=\"gray\")\n",
"plt.title(f\"Sample #{sample_index}, label: {y_train[sample_index]}\")\n",
"plt.axis(\"off\")\n",
"plt.tight_layout()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"Sample image from the dataset:\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAEYCAYAAACgIGhkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAR4ElEQVR4nO3deWxUZRuG8bvtFNFoAVERQou4sS9DS6FlKxUUyipQFcWguIC2BhFDMIAiQqJRolU0KlEMisjmgoLIIou4UJAi+yKihgqV1VKB0mW+PwgTG3n6nsoUSr/rl5Ag5zlzzkzx6jnlzUxYIBAICADOIvxCnwCAiotAADARCAAmAgHARCAAmAgEABOB+I9ycnLk9/vVrFkzDRw48EKfzkVn4cKF8vv9atKkiV5++eUy7Tt69Ogy73Oux/x/VWkCsWvXLg0ePFhxcXFKSEjQY489Vq7Hq1WrlrKysvTss8+Wy+OvXr1ajz/+uCSpY8eOOn78uCTpySefVLt27RQbG6vU1FRlZWUF9/nll1/0wAMPKC4uTsnJySUe7+mnn5bf7w/+atq0qXr16hXc/sYbb6hHjx5q2LChPv7447Oe07p169SgQQPNmTPnnJ9fSkqKsrKySpxDeSvPY3733Xfq3bu3WrZsqS5dumj79u0hP8aFUGkCkZaWpqSkJK1Zs0ZfffWV2rdvf6FP6Zxs2bJFTZo0UU5OjqKionTZZZdJkh588EEtW7ZMP/74o4YPH65HH31URUVFkiSfz6cePXpo1KhR/3q8CRMmKCsrK/jrtttu06233hrcHh0drbFjx6pJkyZnPZ/CwkK99NJLuv7668vh2V7c9u7dq/T0dA0bNkzr1q3Thx9+qKuvvvpCn1ZIVIpAHD58WL/99psGDBigiIgIRUVF6c477wxuX7lypfr06aNWrVqpS5cuevvtt3VmAWlycrKGDh2qdu3a6a233lK7du00btw4Sae/8A0aNNAbb7yh2NhY9e7dW1u3bvV0TkVFRZoyZYqSk5OVmJioiRMnqqCgwPNz2rx5s5o2bapNmzapWbNmwT9v2LChqlatqkAgoMLCQh0+fFhHjhyRJMXExKhfv36qW7duqY997NgxLVu2TH369An+Wa9evZSQkKAqVaqcdZ8PPvhASUlJuuqqqzw/h3MxdepU3XLLLfL7/erdu7dWrlxZYvv+/fvVu3dvtWnTRpMmTQpGUpIWL16snj17qnXr1nrooYf0559/ej5uTk6OunXrpiVLlnje55NPPlGHDh2UkpIin8+na665RjVr1vS8f0VWKQJRvXp11alTR+PGjVNmZqZOnTpVYnsgENC4ceO0Zs0aTZs2TdOmTdPXX38d3D5o0CD169dPS5cu1aJFi/T5558rPz8/uP3o0aP6/vvvlZqaqpEjR8rL6vRp06Zp2bJlmjlzphYvXqyff/5Z77//vnO/yZMnKy4uTkuXLlV6erpGjBihhQsXqmPHjsGZ8ePHq1mzZho6dKi6detW5v9pFyxYoEaNGikmJsbT/IEDBzRv3jzdf//9ZTrOuahWrZqmTp2q9evXa8SIEUpLS9Phw4eD21esWKGMjAwtXLhQ33zzjRYsWCBJ2rhxo8aMGaNJkybp+++/V+PGjfX00097Pm5BQYH27NmjY8eOed5n+/btqlatmlJTU5WYmKiRI0eWaf+KrFIEIjw8XO+9954iIiKUlpamxMRETZkyJbg9KSlJcXFxioyMVHR0tNq0aaNt27YFt9erV08xMTG67rrrdMUVVygqKkpHjx4Nbh88eLCqVKmigQMHKjs7W7/99pvznObMmaO0tDTVqlVLl19+ue655x4tXrzYud/IkSP1+uuvKzk5WZmZmapXr56WLl2qVatWBWfGjx+v9evXa/LkyerevbvXlyno448/Vt++fT3Pv/DCCxo2bJguueSSMh/rv7rjjjt0/fXXKywsTJ07d1ZUVJR2794d3J6cnKz69eurZs2aJa4w5s6dq9tvv10tWrSQz+fTkCFDtGLFin9907DUrVtXO3bsUL9+/Tyfa15enhYtWqRnn31WS5cuVV5enjIyMsr2hCso34U+gVCpV6+eJk+erOLiYq1Zs0bDhw9Xs2bN1KlTJ23cuFEvvviidu3apcLCQp08eVLXXXddcN/w8HBFREQoIiJC0ul7+cLCwuD2M5eLPp9PUVFROnjwYIn9z2b//v0aNWqUwsNPN7i4uNh5X7phwwY9+OCDOnnypHw+n+Li4nTq1Cl1795dr732mhISEoKzVapUUc+ePZWSkqKbbrpJN9xwg6fXaffu3dq+fbvnsPz444/au3evevTo4Wk+VD799FO9++672r9/v4qLi5WXl1fia3LllVcGf1+zZk1lZmZKkvbt26fMzMwSP2iNjIzUn3/+6bz1+q8uvfRSJSYmqnHjxpJOx41AVFDh4eFKSEhQfHy8fv75Z3Xq1ElPPPGEBg0apHfffVeRkZFKT0933ib8c/vBgwdVt25dFRYWKjc3t8T9ZWRkpIqLi/+1/7XXXqvnn39efr/f87m3bNlS69atU8+ePfXOO+/oiy++UFFRkR5++GFzn6KiIu3atctzIObNmxf8juzFpk2blJWVpQYNGgT/LDMzUzt37tSYMWM8PUZpzvb6ZWdna+zYsZo+fbr8fr/CwsIUHx9f4mty6NChEr8/8zWpXbu2Hn30UQ0dOrRMxzwX0dHROnjwYPC/A4GAp9vQi0GluMUoKipSRkaG9u/fL+n0PeG6deuCRf/7779Vo0YN+Xw+rVmzRqtXry7T40+fPl0FBQWaOXOmateurXr16gW31a9fX7t37y5xfyxJ/fv316uvvqqcnBwFAgHt2bPH03FPnDihY8eOqVatWtq0aZOaNm0a3HbgwAHNmTNHx44dU2FhoT766CNlZ2cH/+UhEAgoPz9fBQUFwd//89K6qKhI8+fPP+vtRUFBgfLz84M//MzPz1dxcbHuu+8+7dixI/grPj5eEydO/Fcc7r33Xt17773eXtB/qF+/vjZs2FDi6uDEiRMKCwtTzZo1VVRUpHfeeUe5ubkl9lu+fLn27NmjQ4cOaf78+ercubMkqV+/fvroo4+0ZcsWBQIBHTp0SAsXLnQe84ycnBx17drV0+3gGV27dtXKlSu1Y8cO5efna968eWrTpk1ZXoYKq1IEIjw8XL///rtSU1Pl9/uVnp6uRx55JHhJ/swzz+iVV15Rq1atNGPGjBI/8POievXqatu2rWbPnq3JkycHbxskqWnTpurbt6+6du0qv98f/BeFIUOGKDY2VnfffbdatWqlxx57rMR3Pcu2bdvUqFEjSaf/JeNM5CQpIiJCn3/+ubp06aLWrVtr5syZysjIUHR0tKTT33mbN2+uhx9+WH/88YeaN2+uBx54ILj/qlWrVFRUpA4dOvzruOPGjVPz5s2VlZUV/P3atWs9v0YnTpxQjRo1PM+fkZqaqsjISLVu3VoDBgyQJN14440aMmSIUlNT1b59e/3999+qU6dOif2SkpI0fPhwpaSkqGPHjkpJSZF0+ips9OjReuqppxQbG6v+/ftr8+bNzmOeUVBQoN9//115eXmen0N8fLzS0tI0ZMgQdejQQVWrVtXw4cPL/FpURGG8YYxt7969uuWWW7Rlyxb5fJXubixkTp06pdjYWE2dOlVt27a90KeDEKoUVxC4sLZt26abbrqJOFRCBALnrEWLFubybFzcuMUAYOIKAoCp1J+8hYWFna/zAHABWTcSXEEAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADDxXu6ViJePe6tevXqp2wcPHhyq00ElwBUEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYSv3wXj56r+Jo0KCBc2bTpk3OGddnNbdo0cL5GNu3b3fO4OLCR+8BKDMCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAw8Y5SFUB0dLRzZsmSJc4Zn8/95ZwxY0ap21kEhX/iCgKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEy8o1QFMGvWLOdMamqqc2bz5s3Omfbt25e6PTc31/kYqHx4RykAZUYgAJgIBAATgQBgIhAATAQCgIlAADARCAAmFkqVsyZNmjhnNmzY4JyJiIhwzrRt29Y5k5mZ6ZzB/x8WSgEoMwIBwEQgAJgIBAATgQBgIhAATAQCgIlAADDx0XvlbNiwYc4ZL4ugXB+ZJ0lr1671dE6AV1xBADARCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiXeUOgc33nijc2bHjh3OmUOHDjlnkpKSnDNbt251zgBnwztKASgzAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMPGOUufgzjvvdM54WWy2ceNG50xlXAR12WWXOWcmTpzonPGyYO3XX391zrzyyivOmV9++cU5U5lwBQHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACbeUcrg5ePwMjMznTN+v98507dvX+fM/PnznTMVSZUqVZwzXp7TrbfeGorT8eTkyZPOmfj4eOfM5s2bQ3E65xXvKAWgzAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcDEQilD7dq1nTPZ2dnOmZ07dzpnWrRo4ZzJz893zlQkgwYNcs5Mnz7dOZOVleWcmTBhgnPm5ptvds6MHz/eOeNlcdfAgQOdMxUNC6UAlBmBAGAiEABMBAKAiUAAMBEIACYCAcBEIACY+Oi9cnb8+HHnzMW2CMqLAQMGhORxnnvuOefMZ599FpJjefkowP79+4fkWBcLriAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcDEQqlyduDAgQt9CiHnZbFQSkqKc2bXrl3OmS+//NLTOYVCbm7ueTvWxYIrCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiXUQ5ezbb7+90KcQcp06dXLO+Hzuv1pvvvmmc+Z8vplOq1atztuxLhZcQQAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgImFUuWsT58+zpkJEyachzPxJiwszDmTmJgYkmP9+uuvIXkcL2rUqOGc6datm3Pm9ddfD8XpXDS4ggBgIhAATAQCgIlAADARCAAmAgHARCAAmAgEABMLpcqZl4VHFYmXBUUxMTEhOdbatWtD8jheXuMpU6Y4Z7w897lz53o6p8qCKwgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADCxUKqcXX311c6Z2rVrO2f27dsXitNxOnz4sHNm69atzpmOHTs6Z7p37+6cmTFjhnMmPT3dOTNw4EDnzDPPPOOc2blzp3OmWrVqzpm//vrLOVMRcAUBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCACmsEAgEDA3XmTvhhRKV1xxhXPmhx9+cM40atTIObN8+XLnzG233eacKSwsdM6EwlNPPeWcmTRpknMmLy/POXPgwAHnTP369Z0zXqxatco5U6tWLefMTz/95Jy56667PJ3T+WJlgCsIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwsVDqHLRs2dI5s379+pAcy8uirFmzZjlnFixYUOr2Sy65xPkYCQkJzpm3337bOVPReHnXroyMDOfM1KlTnTNHjhzxdE7nCwulAJQZgQBgIhAATAQCgIlAADARCAAmAgHARCAAmFgodQ58PvcnF/bq1cs54+Xj5apWrerpnFxOnDhR6nYvzykyMjIk5+LFsmXLnDOjR492zmRnZztnjh8/7pzJzc11zlyMWCgFoMwIBAATgQBgIhAATAQCgIlAADARCAAmAgHAxEKpCiAmJsY588ILLzhnUlNTnTMRERGlbi/lr0OQl4/Mmz17tnPmtddec85s3brVOVNQUOCcQelYKAWgzAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcDEQikALJQCUHYEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJgIBwOQrbWMgEDhf5wGgAuIKAoCJQAAwEQgAJgIBwEQgAJgIBADT/wAoeB4tE68iGgAAAABJRU5ErkJggg==\n"
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"id": "tHq96SIJcNfx"
},
"source": [
"#@title Define a DNN model using tf.keras API\n",
"\n",
"def simple_dnn(num_classes):\n",
" \"\"\"Creates a simple multi-layer perceptron model.\"\"\"\n",
"\n",
" model = tf.keras.models.Sequential()\n",
" # Flatten to a 1d array (e.g. 28x28x1 -> 784).\n",
" model.add(tf.keras.layers.Flatten())\n",
" # Fully-connected neural layer with 128 neurons, RELU activation.\n",
" model.add(tf.keras.layers.Dense(128, activation=\"relu\"))\n",
" # Fully-connected neural layer returning probability scores for each class.\n",
" model.add(tf.keras.layers.Dense(num_classes, activation=\"softmax\"))\n",
" return model"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"id": "43BH_9YcsGs8"
},
"source": [
"#@markdown ### Training Parameters\n",
"\n",
"batch_size = 32 #@param { type: \"slider\", min: 10, max: 400 }\n",
"num_epochs = 8 #@param { type: \"slider\", min: 1, max: 20 }"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"id": "7Gdxh7qWcPSO",
"outputId": "4017f6c9-9ca0-4e72-dc8a-104465231644",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 306
}
},
"source": [
"#@title Train the Keras model\n",
"\n",
"tf_model = simple_dnn(NUM_CLASSES)\n",
"# Stateful optimizers like Adam create variable incompatible with compilation as\n",
"# currently implemented.\n",
"tf_model.compile(\n",
" optimizer=\"sgd\", loss=\"sparse_categorical_crossentropy\", metrics=\"accuracy\")\n",
"tf_model.fit(x_train, y_train, batch_size, num_epochs, validation_split=0.1)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1/8\n",
"1688/1688 [==============================] - 1s 657us/step - loss: 0.6559 - accuracy: 0.8389 - val_loss: 0.3182 - val_accuracy: 0.9182\n",
"Epoch 2/8\n",
"1688/1688 [==============================] - 1s 542us/step - loss: 0.3476 - accuracy: 0.9031 - val_loss: 0.2620 - val_accuracy: 0.9287\n",
"Epoch 3/8\n",
"1688/1688 [==============================] - 1s 539us/step - loss: 0.2992 - accuracy: 0.9163 - val_loss: 0.2299 - val_accuracy: 0.9368\n",
"Epoch 4/8\n",
"1688/1688 [==============================] - 1s 542us/step - loss: 0.2687 - accuracy: 0.9249 - val_loss: 0.2095 - val_accuracy: 0.9428\n",
"Epoch 5/8\n",
"1688/1688 [==============================] - 1s 541us/step - loss: 0.2456 - accuracy: 0.9308 - val_loss: 0.1965 - val_accuracy: 0.9423\n",
"Epoch 6/8\n",
"1688/1688 [==============================] - 1s 538us/step - loss: 0.2269 - accuracy: 0.9362 - val_loss: 0.1806 - val_accuracy: 0.9500\n",
"Epoch 7/8\n",
"1688/1688 [==============================] - 1s 542us/step - loss: 0.2106 - accuracy: 0.9409 - val_loss: 0.1685 - val_accuracy: 0.9548\n",
"Epoch 8/8\n",
"1688/1688 [==============================] - 1s 545us/step - loss: 0.1971 - accuracy: 0.9444 - val_loss: 0.1601 - val_accuracy: 0.9588\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x15e010aac40>"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nZdVUd_dgTtc"
},
"source": [
"# Compile and Execute MNIST Model using IREE"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DmespEaFcSEL"
},
"source": [
"#@title Wrap the model in a tf.Module with IREE-compatible settings and convert to MLIR.\n",
"\n",
"# Since the model was written in sequential style, explicitly wrap in a module.\n",
"inference_module = tf.Module()\n",
"inference_module.model = tf_model\n",
"\n",
"# Hack: Convert to static shape. Won't be necessary once dynamic shapes are in.\n",
"input_shape = list(tf_model.inputs[0].shape)\n",
"input_shape[0] = 1 # Make fixed (batch=1)\n",
"\n",
"# Produce a concrete function to compile.\n",
"inference_module.predict = tf.function(input_signature=[\n",
" tf.TensorSpec(input_shape, tf_model.inputs[0].dtype)\n",
"])(lambda x: tf_model.call(x, training=False))\n",
"\n",
"# Only try to compile the function we care about:\n",
"exported_names = [\"predict\"]"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "G7v-2EbjyggO"
},
"source": [
"#@markdown ### Backend Configuration\n",
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
"backend = module_utils.BackendInfo(backend_choice)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IDHI7h3khJr9",
"outputId": "2d4117fb-f1d4-4788-8ad4-cd50bdf0e5a5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 459
}
},
"source": [
"#@title Compile the mhlo MLIR to an IREE backend and prepare a context to execute it\n",
"\n",
"iree_module = module_utils.IreeCompiledModule.create_from_instance(\n",
" inference_module, backend, exported_names, ARTIFACTS_DIR)\n",
"\n",
"print(\"* Module compiled! See intermediate .mlir files in\", ARTIFACTS_DIR, \"*\")"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x0000015E0A89D160> and will run it as-is.\n",
"Cause: could not parse the source code:\n",
"\n",
"])(lambda x: tf_model.call(x, training=False))\n",
"\n",
"This error may be avoided by creating the lambda in a standalone statement.\n",
"\n",
"To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
"WARNING: AutoGraph could not transform <function <lambda> at 0x0000015E0A89D160> and will run it as-is.\n",
"Cause: could not parse the source code:\n",
"\n",
"])(lambda x: tf_model.call(x, training=False))\n",
"\n",
"This error may be avoided by creating the lambda in a standalone statement.\n",
"\n",
"To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
"WARNING:tensorflow:From c:\\users\\scott\\scoop\\apps\\python\\current\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
"WARNING:tensorflow:From c:\\users\\scott\\scoop\\apps\\python\\current\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
"INFO:tensorflow:Assets written to: C:\\Users\\Scott\\AppData\\Local\\Temp\\tmpd667eaqu\\assets\n",
"* Module compiled! See intermediate .mlir files in C:\\Users\\Scott\\AppData\\Local\\Temp\\iree\\modules *\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Created IREE driver vmla: <pyiree.rt.binding.HalDriver object at 0x0000015E0AB1D570>\n",
"SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x0000015E0AB1D570>\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "S2FYao92Xd6r",
"outputId": "734f8b72-e752-43c4-efb8-4a301f8d8624",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
}
},
"source": [
"#@title Execute the compiled module and compare the results with TensorFlow\n",
"\n",
"# Invoke the 'predict' function with a single image as an argument\n",
"iree_prediction = iree_module.predict(x_train[sample_index][None, :])[0]\n",
"tf_prediction = tf_model.predict(x_train[sample_index][None, :])[0]\n",
"error = tf_prediction - iree_prediction\n",
"\n",
"fig, axs = plt.subplots(1, 2)\n",
"fig.set_figwidth(12)\n",
"\n",
"ax = axs[0]\n",
"ax.plot(iree_prediction, linewidth=2, label=backend.backend_name)\n",
"ax.plot(tf_prediction, linewidth=2, label=\"tf\")\n",
"\n",
"ax.set_title(\"Predictions\")\n",
"ax.set_ylabel(\"Softmax 'Probability'\")\n",
"ax.set_xlabel(\"Digit\")\n",
"ax.set_ylim(0, 1)\n",
"ax.set_xlim(0, 9)\n",
"ax.legend(frameon=True)\n",
"\n",
"\n",
"ax = axs[1]\n",
"ax.plot(error)\n",
"\n",
"ax.set_title(\"Error\")\n",
"ax.set_ylabel(\"Numerical between TF and IREE\")\n",
"ax.set_xlabel(\"Digit\")\n",
"ylim = 1.25 * np.max(np.abs(error))\n",
"ax.set_ylim(-ylim, ylim)\n",
"ax.set_xlim(0, 9)\n",
"\n",
"fig.tight_layout()"
],
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {
"tags": []
}
}
]
}
]
}