|  | { | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "##### Copyright 2021 The IREE Authors" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "jFe0n0RUvf3t" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "E8Ft5u8-s1tS" | 
|  | }, | 
|  | "source": [ | 
|  | "```\n", | 
|  | "Copyright 2021 The IREE Authors\n", | 
|  | "\n", | 
|  | "Licensed under the Apache License v2.0 with LLVM Exceptions.\n", | 
|  | "See https://llvm.org/LICENSE.txt for license information.\n", | 
|  | "SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n", | 
|  | "```" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "DlPPifbUs1tY" | 
|  | }, | 
|  | "source": [ | 
|  | "# TFLite text classification sample with IREE\n", | 
|  | "\n", | 
|  | "This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "im5zRYThs1tY" | 
|  | }, | 
|  | "source": [ | 
|  | "## Setup" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 1, | 
|  | "metadata": { | 
|  | "id": "iuAwhzC7s1tZ" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "%%capture\n", | 
|  | "!python -m pip install --upgrade tf-nightly  # Needed for experimental_tflite_to_tosa_bytecode in TF>=2.13\n", | 
|  | "!python -m pip install iree-compiler iree-runtime iree-tools-tflite -f https://openxla.github.io/iree/pip-release-links.html\n", | 
|  | "!python -m pip install tflite-runtime-nightly" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "iGPG9DoEt2B5" | 
|  | }, | 
|  | "execution_count": 2, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 3, | 
|  | "metadata": { | 
|  | "id": "hOUmhoAls1tb", | 
|  | "outputId": "ed50ae3d-f114-441a-8b8a-b57cf272a130", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "IREE (https://openxla.github.io/iree):\n", | 
|  | "  IREE compiler version 20230512.517 @ 2778c08d7e49b3b26e107cb725d94bf1594ac0f3\n", | 
|  | "  LLVM version 17.0.0git\n", | 
|  | "  Optimized build\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "import numpy as np\n", | 
|  | "import urllib.request\n", | 
|  | "import pathlib\n", | 
|  | "import tempfile\n", | 
|  | "import re\n", | 
|  | "import tflite_runtime.interpreter as tflite\n", | 
|  | "\n", | 
|  | "from iree import runtime as iree_rt\n", | 
|  | "from iree.compiler import compile_file, compile_str\n", | 
|  | "from iree.tools import tflite as iree_tflite\n", | 
|  | "\n", | 
|  | "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", | 
|  | "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)\n", | 
|  | "\n", | 
|  | "# Print version information for future notebook users to reference.\n", | 
|  | "!iree-compile --version" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "1Ou5VKsLs1tc" | 
|  | }, | 
|  | "source": [ | 
|  | "### Load the TFLite model\n", | 
|  | "\n", | 
|  | "1.   Download files for the pretrained model\n", | 
|  | "2.   Extract model metadata used for input pre-processing and output post-processing\n", | 
|  | "3.   Define helper functions for pre- and post-processing\n", | 
|  | "\n", | 
|  | "These steps will differ from model to model.  Consult the model source or reference documentation for details.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 4, | 
|  | "metadata": { | 
|  | "id": "2vlRjNSPs1tc", | 
|  | "outputId": "a2c23d74-34b9-495b-d8ad-a8f14aabe2e2", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "execute_result", | 
|  | "data": { | 
|  | "text/plain": [ | 
|  | "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n", | 
|  | " <http.client.HTTPMessage at 0x7fc862349f30>)" | 
|  | ] | 
|  | }, | 
|  | "metadata": {}, | 
|  | "execution_count": 4 | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Download pretrained text classification model\n", | 
|  | "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n", | 
|  | "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 5, | 
|  | "metadata": { | 
|  | "id": "e9t6pCAVs1td", | 
|  | "outputId": "ae106ff1-77de-4b27-9796-f85ddc251245", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Archive:  /tmp/iree/colab_artifacts/text_classification.tflite\n", | 
|  | " extracting: /tmp/iree/colab_artifacts/labels.txt  \n", | 
|  | " extracting: /tmp/iree/colab_artifacts/vocab.txt  \n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Extract model vocab and label metadata\n", | 
|  | "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n", | 
|  | "\n", | 
|  | "# Load the vocab file into a dictionary.  It contains the most common 1,000\n", | 
|  | "# words in the English language, mapped to an integer.\n", | 
|  | "vocab = {}\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n", | 
|  | "  for line in vocab_file:\n", | 
|  | "    (key, val) = line.split()\n", | 
|  | "    vocab[key] = int(val)\n", | 
|  | "\n", | 
|  | "# Text will be labeled as either 'Positive' or 'Negative'.\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n", | 
|  | "  labels = label_file.read().splitlines()" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 6, | 
|  | "metadata": { | 
|  | "id": "g4QeoNHqs1te" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "#@title Input and output processing\n", | 
|  | "\n", | 
|  | "# Input text will be encoded as an integer array of fixed length 256.  The \n", | 
|  | "# input sentence will be mapped to integers from the vocab dictionary, and the \n", | 
|  | "# empty array spaces are filled with padding.\n", | 
|  | "\n", | 
|  | "SENTENCE_LEN = 256\n", | 
|  | "START = \"<START>\"\n", | 
|  | "PAD = \"<PAD>\"\n", | 
|  | "UNKNOWN = \"<UNKNOWN>\"\n", | 
|  | "\n", | 
|  | "def tokenize_input(text):\n", | 
|  | "  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n", | 
|  | "  output.fill(vocab[PAD])\n", | 
|  | "\n", | 
|  | "  # Remove capitalization and punctuation from the input text.\n", | 
|  | "  text_split = text.split()\n", | 
|  | "  text_split = [text.lower() for text in text_split]\n", | 
|  | "  text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n", | 
|  | "\n", | 
|  | "  # Prepend <START>.\n", | 
|  | "  index = 0\n", | 
|  | "  output[0][index] = vocab[START]\n", | 
|  | "  index += 1\n", | 
|  | "\n", | 
|  | "  for word in text_split:\n", | 
|  | "    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n", | 
|  | "    index += 1\n", | 
|  | "\n", | 
|  | "  return output\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def interpret_output(output):\n", | 
|  | "  if output[0] >= output[1]:\n", | 
|  | "    label = labels[0]\n", | 
|  | "    confidence = output[0]\n", | 
|  | "  else:\n", | 
|  | "    label = labels[1]\n", | 
|  | "    confidence = output[1]\n", | 
|  | "\n", | 
|  | "  print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 7, | 
|  | "metadata": { | 
|  | "id": "yXdBOIpBs1tf", | 
|  | "outputId": "2b5a3a16-2132-4648-9aa8-c9cbc3b7b81b", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0]]\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Text samples\n", | 
|  | "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", | 
|  | "negative_text = \"What a waste of my time.\"\n", | 
|  | "\n", | 
|  | "print(positive_text)\n", | 
|  | "print(tokenize_input(positive_text))" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "kM8RSTwAs1tg" | 
|  | }, | 
|  | "source": [ | 
|  | "## Run using TFLite\n", | 
|  | "\n", | 
|  | "Overview:\n", | 
|  | "\n", | 
|  | "1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n", | 
|  | "2.   Allocate tensors and get the input and output shape information\n", | 
|  | "3.   Invoke the TFLite Interpreter to test the text classification function" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 8, | 
|  | "metadata": { | 
|  | "id": "Y28pZUAZs1tg" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "interpreter = tflite.Interpreter(\n", | 
|  | "      model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n", | 
|  | "interpreter.allocate_tensors()\n", | 
|  | "input_details = interpreter.get_input_details()\n", | 
|  | "output_details = interpreter.get_output_details()\n", | 
|  | "\n", | 
|  | "def classify_text_tflite(text):\n", | 
|  | "  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n", | 
|  | "  interpreter.invoke()\n", | 
|  | "  output_data = interpreter.get_tensor(output_details[0]['index'])\n", | 
|  | "  interpret_output(output_data[0])" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 9, | 
|  | "metadata": { | 
|  | "id": "GFrbbHOgs1th", | 
|  | "outputId": "f7f73696-52ff-43f2-acc3-c923947faaaf", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Invoking text classification with TFLite\n", | 
|  | "\n", | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "Label: Positive\n", | 
|  | "Confidence: 0.8997294\n", | 
|  | "\n", | 
|  | "What a waste of my time.\n", | 
|  | "Label: Negative\n", | 
|  | "Confidence: 0.6275043\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "print(\"Invoking text classification with TFLite\\n\")\n", | 
|  | "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", | 
|  | "print(positive_text)\n", | 
|  | "classify_text_tflite(positive_text)\n", | 
|  | "print()\n", | 
|  | "negative_text = \"What a waste of my time.\"\n", | 
|  | "print(negative_text)\n", | 
|  | "classify_text_tflite(negative_text)" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "JiVzwcwCs1th" | 
|  | }, | 
|  | "source": [ | 
|  | "## Run using IREE\n", | 
|  | "\n", | 
|  | "Overview:\n", | 
|  | "\n", | 
|  | "1.   Import the TFLite model to TOSA MLIR \n", | 
|  | "2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module\n", | 
|  | "3.   Run the VM module through IREE's runtime to test the text classification function\n", | 
|  | "\n", | 
|  | "Both runtimes should generate the same output.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 10, | 
|  | "metadata": { | 
|  | "id": "Y7gKbGVSs1ti", | 
|  | "outputId": "8148f54a-d73b-49e4-d8de-e7e22f8b3c3e", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "2023-05-12 16:24:00.764695: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:7704] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", | 
|  | "2023-05-12 16:24:00.764757: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", | 
|  | "2023-05-12 16:24:00.764773: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1520] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", | 
|  | "2023-05-12 16:24:02.167994: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.\n", | 
|  | "tosa_mlirbc_file = ARTIFACTS_DIR.joinpath(\"text_classification.mlirbc\")\n", | 
|  | "!iree-import-tflite {ARTIFACTS_DIR}/text_classification.tflite -o={tosa_mlirbc_file}\n", | 
|  | "\n", | 
|  | "# The generated .mlirbc file could now be saved and used outside of Python, with\n", | 
|  | "# IREE native tools or in apps, etc." | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 11, | 
|  | "metadata": { | 
|  | "id": "dsN89K7as1tj" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "# Compile the TOSA MLIR into a VM module.\n", | 
|  | "compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", | 
|  | "\n", | 
|  | "# Register the module with a runtime context.\n", | 
|  | "config = iree_rt.Config(\"local-task\")\n", | 
|  | "ctx = iree_rt.SystemContext(config=config)\n", | 
|  | "vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, compiled_flatbuffer)\n", | 
|  | "ctx.add_vm_module(vm_module)\n", | 
|  | "invoke_text_classification = ctx.modules.module[\"main\"]\n", | 
|  | "\n", | 
|  | "def classify_text_iree(text):\n", | 
|  | "  result = invoke_text_classification(tokenize_input(text)).to_host()[0]\n", | 
|  | "  interpret_output(result)" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 12, | 
|  | "metadata": { | 
|  | "id": "LQiDmXn_s1tj", | 
|  | "outputId": "d62e7a02-709c-4257-f2a5-ffc6c66d99cd", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Invoking text classification with IREE\n", | 
|  | "\n", | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "Label: Positive\n", | 
|  | "Confidence: 0.8997294\n", | 
|  | "\n", | 
|  | "What a waste of my time.\n", | 
|  | "Label: Negative\n", | 
|  | "Confidence: 0.6275043\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "print(\"Invoking text classification with IREE\\n\")\n", | 
|  | "print(positive_text)\n", | 
|  | "classify_text_iree(positive_text)\n", | 
|  | "print()\n", | 
|  | "print(negative_text)\n", | 
|  | "classify_text_iree(negative_text)" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "collapsed_sections": [ | 
|  | "1Ou5VKsLs1tc" | 
|  | ], | 
|  | "name": "tflite_text_classification.ipynb", | 
|  | "provenance": [] | 
|  | }, | 
|  | "kernelspec": { | 
|  | "display_name": "Python 3", | 
|  | "name": "python3" | 
|  | }, | 
|  | "language_info": { | 
|  | "name": "python" | 
|  | } | 
|  | }, | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 0 | 
|  | } |