| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "provenance": [], |
| "collapsed_sections": [ |
| "UUXnh11hA75x", |
| "jbcW5jMLK8gK" |
| ] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "##### Copyright 2024 The IREE Authors" |
| ], |
| "metadata": { |
| "id": "UUXnh11hA75x" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "metadata": { |
| "cellView": "form", |
| "id": "FqsvmKpjBJO2" |
| }, |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# <img src=\"https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png\" height=\"20px\"> Hugging Face to <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch to <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", |
| "\n", |
| "This notebook uses [iree-turbine](https://github.com/iree-org/iree-turbine) to export a pretrained [Hugging Face Transformers](https://huggingface.co/docs/transformers/) model to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n", |
| "\n", |
| "* The pretrained [whisper-small](https://huggingface.co/openai/whisper-small)\n", |
| " model is showcased here as it is small enough to fit comfortably into a Colab\n", |
| " notebook. Other pretrained models can be found at\n", |
| " https://huggingface.co/docs/transformers/index." |
| ], |
| "metadata": { |
| "id": "38UDc27KBPD1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Setup" |
| ], |
| "metadata": { |
| "id": "jbcW5jMLK8gK" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ], |
| "metadata": { |
| "id": "KsPubQSvCbXd", |
| "cellView": "form" |
| }, |
| "execution_count": 2, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "oO1tirq2ggmO", |
| "outputId": "c3390361-9f40-4a49-b5c7-898a62614143" |
| }, |
| "execution_count": 3, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n", |
| "Collecting torch==2.3.0\n", |
| " Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n", |
| "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n", |
| "Installing collected packages: torch\n", |
| " Attempting uninstall: torch\n", |
| " Found existing installation: torch 2.2.1+cu121\n", |
| " Uninstalling torch-2.2.1+cu121:\n", |
| " Successfully uninstalled torch-2.2.1+cu121\n", |
| "Successfully installed torch-2.3.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "4iJFDHbsAzo4", |
| "outputId": "94721ae8-e222-4203-c356-888b42bc20b9" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting iree-turbine\n", |
| " Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n", |
| "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/150.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━\u001b[0m \u001b[32m143.4/150.4 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n", |
| "Collecting iree-compiler>=20240410.859 (from iree-turbine)\n", |
| " Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n", |
| " Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m26.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n", |
| "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n", |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n", |
| "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n", |
| "Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n", |
| "Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n" |
| ] |
| } |
| ], |
| "source": [ |
| "!python -m pip install iree-turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "nkVLzRpcDnVL", |
| "outputId": "ee4e956f-ca7d-45ac-9913-672ad444d89f" |
| }, |
| "execution_count": 5, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed iree-turbine, Version: 2.3.0rc20240410\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n", |
| " LLVM version 19.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.3.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Load and run whisper-small" |
| ], |
| "metadata": { |
| "id": "I0OfTFxwOud1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Load the pretrained model from https://huggingface.co/openai/whisper-small.\n", |
| "\n", |
| "See also:\n", |
| "\n", |
| "* Model card: https://huggingface.co/docs/transformers/model_doc/whisper\n", |
| "* Test case in [SHARK-TestSuite](https://github.com/nod-ai/SHARK-TestSuite/): [`pytorch/models/whisper-small/model.py`](https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/pytorch/models/whisper-small/model.py)" |
| ], |
| "metadata": { |
| "id": "94Ji4URLT_xM" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", |
| "\n", |
| "# https://huggingface.co/docs/transformers/model_doc/auto\n", |
| "# AutoModelForCausalLM -> WhisperForCausalLM\n", |
| "# AutoTokenizer -> WhisperTokenizerFast\n", |
| "\n", |
| "modelname = \"openai/whisper-small\"\n", |
| "tokenizer = AutoTokenizer.from_pretrained(modelname)\n", |
| "\n", |
| "# Some of the options here affect how the model is exported. See the test cases\n", |
| "# at https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/models\n", |
| "# for other options that may be useful to set.\n", |
| "model = AutoModelForCausalLM.from_pretrained(\n", |
| " modelname,\n", |
| " output_attentions=False,\n", |
| " output_hidden_states=False,\n", |
| " attn_implementation=\"eager\",\n", |
| " torchscript=True,\n", |
| ")\n", |
| "\n", |
| "# This is just a simple demo to get some data flowing through the model.\n", |
| "# Depending on this model and what input it expects (text, image, audio, etc.)\n", |
| "# this might instead use a specific Processor class. For Whisper,\n", |
| "# WhisperProcessor runs audio input pre-processing and output post-processing.\n", |
| "example_prompt = \"Hello world!\"\n", |
| "example_encoding = tokenizer(example_prompt, return_tensors=\"pt\")\n", |
| "example_input = example_encoding[\"input_ids\"].cpu()\n", |
| "example_args = (example_input,)" |
| ], |
| "metadata": { |
| "id": "HLbfUuoBPHgH" |
| }, |
| "execution_count": null, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Test exporting using [`torch.export()`](https://pytorch.org/docs/stable/export.html#torch.export.export). If `torch.export` works, `aot.export()` from Turbine should work as well." |
| ], |
| "metadata": { |
| "id": "vQlF_ua3UNvo" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "import torch\n", |
| "exported_program = torch.export.export(model, example_args)" |
| ], |
| "metadata": { |
| "id": "-4LykgffY9uH" |
| }, |
| "execution_count": null, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Export using the simple [`aot.export()`](https://iree.dev/guides/ml-frameworks/pytorch/#simple-api) API from Turbine." |
| ], |
| "metadata": { |
| "id": "wXZI4GliUazA" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "import shark_turbine.aot as aot\n", |
| "# Note: aot.export() wants the example args to be unpacked.\n", |
| "whisper_compiled_module = aot.export(model, *example_args)" |
| ], |
| "metadata": { |
| "id": "R7-rN_z2Y_5z" |
| }, |
| "execution_count": 8, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Compile using Turbine/IREE then run the program." |
| ], |
| "metadata": { |
| "id": "YK3hjpTpUdhc" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "binary = whisper_compiled_module.compile(save_to=None)\n", |
| "\n", |
| "import iree.runtime as ireert\n", |
| "config = ireert.Config(\"local-task\")\n", |
| "vm_module = ireert.load_vm_module(\n", |
| " ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n", |
| " config,\n", |
| ")\n", |
| "\n", |
| "iree_outputs = vm_module.main(example_args[0])\n", |
| "print(iree_outputs[0].to_host())" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "FctBxxEXZBan", |
| "outputId": "12c042d0-f740-4de1-e246-36ed9b4b357b" |
| }, |
| "execution_count": 9, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "[[[ 5.8126216 3.9667568 4.5749426 ... 2.7658575 2.6436937\n", |
| " 1.5479789]\n", |
| " [ 7.563438 6.0299625 5.1000338 ... 6.4327035 6.101554\n", |
| " 6.434801 ]\n", |
| " [ 0.9380368 -4.4696164 -4.012759 ... -6.24863 -7.791795\n", |
| " -6.84537 ]\n", |
| " [ 0.7450911 -3.7631674 -7.4870267 ... -6.7348223 -6.966235\n", |
| " -10.022385 ]\n", |
| " [ -0.9628638 -3.5101964 -6.0158615 ... -7.116393 -6.7086525\n", |
| " -10.225711 ]\n", |
| " [ 3.3470955 2.4927258 -3.3042645 ... -1.5709444 -1.8455245\n", |
| " -2.9991858]]]\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Run the program using native PyTorch to compare outputs." |
| ], |
| "metadata": { |
| "id": "5WuFpyFfUjh8" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "torch_outputs = model(example_args[0])\n", |
| "print(torch_outputs[0].detach().numpy())" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "IxPYkcPycG4r", |
| "outputId": "d1a3b111-4a6f-4e2a-f80a-645c192e57e3" |
| }, |
| "execution_count": 10, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "[[[ 5.8126183 3.9667587 4.5749483 ... 2.7658575 2.643694\n", |
| " 1.5479784 ]\n", |
| " [ 7.563436 6.029952 5.100036 ... 6.4327083 6.101557\n", |
| " 6.4348083 ]\n", |
| " [ 0.93802685 -4.469646 -4.012787 ... -6.2486415 -7.7918167\n", |
| " -6.8453975 ]\n", |
| " [ 0.74507916 -3.763197 -7.487034 ... -6.734877 -6.966276\n", |
| " -10.022424 ]\n", |
| " [ -0.96288276 -3.510221 -6.0158725 ... -7.1164136 -6.708687\n", |
| " -10.225745 ]\n", |
| " [ 3.3470666 2.492654 -3.304323 ... -1.5709934 -1.8455791\n", |
| " -2.9992423 ]]]\n" |
| ] |
| } |
| ] |
| } |
| ] |
| } |