| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "##### Copyright 2021 The IREE Authors" |
| ], |
| "metadata": { |
| "id": "jFe0n0RUvf3t" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "metadata": { |
| "cellView": "form", |
| "id": "rhiNq8ZUo1kL" |
| }, |
| "execution_count": 5, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "DlPPifbUs1tY" |
| }, |
| "source": [ |
| "# TFLite text classification sample with IREE\n", |
| "\n", |
| "This notebook demonstrates how to download, compile, and run a TFLite model with IREE. It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE. The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "im5zRYThs1tY" |
| }, |
| "source": [ |
| "## Setup" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": { |
| "id": "iuAwhzC7s1tZ" |
| }, |
| "outputs": [], |
| "source": [ |
| "%%capture\n", |
| "!python -m pip install --upgrade tensorflow\n", |
| "!python -m pip install --pre iree-base-compiler iree-base-runtime iree-tools-tflite -f https://iree.dev/pip-release-links.html\n", |
| "!python -m pip install tflite-runtime-nightly" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode" |
| ], |
| "metadata": { |
| "id": "iGPG9DoEt2B5" |
| }, |
| "execution_count": 7, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "metadata": { |
| "id": "hOUmhoAls1tb", |
| "outputId": "a614a2d7-13b4-453a-f794-d5cb0297c6ff", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 20240511.890 @ a3b7e12f1ae3b4d0da9cc5dfa5fb7865b178ec4b\n", |
| " LLVM version 19.0.0git\n", |
| " Optimized build\n" |
| ] |
| } |
| ], |
| "source": [ |
| "import numpy as np\n", |
| "import urllib.request\n", |
| "import pathlib\n", |
| "import tempfile\n", |
| "import re\n", |
| "import tflite_runtime.interpreter as tflite\n", |
| "\n", |
| "from iree import runtime as iree_rt\n", |
| "from iree.compiler import compile_file, compile_str\n", |
| "from iree.tools import tflite as iree_tflite\n", |
| "\n", |
| "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", |
| "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)\n", |
| "\n", |
| "# Print version information for future notebook users to reference.\n", |
| "!iree-compile --version" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "1Ou5VKsLs1tc" |
| }, |
| "source": [ |
| "### Load the TFLite model\n", |
| "\n", |
| "1. Download files for the pretrained model\n", |
| "2. Extract model metadata used for input pre-processing and output post-processing\n", |
| "3. Define helper functions for pre- and post-processing\n", |
| "\n", |
| "These steps will differ from model to model. Consult the model source or reference documentation for details.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "metadata": { |
| "id": "2vlRjNSPs1tc", |
| "outputId": "188aaeda-714a-46fd-cdf3-993205ad988e", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "execute_result", |
| "data": { |
| "text/plain": [ |
| "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n", |
| " <http.client.HTTPMessage at 0x7bc4358faaa0>)" |
| ] |
| }, |
| "metadata": {}, |
| "execution_count": 9 |
| } |
| ], |
| "source": [ |
| "#@title Download pretrained text classification model\n", |
| "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n", |
| "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "metadata": { |
| "id": "e9t6pCAVs1td", |
| "outputId": "72dc16fd-dc31-4799-8431-b64a3fa446f0", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Archive: /tmp/iree/colab_artifacts/text_classification.tflite\n", |
| " extracting: /tmp/iree/colab_artifacts/labels.txt \n", |
| " extracting: /tmp/iree/colab_artifacts/vocab.txt \n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Extract model vocab and label metadata\n", |
| "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n", |
| "\n", |
| "# Load the vocab file into a dictionary. It contains the most common 1,000\n", |
| "# words in the English language, mapped to an integer.\n", |
| "vocab = {}\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n", |
| " for line in vocab_file:\n", |
| " (key, val) = line.split()\n", |
| " vocab[key] = int(val)\n", |
| "\n", |
| "# Text will be labeled as either 'Positive' or 'Negative'.\n", |
| "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n", |
| " labels = label_file.read().splitlines()" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 11, |
| "metadata": { |
| "id": "g4QeoNHqs1te" |
| }, |
| "outputs": [], |
| "source": [ |
| "#@title Input and output processing\n", |
| "\n", |
| "# Input text will be encoded as an integer array of fixed length 256. The\n", |
| "# input sentence will be mapped to integers from the vocab dictionary, and the\n", |
| "# empty array spaces are filled with padding.\n", |
| "\n", |
| "SENTENCE_LEN = 256\n", |
| "START = \"<START>\"\n", |
| "PAD = \"<PAD>\"\n", |
| "UNKNOWN = \"<UNKNOWN>\"\n", |
| "\n", |
| "def tokenize_input(text):\n", |
| " output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n", |
| " output.fill(vocab[PAD])\n", |
| "\n", |
| " # Remove capitalization and punctuation from the input text.\n", |
| " text_split = text.split()\n", |
| " text_split = [text.lower() for text in text_split]\n", |
| " text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n", |
| "\n", |
| " # Prepend <START>.\n", |
| " index = 0\n", |
| " output[0][index] = vocab[START]\n", |
| " index += 1\n", |
| "\n", |
| " for word in text_split:\n", |
| " output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n", |
| " index += 1\n", |
| "\n", |
| " return output\n", |
| "\n", |
| "\n", |
| "def interpret_output(output):\n", |
| " if output[0] >= output[1]:\n", |
| " label = labels[0]\n", |
| " confidence = output[0]\n", |
| " else:\n", |
| " label = labels[1]\n", |
| " confidence = output[1]\n", |
| "\n", |
| " print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 12, |
| "metadata": { |
| "id": "yXdBOIpBs1tf", |
| "outputId": "97e74561-20ec-46de-e837-52385bb64bf3", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "[[ 1 13 8 3 117 19 206 109 10 1134 152 2301 385 11\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", |
| " 0 0 0 0]]\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Text samples\n", |
| "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", |
| "negative_text = \"What a waste of my time.\"\n", |
| "\n", |
| "print(positive_text)\n", |
| "print(tokenize_input(positive_text))" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "kM8RSTwAs1tg" |
| }, |
| "source": [ |
| "## Run using TFLite\n", |
| "\n", |
| "Overview:\n", |
| "\n", |
| "1. Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n", |
| "2. Allocate tensors and get the input and output shape information\n", |
| "3. Invoke the TFLite Interpreter to test the text classification function" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 13, |
| "metadata": { |
| "id": "Y28pZUAZs1tg" |
| }, |
| "outputs": [], |
| "source": [ |
| "interpreter = tflite.Interpreter(\n", |
| " model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n", |
| "interpreter.allocate_tensors()\n", |
| "input_details = interpreter.get_input_details()\n", |
| "output_details = interpreter.get_output_details()\n", |
| "\n", |
| "def classify_text_tflite(text):\n", |
| " interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n", |
| " interpreter.invoke()\n", |
| " output_data = interpreter.get_tensor(output_details[0]['index'])\n", |
| " interpret_output(output_data[0])" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 14, |
| "metadata": { |
| "id": "GFrbbHOgs1th", |
| "outputId": "25840d4d-1b17-4e0d-b19c-9d53b2df2e35", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Invoking text classification with TFLite\n", |
| "\n", |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "Label: Positive\n", |
| "Confidence: 0.8997294\n", |
| "\n", |
| "What a waste of my time.\n", |
| "Label: Negative\n", |
| "Confidence: 0.6275043\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(\"Invoking text classification with TFLite\\n\")\n", |
| "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", |
| "print(positive_text)\n", |
| "classify_text_tflite(positive_text)\n", |
| "print()\n", |
| "negative_text = \"What a waste of my time.\"\n", |
| "print(negative_text)\n", |
| "classify_text_tflite(negative_text)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "JiVzwcwCs1th" |
| }, |
| "source": [ |
| "## Run using IREE\n", |
| "\n", |
| "Overview:\n", |
| "\n", |
| "1. Import the TFLite model to TOSA MLIR\n", |
| "2. Compile the TOSA MLIR into an IREE flatbuffer and VM module\n", |
| "3. Run the VM module through IREE's runtime to test the text classification function\n", |
| "\n", |
| "Both runtimes should generate the same output.\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 15, |
| "metadata": { |
| "id": "Y7gKbGVSs1ti", |
| "outputId": "e96f075f-4af5-41bb-af8c-dd51ca013a39", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "2024-05-13 20:12:11.397050: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" |
| ] |
| } |
| ], |
| "source": [ |
| "# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.\n", |
| "tosa_mlirbc_file = ARTIFACTS_DIR.joinpath(\"text_classification.mlirbc\")\n", |
| "!iree-import-tflite {ARTIFACTS_DIR}/text_classification.tflite -o={tosa_mlirbc_file}\n", |
| "\n", |
| "# The generated .mlirbc file could now be saved and used outside of Python, with\n", |
| "# IREE native tools or in apps, etc." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 16, |
| "metadata": { |
| "id": "dsN89K7as1tj" |
| }, |
| "outputs": [], |
| "source": [ |
| "# Compile the TOSA MLIR into a VM module.\n", |
| "compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", |
| "\n", |
| "# Register the module with a runtime context.\n", |
| "config = iree_rt.Config(\"local-task\")\n", |
| "ctx = iree_rt.SystemContext(config=config)\n", |
| "vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, compiled_flatbuffer)\n", |
| "ctx.add_vm_module(vm_module)\n", |
| "invoke_text_classification = ctx.modules.module[\"main\"]\n", |
| "\n", |
| "def classify_text_iree(text):\n", |
| " result = invoke_text_classification(tokenize_input(text)).to_host()[0]\n", |
| " interpret_output(result)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 17, |
| "metadata": { |
| "id": "LQiDmXn_s1tj", |
| "outputId": "7c87745a-326c-4646-ef79-23c2b09a8b6c", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Invoking text classification with IREE\n", |
| "\n", |
| "This is the best movie I've seen in recent years. Strongly recommend it!\n", |
| "Label: Positive\n", |
| "Confidence: 0.8997294\n", |
| "\n", |
| "What a waste of my time.\n", |
| "Label: Negative\n", |
| "Confidence: 0.6275043\n" |
| ] |
| } |
| ], |
| "source": [ |
| "print(\"Invoking text classification with IREE\\n\")\n", |
| "print(positive_text)\n", |
| "classify_text_iree(positive_text)\n", |
| "print()\n", |
| "print(negative_text)\n", |
| "classify_text_iree(negative_text)" |
| ] |
| } |
| ], |
| "metadata": { |
| "colab": { |
| "collapsed_sections": [ |
| "1Ou5VKsLs1tc" |
| ], |
| "name": "tflite_text_classification.ipynb", |
| "provenance": [] |
| }, |
| "kernelspec": { |
| "display_name": "Python 3", |
| "name": "python3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 0 |
| } |