blob: 0d8f81ad08b607bdb4ef596638d3778ae24c4ee6 [file] [log] [blame]
# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test utilities interop with TensorFlow."""
import os
import tempfile
import timeit
from .. import binding
from .. import compiler
import numpy as np
import tensorflow.compat.v2 as tf
def save_and_compile_tf_module(tf_module):
with tempfile.TemporaryDirectory() as sm_path:
options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, sm_path, options=options)
return compiler.tf_compile_saved_model(sm_path)
def dump_iree_module(m):
print("Loaded module:", m.name)
i = 0
while True:
f = m.lookup_function_by_ordinal(i)
if not f:
break
print(" Export:", f.name, "-> args(", f.signature.argument_count,
"), results(", f.signature.result_count, ")")
i += 1
def get_default_test_backends():
backends_env = os.environ.get("IREE_TEST_BACKENDS")
if backends_env:
return backends_env.split(",")
else:
return ("tf", "iree.interpreter")
class _TfBackend(object):
"""Backend for running directly on the TF module."""
def __init__(self, test_case, backend_name, fn_name):
self.backend_name = backend_name
self.module_f = getattr(test_case.tf_module, fn_name)
def __call__(self, *args):
return self.module_f(*args)
def postprocess(self, results):
# Handle single result (technically ambiguous with return of a tuple).
if not isinstance(results, tuple):
results = (results,)
# TODO(laurenzo): Handle structure mapping, etc.
return [r.numpy() for r in results]
class _IreeBackend(object):
"""Backend for running on an IREE driver."""
def __init__(self, test_case, backend_name, fn_name):
self.backend_name = backend_name
driver_name = backend_name.split(".")[-1]
self.policy = binding.rt.Policy()
instance = binding.rt.Instance(driver_name=driver_name)
self.context = binding.rt.Context(instance=instance, policy=self.policy)
self.context.register_module(test_case.iree_vm_module)
self.f = self.context.resolve_function("module." + fn_name)
def __call__(self, *args):
args = [self.context.wrap_for_input(arg) for arg in args]
# Invoke the function and wait for completion.
inv = self.context.invoke(self.f, self.policy, args)
inv.await_ready()
# Get results as a numpy array.
results = [np.array(r.map(), copy=False) for r in inv.results]
return results
def postprocess(self, results):
return results
_ALL_BACKENDS = {
"tf": _TfBackend,
"iree.interpreter": _IreeBackend,
"iree.vulkan": _IreeBackend,
}
def _wrap_per_backend_fn(saved_model_test_case, fn_name, iterations=100):
"""Generates a wrapper function for a backend fn name."""
def invoke_fn(*args):
"""Lambda that invokes the function on all backends."""
backend_names = saved_model_test_case.BACKENDS
if not backend_names:
backend_names = get_default_test_backends()
backends = [
_ALL_BACKENDS[b](saved_model_test_case, b, fn_name)
for b in backend_names
]
test_id = saved_model_test_case.id().split(".")[-1]
per_backend_results = []
binding.tracing.enable_thread()
for backend in backends:
# pylint: disable=cell-var-from-loop
print(":INVOKE %s:%s on %s" % (test_id, fn_name, backend.backend_name))
event = binding.tracing.ScopedEvent(
"%s_%s#%s" % (test_id, fn_name, backend.backend_name))
def run_iteration():
with event:
return backend(*args)
# Run one for correctness.
results = backend.postprocess(run_iteration())
per_backend_results.append((backend.backend_name, results))
# Then time it.
backend_time_ms = timeit.timeit(run_iteration, number=iterations) * 1000
iteration_time_ms = backend_time_ms / iterations
print(":BENCHMARK %s:%s on %s: time=%rms" %
(test_id, fn_name, backend.backend_name, iteration_time_ms))
# pylint: enable=cell-var-from-loop
# Verify results.
ref_backend_name, ref_results = per_backend_results[0]
print(":REF RESULTS %s:%s %s:" % (test_id, fn_name, ref_backend_name),
ref_results)
for backend_name, results in per_backend_results[1:]:
print(":COMPARE %s:%s %s vs %s" %
(test_id, fn_name, ref_backend_name, backend_name))
print(" :", results)
for ref_result, result in zip(ref_results, results):
saved_model_test_case.assertAllClose(
ref_result,
result,
msg="Result mismatch %s vs %s" % (ref_backend_name, backend_name))
return ref_results
return invoke_fn
def per_backend_test(*fn_names):
"""Wraps a SavedModelTestCase test method to run per backend tests.
Args:
*fn_names: Names of functions to run tests against. These will be converted
to python functions that invoke all of the backends and passed to the test
case method.
Returns:
A decorated function.
"""
def decorator(f):
def replacement(self):
fns = [_wrap_per_backend_fn(self, fn_name) for fn_name in fn_names]
f(self, *fns)
replacement.__name__ = f.__name__
return replacement
return decorator
class SavedModelTestCase(tf.test.TestCase):
"""Tests against a SavedModel.
Use this by subclassing and then defining a TF_MODULE_CONSTRUCTOR member.
"""
TF_MODULE_CONSTRUCTOR = None
TRACE_FILE_NAME = None
BACKENDS = None
@classmethod
def tearDownClass(cls):
trace_file_name = cls.TRACE_FILE_NAME
if not trace_file_name:
trace_file_name = cls.__name__ + ".wtf-trace"
trace_file = os.path.join(tempfile.gettempdir(), trace_file_name)
print("Flushing trace file to:", trace_file)
binding.tracing.flush(trace_file)
print("Flush complete")
super().tearDownClass()
@classmethod
def setUpClass(cls):
super().setUpClass()
if cls.TF_MODULE_CONSTRUCTOR is None:
raise ValueError("Expected a class level TF_MODULE_CONSTRUCTOR")
# Compile the module. We do this once.
cls.tf_module = cls.TF_MODULE_CONSTRUCTOR() # pylint: disable=not-callable
cls.iree_blob = save_and_compile_tf_module(cls.tf_module)
cls.iree_vm_module = binding.vm.create_module_from_blob(cls.iree_blob)
dump_iree_module(cls.iree_vm_module)