blob: e7656eaf24324f30e9eb1ccced80b122ba5deb17 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utilities for running tests from TensorFlow models."""
import contextlib
import io
import os
import subprocess
import sys
import tempfile
import traceback
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
from .. import binding
flags.DEFINE_string("filecheck_binary", "filecheck",
"Location of the filecheck binary.")
flags.DEFINE_bool("disable_filecheck", False,
"Disables filecheck redirection (for debugging).")
FLAGS = flags.FLAGS
ALL_TEST_DICTS = []
def add_test(**kwargs):
assert "test_name" in kwargs, "'test_name' is a required argument"
ALL_TEST_DICTS.append(kwargs)
def _run_test(test_dict):
"""Runs an individual test dict."""
tf_module_builder_lambda = test_dict["tf_module_builder"]
tf_module = tf_module_builder_lambda()
ctx = binding.compiler.CompilerContext()
with tempfile.TemporaryDirectory() as sm_path:
options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, sm_path, options=options)
input_module = binding.tf_interop.load_saved_model(ctx, sm_path)
passes = test_dict.get("passes")
expect_pass_failure = test_dict.get("expect_pass_failure")
if passes:
try:
input_module.run_pass_pipeline(passes)
except: # pylint: disable=bare-except
if not expect_pass_failure:
print(
"UNEXPECTED PASS FAILURE (INTERMEDIATE ASM FOLLOWS ON STDERR):",
file=sys.stderr)
print(input_module.to_asm(), file=sys.stderr)
raise
# Print the input module ASM.
if test_dict.get("print_input_module"):
print(input_module.to_asm())
def _internal_run_tests():
"""Main function that runs all tests."""
test_count = 0
for test_dict in ALL_TEST_DICTS:
test_count += 1
test_name = test_dict["test_name"]
print("RUN_TEST:", test_name)
try:
_run_test(test_dict)
print("FINISH_TEST:", test_name)
except: # pylint: disable=bare-except
# Error goes to stdout for FileCheck.
traceback.print_exc(file=sys.stdout)
print("FINISH_TEST_WITH_EXCEPTION:", test_name)
print("FINISHED: RAN", test_count, "TESTS", file=sys.stderr)
def _find_filecheck():
filecheck_binary = FLAGS.filecheck_binary
if os.path.isabs(filecheck_binary):
return filecheck_binary
# TODO(laurenzo): Why is this runfiles resolution so hard and undocumented.
# Talk to bazel team.
runfiles_dir = os.environ.get("RUNFILES_DIR")
if runfiles_dir:
workspace_name = os.environ.get("TEST_WORKSPACE")
if workspace_name:
runfiles_dir = os.path.join(runfiles_dir, workspace_name)
filecheck_binary = os.path.join(runfiles_dir, filecheck_binary)
# Convert forward slash version to platform default (Windows).
filecheck_binary = filecheck_binary.replace("/", os.path.sep)
return filecheck_binary
def run_tests(main_file, with_filecheck=True):
"""Main entry point."""
def internal_main(unused_argv):
"""App main."""
# In case if running with a version prior to v2 defaulting.
tf.enable_v2_behavior()
if with_filecheck and not FLAGS.disable_filecheck:
# Capture and run through filecheck.
filecheck_capture_io = io.StringIO()
with contextlib.redirect_stdout(filecheck_capture_io):
_internal_run_tests()
filecheck_capture_io.flush()
filecheck_input = filecheck_capture_io.getvalue()
# Convert forward slash version to platform default (Windows).
filecheck_binary = _find_filecheck()
filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"]
print("LAUNCHING FILECHECK:", filecheck_args, file=sys.stderr)
p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE)
p.communicate(filecheck_input.encode("UTF-8"))
sys.exit(p.returncode)
else:
# Just run directly.
_internal_run_tests()
app.run(internal_main)