Enable MobileBert on VMLA via SignatureDef SavedModels (#3307)
Adds a named constructor `create_from_signature_def_saved_model` to compile and test SignatureDef SavedModels.
Also fixes #3258.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
index 3d28439..7787a49 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
@@ -59,6 +59,32 @@
],
})
+# TODO: Isolate SignatureDef SavedModel support into its own library to decrease buildtime cost.
+# While it would be nice to simply depend on TensorFlow, manually paring
+# down the dependencies significantly reduces the build time that this adds.
+#
+# Baseline: 449s
+# SignatureDef SavedModels: 546s – 22% increase in build time.
+# SignatureDef SavedModels + Deps for MobileBert: 572s – 27% increase in build time.
+# TF OpenSource: 664s – 49% increase in build time.
+SIGNATURE_DEF_SAVED_MODEL_TF_RUNTIME_DEPS = [
+ # Deps for SignatureDef SavedModels:
+ "@org_tensorflow//tensorflow/core:direct_session",
+ "@org_tensorflow//tensorflow/core/kernels:resource_variable_ops", # VarHandleOp
+ "@org_tensorflow//tensorflow/core/kernels:regex_full_match_op", # StaticRegexFullMatch
+ "@org_tensorflow//tensorflow/core/kernels:string_join_op", # StringJoin
+ "@org_tensorflow//tensorflow/core/kernels:save_op", # SharedFilename
+ "@org_tensorflow//tensorflow/core/kernels:save_restore_v2_ops", # SaveV2
+
+ # Deps for MobileBert:
+ "@org_tensorflow//tensorflow/core/kernels:parameterized_truncated_normal_op", # TruncatedNormal
+ "@org_tensorflow//tensorflow/core/kernels:state", # Assign.
+ "@org_tensorflow//tensorflow/core/kernels:logging_ops", # Assert
+ "@org_tensorflow//tensorflow/core/kernels:bias_op", # BiasAdd
+ "@org_tensorflow//tensorflow/core/kernels:softmax_op", # Softmax
+ "@org_tensorflow//tensorflow/core/kernels:relu_op", # Relu
+]
+
TF_XLA_PASS_DEPS = [
"//integrations/tensorflow/compiler:tensorflow",
"@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
@@ -101,7 +127,7 @@
hdrs = [
"register_tensorflow.h",
],
- deps = SAVED_MODEL_TF_RUNTIME_DEPS + TF_XLA_PASS_DEPS + [
+ deps = SAVED_MODEL_TF_RUNTIME_DEPS + TF_XLA_PASS_DEPS + SIGNATURE_DEF_SAVED_MODEL_TF_RUNTIME_DEPS + [
"//bindings/python/pyiree/common",
"//bindings/python/pyiree/compiler:compiler_library",
"@llvm-project//llvm:Support",
@@ -127,10 +153,6 @@
name = "signature_def_saved_model_test",
srcs = ["signature_def_saved_model_test.py"],
python_version = "PY3",
- tags = [
- "manual",
- "nokokoro",
- ],
deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
"@absl_py//absl/testing:absltest",
"//integrations/tensorflow/bindings/python/pyiree/tf/compiler",
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 43c8d08..4d432d8 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -30,7 +30,7 @@
import pickle
import sys
import tempfile
-from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
from absl import flags
from absl import logging
@@ -383,15 +383,24 @@
"Expected ref and tar to have the same dtype, but got %s and %s",
ref.dtype, tar.dtype)
return False
+ if ref.size == tar.size == 0:
+ return True
+
if np.issubdtype(ref.dtype, np.floating):
same = np.allclose(ref, tar, rtol=rtol, atol=atol)
+ abs_diff = np.max(np.abs(ref - tar))
+ rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
if not same:
- abs_diff = np.max(np.abs(ref - tar))
- rel_diff = np.max(np.abs(ref - tar) / np.max(tar))
logging.error(
"Floating point difference between ref and tar was too large. "
"Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
abs_diff, atol, rel_diff, rtol)
+ else:
+ logging.info(
+ "Floating point difference between ref and tar was within "
+ "tolerance. "
+ "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
+ abs_diff, atol, rel_diff, rtol)
return same
else:
return np.array_equal(ref, tar)
@@ -598,6 +607,48 @@
return Modules(ref_module, tar_modules, artifacts_dir)
+def compile_tf_signature_def_saved_model(saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str, exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str]):
+ """Compiles a SignatureDef SavedModel to each backend that we test.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+
+ Returns:
+ A 'Modules' namedtuple containing the reference module, target modules and
+ artifacts directory.
+ """
+
+ # Setup the directory for saving compilation artifacts and traces.
+ artifacts_dir = _setup_artifacts_dir(module_name)
+
+ # Get the backend information for this test.
+ ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
+ tar_backend_infos = get_target_backends()
+
+ compile_backend = (
+ lambda backend_info: backend_info.compile_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, module_name, exported_name,
+ input_names, output_names, artifacts_dir))
+
+ ref_module = compile_backend(ref_backend_info)
+ tar_modules = [
+ compile_backend(backend_info) for backend_info in tar_backend_infos
+ ]
+ return Modules(ref_module, tar_modules, artifacts_dir)
+
+
class TracedModuleTestCase(tf.test.TestCase):
"""Compiles a tf.Module to multiple backends to test their correctness."""
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index b3237f9..a4a731a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -21,7 +21,7 @@
import random
import re
import tempfile
-from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
from absl import flags
from absl import logging
@@ -183,7 +183,7 @@
exported_names: Sequence[str] = (),
artifacts_dir: str = None,
) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+ """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
The module blob this creates is not callable. See IreeCompiledModule for an
API that returns a module that can be called without any further steps.
@@ -193,7 +193,7 @@
Args:
module: A tf.Module.
- backend_info: BackendInfo with the details for compiling module to IREE.
+ backend_info: BackendInfo with the details for compiling this module.
exported_names: Optional sequence representing the exported names to keep.
artifacts_dir: An optional string pointing to where compilation artifacts
should be saved. No compilation artifacts will be saved if this is not
@@ -204,7 +204,7 @@
artifacts_dir is provided.
"""
- def _compile_module(module, exported_names, backend_info, artifacts_dir):
+ def _compile_module(module, backend_info, exported_names, artifacts_dir):
compiler_module = compiler.tf_module_to_compiler_module(module,
exported_names,
pass_pipeline=())
@@ -213,7 +213,47 @@
_compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
backend_info.backend_id)
- return _compile_module(module, exported_names, backend_info, artifacts_dir)
+ return _compile_module(module, backend_info, exported_names, artifacts_dir)
+
+
+def _incrementally_compile_tf_signature_def_saved_model(
+ saved_model_dir: str, saved_model_tags: Set[str],
+ backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
+ """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+
+ Returns:
+ A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+ artifacts_dir is provided.
+ """
+
+ def _compile_module(saved_model_dir, saved_model_tags, backend_info,
+ exported_name, artifacts_dir):
+ # Convert the tf_module into raw TF input MLIR.
+ compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
+
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+ backend_info.backend_id)
+ return _compile_module(saved_model_dir, saved_model_tags, backend_info,
+ exported_name, artifacts_dir)
class CompiledModule(object):
@@ -252,7 +292,7 @@
Args:
module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
+ backend_info: BackendInfo with the details for compiling this module.
exported_names: Optional sequence representing the exported names to keep.
artifacts_dir: An optional string pointing to where compilation artifacts
should be saved. No compilation artifacts will be saved if this is not
@@ -280,6 +320,33 @@
"""
raise NotImplementedError()
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
def iree_serializable(self):
return False
@@ -301,8 +368,8 @@
self._context = context
self._f = f
- def __call__(self, *args):
- return self._f(*args)
+ def __call__(self, *args, **kwargs):
+ return self._f(*args, **kwargs)
def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
"""Get cxx serialized inputs and outputs for this function."""
@@ -390,6 +457,46 @@
return cls(module_name, backend_info, compiled_paths, vm_module, config)
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ del input_names # Unused.
+ del output_names # Unused.
+ module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, backend_info, exported_name,
+ artifacts_dir)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
+
+ compiled_paths = None
+ if compiled_path is not None:
+ # IREE bundles every compiled method into the same compiled module :)
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
def reinitialize(self):
"""Reinitializes all stateful variables."""
# set_random_seed is not needed here because the model_class.__init__ is not
@@ -441,6 +548,28 @@
check_types=False)
+def _convert_inputs_to_tensors(function):
+
+ def decorator(*args, **kwargs):
+ args = [tf.convert_to_tensor(arg) for arg in args]
+ kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
+ return function(*args, **kwargs)
+
+ return decorator
+
+
+class SignatureDefSavedModelWrapper(object):
+ """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
+
+ def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
+ exported_name: str):
+ self.saved_model = tf.saved_model.load(saved_model_dir,
+ tags=saved_model_tags)
+ inference_func = self.saved_model.signatures[exported_name]
+ inference_func = _convert_inputs_to_tensors(inference_func)
+ self.__setattr__(exported_name, inference_func)
+
+
class TfCompiledModule(CompiledModule):
"""TensorFlow 'compiled' module.
@@ -482,7 +611,7 @@
Args:
module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
+ backend_info: BackendInfo with the details for compiling this module.
exported_names: Optional sequence representing the exported names to keep.
artifacts_dir: An optional string pointing to where compilation artifacts
should be saved. No compilation artifacts will be saved if this is not
@@ -492,6 +621,35 @@
constructor = module_class
return cls(module_name, backend_info, constructor, exported_names)
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ constructor = lambda: SignatureDefSavedModelWrapper(
+ saved_model_dir, saved_model_tags, exported_name)
+ return cls(module_name, backend_info, constructor, [exported_name])
+
def reinitialize(self):
"""Reinitializes all stateful variables."""
set_random_seed()
@@ -530,58 +688,98 @@
return functions, exported_names
-def tf_module_to_tflite_interpreters(
- module_class: Type[tf.Module],
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None
-) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str]], None]:
- """Compile a tf.Module to TFLite interpreters for each of its methods.
+def tf_module_to_tflite_module_bytes(
+ module_class: Type[tf.Module], exported_names: Sequence[str] = ()
+) -> Dict[str, bytes]:
+ """Compiles a tf.Module's methods with TFLite.
Args:
- module_class: A tf.Module subclass to compile with TFLite. If module_class
- has an attr get_legacy_tflite_saved_model_converter_kwargs then it will
- be compiled using tf.compat.v1.lite. It's best not to use this, however.
+ module_class: A tf.Module subclass to compile with TFLite.
exported_names: an optional iterable of strings representing which of the
- module_class's functions should be callable. If exported_names is empty
- then all functions will be callable.
+ module_class's functions should be compiled. If exported_names is empty
+ then all functions will be compiled.
+
+ Returns:
+ A dict mapping method names to compiled TFLite module bytes.
+ """
+ tflite_modules = []
+ methods, method_names = _get_concrete_functions(module_class, exported_names)
+ for method in methods:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+ tflite_modules.append(converter.convert())
+ return dict(zip(method_names, tflite_modules))
+
+
+def tf_signature_def_saved_model_to_tflite_module_bytes(
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+) -> Dict[str, bytes]:
+ """Compiles a SignatureDef SavedModel signature with TFLite.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+
+ Returns:
+ A dict mapping the signature name to the compiled TFLite module bytes.
+ """
+ converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
+ saved_model_dir,
+ tag_set=saved_model_tags,
+ signature_key=exported_name,
+ input_arrays=input_names,
+ output_arrays=output_names)
+ tflite_module = converter.convert()
+ return dict([[exported_name, tflite_module]])
+
+
+def tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes: Dict[str, bytes],
+ artifacts_dir: str = None
+) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
+ """Compile a dict of TFLite compiled bytes to TFLite interpreters.
+
+ Args:
+ tflite_module_bytes: A dict mapping method names to compiled TFLite byte
+ strings.
artifacts_dir: an optional path to save compilation artifacts to.
Returns:
- A dictionary of function names to TFLite interpreters and a dictionary of
- function names to compiled tflite graph paths (or None if artifacts_dir)
- is None.
+ A dictionary mapping method names to TFLite interpreters and a dictionary
+ mapping method names to compiled tflite graph paths (or None if
+ artifacts_dir is None).
"""
interpreters = dict()
compiled_paths = None
if artifacts_dir is not None:
compiled_paths = dict()
- def _interpret_bytes(tflite_module: bytes, base_dir: str):
+ def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
"""Save compiled TFLite module bytes and convert into an interpreter."""
tflite_dir = os.path.join(base_dir, "tflite")
os.makedirs(tflite_dir, exist_ok=True)
- tflite_path = os.path.join(tflite_dir, f"{name}.tflite")
+ tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
with open(tflite_path, "wb") as f:
f.write(tflite_module)
- interpreters[name] = tf.lite.Interpreter(tflite_path)
+ interpreters[method_name] = tf.lite.Interpreter(tflite_path)
if artifacts_dir is not None:
- compiled_paths[name] = tflite_path
-
- # Convert module_class's methods into TFLite module byte-strings.
- tflite_modules = []
- functions, names = _get_concrete_functions(module_class, exported_names)
- for function in functions:
- converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
- tflite_modules.append(converter.convert())
+ compiled_paths[method_name] = tflite_path
# Load each of the converted methods above into tf.lite.Interpreters.
- for name, tflite_module in zip(names, tflite_modules):
+ for method_name, tflite_module in tflite_module_bytes.items():
if artifacts_dir is None:
with tempfile.TemporaryDirectory() as base_dir:
- _interpret_bytes(tflite_module, base_dir)
+ _interpret_bytes(method_name, tflite_module, base_dir)
else:
- _interpret_bytes(tflite_module, artifacts_dir)
+ _interpret_bytes(method_name, tflite_module, artifacts_dir)
return interpreters, compiled_paths
@@ -593,14 +791,20 @@
self._interpreter = interpreter
def __call__(self, *args, **kwargs) -> Tuple[Any]:
- if len(kwargs):
- raise ValueError("kwargs are not supported, but the following kwargs "
- f"were provided {kwargs}")
+ if len(args) and len(kwargs):
+ raise ValueError("Passing both args and kwargs is not supported by "
+ "_TfLiteFunctionWrapper")
# Set up and run the function.
self._interpreter.allocate_tensors()
- for arg, detail in zip(args, self._interpreter.get_input_details()):
- self._interpreter.set_tensor(detail["index"], arg)
+
+ if len(args):
+ for arg, detail in zip(args, self._interpreter.get_input_details()):
+ self._interpreter.set_tensor(detail["index"], arg)
+ else:
+ for detail in self._interpreter.get_input_details():
+ self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
+
self._interpreter.invoke()
# Extract the outputs from the TFLite interpreter.
@@ -664,18 +868,52 @@
Args:
module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
+ backend_info: BackendInfo with the details for compiling this module.
exported_names: Optional sequence representing the exported names to keep.
artifacts_dir: An optional string pointing to where compilation artifacts
should be saved. No compilation artifacts will be saved if this is not
provided.
"""
set_random_seed()
- interpreters, compiled_paths = tf_module_to_tflite_interpreters(
- module_class, exported_names, artifacts_dir)
+ tflite_module_bytes = tf_module_to_tflite_module_bytes(
+ module_class, exported_names)
+ interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes, artifacts_dir)
module_name = module_class.__name__
return cls(module_name, backend_info, compiled_paths, interpreters)
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
+ saved_model_dir, saved_model_tags, exported_name, input_names,
+ output_names)
+ interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes, artifacts_dir)
+ return cls(module_name, backend_info, compiled_paths, interpreters)
+
def reinitialize(self):
"""Reinitializes all stateful variables."""
# This is a noop because TFLite (mostly) doesn't support stateful modules.
@@ -761,6 +999,19 @@
return self._compiled_module_class.create_from_class(
module_class, self, exported_names, artifacts_dir)
+ def compile_signature_def_saved_model(
+ self,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None) -> CompiledModule:
+ return self._compiled_module_class.create_from_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, module_name, self, exported_name,
+ input_names, output_names, artifacts_dir)
+
@classmethod
def get_all_backends(cls) -> Sequence["BackendInfo"]:
"""Returns a list of all BackendInfo configurations."""
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index dc289d6..5b89789 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -218,12 +218,18 @@
iree_e2e_test_suite(
name = "mobile_bert_squad_tests",
+ size = "enormous",
backends_to_srcs = {
"tf": ["mobile_bert_squad_test.py"],
+ "tflite": ["mobile_bert_squad_test.py"],
+ "iree_vmla": ["mobile_bert_squad_test.py"],
},
reference_backend = "tf",
tags = [
+ "external",
+ "guitar",
"manual",
+ "no-remote",
"nokokoro",
"notap",
],
@@ -234,15 +240,17 @@
iree_e2e_test_suite(
name = "mobile_bert_squad_tests_failing",
+ size = "enormous",
backends_to_srcs = {
- "tflite": ["mobile_bert_squad_test.py"],
- "iree_vmla": ["mobile_bert_squad_test.py"],
"iree_llvmjit": ["mobile_bert_squad_test.py"],
"iree_vulkan": ["mobile_bert_squad_test.py"],
},
reference_backend = "tf",
tags = [
+ "external",
+ "guitar",
"manual",
+ "no-remote",
"nokokoro",
"notap",
],
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 7adb281..553faf7 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -221,14 +221,10 @@
TODO(silvasean): debugging miscompiles
-## Legacy TFLite Compilation
+## Testing SignatureDef SavedModels
-_Please don't use this unless you are forced to._
-
-We support using `tf.compat.v1.lite.TFLiteConverter.from_saved_model` to compile
-older `tf.Module`s with TFLite. This will be used if the `tf.Module` being
-tested has a method named `get_legacy_tflite_saved_model_converter_kwargs`. This
-method must return a dict with the following kwargs: `model_path`,
-`input_arrays`, `output_arrays`, and `exported_name`. The module must use only
-one exported name, and `exported_name` should be equal to that name. See
-`mobile_bert_squad_test.py` for a concrete example.
+TensorFlow 1.x SavedModels can be tested using
+`tf_test_utils.compile_tf_signature_def_saved_model` instead of
+`tf_test_utils.compile_tf_module`. See `mobile_bert_squad_test.py` for a
+concrete example. The compilation artifacts will be saved under whatever
+you specify for `module_name`.
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 60bbff4..159633a 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -21,6 +21,7 @@
backends_to_srcs,
reference_backend,
deps = None,
+ size = None,
tags = None,
python_version = "PY3",
**kwargs):
@@ -73,6 +74,7 @@
srcs = [src],
deps = deps,
args = args,
+ size = size,
tags = py_test_tags,
python_version = python_version,
**kwargs
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index 4f8daf5..b99759d 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -25,80 +25,57 @@
from absl import flags
import numpy as np
from pyiree.tf.support import tf_test_utils
+from pyiree.tf import compiler
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
-flags.DEFINE_boolean("use_quantized_weights", False,
- "Whether to use quantized or floating point weights.")
+flags.DEFINE_boolean('use_quantized_weights', False,
+ 'Whether to use quantized or floating point weights.')
MAX_SEQ_LENGTH = 384 # Max input sequence length used in mobilebert_squad.
-FILE_NAME = "mobilebert_squad_savedmodels.tar.gz"
+FILE_NAME = 'mobilebert_squad_savedmodels.tar.gz'
MODEL_URL = posixpath.join(
- "https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/",
+ 'https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/',
FILE_NAME)
-class MobileBertSquad(tf.Module):
- """Wrapper of MobileBertSquad saved model v1."""
-
- def __init__(self):
- self.model_path = self.get_model_path()
- self.saved_model = tf.saved_model.load(self.model_path, tags=["serve"])
- self.inference_func = self.saved_model.signatures["serving_default"]
-
- @staticmethod
- def get_model_path():
- model_type = "quant_saved_model" if FLAGS.use_quantized_weights else "float"
-
- # Get_file will download the model weights from a publicly available folder,
- # save them to cache_dir=~/.keras/datasets/ and return a path to them.
- model_path = tf.keras.utils.get_file(FILE_NAME, MODEL_URL, untar=True)
- model_dir = os.path.dirname(model_path)
- extracted_name = FILE_NAME.split(".")[0]
- model_path = os.path.join(model_dir, extracted_name, model_type)
- return model_path
-
- @staticmethod
- def get_legacy_tflite_saved_model_converter_kwargs():
- return dict([("input_arrays", ["input_ids", "input_mask", "segment_ids"]),
- ("output_arrays", ["start_logits", "end_logits"]),
- ("exported_name", "predict"),
- ("model_path", MobileBertSquad.get_model_path())])
-
- @tf.function(input_signature=[
- tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
- tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
- tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
- ])
- def predict(self, input_ids, input_mask, segment_ids):
- inputs = {
- "input_ids": input_ids,
- "input_mask": input_mask,
- "segment_ids": segment_ids,
- }
- return self.inference_func(**inputs)
-
-
class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
"""Tests of MobileBertSquad."""
- def __init__(self, methodName="runTest"):
+ def __init__(self, methodName='runTest'):
super(MobileBertSquadTest, self).__init__(methodName)
- self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
- exported_names=["predict"])
+ model_type = 'quant_saved_model' if FLAGS.use_quantized_weights else 'float'
- def test_predict(self):
+ # Get_file will download the model weights from a publicly available folder,
+ # save them to cache_dir=~/.keras/datasets/ and return a path to them.
+ model_path = tf.keras.utils.get_file(FILE_NAME, MODEL_URL, untar=True)
+ model_dir = os.path.dirname(model_path)
+ extracted_name = FILE_NAME.split('.')[0]
+ model_path = os.path.join(model_dir, extracted_name, model_type)
- def predict(module):
+ self._modules = tf_test_utils.compile_tf_signature_def_saved_model(
+ saved_model_dir=model_path,
+ saved_model_tags=set(['serve']),
+ module_name='MobileBertSquad',
+ exported_name='serving_default',
+ input_names=['input_ids', 'input_mask', 'segment_ids'],
+ output_names=['start_logits', 'end_logits'])
+
+ def test_serving_default(self):
+
+ def serving_default(module):
input_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
input_mask = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
segment_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
- module.predict(input_ids, input_mask, segment_ids, atol=1e0)
+ module.serving_default(input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ atol=1e0)
- self.compare_backends(predict, self._modules)
+ self.compare_backends(serving_default, self._modules)
def main(argv):
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 30503ba..b048baa 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -38,6 +38,8 @@
TEST_SUITES_TO_HEADERS = {
'//integrations/tensorflow/e2e:e2e_tests':
'End to end TensorFlow tests',
+ '//integrations/tensorflow/e2e:mobile_bert_squad_tests':
+ 'End to end test of MobileBert on SQuAD',
'//integrations/tensorflow/e2e/keras:keras_tests':
'End to end tests written using tf.keras',
'//integrations/tensorflow/e2e/keras:imagenet_external_tests':