| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "provenance": [], |
| "collapsed_sections": [ |
| "UUXnh11hA75x" |
| ] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "##### Copyright 2023 The IREE Authors" |
| ], |
| "metadata": { |
| "id": "UUXnh11hA75x" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "metadata": { |
| "cellView": "form", |
| "id": "FqsvmKpjBJO2" |
| }, |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Just-in-time (JIT) workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", |
| "\n", |
| "This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for eager execution within a PyTorch session using [IREE](https://github.com/iree-org/iree) and [torch-mlir](https://github.com/llvm/torch-mlir) under the covers." |
| ], |
| "metadata": { |
| "id": "38UDc27KBPD1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Setup" |
| ], |
| "metadata": { |
| "id": "jbcW5jMLK8gK" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ], |
| "metadata": { |
| "id": "KsPubQSvCbXd", |
| "cellView": "form" |
| }, |
| "execution_count": 2, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Install Pytorch 2.3.0 (prerelease)\n", |
| "!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0" |
| ], |
| "metadata": { |
| "id": "KHbDmehBWuDW", |
| "outputId": "c2af25cd-58c9-4757-bdda-1124f9f2aa88", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| } |
| }, |
| "execution_count": 3, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n", |
| "Collecting torch==2.3.0\n", |
| " Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n", |
| "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n", |
| "Installing collected packages: torch\n", |
| " Attempting uninstall: torch\n", |
| " Found existing installation: torch 2.2.1+cu121\n", |
| " Uninstalling torch-2.2.1+cu121:\n", |
| " Successfully uninstalled torch-2.2.1+cu121\n", |
| "Successfully installed torch-2.3.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "4iJFDHbsAzo4", |
| "outputId": "13397b7e-42cd-4f14-d6e8-a21fa2d1f524" |
| }, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting iree-turbine\n", |
| " Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n", |
| "Collecting iree-compiler>=20240410.859 (from iree-turbine)\n", |
| " Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n", |
| " Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m31.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n", |
| "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n", |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n", |
| "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n", |
| "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n", |
| "Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n", |
| "Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Install iree-turbine\n", |
| "!python -m pip install iree-turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "nkVLzRpcDnVL", |
| "outputId": "230b113c-6800-45e2-ec93-eddefa439803" |
| }, |
| "execution_count": 5, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed iree-turbine, Version: 2.3.0rc20240410\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n", |
| " LLVM version 19.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.3.0+cpu\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Sample JIT workflow\n", |
| "\n", |
| "1. Define a program using `torch.nn.Module`\n", |
| "2. Run `torch.compile(module, backend=\"turbine_cpu\")`\n", |
| "3. Use the resulting `OptimizedModule` as you would a regular `nn.Module`\n", |
| "\n", |
| "Useful documentation:\n", |
| "\n", |
| "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n", |
| "* [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) as an interface to TorchDynamo and optimizing using backend compilers like Turbine" |
| ], |
| "metadata": { |
| "id": "1Mi3YR75LBxl" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "torch.manual_seed(0)\n", |
| "\n", |
| "class LinearModule(torch.nn.Module):\n", |
| " def __init__(self, in_features, out_features):\n", |
| " super().__init__()\n", |
| " self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n", |
| " self.bias = torch.nn.Parameter(torch.randn(out_features))\n", |
| "\n", |
| " def forward(self, input):\n", |
| " return (input @ self.weight) + self.bias\n", |
| "\n", |
| "linear_module = LinearModule(4, 3)" |
| ], |
| "metadata": { |
| "id": "oPdjrmPZMNz6" |
| }, |
| "execution_count": 6, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "opt_linear_module = torch.compile(linear_module, backend=\"turbine_cpu\")\n", |
| "print(\"Compiled module using Turbine. New module type is\", type(opt_linear_module))" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "eK2fWVfiSQ8f", |
| "outputId": "337f5077-22e3-4ff0-cde1-38b08ff5bea5" |
| }, |
| "execution_count": 7, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Compiled module using Turbine. New module type is <class 'torch._dynamo.eval_frame.OptimizedModule'>\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "args = torch.randn(4)\n", |
| "turbine_output = opt_linear_module(args)\n", |
| "\n", |
| "print(\"Weight:\", linear_module.weight)\n", |
| "print(\"Bias:\", linear_module.bias)\n", |
| "print(\"Args:\", args)\n", |
| "print(\"Output:\", turbine_output)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "0AdkXY8VNL2-", |
| "outputId": "0f19bdd4-15ff-43ce-b9a7-6fa1113d124c" |
| }, |
| "execution_count": 8, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stderr", |
| "text": [ |
| "module {\n", |
| " func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[4],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {\n", |
| " %int0 = torch.constant.int 0\n", |
| " %0 = torch.aten.unsqueeze %arg2, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n", |
| " %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n", |
| " %int0_0 = torch.constant.int 0\n", |
| " %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " %int1 = torch.constant.int 1\n", |
| " %3 = torch.aten.add.Tensor %2, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>\n", |
| " }\n", |
| "}\n", |
| "\n", |
| "#map = affine_map<(d0) -> (d0)>\n", |
| "module {\n", |
| " util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = \"coarse-fences\", iree.abi.stub} {\n", |
| " %cst = arith.constant 0.000000e+00 : f32\n", |
| " %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<4x3xf32>\n", |
| " %1 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<3xf32>\n", |
| " %2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<4xf32>\n", |
| " %expanded = tensor.expand_shape %2 [[0, 1]] : tensor<4xf32> into tensor<1x4xf32>\n", |
| " %3 = tensor.empty() : tensor<1x3xf32>\n", |
| " %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x3xf32>) -> tensor<1x3xf32>\n", |
| " %5 = linalg.matmul ins(%expanded, %0 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%4 : tensor<1x3xf32>) -> tensor<1x3xf32>\n", |
| " %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<1x3xf32> into tensor<3xf32>\n", |
| " %6 = tensor.empty() : tensor<3xf32>\n", |
| " %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = [\"parallel\"]} ins(%collapsed, %1 : tensor<3xf32>, tensor<3xf32>) outs(%6 : tensor<3xf32>) {\n", |
| " ^bb0(%in: f32, %in_0: f32, %out: f32):\n", |
| " %11 = arith.addf %in, %in_0 : f32\n", |
| " linalg.yield %11 : f32\n", |
| " } -> tensor<3xf32>\n", |
| " %8:2 = hal.tensor.barrier join(%7, %expanded : tensor<3xf32>, tensor<1x4xf32>) => %arg4 : !hal.fence\n", |
| " %9 = hal.tensor.export %8#0 : tensor<3xf32> -> !hal.buffer_view\n", |
| " %10 = hal.tensor.export %8#1 : tensor<1x4xf32> -> !hal.buffer_view\n", |
| " util.return %9, %10 : !hal.buffer_view, !hal.buffer_view\n", |
| " }\n", |
| " util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {\n", |
| " %c-1_i32 = arith.constant -1 : i32\n", |
| " %c0 = arith.constant 0 : index\n", |
| " %device_0 = hal.devices.get %c0 : !hal.device\n", |
| " %0 = util.null : !hal.fence\n", |
| " %fence = hal.fence.create device(%device_0 : !hal.device) flags(\"None\") : !hal.fence\n", |
| " %1:2 = util.call @main$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)\n", |
| " %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32\n", |
| " util.return %1#0, %1#1 : !hal.buffer_view, !hal.buffer_view\n", |
| " }\n", |
| "}\n", |
| "\n" |
| ] |
| }, |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Weight: Parameter containing:\n", |
| "tensor([[ 1.5410, -0.2934, -2.1788],\n", |
| " [ 0.5684, -1.0845, -1.3986],\n", |
| " [ 0.4033, 0.8380, -0.7193],\n", |
| " [-0.4033, -0.5966, 0.1820]], requires_grad=True)\n", |
| "Bias: Parameter containing:\n", |
| "tensor([-0.8567, 1.1006, -1.0712], requires_grad=True)\n", |
| "Args: tensor([ 0.1227, -0.5663, 0.3731, -0.8920])\n", |
| "Output: tensor([-0.4792, 2.5237, -0.9772], grad_fn=<CompiledFunctionBackward>)\n" |
| ] |
| } |
| ] |
| } |
| ] |
| } |