Update PyTorch sample notebooks using latest iree-turbine code. (#19658)
Progress on https://github.com/iree-org/iree-turbine/issues/28 and
https://github.com/iree-org/iree/issues/18564.
* Switch package name from shark_turbine to iree.turbine
* Fix dynamic shapes notebook to use `FxProgramsBuilder` (supported,
used by sharktank) instead of `jittable` (weird, not something we
advertise in docs now)
skip-ci: not covered by presubmit CI
diff --git a/samples/colab/pytorch_huggingface_whisper.ipynb b/samples/colab/pytorch_huggingface_whisper.ipynb
index ef61b9c..4d0314e 100644
--- a/samples/colab/pytorch_huggingface_whisper.ipynb
+++ b/samples/colab/pytorch_huggingface_whisper.ipynb
@@ -15,6 +15,8 @@
},
"language_info": {
"name": "python"
+ },
+ "widgets": {
}
},
"cells": [
@@ -75,8 +77,7 @@
"!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
],
"metadata": {
- "id": "KsPubQSvCbXd",
- "cellView": "form"
+ "id": "KsPubQSvCbXd"
},
"execution_count": 2,
"outputs": []
@@ -84,14 +85,14 @@
{
"cell_type": "code",
"source": [
- "!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0"
+ "!python -m pip install --pre --index-url https://download.pytorch.org/whl/cpu --upgrade torch==2.5.0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oO1tirq2ggmO",
- "outputId": "c3390361-9f40-4a49-b5c7-898a62614143"
+ "outputId": "1c10e964-1bd3-41e7-d7ce-70cf574d817b"
},
"execution_count": 3,
"outputs": [
@@ -99,24 +100,26 @@
"output_type": "stream",
"name": "stdout",
"text": [
- "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n",
- "Collecting torch==2.3.0\n",
- " Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n",
- "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n",
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n",
+ "Looking in indexes: https://download.pytorch.org/whl/cpu\n",
+ "Collecting torch==2.5.0\n",
+ " Downloading https://download.pytorch.org/whl/cpu/torch-2.5.0%2Bcpu-cp310-cp310-linux_x86_64.whl (174.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m174.7/174.7 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.16.1)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (4.12.2)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (3.1.5)\n",
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (2024.10.0)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.0) (1.13.1)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch==2.5.0) (1.3.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.5.0) (3.0.2)\n",
"Installing collected packages: torch\n",
" Attempting uninstall: torch\n",
- " Found existing installation: torch 2.2.1+cu121\n",
- " Uninstalling torch-2.2.1+cu121:\n",
- " Successfully uninstalled torch-2.2.1+cu121\n",
- "Successfully installed torch-2.3.0+cpu\n"
+ " Found existing installation: torch 2.5.1+cu121\n",
+ " Uninstalling torch-2.5.1+cu121:\n",
+ " Successfully uninstalled torch-2.5.1+cu121\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "timm 1.0.12 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed torch-2.5.0+cpu\n"
]
}
]
@@ -129,7 +132,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "4iJFDHbsAzo4",
- "outputId": "94721ae8-e222-4203-c356-888b42bc20b9"
+ "outputId": "c95e32a5-70ab-43e7-8c8c-300d37cccfd3"
},
"outputs": [
{
@@ -137,27 +140,35 @@
"name": "stdout",
"text": [
"Collecting iree-turbine\n",
- " Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n",
- "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/150.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━\u001b[0m \u001b[32m143.4/150.4 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n",
- "Collecting iree-compiler>=20240410.859 (from iree-turbine)\n",
- " Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n",
- " Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m26.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n",
- "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n",
- "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n",
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n",
- "Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n",
- "Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n"
+ " Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n",
+ "Collecting iree-base-compiler (from iree-turbine)\n",
+ " Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
+ "Collecting iree-base-runtime (from iree-turbine)\n",
+ " Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
+ "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n",
+ "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n",
+ " Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
+ "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n",
+ "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n",
+ " Attempting uninstall: ml_dtypes\n",
+ " Found existing installation: ml-dtypes 0.4.1\n",
+ " Uninstalling ml-dtypes-0.4.1:\n",
+ " Successfully uninstalled ml-dtypes-0.4.1\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n"
]
}
],
@@ -182,7 +193,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "nkVLzRpcDnVL",
- "outputId": "ee4e956f-ca7d-45ac-9913-672ad444d89f"
+ "outputId": "210a54b9-4044-4426-f9ee-09d5fd23839c"
},
"execution_count": 5,
"outputs": [
@@ -190,15 +201,15 @@
"output_type": "stream",
"name": "stdout",
"text": [
- "Installed iree-turbine, Version: 2.3.0rc20240410\n",
+ "Installed iree-turbine, Version: 3.1.0\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
- " IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n",
- " LLVM version 19.0.0git\n",
+ " IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n",
+ " LLVM version 20.0.0git\n",
" Optimized build\n",
"\n",
- "Installed PyTorch, version: 2.3.0+cpu\n"
+ "Installed PyTorch, version: 2.5.0+cpu\n"
]
}
]
@@ -259,10 +270,288 @@
"example_args = (example_input,)"
],
"metadata": {
- "id": "HLbfUuoBPHgH"
+ "id": "HLbfUuoBPHgH",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 476,
+ "referenced_widgets": [
+ "aaaaa863923e4b6fb92380ff063677a7",
+ "37ee0e0bcecc44a1aaf25a8e66d752ad",
+ "139c5c777da84307a28ffc7de9e037f6",
+ "ea6b64fa6c5a4d0aa8797ac9dfa9051a",
+ "28125f8e13a14bfaad7c645590360b22",
+ "4b4afd6e36a94799856aa8dd2243f9a2",
+ "7e2f1ea6da7d430fa5df1fc3b9733e10",
+ "5257eab4b9cd42bbb106404bfe428903",
+ "9ae9e376cec54470b398a9cf5fa7b9fd",
+ "b0a36845fd614468a45de443b2e5c0f8",
+ "757a8987af504be89d2cc7d058728d9c",
+ "d88e2b3299f8431bbe7af6e1232ce54c",
+ "d1d3a7166d0f40b4bb26a146e8b76d5a",
+ "f105605987244b56b6e33f162fbe6930",
+ "70a4da0f7f224aa9b828c4493fe31101",
+ "4e8a310daae6485cb53c853bcbb6b029",
+ "9c966c41c6eb4407b6c8751c29b9d082",
+ "86685457be39483fa223541a8e51a79e",
+ "cc26cd911fb84bc19bdb782060138df4",
+ "a8fc625551bb40c9aedb37d548837cd2",
+ "ef484e2a7891478d822929c4728dfdd3",
+ "c2a512da304643e9ae86eb6b1c434934",
+ "d24fd8bd3c6349b8a265a18f96901458",
+ "a217516b31244cb194bb47f4da51ae6c",
+ "8b50957482094ec58561ed62fe53c720",
+ "87e3626416c440a785c3898baf2c8bce",
+ "ecafe4f7ddc64973a6fce7a4d0fcdcb3",
+ "cfa6bedf2057488ca273dea84107cdc9",
+ "6e7a805a79c749f48d3bfa1028e3de70",
+ "fff09ea7e19b47f29304f9f315425884",
+ "9931ae28665347c883d6b4723f405bfb",
+ "f454754510404b61b18a8f87cd8ba1ce",
+ "da14fb80f6994378bf405b38c3f86bab",
+ "560e52debb244737b1cb8f3088506e80",
+ "4fe7fe3078b24fc3811072b7790a9371",
+ "bec9bf346af9464b8fb120fbbaf2fefd",
+ "ffd5fca2e0b84317a473e358e85f3d77",
+ "114aa03e366f46188a771c98607e5adb",
+ "e8dbea7f1cd0443ca3cbe114e24ff3da",
+ "ab512519d4bb4d28b6109a25d9bd6b88",
+ "4ead0d0ec9994682ab8cc9ef027eaae2",
+ "3e61b4815dd14f619b90c715d62df347",
+ "c1420dce3e7246f3866a8ce85474fe5f",
+ "fa9daa6d53564dfdb7af8593edc69884",
+ "a6b1386c9842438ca5801134d26f0b51",
+ "c21ed7621a004d9bb24e9d68c3a76a6f",
+ "7b6e231cec8946118d7bcd745f518010",
+ "2fd77a6a408a41f1be6068ef77057a28",
+ "2c2b92f12df041b0b973ae9793cdc1ab",
+ "446e30992e244f48a68d9130e82a7126",
+ "9aece5f7db4646a7a7f46a186cde18eb",
+ "6faf78708ec44e5f8d5db701938965c3",
+ "3a26fee124d442158961fd5c1b28e5bd",
+ "549fc99a11d44ef1ab48c64b742d58e7",
+ "7516468965394c4b8bcbc7ea3db3b457",
+ "d476cb392ec5423cba72f46757f1df1a",
+ "9fbede93702942c492488056485bdb6c",
+ "e5b32ffa050946469cc375b9234e35ae",
+ "4d18beace81f48c08540812193fd5244",
+ "9076e8da1009489291af97753dcae650",
+ "5f0a6e2ed40a47cc8016e1a11100579e",
+ "df19cd089b854db6ad2230f0a457aca8",
+ "b671614693b141f9821ef7b78ee98ae1",
+ "f3e226744d6e45918d4a48b924389ad7",
+ "6f4e119ee815411094a7d9d5311f10b5",
+ "77503681f26b46419157af3d49a71bcb",
+ "16778bf77b8942939e323b655ac4dfa6",
+ "67057ee8ce464d64a7769ade5d7479dd",
+ "e05d7305000c4b1091631d7b15f7900a",
+ "c9c0abecdc0e4462a8a72e47bfb1e53d",
+ "02202837100e448383e6615758556655",
+ "86ffbd1afeab411fbd322c909cb21a5d",
+ "5870517d9c124931886a88a310c2386e",
+ "59dc9a70d31c492788508f67cd975365",
+ "cf5a0cc0a32d4df2a871ec642d2de5da",
+ "9a9554c7d1f04d9ab14df042cfefbedd",
+ "33a04a7c30dd42de9cfe6a43e988604d",
+ "83e8699174dd43bc9b646d6d49c993b3",
+ "27dbc17542ba451888f1e847b59f2da7",
+ "5a480878ee0d441da7ab360d1c93fbc6",
+ "d646f8677a1b4ca190f2410a8cfc05b7",
+ "1c4cd997e1104a4ba0f3a15b4ce7dea0",
+ "342b91aaba4d4df8aa567779d1a6f4e7",
+ "c7487e228a7048ff9e9be026dc5a9f46",
+ "60cec20206c243efa658a94c7278719c",
+ "57f8627eb25c4b84a74c4bf71c6c122f",
+ "3772acd3cc9d4e61ab2a4bf0f4a15774",
+ "f7a3e9b12c4b461e972113b8880aa985",
+ "a8fdc4ca612b47c68fd130b51bcd1ece",
+ "b5a84ce7032343b9b040a03e5432d96f",
+ "b893ddddacf74e7c8ca40666fb84c24e",
+ "9b44401b93664666a42c56d3165de181",
+ "c976d288588145069099727ab5183da6",
+ "b952b8ec0ebb4727802daddb7a3d5f4a",
+ "18293fa01eb14c8b9d7f2b00669ba82a",
+ "eab894861d1f406c973b525366a8e157",
+ "b727f95de2e245019a61ac5c9466ece4",
+ "a23bf4d95ebe4af1991fe1531b6b7b2f",
+ "fe045669fdc147a5a6bde04b4f31fcef",
+ "ccee9eaca1d944099ee97f8c0a5790b5",
+ "675803d6f0b94b07b3b465b427547a00",
+ "cff95693504d4e7eac5fdfe972cf7e12",
+ "566899e9d7c04a7c954f7186d988c1e4",
+ "d994cd6184ee402ca6e4ae0b7db8faea",
+ "d9d7829fe66049379ecf88ce7b385c34",
+ "fb82acd9ae0f441d9c1dcf87b47f3486",
+ "ef1916aba2f449cbabc64248ef9cc95f",
+ "58d678849b444dfaa467f87a1b7bd9fc",
+ "bb1027ed5dbb4fb68f90480e85cde62c",
+ "79cf395270734f4a926dd3fc165f65e2"
+ ]
+ },
+ "outputId": "c33917e5-8ec5-4e03-85c3-9424a529fac9"
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
+ "You will be able to reuse this secret in all of your notebooks.\n",
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "tokenizer_config.json: 0%| | 0.00/283k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "aaaaa863923e4b6fb92380ff063677a7"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "vocab.json: 0%| | 0.00/836k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "d88e2b3299f8431bbe7af6e1232ce54c"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "tokenizer.json: 0%| | 0.00/2.48M [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "d24fd8bd3c6349b8a265a18f96901458"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "merges.txt: 0%| | 0.00/494k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "560e52debb244737b1cb8f3088506e80"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "normalizer.json: 0%| | 0.00/52.7k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a6b1386c9842438ca5801134d26f0b51"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "added_tokens.json: 0%| | 0.00/34.6k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "d476cb392ec5423cba72f46757f1df1a"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "special_tokens_map.json: 0%| | 0.00/2.19k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "16778bf77b8942939e323b655ac4dfa6"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.97k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "83e8699174dd43bc9b646d6d49c993b3"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/967M [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "a8fdc4ca612b47c68fd130b51bcd1ece"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Some weights of WhisperForCausalLM were not initialized from the model checkpoint at openai/whisper-small and are newly initialized: ['proj_out.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/3.87k [00:00<?, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "ccee9eaca1d944099ee97f8c0a5790b5"
+ }
+ },
+ "metadata": {}
+ }
+ ]
},
{
"cell_type": "markdown",
@@ -282,7 +571,7 @@
"metadata": {
"id": "-4LykgffY9uH"
},
- "execution_count": null,
+ "execution_count": 7,
"outputs": []
},
{
@@ -297,7 +586,7 @@
{
"cell_type": "code",
"source": [
- "import shark_turbine.aot as aot\n",
+ "import iree.turbine.aot as aot\n",
"# Note: aot.export() wants the example args to be unpacked.\n",
"whisper_compiled_module = aot.export(model, *example_args)"
],
@@ -336,7 +625,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "FctBxxEXZBan",
- "outputId": "12c042d0-f740-4de1-e246-36ed9b4b357b"
+ "outputId": "89733fb1-3d5a-4258-9da6-394d56ccd230"
},
"execution_count": 9,
"outputs": [
@@ -346,16 +635,16 @@
"text": [
"[[[ 5.8126216 3.9667568 4.5749426 ... 2.7658575 2.6436937\n",
" 1.5479789]\n",
- " [ 7.563438 6.0299625 5.1000338 ... 6.4327035 6.101554\n",
- " 6.434801 ]\n",
- " [ 0.9380368 -4.4696164 -4.012759 ... -6.24863 -7.791795\n",
- " -6.84537 ]\n",
- " [ 0.7450911 -3.7631674 -7.4870267 ... -6.7348223 -6.966235\n",
- " -10.022385 ]\n",
- " [ -0.9628638 -3.5101964 -6.0158615 ... -7.116393 -6.7086525\n",
- " -10.225711 ]\n",
- " [ 3.3470955 2.4927258 -3.3042645 ... -1.5709444 -1.8455245\n",
- " -2.9991858]]]\n"
+ " [ 7.5634375 6.029962 5.1000347 ... 6.432704 6.101554\n",
+ " 6.4348 ]\n",
+ " [ 0.9380306 -4.4696145 -4.012748 ... -6.2486286 -7.7917867\n",
+ " -6.8453736]\n",
+ " [ 0.7450936 -3.7631674 -7.4870253 ... -6.734828 -6.966235\n",
+ " -10.022404 ]\n",
+ " [ -0.9628601 -3.510199 -6.015854 ... -7.116391 -6.7086434\n",
+ " -10.225704 ]\n",
+ " [ 3.347097 2.4927166 -3.3042672 ... -1.5709717 -1.8455461\n",
+ " -2.9991992]]]\n"
]
}
]
@@ -380,12 +669,19 @@
"base_uri": "https://localhost:8080/"
},
"id": "IxPYkcPycG4r",
- "outputId": "d1a3b111-4a6f-4e2a-f80a-645c192e57e3"
+ "outputId": "f21bc1a0-ddc3-49ec-a122-927ad4e4b54b"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.\n"
+ ]
+ },
+ {
+ "output_type": "stream",
"name": "stdout",
"text": [
"[[[ 5.8126183 3.9667587 4.5749483 ... 2.7658575 2.643694\n",
@@ -405,4 +701,4 @@
]
}
]
-}
+}
\ No newline at end of file
diff --git a/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb b/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb
index dc08015..70b7836 100644
--- a/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb
+++ b/samples/dynamic_shapes/pytorch_dynamic_shapes.ipynb
@@ -60,7 +60,7 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
- "outputId": "7b268798-a20d-4df4-f00d-ed7811f77767"
+ "outputId": "fede6b57-a87b-42fc-ab71-57c1b8ff4ab3"
},
"source": [
"#@title General setup\n",
@@ -72,7 +72,7 @@
"os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n",
"print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")"
],
- "execution_count": 1,
+ "execution_count": 2,
"outputs": [
{
"output_type": "stream",
@@ -94,29 +94,29 @@
"metadata": {
"id": "y9KOsqosg6Ms"
},
- "execution_count": 2,
+ "execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
- "#@title Install SHARK-Turbine\n",
+ "#@title Install iree-turbine\n",
"\n",
"# Limit cell height.\n",
"from IPython.display import Javascript\n",
"display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n",
"\n",
- "!python -m pip install shark-turbine"
+ "!python -m pip install iree-turbine"
],
"metadata": {
"id": "SdCAvI3sqBO7",
- "outputId": "2be248d9-bf6b-475e-c44a-aa529f20de23",
+ "outputId": "2d38c722-33cf-4210-89a7-bf4f42f92ab9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
}
},
- "execution_count": 3,
+ "execution_count": 4,
"outputs": [
{
"output_type": "display_data",
@@ -134,37 +134,36 @@
"output_type": "stream",
"name": "stdout",
"text": [
- "Collecting shark-turbine\n",
- " Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)\n",
- "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/60.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.2/60.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (1.23.5)\n",
- "Collecting iree-compiler>=20231004.665 (from shark-turbine)\n",
- " Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n",
- " Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m91.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (2.1.0+cu118)\n",
- "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n",
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n",
- "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2.1.0)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n",
- "Building wheels for collected packages: shark-turbine\n",
- " Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=507dec827b9a2eea18f47c6ebdc84347c9956b8f2e0b186d3107a006e0742d81\n",
- " Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n",
- "Successfully built shark-turbine\n",
- "Installing collected packages: iree-runtime, iree-compiler, shark-turbine\n",
- "Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 shark-turbine-0.9.1.dev3\n"
+ "Collecting iree-turbine\n",
+ " Downloading iree_turbine-3.1.0-py3-none-any.whl.metadata (6.7 kB)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n",
+ "Collecting iree-base-compiler (from iree-turbine)\n",
+ " Downloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
+ "Collecting iree-base-runtime (from iree-turbine)\n",
+ " Downloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)\n",
+ "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.5)\n",
+ "Collecting ml_dtypes>=0.5.0 (from iree-turbine)\n",
+ " Downloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
+ "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (4.12.2)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from iree-base-compiler->iree-turbine) (1.13.1)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n",
+ "Downloading iree_turbine-3.1.0-py3-none-any.whl (301 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m301.7/301.7 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading iree_base_compiler-3.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (71.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.2/71.2 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hDownloading iree_base_runtime-3.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (8.2 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: ml_dtypes, iree-base-runtime, iree-base-compiler, iree-turbine\n",
+ " Attempting uninstall: ml_dtypes\n",
+ " Found existing installation: ml-dtypes 0.4.1\n",
+ " Uninstalling ml-dtypes-0.4.1:\n",
+ " Successfully uninstalled ml-dtypes-0.4.1\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow 2.17.1 requires ml-dtypes<0.5.0,>=0.3.1, but you have ml-dtypes 0.5.1 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0mSuccessfully installed iree-base-compiler-3.1.0 iree-base-runtime-3.1.0 iree-turbine-3.1.0 ml_dtypes-0.5.1\n"
]
}
]
@@ -173,7 +172,7 @@
"cell_type": "code",
"source": [
"#@title Report version information\n",
- "!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n",
+ "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
"\n",
"!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
"!iree-compile --version\n",
@@ -186,23 +185,23 @@
"base_uri": "https://localhost:8080/"
},
"id": "Oj5I6R9LI7t_",
- "outputId": "35d79e6a-7bd0-46e1-8113-5af1a7bcbb5b"
+ "outputId": "deaa1abf-dc0e-49d8-d165-47d53592d94f"
},
- "execution_count": 4,
+ "execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
- "Installed SHARK-Turbine, Version: 0.9.1.dev3\n",
+ "Installed iree-turbine, Version: 3.1.0\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
- " IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n",
- " LLVM version 18.0.0git\n",
+ " IREE compiler version 3.1.0rc20250107 @ d2242207764230ad398585a5771f9d54ce91b4c8\n",
+ " LLVM version 20.0.0git\n",
" Optimized build\n",
"\n",
- "Installed PyTorch, version: 2.1.0+cu118\n"
+ "Installed PyTorch, version: 2.5.1+cu121\n"
]
}
]
@@ -210,7 +209,7 @@
{
"cell_type": "markdown",
"source": [
- "## Create a program using PyTorch + SHARK-Turbine\n",
+ "## Create a program using PyTorch + iree-turbine\n",
"\n",
"NOTE: as in other domains, providing more information to a compiler allows it\n",
"to generate more efficient code. As a general rule, the slowest varying\n",
@@ -227,45 +226,78 @@
{
"cell_type": "code",
"source": [
- "#@title Define a sample `shark_turbine.aot.CompiledModule` using dynamic shapes\n",
+ "#@title Define a sample `torch.nn.Module`.\n",
"\n",
- "import shark_turbine.aot as aot\n",
+ "import iree.turbine.aot as aot\n",
"\n",
- "class DynamicShapesModule(aot.CompiledModule, export_name=\"module\"):\n",
+ "class DynamicShapesModule(torch.nn.Module):\n",
" # reduce_sum_1d (dynamic input size, static output size)\n",
" # tensor<?xi32> -> tensor<i32>\n",
" # e.g. [1, 2, 3] -> 6\n",
- " def reduce_sum_1d(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n",
- " return self.compute_reduce_sum_1d(values)\n",
- "\n",
- " @aot.jittable\n",
- " def compute_reduce_sum_1d(values):\n",
- " return torch.sum(values, dtype=torch.int32)\n",
+ " def reduce_sum_1d(self, values):\n",
+ " return torch.sum(values)\n",
"\n",
" # reduce_sum_2d (partially dynamic input size, static output size)\n",
" # tensor<?x3xi32> -> tensor<3xi32>\n",
" # e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n",
- " def reduce_sum_2d(self, values=aot.AbstractTensor(None, 3, dtype=torch.int32)):\n",
- " return self.compute_reduce_sum_2d(values)\n",
- "\n",
- " @aot.jittable\n",
- " def compute_reduce_sum_2d(values):\n",
- " return torch.sum(values, 0, dtype=torch.int32)\n",
+ " def reduce_sum_2d(self, values):\n",
+ " return torch.sum(values, 0)\n",
"\n",
" # add_one (dynamic input size, dynamic output size)\n",
" # tensor<?xi32>) -> tensor<?xi32>\n",
" # e.g. [1, 2, 3] -> [2, 3, 4]\n",
- " def add_one(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n",
- " return self.compute_add_one(values)\n",
- "\n",
- " @aot.jittable\n",
- " def compute_add_one(values):\n",
+ " def add_one(self, values):\n",
" return values + 1"
],
"metadata": {
"id": "vsf9F4WxI_DX"
},
- "execution_count": 5,
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Export using FxProgramsBuilder.\n",
+ "\n",
+ "fxb = aot.FxProgramsBuilder(DynamicShapesModule())\n",
+ "\n",
+ "# Create a single dynamic export dimension.\n",
+ "dynamic_x = torch.export.Dim(\"x\")\n",
+ "# Example inputs with a mix of placeholder (dynamic) and static dimensions.\n",
+ "example_1d = torch.empty(16, dtype=torch.int32)\n",
+ "example_2d = torch.empty((16, 3), dtype=torch.int32)\n",
+ "\n",
+ "# Export reduce_sum_1d with a dynamic dimension.\n",
+ "@fxb.export_program(\n",
+ " args=(example_1d,),\n",
+ " dynamic_shapes={\"values\": {0: dynamic_x}},\n",
+ ")\n",
+ "def reduce_sum_1d(module, values):\n",
+ " return module.reduce_sum_1d(values)\n",
+ "\n",
+ "# Export reduce_sum_2d with one dynamic dimension.\n",
+ "@fxb.export_program(\n",
+ " args=(example_2d,),\n",
+ " dynamic_shapes={\"values\": {0: dynamic_x}},\n",
+ ")\n",
+ "def reduce_sum_2d(module, values):\n",
+ " return module.reduce_sum_2d(values)\n",
+ "\n",
+ "# Export add_one with a dynamic dimension.\n",
+ "@fxb.export_program(\n",
+ " args=(example_1d,),\n",
+ " dynamic_shapes={\"values\": {0: dynamic_x}},\n",
+ ")\n",
+ "def add_one(module, values):\n",
+ " return module.add_one(values)\n",
+ "\n",
+ "export_output = aot.export(fxb)"
+ ],
+ "metadata": {
+ "id": "cCy3nuLBKTAg"
+ },
+ "execution_count": 7,
"outputs": []
},
{
@@ -273,10 +305,8 @@
"source": [
"from iree.compiler.ir import Context\n",
"\n",
- "# Import into MLIR and save to disk.\n",
- "dynamic_shapes_instance = DynamicShapesModule(context=Context())\n",
"imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n",
- "aot.CompiledModule.save_mlir(dynamic_shapes_instance, imported_mlir_path)\n",
+ "export_output.save_mlir(imported_mlir_path)\n",
"print(f\"Wrote MLIR to path '{imported_mlir_path}'\")\n",
"\n",
"# Inspect the IR.\n",
@@ -289,9 +319,9 @@
"base_uri": "https://localhost:8080/"
},
"id": "_OQIpOtNr4Gh",
- "outputId": "888c0bf3-bec6-403c-9993-ad45d21364fb"
+ "outputId": "abe96b74-88de-4979-959c-cdfbc981b17c"
},
- "execution_count": 6,
+ "execution_count": 8,
"outputs": [
{
"output_type": "stream",
@@ -300,56 +330,25 @@
"Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlir'\n",
"\n",
"Dynamic Shapes MLIR:\n",
- "#map = affine_map<(d0) -> (d0)>\n",
- "#map1 = affine_map<(d0) -> ()>\n",
- "#map2 = affine_map<(d0, d1) -> (d0, d1)>\n",
- "#map3 = affine_map<(d0, d1) -> (d1)>\n",
"module @module {\n",
- " func.func @reduce_sum_1d(%arg0: tensor<?xi32>) -> tensor<i32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n",
- " %0 = call @compute_reduce_sum_1d(%arg0) : (tensor<?xi32>) -> tensor<i32>\n",
- " return %0 : tensor<i32>\n",
+ " func.func @reduce_sum_1d(%arg0: !torch.vtensor<[?],si32>) -> !torch.vtensor<[],si64> attributes {torch.assume_strict_symbolic_shapes} {\n",
+ " %none = torch.constant.none\n",
+ " %0 = torch.aten.sum %arg0, %none : !torch.vtensor<[?],si32>, !torch.none -> !torch.vtensor<[],si64>\n",
+ " return %0 : !torch.vtensor<[],si64>\n",
" }\n",
- " func.func private @compute_reduce_sum_1d(%arg0: tensor<?xi32>) -> tensor<i32> {\n",
- " %c0_i32 = arith.constant 0 : i32\n",
- " %0 = tensor.empty() : tensor<i32>\n",
- " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<i32>) -> tensor<i32>\n",
- " %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = [\"reduction\"]} ins(%arg0 : tensor<?xi32>) outs(%1 : tensor<i32>) {\n",
- " ^bb0(%in: i32, %out: i32):\n",
- " %3 = arith.addi %in, %out : i32\n",
- " linalg.yield %3 : i32\n",
- " } -> tensor<i32>\n",
- " return %2 : tensor<i32>\n",
+ " func.func @reduce_sum_2d(%arg0: !torch.vtensor<[?,3],si32>) -> !torch.vtensor<[3],si64> attributes {torch.assume_strict_symbolic_shapes} {\n",
+ " %int0 = torch.constant.int 0\n",
+ " %0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n",
+ " %false = torch.constant.bool false\n",
+ " %none = torch.constant.none\n",
+ " %1 = torch.aten.sum.dim_IntList %arg0, %0, %false, %none : !torch.vtensor<[?,3],si32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3],si64>\n",
+ " return %1 : !torch.vtensor<[3],si64>\n",
" }\n",
- " func.func @reduce_sum_2d(%arg0: tensor<?x3xi32>) -> tensor<3xi32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n",
- " %0 = call @compute_reduce_sum_2d(%arg0) : (tensor<?x3xi32>) -> tensor<3xi32>\n",
- " return %0 : tensor<3xi32>\n",
- " }\n",
- " func.func private @compute_reduce_sum_2d(%arg0: tensor<?x3xi32>) -> tensor<3xi32> {\n",
- " %c0_i32 = arith.constant 0 : i32\n",
- " %0 = tensor.empty() : tensor<3xi32>\n",
- " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<3xi32>) -> tensor<3xi32>\n",
- " %2 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = [\"reduction\", \"parallel\"]} ins(%arg0 : tensor<?x3xi32>) outs(%1 : tensor<3xi32>) {\n",
- " ^bb0(%in: i32, %out: i32):\n",
- " %3 = arith.addi %in, %out : i32\n",
- " linalg.yield %3 : i32\n",
- " } -> tensor<3xi32>\n",
- " return %2 : tensor<3xi32>\n",
- " }\n",
- " func.func @add_one(%arg0: tensor<?xi32>) -> tensor<?xi32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n",
- " %0 = call @compute_add_one(%arg0) : (tensor<?xi32>) -> tensor<?xi32>\n",
- " return %0 : tensor<?xi32>\n",
- " }\n",
- " func.func private @compute_add_one(%arg0: tensor<?xi32>) -> tensor<?xi32> {\n",
- " %c0 = arith.constant 0 : index\n",
- " %c1_i32 = arith.constant 1 : i32\n",
- " %dim = tensor.dim %arg0, %c0 : tensor<?xi32>\n",
- " %0 = tensor.empty(%dim) : tensor<?xi32>\n",
- " %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = [\"parallel\"]} ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xi32>) {\n",
- " ^bb0(%in: i32, %out: i32):\n",
- " %2 = arith.addi %in, %c1_i32 : i32\n",
- " linalg.yield %2 : i32\n",
- " } -> tensor<?xi32>\n",
- " return %1 : tensor<?xi32>\n",
+ " func.func @add_one(%arg0: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si32> attributes {torch.assume_strict_symbolic_shapes} {\n",
+ " %int1 = torch.constant.int 1\n",
+ " %int1_0 = torch.constant.int 1\n",
+ " %0 = torch.aten.add.Scalar %arg0, %int1, %int1_0 : !torch.vtensor<[?],si32>, !torch.int, !torch.int -> !torch.vtensor<[?],si32>\n",
+ " return %0 : !torch.vtensor<[?],si32>\n",
" }\n",
"}\n"
]
@@ -377,25 +376,22 @@
{
"cell_type": "code",
"source": [
- "# Export and compile.\n",
- "exported_output = aot.export(DynamicShapesModule)\n",
- "\n",
"# Compile to a file on disk for usage outside of Python.\n",
"flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes_cpu.vmfb\")\n",
- "exported_output.compile(save_to=flatbuffer_path)\n",
+ "export_output.compile(save_to=flatbuffer_path)\n",
"print(f\"Wrote compiled program to path '{flatbuffer_path}'\")\n",
"\n",
"# Compile into memory for testing.\n",
- "binary = exported_output.compile(save_to=None)"
+ "binary = export_output.compile(save_to=None)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0PGyH1tvI_Ic",
- "outputId": "23b53928-4d77-461f-e4b8-b2c8ffb25ef0"
+ "outputId": "2ac3f280-1834-4d6c-f5b0-c9b470549ca7"
},
- "execution_count": 7,
+ "execution_count": 9,
"outputs": [
{
"output_type": "stream",
@@ -429,9 +425,9 @@
"base_uri": "https://localhost:8080/"
},
"id": "9ilJY15BI_LD",
- "outputId": "57db6e52-83f1-4283-fc08-31e743cc9b42"
+ "outputId": "f20aec4f-353e-4793-f9f1-066006d4471b"
},
- "execution_count": 8,
+ "execution_count": 10,
"outputs": [
{
"output_type": "stream",
@@ -476,9 +472,9 @@
"height": 86
},
"id": "dgaXpdiWuGtx",
- "outputId": "dc0fbca1-c5b0-44f9-e1ff-9bf1307c049f"
+ "outputId": "94823b69-1095-4a97-9974-7d36fb3e2fb8"
},
- "execution_count": 9,
+ "execution_count": 11,
"outputs": [
{
"output_type": "stream",
@@ -486,7 +482,7 @@
"text": [
"Zipping '/tmp/iree/colab_artifacts' to '/tmp/dynamic_shapes_colab_artifacts.zip' for download...\n",
" adding: dynamic_shapes_cpu.vmfb (deflated 66%)\n",
- " adding: dynamic_shapes.mlir (deflated 82%)\n",
+ " adding: dynamic_shapes.mlir (deflated 72%)\n",
"Downloading the artifacts zip file...\n"
]
},
@@ -549,7 +545,7 @@
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
- "download(\"download_e2630f9b-e811-4164-b2d8-80cf52f17145\", \"dynamic_shapes_colab_artifacts.zip\", 5699)"
+ "download(\"download_7377c999-5cd8-4987-95c4-921d56969f65\", \"dynamic_shapes_colab_artifacts.zip\", 5472)"
]
},
"metadata": {}
@@ -557,4 +553,4 @@
]
}
]
-}
+}
\ No newline at end of file