| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "simple_tensorflow_module_import.ipynb", |
| "provenance": [], |
| "collapsed_sections": [] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "h5s6ncerSpc5", |
| "colab_type": "text" |
| }, |
| "source": [ |
| "# Defines a simple TF module, saves it and loads it in IREE.\n", |
| "\n", |
| "## Start kernel:\n", |
| "* [Install a TensorFlow2 nightly pip](https://www.tensorflow.org/install) (or bring your own)\n", |
| "* Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n", |
| "* *Optional:* Prime the build: `bazel build bindings/python/pyiree`\n", |
| "* Start colab by running `python build_tools/scripts/start_colab_kernel.py` (see that file for initial setup instructions)\n", |
| "\n", |
| "## TODO:\n", |
| "\n", |
| "* This is just using low-level binding classes. Change to high level API.\n", |
| "* Plumg through ability to run TF compiler lowering passes and import directly into IREE\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "s2bScbYkP6VZ", |
| "colab_type": "code", |
| "colab": {} |
| }, |
| "source": [ |
| "import os\n", |
| "import tensorflow as tf\n", |
| "import pyiree\n", |
| "from pyiree import binding\n", |
| "\n", |
| "SAVE_PATH = os.path.join(os.environ[\"HOME\"], \"saved_models\")\n", |
| "os.makedirs(SAVE_PATH, exist_ok=True)" |
| ], |
| "execution_count": 0, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "6YGqN2uqP_7P", |
| "colab_type": "code", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 411 |
| }, |
| "outputId": "4e8ba182-c7ee-402b-b6e9-15590e8617c5" |
| }, |
| "source": [ |
| "class MyModule(tf.Module):\n", |
| " def __init__(self):\n", |
| " self.v = tf.Variable([4], dtype=tf.float32)\n", |
| " \n", |
| " @tf.function(\n", |
| " input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)]\n", |
| " )\n", |
| " def add(self, a, b):\n", |
| " return tf.tanh(self.v * a + b)\n", |
| "\n", |
| "my_mod = MyModule()\n", |
| "\n", |
| "options = tf.saved_model.SaveOptions(save_debug_info=True)\n", |
| "tf.saved_model.save(my_mod, os.path.join(SAVE_PATH, \"simple.sm\"), options=options)\n", |
| "\n", |
| "mlir_asm = binding.tf_interop.import_saved_model_to_mlir_asm(os.path.join(SAVE_PATH, \"simple.sm\"))\n", |
| "print(mlir_asm)" |
| ], |
| "execution_count": 2, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "WARNING:tensorflow:From c:\\users\\laurenzo\\scoop\\apps\\python36\\current\\lib\\site-packages\\tensorflow_core\\python\\ops\\resource_variable_ops.py:1785: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", |
| "Instructions for updating:\n", |
| "If using Keras pass *_constraint arguments to layers.\n", |
| "INFO:tensorflow:Assets written to: C:\\Users\\laurenzo\\saved_models\\simple.sm\\assets\n", |
| "\n", |
| "\n", |
| "module attributes {tf_saved_model.semantics} {\n", |
| " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", |
| " func @__inference_add_160(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", |
| " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", |
| " %0 = tf_executor.graph {\n", |
| " %1:2 = tf_executor.island wraps \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", |
| " %2:2 = tf_executor.island wraps \"tf.Mul\"(%1#0, %arg0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"mul\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %3:2 = tf_executor.island wraps \"tf.AddV2\"(%2#0, %arg1) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"add\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %4:2 = tf_executor.island wraps \"tf.Tanh\"(%3#0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Tanh\"} : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " %5:2 = tf_executor.island(%1#1) wraps \"tf.Identity\"(%4#0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Identity\"} : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " tf_executor.fetch %5#0, %1#1 : tensor<4xf32>, !tf_executor.control\n", |
| " }\n", |
| " return %0 : tensor<4xf32>\n", |
| " }\n", |
| "}\n", |
| "\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| } |
| ] |
| } |