blob: df89f897b3a48cbf3802f1d15145a4f955b79a7f [file] [log] [blame]
Stella Laurenzo1c2f6552020-01-21 10:44:57 -08001# Lint as: python3
Geoffrey Martin-Noble552d3f82021-05-25 17:56:09 -07002# Copyright 2020 The IREE Authors
Stella Laurenzo1c2f6552020-01-21 10:44:57 -08003#
Geoffrey Martin-Noble552d3f82021-05-25 17:56:09 -07004# Licensed under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Stella Laurenzo1c2f6552020-01-21 10:44:57 -08007
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -07008from absl import app
Phoenix Meadowlark5a8954e2021-03-17 18:22:12 -07009from iree.tf.support import tf_test_utils
10from iree.tf.support import tf_utils
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080011import numpy as np
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080012import tensorflow.compat.v2 as tf
13
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080014HIDDEN_1_DIM = 256
15HIDDEN_2_DIM = 256
16INPUT_DIM = 728 # 28 * 28
17CLASSES = 10
18
19
Phoenix Meadowlarke15008c2020-08-14 08:45:19 -070020class DynamicMlpModule(tf.Module):
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080021
22 def __init__(self,
23 hidden_1_dim=256,
24 hidden_2_dim=256,
25 input_dim=28 * 28,
26 classes=10):
27 super().__init__()
Phoenix Meadowlarkc53ad002020-07-10 11:14:16 -070028 tf_utils.set_random_seed()
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080029 self.hidden_1_dim = hidden_1_dim
30 self.hidden_2_dim = hidden_2_dim
31 self.input_dim = input_dim
32 self.classes = classes
33 self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070034 self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
35 hidden_2_dim]))
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080036 self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
37 self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
38 self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
39 self.out_bias = tf.Variable(tf.random.normal([classes]))
40
41 # Compile with dynamic batch dim.
42 self.predict = tf.function(
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070043 input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080044
45 def mlp(self, x):
46 layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
47 layer_2 = tf.sigmoid(
48 tf.add(tf.matmul(layer_1, self.h2_weights), self.h2_bias))
49 return tf.sigmoid(
50 tf.add(tf.matmul(layer_2, self.out_weights), self.out_bias))
51
52 def predict(self, x):
Stella Laurenzoc2faf482020-05-29 17:41:47 -070053 return tf.nn.softmax(self.mlp(x))
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080054
55
Phoenix Meadowlark6d32fe82020-07-31 12:35:51 -070056class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080057
Phoenix Meadowlark4b506b72020-10-06 11:19:13 -070058 def __init__(self, *args, **kwargs):
Phoenix Meadowlark896137d2020-10-19 09:48:31 -070059 super().__init__(*args, **kwargs)
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070060 self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule,
61 exported_names=["predict"])
62
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080063 def test_dynamic_batch(self):
Phoenix Meadowlark6d32fe82020-07-31 12:35:51 -070064
65 def dynamic_batch(module):
66 x = tf_utils.uniform([3, 28 * 28]) * 1e-3
67 module.predict(x)
68
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070069 self.compare_backends(dynamic_batch, self._modules)
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080070
71
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070072def main(argv):
73 del argv # Unused
74 if hasattr(tf, 'enable_v2_behavior'):
Stella Laurenzo1c2f6552020-01-21 10:44:57 -080075 tf.enable_v2_behavior()
76 tf.test.main()
Phoenix Meadowlark61cb3a62020-10-01 13:35:03 -070077
78
79if __name__ == '__main__':
80 app.run(main)