blob: da9614b94942faa74170bf676a8e4dd49a800c21 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "UUXnh11hA75x"
},
"source": [
"##### Copyright 2023 The IREE Authors"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"cellView": "form",
"id": "FqsvmKpjBJO2"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "38UDc27KBPD1"
},
"source": [
"# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Just-in-time (JIT) workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n",
"\n",
"This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for eager execution within a PyTorch session using [IREE](https://github.com/iree-org/iree) and [torch-mlir](https://github.com/llvm/torch-mlir) under the covers."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jbcW5jMLK8gK"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"cellView": "form",
"id": "KsPubQSvCbXd"
},
"outputs": [],
"source": [
"%%capture\n",
"#@title Uninstall existing packages\n",
"# This avoids some warnings when installing specific PyTorch packages below.\n",
"!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KHbDmehBWuDW",
"outputId": "b52afa55-bba8-415f-f1ee-1add3cc1740e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/cpu\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (2.8.0+cu126)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch) (3.19.1)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch) (4.15.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch) (75.2.0)\n",
"Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch) (1.13.3)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch) (3.1.6)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch) (2025.3.0)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.80)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch) (9.10.2.21)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.4.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch) (11.3.0.4)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch) (10.3.7.77)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch) (11.7.1.2)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch) (12.5.4.2)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch) (0.7.1)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch) (2.27.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.77)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch) (12.6.85)\n",
"Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch) (1.11.1.6)\n",
"Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch) (3.4.0)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch) (1.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch) (3.0.3)\n"
]
}
],
"source": [
"#@title Install Pytorch 2.8+ (supports Python 3.12)\n",
"!python -m pip install torch --index-url https://download.pytorch.org/whl/cpu"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4iJFDHbsAzo4",
"outputId": "e37b1217-6104-46ca-f694-02d8758a90a3"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting iree-turbine\n",
" Downloading iree_turbine-3.7.0-py3-none-any.whl.metadata (7.4 kB)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (2.0.2)\n",
"Collecting iree-base-compiler (from iree-turbine)\n",
" Downloading iree_base_compiler-3.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.3 kB)\n",
"Collecting iree-base-runtime (from iree-turbine)\n",
" Downloading iree_base_runtime-3.7.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (1.3 kB)\n",
"Requirement already satisfied: Jinja2>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (3.1.6)\n",
"Requirement already satisfied: ml_dtypes>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (0.5.3)\n",
"Requirement already satisfied: typing_extensions in /usr/local/lib/python3.12/dist-packages (from iree-turbine) (4.15.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from Jinja2>=3.1.4->iree-turbine) (3.0.3)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.12/dist-packages (from iree-base-compiler->iree-turbine) (1.13.3)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy->iree-base-compiler->iree-turbine) (1.3.0)\n",
"Downloading iree_turbine-3.7.0-py3-none-any.whl (306 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m306.7/306.7 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading iree_base_compiler-3.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (81.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m81.5/81.5 MB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading iree_base_runtime-3.7.0-cp312-cp312-manylinux_2_28_x86_64.whl (8.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.1/8.1 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: iree-base-runtime, iree-base-compiler, iree-turbine\n",
"Successfully installed iree-base-compiler-3.7.1 iree-base-runtime-3.7.0 iree-turbine-3.7.0\n"
]
}
],
"source": [
"#@title Install iree-turbine\n",
"!python -m pip install iree-turbine"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nkVLzRpcDnVL",
"outputId": "9b639275-97ec-4a05-f800-0eb0be1c6b71"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Installed iree-turbine, Version: 3.7.0\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
" IREE compiler version 3.7.1 @ cb2048e0e7694c76ac1af36d5e3450aa3625995c\n",
" LLVM version 22.0.0git\n",
" Optimized build\n",
"\n",
"Installed PyTorch, version: 2.8.0+cu126\n"
]
}
],
"source": [
"#@title Report version information\n",
"!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
"\n",
"!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
"!iree-compile --version\n",
"\n",
"import torch\n",
"print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1Mi3YR75LBxl"
},
"source": [
"## Sample JIT workflow\n",
"\n",
"1. **(Optional)** Set `TURBINE_LOG_LEVEL=debug` to see verbose compilation output including intermediate MLIR IR\n",
"2. Define a program using `torch.nn.Module`\n",
"3. Run `torch.compile(module, backend=\"turbine_cpu\")`\n",
"4. Use the resulting `OptimizedModule` as you would a regular `nn.Module`\n",
"\n",
"Useful documentation:\n",
"\n",
"* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n",
"* [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) as an interface to TorchDynamo and optimizing using backend compilers like Turbine"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6s1U-N91sE-a",
"outputId": "5caf5f2f-ba05-41fc-8f93-bd4ac81c748b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"env: TURBINE_LOG_LEVEL=debug\n"
]
}
],
"source": [
"#@title Enable debug logging to see MLIR output\n",
"%env TURBINE_LOG_LEVEL=debug"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "oPdjrmPZMNz6"
},
"outputs": [],
"source": [
"torch.manual_seed(0)\n",
"\n",
"class LinearModule(torch.nn.Module):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n",
" self.bias = torch.nn.Parameter(torch.randn(out_features))\n",
"\n",
" def forward(self, input):\n",
" return (input @ self.weight) + self.bias\n",
"\n",
"linear_module = LinearModule(4, 3)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eK2fWVfiSQ8f",
"outputId": "1c1596e6-4361-486e-a12f-dd6d52ef195b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Compiled module using Turbine. New module type is <class 'torch._dynamo.eval_frame.OptimizedModule'>\n"
]
}
],
"source": [
"opt_linear_module = torch.compile(linear_module, backend=\"turbine_cpu\")\n",
"print(\"Compiled module using Turbine. New module type is\", type(opt_linear_module))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0AdkXY8VNL2-",
"outputId": "47fbd84f-f76d-44be-f746-914f0a3cda64"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"DEBUG 10-05 09:26:00 [base.py:24] Traced Graph Module:\n",
"<lambda>()\n",
"\n",
"\n",
"\n",
"def forward(self, arg0_1, arg1_1, arg2_1):\n",
" unsqueeze = torch.ops.aten.unsqueeze.default(arg1_1, 0); arg1_1 = None\n",
" mm = torch.ops.aten.mm.default(unsqueeze, arg0_1); arg0_1 = None\n",
" squeeze = torch.ops.aten.squeeze.dim(mm, 0); mm = None\n",
" add = torch.ops.aten.add.Tensor(squeeze, arg2_1); squeeze = arg2_1 = None\n",
" return (add, unsqueeze)\n",
" \n",
"# To see more debug info, please use `graph_module.print_readable()`\n",
"DEBUG 10-05 09:26:00 [base.py:28] Successfully imported gm to mlir:\n",
"module {\n",
" func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[3],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {\n",
" %int0 = torch.constant.int 0\n",
" %0 = torch.aten.unsqueeze %arg1, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n",
" %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n",
" %int0_0 = torch.constant.int 0\n",
" %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" %int1 = torch.constant.int 1\n",
" %3 = torch.aten.add.Tensor %2, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>\n",
" }\n",
"}\n",
"\n",
"DEBUG 10-05 09:26:00 [device.py:648] Creating turbine device for torch.device = device(type='cpu')\n",
"DEBUG 10-05 09:26:03 [launch.py:193] Cached new module for local-task:--iree-hal-target-backends=llvm-cpu;--iree-llvmcpu-target-cpu-features=host\n",
"DEBUG 10-05 09:26:03 [launch.py:148] Cached new binary for local-task:0:None\n",
"DEBUG 10-05 09:26:03 [launch.py:171] Launching cached binary for local-task:0:None\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Weight: Parameter containing:\n",
"tensor([[ 1.5410, -0.2934, -2.1788],\n",
" [ 0.5684, -1.0845, -1.3986],\n",
" [ 0.4033, 0.8380, -0.7193],\n",
" [-0.4033, -0.5966, 0.1820]], requires_grad=True)\n",
"Bias: Parameter containing:\n",
"tensor([-0.8567, 1.1006, -1.0712], requires_grad=True)\n",
"Args: tensor([ 0.1227, -0.5663, 0.3731, -0.8920])\n",
"Output: tensor([-0.4792, 2.5237, -0.9772], grad_fn=<CompiledFunctionBackward>)\n"
]
}
],
"source": [
"args = torch.randn(4)\n",
"turbine_output = opt_linear_module(args)\n",
"\n",
"print(\"Weight:\", linear_module.weight)\n",
"print(\"Bias:\", linear_module.bias)\n",
"print(\"Args:\", args)\n",
"print(\"Output:\", turbine_output)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"UUXnh11hA75x"
],
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}