Example e2e test for control flow Test that runs a TF module through the entire compiler to iree. This includes a test of both tf.While and tf.If. PiperOrigin-RevId: 281611370
diff --git a/bindings/python/pyiree/compiler.py b/bindings/python/pyiree/compiler.py index f7ce73f..9a59eb7 100644 --- a/bindings/python/pyiree/compiler.py +++ b/bindings/python/pyiree/compiler.py
@@ -35,6 +35,7 @@ "tf-executor-graph-pruning", "tf-standard-pipeline", "canonicalize", + "xla-legalize-tf-control-flow", "xla-legalize-tf", ) @@ -92,6 +93,7 @@ Returns: An OpaqueBlob representing the compiled module. """ + print(pass_pipeline) input_module = tf_load_saved_model(saved_model_dir, compiler_context, exported_names, pass_pipeline) return input_module.compile_to_sequencer_blob(
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD index b9765e0..2a1902c 100644 --- a/integrations/tensorflow/e2e/BUILD +++ b/integrations/tensorflow/e2e/BUILD
@@ -24,6 +24,15 @@ ) py_test( + name = "control_flow_test", + srcs = ["control_flow_test.py"], + python_version = "PY3", + deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [ + "//bindings/python/pyiree", + ], +) + +py_test( name = "simple_arithmetic_test", srcs = ["simple_arithmetic_test.py"], python_version = "PY3",
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py new file mode 100644 index 0000000..9717066 --- /dev/null +++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -0,0 +1,60 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy +from pyiree import tf_test_utils +import tensorflow.compat.v2 as tf + + +class ControlFlowModule(tf.Module): + + def __init__(self): + pass + + @tf.function(input_signature=[tf.TensorSpec([], tf.float32)]) + def collatz(self, a): + i = 0. + while a > 1.: + i = i + 1. + if (a % 2.) > 0.: + a = 3. * a + 1. + else: + a = a / 2. + return i + + +@tf_test_utils.compile_modules( + [tf_test_utils.BackendInfo.ALL["iree_interpreter"]], + control_flow=ControlFlowModule) +class ControlFlowTest(tf_test_utils.SavedModelTestCase): + + def test_short_sequence(self): + input_array = numpy.array(9., dtype=numpy.float32) + result = self.modules.control_flow.all.collatz(input_array) + result.print().assert_all_close() + + def test_long_sequence(self): + input_array = numpy.array(178., dtype=numpy.float32) + result = self.modules.control_flow.all.collatz(input_array) + result.print().assert_all_close() + + +if __name__ == "__main__": + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main()