|  | { | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 0, | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "provenance": [], | 
|  | "collapsed_sections": [ | 
|  | "UUXnh11hA75x", | 
|  | "jbcW5jMLK8gK" | 
|  | ] | 
|  | }, | 
|  | "kernelspec": { | 
|  | "name": "python3", | 
|  | "display_name": "Python 3" | 
|  | }, | 
|  | "language_info": { | 
|  | "name": "python" | 
|  | } | 
|  | }, | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "##### Copyright 2024 The IREE Authors" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "UUXnh11hA75x" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", | 
|  | "# See https://llvm.org/LICENSE.txt for license information.\n", | 
|  | "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" | 
|  | ], | 
|  | "metadata": { | 
|  | "cellView": "form", | 
|  | "id": "FqsvmKpjBJO2" | 
|  | }, | 
|  | "execution_count": 1, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "# <img src=\"https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png\" height=\"20px\"> Hugging Face to <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch to <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", | 
|  | "\n", | 
|  | "This notebook uses [iree-turbine](https://github.com/iree-org/iree-turbine) to export a pretrained [Hugging Face Transformers](https://huggingface.co/docs/transformers/) model to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n", | 
|  | "\n", | 
|  | "* The pretrained [whisper-small](https://huggingface.co/openai/whisper-small)\n", | 
|  | "  model is showcased here as it is small enough to fit comfortably into a Colab\n", | 
|  | "  notebook. Other pretrained models can be found at\n", | 
|  | "  https://huggingface.co/docs/transformers/index." | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "38UDc27KBPD1" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "## Setup" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "jbcW5jMLK8gK" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "%%capture\n", | 
|  | "#@title Uninstall existing packages\n", | 
|  | "#   This avoids some warnings when installing specific PyTorch packages below.\n", | 
|  | "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "KsPubQSvCbXd", | 
|  | "cellView": "form" | 
|  | }, | 
|  | "execution_count": 2, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "oO1tirq2ggmO", | 
|  | "outputId": "c3390361-9f40-4a49-b5c7-898a62614143" | 
|  | }, | 
|  | "execution_count": 3, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n", | 
|  | "Collecting torch==2.3.0\n", | 
|  | "  Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n", | 
|  | "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n", | 
|  | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n", | 
|  | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n", | 
|  | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n", | 
|  | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n", | 
|  | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n", | 
|  | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n", | 
|  | "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n", | 
|  | "Installing collected packages: torch\n", | 
|  | "  Attempting uninstall: torch\n", | 
|  | "    Found existing installation: torch 2.2.1+cu121\n", | 
|  | "    Uninstalling torch-2.2.1+cu121:\n", | 
|  | "      Successfully uninstalled torch-2.2.1+cu121\n", | 
|  | "Successfully installed torch-2.3.0+cpu\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 4, | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "4iJFDHbsAzo4", | 
|  | "outputId": "94721ae8-e222-4203-c356-888b42bc20b9" | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Collecting iree-turbine\n", | 
|  | "  Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n", | 
|  | "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/150.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━\u001b[0m \u001b[32m143.4/150.4 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n", | 
|  | "Collecting iree-compiler>=20240410.859 (from iree-turbine)\n", | 
|  | "  Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n", | 
|  | "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n", | 
|  | "  Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n", | 
|  | "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m26.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n", | 
|  | "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n", | 
|  | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n", | 
|  | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n", | 
|  | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n", | 
|  | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n", | 
|  | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n", | 
|  | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n", | 
|  | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n", | 
|  | "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n", | 
|  | "Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n", | 
|  | "Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "!python -m pip install iree-turbine" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title Report version information\n", | 
|  | "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", | 
|  | "\n", | 
|  | "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", | 
|  | "!iree-compile --version\n", | 
|  | "\n", | 
|  | "import torch\n", | 
|  | "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "nkVLzRpcDnVL", | 
|  | "outputId": "ee4e956f-ca7d-45ac-9913-672ad444d89f" | 
|  | }, | 
|  | "execution_count": 5, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Installed iree-turbine, Version: 2.3.0rc20240410\n", | 
|  | "\n", | 
|  | "Installed IREE, compiler version information:\n", | 
|  | "IREE (https://iree.dev):\n", | 
|  | "  IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n", | 
|  | "  LLVM version 19.0.0git\n", | 
|  | "  Optimized build\n", | 
|  | "\n", | 
|  | "Installed PyTorch, version: 2.3.0+cpu\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "## Load and run whisper-small" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "I0OfTFxwOud1" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "Load the pretrained model from https://huggingface.co/openai/whisper-small.\n", | 
|  | "\n", | 
|  | "See also:\n", | 
|  | "\n", | 
|  | "* Model card: https://huggingface.co/docs/transformers/model_doc/whisper\n", | 
|  | "* Test case in [SHARK-TestSuite](https://github.com/nod-ai/SHARK-TestSuite/): [`pytorch/models/whisper-small/model.py`](https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/pytorch/models/whisper-small/model.py)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "94Ji4URLT_xM" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "from transformers import AutoModelForCausalLM, AutoTokenizer\n", | 
|  | "\n", | 
|  | "# https://huggingface.co/docs/transformers/model_doc/auto\n", | 
|  | "# AutoModelForCausalLM -> WhisperForCausalLM\n", | 
|  | "# AutoTokenizer        -> WhisperTokenizerFast\n", | 
|  | "\n", | 
|  | "modelname = \"openai/whisper-small\"\n", | 
|  | "tokenizer = AutoTokenizer.from_pretrained(modelname)\n", | 
|  | "\n", | 
|  | "# Some of the options here affect how the model is exported. See the test cases\n", | 
|  | "# at https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/models\n", | 
|  | "# for other options that may be useful to set.\n", | 
|  | "model = AutoModelForCausalLM.from_pretrained(\n", | 
|  | "    modelname,\n", | 
|  | "    output_attentions=False,\n", | 
|  | "    output_hidden_states=False,\n", | 
|  | "    attn_implementation=\"eager\",\n", | 
|  | "    torchscript=True,\n", | 
|  | ")\n", | 
|  | "\n", | 
|  | "# This is just a simple demo to get some data flowing through the model.\n", | 
|  | "# Depending on this model and what input it expects (text, image, audio, etc.)\n", | 
|  | "# this might instead use a specific Processor class. For Whisper,\n", | 
|  | "# WhisperProcessor runs audio input pre-processing and output post-processing.\n", | 
|  | "example_prompt = \"Hello world!\"\n", | 
|  | "example_encoding = tokenizer(example_prompt, return_tensors=\"pt\")\n", | 
|  | "example_input = example_encoding[\"input_ids\"].cpu()\n", | 
|  | "example_args = (example_input,)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "HLbfUuoBPHgH" | 
|  | }, | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "Test exporting using [`torch.export()`](https://pytorch.org/docs/stable/export.html#torch.export.export). If `torch.export` works, `aot.export()` from Turbine should work as well." | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "vQlF_ua3UNvo" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "import torch\n", | 
|  | "exported_program = torch.export.export(model, example_args)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "-4LykgffY9uH" | 
|  | }, | 
|  | "execution_count": null, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "Export using the simple [`aot.export()`](https://iree.dev/guides/ml-frameworks/pytorch/#simple-api) API from Turbine." | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "wXZI4GliUazA" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "import shark_turbine.aot as aot\n", | 
|  | "# Note: aot.export() wants the example args to be unpacked.\n", | 
|  | "whisper_compiled_module = aot.export(model, *example_args)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "R7-rN_z2Y_5z" | 
|  | }, | 
|  | "execution_count": 8, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "Compile using Turbine/IREE then run the program." | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "YK3hjpTpUdhc" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "binary = whisper_compiled_module.compile(save_to=None)\n", | 
|  | "\n", | 
|  | "import iree.runtime as ireert\n", | 
|  | "config = ireert.Config(\"local-task\")\n", | 
|  | "vm_module = ireert.load_vm_module(\n", | 
|  | "    ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n", | 
|  | "    config,\n", | 
|  | ")\n", | 
|  | "\n", | 
|  | "iree_outputs = vm_module.main(example_args[0])\n", | 
|  | "print(iree_outputs[0].to_host())" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "FctBxxEXZBan", | 
|  | "outputId": "12c042d0-f740-4de1-e246-36ed9b4b357b" | 
|  | }, | 
|  | "execution_count": 9, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "[[[  5.8126216   3.9667568   4.5749426 ...   2.7658575   2.6436937\n", | 
|  | "     1.5479789]\n", | 
|  | "  [  7.563438    6.0299625   5.1000338 ...   6.4327035   6.101554\n", | 
|  | "     6.434801 ]\n", | 
|  | "  [  0.9380368  -4.4696164  -4.012759  ...  -6.24863    -7.791795\n", | 
|  | "    -6.84537  ]\n", | 
|  | "  [  0.7450911  -3.7631674  -7.4870267 ...  -6.7348223  -6.966235\n", | 
|  | "   -10.022385 ]\n", | 
|  | "  [ -0.9628638  -3.5101964  -6.0158615 ...  -7.116393   -6.7086525\n", | 
|  | "   -10.225711 ]\n", | 
|  | "  [  3.3470955   2.4927258  -3.3042645 ...  -1.5709444  -1.8455245\n", | 
|  | "    -2.9991858]]]\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "Run the program using native PyTorch to compare outputs." | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "5WuFpyFfUjh8" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "torch_outputs = model(example_args[0])\n", | 
|  | "print(torch_outputs[0].detach().numpy())" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "IxPYkcPycG4r", | 
|  | "outputId": "d1a3b111-4a6f-4e2a-f80a-645c192e57e3" | 
|  | }, | 
|  | "execution_count": 10, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "[[[  5.8126183    3.9667587    4.5749483  ...   2.7658575    2.643694\n", | 
|  | "     1.5479784 ]\n", | 
|  | "  [  7.563436     6.029952     5.100036   ...   6.4327083    6.101557\n", | 
|  | "     6.4348083 ]\n", | 
|  | "  [  0.93802685  -4.469646    -4.012787   ...  -6.2486415   -7.7918167\n", | 
|  | "    -6.8453975 ]\n", | 
|  | "  [  0.74507916  -3.763197    -7.487034   ...  -6.734877    -6.966276\n", | 
|  | "   -10.022424  ]\n", | 
|  | "  [ -0.96288276  -3.510221    -6.0158725  ...  -7.1164136   -6.708687\n", | 
|  | "   -10.225745  ]\n", | 
|  | "  [  3.3470666    2.492654    -3.304323   ...  -1.5709934   -1.8455791\n", | 
|  | "    -2.9992423 ]]]\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | } | 
|  | ] | 
|  | } |