A local JAX installation is necessary in addition to IREE's Python requirements:
python -m pip install jax jaxlib
A just-in-time compilation decorator similar to jax.jit
is provided by iree.jax.jit
:
import pyiree as iree import pyiree.jax import jax import jax.numpy as jnp # 'backend' is one of 'vmla', 'llvmaot' and 'vulkan' and defaults to 'llvmaot'. @iree.jax.jit(backend="llvmaot") def linear_relu_layer(params, x): w, b = params return jnp.max(jnp.matmul(x, w) + b, 0) w = jnp.zeros((784, 128)) b = jnp.zeros(128) x = jnp.zeros((1, 784)) linear_relu_layer([w, b], x)
An ahead-of-time compilation function provides a lower-level API for compiling a function with a specific input signature without creating the runtime bindings for execution within Python. This is primarily useful for targeting other runtime environments like Android.
Install the Android NDK according to the Android Getting Started doc, and then ensure the following environment variable is set:
export ANDROID_NDK=# NDK install location
The code below assumes that you have flax
installed.
import pyiree as iree import pyiree.jax import jax import jax.numpy as jnp import flax from flax import linen as nn import os # Configure the linker to target Android. os.environ["IREE_LLVMAOT_LINKER_PATH"] = "${ANDROID_NDK?}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android29-clang++ -static-libstdc++" class MLP(nn.Module): @nn.compact def __call__(self, x): x = x.reshape((x.shape[0], -1)) # Flatten. x = nn.Dense(128)(x) x = nn.relu(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x image = jnp.zeros((1, 28, 28, 1)) params = MLP().init(jax.random.PRNGKey(0), image)["params"] apply_args = [{"params": params}, image] options = dict(target_backends=["dylib-llvm-aot"], extra_args=["--iree-llvm-target-triple=aarch64-linux-android"]) # This will throw a compilation error if IREE_LLVMAOT_LINKER_PATH is not set. compiled_binary = iree.jax.aot(MLP().apply, *apply_args, **options) with open("/tmp/mlp_apply.vmfb", "wb") as f: f.write(compiled_binary)
IREE doesn‘t provide installable tools for Android at this time, so they’ll need to be built according to the Android Getting Started. Afterward, the compiled .vmfb
can be pushed to an Android device and executed using iree-run-module
:
adb push /tmp/mlp_apply.vmfb /data/local/tmp/ adb push ../iree-build-android/iree/tools/iree-run-module /data/local/tmp/ adb shell /data/local/tmp/iree-run-module \ -module_file=/data/local/tmp/mlp_apply.vmfb \ -function_inputs="128xf32,784x128xf32,10xf32,128x10xf32,1x28x28x1xf32" \ -driver=dylib \ -entry_function=main