Tensorflow to IREE test case for tf.Fill, which is failing because broadcast_in_dims can't infer its output shape based on its inputs with the new tf2xla path. For now, only run the test on the 'tf' backend PiperOrigin-RevId: 300585192
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD index f4ecdf3..b1fb681 100644 --- a/integrations/tensorflow/e2e/BUILD +++ b/integrations/tensorflow/e2e/BUILD
@@ -37,6 +37,7 @@ ) for name in [ "batch_norm_test", + "fill_test", "control_flow_test", "dynamic_mlp_test", "exported_names_test",
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py new file mode 100644 index 0000000..29e3ab7 --- /dev/null +++ b/integrations/tensorflow/e2e/fill_test.py
@@ -0,0 +1,48 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from pyiree.tf.support import tf_test_utils +import tensorflow.compat.v2 as tf + + +class FillModule(tf.Module): + + def __init__(self): + pass + + @tf.function(input_signature=[ + tf.TensorSpec([2], tf.int32), + tf.TensorSpec([], tf.float32) + ]) + def fill(self, dims, value): + return tf.fill(dims, value) + + +# TODO(jennik): Get this test working on IREE. +@tf_test_utils.compile_modules(backends=["tf"], fill=FillModule) +class FillTest(tf_test_utils.SavedModelTestCase): + + def test_fill(self): + dims = np.array([2, 3], dtype=np.int32) + value = np.array(9., dtype=np.float32) + + result = self.modules.fill.all.fill(dims, value) + result.assert_all_close() + + +if __name__ == "__main__": + if hasattr(tf, "enable_v2_behavior"): + tf.enable_v2_behavior() + tf.test.main()