Refactor 'CompiledModule' to use cls constructors (#3329)
Refactors `CompiledModule` and its subclasses to use the name constructors `create_from_class` and `create_from_instance`. This makes it easier to add additional compilation paths to our TensorFlow integration tests.
Additional changes:
- Clean up docstrings and make type info more accurate.
- Explicitly differentiate between `backend_name` and `backend_id`:
- `backend_name` is one of `[tf, tflite, iree_vmla, iree_llvmjit, iree_vulkan]`.
- `backend_id` uniquely identifies each instantiated backend at test time for the purpose of saving compilation artifacts.
- Use a `defaultdict` to match IREE's behavior of compiling all methods into one binary.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 774412d..ba966db 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -99,7 +99,7 @@
def tf_saved_model_to_compiler_module(
saved_model_dir: str,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> Module:
"""Converts a TensorFlow SavedModel into a MLIR module.
@@ -108,8 +108,7 @@
Args:
saved_model_dir: Directory of the saved model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -131,18 +130,17 @@
def compile_tf_saved_model(
saved_model_dir: str,
- exported_names: Collection[str] = (),
- target_backends: Collection[str] = (),
+ exported_names: Sequence[str] = (),
+ target_backends: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
"""Compiles a TensorFlow SavedModel to IREE in one shot.
Args:
saved_model_dir: Directory of the saved model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ exported_names: Optional sequence representing the exported names to keep.
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -160,7 +158,7 @@
def tf_signature_def_saved_model_to_compiler_module(
saved_model_dir: str,
saved_model_tags: Set[str] = set(),
- exported_names: Collection[str] = [],
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> Module:
"""Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
@@ -168,8 +166,7 @@
Args:
saved_model_dir: Directory of the saved model.
saved_model_tags: Optional set of tags to use when loading the model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -194,8 +191,8 @@
def compile_tf_signature_def_saved_model(
saved_model_dir: str,
saved_model_tags: Set[str] = set(),
- exported_names: Collection[str] = (),
- target_backends: Collection[str] = (),
+ exported_names: Sequence[str] = (),
+ target_backends: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
"""Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
@@ -203,10 +200,9 @@
Args:
saved_model_dir: Directory of the saved model.
saved_model_tags: Optional set of tags to use when loading the model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ exported_names: Optional sequence representing the exported names to keep.
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -222,7 +218,7 @@
def tf_module_to_compiler_module(
module: tf.Module,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None,
saved_model_dir: str = None) -> Module:
@@ -230,8 +226,7 @@
Args:
module: The tf.Module instance to convert to MLIR
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -259,8 +254,8 @@
def compile_tf_module(module: tf.Module,
- exported_names: Collection[str] = (),
- target_backends: Collection[str] = (),
+ exported_names: Sequence[str] = (),
+ target_backends: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None,
saved_model_dir: str = None):
@@ -268,10 +263,9 @@
Args:
module: The tf.Module instance to convert to MLIR
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ exported_names: Optional sequence representing the exported names to keep.
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index fa3df03..aa9481e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -66,22 +66,22 @@
def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
- """Decodes --target_backends and creates unique names for their artifacts."""
+ """Decodes --target_backends and creates unique ids for them."""
backend_names = FLAGS.target_backends.split(",")
backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
- artifact_names = []
+ backend_ids = []
# If there are multiple copies of the same backend_name, index them. e.g.
# backend_names = ["tf", "iree_vmla", "tf"]
- # --> artifact_names = ["tf_0", "iree_vmla", "tf_1"]
+ # --> backend_ids = ["tf_0", "iree_vmla", "tf_1"]
for backend_name in backend_names:
if backend_name in backend_to_index:
- artifact_names.append(f"{backend_name}_{backend_to_index[backend_name]}")
+ backend_ids.append(f"{backend_name}_{backend_to_index[backend_name]}")
backend_to_index[backend_name] += 1
else:
- artifact_names.append(backend_name)
+ backend_ids.append(backend_name)
- return backend_names, artifact_names
+ return backend_names, backend_ids
def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
@@ -95,10 +95,10 @@
"""
if FLAGS.target_backends is not None:
logging.info("Using backends from command line: %s", FLAGS.target_backends)
- backend_names, names = _parse_target_backends()
+ backend_names, backend_ids = _parse_target_backends()
backends = [
- tf_utils.BackendInfo(backend, name)
- for backend, name in zip(backend_names, names)
+ tf_utils.BackendInfo(backend_name, backend_id)
+ for backend_name, backend_id in zip(backend_names, backend_ids)
]
else:
# If no backends are specified, use them all.
@@ -261,10 +261,11 @@
# Extract metadata from module and function.
self.module_name = module.module_name
self.compiled_paths = module.compiled_paths
- self.backend_name = module.backend
+ self.backend_name = module.backend_info.backend_name
+ self.backend_id = module.backend_info.backend_id
+ self.backend_driver = module.backend_info.driver
self.iree_serializable = module.iree_serializable()
self.tflite_serializable = module.tflite_serializable()
- self.backend_driver = module.backend_driver
self.function_name = function.__name__
self.function_sourcefile = inspect.getsourcefile(function)
source, start_line = inspect.getsourcelines(function)
@@ -276,9 +277,10 @@
self.module_name = _load_dict["module_name"]
self.compiled_paths = _load_dict["compiled_paths"]
self.backend_name = _load_dict["backend_name"]
+ self.backend_id = _load_dict["backend_id"]
+ self.backend_driver = _load_dict["backend_driver"]
self.iree_serializable = _load_dict["iree_serializable"]
self.tflite_serializable = _load_dict["tflite_serializable"]
- self.backend_driver = _load_dict["backend_driver"]
self.function_name = _load_dict["function_name"]
self.function_sourcefile = _load_dict["function_sourcefile"]
self.function_line_numbers = _load_dict["function_line_numbers"]
@@ -286,7 +288,7 @@
self.calls = _load_dict["calls"]
def __str__(self):
- header = (f"Trace of {self.module_name} compiled to '{self.backend_name}' "
+ header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
f"on function '{self.function_name}':")
# Give each call a number so it's easier to compare between multiple traces.
calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
@@ -327,11 +329,11 @@
if not calls_match:
logging.error("Comparision between '%s' and '%s' failed on method '%s'",
- ref_trace.backend_name, tar_trace.backend_name,
+ ref_trace.backend_id, tar_trace.backend_id,
ref_call.method)
- logging.error("Reference call '%s':\n%s", ref_trace.backend_name,
+ logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
ref_call)
- logging.error("Target call '%s':\n%s", tar_trace.backend_name, tar_call)
+ logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
traces_match = traces_match and calls_match
return traces_match
@@ -434,14 +436,20 @@
trace_dir: str, path to the directory to serialize the trace to.
"""
+ compiled_paths = None
+ if self.compiled_paths is not None:
+ # Convert to a dict to avoid the issues with serializing defaultdicts.
+ compiled_paths = dict(self.compiled_paths)
+
# Python serialization.
metadata = {
"module_name": self.module_name,
- "compiled_paths": self.compiled_paths,
+ "compiled_paths": compiled_paths,
"backend_name": self.backend_name,
+ "backend_id": self.backend_id,
+ "backend_driver": self.backend_driver,
"iree_serializable": self.iree_serializable,
"tflite_serializable": self.tflite_serializable,
- "backend_driver": self.backend_driver,
"function_name": self.function_name,
"function_sourcefile": self.function_sourcefile,
"function_line_numbers": self.function_line_numbers,
@@ -493,7 +501,7 @@
def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
- trace_dir = os.path.join(artifacts_dir, trace.backend_name, "traces",
+ trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
trace.function_name)
os.makedirs(trace_dir, exist_ok=True)
return trace_dir
@@ -559,17 +567,17 @@
def compile_tf_module(
module_class: Type[tf.Module], exported_names: Sequence[str] = ()
) -> Callable[[Any], Any]:
- """CompiledModuleTestCase decorator that compiles a tf.Module.
-
- A CompiledModule is created for each backend in --target_backends. They can
- be accessed individually via self.compiled_modules.backend_name or as a union
- via self.get_module().
+ """Compiles module_class to each backend that we test.
Args:
module_class: the tf.Module subclass to compile.
exported_names: optional iterable of strings representing which of
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
+
+ Returns:
+ A 'Modules' namedtuple containing the reference module, target modules and
+ artifacts directory.
"""
# Setup the directory for saving compilation artifacts and traces.
@@ -580,7 +588,7 @@
f"{FLAGS.reference_backend}_ref")
tar_backend_infos = get_target_backends()
- compile_backend = lambda backend_info: backend_info.compile(
+ compile_backend = lambda backend_info: backend_info.compile_from_class(
module_class, exported_names, artifacts_dir)
ref_module = compile_backend(ref_backend_info)
@@ -631,7 +639,7 @@
failed_backend_indices = []
for i, tar_trace in enumerate(tar_traces):
logging.info("Comparing the reference backend '%s' with '%s'",
- ref_trace.backend_name, tar_trace.backend_name)
+ ref_trace.backend_id, tar_trace.backend_id)
traces_match = Trace.compare_traces(ref_trace, tar_trace)
if not traces_match:
failed_backend_indices.append(i)
@@ -649,7 +657,7 @@
if failed_backend_indices:
# Extract info for logging.
failed_backends = [
- tar_traces[i].backend_name for i in failed_backend_indices
+ tar_traces[i].backend_id for i in failed_backend_indices
]
self.fail(
"Comparision between the reference backend and the following targets "
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 6157c68..e1fb6e3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -117,8 +117,8 @@
# Only outputs
module.get_count()
- module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ module = tf_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('tf'))
trace = tf_test_utils.Trace(module, trace_function)
trace_function(tf_test_utils.TracedModule(module, trace))
@@ -140,13 +140,13 @@
module.increment()
module.decrement()
- tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ tf_module = tf_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('tf'))
tf_trace = tf_test_utils.Trace(tf_module, tf_function)
tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
- vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('iree_vmla'))
+ vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
@@ -161,13 +161,13 @@
def vmla_function(module):
module.increment_by(np.array([22.], dtype=np.float32))
- tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ tf_module = tf_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('tf'))
tf_trace = tf_test_utils.Trace(tf_module, tf_function)
tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
- vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('iree_vmla'))
+ vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
@@ -178,12 +178,12 @@
def trace_function(module):
module.increment()
module.increment_by(np.array([81.], dtype=np.float32))
- module.increment_by_max(
- np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
+ module.increment_by_max(np.array([81], dtype=np.float32),
+ np.array([92], dtype=np.float32))
module.get_count()
- module = tf_utils.IreeCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('iree_vmla'))
+ module = tf_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
trace = tf_test_utils.Trace(module, trace_function)
trace_function(tf_test_utils.TracedModule(module, trace))
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index fc8bab6..b3237f9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -16,6 +16,7 @@
# pylint: disable=protected-access
+import collections
import os
import random
import re
@@ -91,7 +92,7 @@
def _setup_mlir_crash_reproducer(
function: Callable[[Any], Any],
artifacts_dir: str,
- backend_name: str,
+ backend_id: str,
) -> Callable[[Any], Any]:
"""Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
@@ -100,7 +101,7 @@
Args:
function: The callable to decorate.
artifacts_dir: The directory to write the reproducer to.
- backend_name: The name of the backend `function` compiles to.
+ backend_id: The unique backend name to use when writting the reproducer.
Returns:
A function with the same API as the passed function.
@@ -110,7 +111,7 @@
# Set up a crash reproducer for debugging.
if artifacts_dir is not None:
compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backend_name}.mlir")
+ artifacts_dir, f"reproducer__{backend_id}.mlir")
try:
results = function(*args, **kwargs)
except Exception: # pylint: disable=broad-except
@@ -135,7 +136,7 @@
MLIR for the module in TF's input dialect.
iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
- backend_name/compiled.vmfb:
+ backend_id/compiled.vmfb:
A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
Args:
@@ -167,7 +168,7 @@
compiled_path = None
if artifacts_dir is not None:
- backend_dir = os.path.join(artifacts_dir, backend_info.name)
+ backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
os.makedirs(backend_dir, exist_ok=True)
compiled_path = os.path.join(backend_dir, "compiled.vmfb")
logging.info("Saving compiled IREE module to: %s", compiled_path)
@@ -180,7 +181,7 @@
module: Type[tf.Module],
backend_info: "BackendInfo",
exported_names: Sequence[str] = (),
- artifacts_dir: str = None
+ artifacts_dir: str = None,
) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
"""Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
@@ -193,8 +194,7 @@
Args:
module: A tf.Module.
backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Iterable of dotted function names to consider for
- compilation.
+ exported_names: Optional sequence representing the exported names to keep.
artifacts_dir: An optional string pointing to where compilation artifacts
should be saved. No compilation artifacts will be saved if this is not
provided.
@@ -212,29 +212,72 @@
artifacts_dir)
_compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.name)
+ backend_info.backend_id)
return _compile_module(module, exported_names, backend_info, artifacts_dir)
class CompiledModule(object):
"""Base class for the TF and IREE compiled modules."""
- def __init__(self, module_class: Type[tf.Module], backend_info: "BackendInfo",
- exported_names: Sequence[str], artifacts_dir: str):
- """Shared base constructor – not useful on its own."""
- self._module_class = module_class
- self._backend_info = backend_info
- self._exported_names = exported_names
- self._artifacts_dir = artifacts_dir
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ ):
+ """Shared base constructor – not useful on its own.
- # Public attributes:
- self.backend = self._backend_info.name
- self.backend_driver = self._backend_info.driver
- self.module_name = self._module_class.__name__
- self.compiled_paths = None
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ """
+ self.module_name = module_name
+ self.backend_info = backend_info
+ self.compiled_paths = compiled_paths
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ This is only implemented for IreeCompiledModule.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
raise NotImplementedError()
def iree_serializable(self):
@@ -244,14 +287,6 @@
return False
-def _get_non_inhereted_function_names(cls):
- """Gets all methods that cls has that its parents don't have."""
- names = set(dir(cls))
- for parent in cls.__bases__:
- names -= set(dir(parent))
- return list(names)
-
-
class _FunctionWrapper(object):
def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
@@ -277,57 +312,94 @@
class IreeCompiledModule(CompiledModule):
"""Iree compiled module."""
- def __init__(self,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module to the target backend in backend_info.
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ vm_module: rt.VmModule,
+ config: rt.Config,
+ ):
+ """Base constructor – Use one of the named constructors instead.
Args:
- module_class: the tf.Module subclass to compile.
- backend_info: an element of BackendInfo corresponding to the IREE backend
- to compile to.
- exported_names: an optional iterable of strings representing which of the
- module_class's functions to compile. If exported_names is empty all
- functions will be compiled.
- artifacts_dir: an optional path to save compilation artifacts to.
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ vm_module: A rt.VmModule containing compilation info to wrap.
+ config: A rt.Config containing compilation info to wrap.
"""
- super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._vm_module = vm_module
+ self._config = config
+ self.reinitialize()
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
set_random_seed()
- self._module_blob, compiled_path = _incrementally_compile_tf_module(
- module=module_class(),
+ module_instance = module_class()
+ return cls.create_from_instance(module_instance, backend_info,
+ exported_names, artifacts_dir)
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_blob, compiled_path = _incrementally_compile_tf_module(
+ module=module_instance,
backend_info=backend_info,
exported_names=exported_names,
artifacts_dir=artifacts_dir)
- self._module = rt.VmModule.from_flatbuffer(self._module_blob)
- self._config = rt.Config(driver_name=backend_info.driver)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
- self.compiled_paths = None
+ compiled_paths = None
if compiled_path is not None:
- if not len(exported_names):
- # Get all method names on 'module_class' that aren't on 'tf.Module'.
- # This doesn't address all possbile scenarios.
- # TODO(meadowlark): Figure out how to get a list of all of the functions
- # that this module has access to via `pyiree.rt.system_api.BoundModule`.
- exported_names = _get_non_inhereted_function_names(module_class)
- self.compiled_paths = dict([
- (method, compiled_path) for method in exported_names
- ])
+ # IREE bundles every compiled method into the same compiled module.
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
- self.reinitialize()
+ module_name = type(module_instance).__name__
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
# set_random_seed is not needed here because the model_class.__init__ is not
# called.
- self._context = rt.SystemContext(modules=[self._module],
+ self._context = rt.SystemContext(modules=[self._vm_module],
config=self._config)
def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
# Try to resolve it as a function.
- m = self._context.modules[self._module.name]
+ m = self._context.modules[self._vm_module.name]
f = m[attr]
return _IreeFunctionWrapper(self._context, f)
@@ -376,29 +448,54 @@
normalize TensorFlow's output to Numpy.
"""
- def __init__(self,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Wrap a tf.Module in a TFCompiledModule facade.
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ constructor: Callable[[], tf.Module],
+ exported_names: Sequence[str],
+ ):
+ """Base constructor – Use one of the named constructors instead.
Args:
- module_class: the tf.Module subclass to 'compile'.
- backend_info: one of the 'tf*' elements in BackendInfo.
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ constructor: A callable (class or function) which returns the tf.Module
+ subclass instance to wrap.
exported_names: an optional iterable of strings representing which of the
- module_class's functions should be callable. If exported_names is empty
- then all functions will be callable.
- artifacts_dir: an optional path to save compilation artifacts to. Has no
- effect for this subclass as nothing is compiled.
+ tf.Module subclass instance's functions should be callable. If
+ exported_names is empty then all functions will be callable.
"""
- super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ super().__init__(module_name, backend_info, compiled_paths=None)
+ self._constructor = constructor
+ self._exported_names = exported_names
self.reinitialize()
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_name = module_class.__name__
+ constructor = module_class
+ return cls(module_name, backend_info, constructor, exported_names)
+
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
set_random_seed()
- self._tf_module = self._module_class()
+ self._tf_module = self._constructor()
def __getattr__(self, attr: str) -> _TfFunctionWrapper:
# Try to resolve it as a function.
@@ -412,6 +509,14 @@
return _TfFunctionWrapper(f)
+def _get_non_inhereted_function_names(cls):
+ """Gets all methods that cls has that its parents don't have."""
+ names = set(dir(cls))
+ for parent in cls.__bases__:
+ names -= set(dir(parent))
+ return list(names)
+
+
def _get_concrete_functions(module_class: Type[tf.Module],
exported_names: Sequence[str] = ()):
"""Get concrete functions from non-inherited methods or exported_names."""
@@ -425,12 +530,12 @@
return functions, exported_names
-def compile_to_tflite(
+def tf_module_to_tflite_interpreters(
module_class: Type[tf.Module],
exported_names: Sequence[str] = (),
artifacts_dir: str = None
) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str]], None]:
- """Compile a dict of TFLite interpreters for the methods on module_class.
+ """Compile a tf.Module to TFLite interpreters for each of its methods.
Args:
module_class: A tf.Module subclass to compile with TFLite. If module_class
@@ -463,22 +568,14 @@
if artifacts_dir is not None:
compiled_paths[name] = tflite_path
+ # Convert module_class's methods into TFLite module byte-strings.
tflite_modules = []
- names = []
- if hasattr(module_class, "get_legacy_tflite_saved_model_converter_kwargs"):
- kwargs = module_class.get_legacy_tflite_saved_model_converter_kwargs()
- converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
- kwargs["model_path"],
- input_arrays=kwargs["input_arrays"],
- output_arrays=kwargs["output_arrays"])
+ functions, names = _get_concrete_functions(module_class, exported_names)
+ for function in functions:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
tflite_modules.append(converter.convert())
- names.append(kwargs["exported_name"])
- else:
- functions, names = _get_concrete_functions(module_class, exported_names)
- for function in functions:
- converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
- tflite_modules.append(converter.convert())
+ # Load each of the converted methods above into tf.lite.Interpreters.
for name, tflite_module in zip(names, tflite_modules):
if artifacts_dir is None:
with tempfile.TemporaryDirectory() as base_dir:
@@ -537,18 +634,50 @@
class TfLiteCompiledModule(CompiledModule):
"""Compiles a tf.Module with TFLite and allows it to be called."""
- def __init__(self,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ interpreters: Dict[str, tf.lite.Interpreter],
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ interpreters: A dict of tf.lite.Interpreters to make callable.
+ """
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._interpreters = interpreters
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
set_random_seed()
- self._interpreters, self.compiled_paths = compile_to_tflite(
+ interpreters, compiled_paths = tf_module_to_tflite_interpreters(
module_class, exported_names, artifacts_dir)
+ module_name = module_class.__name__
+ return cls(module_name, backend_info, compiled_paths, interpreters)
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
# This is a noop because TFLite (mostly) doesn't support stateful modules.
pass
@@ -594,36 +723,43 @@
},
}
- def __init__(self, backend_name: str, artifact_name: str = None):
+ def __init__(self, backend_name: str, backend_id: str = None):
"""Creates a BackendInfo with the compilation details for backend_name.
Args:
backend_name: a str specifying which backend to use. Should be one of
'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
- artifact_name: an optional str specifying what name to use when saving
- compiled artifacts.
+ backend_id: an optional str specifying what name to use when saving
+ compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
Raises:
KeyError: if backend_name is not one of ['tf', 'iree_vmla',
'iree_llvmjit', 'iree_vulkan'].
+ ValueError: if backend_id doesn't start with backend_name.
"""
if backend_name not in self._name_to_info:
raise KeyError(
"Expected backend_name to be one of "
f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+ if backend_id is not None and not backend_id.startswith(backend_name):
+ raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+ f"but got '{backend_id}'.")
+
+ self.backend_name = backend_name
+ self.backend_id = backend_name if backend_id is None else backend_id
+
info = self._name_to_info[backend_name]
self._compiled_module_class = info["compiled_module_class"]
self.driver = info["driver"]
self.compiler_targets = info["compiler_targets"]
- self.name = backend_name if artifact_name is None else artifact_name
- def compile(self,
- module: Type[tf.Module],
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None) -> CompiledModule:
+ def compile_from_class(self,
+ module_class: Type[tf.Module],
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None) -> CompiledModule:
"""Creates a 'CompiledModule' for this backend."""
- return self._compiled_module_class(module, self, exported_names,
- artifacts_dir)
+ return self._compiled_module_class.create_from_class(
+ module_class, self, exported_names, artifacts_dir)
@classmethod
def get_all_backends(cls) -> Sequence["BackendInfo"]:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 1a441d0..a9c259c 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -89,7 +89,7 @@
])
def test_unaltered_state(self, backend_name):
backend_info = tf_utils.BackendInfo(backend_name)
- module = backend_info.compile(StatefulCountingModule)
+ module = backend_info.compile_from_class(StatefulCountingModule)
# Test that incrementing works properly.
self.assertEqual([0.], module.get_count())
@@ -126,8 +126,8 @@
backend_info = tf_utils.BackendInfo(backend_name)
# Test compilation is the same.
- module_1 = backend_info.compile(RandomInitModule)
- module_2 = backend_info.compile(RandomInitModule)
+ module_1 = backend_info.compile_from_class(RandomInitModule)
+ module_2 = backend_info.compile_from_class(RandomInitModule)
self.assertAllEqual(module_1.get(), module_2.get())
# Test reinitialization is the same.
diff --git a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
index 78baeb6..df9b317 100644
--- a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
@@ -54,7 +54,7 @@
def xla_load_module_proto(
xla_computation,
compiler_context: Optional[Context] = None,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE) -> Module:
"""Loads a XLA saved model from its persistent representation.
@@ -63,8 +63,7 @@
Args:
xla_computation: XLA Computation generate from XLA Python client
compiler_context: The pyiree.compiler.Context() backing the module.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to XLA_IMPORT_PASS_PIPELINE.
@@ -85,21 +84,20 @@
def xla_compile_module_proto(
xla_computation,
compiler_context: Optional[Context] = None,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE,
- target_backends: Collection[str] = ()
+ target_backends: Sequence[str] = ()
) -> binding.OpaqueBlob:
"""Loads and compiles a XLA saved model in one shot.
Args:
xla_computation: XLA Computation generate from XLA Python client
compiler_context: The pyiree.compiler.Context() backing the module.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to XLA_IMPORT_PASS_PIPELINE.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
Returns:
An OpaqueBlob representing the compiled module.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 2a0b458..7a07478 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -218,7 +218,6 @@
name = "mobile_bert_squad_tests",
backends_to_srcs = {
"tf": ["mobile_bert_squad_test.py"],
- "tflite": ["mobile_bert_squad_test.py"],
},
reference_backend = "tf",
tags = [
@@ -234,6 +233,7 @@
iree_e2e_test_suite(
name = "mobile_bert_squad_tests_failing",
backends_to_srcs = {
+ "tflite": ["mobile_bert_squad_test.py"],
"iree_vmla": ["mobile_bert_squad_test.py"],
"iree_llvmjit": ["mobile_bert_squad_test.py"],
"iree_vulkan": ["mobile_bert_squad_test.py"],