| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "variables_and_state.ipynb", |
| "provenance": [], |
| "collapsed_sections": [ |
| "FH3IRpYTta2v" |
| ] |
| }, |
| "kernelspec": { |
| "display_name": "Python 3", |
| "name": "python3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "FH3IRpYTta2v" |
| }, |
| "source": [ |
| "##### Copyright 2021 The IREE Authors" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "mWGa71_Ct2ug", |
| "cellView": "form" |
| }, |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "h5s6ncerSpc5" |
| }, |
| "source": [ |
| "# Variables and State\n", |
| "\n", |
| "This notebook\n", |
| "\n", |
| "1. Creates a TensorFlow program with basic tf.Variable use\n", |
| "2. Imports that program into IREE's compiler\n", |
| "3. Compiles the imported program to an IREE VM bytecode module\n", |
| "4. Tests running the compiled VM module using IREE's runtime\n", |
| "5. Downloads compilation artifacts for use with the native (C API) sample application" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "s2bScbYkP6VZ", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "8f3fec05-7c72-41c3-a094-541193dc4c86" |
| }, |
| "source": [ |
| "#@title General setup\n", |
| "\n", |
| "import os\n", |
| "import tempfile\n", |
| "\n", |
| "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", |
| "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n", |
| "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")" |
| ], |
| "execution_count": 2, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Using artifacts directory '/tmp/iree/colab_artifacts'\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "dBHgjTjGPOJ7" |
| }, |
| "source": [ |
| "## Create a program using TensorFlow and import it into IREE\n", |
| "\n", |
| "This program uses `tf.Variable` to track state internal to the program then exports functions which can be used to interact with that variable.\n", |
| "\n", |
| "Note that each function we want to be callable from our compiled program needs\n", |
| "to use `@tf.function` with an `input_signature` specified.\n", |
| "\n", |
| "References:\n", |
| "\n", |
| "* [\"Introduction to Variables\" Guide](https://www.tensorflow.org/guide/variable)\n", |
| "* [`tf.Variable` reference](https://www.tensorflow.org/api_docs/python/tf/Variable)\n", |
| "* [`tf.function` reference](https://www.tensorflow.org/api_docs/python/tf/function)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "hwApbPstraWZ" |
| }, |
| "source": [ |
| "#@title Define a simple \"counter\" TensorFlow module\n", |
| "\n", |
| "import tensorflow as tf\n", |
| "\n", |
| "class CounterModule(tf.Module):\n", |
| " def __init__(self):\n", |
| " super().__init__()\n", |
| " self.counter = tf.Variable(0)\n", |
| "\n", |
| " @tf.function(input_signature=[])\n", |
| " def get_value(self):\n", |
| " return self.counter\n", |
| " \n", |
| " @tf.function(input_signature=[tf.TensorSpec([], tf.int32)])\n", |
| " def set_value(self, new_value):\n", |
| " self.counter.assign(new_value)\n", |
| " \n", |
| " @tf.function(input_signature=[tf.TensorSpec([], tf.int32)])\n", |
| " def add_to_value(self, x):\n", |
| " self.counter.assign(self.counter + x)\n", |
| "\n", |
| " @tf.function(input_signature=[])\n", |
| " def reset_value(self):\n", |
| " self.set_value(0)" |
| ], |
| "execution_count": 3, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "k4aMPI2C7btB" |
| }, |
| "source": [ |
| "%%capture\n", |
| "!python -m pip install iree-compiler iree-tools-tf -f https://github.com/iree-org/iree/releases" |
| ], |
| "execution_count": 4, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "3nSXZiZ_X8-P", |
| "outputId": "c8996863-0733-4699-d365-1d04994fbf73" |
| }, |
| "source": [ |
| "#@title Import the TensorFlow program into IREE as MLIR\n", |
| "\n", |
| "from IPython.display import clear_output\n", |
| "\n", |
| "from iree.compiler import tf as tfc\n", |
| "\n", |
| "compiler_module = tfc.compile_module(\n", |
| " CounterModule(), import_only=True, output_mlir_debuginfo=False)\n", |
| "clear_output() # Skip over TensorFlow's output.\n", |
| "\n", |
| "# Print the imported MLIR to see how the compiler views this TensorFlow program.\n", |
| "# Note IREE's `util.global` ops and the public (exported) functions.\n", |
| "print(\"Counter MLIR:\\n```\\n%s```\\n\" % compiler_module.decode(\"utf-8\"))\n", |
| "\n", |
| "# Save the imported MLIR to disk.\n", |
| "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"counter.mlir\")\n", |
| "with open(imported_mlir_path, \"wt\") as output_file:\n", |
| " output_file.write(compiler_module.decode(\"utf-8\"))\n", |
| "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")" |
| ], |
| "execution_count": 5, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Counter MLIR:\n", |
| "```\n", |
| "\"builtin.module\"() ({\n", |
| " \"iree_input.global\"() {initial_value = dense<0> : tensor<i32>, is_mutable, sym_name = \"counter\", sym_visibility = \"private\", type = tensor<i32>} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " ^bb0(%arg0: !iree_input.buffer_view):\n", |
| " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<i32>\n", |
| " \"func.call\"(%0) {callee = @__inference_add_to_value_100} : (tensor<i32>) -> ()\n", |
| " \"func.return\"() : () -> ()\n", |
| " }) {function_type = (!iree_input.buffer_view) -> (), iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"add_to_value\"} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " ^bb0(%arg0: tensor<i32>):\n", |
| " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n", |
| " %1 = \"iree_input.global.load.indirect\"(%0) : (!iree_input.ptr<tensor<i32>>) -> tensor<i32>\n", |
| " %2 = \"chlo.broadcast_add\"(%1, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>\n", |
| " \"iree_input.global.store.indirect\"(%2, %0) : (tensor<i32>, !iree_input.ptr<tensor<i32>>) -> ()\n", |
| " \"func.return\"() : () -> ()\n", |
| " }) {arg_attrs = [{tf._user_specified_name = \"x\"}], function_type = (tensor<i32>) -> (), sym_name = \"__inference_add_to_value_100\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " %0 = \"func.call\"() {callee = @__inference_get_value_160} : () -> tensor<i32>\n", |
| " %1 = \"iree_input.cast.tensor_to_buffer_view\"(%0) : (tensor<i32>) -> !iree_input.buffer_view\n", |
| " \"func.return\"(%1) : (!iree_input.buffer_view) -> ()\n", |
| " }) {function_type = () -> !iree_input.buffer_view, iree.abi = \"{\\22a\\22:[],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\", sym_name = \"get_value\"} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n", |
| " %1 = \"iree_input.global.load.indirect\"(%0) : (!iree_input.ptr<tensor<i32>>) -> tensor<i32>\n", |
| " \"func.return\"(%1) : (tensor<i32>) -> ()\n", |
| " }) {function_type = () -> tensor<i32>, sym_name = \"__inference_get_value_160\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " \"func.call\"() {callee = @__inference_reset_value_270} : () -> ()\n", |
| " \"func.return\"() : () -> ()\n", |
| " }) {function_type = () -> (), iree.abi = \"{\\22a\\22:[],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"reset_value\"} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " %0 = \"mhlo.constant\"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>\n", |
| " %1 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n", |
| " \"iree_input.global.store.indirect\"(%0, %1) : (tensor<i32>, !iree_input.ptr<tensor<i32>>) -> ()\n", |
| " \"func.return\"() : () -> ()\n", |
| " }) {function_type = () -> (), sym_name = \"__inference_reset_value_270\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " ^bb0(%arg0: !iree_input.buffer_view):\n", |
| " %0 = \"iree_input.cast.buffer_view_to_tensor\"(%arg0) : (!iree_input.buffer_view) -> tensor<i32>\n", |
| " \"func.call\"(%0) {callee = @__sm_exported___inference_set_value_230} : (tensor<i32>) -> ()\n", |
| " \"func.return\"() : () -> ()\n", |
| " }) {function_type = (!iree_input.buffer_view) -> (), iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22r\\22:[],\\22v\\22:1}\", sym_name = \"set_value\"} : () -> ()\n", |
| " \"func.func\"() ({\n", |
| " ^bb0(%arg0: tensor<i32>):\n", |
| " %0 = \"iree_input.global.address\"() {global = @counter} : () -> !iree_input.ptr<tensor<i32>>\n", |
| " \"iree_input.global.store.indirect\"(%arg0, %0) : (tensor<i32>, !iree_input.ptr<tensor<i32>>) -> ()\n", |
| " \"func.return\"() : () -> ()\n", |
| " }) {arg_attrs = [{tf._user_specified_name = \"new_value\"}], function_type = (tensor<i32>) -> (), sym_name = \"__sm_exported___inference_set_value_230\", sym_visibility = \"private\", tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} : () -> ()\n", |
| "}) : () -> ()\n", |
| "\n", |
| "```\n", |
| "\n", |
| "Wrote MLIR to path '/tmp/iree/colab_artifacts/counter.mlir'\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "WCiRV6KRh3iA" |
| }, |
| "source": [ |
| "## Test the imported program\n", |
| "\n", |
| "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n", |
| "\n", |
| "_See the [README](https://github.com/iree-org/iree/tree/main/iree/samples/variables_and_state#changing-compilation-options) for more details and example command line instructions._\n", |
| "\n", |
| "* _The \"imported MLIR\" can be used by IREE's generic compiler tools_\n", |
| "* _The \"flatbuffer blob\" can be saved and used by runtime applications_\n", |
| "\n", |
| "_The specific point at which you switch from Python to native tools will depend on your project._" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "6TV6_Hdu6Xlf" |
| }, |
| "source": [ |
| "%%capture\n", |
| "!python -m pip install iree-compiler -f https://github.com/iree-org/iree/releases" |
| ], |
| "execution_count": 6, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "GF0dzDsbaP2w", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "fc95cf5b-f51f-41d2-c6f8-580d644ac287" |
| }, |
| "source": [ |
| "#@title Compile the imported MLIR further into an IREE VM bytecode module\n", |
| "\n", |
| "from iree.compiler import compile_str\n", |
| "\n", |
| "flatbuffer_blob = compile_str(compiler_module, target_backends=[\"vmvx\"], input_type=\"mhlo\")\n", |
| "\n", |
| "# Save the compiled program to disk.\n", |
| "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"counter_vmvx.vmfb\")\n", |
| "with open(flatbuffer_path, \"wb\") as output_file:\n", |
| " output_file.write(flatbuffer_blob)\n", |
| "print(f\"Wrote .vmfb to path '{flatbuffer_path}'\")" |
| ], |
| "execution_count": 7, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Wrote .vmfb to path '/tmp/iree/colab_artifacts/counter_vmvx.vmfb'\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "G7g5eXYL6hWb" |
| }, |
| "source": [ |
| "%%capture\n", |
| "!python -m pip install iree-runtime -f https://github.com/iree-org/iree/releases" |
| ], |
| "execution_count": 8, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "h8cmF6nAfza0" |
| }, |
| "source": [ |
| "#@title Test running the compiled VM module using IREE's runtime\n", |
| "\n", |
| "from iree import runtime as ireert\n", |
| "\n", |
| "config = ireert.Config(\"local-task\")\n", |
| "ctx = ireert.SystemContext(config=config)\n", |
| "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)\n", |
| "ctx.add_vm_module(vm_module)" |
| ], |
| "execution_count": 9, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "CQffg1iQatkb", |
| "outputId": "0cdaa978-06cd-4afa-f70e-e0d9d90e9d23" |
| }, |
| "source": [ |
| "# Our @tf.functions are accessible by name on the module named 'module'\n", |
| "counter = ctx.modules.module\n", |
| "\n", |
| "print(counter.get_value().to_host())\n", |
| "counter.set_value(101)\n", |
| "print(counter.get_value().to_host())\n", |
| "\n", |
| "counter.add_to_value(20)\n", |
| "print(counter.get_value().to_host())\n", |
| "counter.add_to_value(-50)\n", |
| "print(counter.get_value().to_host())\n", |
| "\n", |
| "counter.reset_value()\n", |
| "print(counter.get_value().to_host())" |
| ], |
| "execution_count": 10, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "0\n", |
| "101\n", |
| "121\n", |
| "71\n", |
| "0\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "wCvwX1IEokm6" |
| }, |
| "source": [ |
| "## Download compilation artifacts" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "bUaNUkS2ohRj", |
| "outputId": "192869dd-8e63-4b68-9f97-9664ad942346", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 86 |
| } |
| }, |
| "source": [ |
| "ARTIFACTS_ZIP = \"/tmp/variables_and_state_colab_artifacts.zip\"\n", |
| "\n", |
| "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n", |
| "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n", |
| "\n", |
| "# Note: you can also download files using Colab's file explorer\n", |
| "try:\n", |
| " from google.colab import files\n", |
| " print(\"Downloading the artifacts zip file...\")\n", |
| " files.download(ARTIFACTS_ZIP)\n", |
| "except ImportError:\n", |
| " print(\"Missing google_colab Python package, can't download files\")" |
| ], |
| "execution_count": 11, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Zipping '/tmp/iree/colab_artifacts' to '/tmp/variables_and_state_colab_artifacts.zip' for download...\n", |
| " adding: counter.mlir (deflated 82%)\n", |
| " adding: counter_vmvx.vmfb (deflated 62%)\n", |
| "Downloading the artifacts zip file...\n" |
| ] |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "<IPython.core.display.Javascript object>" |
| ], |
| "application/javascript": [ |
| "\n", |
| " async function download(id, filename, size) {\n", |
| " if (!google.colab.kernel.accessAllowed) {\n", |
| " return;\n", |
| " }\n", |
| " const div = document.createElement('div');\n", |
| " const label = document.createElement('label');\n", |
| " label.textContent = `Downloading \"${filename}\": `;\n", |
| " div.appendChild(label);\n", |
| " const progress = document.createElement('progress');\n", |
| " progress.max = size;\n", |
| " div.appendChild(progress);\n", |
| " document.body.appendChild(div);\n", |
| "\n", |
| " const buffers = [];\n", |
| " let downloaded = 0;\n", |
| "\n", |
| " const channel = await google.colab.kernel.comms.open(id);\n", |
| " // Send a message to notify the kernel that we're ready.\n", |
| " channel.send({})\n", |
| "\n", |
| " for await (const message of channel.messages) {\n", |
| " // Send a message to notify the kernel that we're ready.\n", |
| " channel.send({})\n", |
| " if (message.buffers) {\n", |
| " for (const buffer of message.buffers) {\n", |
| " buffers.push(buffer);\n", |
| " downloaded += buffer.byteLength;\n", |
| " progress.value = downloaded;\n", |
| " }\n", |
| " }\n", |
| " }\n", |
| " const blob = new Blob(buffers, {type: 'application/binary'});\n", |
| " const a = document.createElement('a');\n", |
| " a.href = window.URL.createObjectURL(blob);\n", |
| " a.download = filename;\n", |
| " div.appendChild(a);\n", |
| " a.click();\n", |
| " div.remove();\n", |
| " }\n", |
| " " |
| ] |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "<IPython.core.display.Javascript object>" |
| ], |
| "application/javascript": [ |
| "download(\"download_1285be07-105d-4d50-84e7-523370341c4a\", \"variables_and_state_colab_artifacts.zip\", 4286)" |
| ] |
| }, |
| "metadata": {} |
| } |
| ] |
| } |
| ] |
| } |