|  | { | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "##### Copyright 2021 The IREE Authors" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "jFe0n0RUvf3t" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", | 
|  | "# See https://llvm.org/LICENSE.txt for license information.\n", | 
|  | "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" | 
|  | ], | 
|  | "metadata": { | 
|  | "cellView": "form", | 
|  | "id": "rhiNq8ZUo1kL" | 
|  | }, | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "DlPPifbUs1tY" | 
|  | }, | 
|  | "source": [ | 
|  | "# TFLite text classification sample with IREE\n", | 
|  | "\n", | 
|  | "This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "im5zRYThs1tY" | 
|  | }, | 
|  | "source": [ | 
|  | "## Setup" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 1, | 
|  | "metadata": { | 
|  | "id": "iuAwhzC7s1tZ" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "%%capture\n", | 
|  | "!python -m pip install --upgrade tf-nightly  # Needed for experimental_tflite_to_tosa_bytecode in TF>=2.14\n", | 
|  | "!python -m pip install iree-compiler iree-runtime iree-tools-tflite -f https://iree.dev/pip-release-links.html\n", | 
|  | "!python -m pip install tflite-runtime-nightly" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "from tensorflow.python.pywrap_mlir import experimental_tflite_to_tosa_bytecode" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "iGPG9DoEt2B5" | 
|  | }, | 
|  | "execution_count": 2, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 3, | 
|  | "metadata": { | 
|  | "id": "hOUmhoAls1tb", | 
|  | "outputId": "fd4fc259-b979-4e12-9c02-fd3f699592cb", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "IREE (https://iree.dev):\n", | 
|  | "  IREE compiler version 20230831.630 @ 9ed3dab7ac4fcda959f5b8ebbcd7732aeb4b0c8d\n", | 
|  | "  LLVM version 18.0.0git\n", | 
|  | "  Optimized build\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "import numpy as np\n", | 
|  | "import urllib.request\n", | 
|  | "import pathlib\n", | 
|  | "import tempfile\n", | 
|  | "import re\n", | 
|  | "import tflite_runtime.interpreter as tflite\n", | 
|  | "\n", | 
|  | "from iree import runtime as iree_rt\n", | 
|  | "from iree.compiler import compile_file, compile_str\n", | 
|  | "from iree.tools import tflite as iree_tflite\n", | 
|  | "\n", | 
|  | "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", | 
|  | "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)\n", | 
|  | "\n", | 
|  | "# Print version information for future notebook users to reference.\n", | 
|  | "!iree-compile --version" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "1Ou5VKsLs1tc" | 
|  | }, | 
|  | "source": [ | 
|  | "### Load the TFLite model\n", | 
|  | "\n", | 
|  | "1.   Download files for the pretrained model\n", | 
|  | "2.   Extract model metadata used for input pre-processing and output post-processing\n", | 
|  | "3.   Define helper functions for pre- and post-processing\n", | 
|  | "\n", | 
|  | "These steps will differ from model to model.  Consult the model source or reference documentation for details.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 4, | 
|  | "metadata": { | 
|  | "id": "2vlRjNSPs1tc", | 
|  | "outputId": "c1b75fcf-6e47-4bf5-b514-3667448722ae", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "execute_result", | 
|  | "data": { | 
|  | "text/plain": [ | 
|  | "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n", | 
|  | " <http.client.HTTPMessage at 0x7d22f6441630>)" | 
|  | ] | 
|  | }, | 
|  | "metadata": {}, | 
|  | "execution_count": 4 | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Download pretrained text classification model\n", | 
|  | "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n", | 
|  | "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 5, | 
|  | "metadata": { | 
|  | "id": "e9t6pCAVs1td", | 
|  | "outputId": "f46f5d6c-ee72-4e53-8e82-0b9e9bae64ec", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Archive:  /tmp/iree/colab_artifacts/text_classification.tflite\n", | 
|  | " extracting: /tmp/iree/colab_artifacts/labels.txt  \n", | 
|  | " extracting: /tmp/iree/colab_artifacts/vocab.txt  \n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Extract model vocab and label metadata\n", | 
|  | "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n", | 
|  | "\n", | 
|  | "# Load the vocab file into a dictionary.  It contains the most common 1,000\n", | 
|  | "# words in the English language, mapped to an integer.\n", | 
|  | "vocab = {}\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n", | 
|  | "  for line in vocab_file:\n", | 
|  | "    (key, val) = line.split()\n", | 
|  | "    vocab[key] = int(val)\n", | 
|  | "\n", | 
|  | "# Text will be labeled as either 'Positive' or 'Negative'.\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n", | 
|  | "  labels = label_file.read().splitlines()" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 6, | 
|  | "metadata": { | 
|  | "id": "g4QeoNHqs1te" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "#@title Input and output processing\n", | 
|  | "\n", | 
|  | "# Input text will be encoded as an integer array of fixed length 256.  The\n", | 
|  | "# input sentence will be mapped to integers from the vocab dictionary, and the\n", | 
|  | "# empty array spaces are filled with padding.\n", | 
|  | "\n", | 
|  | "SENTENCE_LEN = 256\n", | 
|  | "START = \"<START>\"\n", | 
|  | "PAD = \"<PAD>\"\n", | 
|  | "UNKNOWN = \"<UNKNOWN>\"\n", | 
|  | "\n", | 
|  | "def tokenize_input(text):\n", | 
|  | "  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n", | 
|  | "  output.fill(vocab[PAD])\n", | 
|  | "\n", | 
|  | "  # Remove capitalization and punctuation from the input text.\n", | 
|  | "  text_split = text.split()\n", | 
|  | "  text_split = [text.lower() for text in text_split]\n", | 
|  | "  text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n", | 
|  | "\n", | 
|  | "  # Prepend <START>.\n", | 
|  | "  index = 0\n", | 
|  | "  output[0][index] = vocab[START]\n", | 
|  | "  index += 1\n", | 
|  | "\n", | 
|  | "  for word in text_split:\n", | 
|  | "    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n", | 
|  | "    index += 1\n", | 
|  | "\n", | 
|  | "  return output\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def interpret_output(output):\n", | 
|  | "  if output[0] >= output[1]:\n", | 
|  | "    label = labels[0]\n", | 
|  | "    confidence = output[0]\n", | 
|  | "  else:\n", | 
|  | "    label = labels[1]\n", | 
|  | "    confidence = output[1]\n", | 
|  | "\n", | 
|  | "  print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 7, | 
|  | "metadata": { | 
|  | "id": "yXdBOIpBs1tf", | 
|  | "outputId": "1997efd2-1fb1-4007-afc6-e03ae30f8590", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0]]\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Text samples\n", | 
|  | "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", | 
|  | "negative_text = \"What a waste of my time.\"\n", | 
|  | "\n", | 
|  | "print(positive_text)\n", | 
|  | "print(tokenize_input(positive_text))" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "kM8RSTwAs1tg" | 
|  | }, | 
|  | "source": [ | 
|  | "## Run using TFLite\n", | 
|  | "\n", | 
|  | "Overview:\n", | 
|  | "\n", | 
|  | "1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n", | 
|  | "2.   Allocate tensors and get the input and output shape information\n", | 
|  | "3.   Invoke the TFLite Interpreter to test the text classification function" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 8, | 
|  | "metadata": { | 
|  | "id": "Y28pZUAZs1tg" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "interpreter = tflite.Interpreter(\n", | 
|  | "      model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n", | 
|  | "interpreter.allocate_tensors()\n", | 
|  | "input_details = interpreter.get_input_details()\n", | 
|  | "output_details = interpreter.get_output_details()\n", | 
|  | "\n", | 
|  | "def classify_text_tflite(text):\n", | 
|  | "  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n", | 
|  | "  interpreter.invoke()\n", | 
|  | "  output_data = interpreter.get_tensor(output_details[0]['index'])\n", | 
|  | "  interpret_output(output_data[0])" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 9, | 
|  | "metadata": { | 
|  | "id": "GFrbbHOgs1th", | 
|  | "outputId": "6242b91d-a7b3-4e3f-a6b1-8adc2fdfd09b", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Invoking text classification with TFLite\n", | 
|  | "\n", | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "Label: Positive\n", | 
|  | "Confidence: 0.8997294\n", | 
|  | "\n", | 
|  | "What a waste of my time.\n", | 
|  | "Label: Negative\n", | 
|  | "Confidence: 0.6275043\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "print(\"Invoking text classification with TFLite\\n\")\n", | 
|  | "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", | 
|  | "print(positive_text)\n", | 
|  | "classify_text_tflite(positive_text)\n", | 
|  | "print()\n", | 
|  | "negative_text = \"What a waste of my time.\"\n", | 
|  | "print(negative_text)\n", | 
|  | "classify_text_tflite(negative_text)" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "JiVzwcwCs1th" | 
|  | }, | 
|  | "source": [ | 
|  | "## Run using IREE\n", | 
|  | "\n", | 
|  | "Overview:\n", | 
|  | "\n", | 
|  | "1.   Import the TFLite model to TOSA MLIR\n", | 
|  | "2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module\n", | 
|  | "3.   Run the VM module through IREE's runtime to test the text classification function\n", | 
|  | "\n", | 
|  | "Both runtimes should generate the same output.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 10, | 
|  | "metadata": { | 
|  | "id": "Y7gKbGVSs1ti", | 
|  | "outputId": "5262b352-c609-4e39-ee2f-a6cad21b1c67", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "2023-08-31 21:32:50.137814: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9511] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", | 
|  | "2023-08-31 21:32:50.137865: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", | 
|  | "2023-08-31 21:32:50.137895: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", | 
|  | "2023-08-31 21:32:51.571208: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "# Convert TFLite model to TOSA MLIR (bytecode) with IREE's import tool.\n", | 
|  | "tosa_mlirbc_file = ARTIFACTS_DIR.joinpath(\"text_classification.mlirbc\")\n", | 
|  | "!iree-import-tflite {ARTIFACTS_DIR}/text_classification.tflite -o={tosa_mlirbc_file}\n", | 
|  | "\n", | 
|  | "# The generated .mlirbc file could now be saved and used outside of Python, with\n", | 
|  | "# IREE native tools or in apps, etc." | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 11, | 
|  | "metadata": { | 
|  | "id": "dsN89K7as1tj" | 
|  | }, | 
|  | "outputs": [], | 
|  | "source": [ | 
|  | "# Compile the TOSA MLIR into a VM module.\n", | 
|  | "compiled_flatbuffer = compile_file(tosa_mlirbc_file, input_type=\"tosa\", target_backends=[\"vmvx\"])\n", | 
|  | "\n", | 
|  | "# Register the module with a runtime context.\n", | 
|  | "config = iree_rt.Config(\"local-task\")\n", | 
|  | "ctx = iree_rt.SystemContext(config=config)\n", | 
|  | "vm_module = iree_rt.VmModule.from_flatbuffer(config.vm_instance, compiled_flatbuffer)\n", | 
|  | "ctx.add_vm_module(vm_module)\n", | 
|  | "invoke_text_classification = ctx.modules.module[\"main\"]\n", | 
|  | "\n", | 
|  | "def classify_text_iree(text):\n", | 
|  | "  result = invoke_text_classification(tokenize_input(text)).to_host()[0]\n", | 
|  | "  interpret_output(result)" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 12, | 
|  | "metadata": { | 
|  | "id": "LQiDmXn_s1tj", | 
|  | "outputId": "58d92deb-a335-4675-a25f-88870966b610", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | } | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Invoking text classification with IREE\n", | 
|  | "\n", | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "Label: Positive\n", | 
|  | "Confidence: 0.8997294\n", | 
|  | "\n", | 
|  | "What a waste of my time.\n", | 
|  | "Label: Negative\n", | 
|  | "Confidence: 0.6275043\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "print(\"Invoking text classification with IREE\\n\")\n", | 
|  | "print(positive_text)\n", | 
|  | "classify_text_iree(positive_text)\n", | 
|  | "print()\n", | 
|  | "print(negative_text)\n", | 
|  | "classify_text_iree(negative_text)" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "collapsed_sections": [ | 
|  | "1Ou5VKsLs1tc" | 
|  | ], | 
|  | "name": "tflite_text_classification.ipynb", | 
|  | "provenance": [] | 
|  | }, | 
|  | "kernelspec": { | 
|  | "display_name": "Python 3", | 
|  | "name": "python3" | 
|  | }, | 
|  | "language_info": { | 
|  | "name": "python" | 
|  | } | 
|  | }, | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 0 | 
|  | } |