| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "E8Ft5u8-s1tS" |
| }, |
| "source": [ |
| "```\n", |
| "Copyright 2021 The IREE Authors\n", |
| "\n", |
| "Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "See https://llvm.org/LICENSE.txt for license information.\n", |
| "SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n", |
| "```" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "DlPPifbUs1tY" |
| }, |
| "source": [ |
| "# TFLite text classification sample with IREE\n", |
| "\n", |
| "This notebook demonstrates how to download, compile, and run a TFLite model with IREE. It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE. The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "im5zRYThs1tY" |
| }, |
| "source": [ |
| "## Setup" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "iuAwhzC7s1tZ" |
| }, |
| "outputs": [], |
| "source": [ |
| "%%capture\n", |
| "!python -m pip install iree-compiler iree-runtime iree-tools-tflite -f https://github.com/google/iree/releases/latest\n", |
| "!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "hOUmhoAls1tb" |
| }, |
| "outputs": [], |
| "source": [ |
| "import numpy as np\n", |
| "import urllib.request\n", |
| "import pathlib\n", |
| "import tempfile\n", |
| "import re\n", |
| "import tflite_runtime.interpreter as tflite\n", |
| "\n", |
| "from iree import runtime as iree_rt\n", |
| "from iree.compiler import compile_str\n", |
| "from iree.tools import tflite as iree_tflite\n", |
| "\n", |
| "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", |
| "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "1Ou5VKsLs1tc" |
| }, |
| "source": [ |
| "### Load the TFLite model\n", |
| "\n", |
| "1. Download files for the pretrained model\n", |
| "2. Extract model metadata used for input pre-processing and output post-processing\n", |
| "3. Define helper functions for pre- and post-processing\n", |
| "\n", |
| "These steps will differ from model to model. Consult the model source or reference documentation for details.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "2vlRjNSPs1tc" |
| }, |
| "outputs": [], |
| "source": [ |
| "#@title Download pretrained text classification model\n", |
| "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n", |
| "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "e9t6pCAVs1td" |
| }, |
| "outputs": [], |
| "source": [ |
| "#@title Extract model vocab and label metadata\n", |
| "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n", |
| "\n", |
| "# Load the vocab file into a dictionary. It contains the most common 1,000\n", |
| "# words in the English language, mapped to an integer.\n", |
| "vocab = {}\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n", |
| " for line in vocab_file:\n", |
| " (key, val) = line.split()\n", |
| " vocab[key] = int(val)\n", |
| "\n", |
| "# Text will be labeled as either 'Positive' or 'Negative'.\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n", |
| " labels = label_file.read().splitlines()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "g4QeoNHqs1te" |
| }, |
| "outputs": [], |
| "source": [ |
| "#@title Input and output processing\n", |
| "\n", |
| "# Input text will be encoded as an integer array of fixed length 256. The \n", |
| "# input sentence will be mapped to integers from the vocab dictionary, and the \n", |
| "# empty array spaces are filled with padding.\n", |
| "\n", |
| "SENTENCE_LEN = 256\n", |
| "START = \"<START>\"\n", |
| "PAD = \"<PAD>\"\n", |
| "UNKNOWN = \"<UNKNOWN>\"\n", |
| "\n", |
| "def tokenize_input(text):\n", |
| " output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n", |
| " output.fill(vocab[PAD])\n", |
| "\n", |
| " # Remove capitalization and punctuation from the input text.\n", |
| " text_split = text.split()\n", |
| " text_split = [text.lower() for text in text_split]\n", |
| " text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n", |
| "\n", |
| " # Prepend <START>.\n", |
| " index = 0\n", |
| " output[0][index] = vocab[START]\n", |
| " index += 1\n", |
| "\n", |
| " for word in text_split:\n", |
| " output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n", |
| " index += 1\n", |
| "\n", |
| " return output\n", |
| "\n", |
| "\n", |
| "def interpret_output(output):\n", |
| " if output[0] >= output[1]:\n", |
| " label = labels[0]\n", |
| " confidence = output[0]\n", |
| " else:\n", |
| " label = labels[1]\n", |
| " confidence = output[1]\n", |
| "\n", |
| " print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "yXdBOIpBs1tf", |
| "outputId": "8223b6fc-f796-4692-ce83-2eeb5ccd25c8", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "[[ 1 13 8 3 117 19 206 109 10 1134 152 2301 385 11\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0]]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Text samples\n", |
| "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", |
| "negative_text = \"What a waste of my time.\"\n", |
| "\n", |
| "print(positive_text)\n", |
| "print(tokenize_input(positive_text))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "kM8RSTwAs1tg" |
| }, |
| "source": [ |
| "## Run using TFLite\n", |
| "\n", |
| "Overview:\n", |
| "\n", |
| "1. Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n", |
| "2. Allocate tensors and get the input and output shape information\n", |
| "3. Invoke the TFLite Interpreter to test the text classification function" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "Y28pZUAZs1tg" |
| }, |
| "outputs": [], |
| "source": [ |
| "interpreter = tflite.Interpreter(\n", |
| " model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n", |
| "interpreter.allocate_tensors()\n", |
| "input_details = interpreter.get_input_details()\n", |
| "output_details = interpreter.get_output_details()\n", |
| "\n", |
| "def classify_text_tflite(text):\n", |
| " interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n", |
| " interpreter.invoke()\n", |
| " output_data = interpreter.get_tensor(output_details[0]['index'])\n", |
| " interpret_output(output_data[0])" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "GFrbbHOgs1th", |
| "outputId": "b37f72ab-197b-480d-aabf-a702bc9aa986", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Invoking text classification with TFLite\n", |
| "\n", |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "Label: Positive\n", |
| "Confidence: 0.8997293\n", |
| "\n", |
| "What a waste of my time.\n", |
| "Label: Negative\n", |
| "Confidence: 0.6275043\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(\"Invoking text classification with TFLite\\n\")\n", |
| "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", |
| "print(positive_text)\n", |
| "classify_text_tflite(positive_text)\n", |
| "print()\n", |
| "negative_text = \"What a waste of my time.\"\n", |
| "print(negative_text)\n", |
| "classify_text_tflite(negative_text)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "JiVzwcwCs1th" |
| }, |
| "source": [ |
| "## Run using IREE\n", |
| "\n", |
| "Overview:\n", |
| "\n", |
| "1. Import the TFLite model to TOSA MLIR \n", |
| "2. Compile the TOSA MLIR into an IREE flatbuffer and VM module\n", |
| "3. Run the VM module through IREE's runtime to test the text classification function\n", |
| "\n", |
| "Both runtimes should generate the same output.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "Y7gKbGVSs1ti" |
| }, |
| "outputs": [], |
| "source": [ |
| "# Convert TFLite model to TOSA MLIR with IREE's import tool.\n", |
| "IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')\n", |
| "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification.mlir\n", |
| "\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n", |
| " tosa_mlir = mlir_file.read()\n", |
| "\n", |
| "# The generated .mlir file could now be saved and used outside of Python, with\n", |
| "# IREE native tools or in apps, etc." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "i2JGBMa-s1ti", |
| "outputId": "5ff0a33b-f7ac-4ed8-a757-23afe1b34b2c", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "module {\n", |
| " func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> {\n", |
| " %0 = \"tosa.const\"() {value = dense<[0.043447677, -0.0434476472]> : tensor<2xf32>} : () -> tensor<2xf32>\n", |
| " %1 = \"tosa.const\"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722, 0.795746565, -1.3122592, 0.352218777, 1.04682362, -1.18796027, -0.0409261398, 1.05883229, 1.48620188, -1.13325548, 1.03072512]]> : tensor<2x16xf32>} : () -> tensor<2x16xf32>\n", |
| " %2 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n", |
| " %3 = \"tosa.const\"() {value = opaque<\"elided_large_const\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n", |
| " %4 = \"tosa.const\"() {value = opaque<\"elided_large_const\", \"0xDEADBEEF\"> : tensor<1x10003x16xf32>} : () -> tensor<1x10003x16xf32>\n", |
| " %5 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<1x1xf32>} : () -> tensor<1x1xf32>\n", |
| " %6 = \"tosa.gather\"(%4, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n", |
| " %7 = \"tosa.reduce_sum\"(%6) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n", |
| " %8 = \"tosa.reshape\"(%7) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n", |
| " %9 = \"tosa.mul\"(%8, %5) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n", |
| " %10 = \"tosa.fully_connected\"(%9, %3, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n", |
| " %11 = \"tosa.clamp\"(%10) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n", |
| " %12 = \"tosa.fully_connected\"(%11, %1, %0) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n", |
| " %13 = \"tosa.exp\"(%12) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n", |
| " %14 = \"tosa.reduce_sum\"(%13) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n", |
| " %15 = \"tosa.reciprocal\"(%14) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n", |
| " %16 = \"tosa.mul\"(%13, %15) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n", |
| " return %16 : tensor<1x2xf32>\n", |
| " }\n", |
| "}\n", |
| "\n" |
| ] |
| } |
| ], |
| "source": [ |
| "# The model contains very large constants, so recompile a truncated version to print.\n", |
| "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification_truncated.mlir -mlir-elide-elementsattrs-if-larger=50\n", |
| "\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"text_classification_truncated.mlir\")) as truncated_mlir_file:\n", |
| " truncated_tosa_mlir = truncated_mlir_file.read()\n", |
| " print(truncated_tosa_mlir, end='')" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "dsN89K7as1tj" |
| }, |
| "outputs": [], |
| "source": [ |
| "# Compile the TOSA MLIR into a VM module.\n", |
| "compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", |
| "vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n", |
| "\n", |
| "# Register the module with a runtime context.\n", |
| "config = iree_rt.Config(\"vmvx\")\n", |
| "ctx = iree_rt.SystemContext(config=config)\n", |
| "ctx.add_vm_module(vm_module)\n", |
| "invoke_text_classification = ctx.modules.module[\"main\"]\n", |
| "\n", |
| "def classify_text_iree(text):\n", |
| " result = invoke_text_classification(tokenize_input(text)).to_host()[0]\n", |
| " interpret_output(result)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": null, |
| "metadata": { |
| "id": "LQiDmXn_s1tj", |
| "outputId": "5911e060-9aad-4163-8bf0-1daa2714b54f", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Invoking text classification with IREE\n", |
| "\n", |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "Label: Positive\n", |
| "Confidence: 0.8997293\n", |
| "\n", |
| "What a waste of my time.\n", |
| "Label: Negative\n", |
| "Confidence: 0.6275043\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(\"Invoking text classification with IREE\\n\")\n", |
| "print(positive_text)\n", |
| "classify_text_iree(positive_text)\n", |
| "print()\n", |
| "print(negative_text)\n", |
| "classify_text_iree(negative_text)" |
| ] |
| } |
| ], |
| "metadata": { |
| "colab": { |
| "collapsed_sections": [ |
| "1Ou5VKsLs1tc" |
| ], |
| "name": "tflite_text_classification.ipynb", |
| "provenance": [] |
| }, |
| "kernelspec": { |
| "display_name": "Python 3", |
| "name": "python3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 0 |
| } |