Handle compiling pytrees with methods in jax frontend (#4630)
Since pytrees are allowed to have methods, the shape and dtypes of
their children is insufficient to uniquely identify their intended behavior.
This change allows multiple `flax.optim` optimizers to be passed to the
same `update` function.
diff --git a/bindings/python/pyiree/jax/frontend.py b/bindings/python/pyiree/jax/frontend.py
index d49edfe..0429f36 100644
--- a/bindings/python/pyiree/jax/frontend.py
+++ b/bindings/python/pyiree/jax/frontend.py
@@ -83,9 +83,9 @@
self._options = options
self._memoized_signatures = {}
- def _get_signature(self, args_flat):
+ def _get_signature(self, args_flat, in_tree):
args_flat = [rt.normalize_value(arg) for arg in args_flat]
- return tuple((arg.shape, arg.dtype) for arg in args_flat)
+ return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
def _wrap_and_compile(self, signature, args_flat, in_tree):
"""Compiles the function for the given signature."""
@@ -110,7 +110,7 @@
def _get_compiled_artifacts(self, args, kwargs):
"""Returns the binary, loaded rt module and out_tree."""
args_flat, in_tree = jax.tree_flatten((args, kwargs))
- signature = self._get_signature(args_flat)
+ signature = self._get_signature(args_flat, in_tree)
if signature not in self._memoized_signatures:
self._wrap_and_compile(signature, args_flat, in_tree)
diff --git a/bindings/python/pyiree/jax/frontend_test.py b/bindings/python/pyiree/jax/frontend_test.py
index dbc59de..153976d 100644
--- a/bindings/python/pyiree/jax/frontend_test.py
+++ b/bindings/python/pyiree/jax/frontend_test.py
@@ -31,6 +31,40 @@
return np.random.normal(0, 1, shape).astype(np.float32)
+class SqrtNode:
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def apply(self, z):
+ return self.x * jnp.sqrt(self.y * z)
+
+ def tree_flatten(self):
+ return ((self.x, self.y), None)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(*children)
+
+
+class SquareNode:
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def apply(self, z):
+ return self.x * (self.y * z)**2
+
+ def tree_flatten(self):
+ return ((self.x, self.y), None)
+
+ @classmethod
+ def tree_unflatten(cls, aux_data, children):
+ return cls(*children)
+
+
class JAXFrontendTest(unittest.TestCase):
def test_aot_pytree(self):
@@ -161,6 +195,22 @@
self.assertEqual(add_sqrt_four(2), 4)
+ def test_jit_pytree_method(self):
+
+ @iree.jax.jit
+ def apply_node(node, z):
+ return node.apply(z)
+
+ expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
+ compied_sqrt = apply_node(SqrtNode(2, 3), 4)
+ np.testing.assert_allclose(compied_sqrt, expected_sqrt)
+
+ expected_square = apply_node._function(SquareNode(2, 3), 4)
+ compied_square = apply_node(SquareNode(2, 3), 4)
+ np.testing.assert_allclose(expected_square, expected_square)
+
if __name__ == "__main__":
+ jax.tree_util.register_pytree_node_class(SqrtNode)
+ jax.tree_util.register_pytree_node_class(SquareNode)
unittest.main()