blob: 1fee225a0fa045a1b05ff2749d1f6ed721cdebd6 [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"UUXnh11hA75x"
]
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"##### Copyright 2023 The IREE Authors"
],
"metadata": {
"id": "UUXnh11hA75x"
}
},
{
"cell_type": "code",
"source": [
"#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
],
"metadata": {
"cellView": "form",
"id": "FqsvmKpjBJO2"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Ahead-of-time (AOT) export workflows using <img src=\"https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg\" height=\"20px\"> IREE\n",
"\n",
"This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for export from a PyTorch session to [IREE](https://github.com/openxla/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n",
"\n",
"SHARK-Turbine contains both a \"simple\" AOT exporter and an underlying advanced\n",
"API for complicated models and full feature availability. This notebook only\n",
"uses the \"simple\" exporter."
],
"metadata": {
"id": "38UDc27KBPD1"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "jbcW5jMLK8gK"
}
},
{
"cell_type": "code",
"source": [
"%%capture\n",
"#@title Uninstall existing packages\n",
"# This avoids some warnings when installing specific PyTorch packages below.\n",
"!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
],
"metadata": {
"id": "KsPubQSvCbXd"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "4iJFDHbsAzo4",
"outputId": "4484964e-9163-4694-c5fc-1c4054bfbb84"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"google.colab.output.setIframeHeight(0, true, {maxHeight: 300})"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting shark-turbine\n",
" Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)\n",
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/60.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━\u001b[0m \u001b[32m51.2/60.2 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.2/60.2 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (1.23.5)\n",
"Collecting iree-compiler>=20231004.665 (from shark-turbine)\n",
" Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n",
" Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m95.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting torch>=2.1.0 (from shark-turbine)\n",
" Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m670.2/670.2 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m62.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m49.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m64.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-nccl-cu12==2.18.1 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_nccl_cu12-2.18.1-py3-none-manylinux1_x86_64.whl (209.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.8/209.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting triton==2.1.0 (from torch>=2.1.0->shark-turbine)\n",
" Downloading triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.2/89.2 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.1.0->shark-turbine)\n",
" Downloading nvidia_nvjitlink_cu12-12.2.140-py3-none-manylinux1_x86_64.whl (20.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.2/20.2 MB\u001b[0m \u001b[31m77.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n",
"Building wheels for collected packages: shark-turbine\n",
" Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=c498fa5dd115d1b796c30744cde960da1f63a85ea783c567832c27e896dcd3ee\n",
" Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n",
"Successfully built shark-turbine\n",
"Installing collected packages: triton, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, iree-runtime, iree-compiler, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch, shark-turbine\n",
" Attempting uninstall: triton\n",
" Found existing installation: triton 2.0.0\n",
" Uninstalling triton-2.0.0:\n",
" Successfully uninstalled triton-2.0.0\n",
" Attempting uninstall: torch\n",
" Found existing installation: torch 2.0.1+cu118\n",
" Uninstalling torch-2.0.1+cu118:\n",
" Successfully uninstalled torch-2.0.1+cu118\n",
"Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.18.1 nvidia-nvjitlink-cu12-12.2.140 nvidia-nvtx-cu12-12.1.105 shark-turbine-0.9.1.dev3 torch-2.1.0 triton-2.1.0\n"
]
}
],
"source": [
"#@title Install SHARK-Turbine\n",
"\n",
"# Limit cell height.\n",
"from IPython.display import Javascript\n",
"display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n",
"\n",
"!python -m pip install shark-turbine"
]
},
{
"cell_type": "code",
"source": [
"#@title Report version information\n",
"!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n",
"\n",
"!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
"!iree-compile --version\n",
"\n",
"import torch\n",
"print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nkVLzRpcDnVL",
"outputId": "46ec1bc5-3720-40fd-e2d6-3a617903b317"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Installed SHARK-Turbine, Version: 0.9.1.dev3\n",
"\n",
"Installed IREE, compiler version information:\n",
"IREE (https://iree.dev):\n",
" IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n",
" LLVM version 18.0.0git\n",
" Optimized build\n",
"\n",
"Installed PyTorch, version: 2.1.0+cu121\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Sample AOT workflow\n",
"\n",
"1. Define a program using `torch.nn.Module`\n",
"2. Export the program using `aot.export()`\n",
"3. Compile to a deployable artifact\n",
" * a: By staying within a Python session\n",
" * b: By outputting MLIR and continuing using native tools\n",
"\n",
"Useful documentation:\n",
"\n",
"* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n",
"* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)"
],
"metadata": {
"id": "1Mi3YR75LBxl"
}
},
{
"cell_type": "code",
"source": [
"#@title 1. Define a program using `torch.nn.Module`\n",
"torch.manual_seed(0)\n",
"\n",
"class LinearModule(torch.nn.Module):\n",
" def __init__(self, in_features, out_features):\n",
" super().__init__()\n",
" self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n",
" self.bias = torch.nn.Parameter(torch.randn(out_features))\n",
"\n",
" def forward(self, input):\n",
" return (input @ self.weight) + self.bias\n",
"\n",
"linear_module = LinearModule(4, 3)"
],
"metadata": {
"id": "oPdjrmPZMNz6"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 2. Export the program using `aot.export()`\n",
"import shark_turbine.aot as aot\n",
"\n",
"example_arg = torch.randn(4)\n",
"export_output = aot.export(linear_module, example_arg)"
],
"metadata": {
"id": "eK2fWVfiSQ8f"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 3a. Compile fully to a deployable artifact, in our existing Python session\n",
"\n",
"# Staying in Python gives the API a chance to reuse memory, improving\n",
"# performance when compiling large programs.\n",
"\n",
"compiled_binary = export_output.compile(save_to=None)\n",
"\n",
"# Use the IREE runtime API to test the compiled program.\n",
"import numpy as np\n",
"import iree.runtime as ireert\n",
"\n",
"config = ireert.Config(\"local-task\")\n",
"vm_module = ireert.load_vm_module(\n",
" ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),\n",
" config,\n",
")\n",
"\n",
"input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)\n",
"result = vm_module.main(input)\n",
"print(result.to_host())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eMRNdFdos900",
"outputId": "db4a575b-562c-40ab-dd8c-fba3bbe3a68f"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[ 1.4178504 -1.2343317 -7.4767947]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title 3b. Output MLIR then continue from Python or native tools later\n",
"\n",
"# Leaving Python allows for file system checkpointing and grants access to\n",
"# native development workflows.\n",
"\n",
"mlir_file_path = \"/tmp/linear_module_pytorch.mlirbc\"\n",
"vmfb_file_path = \"/tmp/linear_module_pytorch_llvmcpu.vmfb\"\n",
"\n",
"export_output.print_readable()\n",
"export_output.save_mlir(mlir_file_path)\n",
"\n",
"!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu {mlir_file_path} -o {vmfb_file_path}\n",
"!iree-run-module --module={vmfb_file_path} --device=local-task --input=\"4xf32=[1.0, 2.0, 3.0, 4.0]\""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0AdkXY8VNL2-",
"outputId": "f521d749-a4b7-45a3-ba1a-4e67267ef244"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"module @LinearModule {\n",
" util.global private @_params.weight {noinline} = dense<[[1.54099607, -0.293428898, -2.17878938], [0.568431258, -1.08452237, -1.39859545], [0.403346837, 0.838026344, -0.719257593], [-0.403343529, -0.596635341, 0.182036489]]> : tensor<4x3xf32>\n",
" util.global private @_params.bias {noinline} = dense<[-0.856674611, 1.10060418, -1.07118738]> : tensor<3xf32>\n",
" func.func @main(%arg0: tensor<4xf32>) -> tensor<3xf32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n",
" %0 = torch_c.from_builtin_tensor %arg0 : tensor<4xf32> -> !torch.vtensor<[4],f32>\n",
" %1 = call @forward(%0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32>\n",
" %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3],f32> -> tensor<3xf32>\n",
" return %2 : tensor<3xf32>\n",
" }\n",
" func.func private @forward(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {\n",
" %int0 = torch.constant.int 0\n",
" %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n",
" %_params.weight = util.global.load @_params.weight : tensor<4x3xf32>\n",
" %1 = torch_c.from_builtin_tensor %_params.weight : tensor<4x3xf32> -> !torch.vtensor<[4,3],f32>\n",
" %2 = torch.aten.mm %0, %1 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n",
" %int0_0 = torch.constant.int 0\n",
" %3 = torch.aten.squeeze.dim %2, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" %_params.bias = util.global.load @_params.bias : tensor<3xf32>\n",
" %4 = torch_c.from_builtin_tensor %_params.bias : tensor<3xf32> -> !torch.vtensor<[3],f32>\n",
" %int1 = torch.constant.int 1\n",
" %5 = torch.aten.add.Tensor %3, %4, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
" return %5 : !torch.vtensor<[3],f32>\n",
" }\n",
"}\n",
"EXEC @main\n",
"result[0]: hal.buffer_view\n",
"3xf32=1.41785 -1.23433 -7.47679\n"
]
}
]
}
]
}