Add support for returning empty sequences in the IREE-JAX frontend.
PiperOrigin-RevId: 370581258
diff --git a/bindings/python/iree/jax/frontend.py b/bindings/python/iree/jax/frontend.py
index 7c57563..0c26e11 100644
--- a/bindings/python/iree/jax/frontend.py
+++ b/bindings/python/iree/jax/frontend.py
@@ -124,9 +124,18 @@
_, module, out_tree = self._get_compiled_artifacts(args, kwargs)
results = module.main(*args_flat)
- if not isinstance(results, tuple):
- results = (results,)
- return jax.tree_unflatten(out_tree, results)
+ if results is not None:
+ if not isinstance(results, tuple):
+ results = (results,)
+ return jax.tree_unflatten(out_tree, results)
+ else:
+ # Address IREE returning None instead of empty sequences.
+ if out_tree == jax.tree_flatten([])[-1]:
+ return []
+ elif out_tree == jax.tree_flatten(())[-1]:
+ return ()
+ else:
+ return results
def get_binary(self, *args, **kwargs):
"""Gets the IREE-compiled binary for the given inputs."""