| { |
| "nbformat": 4, |
| "nbformat_minor": 0, |
| "metadata": { |
| "colab": { |
| "name": "low_level_invoke_function.ipynb", |
| "provenance": [], |
| "collapsed_sections": [] |
| }, |
| "kernelspec": { |
| "name": "python3", |
| "display_name": "Python 3" |
| } |
| }, |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "metadata": { |
| "id": "uMVh8_lZDRa7", |
| "colab_type": "text" |
| }, |
| "source": [ |
| "See the IREE docs/using_colab.md document for instructions.\n", |
| "\n", |
| "This notebook shows off some concepts of the low level IREE python bindings." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab_type": "code", |
| "id": "1F144M4wAFPz", |
| "colab": {} |
| }, |
| "source": [ |
| "import numpy as np\n", |
| "from pyiree.tf import compiler as ireec\n", |
| "from pyiree import rt as ireert" |
| ], |
| "execution_count": 0, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "colab_type": "code", |
| "id": "2Rq-JdzMAFPU", |
| "colab": {} |
| }, |
| "source": [ |
| "# Compile a module.\n", |
| "ctx = ireec.Context()\n", |
| "input_module = ctx.parse_asm(\"\"\"\n", |
| " module @arithmetic {\n", |
| " func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>\n", |
| " attributes { iree.module.export } {\n", |
| " %0 = \"xla_hlo.multiply\"(%arg0, %arg1) {name = \"mul.1\"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>\n", |
| " return %0 : tensor<4xf32>\n", |
| " } \n", |
| " }\n", |
| " \"\"\")\n", |
| "blob = input_module.compile()\n", |
| "m = ireert.VmModule.from_flatbuffer(blob)" |
| ], |
| "execution_count": 0, |
| "outputs": [] |
| }, |
| { |
| "cell_type": "code", |
| "metadata": { |
| "id": "TNQiNeOU_cpK", |
| "colab_type": "code", |
| "colab": { |
| "base_uri": "https://localhost:8080/", |
| "height": 85 |
| }, |
| "outputId": "a3b16ec2-04f6-463a-a03a-8dceec277dc9" |
| }, |
| "source": [ |
| "# Register the module with a runtime context.\n", |
| "# Use the CPU interpreter (which has the most implementation done):\n", |
| "driver_name = \"interpreter\"\n", |
| "# Live on the edge and give the vulkan driver a try:\n", |
| "# driver_name = \"vulkan\"\n", |
| "config = ireert.Config(driver_name)\n", |
| "ctx = ireert.SystemContext(config=config)\n", |
| "ctx.add_module(m)\n", |
| "\n", |
| "# Invoke the function and print the result.\n", |
| "print(\"INVOKE simple_mul\")\n", |
| "arg0 = np.array([1., 2., 3., 4.], dtype=np.float32)\n", |
| "arg1 = np.array([4., 5., 6., 7.], dtype=np.float32)\n", |
| "f = ctx.modules.arithmetic[\"simple_mul\"]\n", |
| "results = f(arg0, arg1)\n", |
| "print(\"Results:\", results)" |
| ], |
| "execution_count": 36, |
| "outputs": [ |
| { |
| "output_type": "stream", |
| "text": [ |
| "INVOKE simple_mul\n", |
| "Results: [ 4. 10. 18. 28.]\n" |
| ], |
| "name": "stdout" |
| }, |
| { |
| "output_type": "stream", |
| "text": [ |
| "Created IREE driver interpreter: <pyiree.rt.binding.HalDriver object at 0x7fae493257f0>\n", |
| "SystemContext driver=<pyiree.rt.binding.HalDriver object at 0x7fae493257f0>\n" |
| ], |
| "name": "stderr" |
| } |
| ] |
| } |
| ] |
| } |