Enable MobileBert on VMLA via SignatureDef SavedModels (#3307)

Adds a named constructor `create_from_signature_def_saved_model` to compile and test SignatureDef SavedModels.

Also fixes #3258.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
index 3d28439..7787a49 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/BUILD
@@ -59,6 +59,32 @@
     ],
 })
 
+# TODO: Isolate SignatureDef SavedModel support into its own library to decrease buildtime cost.
+# While it would be nice to simply depend on TensorFlow, manually paring
+# down the dependencies significantly reduces the build time that this adds.
+#
+# Baseline: 449s
+# SignatureDef SavedModels: 546s – 22% increase in build time.
+# SignatureDef SavedModels + Deps for MobileBert: 572s – 27% increase in build time.
+# TF OpenSource: 664s – 49% increase in build time.
+SIGNATURE_DEF_SAVED_MODEL_TF_RUNTIME_DEPS = [
+    # Deps for SignatureDef SavedModels:
+    "@org_tensorflow//tensorflow/core:direct_session",
+    "@org_tensorflow//tensorflow/core/kernels:resource_variable_ops",  #  VarHandleOp
+    "@org_tensorflow//tensorflow/core/kernels:regex_full_match_op",  # StaticRegexFullMatch
+    "@org_tensorflow//tensorflow/core/kernels:string_join_op",  # StringJoin
+    "@org_tensorflow//tensorflow/core/kernels:save_op",  # SharedFilename
+    "@org_tensorflow//tensorflow/core/kernels:save_restore_v2_ops",  # SaveV2
+
+    # Deps for MobileBert:
+    "@org_tensorflow//tensorflow/core/kernels:parameterized_truncated_normal_op",  # TruncatedNormal
+    "@org_tensorflow//tensorflow/core/kernels:state",  # Assign.
+    "@org_tensorflow//tensorflow/core/kernels:logging_ops",  # Assert
+    "@org_tensorflow//tensorflow/core/kernels:bias_op",  # BiasAdd
+    "@org_tensorflow//tensorflow/core/kernels:softmax_op",  # Softmax
+    "@org_tensorflow//tensorflow/core/kernels:relu_op",  # Relu
+]
+
 TF_XLA_PASS_DEPS = [
     "//integrations/tensorflow/compiler:tensorflow",
     "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_tf",
@@ -101,7 +127,7 @@
     hdrs = [
         "register_tensorflow.h",
     ],
-    deps = SAVED_MODEL_TF_RUNTIME_DEPS + TF_XLA_PASS_DEPS + [
+    deps = SAVED_MODEL_TF_RUNTIME_DEPS + TF_XLA_PASS_DEPS + SIGNATURE_DEF_SAVED_MODEL_TF_RUNTIME_DEPS + [
         "//bindings/python/pyiree/common",
         "//bindings/python/pyiree/compiler:compiler_library",
         "@llvm-project//llvm:Support",
@@ -127,10 +153,6 @@
     name = "signature_def_saved_model_test",
     srcs = ["signature_def_saved_model_test.py"],
     python_version = "PY3",
-    tags = [
-        "manual",
-        "nokokoro",
-    ],
     deps = INTREE_TENSORFLOW_PY_DEPS + NUMPY_DEPS + [
         "@absl_py//absl/testing:absltest",
         "//integrations/tensorflow/bindings/python/pyiree/tf/compiler",
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 43c8d08..4d432d8 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -30,7 +30,7 @@
 import pickle
 import sys
 import tempfile
-from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
 
 from absl import flags
 from absl import logging
@@ -383,15 +383,24 @@
             "Expected ref and tar to have the same dtype, but got %s  and %s",
             ref.dtype, tar.dtype)
         return False
+      if ref.size == tar.size == 0:
+        return True
+
       if np.issubdtype(ref.dtype, np.floating):
         same = np.allclose(ref, tar, rtol=rtol, atol=atol)
+        abs_diff = np.max(np.abs(ref - tar))
+        rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
         if not same:
-          abs_diff = np.max(np.abs(ref - tar))
-          rel_diff = np.max(np.abs(ref - tar) / np.max(tar))
           logging.error(
               "Floating point difference between ref and tar was too large. "
               "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
               abs_diff, atol, rel_diff, rtol)
+        else:
+          logging.info(
+              "Floating point difference between ref and tar was within "
+              "tolerance. "
+              "Max abs diff: %s, atol: %s, max relative diff: %s, rtol: %s",
+              abs_diff, atol, rel_diff, rtol)
         return same
       else:
         return np.array_equal(ref, tar)
@@ -598,6 +607,48 @@
   return Modules(ref_module, tar_modules, artifacts_dir)
 
 
+def compile_tf_signature_def_saved_model(saved_model_dir: str,
+                                         saved_model_tags: Set[str],
+                                         module_name: str, exported_name: str,
+                                         input_names: Sequence[str],
+                                         output_names: Sequence[str]):
+  """Compiles a SignatureDef SavedModel to each backend that we test.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
+    module_name: A name for this compiled module.
+    backend_info: BackendInfo with the details for compiling the saved model.
+    exported_name: A str representing the signature on the saved model to
+      compile.
+    input_names: A sequence of kwargs to feed to the saved model.
+    output_names: A sequence of named outputs to extract from the saved model.
+
+  Returns:
+    A 'Modules' namedtuple containing the reference module, target modules and
+    artifacts directory.
+  """
+
+  # Setup the directory for saving compilation artifacts and traces.
+  artifacts_dir = _setup_artifacts_dir(module_name)
+
+  # Get the backend information for this test.
+  ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+                                          f"{FLAGS.reference_backend}_ref")
+  tar_backend_infos = get_target_backends()
+
+  compile_backend = (
+      lambda backend_info: backend_info.compile_signature_def_saved_model(
+          saved_model_dir, saved_model_tags, module_name, exported_name,
+          input_names, output_names, artifacts_dir))
+
+  ref_module = compile_backend(ref_backend_info)
+  tar_modules = [
+      compile_backend(backend_info) for backend_info in tar_backend_infos
+  ]
+  return Modules(ref_module, tar_modules, artifacts_dir)
+
+
 class TracedModuleTestCase(tf.test.TestCase):
   """Compiles a tf.Module to multiple backends to test their correctness."""
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index b3237f9..a4a731a 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -21,7 +21,7 @@
 import random
 import re
 import tempfile
-from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
 
 from absl import flags
 from absl import logging
@@ -183,7 +183,7 @@
     exported_names: Sequence[str] = (),
     artifacts_dir: str = None,
 ) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
-  """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+  """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
 
   The module blob this creates is not callable. See IreeCompiledModule for an
   API that returns a module that can be called without any further steps.
@@ -193,7 +193,7 @@
 
   Args:
     module: A tf.Module.
-    backend_info: BackendInfo with the details for compiling module to IREE.
+    backend_info: BackendInfo with the details for compiling this module.
     exported_names: Optional sequence representing the exported names to keep.
     artifacts_dir: An optional string pointing to where compilation artifacts
       should be saved. No compilation artifacts will be saved if this is not
@@ -204,7 +204,7 @@
     artifacts_dir is provided.
   """
 
-  def _compile_module(module, exported_names, backend_info, artifacts_dir):
+  def _compile_module(module, backend_info, exported_names, artifacts_dir):
     compiler_module = compiler.tf_module_to_compiler_module(module,
                                                             exported_names,
                                                             pass_pipeline=())
@@ -213,7 +213,47 @@
 
   _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
                                                  backend_info.backend_id)
-  return _compile_module(module, exported_names, backend_info, artifacts_dir)
+  return _compile_module(module, backend_info, exported_names, artifacts_dir)
+
+
+def _incrementally_compile_tf_signature_def_saved_model(
+    saved_model_dir: str, saved_model_tags: Set[str],
+    backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
+  """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
+
+  The module blob this creates is not callable. See IreeCompiledModule for an
+  API that returns a module that can be called without any further steps.
+
+  See _incrementally_lower_compiler_module's docstring for details about which
+  artifacts will be saved.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
+    backend_info: BackendInfo with the details for compiling the saved model.
+    exported_name: A str representing the signature on the saved model to
+      compile.
+    artifacts_dir: An optional string pointing to where compilation artifacts
+      should be saved. No compilation artifacts will be saved if this is not
+      provided.
+
+  Returns:
+    A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+    artifacts_dir is provided.
+  """
+
+  def _compile_module(saved_model_dir, saved_model_tags, backend_info,
+                      exported_name, artifacts_dir):
+    # Convert the tf_module into raw TF input MLIR.
+    compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
+        saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
+    return _incrementally_lower_compiler_module(compiler_module, backend_info,
+                                                artifacts_dir)
+
+  _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+                                                 backend_info.backend_id)
+  return _compile_module(saved_model_dir, saved_model_tags, backend_info,
+                         exported_name, artifacts_dir)
 
 
 class CompiledModule(object):
@@ -252,7 +292,7 @@
 
     Args:
       module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling module to IREE.
+      backend_info: BackendInfo with the details for compiling this module.
       exported_names: Optional sequence representing the exported names to keep.
       artifacts_dir: An optional string pointing to where compilation artifacts
         should be saved. No compilation artifacts will be saved if this is not
@@ -280,6 +320,33 @@
     """
     raise NotImplementedError()
 
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    raise NotImplementedError()
+
   def iree_serializable(self):
     return False
 
@@ -301,8 +368,8 @@
     self._context = context
     self._f = f
 
-  def __call__(self, *args):
-    return self._f(*args)
+  def __call__(self, *args, **kwargs):
+    return self._f(*args, **kwargs)
 
   def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
     """Get cxx serialized inputs and outputs for this function."""
@@ -390,6 +457,46 @@
 
     return cls(module_name, backend_info, compiled_paths, vm_module, config)
 
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    del input_names  # Unused.
+    del output_names  # Unused.
+    module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
+        saved_model_dir, saved_model_tags, backend_info, exported_name,
+        artifacts_dir)
+    vm_module = rt.VmModule.from_flatbuffer(module_blob)
+    config = rt.Config(driver_name=backend_info.driver)
+
+    compiled_paths = None
+    if compiled_path is not None:
+      # IREE bundles every compiled method into the same compiled module :)
+      compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+    return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
   def reinitialize(self):
     """Reinitializes all stateful variables."""
     # set_random_seed is not needed here because the model_class.__init__ is not
@@ -441,6 +548,28 @@
                                  check_types=False)
 
 
+def _convert_inputs_to_tensors(function):
+
+  def decorator(*args, **kwargs):
+    args = [tf.convert_to_tensor(arg) for arg in args]
+    kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
+    return function(*args, **kwargs)
+
+  return decorator
+
+
+class SignatureDefSavedModelWrapper(object):
+  """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
+
+  def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
+               exported_name: str):
+    self.saved_model = tf.saved_model.load(saved_model_dir,
+                                           tags=saved_model_tags)
+    inference_func = self.saved_model.signatures[exported_name]
+    inference_func = _convert_inputs_to_tensors(inference_func)
+    self.__setattr__(exported_name, inference_func)
+
+
 class TfCompiledModule(CompiledModule):
   """TensorFlow 'compiled' module.
 
@@ -482,7 +611,7 @@
 
     Args:
       module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling module to IREE.
+      backend_info: BackendInfo with the details for compiling this module.
       exported_names: Optional sequence representing the exported names to keep.
       artifacts_dir: An optional string pointing to where compilation artifacts
         should be saved. No compilation artifacts will be saved if this is not
@@ -492,6 +621,35 @@
     constructor = module_class
     return cls(module_name, backend_info, constructor, exported_names)
 
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    constructor = lambda: SignatureDefSavedModelWrapper(
+        saved_model_dir, saved_model_tags, exported_name)
+    return cls(module_name, backend_info, constructor, [exported_name])
+
   def reinitialize(self):
     """Reinitializes all stateful variables."""
     set_random_seed()
@@ -530,58 +688,98 @@
   return functions, exported_names
 
 
-def tf_module_to_tflite_interpreters(
-    module_class: Type[tf.Module],
-    exported_names: Sequence[str] = (),
-    artifacts_dir: str = None
-) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str]], None]:
-  """Compile a tf.Module to TFLite interpreters for each of its methods.
+def tf_module_to_tflite_module_bytes(
+    module_class: Type[tf.Module], exported_names: Sequence[str] = ()
+) -> Dict[str, bytes]:
+  """Compiles a tf.Module's methods with TFLite.
 
   Args:
-    module_class: A tf.Module subclass to compile with TFLite. If module_class
-      has an attr get_legacy_tflite_saved_model_converter_kwargs then it will
-      be compiled using tf.compat.v1.lite. It's best not to use this, however.
+    module_class: A tf.Module subclass to compile with TFLite.
     exported_names: an optional iterable of strings representing which of the
-      module_class's functions should be callable. If exported_names is empty
-      then all functions will be callable.
+      module_class's functions should be compiled. If exported_names is empty
+      then all functions will be compiled.
+
+  Returns:
+    A dict mapping method names to compiled TFLite module bytes.
+  """
+  tflite_modules = []
+  methods, method_names = _get_concrete_functions(module_class, exported_names)
+  for method in methods:
+    converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+    tflite_modules.append(converter.convert())
+  return dict(zip(method_names, tflite_modules))
+
+
+def tf_signature_def_saved_model_to_tflite_module_bytes(
+    saved_model_dir: str,
+    saved_model_tags: Set[str],
+    exported_name: str,
+    input_names: Sequence[str],
+    output_names: Sequence[str],
+) -> Dict[str, bytes]:
+  """Compiles a SignatureDef SavedModel signature with TFLite.
+
+  Args:
+    saved_model_dir: Directory of the saved model.
+    saved_model_tags: Optional set of tags to use when loading the model.
+    exported_name: A str representing the signature on the saved model to
+      compile.
+    input_names: A sequence of kwargs to feed to the saved model.
+    output_names: A sequence of named outputs to extract from the saved model.
+
+  Returns:
+    A dict mapping the signature name to the compiled TFLite module bytes.
+  """
+  converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
+      saved_model_dir,
+      tag_set=saved_model_tags,
+      signature_key=exported_name,
+      input_arrays=input_names,
+      output_arrays=output_names)
+  tflite_module = converter.convert()
+  return dict([[exported_name, tflite_module]])
+
+
+def tflite_module_bytes_to_tflite_interpreters(
+    tflite_module_bytes: Dict[str, bytes],
+    artifacts_dir: str = None
+) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
+  """Compile a dict of TFLite compiled bytes to  TFLite interpreters.
+
+  Args:
+    tflite_module_bytes: A dict mapping method names to compiled TFLite byte
+      strings.
     artifacts_dir: an optional path to save compilation artifacts to.
 
   Returns:
-    A dictionary of function names to TFLite interpreters and a dictionary of
-    function names to compiled tflite graph paths (or None if artifacts_dir)
-    is None.
+    A dictionary mapping method names to TFLite interpreters and a dictionary
+    mapping method names to compiled tflite graph paths (or None if
+    artifacts_dir is None).
   """
   interpreters = dict()
   compiled_paths = None
   if artifacts_dir is not None:
     compiled_paths = dict()
 
-  def _interpret_bytes(tflite_module: bytes, base_dir: str):
+  def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
     """Save compiled TFLite module bytes and convert into an interpreter."""
     tflite_dir = os.path.join(base_dir, "tflite")
     os.makedirs(tflite_dir, exist_ok=True)
-    tflite_path = os.path.join(tflite_dir, f"{name}.tflite")
+    tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
     with open(tflite_path, "wb") as f:
       f.write(tflite_module)
 
-    interpreters[name] = tf.lite.Interpreter(tflite_path)
+    interpreters[method_name] = tf.lite.Interpreter(tflite_path)
     if artifacts_dir is not None:
-      compiled_paths[name] = tflite_path
-
-  # Convert module_class's methods into TFLite module byte-strings.
-  tflite_modules = []
-  functions, names = _get_concrete_functions(module_class, exported_names)
-  for function in functions:
-    converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
-    tflite_modules.append(converter.convert())
+      compiled_paths[method_name] = tflite_path
 
   # Load each of the converted methods above into tf.lite.Interpreters.
-  for name, tflite_module in zip(names, tflite_modules):
+  for method_name, tflite_module in tflite_module_bytes.items():
     if artifacts_dir is None:
       with tempfile.TemporaryDirectory() as base_dir:
-        _interpret_bytes(tflite_module, base_dir)
+        _interpret_bytes(method_name, tflite_module, base_dir)
     else:
-      _interpret_bytes(tflite_module, artifacts_dir)
+      _interpret_bytes(method_name, tflite_module, artifacts_dir)
 
   return interpreters, compiled_paths
 
@@ -593,14 +791,20 @@
     self._interpreter = interpreter
 
   def __call__(self, *args, **kwargs) -> Tuple[Any]:
-    if len(kwargs):
-      raise ValueError("kwargs are not supported, but the following kwargs "
-                       f"were provided {kwargs}")
+    if len(args) and len(kwargs):
+      raise ValueError("Passing both args and kwargs is not supported by "
+                       "_TfLiteFunctionWrapper")
 
     # Set up and run the function.
     self._interpreter.allocate_tensors()
-    for arg, detail in zip(args, self._interpreter.get_input_details()):
-      self._interpreter.set_tensor(detail["index"], arg)
+
+    if len(args):
+      for arg, detail in zip(args, self._interpreter.get_input_details()):
+        self._interpreter.set_tensor(detail["index"], arg)
+    else:
+      for detail in self._interpreter.get_input_details():
+        self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
+
     self._interpreter.invoke()
 
     # Extract the outputs from the TFLite interpreter.
@@ -664,18 +868,52 @@
 
     Args:
       module_class: The tf.Module subclass to compile.
-      backend_info: BackendInfo with the details for compiling module to IREE.
+      backend_info: BackendInfo with the details for compiling this module.
       exported_names: Optional sequence representing the exported names to keep.
       artifacts_dir: An optional string pointing to where compilation artifacts
         should be saved. No compilation artifacts will be saved if this is not
         provided.
     """
     set_random_seed()
-    interpreters, compiled_paths = tf_module_to_tflite_interpreters(
-        module_class, exported_names, artifacts_dir)
+    tflite_module_bytes = tf_module_to_tflite_module_bytes(
+        module_class, exported_names)
+    interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+        tflite_module_bytes, artifacts_dir)
     module_name = module_class.__name__
     return cls(module_name, backend_info, compiled_paths, interpreters)
 
+  @classmethod
+  def create_from_signature_def_saved_model(cls,
+                                            saved_model_dir: str,
+                                            saved_model_tags: Set[str],
+                                            module_name: str,
+                                            backend_info: "BackendInfo",
+                                            exported_name: str,
+                                            input_names: Sequence[str],
+                                            output_names: Sequence[str],
+                                            artifacts_dir: str = None):
+    """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+    Args:
+      saved_model_dir: Directory of the saved model.
+      saved_model_tags: Optional set of tags to use when loading the model.
+      module_name: A name for this compiled module.
+      backend_info: BackendInfo with the details for compiling the saved model.
+      exported_name: A str representing the signature on the saved model to
+        compile.
+      input_names: A sequence of kwargs to feed to the saved model.
+      output_names: A sequence of named outputs to extract from the saved model.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
+        saved_model_dir, saved_model_tags, exported_name, input_names,
+        output_names)
+    interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+        tflite_module_bytes, artifacts_dir)
+    return cls(module_name, backend_info, compiled_paths, interpreters)
+
   def reinitialize(self):
     """Reinitializes all stateful variables."""
     # This is a noop because TFLite (mostly) doesn't support stateful modules.
@@ -761,6 +999,19 @@
     return self._compiled_module_class.create_from_class(
         module_class, self, exported_names, artifacts_dir)
 
+  def compile_signature_def_saved_model(
+      self,
+      saved_model_dir: str,
+      saved_model_tags: Set[str],
+      module_name: str,
+      exported_name: str,
+      input_names: Sequence[str],
+      output_names: Sequence[str],
+      artifacts_dir: str = None) -> CompiledModule:
+    return self._compiled_module_class.create_from_signature_def_saved_model(
+        saved_model_dir, saved_model_tags, module_name, self, exported_name,
+        input_names, output_names, artifacts_dir)
+
   @classmethod
   def get_all_backends(cls) -> Sequence["BackendInfo"]:
     """Returns a list of all BackendInfo configurations."""
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index dc289d6..5b89789 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -218,12 +218,18 @@
 
 iree_e2e_test_suite(
     name = "mobile_bert_squad_tests",
+    size = "enormous",
     backends_to_srcs = {
         "tf": ["mobile_bert_squad_test.py"],
+        "tflite": ["mobile_bert_squad_test.py"],
+        "iree_vmla": ["mobile_bert_squad_test.py"],
     },
     reference_backend = "tf",
     tags = [
+        "external",
+        "guitar",
         "manual",
+        "no-remote",
         "nokokoro",
         "notap",
     ],
@@ -234,15 +240,17 @@
 
 iree_e2e_test_suite(
     name = "mobile_bert_squad_tests_failing",
+    size = "enormous",
     backends_to_srcs = {
-        "tflite": ["mobile_bert_squad_test.py"],
-        "iree_vmla": ["mobile_bert_squad_test.py"],
         "iree_llvmjit": ["mobile_bert_squad_test.py"],
         "iree_vulkan": ["mobile_bert_squad_test.py"],
     },
     reference_backend = "tf",
     tags = [
+        "external",
+        "guitar",
         "manual",
+        "no-remote",
         "nokokoro",
         "notap",
     ],
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 7adb281..553faf7 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -221,14 +221,10 @@
 
 TODO(silvasean): debugging miscompiles
 
-## Legacy TFLite Compilation
+## Testing SignatureDef SavedModels
 
-_Please don't use this unless you are forced to._
-
-We support using `tf.compat.v1.lite.TFLiteConverter.from_saved_model` to compile
-older `tf.Module`s with TFLite. This will be used if the `tf.Module` being
-tested has a method named `get_legacy_tflite_saved_model_converter_kwargs`. This
-method must return a dict with the following kwargs: `model_path`,
-`input_arrays`, `output_arrays`, and `exported_name`. The module must use only
-one exported name, and `exported_name` should be equal to that name. See
-`mobile_bert_squad_test.py` for a concrete example.
+TensorFlow 1.x SavedModels can be tested using
+`tf_test_utils.compile_tf_signature_def_saved_model` instead of
+`tf_test_utils.compile_tf_module`. See `mobile_bert_squad_test.py` for a
+concrete example. The compilation artifacts will be saved under whatever
+you specify for `module_name`.
diff --git a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
index 60bbff4..159633a 100644
--- a/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
+++ b/integrations/tensorflow/e2e/iree_e2e_test_suite.bzl
@@ -21,6 +21,7 @@
         backends_to_srcs,
         reference_backend,
         deps = None,
+        size = None,
         tags = None,
         python_version = "PY3",
         **kwargs):
@@ -73,6 +74,7 @@
                 srcs = [src],
                 deps = deps,
                 args = args,
+                size = size,
                 tags = py_test_tags,
                 python_version = python_version,
                 **kwargs
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index 4f8daf5..b99759d 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -25,80 +25,57 @@
 from absl import flags
 import numpy as np
 from pyiree.tf.support import tf_test_utils
+from pyiree.tf import compiler
 import tensorflow.compat.v2 as tf
 
 FLAGS = flags.FLAGS
 
-flags.DEFINE_boolean("use_quantized_weights", False,
-                     "Whether to use quantized or floating point weights.")
+flags.DEFINE_boolean('use_quantized_weights', False,
+                     'Whether to use quantized or floating point weights.')
 
 MAX_SEQ_LENGTH = 384  # Max input sequence length used in mobilebert_squad.
 
-FILE_NAME = "mobilebert_squad_savedmodels.tar.gz"
+FILE_NAME = 'mobilebert_squad_savedmodels.tar.gz'
 MODEL_URL = posixpath.join(
-    "https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/",
+    'https://storage.googleapis.com/cloud-tpu-checkpoints/mobilebert/',
     FILE_NAME)
 
 
-class MobileBertSquad(tf.Module):
-  """Wrapper of MobileBertSquad saved model v1."""
-
-  def __init__(self):
-    self.model_path = self.get_model_path()
-    self.saved_model = tf.saved_model.load(self.model_path, tags=["serve"])
-    self.inference_func = self.saved_model.signatures["serving_default"]
-
-  @staticmethod
-  def get_model_path():
-    model_type = "quant_saved_model" if FLAGS.use_quantized_weights else "float"
-
-    # Get_file will download the model weights from a publicly available folder,
-    # save them to cache_dir=~/.keras/datasets/ and return a path to them.
-    model_path = tf.keras.utils.get_file(FILE_NAME, MODEL_URL, untar=True)
-    model_dir = os.path.dirname(model_path)
-    extracted_name = FILE_NAME.split(".")[0]
-    model_path = os.path.join(model_dir, extracted_name, model_type)
-    return model_path
-
-  @staticmethod
-  def get_legacy_tflite_saved_model_converter_kwargs():
-    return dict([("input_arrays", ["input_ids", "input_mask", "segment_ids"]),
-                 ("output_arrays", ["start_logits", "end_logits"]),
-                 ("exported_name", "predict"),
-                 ("model_path", MobileBertSquad.get_model_path())])
-
-  @tf.function(input_signature=[
-      tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
-      tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
-      tf.TensorSpec((1, MAX_SEQ_LENGTH), tf.int32),
-  ])
-  def predict(self, input_ids, input_mask, segment_ids):
-    inputs = {
-        "input_ids": input_ids,
-        "input_mask": input_mask,
-        "segment_ids": segment_ids,
-    }
-    return self.inference_func(**inputs)
-
-
 class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
   """Tests of MobileBertSquad."""
 
-  def __init__(self, methodName="runTest"):
+  def __init__(self, methodName='runTest'):
     super(MobileBertSquadTest, self).__init__(methodName)
-    self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
-                                                    exported_names=["predict"])
+    model_type = 'quant_saved_model' if FLAGS.use_quantized_weights else 'float'
 
-  def test_predict(self):
+    # Get_file will download the model weights from a publicly available folder,
+    # save them to cache_dir=~/.keras/datasets/ and return a path to them.
+    model_path = tf.keras.utils.get_file(FILE_NAME, MODEL_URL, untar=True)
+    model_dir = os.path.dirname(model_path)
+    extracted_name = FILE_NAME.split('.')[0]
+    model_path = os.path.join(model_dir, extracted_name, model_type)
 
-    def predict(module):
+    self._modules = tf_test_utils.compile_tf_signature_def_saved_model(
+        saved_model_dir=model_path,
+        saved_model_tags=set(['serve']),
+        module_name='MobileBertSquad',
+        exported_name='serving_default',
+        input_names=['input_ids', 'input_mask', 'segment_ids'],
+        output_names=['start_logits', 'end_logits'])
+
+  def test_serving_default(self):
+
+    def serving_default(module):
       input_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
       input_mask = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
       segment_ids = np.zeros((1, MAX_SEQ_LENGTH), dtype=np.int32)
 
-      module.predict(input_ids, input_mask, segment_ids, atol=1e0)
+      module.serving_default(input_ids=input_ids,
+                             input_mask=input_mask,
+                             segment_ids=segment_ids,
+                             atol=1e0)
 
-    self.compare_backends(predict, self._modules)
+    self.compare_backends(serving_default, self._modules)
 
 
 def main(argv):
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index 30503ba..b048baa 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -38,6 +38,8 @@
 TEST_SUITES_TO_HEADERS = {
     '//integrations/tensorflow/e2e:e2e_tests':
         'End to end TensorFlow tests',
+    '//integrations/tensorflow/e2e:mobile_bert_squad_tests':
+        'End to end test of MobileBert on SQuAD',
     '//integrations/tensorflow/e2e/keras:keras_tests':
         'End to end tests written using tf.keras',
     '//integrations/tensorflow/e2e/keras:imagenet_external_tests':