{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2021 The IREE Authors"
      ],
      "metadata": {
        "id": "jFe0n0RUvf3t"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "metadata": {
        "cellView": "form",
        "id": "rhiNq8ZUo1kL"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DlPPifbUs1tY"
      },
      "source": [
        "# TFLite text classification sample with IREE\n",
        "\n",
        "This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "im5zRYThs1tY"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "iuAwhzC7s1tZ"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!python -m pip install --upgrade tensorflow\n",
        "!python -m pip install --pre iree-base-compiler iree-base-runtime iree-tools-tflite -f https://iree.dev/pip-release-links.html\n",
        "!python -m pip install tflite-runtime-nightly"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode"
      ],
      "metadata": {
        "id": "iGPG9DoEt2B5"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "hOUmhoAls1tb",
        "outputId": "a614a2d7-13b4-453a-f794-d5cb0297c6ff",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "IREE (https://iree.dev):\n",
            "  IREE compiler version 20240511.890 @ a3b7e12f1ae3b4d0da9cc5dfa5fb7865b178ec4b\n",
            "  LLVM version 19.0.0git\n",
            "  Optimized build\n"
          ]
        }
      ],
      "source": [
        "import numpy as np\n",
        "import urllib.request\n",
        "import pathlib\n",
        "import tempfile\n",
        "import re\n",
        "import tflite_runtime.interpreter as tflite\n",
        "\n",
        "from iree import runtime as iree_rt\n",
        "from iree.compiler import compile_file, compile_str\n",
        "from iree.tools import tflite as iree_tflite\n",
        "\n",
        "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
        "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "# Print version information for future notebook users to reference.\n",
        "!iree-compile --version"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Ou5VKsLs1tc"
      },
      "source": [
        "### Load the TFLite model\n",
        "\n",
        "1.   Download files for the pretrained model\n",
        "2.   Extract model metadata used for input pre-processing and output post-processing\n",
        "3.   Define helper functions for pre- and post-processing\n",
        "\n",
        "These steps will differ from model to model.  Consult the model source or reference documentation for details.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "2vlRjNSPs1tc",
        "outputId": "188aaeda-714a-46fd-cdf3-993205ad988e",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n",
              " <http.client.HTTPMessage at 0x7bc4358faaa0>)"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ],
      "source": [
        "#@title Download pretrained text classification model\n",
        "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n",
        "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "e9t6pCAVs1td",
        "outputId": "72dc16fd-dc31-4799-8431-b64a3fa446f0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Archive:  /tmp/iree/colab_artifacts/text_classification.tflite\n",
            " extracting: /tmp/iree/colab_artifacts/labels.txt  \n",
            " extracting: /tmp/iree/colab_artifacts/vocab.txt  \n"
          ]
        }
      ],
      "source": [
        "#@title Extract model vocab and label metadata\n",
        "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n",
        "\n",
        "# Load the vocab file into a dictionary.  It contains the most common 1,000\n",
        "# words in the English language, mapped to an integer.\n",
        "vocab = {}\n",
        "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n",
        "  for line in vocab_file:\n",
        "    (key, val) = line.split()\n",
        "    vocab[key] = int(val)\n",
        "\n",
        "# Text will be labeled as either 'Positive' or 'Negative'.\n",
        "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n",
        "  labels = label_file.read().splitlines()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "g4QeoNHqs1te"
      },
      "outputs": [],
      "source": [
        "#@title Input and output processing\n",
        "\n",
        "# Input text will be encoded as an integer array of fixed length 256.  The\n",
        "# input sentence will be mapped to integers from the vocab dictionary, and the\n",
        "# empty array spaces are filled with padding.\n",
        "\n",
        "SENTENCE_LEN = 256\n",
        "START = \"<START>\"\n",
        "PAD = \"<PAD>\"\n",
        "UNKNOWN = \"<UNKNOWN>\"\n",
        "\n",
        "def tokenize_input(text):\n",
        "  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n",
        "  output.fill(vocab[PAD])\n",
        "\n",
        "  # Remove capitalization and punctuation from the input text.\n",
        "  text_split = text.split()\n",
        "  text_split = [text.lower() for text in text_split]\n",
        "  text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n",
        "\n",
        "  # Prepend <START>.\n",
        "  index = 0\n",
        "  output[0][index] = vocab[START]\n",
        "  index += 1\n",
        "\n",
        "  for word in text_split:\n",
        "    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n",
        "    index += 1\n",
        "\n",
        "  return output\n",
        "\n",
        "\n",
        "def interpret_output(output):\n",
        "  if output[0] >= output[1]:\n",
        "    label = labels[0]\n",
        "    confidence = output[0]\n",
        "  else:\n",
        "    label = labels[1]\n",
        "    confidence = output[1]\n",
        "\n",
        "  print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "yXdBOIpBs1tf",
        "outputId": "97e74561-20ec-46de-e837-52385bb64bf3",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "This is the best movie I've seen in recent years. Strongly recommend it!\n",
            "[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0]]\n"
          ]
        }
      ],
      "source": [
        "#@title Text samples\n",
        "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n",
        "negative_text = \"What a waste of my time.\"\n",
        "\n",
        "print(positive_text)\n",
        "print(tokenize_input(positive_text))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kM8RSTwAs1tg"
      },
      "source": [
        "## Run using TFLite\n",
        "\n",
        "Overview:\n",
        "\n",
        "1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n",
        "2.   Allocate tensors and get the input and output shape information\n",
        "3.   Invoke the TFLite Interpreter to test the text classification function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "Y28pZUAZs1tg"
      },
      "outputs": [],
      "source": [
        "interpreter = tflite.Interpreter(\n",
        "      model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "\n",
        "def classify_text_tflite(text):\n",
        "  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n",
        "  interpreter.invoke()\n",
        "  output_data = interpreter.get_tensor(output_details[0]['index'])\n",
        "  interpret_output(output_data[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "GFrbbHOgs1th",
        "outputId": "25840d4d-1b17-4e0d-b19c-9d53b2df2e35",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Invoking text classification with TFLite\n",
            "\n",
            "This is the best movie I've seen in recent years. Strongly recommend it!\n",
            "Label: Positive\n",
            "Confidence: 0.8997294\n",
            "\n",
            "What a waste of my time.\n",
            "Label: Negative\n",
            "Confidence: 0.6275043\n"
          ]
        }
      ],
      "source": [
        "print(\"Invoking text classification with TFLite\\n\")\n",
        "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n",
        "print(positive_text)\n",
        "classify_text_tflite(positive_text)\n",
        "print()\n",
        "negative_text = \"What a waste of my time.\"\n",
        "print(negative_text)\n",
        "classify_text_tflite(negative_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JiVzwcwCs1th"
      },
      "source": [
        "## Run using IREE\n",
        "\n",
        "Overview:\n",
        "\n",
        "1.   Import the TFLite model to TOSA MLIR\n",
        "2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module\n",
        "3.   Run the VM module through IREE's runtime to test the text classification function\n",
        "\n",
        "Both runtimes should generate the same output.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "Y7gKbGVSs1ti",
        "outputId": "e96f075f-4af5-41bb-af8c-dd51ca013a39",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "2024-05-13 20:12:11.397050: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
          ]
        }
      ],
      "source": [
        "# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.\n",
        "tosa_mlirbc_file = ARTIFACTS_DIR.joinpath(\"text_classification.mlirbc\")\n",
        "!iree-import-tflite {ARTIFACTS_DIR}/text_classification.tflite -o={tosa_mlirbc_file}\n",
        "\n",
        "# The generated .mlirbc file could now be saved and used outside of Python, with\n",
        "# IREE native tools or in apps, etc."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "dsN89K7as1tj"
      },
      "outputs": [],
      "source": [
        "# Compile the TOSA MLIR into a VM module.\n",
        "compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type=\"tosa\", target_backends=[\"vmvx\"])\n",
        "\n",
        "# Register the module with a runtime context.\n",
        "config = iree_rt.Config(\"local-task\")\n",
        "ctx = iree_rt.SystemContext(config=config)\n",
        "vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, compiled_flatbuffer)\n",
        "ctx.add_vm_module(vm_module)\n",
        "invoke_text_classification = ctx.modules.module[\"main\"]\n",
        "\n",
        "def classify_text_iree(text):\n",
        "  result = invoke_text_classification(tokenize_input(text)).to_host()[0]\n",
        "  interpret_output(result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "LQiDmXn_s1tj",
        "outputId": "7c87745a-326c-4646-ef79-23c2b09a8b6c",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Invoking text classification with IREE\n",
            "\n",
            "This is the best movie I've seen in recent years. Strongly recommend it!\n",
            "Label: Positive\n",
            "Confidence: 0.8997294\n",
            "\n",
            "What a waste of my time.\n",
            "Label: Negative\n",
            "Confidence: 0.6275043\n"
          ]
        }
      ],
      "source": [
        "print(\"Invoking text classification with IREE\\n\")\n",
        "print(positive_text)\n",
        "classify_text_iree(positive_text)\n",
        "print()\n",
        "print(negative_text)\n",
        "classify_text_iree(negative_text)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "1Ou5VKsLs1tc"
      ],
      "name": "tflite_text_classification.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
