| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "provenance": [], |
| "collapsed_sections": [ |
| "UUXnh11hA75x", |
| "jbcW5jMLK8gK" |
| ] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| }, |
| "language_info": { |
| "name": "python" |
| }, |
| "widgets": { |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "##### Copyright 2024 The IREE Authors" |
| ], |
| "metadata": { |
| "id": "UUXnh11hA75x" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "metadata": { |
| "cellView": "form", |
| "id": "FqsvmKpjBJO2" |
| }, |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# <img src=\"https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png\" height=\"20px\"> Hugging Face to <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch to <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", |
| "\n", |
| "This notebook uses [iree-turbine](https://github.com/iree-org/iree-turbine) to export a pretrained [Hugging Face Transformers](https://huggingface.co/docs/transformers/) model to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n", |
| "\n", |
| "* The pretrained [whisper-small](https://huggingface.co/openai/whisper-small)\n", |
| " model is showcased here as it is small enough to fit comfortably into a Colab\n", |
| " notebook. Other pretrained models can be found at\n", |
| " https://huggingface.co/docs/transformers/index." |
| ], |
| "metadata": { |
| "id": "38UDc27KBPD1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Setup" |
| ], |
| "metadata": { |
| "id": "jbcW5jMLK8gK" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ], |
| "metadata": { |
| "id": "KsPubQSvCbXd" |
| }, |
| "execution_count": 2, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "!python -m pip install --pre --index-url https://download.pytorch.org/whl/cpu --upgrade torch==2.5.0" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "oO1tirq2ggmO", |
| "outputId": "1c10e964-1bd3-41e7-d7ce-70cf574d817b" |
| }, |
| "execution_count": 3, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Looking in indexes: https://download.pytorch.org/whl/cpu\n", |
| "Collecting torch==2.5.0\n", |
| " Downloading https://download.pytorch.org/whl/cpu/torch-2.5.0%2Bcpu-cp310-cp310-linux_x86_64.whl (174.7 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.7/174.7 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.16.1)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (4.12.2)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.4.2)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.1.5)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (2024.10.0)\n", |
| "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (1.13.1)\n", |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch==2.5.0) (1.3.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.5.0) (3.0.2)\n", |
| "Installing collected packages: torch\n", |
| " Attempting uninstall: torch\n", |
| " Found existing installation: torch 2.5.1+cu121\n", |
| " Uninstalling torch-2.5.1+cu121:\n", |
| " Successfully uninstalled torch-2.5.1+cu121\n", |
| "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", |
| "timm 1.0.12 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n", |
| "\u001b[0mSuccessfully installed torch-2.5.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "4iJFDHbsAzo4", |
| "outputId": "c95e32a5-70ab-43e7-8c8c-300d37cccfd3" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting iree-turbine\n", |
| " Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n", |
| "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n", |
| "Collecting iree-base-compiler (from iree-turbine)\n", |
| " Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n", |
| "Collecting iree-base-runtime (from iree-turbine)\n", |
| " Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n", |
| "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n", |
| "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n", |
| " Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n", |
| "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n", |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n", |
| "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n", |
| " Attempting uninstall: ml_dtypes\n", |
| " Found existing installation: ml-dtypes 0.4.1\n", |
| " Uninstalling ml-dtypes-0.4.1:\n", |
| " Successfully uninstalled ml-dtypes-0.4.1\n", |
| "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", |
| "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n", |
| "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n" |
| ] |
| } |
| ], |
| "source": [ |
| "!python -m pip install iree-turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "nkVLzRpcDnVL", |
| "outputId": "210a54b9-4044-4426-f9ee-09d5fd23839c" |
| }, |
| "execution_count": 5, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed iree-turbine, Version: 3.1.0\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n", |
| " LLVM version 20.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.5.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Load and run whisper-small" |
| ], |
| "metadata": { |
| "id": "I0OfTFxwOud1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Load the pretrained model from https://huggingface.co/openai/whisper-small.\n", |
| "\n", |
| "See also:\n", |
| "\n", |
| "* Model card: https://huggingface.co/docs/transformers/model_doc/whisper\n", |
| "* Test case in [SHARK-TestSuite](https://github.com/nod-ai/SHARK-TestSuite/): [`pytorch/models/whisper-small/model.py`](https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/pytorch/models/whisper-small/model.py)" |
| ], |
| "metadata": { |
| "id": "94Ji4URLT_xM" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", |
| "\n", |
| "# https://huggingface.co/docs/transformers/model_doc/auto\n", |
| "# AutoModelForCausalLM -> WhisperForCausalLM\n", |
| "# AutoTokenizer -> WhisperTokenizerFast\n", |
| "\n", |
| "modelname = \"openai/whisper-small\"\n", |
| "tokenizer = AutoTokenizer.from_pretrained(modelname)\n", |
| "\n", |
| "# Some of the options here affect how the model is exported. See the test cases\n", |
| "# at https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/models\n", |
| "# for other options that may be useful to set.\n", |
| "model = AutoModelForCausalLM.from_pretrained(\n", |
| " modelname,\n", |
| " output_attentions=False,\n", |
| " output_hidden_states=False,\n", |
| " attn_implementation=\"eager\",\n", |
| " torchscript=True,\n", |
| ")\n", |
| "\n", |
| "# This is just a simple demo to get some data flowing through the model.\n", |
| "# Depending on this model and what input it expects (text, image, audio, etc.)\n", |
| "# this might instead use a specific Processor class. For Whisper,\n", |
| "# WhisperProcessor runs audio input pre-processing and output post-processing.\n", |
| "example_prompt = \"Hello world!\"\n", |
| "example_encoding = tokenizer(example_prompt, return_tensors=\"pt\")\n", |
| "example_input = example_encoding[\"input_ids\"].cpu()\n", |
| "example_args = (example_input,)" |
| ], |
| "metadata": { |
| "id": "HLbfUuoBPHgH", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 476, |
| "referenced_widgets": [ |
| "aaaaa863923e4b6fb92380ff063677a7", |
| "37ee0e0bcecc44a1aaf25a8e66d752ad", |
| "139c5c777da84307a28ffc7de9e037f6", |
| "ea6b64fa6c5a4d0aa8797ac9dfa9051a", |
| "28125f8e13a14bfaad7c645590360b22", |
| "4b4afd6e36a94799856aa8dd2243f9a2", |
| "7e2f1ea6da7d430fa5df1fc3b9733e10", |
| "5257eab4b9cd42bbb106404bfe428903", |
| "9ae9e376cec54470b398a9cf5fa7b9fd", |
| "b0a36845fd614468a45de443b2e5c0f8", |
| "757a8987af504be89d2cc7d058728d9c", |
| "d88e2b3299f8431bbe7af6e1232ce54c", |
| "d1d3a7166d0f40b4bb26a146e8b76d5a", |
| "f105605987244b56b6e33f162fbe6930", |
| "70a4da0f7f224aa9b828c4493fe31101", |
| "4e8a310daae6485cb53c853bcbb6b029", |
| "9c966c41c6eb4407b6c8751c29b9d082", |
| "86685457be39483fa223541a8e51a79e", |
| "cc26cd911fb84bc19bdb782060138df4", |
| "a8fc625551bb40c9aedb37d548837cd2", |
| "ef484e2a7891478d822929c4728dfdd3", |
| "c2a512da304643e9ae86eb6b1c434934", |
| "d24fd8bd3c6349b8a265a18f96901458", |
| "a217516b31244cb194bb47f4da51ae6c", |
| "8b50957482094ec58561ed62fe53c720", |
| "87e3626416c440a785c3898baf2c8bce", |
| "ecafe4f7ddc64973a6fce7a4d0fcdcb3", |
| "cfa6bedf2057488ca273dea84107cdc9", |
| "6e7a805a79c749f48d3bfa1028e3de70", |
| "fff09ea7e19b47f29304f9f315425884", |
| "9931ae28665347c883d6b4723f405bfb", |
| "f454754510404b61b18a8f87cd8ba1ce", |
| "da14fb80f6994378bf405b38c3f86bab", |
| "560e52debb244737b1cb8f3088506e80", |
| "4fe7fe3078b24fc3811072b7790a9371", |
| "bec9bf346af9464b8fb120fbbaf2fefd", |
| "ffd5fca2e0b84317a473e358e85f3d77", |
| "114aa03e366f46188a771c98607e5adb", |
| "e8dbea7f1cd0443ca3cbe114e24ff3da", |
| "ab512519d4bb4d28b6109a25d9bd6b88", |
| "4ead0d0ec9994682ab8cc9ef027eaae2", |
| "3e61b4815dd14f619b90c715d62df347", |
| "c1420dce3e7246f3866a8ce85474fe5f", |
| "fa9daa6d53564dfdb7af8593edc69884", |
| "a6b1386c9842438ca5801134d26f0b51", |
| "c21ed7621a004d9bb24e9d68c3a76a6f", |
| "7b6e231cec8946118d7bcd745f518010", |
| "2fd77a6a408a41f1be6068ef77057a28", |
| "2c2b92f12df041b0b973ae9793cdc1ab", |
| "446e30992e244f48a68d9130e82a7126", |
| "9aece5f7db4646a7a7f46a186cde18eb", |
| "6faf78708ec44e5f8d5db701938965c3", |
| "3a26fee124d442158961fd5c1b28e5bd", |
| "549fc99a11d44ef1ab48c64b742d58e7", |
| "7516468965394c4b8bcbc7ea3db3b457", |
| "d476cb392ec5423cba72f46757f1df1a", |
| "9fbede93702942c492488056485bdb6c", |
| "e5b32ffa050946469cc375b9234e35ae", |
| "4d18beace81f48c08540812193fd5244", |
| "9076e8da1009489291af97753dcae650", |
| "5f0a6e2ed40a47cc8016e1a11100579e", |
| "df19cd089b854db6ad2230f0a457aca8", |
| "b671614693b141f9821ef7b78ee98ae1", |
| "f3e226744d6e45918d4a48b924389ad7", |
| "6f4e119ee815411094a7d9d5311f10b5", |
| "77503681f26b46419157af3d49a71bcb", |
| "16778bf77b8942939e323b655ac4dfa6", |
| "67057ee8ce464d64a7769ade5d7479dd", |
| "e05d7305000c4b1091631d7b15f7900a", |
| "c9c0abecdc0e4462a8a72e47bfb1e53d", |
| "02202837100e448383e6615758556655", |
| "86ffbd1afeab411fbd322c909cb21a5d", |
| "5870517d9c124931886a88a310c2386e", |
| "59dc9a70d31c492788508f67cd975365", |
| "cf5a0cc0a32d4df2a871ec642d2de5da", |
| "9a9554c7d1f04d9ab14df042cfefbedd", |
| "33a04a7c30dd42de9cfe6a43e988604d", |
| "83e8699174dd43bc9b646d6d49c993b3", |
| "27dbc17542ba451888f1e847b59f2da7", |
| "5a480878ee0d441da7ab360d1c93fbc6", |
| "d646f8677a1b4ca190f2410a8cfc05b7", |
| "1c4cd997e1104a4ba0f3a15b4ce7dea0", |
| "342b91aaba4d4df8aa567779d1a6f4e7", |
| "c7487e228a7048ff9e9be026dc5a9f46", |
| "60cec20206c243efa658a94c7278719c", |
| "57f8627eb25c4b84a74c4bf71c6c122f", |
| "3772acd3cc9d4e61ab2a4bf0f4a15774", |
| "f7a3e9b12c4b461e972113b8880aa985", |
| "a8fdc4ca612b47c68fd130b51bcd1ece", |
| "b5a84ce7032343b9b040a03e5432d96f", |
| "b893ddddacf74e7c8ca40666fb84c24e", |
| "9b44401b93664666a42c56d3165de181", |
| "c976d288588145069099727ab5183da6", |
| "b952b8ec0ebb4727802daddb7a3d5f4a", |
| "18293fa01eb14c8b9d7f2b00669ba82a", |
| "eab894861d1f406c973b525366a8e157", |
| "b727f95de2e245019a61ac5c9466ece4", |
| "a23bf4d95ebe4af1991fe1531b6b7b2f", |
| "fe045669fdc147a5a6bde04b4f31fcef", |
| "ccee9eaca1d944099ee97f8c0a5790b5", |
| "675803d6f0b94b07b3b465b427547a00", |
| "cff95693504d4e7eac5fdfe972cf7e12", |
| "566899e9d7c04a7c954f7186d988c1e4", |
| "d994cd6184ee402ca6e4ae0b7db8faea", |
| "d9d7829fe66049379ecf88ce7b385c34", |
| "fb82acd9ae0f441d9c1dcf87b47f3486", |
| "ef1916aba2f449cbabc64248ef9cc95f", |
| "58d678849b444dfaa467f87a1b7bd9fc", |
| "bb1027ed5dbb4fb68f90480e85cde62c", |
| "79cf395270734f4a926dd3fc165f65e2" |
| ] |
| }, |
| "outputId": "c33917e5-8ec5-4e03-85c3-9424a529fac9" |
| }, |
| "execution_count": 6, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stderr", |
| "text": [ |
| "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", |
| "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", |
| "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", |
| "You will be able to reuse this secret in all of your notebooks.\n", |
| "Please note that authentication is recommended but still optional to access public models or datasets.\n", |
| " warnings.warn(\n" |
| ] |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "tokenizer_config.json: 0%| | 0.00/283k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "aaaaa863923e4b6fb92380ff063677a7" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "vocab.json: 0%| | 0.00/836k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "d88e2b3299f8431bbe7af6e1232ce54c" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "tokenizer.json: 0%| | 0.00/2.48M [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "d24fd8bd3c6349b8a265a18f96901458" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "merges.txt: 0%| | 0.00/494k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "560e52debb244737b1cb8f3088506e80" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "normalizer.json: 0%| | 0.00/52.7k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "a6b1386c9842438ca5801134d26f0b51" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "added_tokens.json: 0%| | 0.00/34.6k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "d476cb392ec5423cba72f46757f1df1a" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "special_tokens_map.json: 0%| | 0.00/2.19k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "16778bf77b8942939e323b655ac4dfa6" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "config.json: 0%| | 0.00/1.97k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "83e8699174dd43bc9b646d6d49c993b3" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "model.safetensors: 0%| | 0.00/967M [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "a8fdc4ca612b47c68fd130b51bcd1ece" |
| } |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "stream", |
| "name": "stderr", |
| "text": [ |
| "Some weights of WhisperForCausalLM were not initialized from the model checkpoint at openai/whisper-small and are newly initialized: ['proj_out.weight']\n", |
| "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" |
| ] |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "generation_config.json: 0%| | 0.00/3.87k [00:00<?, ?B/s]" |
| ], |
| "application/vnd.jupyter.widget-view+json": { |
| "version_major": 2, |
| "version_minor": 0, |
| "model_id": "ccee9eaca1d944099ee97f8c0a5790b5" |
| } |
| }, |
| "metadata": {} |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Test exporting using [`torch.export()`](https://pytorch.org/docs/stable/export.html#torch.export.export). If `torch.export` works, `aot.export()` from Turbine should work as well." |
| ], |
| "metadata": { |
| "id": "vQlF_ua3UNvo" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "import torch\n", |
| "exported_program = torch.export.export(model, example_args)" |
| ], |
| "metadata": { |
| "id": "-4LykgffY9uH" |
| }, |
| "execution_count": 7, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Export using the simple [`aot.export()`](https://iree.dev/guides/ml-frameworks/pytorch/#simple-api) API from Turbine." |
| ], |
| "metadata": { |
| "id": "wXZI4GliUazA" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "import iree.turbine.aot as aot\n", |
| "# Note: aot.export() wants the example args to be unpacked.\n", |
| "whisper_compiled_module = aot.export(model, *example_args)" |
| ], |
| "metadata": { |
| "id": "R7-rN_z2Y_5z" |
| }, |
| "execution_count": 8, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Compile using Turbine/IREE then run the program." |
| ], |
| "metadata": { |
| "id": "YK3hjpTpUdhc" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "binary = whisper_compiled_module.compile(save_to=None)\n", |
| "\n", |
| "import iree.runtime as ireert\n", |
| "config = ireert.Config(\"local-task\")\n", |
| "vm_module = ireert.load_vm_module(\n", |
| " ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n", |
| " config,\n", |
| ")\n", |
| "\n", |
| "iree_outputs = vm_module.main(example_args[0])\n", |
| "print(iree_outputs[0].to_host())" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "FctBxxEXZBan", |
| "outputId": "89733fb1-3d5a-4258-9da6-394d56ccd230" |
| }, |
| "execution_count": 9, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "[[[ 5.8126216 3.9667568 4.5749426 ... 2.7658575 2.6436937\n", |
| " 1.5479789]\n", |
| " [ 7.5634375 6.029962 5.1000347 ... 6.432704 6.101554\n", |
| " 6.4348 ]\n", |
| " [ 0.9380306 -4.4696145 -4.012748 ... -6.2486286 -7.7917867\n", |
| " -6.8453736]\n", |
| " [ 0.7450936 -3.7631674 -7.4870253 ... -6.734828 -6.966235\n", |
| " -10.022404 ]\n", |
| " [ -0.9628601 -3.510199 -6.015854 ... -7.116391 -6.7086434\n", |
| " -10.225704 ]\n", |
| " [ 3.347097 2.4927166 -3.3042672 ... -1.5709717 -1.8455461\n", |
| " -2.9991992]]]\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "Run the program using native PyTorch to compare outputs." |
| ], |
| "metadata": { |
| "id": "5WuFpyFfUjh8" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "torch_outputs = model(example_args[0])\n", |
| "print(torch_outputs[0].detach().numpy())" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "IxPYkcPycG4r", |
| "outputId": "f21bc1a0-ddc3-49ec-a122-927ad4e4b54b" |
| }, |
| "execution_count": 10, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stderr", |
| "text": [ |
| "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.\n" |
| ] |
| }, |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "[[[ 5.8126183 3.9667587 4.5749483 ... 2.7658575 2.643694\n", |
| " 1.5479784 ]\n", |
| " [ 7.563436 6.029952 5.100036 ... 6.4327083 6.101557\n", |
| " 6.4348083 ]\n", |
| " [ 0.93802685 -4.469646 -4.012787 ... -6.2486415 -7.7918167\n", |
| " -6.8453975 ]\n", |
| " [ 0.74507916 -3.763197 -7.487034 ... -6.734877 -6.966276\n", |
| " -10.022424 ]\n", |
| " [ -0.96288276 -3.510221 -6.0158725 ... -7.1164136 -6.708687\n", |
| " -10.225745 ]\n", |
| " [ 3.3470666 2.492654 -3.304323 ... -1.5709934 -1.8455791\n", |
| " -2.9992423 ]]]\n" |
| ] |
| } |
| ] |
| } |
| ] |
| } |