Fix TFLite sample by setting input_type="tosa". (#6701)
Relates to https://github.com/google/iree/issues/6691
Not tested on CI (https://buildkite.com/iree/iree-samples) yet, since that's been broken since https://github.com/google/iree/pull/6611.
diff --git a/colab/test_notebooks.py b/colab/test_notebooks.py
index 2d3f6f0..683d15d 100644
--- a/colab/test_notebooks.py
+++ b/colab/test_notebooks.py
@@ -15,9 +15,8 @@
NOTEBOOKS_TO_SKIP = []
NOTEBOOKS_EXPECTED_TO_FAIL = [
- # Text classification notebook
- # * fails to extract the vocab file on Docker
- # * fails to compile the imported .mlir in Colab
+ # Text classification notebook fails to extract the vocab file on Docker
+ # (needs visibility into the tempdir?)
"tflite_text_classification.ipynb",
]
diff --git a/colab/tflite_text_classification.ipynb b/colab/tflite_text_classification.ipynb
index 5c43d1b..9d6ce5a 100644
--- a/colab/tflite_text_classification.ipynb
+++ b/colab/tflite_text_classification.ipynb
@@ -63,7 +63,7 @@
"!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot -f https://github.com/google/iree/releases/latest\n",
"!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime"
],
- "execution_count": null,
+ "execution_count": 14,
"outputs": []
},
{
@@ -87,7 +87,7 @@
"ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
"ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)"
],
- "execution_count": null,
+ "execution_count": 15,
"outputs": []
},
{
@@ -112,27 +112,27 @@
"base_uri": "https://localhost:8080/"
},
"id": "cUTaotkV7taP",
- "outputId": "778f3dbb-e540-4840-901e-1ede5a8b0cf7"
+ "outputId": "36f8be61-9565-4da4-ae92-8b0eceb05419"
},
"source": [
"#@title Download pretrained text classification model\n",
"MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n",
"urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))"
],
- "execution_count": null,
+ "execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n",
- " <http.client.HTTPMessage at 0x7f86b3d0add0>)"
+ " <http.client.HTTPMessage at 0x7f91d58b8990>)"
]
},
"metadata": {
"tags": []
},
- "execution_count": 3
+ "execution_count": 16
}
]
},
@@ -143,11 +143,11 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "outputId": "c563db87-ad8c-4f72-8b22-afd5c9a1717f"
+ "outputId": "63465e5d-8de8-4b99-a386-5fd264b03bc0"
},
"source": [
"#@title Extract model vocab and label metadata\n",
- "!unzip -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n",
+ "!unzip -o -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n",
"\n",
"# Load the vocab file into a dictionary. It contains the most common 1,000\n",
"# words in the English language, mapped to an integer.\n",
@@ -161,7 +161,7 @@
"with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n",
" labels = label_file.read().splitlines()"
],
- "execution_count": null,
+ "execution_count": 17,
"outputs": [
{
"output_type": "stream",
@@ -222,7 +222,7 @@
"\n",
" print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))"
],
- "execution_count": null,
+ "execution_count": 18,
"outputs": []
},
{
@@ -232,7 +232,7 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "outputId": "866b81ac-74fb-4133-f229-980bb1a5ebae"
+ "outputId": "d04fddad-852f-426a-96c7-397d9e5cd72d"
},
"source": [
"#@title Text samples\n",
@@ -242,7 +242,7 @@
"print(positive_text)\n",
"print(tokenize_input(positive_text))"
],
- "execution_count": null,
+ "execution_count": 19,
"outputs": [
{
"output_type": "stream",
@@ -305,7 +305,7 @@
" output_data = interpreter.get_tensor(output_details[0]['index'])\n",
" interpret_output(output_data[0])"
],
- "execution_count": null,
+ "execution_count": 20,
"outputs": []
},
{
@@ -315,7 +315,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "kpxfU88ckFxI",
- "outputId": "4ab615c4-8492-464a-c5b2-d9107d6b22c7"
+ "outputId": "a6ef8776-8b9b-4b29-931a-387e1e6b18fa"
},
"source": [
"print(\"Invoking text classification with TFLite\\n\")\n",
@@ -325,7 +325,7 @@
"print(negative_text)\n",
"classify_text_tflite(negative_text)"
],
- "execution_count": null,
+ "execution_count": 21,
"outputs": [
{
"output_type": "stream",
@@ -374,14 +374,10 @@
"with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n",
" tosa_mlir = mlir_file.read()\n",
"\n",
- "# Manually insert \"iree.module.export\" attribute until it is removed. \n",
- "# https://github.com/google/iree/issues/3968\n",
- "modified_tosa_mlir = tosa_mlir.replace('outputs = \"Identity\"}','outputs = \"Identity\"}, iree.module.export')\n",
- "\n",
"# The generated .mlir file could now be saved and used outside of Python, with\n",
"# IREE native tools or in apps, etc."
],
- "execution_count": null,
+ "execution_count": 22,
"outputs": []
},
{
@@ -391,7 +387,7 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "outputId": "4e7a41f5-a886-454f-fad2-db7e210cf10e"
+ "outputId": "dc92732b-793d-431c-ca67-a1427044e7b8"
},
"source": [
"# The model contains very large constants, so recompile a truncated version to print.\n",
@@ -401,13 +397,13 @@
" truncated_tosa_mlir = truncated_mlir_file.read()\n",
" print(truncated_tosa_mlir)"
],
- "execution_count": null,
+ "execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
- "module {\n",
- " func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n",
+ "builtin.module {\n",
+ " builtin.func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n",
" %0 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<10003x16xf32>} : () -> tensor<10003x16xf32>\n",
" %1 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n",
" %2 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n",
@@ -418,22 +414,20 @@
" %7 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<f32>} : () -> tensor<f32>\n",
" %8 = \"tosa.transpose\"(%0, %5) : (tensor<10003x16xf32>, tensor<2xi32>) -> tensor<10003x16xf32>\n",
" %9 = \"tosa.reshape\"(%8) {new_shape = [1, 10003, 16]} : (tensor<10003x16xf32>) -> tensor<1x10003x16xf32>\n",
- " %10 = \"tosa.reshape\"(%arg0) {new_shape = [1, 256]} : (tensor<1x256xi32>) -> tensor<1x256xi32>\n",
- " %11 = \"tosa.gather\"(%9, %10) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
- " %12 = \"tosa.reshape\"(%11) {new_shape = [1, 256, 16]} : (tensor<1x256x16xf32>) -> tensor<1x256x16xf32>\n",
- " %13 = \"tosa.transpose\"(%12, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n",
- " %14 = \"tosa.reduce_sum\"(%13) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n",
- " %15 = \"tosa.reshape\"(%14) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n",
- " %16 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n",
- " %17 = \"tosa.mul\"(%15, %16) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
- " %18 = \"tosa.fully_connected\"(%17, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n",
- " %19 = \"tosa.clamp\"(%18) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n",
- " %20 = \"tosa.fully_connected\"(%19, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n",
- " %21 = \"tosa.exp\"(%20) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n",
- " %22 = \"tosa.reduce_sum\"(%21) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n",
- " %23 = \"tosa.reciprocal\"(%22) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n",
- " %24 = \"tosa.mul\"(%21, %23) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n",
- " return %24 : tensor<1x2xf32>\n",
+ " %10 = \"tosa.gather\"(%9, %arg0) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
+ " %11 = \"tosa.transpose\"(%10, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n",
+ " %12 = \"tosa.reduce_sum\"(%11) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n",
+ " %13 = \"tosa.reshape\"(%12) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n",
+ " %14 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n",
+ " %15 = \"tosa.mul\"(%13, %14) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
+ " %16 = \"tosa.fully_connected\"(%15, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n",
+ " %17 = \"tosa.clamp\"(%16) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n",
+ " %18 = \"tosa.fully_connected\"(%17, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n",
+ " %19 = \"tosa.exp\"(%18) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n",
+ " %20 = \"tosa.reduce_sum\"(%19) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n",
+ " %21 = \"tosa.reciprocal\"(%20) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n",
+ " %22 = \"tosa.mul\"(%19, %21) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n",
+ " return %22 : tensor<1x2xf32>\n",
" }\n",
"}\n",
"\n",
@@ -446,15 +440,11 @@
{
"cell_type": "code",
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "M3gXX2AF7aS9",
- "outputId": "b60627f7-a761-444a-b41a-058cea01fbfe"
+ "id": "M3gXX2AF7aS9"
},
"source": [
"# Compile the TOSA MLIR into a VM module.\n",
- "compiled_flatbuffer = compile_str(modified_tosa_mlir, target_backends=[\"vmvx\"])\n",
+ "compiled_flatbuffer = compile_str(tosa_mlir, input_type=\"tosa\", target_backends=[\"vmvx\"])\n",
"vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n",
"\n",
"# Register the module with a runtime context.\n",
@@ -467,17 +457,8 @@
" result = invoke_text_classification(tokenize_input(text))\n",
" interpret_output(result[0])"
],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n",
- "SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n"
- ],
- "name": "stderr"
- }
- ]
+ "execution_count": 24,
+ "outputs": []
},
{
"cell_type": "code",
@@ -486,7 +467,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "jvv9zhMwAWgZ",
- "outputId": "2fc7b07c-2612-45a1-e333-3b9a10aaca88"
+ "outputId": "b086363f-28b9-4ff2-817e-e1928d7cfe4a"
},
"source": [
"print(\"Invoking text classification with IREE\\n\")\n",
@@ -496,7 +477,7 @@
"print(negative_text)\n",
"classify_text_iree(negative_text)"
],
- "execution_count": null,
+ "execution_count": 25,
"outputs": [
{
"output_type": "stream",
@@ -516,4 +497,4 @@
]
}
]
-}
+}
\ No newline at end of file
diff --git a/docs/website/docs/ml-frameworks/tensorflow-lite.md b/docs/website/docs/ml-frameworks/tensorflow-lite.md
index e807555..8db8a33 100644
--- a/docs/website/docs/ml-frameworks/tensorflow-lite.md
+++ b/docs/website/docs/ml-frameworks/tensorflow-lite.md
@@ -47,6 +47,7 @@
``` shell
iree-translate \
--iree-mlir-to-vm-bytecode-module \
+ --iree-input-type=tosa \
--iree-hal-target-backends=vmvx \
sample.mlir \
-o sample.vmfb