blob: a106f1070795b9325a2b0de494124bf4f9b2717d [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"UUXnh11hA75x"
]
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"##### Copyright 2023 The IREE Authors"
],
"metadata": {
"id": "UUXnh11hA75x"
}
},
{
"cell_type": "code",
"source": [
"#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
],
"metadata": {
"cellView": "form",
"id": "FqsvmKpjBJO2"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Ahead-of-time (AOT) export workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/ghost.svg\" height=\"20px\"> IREE\n",
"\n",
"This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for export from a PyTorch session to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n",
"\n",
"SHARK-Turbine contains both a \"simple\" AOT exporter and an underlying advanced\n",
"API for complicated models and full feature availability. This notebook only\n",
"uses the \"simple\" exporter."
],
"metadata": {
"id": "38UDc27KBPD1"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "jbcW5jMLK8gK"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"#@title Uninstall existing packages\n",
"# This avoids some warnings when installing specific PyTorch packages below.\n",
"!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
],
"metadata": {
"id": "KsPubQSvCbXd",
"cellView": "form"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Install Pytorch 2.3.0 (prerelease)\n",
"!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3ebWazfjJ6en",
"outputId": "a04009ab-92d2-4796-a476-e7912fc84410"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/test/cpu\n",
"Collecting torch==2.3.0\n",
" Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n",
"Installing collected packages: torch\n",
" Attempting uninstall: torch\n",
" Found existing installation: torch 2.2.1+cu121\n",
" Uninstalling torch-2.2.1+cu121:\n",
" Successfully uninstalled torch-2.2.1+cu121\n",
"Successfully installed torch-2.3.0+cpu\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4iJFDHbsAzo4",
"outputId": "72f7e43a-fbec-4140-d15b-9df10691e984"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting iree-turbine\n",
" Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n",
"Collecting iree-compiler>=20240410.859 (from iree-turbine)\n",
" Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n",
" Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m23.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n",
"Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n",
"Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n"
]
}
],
"source": [
"#@title Install iree-turbine\n",
"\n",
"!python -m pip install iree-turbine"
]
},
{
"cell_type": "code",
"source": [
"#@title Report version information\n",
"!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
"\n",
"!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
"!iree-compile --version\n",
"\n",
"import torch\n",
"print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nkVLzRpcDnVL",
"outputId": "c8687418-f2c4-42af-974f-2733a5797c1c"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Installed iree-turbine, Version: 2.3.0rc20240410\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
" IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n",
" LLVM version 19.0.0git\n",
" Optimized build\n",
"\n",
"Installed PyTorch, version: 2.3.0+cpu\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Sample AOT workflow\n",
"\n",
"1. Define a program using `torch.nn.Module`\n",
"2. Export the program using `aot.export()`\n",
"3. Compile to a deployable artifact\n",
" * a: By staying within a Python session\n",
" * b: By outputting MLIR and continuing using native tools\n",
"\n",
"Useful documentation:\n",
"\n",
"* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n",
"* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)"
],
"metadata": {
"id": "1Mi3YR75LBxl"
}
},
{
"cell_type": "code",
"source": [
"#@title 1. Define a program using `torch.nn.Module`\n",
"torch.manual_seed(0)\n",
"\n",
"class LinearModule(torch.nn.Module):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n",
" self.bias = torch.nn.Parameter(torch.randn(out_features))\n",
"\n",
" def forward(self, input):\n",
" return (input @ self.weight) + self.bias\n",
"\n",
"linear_module = LinearModule(4, 3)"
],
"metadata": {
"id": "oPdjrmPZMNz6"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 2. Export the program using `aot.export()`\n",
"import shark_turbine.aot as aot\n",
"\n",
"example_arg = torch.randn(4)\n",
"export_output = aot.export(linear_module, example_arg)"
],
"metadata": {
"id": "eK2fWVfiSQ8f"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 3a. Compile fully to a deployable artifact, in our existing Python session\n",
"\n",
"# Staying in Python gives the API a chance to reuse memory, improving\n",
"# performance when compiling large programs.\n",
"\n",
"compiled_binary = export_output.compile(save_to=None)\n",
"\n",
"# Use the IREE runtime API to test the compiled program.\n",
"import numpy as np\n",
"import iree.runtime as ireert\n",
"\n",
"config = ireert.Config(\"local-task\")\n",
"vm_module = ireert.load_vm_module(\n",
" ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),\n",
" config,\n",
")\n",
"\n",
"input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)\n",
"result = vm_module.main(input)\n",
"print(result.to_host())"
],
"metadata": {
"id": "eMRNdFdos900",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2b4db1b2-12d5-4b63-8f0e-9635dfe48277"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[ 1.4178505 -1.2343317 -7.4767942]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title 3b. Output MLIR then continue from Python or native tools later\n",
"\n",
"# Leaving Python allows for file system checkpointing and grants access to\n",
"# native development workflows.\n",
"\n",
"mlir_file_path = \"/tmp/linear_module_pytorch.mlirbc\"\n",
"vmfb_file_path = \"/tmp/linear_module_pytorch_llvmcpu.vmfb\"\n",
"\n",
"print(\"Exported .mlir:\")\n",
"export_output.print_readable()\n",
"export_output.save_mlir(mlir_file_path)\n",
"\n",
"print(\"Compiling and running...\")\n",
"!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu {mlir_file_path} -o {vmfb_file_path}\n",
"!iree-run-module --module={vmfb_file_path} --device=local-task --input=\"4xf32=[1.0, 2.0, 3.0, 4.0]\""
],
"metadata": {
"id": "0AdkXY8VNL2-",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "75ca37ad-0321-4a40-9f54-5c6613f01e9e"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Exported .mlir:\n",
"module @module {\n",
" func.func @main(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {\n",
" %int0 = torch.constant.int 0\n",
" %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n",
" %1 = torch.vtensor.literal(dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>) : !torch.vtensor<[4,3],f32>\n",
" %2 = torch.aten.mm %0, %1 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n",
" %int0_0 = torch.constant.int 0\n",
" %3 = torch.aten.squeeze.dim %2, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" %4 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>\n",
" %int1 = torch.constant.int 1\n",
" %5 = torch.aten.add.Tensor %3, %4, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" return %5 : !torch.vtensor<[3],f32>\n",
" }\n",
"}\n",
"\n",
"{-#\n",
" dialect_resources: {\n",
" builtin: {\n",
" torch_tensor_4_3_torch.float32: \"0x040000005C3FC53F503C96BE49710BC0B684113FA1D18ABF2D05B3BF7A83CE3EE588563F442138BF0B83CEBE18BD18BFC6673A3E\",\n",
" torch_tensor_3_torch.float32: \"0x04000000074F5BBF99E08C3FAB1C89BF\"\n",
" }\n",
" }\n",
"#-}\n",
"Compiling and running...\n",
"EXEC @main\n",
"result[0]: hal.buffer_view\n",
"3xf32=1.41785 -1.23433 -7.47679\n"
]
}
]
}
]
}