|  | { | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 0, | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "name": "simple_tensorflow_module_import.ipynb", | 
|  | "provenance": [], | 
|  | "collapsed_sections": [] | 
|  | }, | 
|  | "kernelspec": { | 
|  | "name": "python3", | 
|  | "display_name": "Python 3" | 
|  | } | 
|  | }, | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "metadata": { | 
|  | "id": "h5s6ncerSpc5", | 
|  | "colab_type": "text" | 
|  | }, | 
|  | "source": [ | 
|  | "# Defines a simple TF module, saves it and loads it in IREE.\n", | 
|  | "\n", | 
|  | "## Start kernel:\n", | 
|  | "*   [Install a TensorFlow2 nightly pip](https://www.tensorflow.org/install) (or bring your own)\n", | 
|  | "*   Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n", | 
|  | "*   *Optional:* Prime the build: `bazel build bindings/python/pyiree`\n", | 
|  | "*   Start colab by running `python colab/start_colab_kernel.py` (see that file for initial setup instructions)\n", | 
|  | "\n", | 
|  | "## TODO:\n", | 
|  | "\n", | 
|  | "* This is just using low-level binding classes. Change to high level API.\n", | 
|  | "* Plumg through ability to run TF compiler lowering passes and import directly into IREE\n" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "s2bScbYkP6VZ", | 
|  | "colab_type": "code", | 
|  | "colab": {} | 
|  | }, | 
|  | "source": [ | 
|  | "import os\n", | 
|  | "import tensorflow as tf\n", | 
|  | "from pyiree.tf import compiler as ireec\n", | 
|  | "\n", | 
|  | "SAVE_PATH = os.path.join(os.environ[\"HOME\"], \"saved_models\")\n", | 
|  | "os.makedirs(SAVE_PATH, exist_ok=True)" | 
|  | ], | 
|  | "execution_count": 0, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "metadata": { | 
|  | "id": "6YGqN2uqP_7P", | 
|  | "colab_type": "code", | 
|  | "outputId": "ec634eb9-25e7-42c8-dd44-2aa035fa80e0", | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/", | 
|  | "height": 802 | 
|  | } | 
|  | }, | 
|  | "source": [ | 
|  | "class MyModule(tf.Module):\n", | 
|  | "  def __init__(self):\n", | 
|  | "    self.v = tf.Variable([4], dtype=tf.float32)\n", | 
|  | "  \n", | 
|  | "  @tf.function(\n", | 
|  | "      input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)]\n", | 
|  | "  )\n", | 
|  | "  def add(self, a, b):\n", | 
|  | "    return tf.tanh(self.v * a + b)\n", | 
|  | "\n", | 
|  | "my_mod = MyModule()\n", | 
|  | "\n", | 
|  | "saved_model_path = os.path.join(SAVE_PATH, \"simple.sm\")\n", | 
|  | "\n", | 
|  | "options = tf.saved_model.SaveOptions(save_debug_info=True)\n", | 
|  | "tf.saved_model.save(my_mod, saved_model_path, options=options)\n", | 
|  | "\n", | 
|  | "input_module = ireec.tf_load_saved_model(saved_model_path, pass_pipeline=[])\n", | 
|  | "print('LOADED ASM:', input_module.to_asm())\n", | 
|  | "\n", | 
|  | "# Canonicalize the TF import.\n", | 
|  | "input_module.run_pass_pipeline([\n", | 
|  | "  \"tf-executor-graph-pruning\",\n", | 
|  | "  \"tf-standard-pipeline\",\n", | 
|  | "  \"canonicalize\",\n", | 
|  | "])\n", | 
|  | "print(\"LOWERED TF ASM:\", input_module.to_asm())\n", | 
|  | "\n", | 
|  | "# Legalize to XLA (high-level).\n", | 
|  | "input_module.run_pass_pipeline([\n", | 
|  | "  \"xla-legalize-tf{allow-partial-conversion=true}\",\n", | 
|  | "])\n", | 
|  | "print(\"XLA ASM:\", input_module.to_asm())" | 
|  | ], | 
|  | "execution_count": 15, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "text": [ | 
|  | "INFO:tensorflow:Assets written to: /usr/local/google/home/scotttodd/saved_models/simple.sm/assets\n", | 
|  | "LOADED ASM: \n", | 
|  | "\n", | 
|  | "module attributes {tf_saved_model.semantics} {\n", | 
|  | "  \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", | 
|  | "  func @__inference_add_10820(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", | 
|  | "    %0 = tf_executor.graph {\n", | 
|  | "      %outputs, %control = tf_executor.island wraps \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", | 
|  | "      %outputs_0, %control_1 = tf_executor.island wraps \"tf.Mul\"(%outputs, %arg0) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "      %outputs_2, %control_3 = tf_executor.island wraps \"tf.AddV2\"(%outputs_0, %arg1) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "      %outputs_4, %control_5 = tf_executor.island wraps \"tf.Tanh\"(%outputs_2) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "      %outputs_6, %control_7 = tf_executor.island(%control) wraps \"tf.Identity\"(%outputs_4) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "      tf_executor.fetch %outputs_6, %control : tensor<4xf32>, !tf_executor.control\n", | 
|  | "    }\n", | 
|  | "    return %0 : tensor<4xf32>\n", | 
|  | "  }\n", | 
|  | "}\n", | 
|  | "\n", | 
|  | "LOWERED TF ASM: \n", | 
|  | "\n", | 
|  | "module attributes {tf_saved_model.semantics} {\n", | 
|  | "  \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", | 
|  | "  func @__inference_add_10820(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", | 
|  | "    %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", | 
|  | "    %1 = \"tf.Mul\"(%0, %arg0) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "    %2 = \"tf.AddV2\"(%1, %arg1) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "    %3 = \"tf.Tanh\"(%2) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "    %4 = \"tf.Identity\"(%3) {T = f32, _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\"} : (tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "    return %4 : tensor<4xf32>\n", | 
|  | "  }\n", | 
|  | "}\n", | 
|  | "\n", | 
|  | "XLA ASM: \n", | 
|  | "\n", | 
|  | "module attributes {tf_saved_model.semantics} {\n", | 
|  | "  \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], type = tensor<1xf32>, value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", | 
|  | "  func @__inference_add_10820(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", | 
|  | "    %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = f32} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", | 
|  | "    %1 = \"mhlo.multiply\"(%0, %arg0) : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "    %2 = mhlo.add %1, %arg1 : tensor<4xf32>\n", | 
|  | "    %3 = \"mhlo.tanh\"(%2) : (tensor<4xf32>) -> tensor<4xf32>\n", | 
|  | "    return %3 : tensor<4xf32>\n", | 
|  | "  }\n", | 
|  | "}\n", | 
|  | "\n" | 
|  | ], | 
|  | "name": "stdout" | 
|  | } | 
|  | ] | 
|  | } | 
|  | ] | 
|  | } |