| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "UUXnh11hA75x" |
| }, |
| "source": [ |
| "##### Copyright 2023 The IREE Authors" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "metadata": { |
| "cellView": "form", |
| "id": "FqsvmKpjBJO2" |
| }, |
| "outputs": [], |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "38UDc27KBPD1" |
| }, |
| "source": [ |
| "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Just-in-time (JIT) workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", |
| "\n", |
| "This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for eager execution within a PyTorch session using [IREE](https://github.com/iree-org/iree) and [torch-mlir](https://github.com/llvm/torch-mlir) under the covers." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "jbcW5jMLK8gK" |
| }, |
| "source": [ |
| "## Setup" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "metadata": { |
| "cellView": "form", |
| "id": "KsPubQSvCbXd" |
| }, |
| "outputs": [], |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "KHbDmehBWuDW", |
| "outputId": "b52afa55-bba8-415f-f1ee-1add3cc1740e" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Looking in indexes: https://download.pytorch.org/whl/cpu\n", |
| "Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.8.0+cu126)\n", |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.19.1)\n", |
| "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)\n", |
| "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)\n", |
| "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.13.3)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch) (3.5)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)\n", |
| "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", |
| "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", |
| "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.80)\n", |
| "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch) (9.10.2.21)\n", |
| "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.4.1)\n", |
| "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.0.4)\n", |
| "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.7.77)\n", |
| "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.1.2)\n", |
| "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.4.2)\n", |
| "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch) (0.7.1)\n", |
| "Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch) (2.27.3)\n", |
| "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n", |
| "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.85)\n", |
| "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch) (1.11.1.6)\n", |
| "Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.4.0)\n", |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.3)\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Install Pytorch 2.8+ (supports Python 3.12)\n", |
| "!python -m pip install torch --index-url https://download.pytorch.org/whl/cpu" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "4iJFDHbsAzo4", |
| "outputId": "e37b1217-6104-46ca-f694-02d8758a90a3" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting iree-turbine\n", |
| " Downloading iree_turbine-3.7.0-py3-none-any.whl.metadata (7.4 kB)\n", |
| "Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (2.0.2)\n", |
| "Collecting iree-base-compiler (from iree-turbine)\n", |
| " Downloading iree_base_compiler-3.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.3 kB)\n", |
| "Collecting iree-base-runtime (from iree-turbine)\n", |
| " Downloading iree_base_runtime-3.7.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (1.3 kB)\n", |
| "Requirement already satisfied: Jinja2>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (3.1.6)\n", |
| "Requirement already satisfied: ml_dtypes>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (0.5.3)\n", |
| "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (4.15.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from Jinja2>=3.1.4->iree-turbine) (3.0.3)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.12/dist-packages (from iree-base-compiler->iree-turbine) (1.13.3)\n", |
| "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n", |
| "Downloading iree_turbine-3.7.0-py3-none-any.whl (306 kB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m306.7/306.7 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading iree_base_compiler-3.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (81.5 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.5/81.5 MB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hDownloading iree_base_runtime-3.7.0-cp312-cp312-manylinux_2_28_x86_64.whl (8.1 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.1/8.1 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hInstalling collected packages: iree-base-runtime, iree-base-compiler, iree-turbine\n", |
| "Successfully installed iree-base-compiler-3.7.1 iree-base-runtime-3.7.0 iree-turbine-3.7.0\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Install iree-turbine\n", |
| "!python -m pip install iree-turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "nkVLzRpcDnVL", |
| "outputId": "9b639275-97ec-4a05-f800-0eb0be1c6b71" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed iree-turbine, Version: 3.7.0\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 3.7.1 @ cb2048e0e7694c76ac1af36d5e3450aa3625995c\n", |
| " LLVM version 22.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.8.0+cu126\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "1Mi3YR75LBxl" |
| }, |
| "source": [ |
| "## Sample JIT workflow\n", |
| "\n", |
| "1. **(Optional)** Set `TURBINE_LOG_LEVEL=debug` to see verbose compilation output including intermediate MLIR IR\n", |
| "2. Define a program using `torch.nn.Module`\n", |
| "3. Run `torch.compile(module, backend=\"turbine_cpu\")`\n", |
| "4. Use the resulting `OptimizedModule` as you would a regular `nn.Module`\n", |
| "\n", |
| "Useful documentation:\n", |
| "\n", |
| "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n", |
| "* [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) as an interface to TorchDynamo and optimizing using backend compilers like Turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "6s1U-N91sE-a", |
| "outputId": "5caf5f2f-ba05-41fc-8f93-bd4ac81c748b" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "env: TURBINE_LOG_LEVEL=debug\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Enable debug logging to see MLIR output\n", |
| "%env TURBINE_LOG_LEVEL=debug" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "metadata": { |
| "id": "oPdjrmPZMNz6" |
| }, |
| "outputs": [], |
| "source": [ |
| "torch.manual_seed(0)\n", |
| "\n", |
| "class LinearModule(torch.nn.Module):\n", |
| " def __init__(self, in_features, out_features):\n", |
| " super().__init__()\n", |
| " self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n", |
| " self.bias = torch.nn.Parameter(torch.randn(out_features))\n", |
| "\n", |
| " def forward(self, input):\n", |
| " return (input @ self.weight) + self.bias\n", |
| "\n", |
| "linear_module = LinearModule(4, 3)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "eK2fWVfiSQ8f", |
| "outputId": "1c1596e6-4361-486e-a12f-dd6d52ef195b" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Compiled module using Turbine. New module type is <class 'torch._dynamo.eval_frame.OptimizedModule'>\n" |
| ] |
| } |
| ], |
| "source": [ |
| "opt_linear_module = torch.compile(linear_module, backend=\"turbine_cpu\")\n", |
| "print(\"Compiled module using Turbine. New module type is\", type(opt_linear_module))" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "0AdkXY8VNL2-", |
| "outputId": "47fbd84f-f76d-44be-f746-914f0a3cda64" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stderr", |
| "text": [ |
| "DEBUG 10-05 09:26:00 [base.py:24] Traced Graph Module:\n", |
| "<lambda>()\n", |
| "\n", |
| "\n", |
| "\n", |
| "def forward(self, arg0_1, arg1_1, arg2_1):\n", |
| " unsqueeze = torch.ops.aten.unsqueeze.default(arg1_1, 0); arg1_1 = None\n", |
| " mm = torch.ops.aten.mm.default(unsqueeze, arg0_1); arg0_1 = None\n", |
| " squeeze = torch.ops.aten.squeeze.dim(mm, 0); mm = None\n", |
| " add = torch.ops.aten.add.Tensor(squeeze, arg2_1); squeeze = arg2_1 = None\n", |
| " return (add, unsqueeze)\n", |
| " \n", |
| "# To see more debug info, please use `graph_module.print_readable()`\n", |
| "DEBUG 10-05 09:26:00 [base.py:28] Successfully imported gm to mlir:\n", |
| "module {\n", |
| " func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[3],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {\n", |
| " %int0 = torch.constant.int 0\n", |
| " %0 = torch.aten.unsqueeze %arg1, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n", |
| " %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n", |
| " %int0_0 = torch.constant.int 0\n", |
| " %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " %int1 = torch.constant.int 1\n", |
| " %3 = torch.aten.add.Tensor %2, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>\n", |
| " }\n", |
| "}\n", |
| "\n", |
| "DEBUG 10-05 09:26:00 [device.py:648] Creating turbine device for torch.device = device(type='cpu')\n", |
| "DEBUG 10-05 09:26:03 [launch.py:193] Cached new module for local-task:--iree-hal-target-backends=llvm-cpu;--iree-llvmcpu-target-cpu-features=host\n", |
| "DEBUG 10-05 09:26:03 [launch.py:148] Cached new binary for local-task:0:None\n", |
| "DEBUG 10-05 09:26:03 [launch.py:171] Launching cached binary for local-task:0:None\n" |
| ] |
| }, |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Weight: Parameter containing:\n", |
| "tensor([[ 1.5410, -0.2934, -2.1788],\n", |
| " [ 0.5684, -1.0845, -1.3986],\n", |
| " [ 0.4033, 0.8380, -0.7193],\n", |
| " [-0.4033, -0.5966, 0.1820]], requires_grad=True)\n", |
| "Bias: Parameter containing:\n", |
| "tensor([-0.8567, 1.1006, -1.0712], requires_grad=True)\n", |
| "Args: tensor([ 0.1227, -0.5663, 0.3731, -0.8920])\n", |
| "Output: tensor([-0.4792, 2.5237, -0.9772], grad_fn=<CompiledFunctionBackward>)\n" |
| ] |
| } |
| ], |
| "source": [ |
| "args = torch.randn(4)\n", |
| "turbine_output = opt_linear_module(args)\n", |
| "\n", |
| "print(\"Weight:\", linear_module.weight)\n", |
| "print(\"Bias:\", linear_module.bias)\n", |
| "print(\"Args:\", args)\n", |
| "print(\"Output:\", turbine_output)" |
| ] |
| } |
| ], |
| "metadata": { |
| "colab": { |
| "collapsed_sections": [ |
| "UUXnh11hA75x" |
| ], |
| "provenance": [] |
| }, |
| "kernelspec": { |
| "display_name": "Python 3", |
| "name": "python3" |
| }, |
| "language_info": { |
| "name": "python", |
| "version": "3.11.6" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 0 |
| } |