| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "provenance": [], |
| "collapsed_sections": [ |
| "UUXnh11hA75x" |
| ] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "##### Copyright 2023 The IREE Authors" |
| ], |
| "metadata": { |
| "id": "UUXnh11hA75x" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "metadata": { |
| "cellView": "form", |
| "id": "FqsvmKpjBJO2" |
| }, |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Ahead-of-time (AOT) export workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", |
| "\n", |
| "This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for export from a PyTorch session to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n", |
| "\n", |
| "iree-turbine contains both a \"simple\" AOT exporter and an underlying advanced\n", |
| "API for complicated models and full feature availability. This notebook only\n", |
| "uses the \"simple\" exporter." |
| ], |
| "metadata": { |
| "id": "38UDc27KBPD1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Setup" |
| ], |
| "metadata": { |
| "id": "jbcW5jMLK8gK" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ], |
| "metadata": { |
| "id": "KsPubQSvCbXd", |
| "cellView": "form" |
| }, |
| "execution_count": 2, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Install Pytorch 2.3.0 (for CPU)\n", |
| "!python -m pip install --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "3ebWazfjJ6en", |
| "outputId": "913c0ccf-9895-4c3f-cd12-20c8185c747d", |
| "cellView": "form" |
| }, |
| "execution_count": 3, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n", |
| "Collecting torch==2.3.0\n", |
| " Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.16.1)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.12.2)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.13.1)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.4.2)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.4)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2024.10.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (3.0.2)\n", |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n", |
| "Installing collected packages: torch\n", |
| " Attempting uninstall: torch\n", |
| " Found existing installation: torch 2.5.0+cu121\n", |
| " Uninstalling torch-2.5.0+cu121:\n", |
| " Successfully uninstalled torch-2.5.0+cu121\n", |
| "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", |
| "timm 1.0.11 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n", |
| "\u001b[0mSuccessfully installed torch-2.3.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "4iJFDHbsAzo4", |
| "outputId": "5ba8936d-0cf4-40e8-d012-d987ee9ad9e3" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting iree-turbine\n", |
| " Downloading iree_turbine-2.5.0-py3-none-any.whl.metadata (5.7 kB)\n", |
| "Requirement already satisfied: numpy>=1.26.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n", |
| "Collecting iree-compiler (from iree-turbine)\n", |
| " Downloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (615 bytes)\n", |
| "Collecting iree-runtime (from iree-turbine)\n", |
| " Downloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (798 bytes)\n", |
| "Requirement already satisfied: torch>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n", |
| "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.4)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n", |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (3.16.1)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (4.12.2)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (1.13.1)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (3.4.2)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (2024.10.0)\n", |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.3.0->iree-turbine) (1.3.0)\n", |
| "Downloading iree_turbine-2.5.0-py3-none-any.whl (271 kB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m271.3/271.3 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (70.7 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.7/70.7 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl (8.0 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m52.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hInstalling collected packages: iree-runtime, iree-compiler, iree-turbine\n", |
| "Successfully installed iree-compiler-20241104.1068 iree-runtime-20241104.1068 iree-turbine-2.5.0\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Install iree-turbine\n", |
| "\n", |
| "!python -m pip install iree-turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "nkVLzRpcDnVL", |
| "outputId": "0481a14b-e261-4ea5-8d4e-228078b775f6" |
| }, |
| "execution_count": 5, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed iree-turbine, Version: 2.5.0\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 20241104.1068 @ 9c85e30df30d6efcf68a7a1b594e89322bd6085d\n", |
| " LLVM version 20.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.3.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Sample AOT workflow\n", |
| "\n", |
| "1. Define a program using `torch.nn.Module`\n", |
| "2. Export the program using `aot.export()`\n", |
| "3. Compile to a deployable artifact\n", |
| " * a: By staying within a Python session\n", |
| " * b: By outputting MLIR and continuing using native tools\n", |
| "\n", |
| "Useful documentation:\n", |
| "\n", |
| "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n", |
| "* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)" |
| ], |
| "metadata": { |
| "id": "1Mi3YR75LBxl" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 1. Define a program using `torch.nn.Module`\n", |
| "torch.manual_seed(0)\n", |
| "\n", |
| "class LinearModule(torch.nn.Module):\n", |
| " def __init__(self, in_features, out_features):\n", |
| " super().__init__()\n", |
| " self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n", |
| " self.bias = torch.nn.Parameter(torch.randn(out_features))\n", |
| "\n", |
| " def forward(self, input):\n", |
| " return (input @ self.weight) + self.bias\n", |
| "\n", |
| "linear_module = LinearModule(4, 3)" |
| ], |
| "metadata": { |
| "id": "oPdjrmPZMNz6" |
| }, |
| "execution_count": 6, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 2. Export the program using `aot.export()`\n", |
| "import iree.turbine.aot as aot\n", |
| "\n", |
| "example_arg = torch.randn(4)\n", |
| "export_output = aot.export(linear_module, example_arg)" |
| ], |
| "metadata": { |
| "id": "eK2fWVfiSQ8f" |
| }, |
| "execution_count": 7, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 3a. Compile fully to a deployable artifact, in our existing Python session\n", |
| "\n", |
| "# Staying in Python gives the API a chance to reuse memory, improving\n", |
| "# performance when compiling large programs.\n", |
| "\n", |
| "compiled_binary = export_output.compile(save_to=None)\n", |
| "\n", |
| "# Use the IREE runtime API to test the compiled program.\n", |
| "import numpy as np\n", |
| "import iree.runtime as ireert\n", |
| "\n", |
| "config = ireert.Config(\"local-task\")\n", |
| "vm_module = ireert.load_vm_module(\n", |
| " ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),\n", |
| " config,\n", |
| ")\n", |
| "\n", |
| "input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)\n", |
| "result = vm_module.main(input)\n", |
| "print(result.to_host())" |
| ], |
| "metadata": { |
| "id": "eMRNdFdos900", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "073486bf-1590-4838-ebc8-b76b80e87d32" |
| }, |
| "execution_count": 8, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "[ 1.4178505 -1.2343317 -7.4767942]\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 3b. Output MLIR then continue from Python or native tools later\n", |
| "\n", |
| "# Leaving Python allows for file system checkpointing and grants access to\n", |
| "# native development workflows.\n", |
| "\n", |
| "mlir_file_path = \"/tmp/linear_module_pytorch.mlirbc\"\n", |
| "vmfb_file_path = \"/tmp/linear_module_pytorch_llvmcpu.vmfb\"\n", |
| "\n", |
| "print(\"Exported .mlir:\")\n", |
| "export_output.print_readable()\n", |
| "export_output.save_mlir(mlir_file_path)\n", |
| "\n", |
| "print(\"Compiling and running...\")\n", |
| "!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host {mlir_file_path} -o {vmfb_file_path}\n", |
| "!iree-run-module --module={vmfb_file_path} --device=local-task --input=\"4xf32=[1.0, 2.0, 3.0, 4.0]\"" |
| ], |
| "metadata": { |
| "id": "0AdkXY8VNL2-", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "3381872c-d048-4f99-b35d-a0ebfef00dfc" |
| }, |
| "execution_count": 9, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Exported .mlir:\n", |
| "module @module {\n", |
| " func.func @main(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> attributes {torch.assume_strict_symbolic_shapes} {\n", |
| " %int0 = torch.constant.int 0\n", |
| " %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n", |
| " %1 = torch.vtensor.literal(dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>) : !torch.vtensor<[4,3],f32>\n", |
| " %2 = torch.aten.mm %0, %1 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n", |
| " %int0_0 = torch.constant.int 0\n", |
| " %3 = torch.aten.squeeze.dim %2, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " %4 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>\n", |
| " %int1 = torch.constant.int 1\n", |
| " %5 = torch.aten.add.Tensor %3, %4, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " return %5 : !torch.vtensor<[3],f32>\n", |
| " }\n", |
| "}\n", |
| "\n", |
| "{-#\n", |
| " dialect_resources: {\n", |
| " builtin: {\n", |
| " torch_tensor_4_3_torch.float32: \"0x040000005C3FC53F503C96BE49710BC0B684113FA1D18ABF2D05B3BF7A83CE3EE588563F442138BF0B83CEBE18BD18BFC6673A3E\",\n", |
| " torch_tensor_3_torch.float32: \"0x04000000074F5BBF99E08C3FAB1C89BF\"\n", |
| " }\n", |
| " }\n", |
| "#-}\n", |
| "Compiling and running...\n", |
| "EXEC @main\n", |
| "result[0]: hal.buffer_view\n", |
| "3xf32=1.41785 -1.23433 -7.47679\n" |
| ] |
| } |
| ] |
| } |
| ] |
| } |