Add minimal sample using dynamic shapes from TF -> C. (#6191)

Progress on https://github.com/google/iree/issues/5222.
diff --git a/iree/samples/dynamic_shapes/CMakeLists.txt b/iree/samples/dynamic_shapes/CMakeLists.txt
new file mode 100644
index 0000000..deaaafd
--- /dev/null
+++ b/iree/samples/dynamic_shapes/CMakeLists.txt
@@ -0,0 +1,18 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+set(_NAME "iree_samples_dynamic_shapes")
+add_executable(${_NAME} "")
+target_sources(${_NAME}
+  PRIVATE
+    main.c
+)
+
+set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "dynamic-shapes")
+
+target_link_libraries(${_NAME}
+  iree_runtime_runtime
+)
diff --git a/iree/samples/dynamic_shapes/README.md b/iree/samples/dynamic_shapes/README.md
new file mode 100644
index 0000000..6678425
--- /dev/null
+++ b/iree/samples/dynamic_shapes/README.md
@@ -0,0 +1,101 @@
+# "Dynamic Shapes" sample
+
+This sample shows how to
+
+1. Create a TensorFlow program that includes dynamic shapes in program inputs
+   and outputs
+2. Import that program into IREE's compiler
+3. Compile that program to an IREE VM bytecode module
+4. Load the compiled program using IREE's high level runtime C API
+5. Call exported functions on the loaded program
+
+Steps 1-2 are performed in Python via the
+[`dynamic_shapes.ipynb`](./dynamic_shapes.ipynb)
+[Colab](https://research.google.com/colaboratory/) notebook:
+
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/iree/blob/main/iree/samples/dynamic_shapes/dynamic_shapes.ipynb)
+
+Step 3 should be performed on your development host machine
+
+Steps 4-5 are in [`main.c`](./main.c)
+
+The program used to demonstrate includes functions with varying uses of
+dynamic shapes:
+
+```python
+class DynamicShapesModule(tf.Module):
+  # reduce_sum_1d (dynamic input size, static output size)
+  #   e.g. [1, 2, 3] -> 6
+  @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)])
+  def reduce_sum_1d(self, values):
+    return tf.math.reduce_sum(values)
+
+  # reduce_sum_2d (partially dynamic input size, static output size)
+  #   e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]
+  @tf.function(input_signature=[tf.TensorSpec([None, 3], tf.int32)])
+  def reduce_sum_2d(self, values):
+    return tf.math.reduce_sum(values, 0)
+
+  # add_one (dynamic input size, dynamic output size)
+  #   e.g. [1, 2, 3] -> [2, 3, 4]
+  @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)])
+  def add_one(self, values):
+    return tf.math.add(values, tf.constant(1, dtype=tf.int32))
+```
+
+## Background
+
+Tensors are multi-dimensional arrays with a uniform type (e.g. int32, float32)
+and a shape. Shapes consist of a rank and a list of dimensions and may be
+static (i.e. fully known and fixed) or varying degrees of dynamic. See
+TensorFlow's [Introduction to Tensors](https://www.tensorflow.org/guide/tensor)
+for more information on how tensors are used in TensorFlow programs.
+
+Dynamic shapes are useful for passing variable sized batches as input,
+receiving variable length sentences of text as output, etc.
+
+NOTE: as in other domains, providing more information to a compiler allows it
+to generate more efficient code. As a general rule, the slowest varying
+dimensions of program data like batch index or timestep are safer to treat as
+dynamic than faster varying dimensions like image x/y/channel. See
+[this paper](https://arxiv.org/pdf/2006.03031.pdf) for a discussion of the
+challenges imposed by dynamic shapes and one project's approach to addressing
+them.
+
+## Instructions
+
+1. Run the Colab notebook and download the `dynamic_shapes.mlir` file it
+    generates
+
+2. Build the `iree-translate` tool (see
+    [here](https://google.github.io/iree/building-from-source/getting-started/)
+    for general instructions on building using CMake)
+
+    ```
+    cmake -B ../iree-build/ -DCMAKE_BUILD_TYPE=RelWithDebInfo .
+    cmake --build ../iree-build/ --target iree_tools_iree-translate
+    ```
+
+3. Compile the `dynamic_shapes.mlir` file using `iree-translate`. The
+    [dylib-llvm-aot](https://google.github.io/iree/deployment-configurations/cpu-dylib/)
+    configuration has the best support for dynamic shapes:
+
+    ```
+    ../iree-build/iree/tools/iree-translate \
+        -iree-mlir-to-vm-bytecode-module \
+        -iree-hal-target-backends=dylib-llvm-aot \
+        dynamic_shapes.mlir -o dynamic_shapes_dylib.vmfb
+    ```
+
+4. Build the `iree_samples_dynamic_shapes` CMake target
+
+    ```
+    cmake --build ../iree-build/ --target iree_samples_dynamic_shapes
+    ```
+
+5. Run the sample binary:
+
+   ```
+   ../iree-build/iree/samples/dynamic_shapes/dynamic-shapes \
+       /path/to/dynamic_shapes_dylib.vmfb dylib
+   ```
diff --git a/iree/samples/dynamic_shapes/dynamic_shapes.ipynb b/iree/samples/dynamic_shapes/dynamic_shapes.ipynb
new file mode 100644
index 0000000..556641f
--- /dev/null
+++ b/iree/samples/dynamic_shapes/dynamic_shapes.ipynb
@@ -0,0 +1,417 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "dynamic_shapes.ipynb",
+      "provenance": [],
+      "collapsed_sections": [
+        "FH3IRpYTta2v"
+      ]
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "name": "python3"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "FH3IRpYTta2v"
+      },
+      "source": [
+        "##### Copyright 2021 The IREE Authors"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "mWGa71_Ct2ug",
+        "cellView": "form"
+      },
+      "source": [
+        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
+        "# See https://llvm.org/LICENSE.txt for license information.\n",
+        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
+      ],
+      "execution_count": 1,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "h5s6ncerSpc5"
+      },
+      "source": [
+        "# Dynamic Shapes\n",
+        "\n",
+        "This notebook\n",
+        "\n",
+        "1. Creates a TensorFlow program with dynamic shapes\n",
+        "2. Imports that program into IREE's compiler\n",
+        "3. Compiles the imported program to an IREE VM bytecode module\n",
+        "4. Tests running the compiled VM module using IREE's runtime\n",
+        "5. Downloads compilation artifacts for use with the native (C API) sample application"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "s2bScbYkP6VZ",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "1fb6eb27-ae8b-4259-a38e-0b92ea307f73"
+      },
+      "source": [
+        "#@title General setup\n",
+        "\n",
+        "import os\n",
+        "import tempfile\n",
+        "\n",
+        "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
+        "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n",
+        "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")"
+      ],
+      "execution_count": 2,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Using artifacts directory '/tmp/iree/colab_artifacts'\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "dBHgjTjGPOJ7"
+      },
+      "source": [
+        "## Create a program using TensorFlow and import it into IREE\n",
+        "\n",
+        "NOTE: as in other domains, providing more information to a compiler allows it\n",
+        "to generate more efficient code. As a general rule, the slowest varying\n",
+        "dimensions of program data like batch index or timestep are safer to treat as\n",
+        "dynamic than faster varying dimensions like image x/y/channel. See\n",
+        "[this paper](https://arxiv.org/pdf/2006.03031.pdf) for a discussion of the\n",
+        "challenges imposed by dynamic shapes and one project's approach to addressing\n",
+        "them."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "hwApbPstraWZ"
+      },
+      "source": [
+        "#@title Define a sample TensorFlow module using dynamic shapes\n",
+        "\n",
+        "import tensorflow as tf\n",
+        "\n",
+        "class DynamicShapesModule(tf.Module):\n",
+        "  # reduce_sum_1d (dynamic input size, static output size)\n",
+        "  #   e.g. [1, 2, 3] -> 6\n",
+        "  @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)])\n",
+        "  def reduce_sum_1d(self, values):\n",
+        "    return tf.math.reduce_sum(values)\n",
+        "    \n",
+        "  # reduce_sum_2d (partially dynamic input size, static output size)\n",
+        "  #   e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n",
+        "  @tf.function(input_signature=[tf.TensorSpec([None, 3], tf.int32)])\n",
+        "  def reduce_sum_2d(self, values):\n",
+        "    return tf.math.reduce_sum(values, 0)\n",
+        "\n",
+        "  # add_one (dynamic input size, dynamic output size)\n",
+        "  #   e.g. [1, 2, 3] -> [2, 3, 4]\n",
+        "  @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)])\n",
+        "  def add_one(self, values):\n",
+        "    return tf.math.add(values, tf.constant(1, dtype=tf.int32))"
+      ],
+      "execution_count": 3,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "k4aMPI2C7btB"
+      },
+      "source": [
+        "%%capture\n",
+        "!python -m pip install iree-compiler-snapshot iree-tools-tf-snapshot -f https://github.com/google/iree/releases"
+      ],
+      "execution_count": 4,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "3nSXZiZ_X8-P",
+        "outputId": "321b77e3-5890-4d95-a149-83d86ea6f039"
+      },
+      "source": [
+        "#@title Import the TensorFlow program into IREE as MLIR\n",
+        "\n",
+        "from IPython.display import clear_output\n",
+        "\n",
+        "from iree.compiler import tf as tfc\n",
+        "\n",
+        "compiler_module = tfc.compile_module(\n",
+        "    DynamicShapesModule(), import_only=True, \n",
+        "    output_mlir_debuginfo=False)\n",
+        "clear_output()  # Skip over TensorFlow's output.\n",
+        "\n",
+        "# Print the imported MLIR to see how the compiler views this program.\n",
+        "print(\"Dynamic Shapes MLIR:\\n```\\n%s```\\n\" % compiler_module.decode(\"utf-8\"))\n",
+        "\n",
+        "# Save the imported MLIR to disk.\n",
+        "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n",
+        "with open(imported_mlir_path, \"wt\") as output_file:\n",
+        "  output_file.write(compiler_module.decode(\"utf-8\"))\n",
+        "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")"
+      ],
+      "execution_count": 5,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Dynamic Shapes MLIR:\n",
+            "```\n",
+            "#map0 = affine_map<(d0) -> ()>\n",
+            "#map1 = affine_map<(d0) -> (d0)>\n",
+            "#map2 = affine_map<(d0, d1) -> (d1, d0)>\n",
+            "#map3 = affine_map<(d0, d1) -> (d0)>\n",
+            "module  {\n",
+            "  func @add_one(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22v\\22:1}\"} {\n",
+            "    %c0 = constant 0 : index\n",
+            "    %0 = hal.buffer_view.dim %arg0, 0 : index\n",
+            "    %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?xi32>{%0}\n",
+            "    %2 = call @__inference_add_one_70(%1) : (tensor<?xi32>) -> tensor<?xi32>\n",
+            "    %3 = memref.dim %2, %c0 : tensor<?xi32>\n",
+            "    %4 = hal.tensor.cast %2 : tensor<?xi32>{%3} -> !hal.buffer_view\n",
+            "    return %4 : !hal.buffer_view\n",
+            "  }\n",
+            "  func private @__inference_add_one_70(%arg0: tensor<?xi32> {tf._user_specified_name = \"values\"}) -> tensor<?xi32> attributes {tf._construction_context = \"kEagerRuntime\", tf._input_shapes = [#tf.shape<?>]} {\n",
+            "    %cst = constant dense<1> : tensor<i32>\n",
+            "    %c0 = constant 0 : index\n",
+            "    %0 = memref.dim %arg0, %c0 : tensor<?xi32>\n",
+            "    %1 = linalg.init_tensor [%0] : tensor<?xi32>\n",
+            "    %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = [\"parallel\"]} ins(%cst : tensor<i32>) outs(%1 : tensor<?xi32>) {\n",
+            "    ^bb0(%arg1: i32, %arg2: i32):  // no predecessors\n",
+            "      linalg.yield %arg1 : i32\n",
+            "    } -> tensor<?xi32>\n",
+            "    %3 = memref.dim %arg0, %c0 : tensor<?xi32>\n",
+            "    %4 = linalg.init_tensor [%3] : tensor<?xi32>\n",
+            "    %5 = linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = [\"parallel\"]} ins(%arg0, %2 : tensor<?xi32>, tensor<?xi32>) outs(%4 : tensor<?xi32>) {\n",
+            "    ^bb0(%arg1: i32, %arg2: i32, %arg3: i32):  // no predecessors\n",
+            "      %6 = addi %arg1, %arg2 : i32\n",
+            "      linalg.yield %6 : i32\n",
+            "    } -> tensor<?xi32>\n",
+            "    return %5 : tensor<?xi32>\n",
+            "  }\n",
+            "  func @reduce_sum_1d(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,1,null]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,0]],\\22v\\22:1}\"} {\n",
+            "    %c0_i32 = constant 0 : i32\n",
+            "    %0 = hal.buffer_view.dim %arg0, 0 : index\n",
+            "    %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?xi32>{%0}\n",
+            "    %2 = linalg.init_tensor [] : tensor<i32>\n",
+            "    %3 = linalg.fill(%2, %c0_i32) : tensor<i32>, i32 -> tensor<i32> \n",
+            "    %4 = linalg.generic {indexing_maps = [#map1, #map0], iterator_types = [\"reduction\"]} ins(%1 : tensor<?xi32>) outs(%3 : tensor<i32>) {\n",
+            "    ^bb0(%arg1: i32, %arg2: i32):  // no predecessors\n",
+            "      %6 = addi %arg1, %arg2 : i32\n",
+            "      linalg.yield %6 : i32\n",
+            "    } -> tensor<i32>\n",
+            "    %5 = hal.tensor.cast %4 : tensor<i32> -> !hal.buffer_view\n",
+            "    return %5 : !hal.buffer_view\n",
+            "  }\n",
+            "  func @reduce_sum_2d(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi = \"{\\22a\\22:[[\\22ndarray\\22,\\22i32\\22,2,null,3]],\\22r\\22:[[\\22ndarray\\22,\\22i32\\22,1,3]],\\22v\\22:1}\"} {\n",
+            "    %c0_i32 = constant 0 : i32\n",
+            "    %0 = hal.buffer_view.dim %arg0, 0 : index\n",
+            "    %1 = hal.tensor.cast %arg0 : !hal.buffer_view -> tensor<?x3xi32>{%0}\n",
+            "    %2 = linalg.init_tensor [3] : tensor<3xi32>\n",
+            "    %3 = linalg.fill(%2, %c0_i32) : tensor<3xi32>, i32 -> tensor<3xi32> \n",
+            "    %4 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = [\"parallel\", \"reduction\"]} ins(%1 : tensor<?x3xi32>) outs(%3 : tensor<3xi32>) {\n",
+            "    ^bb0(%arg1: i32, %arg2: i32):  // no predecessors\n",
+            "      %6 = addi %arg1, %arg2 : i32\n",
+            "      linalg.yield %6 : i32\n",
+            "    } -> tensor<3xi32>\n",
+            "    %5 = hal.tensor.cast %4 : tensor<3xi32> -> !hal.buffer_view\n",
+            "    return %5 : !hal.buffer_view\n",
+            "  }\n",
+            "}\n",
+            "\n",
+            "```\n",
+            "\n",
+            "Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlir'\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "WCiRV6KRh3iA"
+      },
+      "source": [
+        "## Test the imported program\n",
+        "\n",
+        "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n",
+        "\n",
+        "_See the [README](https://github.com/google/iree/tree/main/iree/samples/dynamic_shapes#instructions) for more details and example command line instructions._\n",
+        "\n",
+        "* _The \"imported MLIR\" can be used by IREE's generic compiler tools_\n",
+        "* _The \"flatbuffer blob\" can be saved and used by runtime applications_\n",
+        "\n",
+        "_The specific point at which you switch from Python to native tools will depend on your project._"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "6TV6_Hdu6Xlf"
+      },
+      "source": [
+        "%%capture\n",
+        "!python -m pip install iree-compiler-snapshot -f https://github.com/google/iree/releases"
+      ],
+      "execution_count": 6,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "GF0dzDsbaP2w"
+      },
+      "source": [
+        "#@title Compile the imported MLIR further into an IREE VM bytecode module\n",
+        "\n",
+        "from iree.compiler import compile_str\n",
+        "\n",
+        "# Note: we'll use the dylib-llvm-aot backend since it has the best support\n",
+        "# for dynamic shapes among our compiler targets.\n",
+        "\n",
+        "flatbuffer_blob = compile_str(compiler_module, target_backends=[\"dylib-llvm-aot\"])\n",
+        "\n",
+        "# Note: the dylib-llvm-aot target produces platform-specific code. Since you\n",
+        "# may need to recompile it yourself using the appropriate\n",
+        "# `-iree-llvm-target-triple` flag, we skip saving it to disk and downloading it."
+      ],
+      "execution_count": 7,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "G7g5eXYL6hWb"
+      },
+      "source": [
+        "%%capture\n",
+        "!python -m pip install iree-runtime-snapshot -f https://github.com/google/iree/releases"
+      ],
+      "execution_count": 8,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "h8cmF6nAfza0",
+        "outputId": "0e89eecd-ad88-45bb-9a81-c8cc9519eccb"
+      },
+      "source": [
+        "#@title Test running the compiled VM module using IREE's runtime\n",
+        "\n",
+        "from iree import runtime as ireert\n",
+        "\n",
+        "vm_module = ireert.VmModule.from_flatbuffer(flatbuffer_blob)\n",
+        "config = ireert.Config(\"dylib\")\n",
+        "ctx = ireert.SystemContext(config=config)\n",
+        "ctx.add_vm_module(vm_module)"
+      ],
+      "execution_count": 9,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Created IREE driver dylib: <iree.runtime.binding.HalDriver object at 0x7fc97f1611f0>\n",
+            "SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7fc97f1611f0>\n"
+          ],
+          "name": "stderr"
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "CQffg1iQatkb",
+        "outputId": "f1729cac-4bbc-420e-fe22-bd47b36d8ad4"
+      },
+      "source": [
+        "import numpy as np\n",
+        "\n",
+        "# Our @tf.functions are accessible by name on the module named 'module'\n",
+        "dynamic_shapes_program = ctx.modules.module\n",
+        "\n",
+        "print(dynamic_shapes_program.reduce_sum_1d(np.array([1, 10, 100], dtype=np.int32)))\n",
+        "print(dynamic_shapes_program.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30]], dtype=np.int32)))\n",
+        "print(dynamic_shapes_program.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30], [100, 200, 300]], dtype=np.int32)))\n",
+        "print(dynamic_shapes_program.add_one(np.array([1, 10, 100], dtype=np.int32)))"
+      ],
+      "execution_count": 10,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "111\n",
+            "[11 22 33]\n",
+            "[111 222 333]\n",
+            "[  2  11 101]\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "wCvwX1IEokm6"
+      },
+      "source": [
+        "## Download compilation artifacts"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "id": "bUaNUkS2ohRj"
+      },
+      "source": [
+        "ARTIFACTS_ZIP = \"/tmp/dynamic_shapes_colab_artifacts.zip\"\n",
+        "\n",
+        "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n",
+        "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n",
+        "\n",
+        "# Note: you can also download files using Colab's file explorer\n",
+        "try:\n",
+        "  from google.colab import files\n",
+        "  print(\"Downloading the artifacts zip file...\")\n",
+        "  files.download(ARTIFACTS_ZIP)  \n",
+        "except ImportError:\n",
+        "  print(\"Missing google_colab Python package, can't download files\")"
+      ],
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file
diff --git a/iree/samples/dynamic_shapes/main.c b/iree/samples/dynamic_shapes/main.c
new file mode 100644
index 0000000..1e1838d
--- /dev/null
+++ b/iree/samples/dynamic_shapes/main.c
@@ -0,0 +1,271 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <stdio.h>
+
+#include "iree/runtime/api.h"
+
+iree_status_t reduce_sum_1d(iree_runtime_session_t* session, const int* values,
+                            int values_length, int* out_result) {
+  iree_runtime_call_t call;
+  IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
+      session, iree_make_cstring_view("module.reduce_sum_1d"), &call));
+
+  iree_hal_buffer_view_t* arg0 = NULL;
+  const iree_hal_dim_t arg0_shape[1] = {values_length};
+
+  // TODO(scotttodd): use iree_hal_buffer_view_wrap_or_clone_heap_buffer
+  //   * debugging some apparent memory corruption with the stack-local value
+  iree_status_t status = iree_ok_status();
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_buffer_view_clone_heap_buffer(
+        iree_runtime_session_device_allocator(session), arg0_shape,
+        IREE_ARRAYSIZE(arg0_shape), IREE_HAL_ELEMENT_TYPE_SINT_32,
+        IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+        IREE_HAL_BUFFER_USAGE_ALL,
+        iree_make_const_byte_span((void*)values, sizeof(int) * values_length),
+        &arg0);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg0);
+  }
+  iree_hal_buffer_view_release(arg0);
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_call_invoke(&call, /*flags=*/0);
+  }
+
+  iree_hal_buffer_view_t* buffer_view = NULL;
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_runtime_call_outputs_pop_front_buffer_view(&call, &buffer_view);
+  }
+  iree_hal_buffer_mapping_t buffer_mapping;
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_buffer_map_range(iree_hal_buffer_view_buffer(buffer_view),
+                                       IREE_HAL_MEMORY_ACCESS_READ, 0,
+                                       IREE_WHOLE_BUFFER, &buffer_mapping);
+  }
+  if (iree_status_is_ok(status)) {
+    *out_result = *buffer_mapping.contents.data;
+  }
+  iree_hal_buffer_unmap_range(&buffer_mapping);
+  iree_hal_buffer_view_release(buffer_view);
+
+  iree_runtime_call_deinitialize(&call);
+  return status;
+}
+
+iree_status_t reduce_sum_2d(iree_runtime_session_t* session, const int* values,
+                            size_t values_length,
+                            iree_hal_buffer_view_t** out_buffer_view) {
+  iree_runtime_call_t call;
+  IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
+      session, iree_make_cstring_view("module.reduce_sum_2d"), &call));
+
+  iree_hal_buffer_view_t* arg0 = NULL;
+  const iree_hal_dim_t arg0_shape[2] = {values_length / 3, 3};
+
+  // TODO(scotttodd): use iree_hal_buffer_view_wrap_or_clone_heap_buffer
+  //   * debugging some apparent memory corruption with the stack-local value
+  iree_status_t status = iree_ok_status();
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_buffer_view_clone_heap_buffer(
+        iree_runtime_session_device_allocator(session), arg0_shape,
+        IREE_ARRAYSIZE(arg0_shape), IREE_HAL_ELEMENT_TYPE_SINT_32,
+        IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+        IREE_HAL_BUFFER_USAGE_ALL,
+        iree_make_const_byte_span((void*)values, sizeof(int) * values_length),
+        &arg0);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg0);
+  }
+  iree_hal_buffer_view_release(arg0);
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_call_invoke(&call, /*flags=*/0);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_runtime_call_outputs_pop_front_buffer_view(&call, out_buffer_view);
+  }
+
+  iree_runtime_call_deinitialize(&call);
+  return status;
+}
+
+iree_status_t add_one(iree_runtime_session_t* session, const int* values,
+                      size_t values_length,
+                      iree_hal_buffer_view_t** out_buffer_view) {
+  iree_runtime_call_t call;
+  IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
+      session, iree_make_cstring_view("module.add_one"), &call));
+
+  iree_hal_buffer_view_t* arg0 = NULL;
+  const iree_hal_dim_t arg0_shape[1] = {values_length};
+
+  // TODO(scotttodd): use iree_hal_buffer_view_wrap_or_clone_heap_buffer
+  //   * debugging some apparent memory corruption with the stack-local value
+  iree_status_t status = iree_ok_status();
+  if (iree_status_is_ok(status)) {
+    status = iree_hal_buffer_view_clone_heap_buffer(
+        iree_runtime_session_device_allocator(session), arg0_shape,
+        IREE_ARRAYSIZE(arg0_shape), IREE_HAL_ELEMENT_TYPE_SINT_32,
+        IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
+        IREE_HAL_BUFFER_USAGE_ALL,
+        iree_make_const_byte_span((void*)values, sizeof(int) * values_length),
+        &arg0);
+  }
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg0);
+  }
+  iree_hal_buffer_view_release(arg0);
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_call_invoke(&call, /*flags=*/0);
+  }
+
+  if (iree_status_is_ok(status)) {
+    status =
+        iree_runtime_call_outputs_pop_front_buffer_view(&call, out_buffer_view);
+  }
+
+  iree_runtime_call_deinitialize(&call);
+  return status;
+}
+
+iree_status_t run_sample(iree_string_view_t bytecode_module_path,
+                         iree_string_view_t driver_name) {
+  iree_status_t status = iree_ok_status();
+
+  //===-------------------------------------------------------------------===//
+  // Instance configuration (this should be shared across sessions).
+  fprintf(stdout, "Configuring IREE runtime instance and '%s' device\n",
+          driver_name.data);
+  iree_runtime_instance_options_t instance_options;
+  iree_runtime_instance_options_initialize(IREE_API_VERSION_LATEST,
+                                           &instance_options);
+  iree_runtime_instance_options_use_all_available_drivers(&instance_options);
+  iree_runtime_instance_t* instance = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_instance_create(&instance_options,
+                                          iree_allocator_system(), &instance);
+  }
+  // TODO(#5724): move device selection into the compiled modules.
+  iree_hal_device_t* device = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_instance_try_create_default_device(
+        instance, driver_name, &device);
+  }
+  //===-------------------------------------------------------------------===//
+
+  //===-------------------------------------------------------------------===//
+  // Session configuration (one per loaded module to hold module state).
+  fprintf(stdout, "Creating IREE runtime session\n");
+  iree_runtime_session_options_t session_options;
+  iree_runtime_session_options_initialize(&session_options);
+  iree_runtime_session_t* session = NULL;
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_session_create_with_device(
+        instance, &session_options, device,
+        iree_runtime_instance_host_allocator(instance), &session);
+  }
+  iree_hal_device_release(device);
+
+  fprintf(stdout, "Loading bytecode module at '%s'\n",
+          bytecode_module_path.data);
+  if (iree_status_is_ok(status)) {
+    status = iree_runtime_session_append_bytecode_module_from_file(
+        session, bytecode_module_path.data);
+  }
+  //===-------------------------------------------------------------------===//
+
+  //===-------------------------------------------------------------------===//
+  // Call the exported sample functions with some test inputs
+  fprintf(stdout, "Calling functions\n\n");
+
+  // reduce_sum_1d([1, 10, 100])
+  if (iree_status_is_ok(status)) {
+    const int input[3] = {1, 10, 100};
+    int result = -1;
+    status = reduce_sum_1d(session, input, 3, &result);
+    fprintf(stdout, "reduce_sum_1d([1, 10, 100]): %d\n", result);
+  }
+
+  // reduce_sum_2d([[1, 2, 3], [10, 20, 30]])
+  if (iree_status_is_ok(status)) {
+    const int input[6] = {1, 2, 3, 10, 20, 30};
+    iree_hal_buffer_view_t* result_buffer_view = NULL;
+    status = reduce_sum_2d(session, input, 6, &result_buffer_view);
+    if (iree_status_is_ok(status)) {
+      fprintf(stdout, "reduce_sum_2d([[1, 2, 3], [10, 20, 30]]): ");
+      status = iree_hal_buffer_view_fprint(stdout, result_buffer_view,
+                                           /*max_element_count=*/4096);
+      fprintf(stdout, "\n");
+    }
+    iree_hal_buffer_view_release(result_buffer_view);
+  }
+
+  // reduce_sum_2d([[1, 2, 3], [10, 20, 30], [100, 200, 300]])
+  if (iree_status_is_ok(status)) {
+    const int input[9] = {1, 2, 3, 10, 20, 30, 100, 200, 300};
+    iree_hal_buffer_view_t* result_buffer_view = NULL;
+    status = reduce_sum_2d(session, input, 9, &result_buffer_view);
+    if (iree_status_is_ok(status)) {
+      fprintf(stdout,
+              "reduce_sum_2d([[1, 2, 3], [10, 20, 30], [100, 200, 300]]): ");
+      status = iree_hal_buffer_view_fprint(stdout, result_buffer_view,
+                                           /*max_element_count=*/4096);
+      fprintf(stdout, "\n");
+    }
+    iree_hal_buffer_view_release(result_buffer_view);
+  }
+
+  // add_one([1, 10, 100])
+  if (iree_status_is_ok(status)) {
+    const int input[3] = {1, 10, 100};
+    iree_hal_buffer_view_t* result_buffer_view = NULL;
+    status = add_one(session, input, 3, &result_buffer_view);
+    if (iree_status_is_ok(status)) {
+      fprintf(stdout, "add_one([1, 10, 100]): ");
+      status = iree_hal_buffer_view_fprint(stdout, result_buffer_view,
+                                           /*max_element_count=*/64);
+      fprintf(stdout, "\n");
+    }
+    iree_hal_buffer_view_release(result_buffer_view);
+  }
+  //===-------------------------------------------------------------------===//
+
+  //===-------------------------------------------------------------------===//
+  // Cleanup.
+  iree_runtime_session_release(session);
+  iree_runtime_instance_release(instance);
+  //===-------------------------------------------------------------------===//
+
+  return status;
+}
+
+int main(int argc, char** argv) {
+  if (argc != 3) {
+    fprintf(
+        stderr,
+        "Usage: dynamic-shapes </path/to/dynamic_shapes.vmfb> <driver_name>\n");
+    fprintf(stderr, "  (See the README for this sample for details)\n ");
+    return -1;
+  }
+
+  iree_string_view_t bytecode_module_path = iree_make_cstring_view(argv[1]);
+  iree_string_view_t driver_name = iree_make_cstring_view(argv[2]);
+
+  iree_status_t result = run_sample(bytecode_module_path, driver_name);
+  if (!iree_status_is_ok(result)) {
+    fprintf(stdout, "Failed!\n");
+    iree_status_fprint(stderr, result);
+    iree_status_ignore(result);
+    return -1;
+  }
+  fprintf(stdout, "\nSuccess!\n");
+  return 0;
+}
diff --git a/iree/samples/variables_and_state/README.md b/iree/samples/variables_and_state/README.md
index b219ea2..de84608 100644
--- a/iree/samples/variables_and_state/README.md
+++ b/iree/samples/variables_and_state/README.md
@@ -60,14 +60,16 @@
 
 1. Run the Colab notebook and download the `counter.mlir` and
    `counter_vmvx.vmfb` files it generates
-2. Compile the `iree_samples_variables_and_state` CMake target (see
-   [here](https://google.github.io/iree/building-from-source/getting-started/)
-   for general instructions on building using CMake)
+
+2. Build the `iree_samples_variables_and_state` CMake target (see
+    [here](https://google.github.io/iree/building-from-source/getting-started/)
+    for general instructions on building using CMake)
 
     ```
     cmake -B ../iree-build/ -DCMAKE_BUILD_TYPE=RelWithDebInfo .
     cmake --build ../iree-build/ --target iree_samples_variables_and_state
     ```
+
 3. Run the sample binary:
 
    ```
diff --git a/iree/samples/variables_and_state/main.c b/iree/samples/variables_and_state/main.c
index 9cdef29..c2fe84d 100644
--- a/iree/samples/variables_and_state/main.c
+++ b/iree/samples/variables_and_state/main.c
@@ -169,8 +169,8 @@
   int value = -1;
   if (iree_status_is_ok(status)) {
     status = counter_get_value(session, &value);
+    fprintf(stdout, "Initial get_value()    : %d\n", value);
   }
-  fprintf(stdout, "Initial get_value()    : %d\n", value);
 
   // 2. set_value(101)
   if (iree_status_is_ok(status)) {
@@ -178,8 +178,8 @@
   }
   if (iree_status_is_ok(status)) {
     status = counter_get_value(session, &value);
+    fprintf(stdout, "After set_value(101)   : %d\n", value);
   }
-  fprintf(stdout, "After set_value(101)   : %d\n", value);
 
   // 3. add_to_value(20)
   if (iree_status_is_ok(status)) {
@@ -187,8 +187,8 @@
   }
   if (iree_status_is_ok(status)) {
     status = counter_get_value(session, &value);
+    fprintf(stdout, "After add_to_value(20) : %d\n", value);
   }
-  fprintf(stdout, "After add_to_value(20) : %d\n", value);
 
   // 4. add_to_value(-50)
   if (iree_status_is_ok(status)) {
@@ -196,8 +196,8 @@
   }
   if (iree_status_is_ok(status)) {
     status = counter_get_value(session, &value);
+    fprintf(stdout, "After add_to_value(-50): %d\n", value);
   }
-  fprintf(stdout, "After add_to_value(-50): %d\n", value);
 
   // 5. reset_value()
   if (iree_status_is_ok(status)) {
@@ -205,8 +205,8 @@
   }
   if (iree_status_is_ok(status)) {
     status = counter_get_value(session, &value);
+    fprintf(stdout, "After reset_value()    : %d\n", value);
   }
-  fprintf(stdout, "After reset_value()    : %d\n", value);
   //===-------------------------------------------------------------------===//
 
   //===-------------------------------------------------------------------===//