blob: 3bf41cfee9dffeefaa0794c951fc7cc749206420 [file] [log] [blame]
Geoffrey Martin-Noble552d3f82021-05-25 17:56:09 -07001# Copyright 2020 The IREE Authors
Rob Suderman42165522020-05-21 17:42:05 -07002#
Geoffrey Martin-Noble552d3f82021-05-25 17:56:09 -07003# Licensed under the Apache License v2.0 with LLVM Exceptions.
4# See https://llvm.org/LICENSE.txt for license information.
5# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Rob Suderman42165522020-05-21 17:42:05 -07006
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -07007from absl import app
Phoenix Meadowlark5a8954e2021-03-17 18:22:12 -07008from iree.tf.support import tf_test_utils
Rob Suderman42165522020-05-21 17:42:05 -07009import numpy as np
Rob Suderman42165522020-05-21 17:42:05 -070010import tensorflow.compat.v2 as tf
11
12
13class ResourcesOpsModule(tf.Module):
14
15 def __init__(self):
Phoenix Meadowlark896137d2020-10-19 09:48:31 -070016 super().__init__()
Rob Suderman42165522020-05-21 17:42:05 -070017 self.counter = tf.Variable(0.0)
18
19 @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
20 def add_assign(self, value):
21 return self.counter.assign_add(value)
22
Stella Laurenzo65137872021-06-18 13:15:07 -070023 @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
24 def set_value(self, new_value):
25 self.counter.assign(new_value)
26
27 @tf.function(input_signature=[])
28 def get_value(self):
29 return self.counter
30
Rob Suderman42165522020-05-21 17:42:05 -070031
Phoenix Meadowlark6d32fe82020-07-31 12:35:51 -070032class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
Rob Suderman42165522020-05-21 17:42:05 -070033
Phoenix Meadowlark4b506b72020-10-06 11:19:13 -070034 def __init__(self, *args, **kwargs):
Phoenix Meadowlark896137d2020-10-19 09:48:31 -070035 super().__init__(*args, **kwargs)
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070036 self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule)
37
Rob Suderman42165522020-05-21 17:42:05 -070038 def test_add_assign(self):
Phoenix Meadowlark6d32fe82020-07-31 12:35:51 -070039
40 def add_assign(module):
41 module.add_assign(np.array(9., dtype=np.float32))
42
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070043 self.compare_backends(add_assign, self._modules)
Rob Suderman42165522020-05-21 17:42:05 -070044
Stella Laurenzo65137872021-06-18 13:15:07 -070045 def test_assign_get(self):
46
47 def assign_get(module):
48 module.set_value(np.array(9., dtype=np.float32))
49 return module.get_value()
50
51 self.compare_backends(assign_get, self._modules)
52
Rob Suderman42165522020-05-21 17:42:05 -070053
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070054def main(argv):
55 del argv # Unused
56 if hasattr(tf, 'enable_v2_behavior'):
Rob Suderman42165522020-05-21 17:42:05 -070057 tf.enable_v2_behavior()
58 tf.test.main()
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070059
60
61if __name__ == '__main__':
62 app.run(main)