Handle compiling pytrees with methods in jax frontend (#4630)

Since pytrees are allowed to have methods, the shape and dtypes of
their children is insufficient to uniquely identify their intended behavior.
This change allows multiple `flax.optim` optimizers to be passed to the
same `update` function.
diff --git a/bindings/python/pyiree/jax/frontend.py b/bindings/python/pyiree/jax/frontend.py
index d49edfe..0429f36 100644
--- a/bindings/python/pyiree/jax/frontend.py
+++ b/bindings/python/pyiree/jax/frontend.py
@@ -83,9 +83,9 @@
     self._options = options
     self._memoized_signatures = {}
 
-  def _get_signature(self, args_flat):
+  def _get_signature(self, args_flat, in_tree):
     args_flat = [rt.normalize_value(arg) for arg in args_flat]
-    return tuple((arg.shape, arg.dtype) for arg in args_flat)
+    return tuple((arg.shape, arg.dtype) for arg in args_flat) + (in_tree,)
 
   def _wrap_and_compile(self, signature, args_flat, in_tree):
     """Compiles the function for the given signature."""
@@ -110,7 +110,7 @@
   def _get_compiled_artifacts(self, args, kwargs):
     """Returns the binary, loaded rt module and out_tree."""
     args_flat, in_tree = jax.tree_flatten((args, kwargs))
-    signature = self._get_signature(args_flat)
+    signature = self._get_signature(args_flat, in_tree)
 
     if signature not in self._memoized_signatures:
       self._wrap_and_compile(signature, args_flat, in_tree)
diff --git a/bindings/python/pyiree/jax/frontend_test.py b/bindings/python/pyiree/jax/frontend_test.py
index dbc59de..153976d 100644
--- a/bindings/python/pyiree/jax/frontend_test.py
+++ b/bindings/python/pyiree/jax/frontend_test.py
@@ -31,6 +31,40 @@
   return np.random.normal(0, 1, shape).astype(np.float32)
 
 
+class SqrtNode:
+
+  def __init__(self, x, y):
+    self.x = x
+    self.y = y
+
+  def apply(self, z):
+    return self.x * jnp.sqrt(self.y * z)
+
+  def tree_flatten(self):
+    return ((self.x, self.y), None)
+
+  @classmethod
+  def tree_unflatten(cls, aux_data, children):
+    return cls(*children)
+
+
+class SquareNode:
+
+  def __init__(self, x, y):
+    self.x = x
+    self.y = y
+
+  def apply(self, z):
+    return self.x * (self.y * z)**2
+
+  def tree_flatten(self):
+    return ((self.x, self.y), None)
+
+  @classmethod
+  def tree_unflatten(cls, aux_data, children):
+    return cls(*children)
+
+
 class JAXFrontendTest(unittest.TestCase):
 
   def test_aot_pytree(self):
@@ -161,6 +195,22 @@
 
     self.assertEqual(add_sqrt_four(2), 4)
 
+  def test_jit_pytree_method(self):
+
+    @iree.jax.jit
+    def apply_node(node, z):
+      return node.apply(z)
+
+    expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
+    compied_sqrt = apply_node(SqrtNode(2, 3), 4)
+    np.testing.assert_allclose(compied_sqrt, expected_sqrt)
+
+    expected_square = apply_node._function(SquareNode(2, 3), 4)
+    compied_square = apply_node(SquareNode(2, 3), 4)
+    np.testing.assert_allclose(expected_square, expected_square)
+
 
 if __name__ == "__main__":
+  jax.tree_util.register_pytree_node_class(SqrtNode)
+  jax.tree_util.register_pytree_node_class(SquareNode)
   unittest.main()