| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "provenance": [], |
| "collapsed_sections": [ |
| "FH3IRpYTta2v" |
| ] |
| }, |
| "kernelspec": { |
| "display_name": "Python 3", |
| "name": "python3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "FH3IRpYTta2v" |
| }, |
| "source": [ |
| "##### Copyright 2023 The IREE Authors" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "mWGa71_Ct2ug", |
| "cellView": "form" |
| }, |
| "source": [ |
| "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n", |
| "# See https://llvm.org/LICENSE.txt for license information.\n", |
| "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception" |
| ], |
| "execution_count": 1, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "h5s6ncerSpc5" |
| }, |
| "source": [ |
| "# Dynamic Shapes\n", |
| "\n", |
| "This notebook\n", |
| "\n", |
| "1. Creates a PyTorch program with dynamic shapes using [iree-turbine](https://github.com/iree-org/iree-turbine)'s advanced AOT toolkit\n", |
| "2. Compiles the program to an IREE VM bytecode module\n", |
| "3. Tests running the compiled VM module using IREE's runtime\n", |
| "4. Downloads compilation artifacts for use with the native (C API) sample application" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "s2bScbYkP6VZ", |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "outputId": "7b268798-a20d-4df4-f00d-ed7811f77767" |
| }, |
| "source": [ |
| "#@title General setup\n", |
| "\n", |
| "import os\n", |
| "import tempfile\n", |
| "\n", |
| "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n", |
| "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n", |
| "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")" |
| ], |
| "execution_count": 1, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Using artifacts directory '/tmp/iree/colab_artifacts'\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "%%capture\n", |
| "#@title Uninstall existing packages\n", |
| "# This avoids some warnings when installing specific PyTorch packages below.\n", |
| "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision" |
| ], |
| "metadata": { |
| "id": "y9KOsqosg6Ms" |
| }, |
| "execution_count": 2, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Install SHARK-Turbine\n", |
| "\n", |
| "# Limit cell height.\n", |
| "from IPython.display import Javascript\n", |
| "display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))\n", |
| "\n", |
| "!python -m pip install shark-turbine" |
| ], |
| "metadata": { |
| "id": "SdCAvI3sqBO7", |
| "outputId": "2be248d9-bf6b-475e-c44a-aa529f20de23", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 300 |
| } |
| }, |
| "execution_count": 3, |
| "outputs": [ |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "<IPython.core.display.Javascript object>" |
| ], |
| "application/javascript": [ |
| "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" |
| ] |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Collecting shark-turbine\n", |
| " Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)\n", |
| "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/60.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.2/60.2 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", |
| " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", |
| " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", |
| "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (1.23.5)\n", |
| "Collecting iree-compiler>=20231004.665 (from shark-turbine)\n", |
| " Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.2/57.2 MB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)\n", |
| " Downloading iree_runtime-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (7.8 MB)\n", |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m91.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
| "\u001b[?25hRequirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from shark-turbine) (2.1.0+cu118)\n", |
| "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from iree-compiler>=20231004.665->shark-turbine) (6.0.1)\n", |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.12.4)\n", |
| "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (4.5.0)\n", |
| "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (1.12)\n", |
| "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1)\n", |
| "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (3.1.2)\n", |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2023.6.0)\n", |
| "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.1.0->shark-turbine) (2.1.0)\n", |
| "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.1.0->shark-turbine) (2.1.3)\n", |
| "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.1.0->shark-turbine) (1.3.0)\n", |
| "Building wheels for collected packages: shark-turbine\n", |
| " Building wheel for shark-turbine (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", |
| " Created wheel for shark-turbine: filename=shark_turbine-0.9.1.dev3-py3-none-any.whl size=70102 sha256=507dec827b9a2eea18f47c6ebdc84347c9956b8f2e0b186d3107a006e0742d81\n", |
| " Stored in directory: /root/.cache/pip/wheels/e9/78/0f/88c9d8224ef1550fe00b18a014eab5121f26264e2261f31926\n", |
| "Successfully built shark-turbine\n", |
| "Installing collected packages: iree-runtime, iree-compiler, shark-turbine\n", |
| "Successfully installed iree-compiler-20231004.665 iree-runtime-20231004.665 shark-turbine-0.9.1.dev3\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Report version information\n", |
| "!echo \"Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)\"\n", |
| "\n", |
| "!echo -e \"\\nInstalled IREE, compiler version information:\"\n", |
| "!iree-compile --version\n", |
| "\n", |
| "import torch\n", |
| "print(\"\\nInstalled PyTorch, version:\", torch.__version__)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "Oj5I6R9LI7t_", |
| "outputId": "35d79e6a-7bd0-46e1-8113-5af1a7bcbb5b" |
| }, |
| "execution_count": 4, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Installed SHARK-Turbine, Version: 0.9.1.dev3\n", |
| "\n", |
| "Installed IREE, compiler version information:\n", |
| "IREE (https://iree.dev):\n", |
| " IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447\n", |
| " LLVM version 18.0.0git\n", |
| " Optimized build\n", |
| "\n", |
| "Installed PyTorch, version: 2.1.0+cu118\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Create a program using PyTorch + SHARK-Turbine\n", |
| "\n", |
| "NOTE: as in other domains, providing more information to a compiler allows it\n", |
| "to generate more efficient code. As a general rule, the slowest varying\n", |
| "dimensions of program data like batch index or timestep are safer to treat as\n", |
| "dynamic than faster varying dimensions like image x/y/channel. See\n", |
| "[this paper](https://arxiv.org/pdf/2006.03031.pdf) for a discussion of the\n", |
| "challenges imposed by dynamic shapes and one project's approach to addressing\n", |
| "them." |
| ], |
| "metadata": { |
| "id": "C3mhaullI940" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "#@title Define a sample `shark_turbine.aot.CompiledModule` using dynamic shapes\n", |
| "\n", |
| "import shark_turbine.aot as aot\n", |
| "\n", |
| "class DynamicShapesModule(aot.CompiledModule, export_name=\"module\"):\n", |
| " # reduce_sum_1d (dynamic input size, static output size)\n", |
| " # tensor<?xi32> -> tensor<i32>\n", |
| " # e.g. [1, 2, 3] -> 6\n", |
| " def reduce_sum_1d(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n", |
| " return self.compute_reduce_sum_1d(values)\n", |
| "\n", |
| " @aot.jittable\n", |
| " def compute_reduce_sum_1d(values):\n", |
| " return torch.sum(values, dtype=torch.int32)\n", |
| "\n", |
| " # reduce_sum_2d (partially dynamic input size, static output size)\n", |
| " # tensor<?x3xi32> -> tensor<3xi32>\n", |
| " # e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n", |
| " def reduce_sum_2d(self, values=aot.AbstractTensor(None, 3, dtype=torch.int32)):\n", |
| " return self.compute_reduce_sum_2d(values)\n", |
| "\n", |
| " @aot.jittable\n", |
| " def compute_reduce_sum_2d(values):\n", |
| " return torch.sum(values, 0, dtype=torch.int32)\n", |
| "\n", |
| " # add_one (dynamic input size, dynamic output size)\n", |
| " # tensor<?xi32>) -> tensor<?xi32>\n", |
| " # e.g. [1, 2, 3] -> [2, 3, 4]\n", |
| " def add_one(self, values=aot.AbstractTensor(None, dtype=torch.int32)):\n", |
| " return self.compute_add_one(values)\n", |
| "\n", |
| " @aot.jittable\n", |
| " def compute_add_one(values):\n", |
| " return values + 1" |
| ], |
| "metadata": { |
| "id": "vsf9F4WxI_DX" |
| }, |
| "execution_count": 5, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "from iree.compiler.ir import Context\n", |
| "\n", |
| "# Import into MLIR and save to disk.\n", |
| "dynamic_shapes_instance = DynamicShapesModule(context=Context())\n", |
| "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n", |
| "aot.CompiledModule.save_mlir(dynamic_shapes_instance, imported_mlir_path)\n", |
| "print(f\"Wrote MLIR to path '{imported_mlir_path}'\")\n", |
| "\n", |
| "# Inspect the IR.\n", |
| "# Note the question marks for dynamic shapes in types, like `tensor<?xi32>`.\n", |
| "print(\"\\nDynamic Shapes MLIR:\")\n", |
| "!cat {imported_mlir_path}" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "_OQIpOtNr4Gh", |
| "outputId": "888c0bf3-bec6-403c-9993-ad45d21364fb" |
| }, |
| "execution_count": 6, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlir'\n", |
| "\n", |
| "Dynamic Shapes MLIR:\n", |
| "#map = affine_map<(d0) -> (d0)>\n", |
| "#map1 = affine_map<(d0) -> ()>\n", |
| "#map2 = affine_map<(d0, d1) -> (d0, d1)>\n", |
| "#map3 = affine_map<(d0, d1) -> (d1)>\n", |
| "module @module {\n", |
| " func.func @reduce_sum_1d(%arg0: tensor<?xi32>) -> tensor<i32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", |
| " %0 = call @compute_reduce_sum_1d(%arg0) : (tensor<?xi32>) -> tensor<i32>\n", |
| " return %0 : tensor<i32>\n", |
| " }\n", |
| " func.func private @compute_reduce_sum_1d(%arg0: tensor<?xi32>) -> tensor<i32> {\n", |
| " %c0_i32 = arith.constant 0 : i32\n", |
| " %0 = tensor.empty() : tensor<i32>\n", |
| " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<i32>) -> tensor<i32>\n", |
| " %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = [\"reduction\"]} ins(%arg0 : tensor<?xi32>) outs(%1 : tensor<i32>) {\n", |
| " ^bb0(%in: i32, %out: i32):\n", |
| " %3 = arith.addi %in, %out : i32\n", |
| " linalg.yield %3 : i32\n", |
| " } -> tensor<i32>\n", |
| " return %2 : tensor<i32>\n", |
| " }\n", |
| " func.func @reduce_sum_2d(%arg0: tensor<?x3xi32>) -> tensor<3xi32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", |
| " %0 = call @compute_reduce_sum_2d(%arg0) : (tensor<?x3xi32>) -> tensor<3xi32>\n", |
| " return %0 : tensor<3xi32>\n", |
| " }\n", |
| " func.func private @compute_reduce_sum_2d(%arg0: tensor<?x3xi32>) -> tensor<3xi32> {\n", |
| " %c0_i32 = arith.constant 0 : i32\n", |
| " %0 = tensor.empty() : tensor<3xi32>\n", |
| " %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<3xi32>) -> tensor<3xi32>\n", |
| " %2 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = [\"reduction\", \"parallel\"]} ins(%arg0 : tensor<?x3xi32>) outs(%1 : tensor<3xi32>) {\n", |
| " ^bb0(%in: i32, %out: i32):\n", |
| " %3 = arith.addi %in, %out : i32\n", |
| " linalg.yield %3 : i32\n", |
| " } -> tensor<3xi32>\n", |
| " return %2 : tensor<3xi32>\n", |
| " }\n", |
| " func.func @add_one(%arg0: tensor<?xi32>) -> tensor<?xi32> attributes {torch.args_schema = \"[1, {\\22type\\22: \\22builtins.tuple\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: \\22builtins.list\\22, \\22context\\22: \\22null\\22, \\22children_spec\\22: [{\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]}, {\\22type\\22: \\22builtins.dict\\22, \\22context\\22: \\22[]\\22, \\22children_spec\\22: []}]}]\", torch.return_schema = \"[1, {\\22type\\22: null, \\22context\\22: null, \\22children_spec\\22: []}]\"} {\n", |
| " %0 = call @compute_add_one(%arg0) : (tensor<?xi32>) -> tensor<?xi32>\n", |
| " return %0 : tensor<?xi32>\n", |
| " }\n", |
| " func.func private @compute_add_one(%arg0: tensor<?xi32>) -> tensor<?xi32> {\n", |
| " %c0 = arith.constant 0 : index\n", |
| " %c1_i32 = arith.constant 1 : i32\n", |
| " %dim = tensor.dim %arg0, %c0 : tensor<?xi32>\n", |
| " %0 = tensor.empty(%dim) : tensor<?xi32>\n", |
| " %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = [\"parallel\"]} ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xi32>) {\n", |
| " ^bb0(%in: i32, %out: i32):\n", |
| " %2 = arith.addi %in, %c1_i32 : i32\n", |
| " linalg.yield %2 : i32\n", |
| " } -> tensor<?xi32>\n", |
| " return %1 : tensor<?xi32>\n", |
| " }\n", |
| "}\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Test the imported program\n", |
| "\n", |
| "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n", |
| "\n", |
| "_See the [README](https://github.com/iree-org/iree/tree/main/samples/dynamic_shapes#instructions) for more details and example command line instructions._\n", |
| "\n", |
| "* _The \"imported MLIR\" (above) can be used by IREE's generic compiler tools_\n", |
| "* _The \"binary\" can be saved and used by runtime applications_\n", |
| "\n", |
| "_The specific point at which you switch from Python to native tools will depend on your project._" |
| ], |
| "metadata": { |
| "id": "z6w_Pbl6tUtJ" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "# Export and compile.\n", |
| "exported_output = aot.export(DynamicShapesModule)\n", |
| "\n", |
| "# Compile to a file on disk for usage outside of Python.\n", |
| "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes_cpu.vmfb\")\n", |
| "exported_output.compile(save_to=flatbuffer_path)\n", |
| "print(f\"Wrote compiled program to path '{flatbuffer_path}'\")\n", |
| "\n", |
| "# Compile into memory for testing.\n", |
| "binary = exported_output.compile(save_to=None)" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "0PGyH1tvI_Ic", |
| "outputId": "23b53928-4d77-461f-e4b8-b2c8ffb25ef0" |
| }, |
| "execution_count": 7, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Wrote compiled program to path '/tmp/iree/colab_artifacts/dynamic_shapes_cpu.vmfb'\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "import iree.runtime as ireert\n", |
| "import numpy as np\n", |
| "\n", |
| "# Use the IREE runtime API to test the compiled program.\n", |
| "config = ireert.Config(\"local-task\")\n", |
| "vm_module = ireert.load_vm_module(\n", |
| " ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),\n", |
| " config,\n", |
| ")\n", |
| "\n", |
| "print(vm_module.reduce_sum_1d(np.array([1, 10, 100], dtype=np.int32)).to_host())\n", |
| "print(vm_module.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30]], dtype=np.int32)).to_host())\n", |
| "print(vm_module.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30], [100, 200, 300]], dtype=np.int32)).to_host())\n", |
| "print(vm_module.add_one(np.array([1, 10, 100], dtype=np.int32)).to_host())" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/" |
| }, |
| "id": "9ilJY15BI_LD", |
| "outputId": "57db6e52-83f1-4283-fc08-31e743cc9b42" |
| }, |
| "execution_count": 8, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "111\n", |
| "[11 22 33]\n", |
| "[111 222 333]\n", |
| "[ 2 11 101]\n" |
| ] |
| } |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "source": [ |
| "## Download compilation artifacts" |
| ], |
| "metadata": { |
| "id": "3mizlpY9uJEW" |
| } |
| }, |
| { |
| "cell_type": "code", |
| "source": [ |
| "ARTIFACTS_ZIP = \"/tmp/dynamic_shapes_colab_artifacts.zip\"\n", |
| "\n", |
| "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n", |
| "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n", |
| "\n", |
| "# Note: you can also download files using Colab's file explorer\n", |
| "try:\n", |
| " from google.colab import files\n", |
| " print(\"Downloading the artifacts zip file...\")\n", |
| " files.download(ARTIFACTS_ZIP)\n", |
| "except ImportError:\n", |
| " print(\"Missing google_colab Python package, can't download files\")" |
| ], |
| "metadata": { |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 86 |
| }, |
| "id": "dgaXpdiWuGtx", |
| "outputId": "dc0fbca1-c5b0-44f9-e1ff-9bf1307c049f" |
| }, |
| "execution_count": 9, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "name": "stdout", |
| "text": [ |
| "Zipping '/tmp/iree/colab_artifacts' to '/tmp/dynamic_shapes_colab_artifacts.zip' for download...\n", |
| " adding: dynamic_shapes_cpu.vmfb (deflated 66%)\n", |
| " adding: dynamic_shapes.mlir (deflated 82%)\n", |
| "Downloading the artifacts zip file...\n" |
| ] |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "<IPython.core.display.Javascript object>" |
| ], |
| "application/javascript": [ |
| "\n", |
| " async function download(id, filename, size) {\n", |
| " if (!google.colab.kernel.accessAllowed) {\n", |
| " return;\n", |
| " }\n", |
| " const div = document.createElement('div');\n", |
| " const label = document.createElement('label');\n", |
| " label.textContent = `Downloading \"${filename}\": `;\n", |
| " div.appendChild(label);\n", |
| " const progress = document.createElement('progress');\n", |
| " progress.max = size;\n", |
| " div.appendChild(progress);\n", |
| " document.body.appendChild(div);\n", |
| "\n", |
| " const buffers = [];\n", |
| " let downloaded = 0;\n", |
| "\n", |
| " const channel = await google.colab.kernel.comms.open(id);\n", |
| " // Send a message to notify the kernel that we're ready.\n", |
| " channel.send({})\n", |
| "\n", |
| " for await (const message of channel.messages) {\n", |
| " // Send a message to notify the kernel that we're ready.\n", |
| " channel.send({})\n", |
| " if (message.buffers) {\n", |
| " for (const buffer of message.buffers) {\n", |
| " buffers.push(buffer);\n", |
| " downloaded += buffer.byteLength;\n", |
| " progress.value = downloaded;\n", |
| " }\n", |
| " }\n", |
| " }\n", |
| " const blob = new Blob(buffers, {type: 'application/binary'});\n", |
| " const a = document.createElement('a');\n", |
| " a.href = window.URL.createObjectURL(blob);\n", |
| " a.download = filename;\n", |
| " div.appendChild(a);\n", |
| " a.click();\n", |
| " div.remove();\n", |
| " }\n", |
| " " |
| ] |
| }, |
| "metadata": {} |
| }, |
| { |
| "output_type": "display_data", |
| "data": { |
| "text/plain": [ |
| "<IPython.core.display.Javascript object>" |
| ], |
| "application/javascript": [ |
| "download(\"download_e2630f9b-e811-4164-b2d8-80cf52f17145\", \"dynamic_shapes_colab_artifacts.zip\", 5699)" |
| ] |
| }, |
| "metadata": {} |
| } |
| ] |
| } |
| ] |
| } |