| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "provenance": [], |
| "collapsed_sections": [ |
| "UUXnh11hA75x" |
| ] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| }, |
| "language_info": { |
| "name": "python" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "##### Copyright 2023 The IREE Authors" |
| ], |
| "metadata": { |
| "id": "UUXnh11hA75x" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "metadata": { |
| "cellView": "form", |
| "id": "FqsvmKpjBJO2" |
| }, |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Ahead-of-time (AOT) export workflows using <img src=\"https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg\" height=\"20px\"> IREE\n", |
| "\n", |
| "This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for export from a PyTorch session to [IREE](https://github.com/openxla/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n", |
| "\n", |
| "SHARK-Turbine contains both a \"simple\" AOT exporter and an underlying advanced\n", |
| "API for complicated models and full feature availability. This notebook only\n", |
| "uses the \"simple\" exporter." |
| ], |
| "metadata": { |
| "id": "38UDc27KBPD1" |
| } |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Setup" |
| ], |
| "metadata": { |
| "id": "jbcW5jMLK8gK" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ], |
| "metadata": { |
| "id": "KsPubQSvCbXd" |
| }, |
| "execution_count": 2, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 300 |
| }, |
| "id": "4iJFDHbsAzo4", |
| "outputId": "4484964e-9163-4694-c5fc-1c4054bfbb84" |
| }, |
| "outputs": [ |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "<IPython.core.display.Javascript object>" |
| ], |
| "application/javascript": [ |
| "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" |
| ] |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting shark-turbine\n", |
| " Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)\n", |
| "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/60.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━\u001b[0m \u001b[32m51.2/60.2 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.2/60.2 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", |
| " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", |
| " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", |
| "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (1.23.5)\n", |
| "Collecting iree-compiler>=20231004.665 (from shark-turbine)\n", |
| " Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n", |
| " Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m95.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting torch>=2.1.0 (from shark-turbine)\n", |
| " Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m670.2/670.2 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n", |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n", |
| "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n", |
| "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m62.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m49.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m64.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.8/209.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting triton==2.1.0 (from torch>=2.1.0->shark-turbine)\n", |
| " Downloading triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.2/89.2 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.1.0->shark-turbine)\n", |
| " Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.2/20.2 MB\u001b[0m \u001b[31m77.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n", |
| "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n", |
| "Building wheels for collected packages: shark-turbine\n", |
| " Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", |
| " Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=c498fa5dd115d1b796c30744cde960da1f63a85ea783c567832c27e896dcd3ee\n", |
| " Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n", |
| "Successfully built shark-turbine\n", |
| "Installing collected packages: triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, iree-runtime, iree-compiler, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, shark-turbine\n", |
| " Attempting uninstall: triton\n", |
| " Found existing installation: triton 2.0.0\n", |
| " Uninstalling triton-2.0.0:\n", |
| " Successfully uninstalled triton-2.0.0\n", |
| " Attempting uninstall: torch\n", |
| " Found existing installation: torch 2.0.1+cu118\n", |
| " Uninstalling torch-2.0.1+cu118:\n", |
| " Successfully uninstalled torch-2.0.1+cu118\n", |
| "Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 shark-turbine-0.9.1.dev3 torch-2.1.0 triton-2.1.0\n" |
| ] |
| } |
| ], |
| "source": [ |
| "#@title Install SHARK-Turbine\n", |
| "\n", |
| "# Limit cell height.\n", |
| "from IPython.display import Javascript\n", |
| "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n", |
| "\n", |
| "!python -m pip install shark-turbine" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "nkVLzRpcDnVL", |
| "outputId": "46ec1bc5-3720-40fd-e2d6-3a617903b317" |
| }, |
| "execution_count": 4, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed SHARK-Turbine, Version: 0.9.1.dev3\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n", |
| " LLVM version 18.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.1.0+cu121\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Sample AOT workflow\n", |
| "\n", |
| "1. Define a program using `torch.nn.Module`\n", |
| "2. Export the program using `aot.export()`\n", |
| "3. Compile to a deployable artifact\n", |
| " * a: By staying within a Python session\n", |
| " * b: By outputting MLIR and continuing using native tools\n", |
| "\n", |
| "Useful documentation:\n", |
| "\n", |
| "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n", |
| "* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)" |
| ], |
| "metadata": { |
| "id": "1Mi3YR75LBxl" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 1. Define a program using `torch.nn.Module`\n", |
| "torch.manual_seed(0)\n", |
| "\n", |
| "class LinearModule(torch.nn.Module):\n", |
| " def __init__(self, in_features, out_features):\n", |
| " super().__init__()\n", |
| " self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n", |
| " self.bias = torch.nn.Parameter(torch.randn(out_features))\n", |
| "\n", |
| " def forward(self, input):\n", |
| " return (input @ self.weight) + self.bias\n", |
| "\n", |
| "linear_module = LinearModule(4, 3)" |
| ], |
| "metadata": { |
| "id": "oPdjrmPZMNz6" |
| }, |
| "execution_count": 5, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 2. Export the program using `aot.export()`\n", |
| "import shark_turbine.aot as aot\n", |
| "\n", |
| "example_arg = torch.randn(4)\n", |
| "export_output = aot.export(linear_module, example_arg)" |
| ], |
| "metadata": { |
| "id": "eK2fWVfiSQ8f" |
| }, |
| "execution_count": 6, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 3a. Compile fully to a deployable artifact, in our existing Python session\n", |
| "\n", |
| "# Staying in Python gives the API a chance to reuse memory, improving\n", |
| "# performance when compiling large programs.\n", |
| "\n", |
| "compiled_binary = export_output.compile(save_to=None)\n", |
| "\n", |
| "# Use the IREE runtime API to test the compiled program.\n", |
| "import numpy as np\n", |
| "import iree.runtime as ireert\n", |
| "\n", |
| "config = ireert.Config(\"local-task\")\n", |
| "vm_module = ireert.load_vm_module(\n", |
| " ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),\n", |
| " config,\n", |
| ")\n", |
| "\n", |
| "input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)\n", |
| "result = vm_module.main(input)\n", |
| "print(result.to_host())" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "eMRNdFdos900", |
| "outputId": "db4a575b-562c-40ab-dd8c-fba3bbe3a68f" |
| }, |
| "execution_count": 7, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "[ 1.4178504 -1.2343317 -7.4767947]\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title 3b. Output MLIR then continue from Python or native tools later\n", |
| "\n", |
| "# Leaving Python allows for file system checkpointing and grants access to\n", |
| "# native development workflows.\n", |
| "\n", |
| "mlir_file_path = \"/tmp/linear_module_pytorch.mlirbc\"\n", |
| "vmfb_file_path = \"/tmp/linear_module_pytorch_llvmcpu.vmfb\"\n", |
| "\n", |
| "export_output.print_readable()\n", |
| "export_output.save_mlir(mlir_file_path)\n", |
| "\n", |
| "!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu {mlir_file_path} -o {vmfb_file_path}\n", |
| "!iree-run-module --module={vmfb_file_path} --device=local-task --input=\"4xf32=[1.0, 2.0, 3.0, 4.0]\"" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "0AdkXY8VNL2-", |
| "outputId": "f521d749-a4b7-45a3-ba1a-4e67267ef244" |
| }, |
| "execution_count": 8, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "module @LinearModule {\n", |
| " util.global private @_params.weight {noinline} = dense<[[1.54099607, -0.293428898, -2.17878938], [0.568431258, -1.08452237, -1.39859545], [0.403346837, 0.838026344, -0.719257593], [-0.403343529, -0.596635341, 0.182036489]]> : tensor<4x3xf32>\n", |
| " util.global private @_params.bias {noinline} = dense<[-0.856674611, 1.10060418, -1.07118738]> : tensor<3xf32>\n", |
| " func.func @main(%arg0: tensor<4xf32>) -> tensor<3xf32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", |
| " %0 = torch_c.from_builtin_tensor %arg0 : tensor<4xf32> -> !torch.vtensor<[4],f32>\n", |
| " %1 = call @forward(%0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32>\n", |
| " %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3],f32> -> tensor<3xf32>\n", |
| " return %2 : tensor<3xf32>\n", |
| " }\n", |
| " func.func private @forward(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {\n", |
| " %int0 = torch.constant.int 0\n", |
| " %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n", |
| " %_params.weight = util.global.load @_params.weight : tensor<4x3xf32>\n", |
| " %1 = torch_c.from_builtin_tensor %_params.weight : tensor<4x3xf32> -> !torch.vtensor<[4,3],f32>\n", |
| " %2 = torch.aten.mm %0, %1 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n", |
| " %int0_0 = torch.constant.int 0\n", |
| " %3 = torch.aten.squeeze.dim %2, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " %_params.bias = util.global.load @_params.bias : tensor<3xf32>\n", |
| " %4 = torch_c.from_builtin_tensor %_params.bias : tensor<3xf32> -> !torch.vtensor<[3],f32>\n", |
| " %int1 = torch.constant.int 1\n", |
| " %5 = torch.aten.add.Tensor %3, %4, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", |
| " return %5 : !torch.vtensor<[3],f32>\n", |
| " }\n", |
| "}\n", |
| "EXEC @main\n", |
| "result[0]: hal.buffer_view\n", |
| "3xf32=1.41785 -1.23433 -7.47679\n" |
| ] |
| } |
| ] |
| } |
| ] |
| } |