Geoffrey Martin-Noble | 552d3f8 | 2021-05-25 17:56:09 -0700 | [diff] [blame] | 1 | # Copyright 2020 The IREE Authors |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 2 | # |
Geoffrey Martin-Noble | 552d3f8 | 2021-05-25 17:56:09 -0700 | [diff] [blame] | 3 | # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | # See https://llvm.org/LICENSE.txt for license information. |
| 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 6 | |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 7 | from absl import app |
Phoenix Meadowlark | 5a8954e | 2021-03-17 18:22:12 -0700 | [diff] [blame] | 8 | from iree.tf.support import tf_test_utils |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 9 | import numpy as np |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 10 | import tensorflow.compat.v2 as tf |
| 11 | |
| 12 | |
| 13 | class ResourcesOpsModule(tf.Module): |
| 14 | |
| 15 | def __init__(self): |
Phoenix Meadowlark | 896137d | 2020-10-19 09:48:31 -0700 | [diff] [blame] | 16 | super().__init__() |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 17 | self.counter = tf.Variable(0.0) |
| 18 | |
| 19 | @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) |
| 20 | def add_assign(self, value): |
| 21 | return self.counter.assign_add(value) |
| 22 | |
Stella Laurenzo | 6513787 | 2021-06-18 13:15:07 -0700 | [diff] [blame] | 23 | @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) |
| 24 | def set_value(self, new_value): |
| 25 | self.counter.assign(new_value) |
| 26 | |
| 27 | @tf.function(input_signature=[]) |
| 28 | def get_value(self): |
| 29 | return self.counter |
| 30 | |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 31 | |
Phoenix Meadowlark | 6d32fe8 | 2020-07-31 12:35:51 -0700 | [diff] [blame] | 32 | class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase): |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 33 | |
Phoenix Meadowlark | 4b506b7 | 2020-10-06 11:19:13 -0700 | [diff] [blame] | 34 | def __init__(self, *args, **kwargs): |
Phoenix Meadowlark | 896137d | 2020-10-19 09:48:31 -0700 | [diff] [blame] | 35 | super().__init__(*args, **kwargs) |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 36 | self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule) |
| 37 | |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 38 | def test_add_assign(self): |
Phoenix Meadowlark | 6d32fe8 | 2020-07-31 12:35:51 -0700 | [diff] [blame] | 39 | |
| 40 | def add_assign(module): |
| 41 | module.add_assign(np.array(9., dtype=np.float32)) |
| 42 | |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 43 | self.compare_backends(add_assign, self._modules) |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 44 | |
Stella Laurenzo | 6513787 | 2021-06-18 13:15:07 -0700 | [diff] [blame] | 45 | def test_assign_get(self): |
| 46 | |
| 47 | def assign_get(module): |
| 48 | module.set_value(np.array(9., dtype=np.float32)) |
| 49 | return module.get_value() |
| 50 | |
| 51 | self.compare_backends(assign_get, self._modules) |
| 52 | |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 53 | |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 54 | def main(argv): |
| 55 | del argv # Unused |
| 56 | if hasattr(tf, 'enable_v2_behavior'): |
Rob Suderman | 4216552 | 2020-05-21 17:42:05 -0700 | [diff] [blame] | 57 | tf.enable_v2_behavior() |
| 58 | tf.test.main() |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 59 | |
| 60 | |
| 61 | if __name__ == '__main__': |
| 62 | app.run(main) |