Add support for returning empty sequences in the IREE-JAX frontend.

PiperOrigin-RevId: 370581258
diff --git a/bindings/python/iree/jax/frontend.py b/bindings/python/iree/jax/frontend.py
index 7c57563..0c26e11 100644
--- a/bindings/python/iree/jax/frontend.py
+++ b/bindings/python/iree/jax/frontend.py
@@ -124,9 +124,18 @@
 
     _, module, out_tree = self._get_compiled_artifacts(args, kwargs)
     results = module.main(*args_flat)
-    if not isinstance(results, tuple):
-      results = (results,)
-    return jax.tree_unflatten(out_tree, results)
+    if results is not None:
+      if not isinstance(results, tuple):
+        results = (results,)
+      return jax.tree_unflatten(out_tree, results)
+    else:
+      # Address IREE returning None instead of empty sequences.
+      if out_tree == jax.tree_flatten([])[-1]:
+        return []
+      elif out_tree == jax.tree_flatten(())[-1]:
+        return ()
+      else:
+        return results
 
   def get_binary(self, *args, **kwargs):
     """Gets the IREE-compiled binary for the given inputs."""