blob: 74a79679484fbac190f2af5fee80619a00dd439e [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"UUXnh11hA75x"
]
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"##### Copyright 2023 The IREE Authors"
],
"metadata": {
"id": "UUXnh11hA75x"
}
},
{
"cell_type": "code",
"source": [
"#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
],
"metadata": {
"cellView": "form",
"id": "FqsvmKpjBJO2"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Just-in-time (JIT) workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n",
"\n",
"This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for eager execution within a PyTorch session using [IREE](https://github.com/iree-org/iree) and [torch-mlir](https://github.com/llvm/torch-mlir) under the covers."
],
"metadata": {
"id": "38UDc27KBPD1"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "jbcW5jMLK8gK"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"#@title Uninstall existing packages\n",
"# This avoids some warnings when installing specific PyTorch packages below.\n",
"!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
],
"metadata": {
"id": "KsPubQSvCbXd",
"cellView": "form"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Install Pytorch 2.3.0 (prerelease)\n",
"!python -m pip install --pre --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0"
],
"metadata": {
"id": "KHbDmehBWuDW",
"outputId": "c2af25cd-58c9-4757-bdda-1124f9f2aa88",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/test/cpu\n",
"Collecting torch==2.3.0\n",
" Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.13.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2023.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n",
"Installing collected packages: torch\n",
" Attempting uninstall: torch\n",
" Found existing installation: torch 2.2.1+cu121\n",
" Uninstalling torch-2.2.1+cu121:\n",
" Successfully uninstalled torch-2.2.1+cu121\n",
"Successfully installed torch-2.3.0+cpu\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4iJFDHbsAzo4",
"outputId": "13397b7e-42cd-4f14-d6e8-a21fa2d1f524"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting iree-turbine\n",
" Downloading iree_turbine-2.3.0rc20240410-py3-none-any.whl (150 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.4/150.4 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.25.2)\n",
"Collecting iree-compiler>=20240410.859 (from iree-turbine)\n",
" Downloading iree_compiler-20240410.859-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (64.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 MB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting iree-runtime>=20240410.859 (from iree-turbine)\n",
" Downloading iree_runtime-20240410.859-cp310-cp310-manylinux_2_28_x86_64.whl (7.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m31.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n",
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20240410.859->iree-turbine) (6.0.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.13.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (4.11.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (3.1.3)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->iree-turbine) (2023.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->iree-turbine) (2.1.5)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->iree-turbine) (1.3.0)\n",
"Installing collected packages: iree-runtime, iree-compiler, iree-turbine\n",
"Successfully installed iree-compiler-20240410.859 iree-runtime-20240410.859 iree-turbine-2.3.0rc20240410\n"
]
}
],
"source": [
"#@title Install iree-turbine\n",
"!python -m pip install iree-turbine"
]
},
{
"cell_type": "code",
"source": [
"#@title Report version information\n",
"!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
"\n",
"!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
"!iree-compile --version\n",
"\n",
"import torch\n",
"print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nkVLzRpcDnVL",
"outputId": "230b113c-6800-45e2-ec93-eddefa439803"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Installed iree-turbine, Version: 2.3.0rc20240410\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
" IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1\n",
" LLVM version 19.0.0git\n",
" Optimized build\n",
"\n",
"Installed PyTorch, version: 2.3.0+cpu\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Sample JIT workflow\n",
"\n",
"1. Define a program using `torch.nn.Module`\n",
"2. Run `torch.compile(module, backend=\"turbine_cpu\")`\n",
"3. Use the resulting `OptimizedModule` as you would a regular `nn.Module`\n",
"\n",
"Useful documentation:\n",
"\n",
"* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n",
"* [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) as an interface to TorchDynamo and optimizing using backend compilers like Turbine"
],
"metadata": {
"id": "1Mi3YR75LBxl"
}
},
{
"cell_type": "code",
"source": [
"torch.manual_seed(0)\n",
"\n",
"class LinearModule(torch.nn.Module):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n",
" self.bias = torch.nn.Parameter(torch.randn(out_features))\n",
"\n",
" def forward(self, input):\n",
" return (input @ self.weight) + self.bias\n",
"\n",
"linear_module = LinearModule(4, 3)"
],
"metadata": {
"id": "oPdjrmPZMNz6"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"opt_linear_module = torch.compile(linear_module, backend=\"turbine_cpu\")\n",
"print(\"Compiled module using Turbine. New module type is\", type(opt_linear_module))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eK2fWVfiSQ8f",
"outputId": "337f5077-22e3-4ff0-cde1-38b08ff5bea5"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Compiled module using Turbine. New module type is <class 'torch._dynamo.eval_frame.OptimizedModule'>\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"args = torch.randn(4)\n",
"turbine_output = opt_linear_module(args)\n",
"\n",
"print(\"Weight:\", linear_module.weight)\n",
"print(\"Bias:\", linear_module.bias)\n",
"print(\"Args:\", args)\n",
"print(\"Output:\", turbine_output)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0AdkXY8VNL2-",
"outputId": "0f19bdd4-15ff-43ce-b9a7-6fa1113d124c"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"module {\n",
" func.func @main(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[4],f32>) -> (!torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>) {\n",
" %int0 = torch.constant.int 0\n",
" %0 = torch.aten.unsqueeze %arg2, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n",
" %1 = torch.aten.mm %0, %arg0 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n",
" %int0_0 = torch.constant.int 0\n",
" %2 = torch.aten.squeeze.dim %1, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" %int1 = torch.constant.int 1\n",
" %3 = torch.aten.add.Tensor %2, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" return %3, %0 : !torch.vtensor<[3],f32>, !torch.vtensor<[1,4],f32>\n",
" }\n",
"}\n",
"\n",
"#map = affine_map<(d0) -> (d0)>\n",
"module {\n",
" util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.fence, %arg4: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view) attributes {inlining_policy = #util.inline.never, iree.abi.model = \"coarse-fences\", iree.abi.stub} {\n",
" %cst = arith.constant 0.000000e+00 : f32\n",
" %0 = hal.tensor.import wait(%arg3) => %arg0 : !hal.buffer_view -> tensor<4x3xf32>\n",
" %1 = hal.tensor.import wait(%arg3) => %arg1 : !hal.buffer_view -> tensor<3xf32>\n",
" %2 = hal.tensor.import wait(%arg3) => %arg2 : !hal.buffer_view -> tensor<4xf32>\n",
" %expanded = tensor.expand_shape %2 [[0, 1]] : tensor<4xf32> into tensor<1x4xf32>\n",
" %3 = tensor.empty() : tensor<1x3xf32>\n",
" %4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x3xf32>) -> tensor<1x3xf32>\n",
" %5 = linalg.matmul ins(%expanded, %0 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%4 : tensor<1x3xf32>) -> tensor<1x3xf32>\n",
" %collapsed = tensor.collapse_shape %5 [[0, 1]] : tensor<1x3xf32> into tensor<3xf32>\n",
" %6 = tensor.empty() : tensor<3xf32>\n",
" %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = [\"parallel\"]} ins(%collapsed, %1 : tensor<3xf32>, tensor<3xf32>) outs(%6 : tensor<3xf32>) {\n",
" ^bb0(%in: f32, %in_0: f32, %out: f32):\n",
" %11 = arith.addf %in, %in_0 : f32\n",
" linalg.yield %11 : f32\n",
" } -> tensor<3xf32>\n",
" %8:2 = hal.tensor.barrier join(%7, %expanded : tensor<3xf32>, tensor<1x4xf32>) => %arg4 : !hal.fence\n",
" %9 = hal.tensor.export %8#0 : tensor<3xf32> -> !hal.buffer_view\n",
" %10 = hal.tensor.export %8#1 : tensor<1x4xf32> -> !hal.buffer_view\n",
" util.return %9, %10 : !hal.buffer_view, !hal.buffer_view\n",
" }\n",
" util.func public @main(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view) -> (!hal.buffer_view, !hal.buffer_view) attributes {iree.abi.stub} {\n",
" %c-1_i32 = arith.constant -1 : i32\n",
" %c0 = arith.constant 0 : index\n",
" %device_0 = hal.devices.get %c0 : !hal.device\n",
" %0 = util.null : !hal.fence\n",
" %fence = hal.fence.create device(%device_0 : !hal.device) flags(\"None\") : !hal.fence\n",
" %1:2 = util.call @main$async(%arg0, %arg1, %arg2, %0, %fence) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.fence, !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)\n",
" %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32\n",
" util.return %1#0, %1#1 : !hal.buffer_view, !hal.buffer_view\n",
" }\n",
"}\n",
"\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Weight: Parameter containing:\n",
"tensor([[ 1.5410, -0.2934, -2.1788],\n",
" [ 0.5684, -1.0845, -1.3986],\n",
" [ 0.4033, 0.8380, -0.7193],\n",
" [-0.4033, -0.5966, 0.1820]], requires_grad=True)\n",
"Bias: Parameter containing:\n",
"tensor([-0.8567, 1.1006, -1.0712], requires_grad=True)\n",
"Args: tensor([ 0.1227, -0.5663, 0.3731, -0.8920])\n",
"Output: tensor([-0.4792, 2.5237, -0.9772], grad_fn=<CompiledFunctionBackward>)\n"
]
}
]
}
]
}