Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 1 | # Lint as: python3 |
Geoffrey Martin-Noble | 552d3f8 | 2021-05-25 17:56:09 -0700 | [diff] [blame] | 2 | # Copyright 2020 The IREE Authors |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 3 | # |
Geoffrey Martin-Noble | 552d3f8 | 2021-05-25 17:56:09 -0700 | [diff] [blame] | 4 | # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 5 | # See https://llvm.org/LICENSE.txt for license information. |
| 6 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 7 | |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 8 | from absl import app |
Phoenix Meadowlark | 5a8954e | 2021-03-17 18:22:12 -0700 | [diff] [blame] | 9 | from iree.tf.support import tf_test_utils |
| 10 | from iree.tf.support import tf_utils |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 11 | import numpy as np |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 12 | import tensorflow.compat.v2 as tf |
| 13 | |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 14 | HIDDEN_1_DIM = 256 |
| 15 | HIDDEN_2_DIM = 256 |
| 16 | INPUT_DIM = 728 # 28 * 28 |
| 17 | CLASSES = 10 |
| 18 | |
| 19 | |
Phoenix Meadowlark | e15008c | 2020-08-14 08:45:19 -0700 | [diff] [blame] | 20 | class DynamicMlpModule(tf.Module): |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 21 | |
| 22 | def __init__(self, |
| 23 | hidden_1_dim=256, |
| 24 | hidden_2_dim=256, |
| 25 | input_dim=28 * 28, |
| 26 | classes=10): |
| 27 | super().__init__() |
Phoenix Meadowlark | c53ad00 | 2020-07-10 11:14:16 -0700 | [diff] [blame] | 28 | tf_utils.set_random_seed() |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 29 | self.hidden_1_dim = hidden_1_dim |
| 30 | self.hidden_2_dim = hidden_2_dim |
| 31 | self.input_dim = input_dim |
| 32 | self.classes = classes |
| 33 | self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim])) |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 34 | self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim, |
| 35 | hidden_2_dim])) |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 36 | self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes])) |
| 37 | self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim])) |
| 38 | self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim])) |
| 39 | self.out_bias = tf.Variable(tf.random.normal([classes])) |
| 40 | |
| 41 | # Compile with dynamic batch dim. |
| 42 | self.predict = tf.function( |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 43 | input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict) |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 44 | |
| 45 | def mlp(self, x): |
| 46 | layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias)) |
| 47 | layer_2 = tf.sigmoid( |
| 48 | tf.add(tf.matmul(layer_1, self.h2_weights), self.h2_bias)) |
| 49 | return tf.sigmoid( |
| 50 | tf.add(tf.matmul(layer_2, self.out_weights), self.out_bias)) |
| 51 | |
| 52 | def predict(self, x): |
Stella Laurenzo | c2faf48 | 2020-05-29 17:41:47 -0700 | [diff] [blame] | 53 | return tf.nn.softmax(self.mlp(x)) |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 54 | |
| 55 | |
Phoenix Meadowlark | 6d32fe8 | 2020-07-31 12:35:51 -0700 | [diff] [blame] | 56 | class DynamicMlpTest(tf_test_utils.TracedModuleTestCase): |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 57 | |
Phoenix Meadowlark | 4b506b7 | 2020-10-06 11:19:13 -0700 | [diff] [blame] | 58 | def __init__(self, *args, **kwargs): |
Phoenix Meadowlark | 896137d | 2020-10-19 09:48:31 -0700 | [diff] [blame] | 59 | super().__init__(*args, **kwargs) |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 60 | self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule, |
| 61 | exported_names=["predict"]) |
| 62 | |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 63 | def test_dynamic_batch(self): |
Phoenix Meadowlark | 6d32fe8 | 2020-07-31 12:35:51 -0700 | [diff] [blame] | 64 | |
| 65 | def dynamic_batch(module): |
| 66 | x = tf_utils.uniform([3, 28 * 28]) * 1e-3 |
| 67 | module.predict(x) |
| 68 | |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 69 | self.compare_backends(dynamic_batch, self._modules) |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 70 | |
| 71 | |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 72 | def main(argv): |
| 73 | del argv # Unused |
| 74 | if hasattr(tf, 'enable_v2_behavior'): |
Stella Laurenzo | 1c2f655 | 2020-01-21 10:44:57 -0800 | [diff] [blame] | 75 | tf.enable_v2_behavior() |
| 76 | tf.test.main() |
Phoenix Meadowlark | 61cb3a6 | 2020-10-01 13:35:03 -0700 | [diff] [blame] | 77 | |
| 78 | |
| 79 | if __name__ == '__main__': |
| 80 | app.run(main) |