blob: 153976dc2f1ec146fa0f798acf68db9074011aa6 [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import jax
import jax.numpy as jnp
import numpy as np
import pyiree as iree
import pyiree.jax
import pyiree.rt
# pytype thinks pyiree.jax is jax.
# pytype: disable=module-attr
TOLERANCE = {"rtol": 1e-6, "atol": 1e-6}
def normal(shape):
return np.random.normal(0, 1, shape).astype(np.float32)
class SqrtNode:
def __init__(self, x, y):
self.x = x
self.y = y
def apply(self, z):
return self.x * jnp.sqrt(self.y * z)
def tree_flatten(self):
return ((self.x, self.y), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
class SquareNode:
def __init__(self, x, y):
self.x = x
self.y = y
def apply(self, z):
return self.x * (self.y * z)**2
def tree_flatten(self):
return ((self.x, self.y), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
class JAXFrontendTest(unittest.TestCase):
def test_aot_pytree(self):
def pytree_func(params, x):
return jnp.max(jnp.matmul(x, params["w"]) + params["b"], 0)
trace_args = [
{
"w": jnp.zeros((32, 8)),
"b": jnp.zeros((8,))
},
jnp.zeros((1, 32)),
]
binary = iree.jax.aot(pytree_func, *trace_args, target_backends=["vmla"])
def test_jit_get_binary(self):
@iree.jax.jit(backend="vmla")
def scaled_dot_product_attention(q, k, v):
# Subscripts:
# q: seq_len_q, k: seq_len_kv, d: att_depth_qk, v: att_depth_v
qk = jnp.einsum("...qd, ...kd -> ...qk", q, k)
soft_qk = jax.nn.softmax(qk / np.sqrt(k.shape[-1]), axis=-1)
return jnp.einsum("...qk, ...kv -> ...qv", soft_qk, v)
np.random.seed(0)
inputs_1 = [normal((2, 4, 5, 8)) for _ in range(3)]
vmla_result_1 = scaled_dot_product_attention(*inputs_1)
jax_result_1 = scaled_dot_product_attention._function(*inputs_1)
np.testing.assert_allclose(vmla_result_1, jax_result_1, **TOLERANCE)
inputs_2 = [normal((2, 1, 6, 8)) for _ in range(3)]
vmla_result_2 = scaled_dot_product_attention(*inputs_2)
jax_result_2 = scaled_dot_product_attention._function(*inputs_2)
np.testing.assert_allclose(vmla_result_2, jax_result_2, **TOLERANCE)
# Test that the binaries are different.
binary_1 = scaled_dot_product_attention.get_binary(*inputs_1)
binary_2 = scaled_dot_product_attention.get_binary(*inputs_2)
self.assertNotEqual(binary_1, binary_2)
# Test that the binary gives the same result.
cpp_vm_module = iree.rt.VmModule.from_flatbuffer(binary_1)
module = iree.rt.load_module(cpp_vm_module, config=iree.rt.Config("vmla"))
binary_result_1 = module.main(*inputs_1)
np.testing.assert_equal(binary_result_1, vmla_result_1)
with self.assertRaises(ValueError):
bad_input_shape = module.main(*inputs_2)
def test_jit_pytree_return(self):
@iree.jax.jit
def apply_sqrt(pytree):
return jax.tree_map(jnp.sqrt, pytree)
np.random.seed(0)
input_tree = {
"a": [
normal((2, 3)),
{
"b": normal(3)
},
],
"c": (
{
"d": [normal(2), normal(3)]
},
(normal(1), normal(4)),
)
}
expected = jax.tree_map(jnp.sqrt, input_tree)
expected_arrays, expected_tree = jax.tree_flatten(expected)
result = apply_sqrt(input_tree)
result_arrays, result_tree = jax.tree_flatten(result)
self.assertEqual(expected_tree, result_tree)
for expected_array, result_array in zip(expected_arrays, result_arrays):
np.testing.assert_allclose(expected_array, result_array, **TOLERANCE)
def test_iree_jit_of_iree_jit(self):
@iree.jax.jit
def add(a, b):
return a + b
@iree.jax.jit
def mul_two(a):
return add(a, a)
self.assertEqual(mul_two(3), 6)
def test_jax_jit_of_iree_jit(self):
@iree.jax.jit
def add(a, b):
return a + b
@jax.jit
def mul_two(a):
return add(a, a)
self.assertEqual(mul_two(3), 6)
def test_iree_jit_of_jax_jit(self):
@jax.jit
def add(a, b):
return a + b
@iree.jax.jit
def mul_two(a):
return add(a, a)
self.assertEqual(mul_two(3), 6)
def test_iree_jit_of_empty_iree_jit(self):
@iree.jax.jit
def sqrt_four():
return jnp.sqrt(4)
@iree.jax.jit
def add_sqrt_four(a):
return a + sqrt_four()
self.assertEqual(add_sqrt_four(2), 4)
def test_jit_pytree_method(self):
@iree.jax.jit
def apply_node(node, z):
return node.apply(z)
expected_sqrt = apply_node._function(SqrtNode(2, 3), 4)
compied_sqrt = apply_node(SqrtNode(2, 3), 4)
np.testing.assert_allclose(compied_sqrt, expected_sqrt)
expected_square = apply_node._function(SquareNode(2, 3), 4)
compied_square = apply_node(SquareNode(2, 3), 4)
np.testing.assert_allclose(expected_square, expected_square)
if __name__ == "__main__":
jax.tree_util.register_pytree_node_class(SqrtNode)
jax.tree_util.register_pytree_node_class(SquareNode)
unittest.main()