| # Copyright 2020 The IREE Authors |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| """Tests for ops in the tf.math module.""" |
| from iree.tf.support import tf_test_utils |
| from iree.tf.support import tf_utils |
| import tensorflow.compat.v2 as tf |
| class QuantizationDynModule(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)]) |
| return tf.quantization.fake_quant_with_min_max_args(x, |
| class QuantizationDynTest(tf_test_utils.TracedModuleTestCase): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._modules = tf_test_utils.compile_tf_module(QuantizationDynModule) |
| def test_fake_quant(self): |
| module.fake_quant(tf_utils.uniform([32], low=-6, high=6)) |
| self.compare_backends(abs, self._modules) |
| if hasattr(tf, 'enable_v2_behavior'): |
| if __name__ == '__main__': |