|  | { | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 0, | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "name": "tflite_text_classification.ipynb", | 
|  | "provenance": [], | 
|  | "collapsed_sections": [ | 
|  | "dweksnWs1eZb" | 
|  | ] | 
|  | }, | 
|  | "kernelspec": { | 
|  | "name": "python3", | 
|  | "display_name": "Python 3" | 
|  | }, | 
|  | "language_info": { | 
|  | "name": "python" | 
|  | } | 
|  | }, | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "_VsBo9D7GJ8G" | 
|  | }, | 
|  | "source": [ | 
|  | "```\n", | 
|  | "Copyright 2021 The IREE Authors\n", | 
|  | "\n", | 
|  | "Licensed under the Apache License v2.0 with LLVM Exceptions.\n", | 
|  | "See https://llvm.org/LICENSE.txt for license information.\n", | 
|  | "SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n", | 
|  | "```" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "3Mjqf_WmSkAs" | 
|  | }, | 
|  | "source": [ | 
|  | "# TFLite text classification sample with IREE\n", | 
|  | "\n", | 
|  | "This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "W71slxNjS3SB" | 
|  | }, | 
|  | "source": [ | 
|  | "## Setup" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "T8WMt2sPS4ft" | 
|  | }, | 
|  | "source": [ | 
|  | "%%capture\n", | 
|  | "!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot -f https://github.com/google/iree/releases/latest\n", | 
|  | "!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "cellView": "code", | 
|  | "id": "7L4_gnRVVBdi" | 
|  | }, | 
|  | "source": [ | 
|  | "import numpy as np\n", | 
|  | "import urllib.request\n", | 
|  | "import pathlib\n", | 
|  | "import tempfile\n", | 
|  | "import re\n", | 
|  | "import tflite_runtime.interpreter as tflite\n", | 
|  | "\n", | 
|  | "from iree import runtime as iree_rt\n", | 
|  | "from iree.compiler import compile_str\n", | 
|  | "from iree.tools import tflite as iree_tflite\n", | 
|  | "\n", | 
|  | "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", | 
|  | "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "dweksnWs1eZb" | 
|  | }, | 
|  | "source": [ | 
|  | "### Load the TFLite model\n", | 
|  | "\n", | 
|  | "1.   Download files for the pretrained model\n", | 
|  | "2.   Extract model metadata used for input pre-processing and output post-processing\n", | 
|  | "3.   Define helper functions for pre- and post-processing\n", | 
|  | "\n", | 
|  | "These steps will differ from model to model.  Consult the model source or reference documentation for details.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "cUTaotkV7taP", | 
|  | "outputId": "778f3dbb-e540-4840-901e-1ede5a8b0cf7" | 
|  | }, | 
|  | "source": [ | 
|  | "#@title Download pretrained text classification model\n", | 
|  | "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n", | 
|  | "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "execute_result", | 
|  | "data": { | 
|  | "text/plain": [ | 
|  | "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n", | 
|  | " <http.client.HTTPMessage at 0x7f86b3d0add0>)" | 
|  | ] | 
|  | }, | 
|  | "metadata": { | 
|  | "tags": [] | 
|  | }, | 
|  | "execution_count": 3 | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "ADu-gSDnIm2B", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "outputId": "c563db87-ad8c-4f72-8b22-afd5c9a1717f" | 
|  | }, | 
|  | "source": [ | 
|  | "#@title Extract model vocab and label metadata\n", | 
|  | "!unzip -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n", | 
|  | "\n", | 
|  | "# Load the vocab file into a dictionary.  It contains the most common 1,000\n", | 
|  | "# words in the English language, mapped to an integer.\n", | 
|  | "vocab = {}\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n", | 
|  | "  for line in vocab_file:\n", | 
|  | "    (key, val) = line.split()\n", | 
|  | "    vocab[key] = int(val)\n", | 
|  | "\n", | 
|  | "# Text will be labeled as either 'Positive' or 'Negative'.\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n", | 
|  | "  labels = label_file.read().splitlines()" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "Archive:  /tmp/iree/colab_artifacts/text_classification.tflite\n", | 
|  | " extracting: /tmp/iree/colab_artifacts/labels.txt  \n", | 
|  | " extracting: /tmp/iree/colab_artifacts/vocab.txt  \n" | 
|  | ], | 
|  | "name": "stdout" | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "Z1WKEZY1JH6E" | 
|  | }, | 
|  | "source": [ | 
|  | "#@title Input and output processing\n", | 
|  | "\n", | 
|  | "# Input text will be encoded as an integer array of fixed length 256.  The \n", | 
|  | "# input sentence will be mapped to integers from the vocab dictionary, and the \n", | 
|  | "# empty array spaces are filled with padding.\n", | 
|  | "\n", | 
|  | "SENTENCE_LEN = 256\n", | 
|  | "START = \"<START>\"\n", | 
|  | "PAD = \"<PAD>\"\n", | 
|  | "UNKNOWN = \"<UNKNOWN>\"\n", | 
|  | "\n", | 
|  | "def tokenize_input(text):\n", | 
|  | "  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n", | 
|  | "  output.fill(vocab[PAD])\n", | 
|  | "\n", | 
|  | "  # Remove capitalization and punctuation from the input text.\n", | 
|  | "  text_split = text.split()\n", | 
|  | "  text_split = [text.lower() for text in text_split]\n", | 
|  | "  text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n", | 
|  | "\n", | 
|  | "  # Prepend <START>.\n", | 
|  | "  index = 0\n", | 
|  | "  output[0][index] = vocab[START]\n", | 
|  | "  index += 1\n", | 
|  | "\n", | 
|  | "  for word in text_split:\n", | 
|  | "    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n", | 
|  | "    index += 1\n", | 
|  | "\n", | 
|  | "  return output\n", | 
|  | "\n", | 
|  | "\n", | 
|  | "def interpret_output(output):\n", | 
|  | "  if output[0] >= output[1]:\n", | 
|  | "    label = labels[0]\n", | 
|  | "    confidence = output[0]\n", | 
|  | "  else:\n", | 
|  | "    label = labels[1]\n", | 
|  | "    confidence = output[1]\n", | 
|  | "\n", | 
|  | "  print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "xq00JknF0cOz", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "outputId": "866b81ac-74fb-4133-f229-980bb1a5ebae" | 
|  | }, | 
|  | "source": [ | 
|  | "#@title Text samples\n", | 
|  | "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n", | 
|  | "negative_text = \"What a waste of my time.\"\n", | 
|  | "\n", | 
|  | "print(positive_text)\n", | 
|  | "print(tokenize_input(positive_text))" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n", | 
|  | "     0    0    0    0]]\n" | 
|  | ], | 
|  | "name": "stdout" | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "fvxhY1W7X7Gf" | 
|  | }, | 
|  | "source": [ | 
|  | "## Run using TFLite\n", | 
|  | "\n", | 
|  | "Overview:\n", | 
|  | "\n", | 
|  | "1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n", | 
|  | "2.   Allocate tensors and get the input and output shape information\n", | 
|  | "3.   Invoke the TFLite Interpreter to test the text classification function" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "Nv8x81S5eV5g" | 
|  | }, | 
|  | "source": [ | 
|  | "interpreter = tflite.Interpreter(\n", | 
|  | "      model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n", | 
|  | "interpreter.allocate_tensors()\n", | 
|  | "input_details = interpreter.get_input_details()\n", | 
|  | "output_details = interpreter.get_output_details()\n", | 
|  | "\n", | 
|  | "def classify_text_tflite(text):\n", | 
|  | "  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n", | 
|  | "  interpreter.invoke()\n", | 
|  | "  output_data = interpreter.get_tensor(output_details[0]['index'])\n", | 
|  | "  interpret_output(output_data[0])" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "kpxfU88ckFxI", | 
|  | "outputId": "4ab615c4-8492-464a-c5b2-d9107d6b22c7" | 
|  | }, | 
|  | "source": [ | 
|  | "print(\"Invoking text classification with TFLite\\n\")\n", | 
|  | "print(positive_text)\n", | 
|  | "classify_text_tflite(positive_text)\n", | 
|  | "print()\n", | 
|  | "print(negative_text)\n", | 
|  | "classify_text_tflite(negative_text)" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "Invoking text classification with TFLite\n", | 
|  | "\n", | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "Label: Positive\n", | 
|  | "Confidence: 0.8997293\n", | 
|  | "\n", | 
|  | "What a waste of my time.\n", | 
|  | "Label: Negative\n", | 
|  | "Confidence: 0.6275043\n" | 
|  | ], | 
|  | "name": "stdout" | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "_FgwUBtm7Y5n" | 
|  | }, | 
|  | "source": [ | 
|  | "## Run using IREE\n", | 
|  | "\n", | 
|  | "Overview:\n", | 
|  | "\n", | 
|  | "1.   Import the TFLite model to TOSA MLIR \n", | 
|  | "2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module\n", | 
|  | "3.   Run the VM module through IREE's runtime to test the text classification function\n", | 
|  | "\n", | 
|  | "Both runtimes should generate the same output.\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "ZifuTzgEXLJn" | 
|  | }, | 
|  | "source": [ | 
|  | "# Convert TFLite model to TOSA MLIR with IREE's import tool.\n", | 
|  | "IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')\n", | 
|  | "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification.mlir\n", | 
|  | "\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n", | 
|  | "  tosa_mlir = mlir_file.read()\n", | 
|  | "\n", | 
|  | "# Manually insert \"iree.module.export\" attribute until it is removed. \n", | 
|  | "# https://github.com/google/iree/issues/3968\n", | 
|  | "modified_tosa_mlir = tosa_mlir.replace('outputs = \"Identity\"}','outputs = \"Identity\"}, iree.module.export')\n", | 
|  | "\n", | 
|  | "# The generated .mlir file could now be saved and used outside of Python, with\n", | 
|  | "# IREE native tools or in apps, etc." | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "-C-5X4C0D0La", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "outputId": "4e7a41f5-a886-454f-fad2-db7e210cf10e" | 
|  | }, | 
|  | "source": [ | 
|  | "# The model contains very large constants, so recompile a truncated version to print.\n", | 
|  | "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification_truncated.mlir -mlir-elide-elementsattrs-if-larger=50\n", | 
|  | "\n", | 
|  | "with open(ARTIFACTS_DIR.joinpath(\"text_classification_truncated.mlir\")) as truncated_mlir_file:\n", | 
|  | "  truncated_tosa_mlir = truncated_mlir_file.read()\n", | 
|  | "  print(truncated_tosa_mlir)" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "module  {\n", | 
|  | "  func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n", | 
|  | "    %0 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<10003x16xf32>} : () -> tensor<10003x16xf32>\n", | 
|  | "    %1 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n", | 
|  | "    %2 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n", | 
|  | "    %3 = \"tosa.const\"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722, 0.795746565, -1.3122592, 0.352218777, 1.04682362, -1.18796027, -0.0409261398, 1.05883229, 1.48620188, -1.13325548, 1.03072512]]> : tensor<2x16xf32>} : () -> tensor<2x16xf32>\n", | 
|  | "    %4 = \"tosa.const\"() {value = dense<[0.043447677, -0.0434476472]> : tensor<2xf32>} : () -> tensor<2xf32>\n", | 
|  | "    %5 = \"tosa.const\"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>\n", | 
|  | "    %6 = \"tosa.const\"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>\n", | 
|  | "    %7 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<f32>} : () -> tensor<f32>\n", | 
|  | "    %8 = \"tosa.transpose\"(%0, %5) : (tensor<10003x16xf32>, tensor<2xi32>) -> tensor<10003x16xf32>\n", | 
|  | "    %9 = \"tosa.reshape\"(%8) {new_shape = [1, 10003, 16]} : (tensor<10003x16xf32>) -> tensor<1x10003x16xf32>\n", | 
|  | "    %10 = \"tosa.reshape\"(%arg0) {new_shape = [1, 256]} : (tensor<1x256xi32>) -> tensor<1x256xi32>\n", | 
|  | "    %11 = \"tosa.gather\"(%9, %10) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n", | 
|  | "    %12 = \"tosa.reshape\"(%11) {new_shape = [1, 256, 16]} : (tensor<1x256x16xf32>) -> tensor<1x256x16xf32>\n", | 
|  | "    %13 = \"tosa.transpose\"(%12, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n", | 
|  | "    %14 = \"tosa.reduce_sum\"(%13) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n", | 
|  | "    %15 = \"tosa.reshape\"(%14) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n", | 
|  | "    %16 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n", | 
|  | "    %17 = \"tosa.mul\"(%15, %16) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n", | 
|  | "    %18 = \"tosa.fully_connected\"(%17, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n", | 
|  | "    %19 = \"tosa.clamp\"(%18) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n", | 
|  | "    %20 = \"tosa.fully_connected\"(%19, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n", | 
|  | "    %21 = \"tosa.exp\"(%20) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n", | 
|  | "    %22 = \"tosa.reduce_sum\"(%21) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n", | 
|  | "    %23 = \"tosa.reciprocal\"(%22) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n", | 
|  | "    %24 = \"tosa.mul\"(%21, %23) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n", | 
|  | "    return %24 : tensor<1x2xf32>\n", | 
|  | "  }\n", | 
|  | "}\n", | 
|  | "\n", | 
|  | "\n" | 
|  | ], | 
|  | "name": "stdout" | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "M3gXX2AF7aS9", | 
|  | "outputId": "b60627f7-a761-444a-b41a-058cea01fbfe" | 
|  | }, | 
|  | "source": [ | 
|  | "# Compile the TOSA MLIR into a VM module.\n", | 
|  | "compiled_flatbuffer = compile_str(modified_tosa_mlir, target_backends=[\"vmvx\"])\n", | 
|  | "vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n", | 
|  | "\n", | 
|  | "# Register the module with a runtime context.\n", | 
|  | "config = iree_rt.Config(\"vmvx\")\n", | 
|  | "ctx = iree_rt.SystemContext(config=config)\n", | 
|  | "ctx.add_vm_module(vm_module)\n", | 
|  | "invoke_text_classification = ctx.modules.module[\"main\"]\n", | 
|  | "\n", | 
|  | "def classify_text_iree(text):\n", | 
|  | "  result = invoke_text_classification(tokenize_input(text))\n", | 
|  | "  interpret_output(result[0])" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n", | 
|  | "SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n" | 
|  | ], | 
|  | "name": "stderr" | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "jvv9zhMwAWgZ", | 
|  | "outputId": "2fc7b07c-2612-45a1-e333-3b9a10aaca88" | 
|  | }, | 
|  | "source": [ | 
|  | "print(\"Invoking text classification with IREE\\n\")\n", | 
|  | "print(positive_text)\n", | 
|  | "classify_text_iree(positive_text)\n", | 
|  | "print()\n", | 
|  | "print(negative_text)\n", | 
|  | "classify_text_iree(negative_text)" | 
|  | ], | 
|  | "execution_count": null, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "Invoking text classification with IREE\n", | 
|  | "\n", | 
|  | "This is the best movie I've seen in recent years. Strongly recommend it!\n", | 
|  | "Label: Positive\n", | 
|  | "Confidence: 0.8997293\n", | 
|  | "\n", | 
|  | "What a waste of my time.\n", | 
|  | "Label: Negative\n", | 
|  | "Confidence: 0.62750435\n" | 
|  | ], | 
|  | "name": "stdout" | 
|  | } | 
|  | ] | 
|  | } | 
|  | ] | 
|  | } |