blob: bcc7bf5e41e79010ce96e23a9ae45a87ea278bd1 [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "low_level_invoke_function.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uMVh8_lZDRa7",
"colab_type": "text"
},
"source": [
"See the IREE docs/using_colab.md document for instructions.\n",
"\n",
"This notebook shows off some concepts of the low level IREE python bindings."
]
},
{
"cell_type": "code",
"metadata": {
"id": "7qZfpb7Ob6id",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"from pyiree import binding\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FsHplZ3kEWrl",
"colab_type": "code",
"colab": {}
},
"source": [
"def dump_module(m):\n",
" print(\"Loaded module:\", m.name)\n",
" i = 0\n",
" while True:\n",
" f = m.lookup_function_by_ordinal(i)\n",
" if not f: break\n",
" print(\" Export:\", f.name, \"-> args(\", f.signature.argument_count, \n",
" \"), results(\", f.signature.result_count, \")\")\n",
" i += 1 "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rxaiDxiq96SD",
"colab_type": "code",
"outputId": "a1304fa3-b15a-4fab-eaaf-1467dc867191",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"# Compile a module.\n",
"ctx = binding.compiler.CompilerContext()\n",
"input_module = ctx.parse_asm(\"\"\"\n",
" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n",
" attributes { iree.module.export } {\n",
" %0 = \"xla_hlo.mul\"(%arg0, %arg1) {name = \"mul.1\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
" return %0 : tensor<4xf32>\n",
" }\n",
" \"\"\")\n",
"blob = input_module.compile_to_sequencer_blob()\n",
"m = binding.vm.create_module_from_blob(blob)\n",
"dump_module(m)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Loaded module: module\n",
" Export: simple_mul -> args( 2 ), results( 1 )\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aH6VdaoXD4hV",
"colab_type": "code",
"outputId": "d109d4e7-83bf-4038-c0d1-c643dbd10c8e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
}
},
"source": [
"# Initialize the runtime and register the module.\n",
"# Use the CPU interpreter driver (which has the most implementation done):\n",
"driver_name = \"interpreter\"\n",
"\n",
"# Live on the edge and give the vulkan driver a try:\n",
"# driver_name = \"vulkan\"\n",
"\n",
"policy = binding.rt.Policy()\n",
"instance = binding.rt.Instance(driver_name=driver_name)\n",
"context = binding.rt.Context(instance=instance, policy=policy)\n",
"context.register_module(m)\n",
"f = context.resolve_function(\"module.simple_mul\")\n",
"\n",
"print(\"INVOKE F:\", f.name)\n",
"arg0 = context.wrap_for_input(np.array([1., 2., 3., 4.], dtype=np.float32))\n",
"arg1 = context.wrap_for_input(np.array([4., 5., 6., 7.], dtype=np.float32))\n",
"\n",
"# Invoke the function and wait for completion.\n",
"inv = context.invoke(f, policy, [arg0, arg1])\n",
"print(\"Status:\", inv.query_status())\n",
"inv.await_ready()\n",
"\n",
"# Get the result as a numpy array and print.\n",
"results = inv.results\n",
"print(\"Results:\", results)\n",
"result = results[0].map()\n",
"print(\"Mapped result:\", result)\n",
"result_ary = np.array(result, copy=False)\n",
"print(\"NP result:\", result_ary)\n"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"INVOKE F: simple_mul\n",
"Status: True\n",
"Results: [<pyiree.binding.hal.BufferView object at 0x00000179E51410D8>]\n",
"Mapped result: <pyiree.binding.hal.MappedMemory object at 0x00000179E51412D0>\n",
"NP result: [ 4. 10. 18. 28.]\n"
],
"name": "stdout"
}
]
}
]
}