Refactor tf bindings compilation API (#3326)
This will make it easier to add support for multiple compilation paths to our TensorFlow integration tests.
List of changes:
- Rename `tf_load_saved_model` to `tf_saved_model_to_compiler_module`
- Rename `tf_load_signature_def_saved_model` to `tf_signature_def_saved_model_to_compiler_module`
- Rename `tf_compile_saved_model` to `compile_tf_saved_model`
- Add `compile_tf_signature_def_saved_model`
- Add `compile_tf_module`
- Rename `tf_utils.compile_tf_module` to `tf_utils._incrementally_compile_tf_module` and factor out `_incrementally_lower_compiler_module` and `_setup_mlir_crash_reproducer` for future code reuse.
- Remove support for compiling to multiple backends from `tf_utils`.
- This was unused and hacky because it didn't fit well into the artifact saving framework in `tf_utils.compile_tf_module`.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 982242e..774412d 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -28,14 +28,16 @@
"OutputFormat",
# TensorFlow
"TF_IMPORT_PASS_PIPELINE",
- "tf_load_saved_model",
- "tf_load_signature_def_saved_model",
- "tf_compile_saved_model",
+ "tf_saved_model_to_compiler_module",
+ "tf_signature_def_saved_model_to_compiler_module",
"tf_module_to_compiler_module",
+ "compile_tf_saved_model",
+ "compile_tf_signature_def_saved_model",
+ "compile_tf_module",
]
import tempfile
-from typing import Collection, Optional, Sequence
+from typing import Collection, Optional, Sequence, Set
from . import binding as binding
import tensorflow as tf
@@ -95,74 +97,14 @@
)
-def tf_load_saved_model(saved_model_dir: str,
- exported_names: Collection[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> Module:
- """Loads a TensorFlow saved model from its persistent representation.
-
- See also tf_compile_saved_model() for a one-shot API to load and compile.
-
- Args:
- saved_model_dir: Directory of the saved model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- input_module = binding.load_saved_model(
- compiler_context, saved_model_dir, exported_names=exported_names)
- if pass_pipeline:
- input_module.run_pass_pipeline(pass_pipeline)
- return input_module
-
-
-def tf_load_signature_def_saved_model(
- saved_model_dir: str,
- tags: Collection[str] = set(),
- exported_names: Collection[str] = [],
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> Module:
- """Loads a TensorFlow SignatureDef saved model from persistent representation.
-
- Args:
- saved_model_dir: Directory of the saved model.
- tags: Optional tuple of tags to use when loading the model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- input_module = binding.load_signature_def_saved_model(
- compiler_context, saved_model_dir, tags, exported_names=exported_names)
- if pass_pipeline:
- input_module.run_pass_pipeline(pass_pipeline)
- return input_module
-
-
-def tf_compile_saved_model(
+def tf_saved_model_to_compiler_module(
saved_model_dir: str,
exported_names: Collection[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- target_backends: Collection[str] = (),
- compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
- """Loads and compiles a TensorFlow saved model in one shot.
+ compiler_context: Optional[Context] = None) -> Module:
+ """Converts a TensorFlow SavedModel into a MLIR module.
+
+ See also compile_tf_saved_model() for a one-shot API to load and compile.
Args:
saved_model_dir: Directory of the saved model.
@@ -170,31 +112,64 @@
keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+ if not compiler_context:
+ compiler_context = Context()
+ compiler_module = binding.load_saved_model(compiler_context,
+ saved_model_dir,
+ exported_names=exported_names)
+ if pass_pipeline:
+ compiler_module.run_pass_pipeline(pass_pipeline)
+ return compiler_module
+
+
+def compile_tf_saved_model(
+ saved_model_dir: str,
+ exported_names: Collection[str] = (),
+ target_backends: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
+ """Compiles a TensorFlow SavedModel to IREE in one shot.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
target_backends: The specific target backends to compile for (defaults to
all compiled in targets).
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
Returns:
An OpaqueBlob representing the compiled module.
"""
- input_module = tf_load_saved_model(saved_model_dir, exported_names,
- pass_pipeline, compiler_context)
- return input_module.compile(target_backends=target_backends)
+ compiler_module = tf_saved_model_to_compiler_module(saved_model_dir,
+ exported_names,
+ pass_pipeline,
+ compiler_context)
+ return compiler_module.compile(target_backends=target_backends)
-def tf_module_to_compiler_module(
- module: tf.Module,
- exported_names: Collection[str] = (),
- sm_path: str = None,
+def tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir: str,
+ saved_model_tags: Set[str] = set(),
+ exported_names: Collection[str] = [],
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> Module:
- """Converts a tf.Module into a MLIR module.
+ """Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
Args:
- module: The tf.Module instance to convert to MLIR
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
exported_names: Optional tuple of strings representing the exported names to
keep.
- sm_path: the path to save the tf.Module to, if any. Defaults to None.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -204,16 +179,110 @@
This can be further compiled to an IREE blob by calling
.compile_to_sequencer_blob.
"""
-
- def _convert(sm_path):
- options = tf.saved_model.SaveOptions(save_debug_info=True)
- tf.saved_model.save(module, sm_path, options=options)
- return tf_load_saved_model(sm_path, exported_names, pass_pipeline,
- compiler_context)
-
- if sm_path is None:
- with tempfile.TemporaryDirectory() as sm_path:
- compiler_module = _convert(sm_path)
- else:
- compiler_module = _convert(sm_path)
+ if not compiler_context:
+ compiler_context = Context()
+ compiler_module = binding.load_signature_def_saved_model(
+ compiler_context,
+ saved_model_dir,
+ saved_model_tags,
+ exported_names=exported_names)
+ if pass_pipeline:
+ compiler_module.run_pass_pipeline(pass_pipeline)
return compiler_module
+
+
+def compile_tf_signature_def_saved_model(
+ saved_model_dir: str,
+ saved_model_tags: Set[str] = set(),
+ exported_names: Collection[str] = (),
+ target_backends: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
+ """Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ target_backends: The specific target backends to compile for (defaults to
+ all compiled in targets).
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+
+ Returns:
+ An OpaqueBlob representing the compiled module.
+ """
+ compiler_module = tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir, saved_model_tags, exported_names, pass_pipeline,
+ compiler_context)
+ return compiler_module.compile(target_backends=target_backends)
+
+
+def tf_module_to_compiler_module(
+ module: tf.Module,
+ exported_names: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None,
+ saved_model_dir: str = None) -> Module:
+ """Converts a tf.Module instance into a MLIR module.
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ saved_model_dir: Optional path to save the tf.Module to. The module will not
+ be saved on disk if this is not provided.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+
+ def _convert(saved_model_dir):
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(module, saved_model_dir, options=options)
+ return tf_saved_model_to_compiler_module(saved_model_dir, exported_names,
+ pass_pipeline, compiler_context)
+
+ if saved_model_dir is None:
+ with tempfile.TemporaryDirectory() as saved_model_dir:
+ compiler_module = _convert(saved_model_dir)
+ else:
+ compiler_module = _convert(saved_model_dir)
+ return compiler_module
+
+
+def compile_tf_module(module: tf.Module,
+ exported_names: Collection[str] = (),
+ target_backends: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None,
+ saved_model_dir: str = None):
+ """Compiles a tf.Module to IREE in one shot.
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ target_backends: The specific target backends to compile for (defaults to
+ all compiled in targets).
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ saved_model_dir: Optional path to save the tf.Module to. The module will not
+ be saved on disk if this is not provided.
+
+ Returns:
+ An OpaqueBlob representing the compiled module.
+ """
+ compiler_module = tf_module_to_compiler_module(module, exported_names,
+ pass_pipeline,
+ compiler_context,
+ saved_model_dir)
+ return compiler_module.compile(target_backends=target_backends)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
index a7aba66..10e55c3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
@@ -62,7 +62,7 @@
tf.saved_model.save(my_module, sm_dir, options=options)
# Load it up.
- input_module = compiler.tf_load_saved_model(sm_dir)
+ input_module = compiler.tf_saved_model_to_compiler_module(sm_dir)
xla_asm = input_module.to_asm()
print("XLA ASM:", xla_asm)
self.assertRegex(xla_asm, "mhlo.tanh")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
index 8a2e1cb..2640862 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
@@ -55,7 +55,7 @@
sess, ["bar"], {"baz": sig}, strip_default_attrs=True)
builder.save()
- module = compiler.tf_load_signature_def_saved_model(
+ module = compiler.tf_signature_def_saved_model_to_compiler_module(
sm_dir, tags=set(["bar"]), exported_names=["baz"])
module_asm = module.to_asm(large_element_limit=100)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 01c6e0f..fc8bab6 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -88,40 +88,47 @@
return result
-def backends_to_str(backend_infos: Sequence["BackendInfo"]) -> str:
- """Creates a normalized string representing the provided backends."""
- normalized_names = []
- for backend_info in backend_infos:
- # Remove unusual characters and ensure names don't end or start in "_".
- name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
- normalized_names.append(name.strip("_"))
- return "__".join(normalized_names)
+def _setup_mlir_crash_reproducer(
+ function: Callable[[Any], Any],
+ artifacts_dir: str,
+ backend_name: str,
+) -> Callable[[Any], Any]:
+ """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+
+ Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+
+ Args:
+ function: The callable to decorate.
+ artifacts_dir: The directory to write the reproducer to.
+ backend_name: The name of the backend `function` compiles to.
+
+ Returns:
+ A function with the same API as the passed function.
+ """
+
+ def decorator(*args, **kwargs):
+ # Set up a crash reproducer for debugging.
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = os.path.join(
+ artifacts_dir, f"reproducer__{backend_name}.mlir")
+ try:
+ results = function(*args, **kwargs)
+ except Exception: # pylint: disable=broad-except
+ # Disable the crash reproducer (to avoid inadvertently overwriting it).
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = None
+ raise
+ return results
+
+ return decorator
-def _get_backends_path(artifact_name: str,
- backend_infos: Sequence["BackendInfo"],
- artifacts_dir: str) -> str:
- """Gets the path to save artifact_name under for the specified backend(s)."""
- backends_string = backends_to_str(backend_infos)
- # Put the artifact in a directory if there's only one backend.
- if len(backend_infos) == 1:
- backend_dir = os.path.join(artifacts_dir, backends_string)
- os.makedirs(backend_dir, exist_ok=True)
- return os.path.join(artifacts_dir, backends_string, artifact_name)
- else:
- return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
-
-
-def compile_tf_module(
- tf_module: Type[tf.Module],
- backend_infos: Sequence["BackendInfo"] = (),
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None
+def _incrementally_lower_compiler_module(
+ compiler_module: compiler.Module,
+ backend_info: "BackendInfo",
+ artifacts_dir: str,
) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
-
- The artifact this creates is not callable. See IreeCompiledModule for an API
- that returns a module that can be called without any further steps.
+ """Lowers a MLIR compiler module incrementally and saves its outputs.
If artifacts_dir is provided then the following artifacts will be saved:
tf_input.mlir:
@@ -129,75 +136,84 @@
iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
backend_name/compiled.vmfb:
- A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
-
- If multiple backends are specified, then instead of saving compiled 'vmfb'
- under 'backend_name/', it will be saved as follows:
- - 'compiled__{backends}.vmfb'
- where 'backends' is a '__' delimited list (e.g. iree_vmla__iree_llvmjit).
+ A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
Args:
- tf_module: A tf.Module.
- backend_infos: Iterable of BackendInfo names to compile for.
+ compiler_module: A compiler.Module to lower.
+ backend_info: BackendInfo with the details for lowering compiler_module to
+ IREE.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ if artifacts_dir is not None:
+ os.makedirs(artifacts_dir, exist_ok=True)
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+ logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+ with open(tf_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ # Manually run the passes that tf_module_to_compiler_module usually would.
+ compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+ if artifacts_dir is not None:
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+ logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+ with open(iree_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ compiled_module = compiler_module.compile(
+ target_backends=backend_info.compiler_targets)
+
+ compiled_path = None
+ if artifacts_dir is not None:
+ backend_dir = os.path.join(artifacts_dir, backend_info.name)
+ os.makedirs(backend_dir, exist_ok=True)
+ compiled_path = os.path.join(backend_dir, "compiled.vmfb")
+ logging.info("Saving compiled IREE module to: %s", compiled_path)
+ with open(compiled_path, "wb") as f:
+ f.write(compiled_module)
+ return compiled_module, compiled_path
+
+
+def _incrementally_compile_tf_module(
+ module: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ module: A tf.Module.
+ backend_info: BackendInfo with the details for compiling module to IREE.
exported_names: Iterable of dotted function names to consider for
compilation.
artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved.
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
Returns:
A compiled IREE module blob and the path to the compiled VM FlatBuffer if
artifacts_dir is provided.
"""
- if artifacts_dir is not None:
- # Set up a crash reproducer for debugging.
- backends_string = backends_to_str(backend_infos)
- compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backends_string}.mlir")
-
- try:
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_module_to_compiler_module(tf_module,
+ def _compile_module(module, exported_names, backend_info, artifacts_dir):
+ compiler_module = compiler.tf_module_to_compiler_module(module,
exported_names,
pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
- if artifacts_dir is not None:
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- # Now run the passes manually that tf_module_to_compiler_module would
- # usually do.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- target_backends = []
- for backend_info in backend_infos:
- target_backends.extend(backend_info.compiler_targets)
- compiled_module = compiler_module.compile(target_backends=target_backends)
-
- compiled_path = None
- if artifacts_dir is not None:
- compiled_path = _get_backends_path("compiled", backend_infos,
- artifacts_dir)
- compiled_path = f"{compiled_path}.vmfb"
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
-
- except Exception: # pylint: disable=broad-except
- if artifacts_dir is not None:
- # Disable the crash reproducer (to avoid inadvertently overwriting it).
- compiler.Context.default_crash_reproducer_path = None
- raise
-
- return compiled_module, compiled_path
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+ backend_info.name)
+ return _compile_module(module, exported_names, backend_info, artifacts_dir)
class CompiledModule(object):
@@ -280,9 +296,9 @@
super().__init__(module_class, backend_info, exported_names, artifacts_dir)
set_random_seed()
- self._module_blob, compiled_path = compile_tf_module(
- tf_module=module_class(),
- backend_infos=[backend_info],
+ self._module_blob, compiled_path = _incrementally_compile_tf_module(
+ module=module_class(),
+ backend_info=backend_info,
exported_names=exported_names,
artifacts_dir=artifacts_dir)
self._module = rt.VmModule.from_flatbuffer(self._module_blob)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 43ec6cf..1a441d0 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -44,6 +44,7 @@
def get_count(self):
return self.count
+
class RandomInitModule(tf.Module):
def __init__(self):
@@ -56,25 +57,15 @@
class UtilsTests(tf.test.TestCase, parameterized.TestCase):
- @parameterized.named_parameters([
- {
- 'testcase_name': 'single_backend',
- 'backend_infos': [tf_utils.BackendInfo('iree_vmla')],
- },
- {
- 'testcase_name':
- 'multiple_backends',
- 'backend_infos': [
- tf_utils.BackendInfo('iree_vmla'),
- tf_utils.BackendInfo('iree_llvmjit')
- ],
- },
- ])
- def test_artifact_saving(self, backend_infos):
+ def test_artifact_saving(self):
+ backend_info = tf_utils.BackendInfo('iree_vmla')
with tempfile.TemporaryDirectory() as artifacts_dir:
tf_module = ConstantModule()
- iree_compiled_module, compiled_path = tf_utils.compile_tf_module(
- tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
+ iree_compiled_module, compiled_path = (
+ tf_utils._incrementally_compile_tf_module(
+ tf_module,
+ backend_info=backend_info,
+ artifacts_dir=artifacts_dir))
artifacts_to_check = [
'tf_input.mlir',
@@ -121,7 +112,6 @@
inputs = [np.array([1, 2], dtype=np.float32)]
self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
-
@parameterized.named_parameters([
{
'testcase_name': 'tensorflow',