Add pytorch_jit sample Colab notebook using SHARK-Turbine. (#15146)
Progress on https://github.com/openxla/iree/issues/15117
This notebook shows how to use
[SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for eager
execution within a PyTorch session using IREE and
[torch-mlir](https://github.com/llvm/torch-mlir) under the covers. I'm
starting simple to get the concepts across, with minimal API usage and a
tiny `nn.Module` sourced from
https://pytorch.org/docs/stable/notes/modules.html. My expectation is
that this notebook will evolve alongside other notebooks (e.g.
`pytorch_aot.ipynb`), documentation
(https://github.com/openxla/iree/issues/15114), and the SHARK-Turbine
project itself.
Preview URL for review:
https://colab.research.google.com/github/openxla/iree/blob/scotttodd-pytorch-samples-1/samples/colab/pytorch_jit.ipynb
skip-ci: no-op
diff --git a/samples/colab/README.md b/samples/colab/README.md
index 33a8127..4c98616 100644
--- a/samples/colab/README.md
+++ b/samples/colab/README.md
@@ -1,49 +1,19 @@
# Google Colaboratory (Colab) Notebooks
+These [Colab](https://colab.google/) notebooks contain interactive sample
+applications using IREE's Python bindings and ML framework integrations.
+
## Notebooks
-### [edge_detection\.ipynb](edge_detection.ipynb)
-
-Constructs a TF module for performing image edge detection and runs it using
-IREE
-
-[](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/edge_detection.ipynb)
-
-### [low_level_invoke_function\.ipynb](low_level_invoke_function.ipynb)
-
-Shows off some concepts of the low level IREE python bindings
-
-[](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/low_level_invoke_function.ipynb)
-
-### [mnist_training\.ipynb](mnist_training.ipynb)
-
-Compile, train and execute a TensorFlow Keras neural network with IREE
-
-[](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/mnist_training.ipynb)
-
-### [resnet\.ipynb](resnet.ipynb)
-
-Loads a pretrained
-[ResNet50](https://www.tensorflow.org/api_docs/python/tf/keras/applications/ResNet50)
-model and runs it using IREE
-
-[](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/resnet.ipynb)
-
-### [tensorflow_hub_import\.ipynb](tensorflow_hub_import.ipynb)
-
-Downloads a pretrained
-[MobileNet V2](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification)
-model, pre-processes it for import, then compiles it using IREE
-
-[](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/tensorflow_hub_import.ipynb)
-
-### [tflite_text_classification\.ipynb](tflite_text_classification.ipynb)
-
-Downloads a pretrained
-[TFLite text classification](https://www.tensorflow.org/lite/examples/text_classification/overview)
-model, and runs it using TFLite and IREE
-
-[](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/tflite_text_classification.ipynb)
+Framework | Notebook file | Description | Link
+-------- | ------------- | ----------- | ----
+Generic | [low_level_invoke_function\.ipynb](low_level_invoke_function.ipynb) | Shows off some concepts of the low level IREE python bindings | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/low_level_invoke_function.ipynb)
+PyTorch | [pytorch_jit\.ipynb](pytorch_jit.ipynb) | Uses [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for eager execution in a PyTorch session | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/pytorch_jit.ipynb)
+TensorFlow | [edge_detection\.ipynb](edge_detection.ipynb) |Performs image edge detection using TF and IREE | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/edge_detection.ipynb)
+TensorFlow | [mnist_training\.ipynb](mnist_training.ipynb) | Compile, train, and execute a neural network with IREE | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/mnist_training.ipynb)
+TensorFlow | [resnet\.ipynb](resnet.ipynb) | Loads a pretrained [ResNet50](https://www.tensorflow.org/api_docs/python/tf/keras/applications/ResNet50) model and runs it using IREE | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/resnet.ipynb)
+TensorFlow | [tensorflow_hub_import\.ipynb](tensorflow_hub_import.ipynb) | Runs a pretrained [MobileNet V2](https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification) model using IREE | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/tensorflow_hub_import.ipynb)
+TFLite | [tflite_text_classification\.ipynb](tflite_text_classification.ipynb) | Runs a pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model using IREE | [](https://colab.research.google.com/github/openxla/iree/blob/main/samples/colab/tflite_text_classification.ipynb)
## Working with GitHub
@@ -59,3 +29,7 @@
* Download the modified notebook using `File > Download .ipynb`
* Move the downloaded notebook file into a clone of this repository and submit
a pull request
+
+## Testing
+
+This notebooks are tested continuously by the samples.yml CI job.
diff --git a/samples/colab/pytorch_jit.ipynb b/samples/colab/pytorch_jit.ipynb
new file mode 100644
index 0000000..6e559a6
--- /dev/null
+++ b/samples/colab/pytorch_jit.ipynb
@@ -0,0 +1,383 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "collapsed_sections": [
+ "UUXnh11hA75x"
+ ]
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2023 The IREE Authors"
+ ],
+ "metadata": {
+ "id": "UUXnh11hA75x"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
+ "# See https://llvm.org/LICENSE.txt for license information.\n",
+ "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "FqsvmKpjBJO2"
+ },
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Just-in-time (JIT) workflows using <img src=\"https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg\" height=\"20px\"> IREE\n",
+ "\n",
+ "This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for eager execution within a PyTorch session using [IREE](https://github.com/openxla/iree) and [torch-mlir](https://github.com/llvm/torch-mlir) under the covers."
+ ],
+ "metadata": {
+ "id": "38UDc27KBPD1"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Setup"
+ ],
+ "metadata": {
+ "id": "jbcW5jMLK8gK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%%capture\n",
+ "#@title Uninstall existing packages\n",
+ "# This avoids some warnings when installing specific PyTorch packages below.\n",
+ "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
+ ],
+ "metadata": {
+ "id": "KsPubQSvCbXd"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 300
+ },
+ "id": "4iJFDHbsAzo4",
+ "outputId": "6ed6f706-f701-47a6-f8b9-2d0141579f8d"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<IPython.core.display.Javascript object>"
+ ],
+ "application/javascript": [
+ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting shark-turbine\n",
+ " Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.2/60.2 kB\u001b[0m \u001b[31m786.0 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (1.23.5)\n",
+ "Collecting iree-compiler>=20231004.665 (from shark-turbine)\n",
+ " Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n",
+ " Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m60.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting torch>=2.1.0 (from shark-turbine)\n",
+ " Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m670.2/670.2 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n",
+ "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m56.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m54.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m72.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m11.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.8/209.8 MB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting triton==2.1.0 (from torch>=2.1.0->shark-turbine)\n",
+ " Downloading triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.2/89.2 MB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.1.0->shark-turbine)\n",
+ " Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.2/20.2 MB\u001b[0m \u001b[31m74.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n",
+ "Building wheels for collected packages: shark-turbine\n",
+ " Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=73e3b15d1dfbe2c9d718b6d9f08ba3ec8dc149061c13935ed97214fb6aa77ac7\n",
+ " Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n",
+ "Successfully built shark-turbine\n",
+ "Installing collected packages: triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, iree-runtime, iree-compiler, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, shark-turbine\n",
+ " Attempting uninstall: triton\n",
+ " Found existing installation: triton 2.0.0\n",
+ " Uninstalling triton-2.0.0:\n",
+ " Successfully uninstalled triton-2.0.0\n",
+ " Attempting uninstall: torch\n",
+ " Found existing installation: torch 2.0.1+cu118\n",
+ " Uninstalling torch-2.0.1+cu118:\n",
+ " Successfully uninstalled torch-2.0.1+cu118\n",
+ "Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 shark-turbine-0.9.1.dev3 torch-2.1.0 triton-2.1.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "#@title Install SHARK-Turbine\n",
+ "\n",
+ "# Limit cell height.\n",
+ "from IPython.display import Javascript\n",
+ "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n",
+ "\n",
+ "!python -m pip install shark-turbine"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Report version information\n",
+ "!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n",
+ "\n",
+ "!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
+ "!iree-compile --version\n",
+ "\n",
+ "import torch\n",
+ "print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "nkVLzRpcDnVL",
+ "outputId": "1fa62bc3-6cba-4d7b-9ccf-d8ad024df53b"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Installed SHARK-Turbine, Version: 0.9.1.dev3\n",
+ "\n",
+ "Installed IREE, compiler version information:\n",
+ "IREE (https://openxla.github.io/iree):\n",
+ " IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n",
+ " LLVM version 18.0.0git\n",
+ " Optimized build\n",
+ "\n",
+ "Installed PyTorch, version: 2.1.0+cu121\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Sample JIT workflow\n",
+ "\n",
+ "1. Define a program using `torch.nn.Module`\n",
+ "2. Run `torch.compile(module, backend=\"turbine_cpu\")`\n",
+ "3. Use the resulting `OptimizedModule` as you would a regular `nn.Module`\n",
+ "\n",
+ "Useful documentation:\n",
+ "\n",
+ "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n",
+ "* [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) as an interface to TorchDynamo and optimizing using backend compilers like Turbine"
+ ],
+ "metadata": {
+ "id": "1Mi3YR75LBxl"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "torch.manual_seed(0)\n",
+ "\n",
+ "class LinearModule(torch.nn.Module):\n",
+ " def __init__(self, in_features, out_features):\n",
+ " super().__init__()\n",
+ " self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n",
+ " self.bias = torch.nn.Parameter(torch.randn(out_features))\n",
+ "\n",
+ " def forward(self, input):\n",
+ " return (input @ self.weight) + self.bias\n",
+ "\n",
+ "linear_module = LinearModule(4, 3)"
+ ],
+ "metadata": {
+ "id": "oPdjrmPZMNz6"
+ },
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "opt_linear_module = torch.compile(linear_module, backend=\"turbine_cpu\")\n",
+ "print(\"Compiled module using Turbine. New module type is\", type(opt_linear_module))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "eK2fWVfiSQ8f",
+ "outputId": "7696a60a-46d1-4d4b-a38b-901aa36530b5"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Compiled module using Turbine. New module type is <class 'torch._dynamo.eval_frame.OptimizedModule'>\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "args = torch.randn(4)\n",
+ "turbine_output = opt_linear_module(args)\n",
+ "\n",
+ "print(\"Weight:\", linear_module.weight)\n",
+ "print(\"Bias:\", linear_module.bias)\n",
+ "print(\"Args:\", args)\n",
+ "print(\"Output:\", turbine_output)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0AdkXY8VNL2-",
+ "outputId": "c965bf26-5d23-4776-8cda-80ce8a307d28"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "module {\n",
+ " func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[4],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {\n",
+ " %int0 = torch.constant.int 0\n",
+ " %0 = torch.aten.unsqueeze %arg2, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n",
+ " %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n",
+ " %int0_0 = torch.constant.int 0\n",
+ " %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
+ " %int1 = torch.constant.int 1\n",
+ " %3 = torch.aten.add.Tensor %2, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
+ " return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "#map = affine_map<(d0) -> (d0)>\n",
+ "module {\n",
+ " func.func @main(%arg0: tensor<4x3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<4xf32>) -> (tensor<3xf32>, tensor<1x4xf32>) {\n",
+ " %cst = arith.constant 0.000000e+00 : f32\n",
+ " %expanded = tensor.expand_shape %arg2 [[0, 1]] : tensor<4xf32> into tensor<1x4xf32>\n",
+ " %0 = tensor.empty() : tensor<1x3xf32>\n",
+ " %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x3xf32>) -> tensor<1x3xf32>\n",
+ " %2 = linalg.matmul ins(%expanded, %arg0 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%1 : tensor<1x3xf32>) -> tensor<1x3xf32>\n",
+ " %collapsed = tensor.collapse_shape %2 [[0, 1]] : tensor<1x3xf32> into tensor<3xf32>\n",
+ " %3 = tensor.empty() : tensor<3xf32>\n",
+ " %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = [\"parallel\"]} ins(%collapsed, %arg1 : tensor<3xf32>, tensor<3xf32>) outs(%3 : tensor<3xf32>) {\n",
+ " ^bb0(%in: f32, %in_0: f32, %out: f32):\n",
+ " %5 = arith.addf %in, %in_0 : f32\n",
+ " linalg.yield %5 : f32\n",
+ " } -> tensor<3xf32>\n",
+ " return %4, %expanded : tensor<3xf32>, tensor<1x4xf32>\n",
+ " }\n",
+ "}\n",
+ "\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Weight: Parameter containing:\n",
+ "tensor([[ 1.5410, -0.2934, -2.1788],\n",
+ " [ 0.5684, -1.0845, -1.3986],\n",
+ " [ 0.4033, 0.8380, -0.7193],\n",
+ " [-0.4033, -0.5966, 0.1820]], requires_grad=True)\n",
+ "Bias: Parameter containing:\n",
+ "tensor([-0.8567, 1.1006, -1.0712], requires_grad=True)\n",
+ "Args: tensor([ 0.1227, -0.5663, 0.3731, -0.8920])\n",
+ "Output: tensor([-0.4792, 2.5237, -0.9772], grad_fn=<CompiledFunctionBackward>)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py:1510: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/samples/colab/test_notebooks.py b/samples/colab/test_notebooks.py
index 984002c..17ca31e 100755
--- a/samples/colab/test_notebooks.py
+++ b/samples/colab/test_notebooks.py
@@ -15,6 +15,8 @@
# tflite_runtime requires some deps ("version `GLIBC_2.29' not found") that
# samples.Dockerfile does not currently include.
"tflite_text_classification.ipynb",
+ # Requires Python 3.10+ in our Docker image.
+ "pytorch_jit.ipynb",
]
NOTEBOOKS_EXPECTED_TO_FAIL = [