| # Copyright 2019 The IREE Authors |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| from iree.tf.support import tf_test_utils |
| import tensorflow.compat.v2 as tf |
| class SimpleStatefulModule(tf.Module): |
| self.counter = tf.Variable(0.0) |
| @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) |
| self.counter.assign(self.counter + x) |
| @tf.function(input_signature=[]) |
| class StatefulTest(tf_test_utils.TracedModuleTestCase): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule) |
| module.inc_by(np.array(1., dtype=np.float32)) |
| self.compare_backends(get_state, self._modules) |
| if hasattr(tf, 'enable_v2_behavior'): |
| if __name__ == '__main__': |