Colab showing definition of a module in Python and loading.
Closes #104
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/104 from stellaraccident:colabsm1 ef7126186385830fb1f195199b1dbd8305cd49b0
PiperOrigin-RevId: 276736666
diff --git a/bindings/python/pyiree/tensorflow/BUILD b/bindings/python/pyiree/tensorflow/BUILD
index f05a403..2838997 100644
--- a/bindings/python/pyiree/tensorflow/BUILD
+++ b/bindings/python/pyiree/tensorflow/BUILD
@@ -76,6 +76,8 @@
"@org_tensorflow//tensorflow/core/kernels:partitioned_function_ops",
"@org_tensorflow//tensorflow/core/kernels:identity_op",
"@org_tensorflow//tensorflow/core/kernels:identity_n_op",
+ "@org_tensorflow//tensorflow/core/kernels:resource_variable_ops",
+ "@org_tensorflow//tensorflow/core/kernels:state",
]
cc_library(
diff --git a/colab/simple_tensorflow_module_import.ipynb b/colab/simple_tensorflow_module_import.ipynb
new file mode 100644
index 0000000..ed4766e
--- /dev/null
+++ b/colab/simple_tensorflow_module_import.ipynb
@@ -0,0 +1,119 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "simple_tensorflow_module_import.ipynb",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "h5s6ncerSpc5",
+ "colab_type": "text"
+ },
+ "source": [
+ "# Defines a simple TF module, saves it and loads it in IREE.\n",
+ "\n",
+ "## Start kernel:\n",
+ "* [Install a TensorFlow2 nightly pip](https://www.tensorflow.org/install) (or bring your own)\n",
+ "* Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n",
+ "* *Optional:* Prime the build: `bazel build bindings/python/pyiree`\n",
+ "* Start colab by running `python build_tools/scripts/start_colab_kernel.py` (see that file for initial setup instructions)\n",
+ "\n",
+ "## TODO:\n",
+ "\n",
+ "* This is just using low-level binding classes. Change to high level API.\n",
+ "* Plumg through ability to run TF compiler lowering passes and import directly into IREE\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "s2bScbYkP6VZ",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import os\n",
+ "import tensorflow as tf\n",
+ "import pyiree\n",
+ "from pyiree import binding\n",
+ "\n",
+ "SAVE_PATH = os.path.join(os.environ[\"HOME\"], \"saved_models\")\n",
+ "os.makedirs(SAVE_PATH, exist_ok=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6YGqN2uqP_7P",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 411
+ },
+ "outputId": "4e8ba182-c7ee-402b-b6e9-15590e8617c5"
+ },
+ "source": [
+ "class MyModule(tf.Module):\n",
+ " def __init__(self):\n",
+ " self.v = tf.Variable([4], dtype=tf.float32)\n",
+ " \n",
+ " @tf.function(\n",
+ " input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)]\n",
+ " )\n",
+ " def add(self, a, b):\n",
+ " return tf.tanh(self.v * a + b)\n",
+ "\n",
+ "my_mod = MyModule()\n",
+ "\n",
+ "options = tf.saved_model.SaveOptions(save_debug_info=True)\n",
+ "tf.saved_model.save(my_mod, os.path.join(SAVE_PATH, \"simple.sm\"), options=options)\n",
+ "\n",
+ "mlir_asm = binding.tf_interop.import_saved_model_to_mlir_asm(os.path.join(SAVE_PATH, \"simple.sm\"))\n",
+ "print(mlir_asm)"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From c:\\users\\laurenzo\\scoop\\apps\\python36\\current\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1785: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "If using Keras pass *_constraint arguments to layers.\n",
+ "INFO:tensorflow:Assets written to: C:\\Users\\laurenzo\\saved_models\\simple.sm\\assets\n",
+ "\n",
+ "\n",
+ "module attributes {tf_saved_model.semantics} {\n",
+ " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n",
+ " func @__inference_add_160(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n",
+ " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n",
+ " %0 = tf_executor.graph {\n",
+ " %1:2 = tf_executor.island wraps \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n",
+ " %2:2 = tf_executor.island wraps \"tf.Mul\"(%1#0, %arg0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"mul\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
+ " %3:2 = tf_executor.island wraps \"tf.AddV2\"(%2#0, %arg1) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"add\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
+ " %4:2 = tf_executor.island wraps \"tf.Tanh\"(%3#0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Tanh\"} : (tensor<4xf32>) -> tensor<4xf32>\n",
+ " %5:2 = tf_executor.island(%1#1) wraps \"tf.Identity\"(%4#0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Identity\"} : (tensor<4xf32>) -> tensor<4xf32>\n",
+ " tf_executor.fetch %5#0, %1#1 : tensor<4xf32>, !tf_executor.control\n",
+ " }\n",
+ " return %0 : tensor<4xf32>\n",
+ " }\n",
+ "}\n",
+ "\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file