| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "mnist_tensorflow.ipynb", |
| "provenance": [], |
| "collapsed_sections": [] |
| }, |
| "kernelspec": { |
| "display_name": "Python 3", |
| "name": "python3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "PZtRtMMUZHJS" |
| }, |
| "source": [ |
| "##### Copyright 2020 Google LLC.\n", |
| "\n", |
| "Licensed under the Apache License, Version 2.0 (the \"License\");" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "cellView": "form", |
| "id": "TouZL3JZZSQe" |
| }, |
| "source": [ |
| "#@title License header\n", |
| "# Copyright 2020 Google LLC\n", |
| "#\n", |
| "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", |
| "# you may not use this file except in compliance with the License.\n", |
| "# You may obtain a copy of the License at\n", |
| "#\n", |
| "# https://www.apache.org/licenses/LICENSE-2.0\n", |
| "#\n", |
| "# Unless required by applicable law or agreed to in writing, software\n", |
| "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", |
| "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", |
| "# See the License for the specific language governing permissions and\n", |
| "# limitations under the License." |
| ], |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "O6c3qfq5Zv57" |
| }, |
| "source": [ |
| "# MNIST Model TensorFlow Training, IREE Execution\n", |
| "\n", |
| "## Overview\n", |
| "\n", |
| "This notebook creates and trains a TensorFlow 2.0 model for recognizing handwritten digits using the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database), then compiles and executes that trained model using IREE." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "ZpkMjTQxLzLq", |
| "outputId": "f2c8bee0-5407-4fda-8950-ed7a9666292e" |
| }, |
| "source": [ |
| "!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tf-snapshot -f https://github.com/google/iree/releases" |
| ], |
| "execution_count": 2, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Looking in links: https://github.com/google/iree/releases\n", |
| "Collecting iree-compiler-snapshot\n", |
| "\u001b[?25l Downloading https://github.com/google/iree/releases/download/snapshot-20210107.21/iree_compiler_snapshot-20210107.21-py3-none-manylinux2014_x86_64.whl (27.8MB)\n", |
| "\u001b[K |████████████████████████████████| 27.9MB 154kB/s \n", |
| "\u001b[?25hCollecting iree-runtime-snapshot\n", |
| "\u001b[?25l Downloading https://github.com/google/iree/releases/download/snapshot-20210107.21/iree_runtime_snapshot-20210107.21-cp36-cp36m-manylinux2014_x86_64.whl (1.0MB)\n", |
| "\u001b[K |████████████████████████████████| 1.0MB 56.9MB/s \n", |
| "\u001b[?25hCollecting iree-tools-tf-snapshot\n", |
| "\u001b[?25l Downloading https://github.com/google/iree/releases/download/snapshot-20210107.21/iree_tools_tf_snapshot-20210107.21-py3-none-manylinux2014_x86_64.whl (41.4MB)\n", |
| "\u001b[K |████████████████████████████████| 41.4MB 85kB/s \n", |
| "\u001b[?25hInstalling collected packages: iree-compiler-snapshot, iree-runtime-snapshot, iree-tools-tf-snapshot\n", |
| "Successfully installed iree-compiler-snapshot-20210107.21 iree-runtime-snapshot-20210107.21 iree-tools-tf-snapshot-20210107.21\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "EPF7RGQDYK-M", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "99323e03-bad5-4146-ec39-5fba2e62939b" |
| }, |
| "source": [ |
| "#@title Imports and Setup\n", |
| "\n", |
| "from pyiree import rt as ireert\n", |
| "from pyiree.tf.support import module_utils\n", |
| "from pyiree.compiler2 import tf as tfc\n", |
| "\n", |
| "from matplotlib import pyplot as plt\n", |
| "import numpy as np\n", |
| "import os\n", |
| "import tempfile\n", |
| "import tensorflow as tf\n", |
| "\n", |
| "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), 'iree', 'modules')\n", |
| "print(\"Artifacts directory is: \", ARTIFACTS_DIR)\n", |
| "\n", |
| "plt.style.use(\"seaborn-whitegrid\")\n", |
| "plt.rcParams[\"font.family\"] = \"monospace\"\n", |
| "\n", |
| "# Print version information for future notebook users to reference.\n", |
| "print(\"TensorFlow version: \", tf.__version__)\n", |
| "print(\"Numpy version: \", np.__version__)\n", |
| "\n", |
| "# (Temporary) workaround for absl flags...\n", |
| "# https://github.com/googlecolab/colabtools/issues/1323#issuecomment-756343620\n", |
| "import sys\n", |
| "from absl import app\n", |
| "sys.argv = sys.argv[:1]\n", |
| "try:\n", |
| " app.run(lambda argv: None)\n", |
| "except:\n", |
| " pass" |
| ], |
| "execution_count": 3, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Artifacts directory is: /tmp/iree/modules\n", |
| "TensorFlow version: 2.4.0\n", |
| "Numpy version: 1.19.4\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "5vkQOMOMbXdy" |
| }, |
| "source": [ |
| "# Create and Train MNIST Model in TensorFlow\n", |
| "\n", |
| "The specific details of the training process here aren't critical to the model compilation and execution through IREE." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "XPo8ATGqqZbW", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 348 |
| }, |
| "outputId": "b1a692a8-4523-4ba6-e00e-6cd71668fcc1" |
| }, |
| "source": [ |
| "#@title Load MNIST dataset, setup training and evaluation\n", |
| "\n", |
| "# Keras datasets don't provide metadata.\n", |
| "NUM_CLASSES = 10\n", |
| "NUM_ROWS, NUM_COLS = 28, 28\n", |
| "\n", |
| "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", |
| "\n", |
| "# Reshape into grayscale images:\n", |
| "x_train = np.reshape(x_train, (-1, NUM_ROWS, NUM_COLS, 1))\n", |
| "x_test = np.reshape(x_test, (-1, NUM_ROWS, NUM_COLS, 1))\n", |
| "\n", |
| "# Rescale uint8 pixel values into floats:\n", |
| "x_train = x_train / 255\n", |
| "x_test = x_test / 255\n", |
| "\n", |
| "# Explicitly cast to float32 because numpy defaults to double precision and\n", |
| "# IREE uses single precision:\n", |
| "x_train = x_train.astype(np.float32)\n", |
| "x_test = x_test.astype(np.float32)\n", |
| "\n", |
| "print(\"Sample image from the dataset:\")\n", |
| "sample_index = np.random.randint(x_train.shape[0])\n", |
| "plt.imshow(x_train[sample_index].reshape(NUM_ROWS, NUM_COLS), cmap=\"gray\")\n", |
| "plt.title(f\"Sample #{sample_index}, label: {y_train[sample_index]}\")\n", |
| "plt.axis(\"off\")\n", |
| "plt.tight_layout()" |
| ], |
| "execution_count": 4, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", |
| "11493376/11490434 [==============================] - 0s 0us/step\n", |
| "Sample image from the dataset:\n" |
| ], |
| "name": "stdout" |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAEYCAYAAACgIGhkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAARz0lEQVR4nO3dfUyV9f/H8dcBtJqCGiVF3ptZqTOSG8VKxSgllazQYTYbmuZNsbRMK+/2VZeZOVJr6tRmc1a0amlUBt7FbJCFZVpMTTFMMNFSTALOOb8/nOcX0zfXQRENn4+/Gtf7OtcH1CefA1fnuLxer1cAcB4Bl3sBAK5cBAKAiUAAMBEIACYCAcBEIACYCMQFKi4uVkREhLp06aLk5OTLvZz/nIyMDEVERKhTp05auHBhjc6dMmVKjc+52GterepNIPbs2aMRI0YoMjJSPXr00DPPPHNJrxcWFqa8vDzNmjXrkjx+dna2UlNTJUm9evXSqVOnJEnPP/+8evbsqW7duikpKUl5eXm+cxYtWqROnTopIiJCERERiouL8x0rLy/XrFmzFBsbq5iYGL322ms6ewvMgQMHNHLkSEVHR6tHjx6aPHmyTp486Tu3rKxM06dPV3R0tCIjIzVz5syL/vwSEhKUl5engQMHXvRjXUnXfPLJJ3Xfffddsseva/UmEOPHj1fv3r2Vk5OjL7/8Uvfcc8/lXtJF2bVrlzp37qzi4mIFBwerUaNGkqRRo0YpKytL3333nVJTUzVu3Di53W7fef3791deXp7y8vK0ceNG38fXrl2rH374QZ9//rk+/vhjff755/r0008lSaWlpRowYICysrK0ceNG/fPPP3r11Vd9586dO1cFBQVav369cnJyNGzYsDr6Kvy3ZGRkqLS09HIvo1bVi0AcO3ZMBQUFeuyxxxQYGKiQkBANHTrUd3zLli1KTExURESEYmNjtXDhQt93z7i4OI0ZM0Y9e/bU0qVL1bNnT02bNk2SVFhYqI4dO+qtt95St27dNGjQIO3evduvNbndbi1evFhxcXGKjY3V7NmzVVFR4ffn9NNPP6lz587auXOnunTp4vv47bffrmuvvVZer1eVlZU6duyYjh8/7vh43377rQYOHKgmTZooPDxcDz30kL766itJUufOnTV48GAFBwfruuuu08CBA7Vjxw5JZ3YPn3zyiaZNm6bmzZsrMDBQt912m9+fx4Vavny5+vbtq65duyo+Pl4ZGRlVjhcVFWnQoEGKiYnRnDlzqkRyw4YNGjBggKKiovTUU0/pyJEjfl+3uLhY/fr1831t/HXq1CktW7ZMTz/9dI3Ou9LVi0A0bdpU4eHhmjZtmnJyclReXl7luNfr1bRp05Sbm6v3339fH3zwgbKysnzHhw8frkceeUSZmZn64osvtG7dOv3zzz++43/++ae++eYbJSUladKkSfLn7vRVq1YpKytLa9eu1YYNG7R37169++67juctWLBAkZGRyszM1IQJE/Tcc88pIyOjyrZ15syZ6tKli8aMGaN+/frphhtu8B3btGmTYmJilJiYWGUHcb41FxQUnHcNeXl56tixo6QzTz9cLpcyMzMVGxurhISEGv/juRBNmjTR8uXLtWPHDr300kuaPHmyjh075ju+efNmpaWlKSMjQ19//bU+++wzSdKPP/6ol19+WXPmzNE333yjO++8U9OnT/f7uhUVFdq/f3+Vp1j+WLJkiYYMGaLg4OAanXelqxeBCAgI0DvvvKPAwEBNmDBBsbGxWrx4se947969FRkZqQYNGqhly5aKiYnRL7/84jveunVrtWrVSm3atFFwcLBCQkL0559/+o6PGDFCDRs2VHJysg4dOmT+w/q39PR0jR8/XmFhYWrcuLEef/xxbdiwwfG8SZMmacmSJYqLi1Nubq5at26trKwsbd261Tczc+ZMff/991qwYIH69+/v+/jZf7zZ2dkaP368Jk6cqF9//VWSFB0drXXr1un48eMqLCzUF198obKysnOu/9NPP+mjjz7y/fyjtLRUFRUV+u2337Rp0yZNnz5dkydPrtF35QsxZMgQtWvXTi6XS3369FFISIj27dvnOx4XF6e2bdsqNDRUgwYN0pYtWyRJH374oQYPHqyuXbsqKChIKSkp2rx58znfNCwtWrRQfn6+HnnkEb/Xum/fPm3btq3KrrW+CLrcC6gtrVu31oIFC+TxeJSTk6PU1FR16dJFvXr10o8//qj58+drz549qqysVFlZmdq0aeM7NyAgQIGBgQoMDJQkBQUFqbKy0nc8NDTU9/GQkBAdPXq0yvnnU1RUpMmTJysg4EyDPR6PbrzxxmrP2bFjh0aNGqWysjIFBQUpMjJS5eXl6tevnxYtWqQePXr4Zhs2bKgBAwYoISFBHTp0UPv27dW+fXvf8QceeEDp6enKzs5Wu3btlJycrIKCAt/TjPvvv185OTlVrl9YWKhnn31W8+bNU+vWrSVJ1157rdxut1JSUnTNNdeoe/fuatOmjX744QfFx8dX+/lcjE8++UQrV65UUVGRPB6PSktLq/yZXH/99b7/Dg0NVW5uriTp8OHDys3N1UcffeQ73qBBAx05ckQtWrS4JGudPXu2Jk6c6Pv7U5/Um0CcFRAQoB49eig6Olp79uxRr169NHHiRA0fPlwrV65UgwYNNGHCBMenCf8+fvToUbVo0UKVlZU6ceKELxjSmb98Ho/nnPNvuukmvfrqq4qIiPB77XfddZe2b9+uAQMGaMWKFVq/fr3cbrdGjx5tnuN2u7Vnz54qcTgrICDA93k0bNhQ06dP9223582bp9tvv903W1JSopEjRyo1NbXK05mWLVvK5XKd89i19T8Bn+/rd+jQIb3yyitavXq1IiIi5HK5FB0dXeWaJSUlVf777J/JzTffrHHjxmnMmDE1uubF2Llzp5566qkqH+vYsaO+/fZbhYSE1Np1Lod68RTD7XYrLS1NRUVFkqRffvlF27dvV6dOnSSd+QFSs2bNFBQUpJycHGVnZ9fo8VevXq2KigqtXbtWN998s++7qyS1bdtW+/btq/L8WJIeffRRvfnmmyouLpbX69X+/fv9uu7p06d18uRJhYWFaefOnercubPv2B9//KH09HSdPHlSlZWVeu+993To0CHf5/nVV1/pxIkT8ng82rx5s3Jzc32/zfn7779VUFAgj8ejbdu2KT09XUOGDJEknTx5UqNGjVJycrISExOrrKdJkyaKiorSO++8o4qKCm3fvl0HDhxQ165dq8w98cQTeuKJJ2rwVf3/r9+OHTuq7A5Onz4tl8ul0NBQud1urVixQidOnKhy3qZNm7R//36VlJTo008/VZ8+fSRJgwcP1nvvvaddu3bJ6/WqpKTknB9wnu+aZxUXFys+Pt6vp4Nnbd++Xfn5+crPz9fq1asVFham/Pz8/3wcpHqygwgICNDBgweVlJSk0tJShYaGauzYsb4t+YwZMzRv3jzNnDlT9957b41/T920aVN1795d4eHhWrBgge9pg3TmNwAPP/yw4uPj5fF4tHHjRjVr1kwpKSkqLy/XsGHDdPz4cYWHh5/zXeZ8fv75Z91xxx2Szvw84N/3HAQGBmrdunV6/fXXVV5erlatWiktLU0tW7aUJK1fv15TpkyRx+NRq1at9MYbb/h2FmVlZRo9erSKiorUvHlzzZgxw7e7yczM1O7du3XgwAGlpaX5rnf2Hou5c+dq6tSpioqKUlhYmF577TWFhYVVWffp06cVHh5eo6+rJCUlJSk7O1tRUVFq3769PvzwQ916661KSUlRUlKSAgICNGzYsHMeu3fv3kpNTVVxcbESExOVkJAgSYqIiNCUKVM0depUFRYWKiQkRAkJCb7j1jXPqqio0MGDB+vdrysvlIsXjLEVFhaqb9++2rVrl4KC6kVLL4ny8nJ169ZNy5cvV/fu3S/3clCL6sVTDFxeP//8szp06EAc6iECgYvWtWvXKr81QP3BUwwAJnYQAEzV/uTtfL//BlD/WE8k2EEAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYCAQAU9DlXgCkNm3aOM5MmjTJcWbcuHGOM+np6dUenz9/vuNjfPfdd44zqB/YQQAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgMnl9Xq95kGXqy7XctXasmWL40zPnj3rYCXSmjVrHGdGjBhRBytBXbIywA4CgIlAADARCAAmAgHARCAAmAgEABOBAGAiEABMvKLUJda/f3/Hmejo6DpYiX+Sk5MdZyoqKhxnRo0a5TiTlpbmODN48GDHmT59+jjO7Nu3z3EG52IHAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJl5R6hK7kl4tqi758zaAixcvdpwJCHD+HrZ8+XLHmbFjxzrOXM14RSkANUYgAJgIBAATgQBgIhAATAQCgIlAADARCAAmXlHqIjz00EOOM3fffXcdrKRu+fPqTFFRUY4z/twE5Y8bb7yxVh4H52IHAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJm6UuggvvPCC48x1111XBys5Y9u2bY4zM2bMqPb4gw8+6PgYb7/9tuPMr7/+6jhTzYuZ1ciKFStq5XFwLnYQAEwEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBg4kYpQ7NmzRxnbrvttjpYyRn+3AQ1dOhQx5nDhw9Xe3zTpk1+r+lK4c9NWbgw7CAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcB0Vd4oFRwc7Djz8ccfO840b968Npbjl8WLFzvO+PNWdrfccku1xwcNGuT4GDExMY4zdWn48OGOM5mZmY4z/rzClT9vO+h2ux1nioqKHGeuBOwgAJgIBAATgQBgIhAATAQCgIlAADARCAAml7eaX/66XK66XEudmTp1quPM//73vzpYSf3lz9+d2npnrSvNqVOnHGeWLVvmOOPPO7fVFuvPgh0EABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYrsoXjAEupUaNGjnOjB8/3nFm/fr1jjNbtmzxa00Xih0EABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYrsobpXr16uU4U19fTauu+PMuXx6Ppw5WcmUqLy93nLnUN0H5gx0EABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYrsobpeLj4x1n6uvbwjnZunWr40zjxo0dZ7p16+Y4U1tf4zVr1jjO+LPmxMTE2liOX2+9t3Tp0lq51qXGDgKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwubzV3q9TXV1WaM2eO48yLL75YK9c6ePCg48yQIUMcZ37//ffaWI6jv/76y3EmMDDQceb48eOOM7V1o1SnTp0cZw4cOOA4ExoaWgurkdxut+NMcXFxrVyrtlh/FuwgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHAdFXeKOXPDTE33HBDrVwrNjbWcWbVqlW1cq0riT83C9XljVL5+fm1cq36ihulANQYgQBgIhAATAQCgIlAADARCAAmAgHARCAAmK7Kt94rKSmplRl/cIMO/svYQQAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJgIBwEQgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHAFHS5FwBUp6KiwnHG7XbXwUquTuwgAJgIBAATgQBgIhAATAQCgIlAADARCAAmAgHAxI1SuCQ2bNjgOBMfH+84k5KS4jizd+9ev9aEmmMHAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgAJpfX6/WaB12uulwLgMvEygA7CAAmAgHARCAAmAgEABOBAGAiEABMBAKAiUAAMBEIACYCAcBEIACYCAQAE4EAYCIQAEwEAoCJQAAwEQgApmrfeq+aF5sCcBVgBwHARCAAmAgEABOBAGAiEABMBAKA6f8AI4soOmugBcgAAAAASUVORK5CYII=\n", |
| "text/plain": [ |
| "<Figure size 432x288 with 1 Axes>" |
| ] |
| }, |
| "metadata": { |
| "tags": [] |
| } |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "tHq96SIJcNfx" |
| }, |
| "source": [ |
| "#@title Define a DNN model using tf.keras API\n", |
| "\n", |
| "def simple_dnn(num_classes):\n", |
| " \"\"\"Creates a simple multi-layer perceptron model.\"\"\"\n", |
| "\n", |
| " model = tf.keras.models.Sequential()\n", |
| " # Flatten to a 1d array (e.g. 28x28x1 -> 784).\n", |
| " model.add(tf.keras.layers.Flatten())\n", |
| " # Fully-connected neural layer with 128 neurons, RELU activation.\n", |
| " model.add(tf.keras.layers.Dense(128, activation=\"relu\"))\n", |
| " # Fully-connected neural layer returning probability scores for each class.\n", |
| " model.add(tf.keras.layers.Dense(num_classes, activation=\"softmax\"))\n", |
| " return model" |
| ], |
| "execution_count": 5, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "43BH_9YcsGs8" |
| }, |
| "source": [ |
| "#@markdown ### Training Parameters\n", |
| "\n", |
| "batch_size = 32 #@param { type: \"slider\", min: 10, max: 400 }\n", |
| "num_epochs = 8 #@param { type: \"slider\", min: 1, max: 20 }" |
| ], |
| "execution_count": 6, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "7Gdxh7qWcPSO", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "ff0cf33d-353e-478b-b80c-3fb67d97744d" |
| }, |
| "source": [ |
| "#@title Train the Keras model\n", |
| "\n", |
| "tf_model = simple_dnn(NUM_CLASSES)\n", |
| "# Stateful optimizers like Adam create variable incompatible with compilation as\n", |
| "# currently implemented.\n", |
| "tf_model.compile(\n", |
| " optimizer=\"sgd\", loss=\"sparse_categorical_crossentropy\", metrics=\"accuracy\")\n", |
| "tf_model.fit(x_train, y_train, batch_size, num_epochs, validation_split=0.1)" |
| ], |
| "execution_count": 7, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Epoch 1/8\n", |
| "1688/1688 [==============================] - 4s 2ms/step - loss: 1.1068 - accuracy: 0.7158 - val_loss: 0.3270 - val_accuracy: 0.9142\n", |
| "Epoch 2/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.3705 - accuracy: 0.8982 - val_loss: 0.2648 - val_accuracy: 0.9288\n", |
| "Epoch 3/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.3143 - accuracy: 0.9133 - val_loss: 0.2326 - val_accuracy: 0.9397\n", |
| "Epoch 4/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.2728 - accuracy: 0.9236 - val_loss: 0.2127 - val_accuracy: 0.9440\n", |
| "Epoch 5/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.2529 - accuracy: 0.9294 - val_loss: 0.1982 - val_accuracy: 0.9477\n", |
| "Epoch 6/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.2327 - accuracy: 0.9347 - val_loss: 0.1853 - val_accuracy: 0.9523\n", |
| "Epoch 7/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.2141 - accuracy: 0.9403 - val_loss: 0.1720 - val_accuracy: 0.9572\n", |
| "Epoch 8/8\n", |
| "1688/1688 [==============================] - 3s 2ms/step - loss: 0.2040 - accuracy: 0.9417 - val_loss: 0.1652 - val_accuracy: 0.9588\n" |
| ], |
| "name": "stdout" |
| }, |
| { |
| "output_type": "execute_result", |
| "data": { |
| "text/plain": [ |
| "<tensorflow.python.keras.callbacks.History at 0x7fb807694160>" |
| ] |
| }, |
| "metadata": { |
| "tags": [] |
| }, |
| "execution_count": 7 |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "nZdVUd_dgTtc" |
| }, |
| "source": [ |
| "# Compile and Execute MNIST Model using IREE" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "DmespEaFcSEL" |
| }, |
| "source": [ |
| "#@title Wrap the model in a tf.Module with IREE-compatible settings and convert to MLIR.\n", |
| "\n", |
| "# Since the model was written in sequential style, explicitly wrap in a module.\n", |
| "inference_module = tf.Module()\n", |
| "inference_module.model = tf_model\n", |
| "\n", |
| "# Hack: Convert to static shape. Won't be necessary once dynamic shapes are in.\n", |
| "input_shape = list(tf_model.inputs[0].shape)\n", |
| "input_shape[0] = 1 # Make fixed (batch=1)\n", |
| "\n", |
| "# Produce a concrete function to compile.\n", |
| "inference_module.predict = tf.function(input_signature=[\n", |
| " tf.TensorSpec(input_shape, tf_model.inputs[0].dtype)\n", |
| "])(lambda x: tf_model.call(x, training=False))\n", |
| "\n", |
| "# Only try to compile the function we care about:\n", |
| "exported_names = [\"predict\"]" |
| ], |
| "execution_count": 8, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "G7v-2EbjyggO" |
| }, |
| "source": [ |
| "#@markdown ### Backend Configuration\n", |
| "\n", |
| "backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmaot (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n", |
| "backend_choice = backend_choice.split(\" \")[0]\n", |
| "backend = module_utils.BackendInfo(backend_choice)" |
| ], |
| "execution_count": 9, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "IDHI7h3khJr9", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "79f3191b-7e8e-4e10-9cea-cff168f7c958" |
| }, |
| "source": [ |
| "#@title Compile the mhlo MLIR to an IREE backend and prepare a context to execute it\n", |
| "\n", |
| "iree_module = module_utils.IreeCompiledModule.create_from_instance(\n", |
| " inference_module, backend, exported_names, ARTIFACTS_DIR)\n", |
| "\n", |
| "print(\"* Module compiled! See intermediate .mlir files in\", ARTIFACTS_DIR, \"*\")" |
| ], |
| "execution_count": 10, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "I0107 21:49:35.160961 140429953972096 module_utils.py:88] Outputting intermediate artifacts (--artifacts_dir is set):\n", |
| " output_file: /tmp/iree/modules/iree_vmla/compiled.vmfb\n", |
| " saved_model_dir: /tmp/iree/modules/tfmodule.saved_model\n", |
| " save_temp_tf_input: /tmp/iree/modules/tf_input.mlir\n", |
| " save_temp_iree_input: /tmp/iree/modules/iree_input.mlir\n", |
| " crash_reproducer_path: /tmp/iree/modules/reproducer__iree_vmla.mlir\n" |
| ], |
| "name": "stderr" |
| }, |
| { |
| "output_type": "stream", |
| "text": [ |
| "INFO:tensorflow:Assets written to: /tmp/iree/modules/tfmodule.saved_model/assets\n" |
| ], |
| "name": "stdout" |
| }, |
| { |
| "output_type": "stream", |
| "text": [ |
| "I0107 21:49:35.632078 140429953972096 builder_impl.py:775] Assets written to: /tmp/iree/modules/tfmodule.saved_model/assets\n" |
| ], |
| "name": "stderr" |
| }, |
| { |
| "output_type": "stream", |
| "text": [ |
| "* Module compiled! See intermediate .mlir files in /tmp/iree/modules *\n" |
| ], |
| "name": "stdout" |
| }, |
| { |
| "output_type": "stream", |
| "text": [ |
| "2021-01-07 21:49:35.767974: I external/org_tensorflow/tensorflow/cc/saved_model/bundle_v2.cc:32] Reading SavedModel from: /tmp/iree/modules/tfmodule.saved_model\n", |
| "2021-01-07 21:49:35.768968: I external/org_tensorflow/tensorflow/cc/saved_model/bundle_v2.cc:55] Reading SavedModel debug info (if present) from: /tmp/iree/modules/tfmodule.saved_model\n", |
| "Created IREE driver vmla: <pyiree.rt.binding.HalDriver object at 0x7fb803d9d2d0>\n", |
| "SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x7fb803d9d2d0>\n" |
| ], |
| "name": "stderr" |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "S2FYao92Xd6r", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 297 |
| }, |
| "outputId": "38a1c963-f484-455e-da78-1dcabdf1f92b" |
| }, |
| "source": [ |
| "#@title Execute the compiled module and compare the results with TensorFlow\n", |
| "\n", |
| "# Invoke the 'predict' function with a single image as an argument\n", |
| "iree_prediction = iree_module.predict(x_train[sample_index][None, :])[0]\n", |
| "tf_prediction = tf_model.predict(x_train[sample_index][None, :])[0]\n", |
| "error = tf_prediction - iree_prediction\n", |
| "\n", |
| "fig, axs = plt.subplots(1, 2)\n", |
| "fig.set_figwidth(12)\n", |
| "\n", |
| "ax = axs[0]\n", |
| "ax.plot(iree_prediction, linewidth=2, label=backend.backend_name)\n", |
| "ax.plot(tf_prediction, linewidth=2, label=\"tf\")\n", |
| "\n", |
| "ax.set_title(\"Predictions\")\n", |
| "ax.set_ylabel(\"Softmax 'Probability'\")\n", |
| "ax.set_xlabel(\"Digit\")\n", |
| "ax.set_ylim(0, 1)\n", |
| "ax.set_xlim(0, 9)\n", |
| "ax.legend(frameon=True)\n", |
| "\n", |
| "ax = axs[1]\n", |
| "ax.plot(error)\n", |
| "\n", |
| "ax.set_title(\"Error\")\n", |
| "ax.set_ylabel(\"Numerical between TF and IREE\")\n", |
| "ax.set_xlabel(\"Digit\")\n", |
| "ylim = 1.25 * np.max(np.abs(error))\n", |
| "ax.set_ylim(-ylim, ylim)\n", |
| "ax.set_xlim(0, 9)\n", |
| "\n", |
| "fig.tight_layout()" |
| ], |
| "execution_count": 11, |
| "outputs": [ |
| { |
| "output_type": "display_data", |
| "data": { |
| "image/png": "\n", |
| "text/plain": [ |
| "<Figure size 864x288 with 2 Axes>" |
| ] |
| }, |
| "metadata": { |
| "tags": [] |
| } |
| } |
| ] |
| } |
| ] |
| } |