blob: b06c22f211405c0a66a292e50bf716a38356cfc9 [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "low_level_invoke_function.ipynb",
"provenance": [],
"collapsed_sections": [
"FH3IRpYTta2v"
]
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "FH3IRpYTta2v"
},
"source": [
"##### Copyright 2019 The IREE Authors"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mWGa71_Ct2ug",
"cellView": "form"
},
"source": [
"#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
"# See https://llvm.org/LICENSE.txt for license information.\n",
"# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZEmrd07EvthK"
},
"source": [
"# Low Level Invoke Function"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uMVh8_lZDRa7"
},
"source": [
"This notebook shows off some concepts of the low level IREE python bindings."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Go2Nw7BgIHYU",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bfd1cc6a-bc75-4d16-850d-64776dc6d64f"
},
"source": [
"!python -m pip install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Looking in links: https://github.com/google/iree/releases\n",
"Collecting iree-compiler-snapshot\n",
"\u001b[?25l Downloading https://github.com/google/iree/releases/download/snapshot-20210608.328/iree_compiler_snapshot-20210608.328-py3-none-manylinux2014_x86_64.whl (33.6MB)\n",
"\u001b[K |████████████████████████████████| 33.6MB 118kB/s \n",
"\u001b[?25hCollecting iree-runtime-snapshot\n",
"\u001b[?25l Downloading https://github.com/google/iree/releases/download/snapshot-20210608.328/iree_runtime_snapshot-20210608.328-cp37-cp37m-manylinux2014_x86_64.whl (521kB)\n",
"\u001b[K |████████████████████████████████| 522kB 27.1MB/s \n",
"\u001b[?25hInstalling collected packages: iree-compiler-snapshot, iree-runtime-snapshot\n",
"Successfully installed iree-compiler-snapshot-20210608.328 iree-runtime-snapshot-20210608.328\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1F144M4wAFPz"
},
"source": [
"import numpy as np\n",
"\n",
"from iree import runtime as ireert\n",
"from iree.compiler import compile_str"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2Rq-JdzMAFPU"
},
"source": [
"# Compile a module.\n",
"SIMPLE_MUL_ASM = \"\"\"\n",
" module @arithmetic {\n",
" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n",
" {\n",
" %0 = mulf %arg0, %arg1 : tensor<4xf32>\n",
" return %0 : tensor<4xf32>\n",
" } \n",
" }\n",
"\"\"\"\n",
"\n",
"compiled_flatbuffer = compile_str(SIMPLE_MUL_ASM, target_backends=[\"vmvx\"])\n",
"vm_module = ireert.VmModule.from_flatbuffer(compiled_flatbuffer)"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TNQiNeOU_cpK",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "37e2df44-90fb-425e-f6e5-059efdfe099e"
},
"source": [
"# Register the module with a runtime context.\n",
"# Use the CPU interpreter (which has the most implementation done):\n",
"config = ireert.Config(\"vmvx\")\n",
"ctx = ireert.SystemContext(config=config)\n",
"ctx.add_vm_module(vm_module)\n",
"\n",
"# Invoke the function and print the result.\n",
"print(\"INVOKE simple_mul\")\n",
"arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)\n",
"arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)\n",
"f = ctx.modules.arithmetic[\"simple_mul\"]\n",
"results = f(arg0, arg1)\n",
"print(\"Results:\", results)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"INVOKE simple_mul\n",
"Results: [ 4. 10. 18. 28.]\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7f9b9d6e6ef0>\n",
"SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7f9b9d6e6ef0>\n"
],
"name": "stderr"
}
]
}
]
}