|  | { | 
|  | "nbformat": 4, | 
|  | "nbformat_minor": 0, | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "provenance": [], | 
|  | "collapsed_sections": [ | 
|  | "UUXnh11hA75x" | 
|  | ] | 
|  | }, | 
|  | "kernelspec": { | 
|  | "name": "python3", | 
|  | "display_name": "Python 3" | 
|  | }, | 
|  | "language_info": { | 
|  | "name": "python" | 
|  | } | 
|  | }, | 
|  | "cells": [ | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "##### Copyright 2023 The IREE Authors" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "UUXnh11hA75x" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", | 
|  | "# See https://llvm.org/LICENSE.txt for license information.\n", | 
|  | "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" | 
|  | ], | 
|  | "metadata": { | 
|  | "cellView": "form", | 
|  | "id": "FqsvmKpjBJO2" | 
|  | }, | 
|  | "execution_count": 1, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Ahead-of-time (AOT) export workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n", | 
|  | "\n", | 
|  | "This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for export from a PyTorch session to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n", | 
|  | "\n", | 
|  | "iree-turbine contains both a \"simple\" AOT exporter and an underlying advanced\n", | 
|  | "API for complicated models and full feature availability. This notebook shows\n", | 
|  | "some of the features available in the \"advanced\" toolkit." | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "38UDc27KBPD1" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "## Setup" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "jbcW5jMLK8gK" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "%%capture\n", | 
|  | "#@title Uninstall existing packages\n", | 
|  | "#   This avoids some warnings when installing specific PyTorch packages below.\n", | 
|  | "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "KsPubQSvCbXd", | 
|  | "cellView": "form" | 
|  | }, | 
|  | "execution_count": 2, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title Install Pytorch 2.3.0 (for CPU)\n", | 
|  | "!python -m pip install --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "oO1tirq2ggmO", | 
|  | "outputId": "2f36a84a-ac8d-453c-c8ed-3d19161b8866", | 
|  | "cellView": "form" | 
|  | }, | 
|  | "execution_count": 3, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n", | 
|  | "Collecting torch==2.3.0\n", | 
|  | "  Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n", | 
|  | "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.16.1)\n", | 
|  | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.12.2)\n", | 
|  | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.13.1)\n", | 
|  | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.4.2)\n", | 
|  | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.4)\n", | 
|  | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2024.10.0)\n", | 
|  | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (3.0.2)\n", | 
|  | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n", | 
|  | "Installing collected packages: torch\n", | 
|  | "  Attempting uninstall: torch\n", | 
|  | "    Found existing installation: torch 2.5.0+cu121\n", | 
|  | "    Uninstalling torch-2.5.0+cu121:\n", | 
|  | "      Successfully uninstalled torch-2.5.0+cu121\n", | 
|  | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | 
|  | "timm 1.0.11 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n", | 
|  | "\u001b[0mSuccessfully installed torch-2.3.0+cpu\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "execution_count": 4, | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "4iJFDHbsAzo4", | 
|  | "outputId": "0234c0d1-94b0-4a4d-8876-2feb320c4ae5" | 
|  | }, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Collecting iree-turbine\n", | 
|  | "  Downloading iree_turbine-2.5.0-py3-none-any.whl.metadata (5.7 kB)\n", | 
|  | "Requirement already satisfied: numpy>=1.26.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n", | 
|  | "Collecting iree-compiler (from iree-turbine)\n", | 
|  | "  Downloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (615 bytes)\n", | 
|  | "Collecting iree-runtime (from iree-turbine)\n", | 
|  | "  Downloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (798 bytes)\n", | 
|  | "Requirement already satisfied: torch>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n", | 
|  | "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.4)\n", | 
|  | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n", | 
|  | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (3.16.1)\n", | 
|  | "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (4.12.2)\n", | 
|  | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (1.13.1)\n", | 
|  | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (3.4.2)\n", | 
|  | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (2024.10.0)\n", | 
|  | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.3.0->iree-turbine) (1.3.0)\n", | 
|  | "Downloading iree_turbine-2.5.0-py3-none-any.whl (271 kB)\n", | 
|  | "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m271.3/271.3 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hDownloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (70.7 MB)\n", | 
|  | "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.7/70.7 MB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hDownloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl (8.0 MB)\n", | 
|  | "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m59.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | 
|  | "\u001b[?25hInstalling collected packages: iree-runtime, iree-compiler, iree-turbine\n", | 
|  | "Successfully installed iree-compiler-20241104.1068 iree-runtime-20241104.1068 iree-turbine-2.5.0\n" | 
|  | ] | 
|  | } | 
|  | ], | 
|  | "source": [ | 
|  | "#@title Install iree-turbine\n", | 
|  | "!python -m pip install iree-turbine" | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title Report version information\n", | 
|  | "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n", | 
|  | "\n", | 
|  | "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", | 
|  | "!iree-compile --version\n", | 
|  | "\n", | 
|  | "import torch\n", | 
|  | "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "nkVLzRpcDnVL", | 
|  | "outputId": "f989db15-1644-4c9e-f307-ea5e2abbbc82" | 
|  | }, | 
|  | "execution_count": 5, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "Installed iree-turbine, Version: 2.5.0\n", | 
|  | "\n", | 
|  | "Installed IREE, compiler version information:\n", | 
|  | "IREE (https://iree.dev):\n", | 
|  | "  IREE compiler version 20241104.1068 @ 9c85e30df30d6efcf68a7a1b594e89322bd6085d\n", | 
|  | "  LLVM version 20.0.0git\n", | 
|  | "  Optimized build\n", | 
|  | "\n", | 
|  | "Installed PyTorch, version: 2.3.0+cpu\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "markdown", | 
|  | "source": [ | 
|  | "## Advanced AOT toolkit examples\n", | 
|  | "\n", | 
|  | "1. Define a PyTorch program using `torch.nn.Module`\n", | 
|  | "2. Define the API and properties of that program by using `aot.CompiledModule`\n", | 
|  | "3. Export the program using `aot.export()`\n", | 
|  | "4. Compile to a deployable artifact\n", | 
|  | "  * a: By staying within a Python session\n", | 
|  | "  * b: By outputting MLIR and continuing using native tools\n", | 
|  | "\n", | 
|  | "Useful documentation:\n", | 
|  | "\n", | 
|  | "* [IREE PyTorch guide](https://iree.dev/guides/ml-frameworks/pytorch/)\n", | 
|  | "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n", | 
|  | "* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "1Mi3YR75LBxl" | 
|  | } | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title 1. Define a program using `torch.nn.Module`\n", | 
|  | "torch.manual_seed(0)\n", | 
|  | "\n", | 
|  | "class LinearModule(torch.nn.Module):\n", | 
|  | "  def __init__(self, in_features, out_features):\n", | 
|  | "    super().__init__()\n", | 
|  | "    self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n", | 
|  | "    self.bias = torch.nn.Parameter(torch.randn(out_features))\n", | 
|  | "\n", | 
|  | "  def forward(self, input):\n", | 
|  | "    return (input @ self.weight) + self.bias\n", | 
|  | "\n", | 
|  | "linear_module = LinearModule(4, 3)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "oPdjrmPZMNz6" | 
|  | }, | 
|  | "execution_count": 6, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title 2. Define the API and properties of that program by using aot.CompiledModule\n", | 
|  | "\n", | 
|  | "import iree.turbine.aot as aot\n", | 
|  | "\n", | 
|  | "example_weight = torch.randn(4, 3)\n", | 
|  | "example_bias = torch.randn(3)\n", | 
|  | "\n", | 
|  | "class CompiledLinearModule(aot.CompiledModule):\n", | 
|  | "  params = aot.export_parameters(linear_module, mutable=True)\n", | 
|  | "  compute = aot.jittable(linear_module.forward)\n", | 
|  | "\n", | 
|  | "  def main(self, x=aot.AbstractTensor(4)):\n", | 
|  | "    return self.compute(x)\n", | 
|  | "\n", | 
|  | "  def get_weight(self):\n", | 
|  | "    return self.params[\"weight\"]\n", | 
|  | "\n", | 
|  | "  def set_weight(self, weight=aot.abstractify(example_weight)):\n", | 
|  | "    self.params[\"weight\"] = weight\n", | 
|  | "\n", | 
|  | "  def get_bias(self):\n", | 
|  | "    return self.params[\"bias\"]\n", | 
|  | "\n", | 
|  | "  def set_bias(self, bias=aot.abstractify(example_bias)):\n", | 
|  | "    self.params[\"bias\"] = bias" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "Ua3tNtUIozoa" | 
|  | }, | 
|  | "execution_count": 8, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title 3. Export the program using `aot.export()`\n", | 
|  | "export_output = aot.export(CompiledLinearModule)" | 
|  | ], | 
|  | "metadata": { | 
|  | "id": "eK2fWVfiSQ8f" | 
|  | }, | 
|  | "execution_count": 9, | 
|  | "outputs": [] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title 4a. Compile fully to a deployable artifact, in our existing Python session\n", | 
|  | "\n", | 
|  | "# Staying in Python gives the API a chance to reuse memory, improving\n", | 
|  | "# performance when compiling large programs.\n", | 
|  | "\n", | 
|  | "compiled_binary = export_output.compile(save_to=None)\n", | 
|  | "\n", | 
|  | "# Use the IREE runtime API to test the compiled program.\n", | 
|  | "import numpy as np\n", | 
|  | "import iree.runtime as ireert\n", | 
|  | "\n", | 
|  | "config = ireert.Config(\"local-task\")\n", | 
|  | "vm_module = ireert.load_vm_module(\n", | 
|  | "    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),\n", | 
|  | "    config,\n", | 
|  | ")\n", | 
|  | "\n", | 
|  | "input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)\n", | 
|  | "result = vm_module.main(input)\n", | 
|  | "print(result.to_host())" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "eMRNdFdos900", | 
|  | "outputId": "fae696f2-8dbe-4873-f392-953673b6094f" | 
|  | }, | 
|  | "execution_count": 10, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "[ 1.4178505 -1.2343317 -7.4767942]\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | }, | 
|  | { | 
|  | "cell_type": "code", | 
|  | "source": [ | 
|  | "#@title 4b. Output MLIR then continue from Python or native tools later\n", | 
|  | "\n", | 
|  | "# Leaving Python allows for file system checkpointing and grants access to\n", | 
|  | "# native development workflows.\n", | 
|  | "\n", | 
|  | "mlir_file_path = \"/tmp/linear_module_pytorch.mlirbc\"\n", | 
|  | "vmfb_file_path = \"/tmp/linear_module_pytorch_llvmcpu.vmfb\"\n", | 
|  | "\n", | 
|  | "export_output.print_readable()\n", | 
|  | "export_output.save_mlir(mlir_file_path)\n", | 
|  | "\n", | 
|  | "!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host {mlir_file_path} -o {vmfb_file_path}\n", | 
|  | "!iree-run-module --module={vmfb_file_path} --device=local-task --function=main --input=\"4xf32=[1.0, 2.0, 3.0, 4.0]\"" | 
|  | ], | 
|  | "metadata": { | 
|  | "colab": { | 
|  | "base_uri": "https://localhost:8080/" | 
|  | }, | 
|  | "id": "0AdkXY8VNL2-", | 
|  | "outputId": "66335f65-2e9a-4a3a-b9ae-638490063a6e" | 
|  | }, | 
|  | "execution_count": 12, | 
|  | "outputs": [ | 
|  | { | 
|  | "output_type": "stream", | 
|  | "name": "stdout", | 
|  | "text": [ | 
|  | "module @compiled_linear {\n", | 
|  | "  util.global private mutable @_params.weight = dense_resource<_params.weight> : tensor<4x3xf32>\n", | 
|  | "  util.global private mutable @_params.bias = dense_resource<_params.bias> : tensor<3xf32>\n", | 
|  | "  func.func @main(%arg0: tensor<4xf32>) -> tensor<3xf32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", | 
|  | "    %0 = torch_c.from_builtin_tensor %arg0 : tensor<4xf32> -> !torch.vtensor<[4],f32>\n", | 
|  | "    %1 = call @forward(%0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32>\n", | 
|  | "    %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[3],f32> -> tensor<3xf32>\n", | 
|  | "    return %2 : tensor<3xf32>\n", | 
|  | "  }\n", | 
|  | "  func.func private @forward(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {\n", | 
|  | "    %int0 = torch.constant.int 0\n", | 
|  | "    %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n", | 
|  | "    %_params.weight = util.global.load @_params.weight : tensor<4x3xf32>\n", | 
|  | "    %1 = torch_c.from_builtin_tensor %_params.weight : tensor<4x3xf32> -> !torch.vtensor<[4,3],f32>\n", | 
|  | "    %2 = torch.aten.mm %0, %1 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n", | 
|  | "    %int0_0 = torch.constant.int 0\n", | 
|  | "    %3 = torch.aten.squeeze.dim %2, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", | 
|  | "    %_params.bias = util.global.load @_params.bias : tensor<3xf32>\n", | 
|  | "    %4 = torch_c.from_builtin_tensor %_params.bias : tensor<3xf32> -> !torch.vtensor<[3],f32>\n", | 
|  | "    %int1 = torch.constant.int 1\n", | 
|  | "    %5 = torch.aten.add.Tensor %3, %4, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n", | 
|  | "    return %5 : !torch.vtensor<[3],f32>\n", | 
|  | "  }\n", | 
|  | "  func.func @get_weight() -> tensor<4x3xf32> attributes {torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", | 
|  | "    %_params.weight = util.global.load @_params.weight : tensor<4x3xf32>\n", | 
|  | "    return %_params.weight : tensor<4x3xf32>\n", | 
|  | "  }\n", | 
|  | "  func.func @set_weight(%arg0: tensor<4x3xf32>) attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\"} {\n", | 
|  | "    util.global.store %arg0, @_params.weight : tensor<4x3xf32>\n", | 
|  | "    return\n", | 
|  | "  }\n", | 
|  | "  func.func @get_bias() -> tensor<3xf32> attributes {torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", | 
|  | "    %_params.bias = util.global.load @_params.bias : tensor<3xf32>\n", | 
|  | "    return %_params.bias : tensor<3xf32>\n", | 
|  | "  }\n", | 
|  | "  func.func @set_bias(%arg0: tensor<3xf32>) attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\"} {\n", | 
|  | "    util.global.store %arg0, @_params.bias : tensor<3xf32>\n", | 
|  | "    return\n", | 
|  | "  }\n", | 
|  | "}\n", | 
|  | "\n", | 
|  | "{-#\n", | 
|  | "  dialect_resources: {\n", | 
|  | "    builtin: {\n", | 
|  | "      _params.weight: \"0x040000005C3FC53F503C96BE49710BC0B684113FA1D18ABF2D05B3BF7A83CE3EE588563F442138BF0B83CEBE18BD18BFC6673A3E\",\n", | 
|  | "      _params.bias: \"0x04000000074F5BBF99E08C3FAB1C89BF\"\n", | 
|  | "    }\n", | 
|  | "  }\n", | 
|  | "#-}\n", | 
|  | "EXEC @main\n", | 
|  | "result[0]: hal.buffer_view\n", | 
|  | "3xf32=1.41785 -1.23433 -7.47679\n" | 
|  | ] | 
|  | } | 
|  | ] | 
|  | } | 
|  | ] | 
|  | } |