| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "simple_tensorflow_module_import.ipynb", |
| "provenance": [], |
| "collapsed_sections": [] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "h5s6ncerSpc5", |
| "colab_type": "text" |
| }, |
| "source": [ |
| "# Defines a simple TF module, saves it and loads it in IREE.\n", |
| "\n", |
| "## Start kernel:\n", |
| "* [Install a TensorFlow2 nightly pip](https://www.tensorflow.org/install) (or bring your own)\n", |
| "* Enable IREE/TF integration by adding to your user.bazelrc: `build --define=iree_tensorflow=true`\n", |
| "* *Optional:* Prime the build: `bazel build bindings/python/pyiree`\n", |
| "* Start colab by running `python build_tools/scripts/start_colab_kernel.py` (see that file for initial setup instructions)\n", |
| "\n", |
| "## TODO:\n", |
| "\n", |
| "* This is just using low-level binding classes. Change to high level API.\n", |
| "* Plumg through ability to run TF compiler lowering passes and import directly into IREE\n" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "s2bScbYkP6VZ", |
| "colab_type": "code", |
| "colab": {} |
| }, |
| "source": [ |
| "import os\n", |
| "import tensorflow as tf\n", |
| "import pyiree\n", |
| "from pyiree import binding\n", |
| "\n", |
| "SAVE_PATH = os.path.join(os.environ[\"HOME\"], \"saved_models\")\n", |
| "os.makedirs(SAVE_PATH, exist_ok=True)" |
| ], |
| "execution_count": 0, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "6YGqN2uqP_7P", |
| "colab_type": "code", |
| "outputId": "4cc03b22-bee5-4ffd-e4cf-6fa3055ea886", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 853 |
| } |
| }, |
| "source": [ |
| "class MyModule(tf.Module):\n", |
| " def __init__(self):\n", |
| " self.v = tf.Variable([4], dtype=tf.float32)\n", |
| " \n", |
| " @tf.function(\n", |
| " input_signature=[tf.TensorSpec([4], tf.float32), tf.TensorSpec([4], tf.float32)]\n", |
| " )\n", |
| " def add(self, a, b):\n", |
| " return tf.tanh(self.v * a + b)\n", |
| "\n", |
| "my_mod = MyModule()\n", |
| "\n", |
| "options = tf.saved_model.SaveOptions(save_debug_info=True)\n", |
| "tf.saved_model.save(my_mod, os.path.join(SAVE_PATH, \"simple.sm\"), options=options)\n", |
| "\n", |
| "ctx = binding.compiler.CompilerContext()\n", |
| "input_module = binding.tf_interop.load_saved_model(ctx, os.path.join(SAVE_PATH, \"simple.sm\"))\n", |
| "print('LOADED ASM:', input_module.to_asm())\n", |
| "\n", |
| "# Canonicalize the TF import.\n", |
| "input_module.run_pass_pipeline([\n", |
| " \"tf-executor-graph-pruning\",\n", |
| " \"tf-standard-pipeline\",\n", |
| " \"canonicalize\",\n", |
| "])\n", |
| "print(\"LOWERED TF ASM:\", input_module.to_asm())\n", |
| "\n", |
| "# Legalize to XLA (high-level).\n", |
| "input_module.run_pass_pipeline([\n", |
| " \"xla-legalize-tf\",\n", |
| "])\n", |
| "print(\"XLA ASM:\", input_module.to_asm())" |
| ], |
| "execution_count": 5, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "INFO:tensorflow:Assets written to: C:\\Users\\laurenzo\\saved_models\\simple.sm\\assets\n", |
| "LOADED ASM: \n", |
| "\n", |
| "module attributes {tf_saved_model.semantics} {\n", |
| " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", |
| " func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", |
| " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", |
| " %0 = tf_executor.graph {\n", |
| " %1:2 = tf_executor.island wraps \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", |
| " %2:2 = tf_executor.island wraps \"tf.Mul\"(%1#0, %arg0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"mul\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %3:2 = tf_executor.island wraps \"tf.AddV2\"(%2#0, %arg1) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"add\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %4:2 = tf_executor.island wraps \"tf.Tanh\"(%3#0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Tanh\"} : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " %5:2 = tf_executor.island(%1#1) wraps \"tf.Identity\"(%4#0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Identity\"} : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " tf_executor.fetch %5#0, %1#1 : tensor<4xf32>, !tf_executor.control\n", |
| " }\n", |
| " return %0 : tensor<4xf32>\n", |
| " }\n", |
| "}\n", |
| "\n", |
| "LOWERED TF ASM: \n", |
| "\n", |
| "module attributes {tf_saved_model.semantics} {\n", |
| " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", |
| " func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", |
| " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", |
| " %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", |
| " %1 = \"tf.Mul\"(%0, %arg0) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"mul\"} : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %2 = \"tf.AddV2\"(%1, %arg1) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"add\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %3 = \"tf.Tanh\"(%2) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Tanh\"} : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " %4 = \"tf.Identity\"(%3) {T = \"tfdtype$DT_FLOAT\", _output_shapes = [\"tfshape$dim { size: 4 }\"], device = \"\", name = \"Identity\"} : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " return %4 : tensor<4xf32>\n", |
| " }\n", |
| "}\n", |
| "\n", |
| "XLA ASM: \n", |
| "\n", |
| "module attributes {tf_saved_model.semantics} {\n", |
| " \"tf_saved_model.global_tensor\"() {is_mutable, sym_name = \"__sm_node1__v\", tf_saved_model.exported_names = [\"v\"], value = dense<4.000000e+00> : tensor<1xf32>} : () -> ()\n", |
| " func @__inference_add_2620(%arg0: tensor<4xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<4xf32> {tf_saved_model.index_path = [1]}, %arg2: tensor<*x!tf.resource> {tf_saved_model.bound_input = @__sm_node1__v}) -> (tensor<4xf32> {tf_saved_model.index_path = []})\n", |
| " attributes {tf._input_shapes = [\"tfshape$dim { size: 4 }\", \"tfshape$dim { size: 4 }\", \"tfshape$unknown_rank: true\"], tf.signature.is_stateful, tf_saved_model.exported_names = [\"add\"]} {\n", |
| " %0 = \"tf.ReadVariableOp\"(%arg2) {_output_shapes = [\"tfshape$dim { size: 1 }\"], device = \"\", dtype = \"tfdtype$DT_FLOAT\", name = \"ReadVariableOp\"} : (tensor<*x!tf.resource>) -> tensor<1xf32>\n", |
| " %1 = \"xla_hlo.mul\"(%0, %arg0) : (tensor<1xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " %2 = xla_hlo.add %1, %arg1 : tensor<4xf32>\n", |
| " %3 = \"xla_hlo.tanh\"(%2) : (tensor<4xf32>) -> tensor<4xf32>\n", |
| " return %3 : tensor<4xf32>\n", |
| " }\n", |
| "}\n", |
| "\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| } |
| ] |
| } |