blob: ef61b9ce79e6e7bfdbe79aaed8fa290a4d412168 [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"UUXnh11hA75x",
"jbcW5jMLK8gK"
]
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"##### Copyright 2024 The IREE Authors"
],
"metadata": {
"id": "UUXnh11hA75x"
}
},
{
"cell_type": "code",
"source": [
"#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
],
"metadata": {
"cellView": "form",
"id": "FqsvmKpjBJO2"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# <img src=\"https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png\" height=\"20px\"> Hugging Face to <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch to <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n",
"\n",
"This notebook uses [iree-turbine](https://github.com/iree-org/iree-turbine) to export a pretrained [Hugging Face Transformers](https://huggingface.co/docs/transformers/) model to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n",
"\n",
"* The pretrained [whisper-small](https://huggingface.co/openai/whisper-small)\n",
" model is showcased here as it is small enough to fit comfortably into a Colab\n",
" notebook. Other pretrained models can be found at\n",
" https://huggingface.co/docs/transformers/index."
],
"metadata": {
"id": "38UDc27KBPD1"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "jbcW5jMLK8gK"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"#@title Uninstall existing packages\n",
"# This avoids some warnings when installing specific PyTorch packages below.\n",
"!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
],
"metadata": {
"id": "KsPubQSvCbXd",
"cellView": "form"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oO1tirq2ggmO",
"outputId": "c3390361-9f40-4a49-b5c7-898a62614143"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/test/cpu\n",
"Collecting torch==2.3.0\n",
" Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n",
"Installing collected packages: torch\n",
" Attempting uninstall: torch\n",
" Found existing installation: torch 2.2.1+cu121\n",
" Uninstalling torch-2.2.1+cu121:\n",
" Successfully uninstalled torch-2.2.1+cu121\n",
"Successfully installed torch-2.3.0+cpu\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4iJFDHbsAzo4",
"outputId": "94721ae8-e222-4203-c356-888b42bc20b9"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting iree-turbine\n",
" Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n",
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/150.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━\u001b[0m \u001b[32m143.4/150.4 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n",
"Collecting iree-compiler>=20240410.859 (from iree-turbine)\n",
" Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n",
" Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m26.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n",
"Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n",
"Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n"
]
}
],
"source": [
"!python -m pip install iree-turbine"
]
},
{
"cell_type": "code",
"source": [
"#@title Report version information\n",
"!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
"\n",
"!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
"!iree-compile --version\n",
"\n",
"import torch\n",
"print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nkVLzRpcDnVL",
"outputId": "ee4e956f-ca7d-45ac-9913-672ad444d89f"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Installed iree-turbine, Version: 2.3.0rc20240410\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
" IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n",
" LLVM version 19.0.0git\n",
" Optimized build\n",
"\n",
"Installed PyTorch, version: 2.3.0+cpu\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Load and run whisper-small"
],
"metadata": {
"id": "I0OfTFxwOud1"
}
},
{
"cell_type": "markdown",
"source": [
"Load the pretrained model from https://huggingface.co/openai/whisper-small.\n",
"\n",
"See also:\n",
"\n",
"* Model card: https://huggingface.co/docs/transformers/model_doc/whisper\n",
"* Test case in [SHARK-TestSuite](https://github.com/nod-ai/SHARK-TestSuite/): [`pytorch/models/whisper-small/model.py`](https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/pytorch/models/whisper-small/model.py)"
],
"metadata": {
"id": "94Ji4URLT_xM"
}
},
{
"cell_type": "code",
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"# https://huggingface.co/docs/transformers/model_doc/auto\n",
"# AutoModelForCausalLM -> WhisperForCausalLM\n",
"# AutoTokenizer -> WhisperTokenizerFast\n",
"\n",
"modelname = \"openai/whisper-small\"\n",
"tokenizer = AutoTokenizer.from_pretrained(modelname)\n",
"\n",
"# Some of the options here affect how the model is exported. See the test cases\n",
"# at https://github.com/nod-ai/SHARK-TestSuite/tree/main/e2eshark/pytorch/models\n",
"# for other options that may be useful to set.\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" modelname,\n",
" output_attentions=False,\n",
" output_hidden_states=False,\n",
" attn_implementation=\"eager\",\n",
" torchscript=True,\n",
")\n",
"\n",
"# This is just a simple demo to get some data flowing through the model.\n",
"# Depending on this model and what input it expects (text, image, audio, etc.)\n",
"# this might instead use a specific Processor class. For Whisper,\n",
"# WhisperProcessor runs audio input pre-processing and output post-processing.\n",
"example_prompt = \"Hello world!\"\n",
"example_encoding = tokenizer(example_prompt, return_tensors=\"pt\")\n",
"example_input = example_encoding[\"input_ids\"].cpu()\n",
"example_args = (example_input,)"
],
"metadata": {
"id": "HLbfUuoBPHgH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Test exporting using [`torch.export()`](https://pytorch.org/docs/stable/export.html#torch.export.export). If `torch.export` works, `aot.export()` from Turbine should work as well."
],
"metadata": {
"id": "vQlF_ua3UNvo"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"exported_program = torch.export.export(model, example_args)"
],
"metadata": {
"id": "-4LykgffY9uH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Export using the simple [`aot.export()`](https://iree.dev/guides/ml-frameworks/pytorch/#simple-api) API from Turbine."
],
"metadata": {
"id": "wXZI4GliUazA"
}
},
{
"cell_type": "code",
"source": [
"import shark_turbine.aot as aot\n",
"# Note: aot.export() wants the example args to be unpacked.\n",
"whisper_compiled_module = aot.export(model, *example_args)"
],
"metadata": {
"id": "R7-rN_z2Y_5z"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Compile using Turbine/IREE then run the program."
],
"metadata": {
"id": "YK3hjpTpUdhc"
}
},
{
"cell_type": "code",
"source": [
"binary = whisper_compiled_module.compile(save_to=None)\n",
"\n",
"import iree.runtime as ireert\n",
"config = ireert.Config(\"local-task\")\n",
"vm_module = ireert.load_vm_module(\n",
" ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n",
" config,\n",
")\n",
"\n",
"iree_outputs = vm_module.main(example_args[0])\n",
"print(iree_outputs[0].to_host())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "FctBxxEXZBan",
"outputId": "12c042d0-f740-4de1-e246-36ed9b4b357b"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[[[ 5.8126216 3.9667568 4.5749426 ... 2.7658575 2.6436937\n",
" 1.5479789]\n",
" [ 7.563438 6.0299625 5.1000338 ... 6.4327035 6.101554\n",
" 6.434801 ]\n",
" [ 0.9380368 -4.4696164 -4.012759 ... -6.24863 -7.791795\n",
" -6.84537 ]\n",
" [ 0.7450911 -3.7631674 -7.4870267 ... -6.7348223 -6.966235\n",
" -10.022385 ]\n",
" [ -0.9628638 -3.5101964 -6.0158615 ... -7.116393 -6.7086525\n",
" -10.225711 ]\n",
" [ 3.3470955 2.4927258 -3.3042645 ... -1.5709444 -1.8455245\n",
" -2.9991858]]]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Run the program using native PyTorch to compare outputs."
],
"metadata": {
"id": "5WuFpyFfUjh8"
}
},
{
"cell_type": "code",
"source": [
"torch_outputs = model(example_args[0])\n",
"print(torch_outputs[0].detach().numpy())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IxPYkcPycG4r",
"outputId": "d1a3b111-4a6f-4e2a-f80a-645c192e57e3"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[[[ 5.8126183 3.9667587 4.5749483 ... 2.7658575 2.643694\n",
" 1.5479784 ]\n",
" [ 7.563436 6.029952 5.100036 ... 6.4327083 6.101557\n",
" 6.4348083 ]\n",
" [ 0.93802685 -4.469646 -4.012787 ... -6.2486415 -7.7918167\n",
" -6.8453975 ]\n",
" [ 0.74507916 -3.763197 -7.487034 ... -6.734877 -6.966276\n",
" -10.022424 ]\n",
" [ -0.96288276 -3.510221 -6.0158725 ... -7.1164136 -6.708687\n",
" -10.225745 ]\n",
" [ 3.3470666 2.492654 -3.304323 ... -1.5709934 -1.8455791\n",
" -2.9992423 ]]]\n"
]
}
]
}
]
}