Fix TFLite sample by setting input_type="tosa". (#6701)

Relates to https://github.com/google/iree/issues/6691

Not tested on CI (https://buildkite.com/iree/iree-samples) yet, since that's been broken since https://github.com/google/iree/pull/6611.
diff --git a/colab/test_notebooks.py b/colab/test_notebooks.py
index 2d3f6f0..683d15d 100644
--- a/colab/test_notebooks.py
+++ b/colab/test_notebooks.py
@@ -15,9 +15,8 @@
 NOTEBOOKS_TO_SKIP = []
 
 NOTEBOOKS_EXPECTED_TO_FAIL = [
-    # Text classification notebook
-    #   * fails to extract the vocab file on Docker
-    #   * fails to compile the imported .mlir in Colab
+    # Text classification notebook fails to extract the vocab file on Docker
+    # (needs visibility into the tempdir?)
     "tflite_text_classification.ipynb",
 ]
 
diff --git a/colab/tflite_text_classification.ipynb b/colab/tflite_text_classification.ipynb
index 5c43d1b..9d6ce5a 100644
--- a/colab/tflite_text_classification.ipynb
+++ b/colab/tflite_text_classification.ipynb
@@ -63,7 +63,7 @@
         "!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot -f https://github.com/google/iree/releases/latest\n",
         "!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime"
       ],
-      "execution_count": null,
+      "execution_count": 14,
       "outputs": []
     },
     {
@@ -87,7 +87,7 @@
         "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
         "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)"
       ],
-      "execution_count": null,
+      "execution_count": 15,
       "outputs": []
     },
     {
@@ -112,27 +112,27 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "cUTaotkV7taP",
-        "outputId": "778f3dbb-e540-4840-901e-1ede5a8b0cf7"
+        "outputId": "36f8be61-9565-4da4-ae92-8b0eceb05419"
       },
       "source": [
         "#@title Download pretrained text classification model\n",
         "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n",
         "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))"
       ],
-      "execution_count": null,
+      "execution_count": 16,
       "outputs": [
         {
           "output_type": "execute_result",
           "data": {
             "text/plain": [
               "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n",
-              " <http.client.HTTPMessage at 0x7f86b3d0add0>)"
+              " <http.client.HTTPMessage at 0x7f91d58b8990>)"
             ]
           },
           "metadata": {
             "tags": []
           },
-          "execution_count": 3
+          "execution_count": 16
         }
       ]
     },
@@ -143,11 +143,11 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "c563db87-ad8c-4f72-8b22-afd5c9a1717f"
+        "outputId": "63465e5d-8de8-4b99-a386-5fd264b03bc0"
       },
       "source": [
         "#@title Extract model vocab and label metadata\n",
-        "!unzip -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n",
+        "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n",
         "\n",
         "# Load the vocab file into a dictionary.  It contains the most common 1,000\n",
         "# words in the English language, mapped to an integer.\n",
@@ -161,7 +161,7 @@
         "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n",
         "  labels = label_file.read().splitlines()"
       ],
-      "execution_count": null,
+      "execution_count": 17,
       "outputs": [
         {
           "output_type": "stream",
@@ -222,7 +222,7 @@
         "\n",
         "  print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))"
       ],
-      "execution_count": null,
+      "execution_count": 18,
       "outputs": []
     },
     {
@@ -232,7 +232,7 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "866b81ac-74fb-4133-f229-980bb1a5ebae"
+        "outputId": "d04fddad-852f-426a-96c7-397d9e5cd72d"
       },
       "source": [
         "#@title Text samples\n",
@@ -242,7 +242,7 @@
         "print(positive_text)\n",
         "print(tokenize_input(positive_text))"
       ],
-      "execution_count": null,
+      "execution_count": 19,
       "outputs": [
         {
           "output_type": "stream",
@@ -305,7 +305,7 @@
         "  output_data = interpreter.get_tensor(output_details[0]['index'])\n",
         "  interpret_output(output_data[0])"
       ],
-      "execution_count": null,
+      "execution_count": 20,
       "outputs": []
     },
     {
@@ -315,7 +315,7 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "kpxfU88ckFxI",
-        "outputId": "4ab615c4-8492-464a-c5b2-d9107d6b22c7"
+        "outputId": "a6ef8776-8b9b-4b29-931a-387e1e6b18fa"
       },
       "source": [
         "print(\"Invoking text classification with TFLite\\n\")\n",
@@ -325,7 +325,7 @@
         "print(negative_text)\n",
         "classify_text_tflite(negative_text)"
       ],
-      "execution_count": null,
+      "execution_count": 21,
       "outputs": [
         {
           "output_type": "stream",
@@ -374,14 +374,10 @@
         "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n",
         "  tosa_mlir = mlir_file.read()\n",
         "\n",
-        "# Manually insert \"iree.module.export\" attribute until it is removed. \n",
-        "# https://github.com/google/iree/issues/3968\n",
-        "modified_tosa_mlir = tosa_mlir.replace('outputs = \"Identity\"}','outputs = \"Identity\"}, iree.module.export')\n",
-        "\n",
         "# The generated .mlir file could now be saved and used outside of Python, with\n",
         "# IREE native tools or in apps, etc."
       ],
-      "execution_count": null,
+      "execution_count": 22,
       "outputs": []
     },
     {
@@ -391,7 +387,7 @@
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "4e7a41f5-a886-454f-fad2-db7e210cf10e"
+        "outputId": "dc92732b-793d-431c-ca67-a1427044e7b8"
       },
       "source": [
         "# The model contains very large constants, so recompile a truncated version to print.\n",
@@ -401,13 +397,13 @@
         "  truncated_tosa_mlir = truncated_mlir_file.read()\n",
         "  print(truncated_tosa_mlir)"
       ],
-      "execution_count": null,
+      "execution_count": 23,
       "outputs": [
         {
           "output_type": "stream",
           "text": [
-            "module  {\n",
-            "  func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n",
+            "builtin.module  {\n",
+            "  builtin.func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n",
             "    %0 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<10003x16xf32>} : () -> tensor<10003x16xf32>\n",
             "    %1 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n",
             "    %2 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n",
@@ -418,22 +414,20 @@
             "    %7 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<f32>} : () -> tensor<f32>\n",
             "    %8 = \"tosa.transpose\"(%0, %5) : (tensor<10003x16xf32>, tensor<2xi32>) -> tensor<10003x16xf32>\n",
             "    %9 = \"tosa.reshape\"(%8) {new_shape = [1, 10003, 16]} : (tensor<10003x16xf32>) -> tensor<1x10003x16xf32>\n",
-            "    %10 = \"tosa.reshape\"(%arg0) {new_shape = [1, 256]} : (tensor<1x256xi32>) -> tensor<1x256xi32>\n",
-            "    %11 = \"tosa.gather\"(%9, %10) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
-            "    %12 = \"tosa.reshape\"(%11) {new_shape = [1, 256, 16]} : (tensor<1x256x16xf32>) -> tensor<1x256x16xf32>\n",
-            "    %13 = \"tosa.transpose\"(%12, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n",
-            "    %14 = \"tosa.reduce_sum\"(%13) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n",
-            "    %15 = \"tosa.reshape\"(%14) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n",
-            "    %16 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n",
-            "    %17 = \"tosa.mul\"(%15, %16) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
-            "    %18 = \"tosa.fully_connected\"(%17, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n",
-            "    %19 = \"tosa.clamp\"(%18) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n",
-            "    %20 = \"tosa.fully_connected\"(%19, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n",
-            "    %21 = \"tosa.exp\"(%20) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n",
-            "    %22 = \"tosa.reduce_sum\"(%21) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n",
-            "    %23 = \"tosa.reciprocal\"(%22) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n",
-            "    %24 = \"tosa.mul\"(%21, %23) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n",
-            "    return %24 : tensor<1x2xf32>\n",
+            "    %10 = \"tosa.gather\"(%9, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
+            "    %11 = \"tosa.transpose\"(%10, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n",
+            "    %12 = \"tosa.reduce_sum\"(%11) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n",
+            "    %13 = \"tosa.reshape\"(%12) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n",
+            "    %14 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n",
+            "    %15 = \"tosa.mul\"(%13, %14) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
+            "    %16 = \"tosa.fully_connected\"(%15, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n",
+            "    %17 = \"tosa.clamp\"(%16) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n",
+            "    %18 = \"tosa.fully_connected\"(%17, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n",
+            "    %19 = \"tosa.exp\"(%18) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n",
+            "    %20 = \"tosa.reduce_sum\"(%19) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n",
+            "    %21 = \"tosa.reciprocal\"(%20) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n",
+            "    %22 = \"tosa.mul\"(%19, %21) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n",
+            "    return %22 : tensor<1x2xf32>\n",
             "  }\n",
             "}\n",
             "\n",
@@ -446,15 +440,11 @@
     {
       "cell_type": "code",
       "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "M3gXX2AF7aS9",
-        "outputId": "b60627f7-a761-444a-b41a-058cea01fbfe"
+        "id": "M3gXX2AF7aS9"
       },
       "source": [
         "# Compile the TOSA MLIR into a VM module.\n",
-        "compiled_flatbuffer = compile_str(modified_tosa_mlir, target_backends=[\"vmvx\"])\n",
+        "compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n",
         "vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n",
         "\n",
         "# Register the module with a runtime context.\n",
@@ -467,17 +457,8 @@
         "  result = invoke_text_classification(tokenize_input(text))\n",
         "  interpret_output(result[0])"
       ],
-      "execution_count": null,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "text": [
-            "Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n",
-            "SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n"
-          ],
-          "name": "stderr"
-        }
-      ]
+      "execution_count": 24,
+      "outputs": []
     },
     {
       "cell_type": "code",
@@ -486,7 +467,7 @@
           "base_uri": "https://localhost:8080/"
         },
         "id": "jvv9zhMwAWgZ",
-        "outputId": "2fc7b07c-2612-45a1-e333-3b9a10aaca88"
+        "outputId": "b086363f-28b9-4ff2-817e-e1928d7cfe4a"
       },
       "source": [
         "print(\"Invoking text classification with IREE\\n\")\n",
@@ -496,7 +477,7 @@
         "print(negative_text)\n",
         "classify_text_iree(negative_text)"
       ],
-      "execution_count": null,
+      "execution_count": 25,
       "outputs": [
         {
           "output_type": "stream",
@@ -516,4 +497,4 @@
       ]
     }
   ]
-}
+}
\ No newline at end of file
diff --git a/docs/website/docs/ml-frameworks/tensorflow-lite.md b/docs/website/docs/ml-frameworks/tensorflow-lite.md
index e807555..8db8a33 100644
--- a/docs/website/docs/ml-frameworks/tensorflow-lite.md
+++ b/docs/website/docs/ml-frameworks/tensorflow-lite.md
@@ -47,6 +47,7 @@
 ``` shell
 iree-translate \
   --iree-mlir-to-vm-bytecode-module \
+  --iree-input-type=tosa \
   --iree-hal-target-backends=vmvx \
   sample.mlir \
   -o sample.vmfb