| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "low_level_invoke_function.ipynb", |
| "provenance": [], |
| "collapsed_sections": [] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "uMVh8_lZDRa7", |
| "colab_type": "text" |
| }, |
| "source": [ |
| "See the IREE docs/using_colab.md document for instructions.\n", |
| "\n", |
| "This notebook shows off some concepts of the low level IREE python bindings." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "7qZfpb7Ob6id", |
| "colab_type": "code", |
| "colab": {} |
| }, |
| "source": [ |
| "import numpy as np\n", |
| "from pyiree import binding\n" |
| ], |
| "execution_count": 0, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "FsHplZ3kEWrl", |
| "colab_type": "code", |
| "colab": {} |
| }, |
| "source": [ |
| "def dump_module(m):\n", |
| " print(\"Loaded module:\", m.name)\n", |
| " i = 0\n", |
| " while True:\n", |
| " f = m.lookup_function_by_ordinal(i)\n", |
| " if not f: break\n", |
| " print(\" Export:\", f.name, \"-> args(\", f.signature.argument_count, \n", |
| " \"), results(\", f.signature.result_count, \")\")\n", |
| " i += 1 " |
| ], |
| "execution_count": 0, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "rxaiDxiq96SD", |
| "colab_type": "code", |
| "outputId": "a1304fa3-b15a-4fab-eaaf-1467dc867191", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 51 |
| } |
| }, |
| "source": [ |
| "# Compile a module.\n", |
| "ctx = binding.compiler.CompilerContext()\n", |
| "input_module = ctx.parse_asm(\"\"\"\n", |
| " func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n", |
| " attributes { iree.module.export } {\n", |
| " %0 = \"xla_hlo.mul\"(%arg0, %arg1) {name = \"mul.1\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " return %0 : tensor<4xf32>\n", |
| " }\n", |
| " \"\"\")\n", |
| "blob = input_module.compile_to_sequencer_blob()\n", |
| "m = binding.vm.create_module_from_blob(blob)\n", |
| "dump_module(m)" |
| ], |
| "execution_count": 7, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "Loaded module: module\n", |
| " Export: simple_mul -> args( 2 ), results( 1 )\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "aH6VdaoXD4hV", |
| "colab_type": "code", |
| "outputId": "d109d4e7-83bf-4038-c0d1-c643dbd10c8e", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 102 |
| } |
| }, |
| "source": [ |
| "# Initialize the runtime and register the module.\n", |
| "# Use the CPU interpreter driver (which has the most implementation done):\n", |
| "driver_name = \"interpreter\"\n", |
| "\n", |
| "# Live on the edge and give the vulkan driver a try:\n", |
| "# driver_name = \"vulkan\"\n", |
| "\n", |
| "policy = binding.rt.Policy()\n", |
| "instance = binding.rt.Instance(driver_name=driver_name)\n", |
| "context = binding.rt.Context(instance=instance, policy=policy)\n", |
| "context.register_module(m)\n", |
| "f = context.resolve_function(\"module.simple_mul\")\n", |
| "\n", |
| "print(\"INVOKE F:\", f.name)\n", |
| "arg0 = context.wrap_for_input(np.array([1., 2., 3., 4.], dtype=np.float32))\n", |
| "arg1 = context.wrap_for_input(np.array([4., 5., 6., 7.], dtype=np.float32))\n", |
| "\n", |
| "# Invoke the function and wait for completion.\n", |
| "inv = context.invoke(f, policy, [arg0, arg1])\n", |
| "print(\"Status:\", inv.query_status())\n", |
| "inv.await_ready()\n", |
| "\n", |
| "# Get the result as a numpy array and print.\n", |
| "results = inv.results\n", |
| "print(\"Results:\", results)\n", |
| "result = results[0].map()\n", |
| "print(\"Mapped result:\", result)\n", |
| "result_ary = np.array(result, copy=False)\n", |
| "print(\"NP result:\", result_ary)\n" |
| ], |
| "execution_count": 8, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "INVOKE F: simple_mul\n", |
| "Status: True\n", |
| "Results: [<pyiree.binding.hal.BufferView object at 0x00000179E51410D8>]\n", |
| "Mapped result: <pyiree.binding.hal.MappedMemory object at 0x00000179E51412D0>\n", |
| "NP result: [ 4. 10. 18. 28.]\n" |
| ], |
| "name": "stdout" |
| } |
| ] |
| } |
| ] |
| } |