blob: 94a60891a53cb4d6ee7091c5b10ef798536cc59f [file] [log] [blame]
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "low_level_invoke_function.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uMVh8_lZDRa7",
"colab_type": "text"
},
"source": [
"See the IREE docs/using_colab.md document for instructions.\n",
"\n",
"This notebook shows off some concepts of the low level IREE python bindings."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "1F144M4wAFPz",
"colab": {}
},
"source": [
"import numpy as np\n",
"from pyiree.tf import compiler as ireec\n",
"from pyiree import rt as ireert"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "2Rq-JdzMAFPU",
"colab": {}
},
"source": [
"# Compile a module.\n",
"ctx = ireec.Context()\n",
"input_module = ctx.parse_asm(\"\"\"\n",
" module @arithmetic {\n",
" func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n",
" attributes { iree.module.export } {\n",
" %0 = \"xla_hlo.multiply\"(%arg0, %arg1) {name = \"mul.1\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n",
" return %0 : tensor<4xf32>\n",
" } \n",
" }\n",
" \"\"\")\n",
"blob = input_module.compile()\n",
"m = ireert.VmModule.from_flatbuffer(blob)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TNQiNeOU_cpK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "a3b16ec2-04f6-463a-a03a-8dceec277dc9"
},
"source": [
"# Register the module with a runtime context.\n",
"# Use the CPU interpreter (which has the most implementation done):\n",
"driver_name = \"interpreter\"\n",
"# Live on the edge and give the vulkan driver a try:\n",
"# driver_name = \"vulkan\"\n",
"config = ireert.Config(driver_name)\n",
"ctx = ireert.SystemContext(config=config)\n",
"ctx.add_module(m)\n",
"\n",
"# Invoke the function and print the result.\n",
"print(\"INVOKE simple_mul\")\n",
"arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)\n",
"arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)\n",
"f = ctx.modules.arithmetic[\"simple_mul\"]\n",
"results = f(arg0, arg1)\n",
"print(\"Results:\", results)"
],
"execution_count": 36,
"outputs": [
{
"output_type": "stream",
"text": [
"INVOKE simple_mul\n",
"Results: [ 4. 10. 18. 28.]\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"Created IREE driver interpreter: <pyiree.rt.binding.HalDriver object at 0x7fae493257f0>\n",
"SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x7fae493257f0>\n"
],
"name": "stderr"
}
]
}
]
}