| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "tflite_text_classification.ipynb", |
| "provenance": [], |
| "collapsed_sections": [ |
| "dweksnWs1eZb" |
| ] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "_VsBo9D7GJ8G" |
| }, |
| "source": [ |
| "```\n", |
| "Copyright 2021 The IREE Authors\n", |
| "\n", |
| "Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "See https://llvm.org/LICENSE.txt for license information.\n", |
| "SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n", |
| "```" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "3Mjqf_WmSkAs" |
| }, |
| "source": [ |
| "# TFLite text classification sample with IREE\n", |
| "\n", |
| "This notebook demonstrates how to download, compile, and run a TFLite model with IREE. It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE. The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "W71slxNjS3SB" |
| }, |
| "source": [ |
| "## Setup" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "T8WMt2sPS4ft" |
| }, |
| "source": [ |
| "%%capture\n", |
| "!python -m pip install iree-compiler iree-runtime iree-tools-tflite -f https://github.com/google/iree/releases/latest\n", |
| "!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime" |
| ], |
| "execution_count": 14, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "cellView": "code", |
| "id": "7L4_gnRVVBdi" |
| }, |
| "source": [ |
| "import numpy as np\n", |
| "import urllib.request\n", |
| "import pathlib\n", |
| "import tempfile\n", |
| "import re\n", |
| "import tflite_runtime.interpreter as tflite\n", |
| "\n", |
| "from iree import runtime as iree_rt\n", |
| "from iree.compiler import compile_str\n", |
| "from iree.tools import tflite as iree_tflite\n", |
| "\n", |
| "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", |
| "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)" |
| ], |
| "execution_count": 15, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "dweksnWs1eZb" |
| }, |
| "source": [ |
| "### Load the TFLite model\n", |
| "\n", |
| "1. Download files for the pretrained model\n", |
| "2. Extract model metadata used for input pre-processing and output post-processing\n", |
| "3. Define helper functions for pre- and post-processing\n", |
| "\n", |
| "These steps will differ from model to model. Consult the model source or reference documentation for details.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "cUTaotkV7taP", |
| "outputId": "36f8be61-9565-4da4-ae92-8b0eceb05419" |
| }, |
| "source": [ |
| "#@title Download pretrained text classification model\n", |
| "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n", |
| "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))" |
| ], |
| "execution_count": 16, |
| "outputs": [ |
| { |
| "output_type": "execute_result", |
| "data": { |
| "text/plain": [ |
| "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n", |
| " <http.client.HTTPMessage at 0x7f91d58b8990>)" |
| ] |
| }, |
| "metadata": { |
| "tags": [] |
| }, |
| "execution_count": 16 |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "ADu-gSDnIm2B", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "63465e5d-8de8-4b99-a386-5fd264b03bc0" |
| }, |
| "source": [ |
| "#@title Extract model vocab and label metadata\n", |
| "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n", |
| "\n", |
| "# Load the vocab file into a dictionary. It contains the most common 1,000\n", |
| "# words in the English language, mapped to an integer.\n", |
| "vocab = {}\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n", |
| " for line in vocab_file:\n", |
| " (key, val) = line.split()\n", |
| " vocab[key] = int(val)\n", |
| "\n", |
| "# Text will be labeled as either 'Positive' or 'Negative'.\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n", |
| " labels = label_file.read().splitlines()" |
| ], |
| "execution_count": 17, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Archive: /tmp/iree/colab_artifacts/text_classification.tflite\n", |
| " extracting: /tmp/iree/colab_artifacts/labels.txt \n", |
| " extracting: /tmp/iree/colab_artifacts/vocab.txt \n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "Z1WKEZY1JH6E" |
| }, |
| "source": [ |
| "#@title Input and output processing\n", |
| "\n", |
| "# Input text will be encoded as an integer array of fixed length 256. The \n", |
| "# input sentence will be mapped to integers from the vocab dictionary, and the \n", |
| "# empty array spaces are filled with padding.\n", |
| "\n", |
| "SENTENCE_LEN = 256\n", |
| "START = \"<START>\"\n", |
| "PAD = \"<PAD>\"\n", |
| "UNKNOWN = \"<UNKNOWN>\"\n", |
| "\n", |
| "def tokenize_input(text):\n", |
| " output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n", |
| " output.fill(vocab[PAD])\n", |
| "\n", |
| " # Remove capitalization and punctuation from the input text.\n", |
| " text_split = text.split()\n", |
| " text_split = [text.lower() for text in text_split]\n", |
| " text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n", |
| "\n", |
| " # Prepend <START>.\n", |
| " index = 0\n", |
| " output[0][index] = vocab[START]\n", |
| " index += 1\n", |
| "\n", |
| " for word in text_split:\n", |
| " output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n", |
| " index += 1\n", |
| "\n", |
| " return output\n", |
| "\n", |
| "\n", |
| "def interpret_output(output):\n", |
| " if output[0] >= output[1]:\n", |
| " label = labels[0]\n", |
| " confidence = output[0]\n", |
| " else:\n", |
| " label = labels[1]\n", |
| " confidence = output[1]\n", |
| "\n", |
| " print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))" |
| ], |
| "execution_count": 18, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "xq00JknF0cOz", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "d04fddad-852f-426a-96c7-397d9e5cd72d" |
| }, |
| "source": [ |
| "#@title Text samples\n", |
| "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", |
| "negative_text = \"What a waste of my time.\"\n", |
| "\n", |
| "print(positive_text)\n", |
| "print(tokenize_input(positive_text))" |
| ], |
| "execution_count": 19, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "[[ 1 13 8 3 117 19 206 109 10 1134 152 2301 385 11\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0]]\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "fvxhY1W7X7Gf" |
| }, |
| "source": [ |
| "## Run using TFLite\n", |
| "\n", |
| "Overview:\n", |
| "\n", |
| "1. Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n", |
| "2. Allocate tensors and get the input and output shape information\n", |
| "3. Invoke the TFLite Interpreter to test the text classification function" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "Nv8x81S5eV5g" |
| }, |
| "source": [ |
| "interpreter = tflite.Interpreter(\n", |
| " model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n", |
| "interpreter.allocate_tensors()\n", |
| "input_details = interpreter.get_input_details()\n", |
| "output_details = interpreter.get_output_details()\n", |
| "\n", |
| "def classify_text_tflite(text):\n", |
| " interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n", |
| " interpreter.invoke()\n", |
| " output_data = interpreter.get_tensor(output_details[0]['index'])\n", |
| " interpret_output(output_data[0])" |
| ], |
| "execution_count": 20, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "kpxfU88ckFxI", |
| "outputId": "a6ef8776-8b9b-4b29-931a-387e1e6b18fa" |
| }, |
| "source": [ |
| "print(\"Invoking text classification with TFLite\\n\")\n", |
| "print(positive_text)\n", |
| "classify_text_tflite(positive_text)\n", |
| "print()\n", |
| "print(negative_text)\n", |
| "classify_text_tflite(negative_text)" |
| ], |
| "execution_count": 21, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Invoking text classification with TFLite\n", |
| "\n", |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "Label: Positive\n", |
| "Confidence: 0.8997293\n", |
| "\n", |
| "What a waste of my time.\n", |
| "Label: Negative\n", |
| "Confidence: 0.6275043\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "_FgwUBtm7Y5n" |
| }, |
| "source": [ |
| "## Run using IREE\n", |
| "\n", |
| "Overview:\n", |
| "\n", |
| "1. Import the TFLite model to TOSA MLIR \n", |
| "2. Compile the TOSA MLIR into an IREE flatbuffer and VM module\n", |
| "3. Run the VM module through IREE's runtime to test the text classification function\n", |
| "\n", |
| "Both runtimes should generate the same output.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "ZifuTzgEXLJn" |
| }, |
| "source": [ |
| "# Convert TFLite model to TOSA MLIR with IREE's import tool.\n", |
| "IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')\n", |
| "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification.mlir\n", |
| "\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n", |
| " tosa_mlir = mlir_file.read()\n", |
| "\n", |
| "# The generated .mlir file could now be saved and used outside of Python, with\n", |
| "# IREE native tools or in apps, etc." |
| ], |
| "execution_count": 22, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "-C-5X4C0D0La", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "dc92732b-793d-431c-ca67-a1427044e7b8" |
| }, |
| "source": [ |
| "# The model contains very large constants, so recompile a truncated version to print.\n", |
| "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification_truncated.mlir -mlir-elide-elementsattrs-if-larger=50\n", |
| "\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"text_classification_truncated.mlir\")) as truncated_mlir_file:\n", |
| " truncated_tosa_mlir = truncated_mlir_file.read()\n", |
| " print(truncated_tosa_mlir)" |
| ], |
| "execution_count": 23, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "builtin.module {\n", |
| " builtin.func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n", |
| " %0 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<10003x16xf32>} : () -> tensor<10003x16xf32>\n", |
| " %1 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n", |
| " %2 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n", |
| " %3 = \"tosa.const\"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722, 0.795746565, -1.3122592, 0.352218777, 1.04682362, -1.18796027, -0.0409261398, 1.05883229, 1.48620188, -1.13325548, 1.03072512]]> : tensor<2x16xf32>} : () -> tensor<2x16xf32>\n", |
| " %4 = \"tosa.const\"() {value = dense<[0.043447677, -0.0434476472]> : tensor<2xf32>} : () -> tensor<2xf32>\n", |
| " %5 = \"tosa.const\"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>\n", |
| " %6 = \"tosa.const\"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>\n", |
| " %7 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<f32>} : () -> tensor<f32>\n", |
| " %8 = \"tosa.transpose\"(%0, %5) : (tensor<10003x16xf32>, tensor<2xi32>) -> tensor<10003x16xf32>\n", |
| " %9 = \"tosa.reshape\"(%8) {new_shape = [1, 10003, 16]} : (tensor<10003x16xf32>) -> tensor<1x10003x16xf32>\n", |
| " %10 = \"tosa.gather\"(%9, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n", |
| " %11 = \"tosa.transpose\"(%10, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n", |
| " %12 = \"tosa.reduce_sum\"(%11) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n", |
| " %13 = \"tosa.reshape\"(%12) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n", |
| " %14 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n", |
| " %15 = \"tosa.mul\"(%13, %14) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n", |
| " %16 = \"tosa.fully_connected\"(%15, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n", |
| " %17 = \"tosa.clamp\"(%16) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n", |
| " %18 = \"tosa.fully_connected\"(%17, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n", |
| " %19 = \"tosa.exp\"(%18) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n", |
| " %20 = \"tosa.reduce_sum\"(%19) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n", |
| " %21 = \"tosa.reciprocal\"(%20) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n", |
| " %22 = \"tosa.mul\"(%19, %21) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n", |
| " return %22 : tensor<1x2xf32>\n", |
| " }\n", |
| "}\n", |
| "\n", |
| "\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "M3gXX2AF7aS9" |
| }, |
| "source": [ |
| "# Compile the TOSA MLIR into a VM module.\n", |
| "compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", |
| "vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n", |
| "\n", |
| "# Register the module with a runtime context.\n", |
| "config = iree_rt.Config(\"vmvx\")\n", |
| "ctx = iree_rt.SystemContext(config=config)\n", |
| "ctx.add_vm_module(vm_module)\n", |
| "invoke_text_classification = ctx.modules.module[\"main\"]\n", |
| "\n", |
| "def classify_text_iree(text):\n", |
| " result = invoke_text_classification(tokenize_input(text))\n", |
| " interpret_output(result[0])" |
| ], |
| "execution_count": 24, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "jvv9zhMwAWgZ", |
| "outputId": "b086363f-28b9-4ff2-817e-e1928d7cfe4a" |
| }, |
| "source": [ |
| "print(\"Invoking text classification with IREE\\n\")\n", |
| "print(positive_text)\n", |
| "classify_text_iree(positive_text)\n", |
| "print()\n", |
| "print(negative_text)\n", |
| "classify_text_iree(negative_text)" |
| ], |
| "execution_count": 25, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Invoking text classification with IREE\n", |
| "\n", |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "Label: Positive\n", |
| "Confidence: 0.8997293\n", |
| "\n", |
| "What a waste of my time.\n", |
| "Label: Negative\n", |
| "Confidence: 0.62750435\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| } |
| ] |
| } |