Refactor trace_utils and module_utils out of tf_test_utils/tf_utils (#3916)
These files were becoming long catch-alls so I split them up:
- `ModuleCall`, `Trace` and `TracedModule` are moved from `tf_test_utils` into `trace_utils`.
- `compare_traces` is removed from `Trace` and in favor of accessing it from `trace_utils.compare_traces`.
- `check_same` is moved into `tf_utils`.
- `_FunctionWrapper` and `CompiledModule` (sub)classes are moved into `module_utils` along with `BackendInfo`.
- The tests are split up to match these changes.
diff --git a/colab/edge_detection.ipynb b/colab/edge_detection.ipynb
index 18528e4..4501f53 100644
--- a/colab/edge_detection.ipynb
+++ b/colab/edge_detection.ipynb
@@ -90,7 +90,7 @@
"import numpy as np\n",
"import tensorflow as tf\n",
"from pyiree.tf import compiler as ireec\n",
- "from pyiree.tf.support import tf_utils\n",
+ "from pyiree.tf.support import module_utils\n",
"from pyiree import rt as ireert"
]
},
@@ -271,7 +271,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
]
},
{
@@ -435,7 +435,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
]
},
{
diff --git a/colab/mnist_tensorflow.ipynb b/colab/mnist_tensorflow.ipynb
index 733f65d..4a1f1aa 100644
--- a/colab/mnist_tensorflow.ipynb
+++ b/colab/mnist_tensorflow.ipynb
@@ -86,7 +86,7 @@
"\n",
"from pyiree import rt as ireert\n",
"from pyiree.tf import compiler as ireec\n",
- "from pyiree.tf.support import tf_utils\n",
+ "from pyiree.tf.support import module_utils\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
@@ -335,7 +335,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
],
"execution_count": 8,
"outputs": []
@@ -353,7 +353,7 @@
"source": [
"#@title Compile the mhlo MLIR to an IREE backend and prepare a context to execute it\n",
"\n",
- "iree_module = tf_utils.IreeCompiledModule.create_from_instance(\n",
+ "iree_module = module_utils.IreeCompiledModule.create_from_instance(\n",
" inference_module, backend, exported_names, ARTIFACTS_DIR)\n",
"\n",
"print(\"* Module compiled! See intermediate .mlir files in\", ARTIFACTS_DIR, \"*\")"
@@ -462,4 +462,4 @@
]
}
]
-}
\ No newline at end of file
+}
diff --git a/colab/resnet.ipynb b/colab/resnet.ipynb
index 8595da9..af7e8e6 100644
--- a/colab/resnet.ipynb
+++ b/colab/resnet.ipynb
@@ -82,7 +82,7 @@
"\n",
"from pyiree import rt as ireert\n",
"from pyiree.tf import compiler as ireec\n",
- "from pyiree.tf.support import tf_utils\n",
+ "from pyiree.tf.support import module_utils\n",
"\n",
"import tensorflow as tf\n",
"from matplotlib import pyplot as plt\n",
@@ -139,7 +139,7 @@
"\n",
"backend_choice = \"iree_vmla (CPU)\" #@param [ \"iree_vmla (CPU)\", \"iree_llvmjit (CPU)\", \"iree_vulkan (GPU/SwiftShader)\" ]\n",
"backend_choice = backend_choice.split(\" \")[0]\n",
- "backend = tf_utils.BackendInfo(backend_choice)"
+ "backend = module_utils.BackendInfo(backend_choice)"
],
"execution_count": 3,
"outputs": []
@@ -314,4 +314,4 @@
]
}
]
-}
\ No newline at end of file
+}
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
index 9b1c869..09fe138 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/BUILD
@@ -29,9 +29,11 @@
name = "support",
srcs = [
"__init__.py",
+ "module_utils.py",
"tf_test_driver.py",
"tf_test_utils.py",
"tf_utils.py",
+ "trace_utils.py",
],
deps = INTREE_TENSORFLOW_PY_DEPS + [
"//integrations/tensorflow/bindings/python:pathsetup", # build_cleaner: keep
@@ -41,6 +43,22 @@
)
iree_py_test(
+ name = "module_utils_test",
+ srcs = [
+ "module_utils.py",
+ "module_utils_test.py",
+ ],
+ python_version = "PY3",
+ tags = [
+ "driver=llvm",
+ "driver=vmla",
+ ],
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_py_test(
name = "tf_test_utils_test",
srcs = [
"tf_test_utils.py",
@@ -59,6 +77,18 @@
"tf_utils_test.py",
],
python_version = "PY3",
+ deps = INTREE_TENSORFLOW_PY_DEPS + [
+ "//integrations/tensorflow/bindings/python/pyiree/tf/support",
+ ],
+)
+
+iree_py_test(
+ name = "trace_utils_test",
+ srcs = [
+ "trace_utils.py",
+ "trace_utils_test.py",
+ ],
+ python_version = "PY3",
tags = [
"driver=vmla",
],
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
new file mode 100644
index 0000000..4ec9d66
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils.py
@@ -0,0 +1,959 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for compiling 'tf.Module's"""
+
+import collections
+import os
+import tempfile
+from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+
+from absl import logging
+import numpy as np
+from pyiree import rt
+from pyiree.tf import compiler
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+
+def _setup_mlir_crash_reproducer(
+ function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
+ artifacts_dir: str,
+ backend_id: str,
+) -> Any: # Callable[Any, Any]
+ """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+
+ Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+
+ Args:
+ function: The callable to decorate.
+ artifacts_dir: The directory to write the reproducer to.
+ backend_id: The unique backend name to use when writting the reproducer.
+
+ Returns:
+ A function with the same API as the passed function.
+ """
+
+ def decorator(*args, **kwargs):
+ # Set up a crash reproducer for debugging.
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = os.path.join(
+ artifacts_dir, f"reproducer__{backend_id}.mlir")
+ try:
+ results = function(*args, **kwargs)
+ except Exception: # pylint: disable=broad-except
+ # Disable the crash reproducer (to avoid inadvertently overwriting it).
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = None
+ raise
+ return results
+
+ return decorator
+
+
+def _incrementally_lower_compiler_module(
+ compiler_module: compiler.Module,
+ backend_info: "BackendInfo",
+ artifacts_dir: str,
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ """Lowers a MLIR compiler module incrementally and saves its outputs.
+
+ If artifacts_dir is provided then the following artifacts will be saved:
+ tf_input.mlir:
+ MLIR for the module in TF's input dialect.
+ iree_input.mlir:
+ The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
+ backend_id/compiled.vmfb:
+ A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
+
+ Args:
+ compiler_module: A compiler.Module to lower.
+ backend_info: BackendInfo with the details for lowering compiler_module to
+ IREE.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ if artifacts_dir is not None:
+ os.makedirs(artifacts_dir, exist_ok=True)
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+ logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+ with open(tf_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ # Manually run the passes that tf_module_to_compiler_module usually would.
+ compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+ if artifacts_dir is not None:
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+ logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+ with open(iree_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ compiled_module = compiler_module.compile(
+ target_backends=backend_info.compiler_targets)
+
+ compiled_path = None
+ if artifacts_dir is not None:
+ backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
+ os.makedirs(backend_dir, exist_ok=True)
+ compiled_path = os.path.join(backend_dir, "compiled.vmfb")
+ logging.info("Saving compiled IREE module to: %s", compiled_path)
+ with open(compiled_path, "wb") as f:
+ f.write(compiled_module)
+ return compiled_module, compiled_path
+
+
+def _incrementally_compile_tf_module(
+ module: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None,
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ module: A tf.Module.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+
+ Returns:
+ A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+ artifacts_dir is provided.
+ """
+
+ def _compile_module(module, backend_info, exported_names, artifacts_dir):
+ compiler_module = compiler.tf_module_to_compiler_module(module,
+ exported_names,
+ pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
+
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+ backend_info.backend_id)
+ return _compile_module(module, backend_info, exported_names, artifacts_dir)
+
+
+def _incrementally_compile_tf_signature_def_saved_model(
+ saved_model_dir: str, saved_model_tags: Set[str],
+ backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
+ """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+
+ Returns:
+ A compiled IREE module blob and the path to the compiled VM FlatBuffer if
+ artifacts_dir is provided.
+ """
+
+ def _compile_module(saved_model_dir, saved_model_tags, backend_info,
+ exported_name, artifacts_dir):
+ # Convert the tf_module into raw TF input MLIR.
+ compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
+
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+ backend_info.backend_id)
+ return _compile_module(saved_model_dir, saved_model_tags, backend_info,
+ exported_name, artifacts_dir)
+
+
+class _FunctionWrapper(object):
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ """Dummy function to match _IreeFunctionWrapper's API."""
+ return ("",), ("",)
+
+
+class CompiledModule(object):
+ """Base class for the TF and IREE compiled modules."""
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Union[Dict[str, str], None],
+ ):
+ """Shared base constructor – not useful on its own.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ """
+ self.module_name = module_name
+ self.backend_info = backend_info
+ self.compiled_paths = compiled_paths
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ This is only implemented for IreeCompiledModule.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ def __getattr__(self, attr: str) -> _FunctionWrapper:
+ raise NotImplementedError()
+
+ def iree_serializable(self):
+ return False
+
+ def tflite_serializable(self):
+ return False
+
+
+class _IreeFunctionWrapper(_FunctionWrapper):
+ """Wraps an IREE function, making it callable."""
+
+ def __init__(self, context: rt.SystemContext, f: rt.system_api.BoundFunction):
+ self._context = context
+ self._f = f
+
+ def __call__(self, *args, **kwargs):
+ return self._f(*args, **kwargs)
+
+ def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
+ """Get cxx serialized inputs and outputs for this function."""
+ return self._f.get_serialized_values()
+
+
+class IreeCompiledModule(CompiledModule):
+ """Iree compiled module."""
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ vm_module: rt.VmModule,
+ config: rt.Config,
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ vm_module: A rt.VmModule containing compilation info to wrap.
+ config: A rt.Config containing compilation info to wrap.
+ """
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._vm_module = vm_module
+ self._config = config
+ self.reinitialize()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tf_utils.set_random_seed()
+ module_instance = module_class()
+ return cls.create_from_instance(module_instance, backend_info,
+ exported_names, artifacts_dir)
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_blob, compiled_path = _incrementally_compile_tf_module(
+ module=module_instance,
+ backend_info=backend_info,
+ exported_names=exported_names,
+ artifacts_dir=artifacts_dir)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
+
+ compiled_paths = None
+ if compiled_path is not None:
+ # IREE bundles every compiled method into the same compiled module.
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+ module_name = type(module_instance).__name__
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ del input_names # Unused.
+ del output_names # Unused.
+ module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, backend_info, exported_name,
+ artifacts_dir)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
+
+ compiled_paths = None
+ if compiled_path is not None:
+ # IREE bundles every compiled method into the same compiled module :)
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ # set_random_seed is not needed here because the model_class.__init__ is not
+ # called.
+ self._context = rt.SystemContext(modules=[self._vm_module],
+ config=self._config)
+
+ def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
+ # Try to resolve it as a function.
+ m = self._context.modules[self._vm_module.name]
+ f = m[attr]
+ return _IreeFunctionWrapper(self._context, f)
+
+ def iree_serializable(self) -> bool:
+ return self.compiled_paths is not None
+
+
+class _TfFunctionWrapper(_FunctionWrapper):
+ """Wraps a TF function, normalizing it to numpy."""
+
+ def __init__(self, f: Callable[..., Any]):
+ self._f = f
+
+ def __call__(self, *args, **kwargs):
+ # TensorFlow will auto-convert all inbound args.
+ results = self._f(*args, **kwargs)
+ return tf_utils.convert_to_numpy(results)
+
+
+def _convert_inputs_to_tensors(function):
+
+ def decorator(*args, **kwargs):
+ args = [tf.convert_to_tensor(arg) for arg in args]
+ kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
+ return function(*args, **kwargs)
+
+ return decorator
+
+
+class SignatureDefSavedModelWrapper(object):
+ """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
+
+ def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
+ exported_name: str):
+ self.saved_model = tf.saved_model.load(saved_model_dir,
+ tags=saved_model_tags)
+ inference_func = self.saved_model.signatures[exported_name]
+ inference_func = _convert_inputs_to_tensors(inference_func)
+ self.__setattr__(exported_name, inference_func)
+
+
+class TfCompiledModule(CompiledModule):
+ """TensorFlow 'compiled' module.
+
+ This facade exists to provide a complimentary API to IreeCompiledModule and
+ normalize TensorFlow's output to Numpy.
+ """
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ constructor: Callable[[], tf.Module],
+ exported_names: Sequence[str],
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ constructor: A callable (class or function) which returns the tf.Module
+ subclass instance to wrap.
+ exported_names: an optional iterable of strings representing which of the
+ tf.Module subclass instance's functions should be callable. If
+ exported_names is empty then all functions will be callable.
+ """
+ super().__init__(module_name, backend_info, compiled_paths=None)
+ self._constructor = constructor
+ self._exported_names = exported_names
+ self.reinitialize()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_name = module_class.__name__
+ constructor = module_class
+ return cls(module_name, backend_info, constructor, exported_names)
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ constructor = lambda: SignatureDefSavedModelWrapper(
+ saved_model_dir, saved_model_tags, exported_name)
+ return cls(module_name, backend_info, constructor, [exported_name])
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ tf_utils.set_random_seed()
+ self._tf_module = self._constructor()
+
+ def __getattr__(self, attr: str) -> _TfFunctionWrapper:
+ # Try to resolve it as a function.
+ exported = not self._exported_names or attr in self._exported_names
+ if not hasattr(self._tf_module, attr) or not exported:
+ raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
+ f = getattr(self._tf_module, attr)
+ if not f or not hasattr(f, "__call__"):
+ raise AttributeError(
+ f"The TensorFlow module does not have a callable attr '{attr}'")
+ return _TfFunctionWrapper(f)
+
+
+def _get_non_inhereted_function_names(cls):
+ """Gets all methods that cls has that its parents don't have."""
+ names = set(dir(cls))
+ for parent in cls.__bases__:
+ names -= set(dir(parent))
+ return list(names)
+
+
+def _get_concrete_functions(module_class: Type[tf.Module],
+ exported_names: Sequence[str] = ()):
+ """Get concrete functions from non-inherited methods or exported_names."""
+ if not len(exported_names):
+ # Get all method names on 'module_class' that aren't on 'tf.Module'.
+ exported_names = _get_non_inhereted_function_names(module_class)
+ instance = module_class()
+ functions = []
+ for name in exported_names:
+ functions.append(getattr(instance, name).get_concrete_function())
+ return functions, exported_names, instance
+
+
+def tf_module_to_tflite_module_bytes(
+ module_class: Type[tf.Module], exported_names: Sequence[str] = ()
+) -> Dict[str, bytes]:
+ """Compiles a tf.Module's methods with TFLite.
+
+ Args:
+ module_class: A tf.Module subclass to compile with TFLite.
+ exported_names: an optional iterable of strings representing which of the
+ module_class's functions should be compiled. If exported_names is empty
+ then all functions will be compiled.
+
+ Returns:
+ A dict mapping method names to compiled TFLite module bytes.
+ """
+ tflite_modules = []
+ methods, method_names, instance = _get_concrete_functions(
+ module_class, exported_names)
+ failed_methods = []
+ for method, method_name in zip(methods, method_names):
+ logging.info("Attempting to convert '%s' to tflite...", method_name)
+ try:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
+ logging.info("...converted '%s' to tflite.", method_name)
+ tflite_modules.append(converter.convert())
+ except Exception as e:
+ logging.error("Failed to convert '%s' to tflite.", method_name)
+ logging.error("TFLite excpetion: %s", e)
+ failed_methods.append(method_name)
+
+ if failed_methods:
+ raise RuntimeError(
+ f"Failed to convert the following methods to tflite: {failed_methods}")
+
+ # Keep variables alive until TFLite has done the conversion; ConcreteFunctions
+ # themselves only keep weak references to variables.
+ del instance
+ return dict(zip(method_names, tflite_modules))
+
+
+def tf_signature_def_saved_model_to_tflite_module_bytes(
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+) -> Dict[str, bytes]:
+ """Compiles a SignatureDef SavedModel signature with TFLite.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+
+ Returns:
+ A dict mapping the signature name to the compiled TFLite module bytes.
+ """
+ converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
+ saved_model_dir,
+ tag_set=saved_model_tags,
+ signature_key=exported_name,
+ input_arrays=input_names,
+ output_arrays=output_names)
+ tflite_module = converter.convert()
+ return dict([[exported_name, tflite_module]])
+
+
+def tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes: Dict[str, bytes],
+ artifacts_dir: str = None
+) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
+ """Compile a dict of TFLite compiled bytes to TFLite interpreters.
+
+ Args:
+ tflite_module_bytes: A dict mapping method names to compiled TFLite byte
+ strings.
+ artifacts_dir: an optional path to save compilation artifacts to.
+
+ Returns:
+ A dictionary mapping method names to TFLite interpreters and a dictionary
+ mapping method names to compiled tflite graph paths (or None if
+ artifacts_dir is None).
+ """
+ interpreters = dict()
+ compiled_paths = None
+ if artifacts_dir is not None:
+ compiled_paths = dict()
+
+ def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
+ """Save compiled TFLite module bytes and convert into an interpreter."""
+ tflite_dir = os.path.join(base_dir, "tflite")
+ os.makedirs(tflite_dir, exist_ok=True)
+ tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
+ with open(tflite_path, "wb") as f:
+ f.write(tflite_module)
+
+ interpreters[method_name] = tf.lite.Interpreter(tflite_path)
+ if artifacts_dir is not None:
+ compiled_paths[method_name] = tflite_path
+
+ # Load each of the converted methods above into tf.lite.Interpreters.
+ for method_name, tflite_module in tflite_module_bytes.items():
+ if artifacts_dir is None:
+ with tempfile.TemporaryDirectory() as base_dir:
+ _interpret_bytes(method_name, tflite_module, base_dir)
+ else:
+ _interpret_bytes(method_name, tflite_module, artifacts_dir)
+
+ return interpreters, compiled_paths
+
+
+class _TfLiteFunctionWrapper(_FunctionWrapper):
+ """Wraps a TFLite interpreter and makes it behave like a python function."""
+
+ def __init__(self, interpreter: tf.lite.Interpreter,
+ output_names: Sequence[str]):
+ self._interpreter = interpreter
+ self._output_names = output_names
+
+ def __call__(self, *args,
+ **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
+ if len(args) and len(kwargs):
+ raise ValueError("Passing both args and kwargs is not supported by "
+ "_TfLiteFunctionWrapper")
+
+ # Set up and run the function.
+ self._interpreter.allocate_tensors()
+
+ if len(args):
+ # Specifically to get TFLite to work with keras models that take a list of
+ # inputs instead of a sequence of args as their inputs, because it decides
+ # to change the input signature but it still technically works if you
+ # ignore that it does that.
+ if len(args) == 1 and isinstance(args[0], list):
+ args = args[0]
+
+ for arg, detail in zip(args, self._interpreter.get_input_details()):
+ self._interpreter.set_tensor(detail["index"], arg)
+ else:
+ for detail in self._interpreter.get_input_details():
+ self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
+
+ self._interpreter.invoke()
+
+ # Extract the outputs from the TFLite interpreter.
+ outputs = []
+ for detail in self._interpreter.get_output_details():
+ value = tf_utils.normalize_numpy(
+ self._interpreter.get_tensor(detail["index"]))
+ if self._output_names is not None:
+ name = detail["name"]
+ if name not in self._output_names:
+ raise ValueError(f"Expected '{name}' to be in {self._output_names}")
+ outputs.append([detail["name"], value])
+ else:
+ outputs.append(value)
+
+ # Process them to match the output of the tf.Module.
+ if self._output_names is not None:
+ return dict(outputs)
+ else:
+ if len(outputs) == 1:
+ return outputs[0]
+ return tuple(outputs)
+
+
+class TfLiteCompiledModule(CompiledModule):
+ """Compiles a tf.Module with TFLite and allows it to be called."""
+
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ interpreters: Dict[str, tf.lite.Interpreter],
+ output_names: Sequence[str] = None,
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ interpreters: A dict of tf.lite.Interpreters to make callable.
+ """
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._interpreters = interpreters
+ self._output_names = output_names
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling this module.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tf_utils.set_random_seed()
+ tflite_module_bytes = tf_module_to_tflite_module_bytes(
+ module_class, exported_names)
+ interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes, artifacts_dir)
+ module_name = module_class.__name__
+ return cls(module_name, backend_info, compiled_paths, interpreters)
+
+ @classmethod
+ def create_from_signature_def_saved_model(cls,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ backend_info: "BackendInfo",
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None):
+ """Compile a SignatureDef SavedModel to the target backend in backend_info.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ module_name: A name for this compiled module.
+ backend_info: BackendInfo with the details for compiling the saved model.
+ exported_name: A str representing the signature on the saved model to
+ compile.
+ input_names: A sequence of kwargs to feed to the saved model.
+ output_names: A sequence of named outputs to extract from the saved model.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
+ saved_model_dir, saved_model_tags, exported_name, input_names,
+ output_names)
+ interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
+ tflite_module_bytes, artifacts_dir)
+ return cls(module_name, backend_info, compiled_paths, interpreters,
+ output_names)
+
+ def reinitialize(self):
+ """Reinitializes all stateful variables."""
+ # This is a noop because TFLite (mostly) doesn't support stateful modules.
+ pass
+
+ def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper:
+ # Try to resolve it as an interpreter.
+ if not attr in self._interpreters:
+ raise AttributeError(
+ f"The TFLite module does not have an interpreter for '{attr}'")
+ return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names)
+
+ def tflite_serializable(self) -> bool:
+ return self.compiled_paths is not None
+
+
+class BackendInfo:
+ """Contains information for compiling the specified backend."""
+
+ _name_to_info = {
+ "tf": {
+ "compiled_module_class": TfCompiledModule,
+ "driver": None,
+ "compiler_targets": None,
+ },
+ "tflite": {
+ "compiled_module_class": TfLiteCompiledModule,
+ "driver": None,
+ "compiler_targets": None,
+ },
+ "iree_vmla": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "vmla",
+ "compiler_targets": ["vmla"]
+ },
+ "iree_vulkan": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "vulkan",
+ "compiler_targets": ["vulkan-*"]
+ },
+ }
+
+ def __init__(self, backend_name: str, backend_id: str = None):
+ """Creates a BackendInfo with the compilation details for backend_name.
+
+ Args:
+ backend_name: a str specifying which backend to use. Should be one of
+ 'tf', 'tflite', 'iree_vmla', 'iree_vulkan'.
+ backend_id: an optional str specifying what name to use when saving
+ compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
+
+ Raises:
+ KeyError: if backend_name is not one of ['tf', 'tflite', 'iree_vmla',
+ 'iree_vulkan'].
+ ValueError: if backend_id doesn't start with backend_name.
+ """
+ if backend_name not in self._name_to_info:
+ raise KeyError(
+ "Expected backend_name to be one of "
+ f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+ if backend_id is not None and not backend_id.startswith(backend_name):
+ raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+ f"but got '{backend_id}'.")
+
+ self.backend_name = backend_name
+ self.backend_id = backend_name if backend_id is None else backend_id
+
+ info = self._name_to_info[backend_name]
+ self._compiled_module_class = info["compiled_module_class"]
+ self.driver = info["driver"]
+ self.compiler_targets = info["compiler_targets"]
+
+ def compile_from_class(self,
+ module_class: Type[tf.Module],
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None) -> CompiledModule:
+ """Creates a 'CompiledModule' for this backend."""
+ return self._compiled_module_class.create_from_class(
+ module_class, self, exported_names, artifacts_dir)
+
+ def compile_signature_def_saved_model(
+ self,
+ saved_model_dir: str,
+ saved_model_tags: Set[str],
+ module_name: str,
+ exported_name: str,
+ input_names: Sequence[str],
+ output_names: Sequence[str],
+ artifacts_dir: str = None) -> CompiledModule:
+ return self._compiled_module_class.create_from_signature_def_saved_model(
+ saved_model_dir, saved_model_tags, module_name, self, exported_name,
+ input_names, output_names, artifacts_dir)
+
+ @classmethod
+ def get_all_backends(cls) -> Sequence["BackendInfo"]:
+ """Returns a list of all BackendInfo configurations."""
+ return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py
new file mode 100644
index 0000000..8baa6bc
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py
@@ -0,0 +1,114 @@
+# Lint as: python3
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.module_utils."""
+
+import os
+import tempfile
+
+from absl import logging
+from absl.testing import parameterized
+from pyiree.tf.support import module_utils
+import tensorflow as tf
+
+
+class ConstantModule(tf.Module):
+
+ @tf.function(input_signature=[])
+ def meaning(self):
+ return tf.constant([42.])
+
+
+class StatefulCountingModule(tf.Module):
+
+ def __init__(self):
+ self.count = tf.Variable([0.])
+
+ @tf.function(input_signature=[])
+ def increment(self):
+ self.count.assign_add(tf.constant([1.]))
+
+ @tf.function(input_signature=[])
+ def get_count(self):
+ return self.count
+
+
+class RandomInitModule(tf.Module):
+
+ def __init__(self):
+ self.value = tf.Variable(tf.random.uniform([1]))
+
+ @tf.function(input_signature=[])
+ def get(self):
+ return self.value
+
+
+class UtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+ def test_artifact_saving(self):
+ backend_info = module_utils.BackendInfo('iree_vmla')
+ with tempfile.TemporaryDirectory() as artifacts_dir:
+ tf_module = ConstantModule()
+ iree_module_utils, compiled_path = (
+ module_utils._incrementally_compile_tf_module(
+ tf_module, backend_info=backend_info,
+ artifacts_dir=artifacts_dir))
+
+ artifacts_to_check = [
+ 'tf_input.mlir',
+ 'iree_input.mlir',
+ compiled_path,
+ ]
+ for artifact in artifacts_to_check:
+ artifact_path = os.path.join(artifacts_dir, artifact)
+ logging.info('Checking path: %s', artifact_path)
+ self.assertTrue(os.path.exists(artifact_path))
+
+ @parameterized.named_parameters([
+ ('tensorflow', 'tf'),
+ ('vmla', 'iree_vmla'),
+ ])
+ def test_unaltered_state(self, backend_name):
+ backend_info = module_utils.BackendInfo(backend_name)
+ module = backend_info.compile_from_class(StatefulCountingModule)
+
+ # Test that incrementing works properly.
+ self.assertEqual([0.], module.get_count())
+ module.increment()
+ self.assertEqual([1.], module.get_count())
+
+ module.reinitialize()
+ # Test reinitialization.
+ self.assertEqual([0.], module.get_count())
+
+ @parameterized.named_parameters([
+ ('tensorflow', 'tf'),
+ ('vmla', 'iree_vmla'),
+ ])
+ def test_random_initialization(self, backend_name):
+ backend_info = module_utils.BackendInfo(backend_name)
+
+ # Test compilation is the same.
+ module_1 = backend_info.compile_from_class(RandomInitModule)
+ module_2 = backend_info.compile_from_class(RandomInitModule)
+ self.assertAllEqual(module_1.get(), module_2.get())
+
+ # Test reinitialization is the same.
+ old_value = module_1.get()
+ module_1.reinitialize()
+ self.assertAllEqual(old_value, module_1.get())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index e18b5ce..e2a67b9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -24,20 +24,17 @@
import collections
import copy
-import glob
-import inspect
import itertools
import os
-import pickle
import re
-import sys
import tempfile
from typing import Any, Callable, Dict, List, Sequence, Set, Tuple, Type, Union
from absl import flags
from absl import logging
-import numpy as np
+from pyiree.tf.support import module_utils
from pyiree.tf.support import tf_utils
+from pyiree.tf.support import trace_utils
import tensorflow.compat.v2 as tf
flags.DEFINE_string("reference_backend", "tf",
@@ -56,8 +53,8 @@
flags.DEFINE_bool(
"get_saved_model", False,
"Creates and stores a SavedModel for the tf.Module class to be tested.")
+
FLAGS = flags.FLAGS
-NUMPY_LINEWIDTH = 120
DEFAULT_INPUT_GENERATOR = tf_utils.uniform
@@ -96,7 +93,7 @@
return backend_names, backend_ids
-def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
+def get_target_backends() -> Sequence[module_utils.BackendInfo]:
"""Gets the BackendInfo instances to compare with the reference backend.
By default all backends in BackendInfo will be used. Specific backends to
@@ -109,496 +106,15 @@
logging.info("Using backends from command line: %s", FLAGS.target_backends)
backend_names, backend_ids = _parse_target_backends()
backends = [
- tf_utils.BackendInfo(backend_name, backend_id)
+ module_utils.BackendInfo(backend_name, backend_id)
for backend_name, backend_id in zip(backend_names, backend_ids)
]
else:
# If no backends are specified, use them all.
- backends = tf_utils.BackendInfo.get_all_backends()
+ backends = module_utils.BackendInfo.get_all_backends()
return backends
-def _indent(input_str: str, indentation: int = 2) -> str:
- """Indents a string by the specified number of spaces, defaulting to 2."""
- spaces = " " * indentation
- lines = input_str.split("\n")
- # Prepend spaces to each non-empty line.
- lines = [f"{spaces}{line}" if len(line) else line for line in lines]
- return "\n".join(lines)
-
-
-def _zfill_width(length: int) -> Union[int, None]:
- return int(np.ceil(np.log10(length))) if length else None
-
-
-class ModuleCall:
-
- def __init__(self,
- method: str,
- inputs: Tuple[Any],
- outputs: Tuple[Any],
- serialized_inputs: Tuple[str],
- serialized_outputs: Tuple[str],
- rtol: float = 1e-6,
- atol: float = 1e-6):
- """Records the details of a call to a CompiledModule."""
- self.method = method
-
- # Deepcopy to safegard against mutation.
- self.inputs = copy.deepcopy(inputs)
- if outputs is not None:
- outputs = copy.deepcopy(outputs)
- else:
- outputs = tuple()
- self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
-
- self.serialized_inputs = serialized_inputs
- self.serialized_outputs = serialized_outputs
-
- self.rtol = rtol
- self.atol = atol
-
- def get_tolerances(self) -> Tuple[float, float]:
- """Gets the floating point tolerances associated with this call."""
- return self.rtol, self.atol
-
- def _get_shape_and_dtype(self, value: Any) -> str:
- if isinstance(value, np.ndarray):
- return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True)
- else:
- return str(type(value))
-
- def __str__(self):
- prior_printoptions = np.get_printoptions()
- np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
-
- header = f"Method: {self.method}"
- inputs = "\n".join(_indent(str(value)) for value in self.inputs)
- input_shapes = ", ".join(
- self._get_shape_and_dtype(value) for value in self.inputs)
-
- outputs = "\n".join(_indent(str(value)) for value in self.outputs)
- output_shapes = ", ".join(
- self._get_shape_and_dtype(value) for value in self.outputs)
-
- tolerances = _indent(f"rtol={self.rtol}, atol={self.atol}")
- body = (f"Inputs: {input_shapes}\n{inputs}\n"
- f"Outputs: {output_shapes}\n{outputs}"
- f"\nTolerances:\n{tolerances}")
- result = f"{header}\n{_indent(body)}"
-
- np.set_printoptions(**prior_printoptions)
- return result
-
- def serialize(self, call_dir: str) -> None:
- """Stores a serialized copy of this call.
-
- Can be loaded via ModuleCall.load(call_dir)
-
- Args:
- call_dir: str, the path to the directory to serialize this call to.
- """
- os.makedirs(call_dir, exist_ok=True)
-
- metadata = {
- "method": self.method,
- "serialized_inputs": self.serialized_inputs,
- "serialized_outputs": self.serialized_outputs,
- "rtol": self.rtol,
- "atol": self.atol
- }
- with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
- pickle.dump(metadata, f)
-
- width = _zfill_width(len(self.inputs))
- for i, value in enumerate(self.inputs):
- path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl")
- with open(path, "wb") as f:
- pickle.dump(value, f)
-
- width = _zfill_width(len(self.outputs))
- for i, value in enumerate(self.outputs):
- path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl")
- with open(path, "wb") as f:
- pickle.dump(value, f)
-
- @staticmethod
- def load(call_dir: str) -> "ModuleCall":
- """Loads and returns a trace serialized with ModuleCall.serialize."""
- with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f:
- kwargs = pickle.load(f)
-
- for result_type in ["input", "output"]:
- key = f"{result_type}s" # inputs or outputs
- kwargs[key] = []
-
- files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl"))
- for filename in sorted(files):
- with open(filename, "rb") as f:
- kwargs[key].append(pickle.load(f))
-
- # Convert to tuple to match python's return type for multiple results.
- kwargs[key] = tuple(kwargs[key])
-
- return ModuleCall(**kwargs)
-
-
-class Trace:
- """Stores the inputs and outputs of a series of calls to a module."""
-
- def __init__(self,
- module: Union[tf_utils.CompiledModule, None],
- function: Union[Callable[["TracedModule"], None], None],
- _load_dict: Dict[str, Any] = None):
- """Extracts metadata from module and function and initializes.
-
- Example usage:
- def forward_pass(...):
- ...
- module = IreeCompiledModule(...)
- trace = Trace(module, forward_pass)
- forward_pass(TracedModule(module, trace))
-
- Args:
- module: the module who's outputs this trace will record.
- function: the function that module will be traced on.
- _load_dict: used internally
- """
- if _load_dict is None:
- # Extract metadata from module and function.
- self.module_name = module.module_name
- self.compiled_paths = module.compiled_paths
- self.backend_name = module.backend_info.backend_name
- self.backend_id = module.backend_info.backend_id
- self.backend_driver = module.backend_info.driver
- self.iree_serializable = module.iree_serializable()
- self.tflite_serializable = module.tflite_serializable()
- self.function_name = function.__name__
- self.function_sourcefile = inspect.getsourcefile(function)
- source, start_line = inspect.getsourcelines(function)
- self.function_line_numbers = (start_line, start_line + len(source))
- self.function_source = "".join(source)
-
- self.calls = []
- else:
- self.module_name = _load_dict["module_name"]
- self.compiled_paths = _load_dict["compiled_paths"]
- self.backend_name = _load_dict["backend_name"]
- self.backend_id = _load_dict["backend_id"]
- self.backend_driver = _load_dict["backend_driver"]
- self.iree_serializable = _load_dict["iree_serializable"]
- self.tflite_serializable = _load_dict["tflite_serializable"]
- self.function_name = _load_dict["function_name"]
- self.function_sourcefile = _load_dict["function_sourcefile"]
- self.function_line_numbers = _load_dict["function_line_numbers"]
- self.function_source = _load_dict["function_source"]
- self.calls = _load_dict["calls"]
-
- def __str__(self):
- header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
- f"on function '{self.function_name}':")
- # Give each call a number so it's easier to compare between multiple traces.
- calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
- calls = _indent("\n".join(calls))
- return f"{header}\n{calls}"
-
- def __iter__(self):
- for call in self.calls:
- yield call
-
- @staticmethod
- def compare_traces(ref_trace: "Trace",
- tar_trace: "Trace") -> Tuple[bool, Sequence[str]]:
- traces_match = True
- error_messages = []
-
- # Check that all method invocations match.
- ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
- tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
- if ref_methods != tar_methods:
- # Raise a ValueError instead of returning False since this is an
- # unexpected error.
- raise ValueError(
- "The reference and target traces have different call structures:\n"
- f"Reference: {ref_methods}\nTarget: {tar_methods}")
-
- for ref_call, tar_call in zip(ref_trace, tar_trace):
- logging.info("Comparing calls to '%s'", ref_call.method)
- rtol, atol = ref_call.get_tolerances()
-
- inputs_match, error_message = Trace._check_same(ref_call.inputs,
- tar_call.inputs, rtol,
- atol)
- if not inputs_match:
- error_messages.append(error_message)
- logging.error("Inputs did not match.")
- outputs_match, error_message = Trace._check_same(ref_call.outputs,
- tar_call.outputs, rtol,
- atol)
- if not outputs_match:
- error_messages.append(error_message)
- logging.error("Outputs did not match.")
- calls_match = inputs_match and outputs_match
-
- if not calls_match:
- logging.error("Comparision between '%s' and '%s' failed on method '%s'",
- ref_trace.backend_id, tar_trace.backend_id,
- ref_call.method)
- logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
- ref_call)
- logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
-
- traces_match = traces_match and calls_match
- return traces_match, error_messages
-
- @staticmethod
- def _check_same(ref: Any, tar: Any, rtol: float,
- atol: float) -> Tuple[bool, Union[str, None]]:
- """Checks that ref and tar have identical datastructures and values."""
- # Check for matching types.
- if not isinstance(tar, type(ref)):
- error = ("Expected ref and tar to have the same type but got "
- f"'{type(ref)}' and '{type(tar)}'")
- logging.error(error)
- return False, error
-
- if ref is None:
- # Nothing to compare (e.g. the called method had no outputs).
- return True, None
-
- # Recursive check for dicts.
- if isinstance(ref, dict):
- if ref.keys() != tar.keys():
- error = ("Expected ref and tar to have the same keys, but got "
- f"'{ref.keys()}' and '{tar.keys()}'")
- logging.error(error)
- return False, error
- # Check that all of the dictionaries' values are the same.
- for key in ref:
- same, error = Trace._check_same(ref[key], tar[key], rtol, atol)
- if not same:
- return same, error
-
- # Recursive check for iterables.
- elif isinstance(ref, list) or isinstance(ref, tuple):
- if len(ref) != len(tar):
- error = ("Expected ref and tar to have the same length, but got "
- f"{len(ref)} and {len(tar)}")
- logging.error(error)
- return False, error
- # Check that all of the iterables' values are the same.
- for i in range(len(ref)):
- same, error = Trace._check_same(ref[i], tar[i], rtol, atol)
- if not same:
- return same, error
-
- # Base check for numpy arrays.
- elif isinstance(ref, np.ndarray):
- if ref.dtype != tar.dtype:
- error = ("Expected ref and tar to have the same dtype, but got "
- f"'{ref.dtype}' and '{tar.dtype}'")
- logging.error(error)
- return False, error
- if ref.size == tar.size == 0:
- return True, None
-
- if np.issubdtype(ref.dtype, np.floating):
- same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True)
- abs_diff = np.max(np.abs(ref - tar))
- rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
- diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
- f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
- if not same:
- error = ("Floating point difference between ref and tar was too "
- f"large. {diff_string}")
- logging.error(error)
- else:
- error = None
- logging.info(
- "Floating point difference between ref and tar was within "
- "tolerance. %s", diff_string)
- return same, error
- elif np.issubdtype(ref.dtype, np.integer):
- same = np.array_equal(ref, tar)
- if not same:
- abs_diff = np.max(np.abs(ref - tar))
- error = ("Expected array equality between ref and tar, but got "
- f"a max elementwise difference of {abs_diff}")
- logging.error(error)
- else:
- error = None
- return same, error
- else:
- return np.array_equal(ref, tar), None
-
- # Base check for native number types.
- elif isinstance(ref, (int, float)):
- return ref == tar, None
-
- # If outputs end up here then an extra branch for that type should be added.
- else:
- raise TypeError(f"Encountered results with unexpected type {type(ref)}")
- return True, None
-
- def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
- """Saves a human-readable string representation of this trace to disk.
-
- Args:
- trace_dir: str, path to the directory to save the trace in.
- summarize: a bool controlling whether numpy should summarize the inputs
- and outputs if they're large. Setting this to False is very slow for
- large outputs.
- """
- prior_printoptions = np.get_printoptions()
- np.set_printoptions(
- linewidth=NUMPY_LINEWIDTH,
- threshold=None if summarize else sys.maxsize,
- edgeitems=10) # Can show more items since they won't clutter the logs.
-
- path = os.path.join(trace_dir, "log.txt")
- with open(path, "w") as f:
- f.write(str(self))
- f.write("\n")
-
- np.set_printoptions(**prior_printoptions)
-
- def serialize(self, trace_dir: str) -> None:
- """Stores a serialized copy of this trace in trace_dir.
-
- It can be loaded via `Trace.load(trace_dir)`.
-
- Args:
- trace_dir: str, path to the directory to serialize the trace to.
- """
-
- compiled_paths = None
- if self.compiled_paths is not None:
- # Convert to a dict to avoid the issues with serializing defaultdicts.
- compiled_paths = dict(self.compiled_paths)
-
- # Python serialization.
- metadata = {
- "module_name": self.module_name,
- "compiled_paths": compiled_paths,
- "backend_name": self.backend_name,
- "backend_id": self.backend_id,
- "backend_driver": self.backend_driver,
- "iree_serializable": self.iree_serializable,
- "tflite_serializable": self.tflite_serializable,
- "function_name": self.function_name,
- "function_sourcefile": self.function_sourcefile,
- "function_line_numbers": self.function_line_numbers,
- "function_source": self.function_source
- }
- with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f:
- pickle.dump(metadata, f)
-
- width = _zfill_width(len(self.calls))
- for i, call in enumerate(self.calls):
- call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
- call.serialize(call_dir)
-
- # C++ benchmark serialization.
- if self.iree_serializable or self.tflite_serializable:
- entry_function = self.calls[0].method
- compiled_path = self.compiled_paths[entry_function]
-
- if self.iree_serializable:
- serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
- flagfile = [
- f"--module_file={compiled_path}",
- f"--driver={self.backend_driver}",
- f"--function_inputs={serialized_inputs}",
- f"--entry_function={entry_function}",
- ]
- with open(os.path.join(trace_dir, "flagfile"), "w") as f:
- f.writelines(line + "\n" for line in flagfile)
- else:
- with open(os.path.join(trace_dir, "graph_path"), "w") as f:
- f.writelines(compiled_path + "\n")
-
- @staticmethod
- def load(trace_dir: str) -> "Trace":
- """Loads and returns a trace serialized with Trace.serialize.
-
- Args:
- trace_dir: str, path to the directory of the serialized trace.
-
- Returns:
- A Trace deserialized from trace_dir.
- """
- with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f:
- load_dict = pickle.load(f)
- call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*")))
- calls = [ModuleCall.load(call_dir) for call_dir in call_dirs]
- load_dict["calls"] = calls
- return Trace(module=None, function=None, _load_dict=load_dict)
-
-
-def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
- trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
- trace.function_name)
- os.makedirs(trace_dir, exist_ok=True)
- return trace_dir
-
-
-class TracedModule:
-
- def __init__(self, module: tf_utils.CompiledModule, trace: Trace):
- """Wraps a CompiledModule so that all inputs and outputs are traced.
-
- The TracedModule returned will have an API almost identical to that of the
- passed CompiledModule. The only changes is that if the keywords `rtol` or
- `atol` are passed to one of the CompiledModule's methods, then they will be
- used to set the tolerance for comparing that call to the same call in
- another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
- would be the same as calling `module.add(a, b)`.
-
- Args:
- module: the CompiledModule to trace.
- trace: the Trace to record calls to this module with.
- """
- self._module = module
- self._trace = trace
-
- def _trace_call(self, method: tf_utils._FunctionWrapper, method_name: str):
- """Decorates a CompiledModule method to capture its inputs and outputs."""
-
- def call(*args, **kwargs):
- # Pop manually specified tolerances from the kwargs (if any).
- tolerances = {}
- tolerances["rtol"] = kwargs.pop("rtol", None)
- tolerances["atol"] = kwargs.pop("atol", None)
- # Only pass these to ModuleCall if they were specified by the user.
- tolerances = {k: v for k, v in tolerances.items() if v is not None}
-
- # Ensure the inputs are numpy inputs.
- args = tf_utils.convert_to_numpy(args)
- kwargs = tf_utils.convert_to_numpy(kwargs)
-
- # Run the method and record the details of the call.
- outputs = method(*args, **kwargs)
- serialized_inputs, serialized_outputs = method.get_serialized_values()
- self._trace.calls.append(
- ModuleCall(method_name, args, outputs, serialized_inputs,
- serialized_outputs, **tolerances))
- return outputs
-
- return call
-
- def __getattr__(self, attr):
- # Try to resolve it as an attr on self._module.
- if not hasattr(self._module, attr):
- raise AttributeError(f"The compiled module does not have attr '{attr}'")
- module_attr = getattr(self._module, attr)
- if not hasattr(module_attr, "__call__"):
- # e.g. traced_module.backend
- return module_attr
- else:
- # e.g. traced_module.simple_mul(a, b)
- return self._trace_call(module_attr, method_name=attr)
-
-
Modules = collections.namedtuple("Modules",
["ref_module", "tar_modules", "artifacts_dir"])
@@ -635,8 +151,8 @@
artifacts_dir = _setup_artifacts_dir(module_class.__name__)
# Get the backend information for this test.
- ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
- f"{FLAGS.reference_backend}_ref")
+ ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
tar_backend_infos = get_target_backends()
compile_backend = lambda backend_info: backend_info.compile_from_class(
@@ -678,8 +194,8 @@
artifacts_dir = _setup_artifacts_dir(module_name)
# Get the backend information for this test.
- ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
- f"{FLAGS.reference_backend}_ref")
+ ref_backend_info = module_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
tar_backend_infos = get_target_backends()
compile_backend = (
@@ -1049,7 +565,9 @@
f"unit_test '{unit_test.__name__}'.")
setattr(cls, unit_test.__name__, unit_test)
- def compare_backends(self, trace_function: Callable[[TracedModule], None],
+ def compare_backends(self,
+ trace_function: Callable[[trace_utils.TracedModule],
+ None],
modules: Modules) -> None:
"""Run the reference and target backends on trace_function and compare them.
@@ -1060,19 +578,20 @@
trace_function: a function accepting a TracedModule as its argument.
"""
# Create Traces for each backend.
- ref_trace = Trace(modules.ref_module, trace_function)
+ ref_trace = trace_utils.Trace(modules.ref_module, trace_function)
tar_traces = [
- Trace(module, trace_function) for module in modules.tar_modules
+ trace_utils.Trace(module, trace_function)
+ for module in modules.tar_modules
]
# Run the traces through trace_function with their associated modules.
tf_utils.set_random_seed()
- trace_function(TracedModule(modules.ref_module, ref_trace))
+ trace_function(trace_utils.TracedModule(modules.ref_module, ref_trace))
if FLAGS.log_all_traces:
logging.info(ref_trace)
for module, trace in zip(modules.tar_modules, tar_traces):
tf_utils.set_random_seed()
- trace_function(TracedModule(module, trace))
+ trace_function(trace_utils.TracedModule(module, trace))
if FLAGS.log_all_traces:
logging.info(trace)
@@ -1082,17 +601,18 @@
for i, tar_trace in enumerate(tar_traces):
logging.info("Comparing the reference backend '%s' with '%s'",
ref_trace.backend_id, tar_trace.backend_id)
- traces_match, errors = Trace.compare_traces(ref_trace, tar_trace)
+ traces_match, errors = trace_utils.compare_traces(ref_trace, tar_trace)
if not traces_match:
failed_backend_indices.append(i)
error_messages.extend(errors)
# Save the results to disk before validating.
- ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
+ ref_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir, ref_trace)
ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
ref_trace.serialize(ref_trace_dir)
for tar_trace in tar_traces:
- tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
+ tar_trace_dir = trace_utils.get_trace_dir(modules.artifacts_dir,
+ tar_trace)
tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
tar_trace.serialize(tar_trace_dir)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 6069688..84c1c36 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -14,45 +14,13 @@
# limitations under the License.
"""Tests for pyiree.tf.support.tf_test_utils."""
-import os
-import tempfile
-
-from absl.testing import parameterized
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow as tf
-class StatefulCountingModule(tf.Module):
-
- def __init__(self):
- self.count = tf.Variable([0.])
-
- @tf.function(input_signature=[])
- def increment(self):
- self.count.assign_add(tf.constant([1.]))
-
- @tf.function(input_signature=[])
- def get_count(self):
- return self.count
-
- @tf.function(input_signature=[tf.TensorSpec([1])])
- def increment_by(self, value):
- self.count.assign_add(value)
-
- @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])])
- def increment_by_max(self, a, b):
- result = tf.maximum(a, b)
- self.count.assign_add(result)
- return result
-
- @tf.function(input_signature=[])
- def decrement(self):
- self.count.assign_sub(tf.constant([1.]))
-
-
-class TfFunctionUnittestModule(tf_test_utils.TestModule):
+class TfFunctionUnitTestModule(tf_test_utils.TestModule):
@tf_test_utils.tf_function_unit_test(input_signature=[])
def no_args(self):
@@ -100,158 +68,7 @@
return tf.matmul(a, b)
-class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
-
- @parameterized.named_parameters([
- {
- 'testcase_name': 'all the same',
- 'array_c': np.array([0, 1, 2]),
- 'array_d': np.array(['0', '1', '2']),
- 'array_e': np.array([0.0, 0.1, 0.2]),
- 'tar_same': True,
- },
- {
- 'testcase_name': 'wrong int',
- 'array_c': np.array([1, 1, 2]),
- 'array_d': np.array(['0', '1', '2']),
- 'array_e': np.array([0.0, 0.1, 0.2]),
- 'tar_same': False,
- },
- {
- 'testcase_name': 'wrong string',
- 'array_c': np.array([0, 1, 2]),
- 'array_d': np.array(['a', '1', '2']),
- 'array_e': np.array([0.0, 0.1, 0.2]),
- 'tar_same': False,
- },
- {
- 'testcase_name': 'wrong float',
- 'array_c': np.array([0, 1, 2]),
- 'array_d': np.array(['0', '1', '2']),
- 'array_e': np.array([1.0, 0.1, 0.2]),
- 'tar_same': False,
- },
- ])
- def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
-
- # yapf: disable
- ref = {
- 'a': 1,
- 'b': [
- {'c': np.array([0, 1, 2])},
- {'d': np.array(['0', '1', '2'])},
- {'e': np.array([0.0, 0.1, 0.2])}
- ],
- }
- tar = {
- 'a': 1,
- 'b': [
- {'c': array_c},
- {'d': array_d},
- {'e': array_e}
- ],
- }
- # yapf: enable
- same, _ = tf_test_utils.Trace._check_same(ref, tar, rtol=1e-6, atol=1e-6)
- self.assertEqual(tar_same, same)
-
- def test_trace_inputs_and_outputs(self):
-
- def trace_function(module):
- # No inputs or outputs
- module.increment()
- # Only inputs
- module.increment_by(np.array([81.], dtype=np.float32))
- # Only outputs
- module.get_count()
-
- module = tf_utils.TfCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('tf'))
- trace = tf_test_utils.Trace(module, trace_function)
- trace_function(tf_test_utils.TracedModule(module, trace))
-
- self.assertIsInstance(trace.calls[0].inputs, tuple)
- self.assertEmpty(trace.calls[0].inputs)
- self.assertIsInstance(trace.calls[0].outputs, tuple)
- self.assertEmpty(trace.calls[0].outputs)
-
- self.assertAllClose(trace.calls[1].inputs[0], [81.])
- self.assertAllClose(trace.calls[2].outputs[0], [82.])
-
- def test_nonmatching_methods(self):
-
- def tf_function(module):
- module.increment()
- module.increment()
-
- def vmla_function(module):
- module.increment()
- module.decrement()
-
- tf_module = tf_utils.TfCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('tf'))
- tf_trace = tf_test_utils.Trace(tf_module, tf_function)
- tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
-
- vmla_module = tf_utils.IreeCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
- vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
- vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
-
- with self.assertRaises(ValueError):
- tf_test_utils.Trace.compare_traces(tf_trace, vmla_trace)
-
- def test_nonmatching_inputs(self):
-
- def tf_function(module):
- module.increment_by(np.array([42.], dtype=np.float32))
-
- def vmla_function(module):
- module.increment_by(np.array([22.], dtype=np.float32))
-
- tf_module = tf_utils.TfCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('tf'))
- tf_trace = tf_test_utils.Trace(tf_module, tf_function)
- tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
-
- vmla_module = tf_utils.IreeCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
- vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
- vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
-
- same, error_messages = tf_test_utils.Trace.compare_traces(
- tf_trace, vmla_trace)
- self.assertFalse(same)
-
- def test_trace_serialize_and_load(self):
-
- def trace_function(module):
- module.increment()
- module.increment_by(np.array([81.], dtype=np.float32))
- module.increment_by_max(np.array([81], dtype=np.float32),
- np.array([92], dtype=np.float32))
- module.get_count()
-
- module = tf_utils.IreeCompiledModule.create_from_class(
- StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
- trace = tf_test_utils.Trace(module, trace_function)
- trace_function(tf_test_utils.TracedModule(module, trace))
-
- with tempfile.TemporaryDirectory() as artifacts_dir:
- trace_function_dir = tf_test_utils._get_trace_dir(artifacts_dir, trace)
- trace.serialize(trace_function_dir)
- self.assertTrue(
- os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
- loaded_trace = tf_test_utils.Trace.load(trace_function_dir)
-
- # Check all calls match.
- self.assertTrue(tf_test_utils.Trace.compare_traces(trace, loaded_trace))
-
- # Check all other metadata match.
- self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
- for key in trace.__dict__.keys():
- if key != 'calls':
- self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
+class TestUtilsTests(tf.test.TestCase):
def test_tf_function_unittet(self):
@@ -260,9 +77,9 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modules = tf_test_utils.compile_tf_module(
- TfFunctionUnittestModule)
+ TfFunctionUnitTestModule)
- TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnittestModule)
+ TfFunctionUnittestTest.generate_unit_tests(TfFunctionUnitTestModule)
test_case = TfFunctionUnittestTest()
self.assertTrue(hasattr(test_case, 'test_no_args'))
self.assertTrue(hasattr(test_case, 'test_default_uniform_inputs'))
@@ -270,7 +87,7 @@
self.assertTrue(hasattr(test_case, 'test_custom_input_args'))
self.assertTrue(hasattr(test_case, 'test_high_tolerance'))
- # Will throw an error if 'atol' and 'rtol' are not set.
+ # Will throw an error if 'atol' is not set.
test_case = TfFunctionUnittestTest()
test_case.test_high_tolerance()
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index aa94744..1af3cca 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -14,21 +14,18 @@
# limitations under the License.
"""Utilities interop with TensorFlow."""
-# pylint: disable=protected-access
-
-import collections
import os
import random
import re
-import tempfile
-from typing import Any, Callable, Dict, Sequence, Set, Tuple, Type, Union
+from typing import Any, Callable, Sequence, Set, Tuple, Union
from absl import logging
import numpy as np
-from pyiree import rt
-from pyiree.tf import compiler
import tensorflow.compat.v2 as tf
+InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
+ np.ndarray]
+
def set_random_seed(seed: int = 0) -> None:
"""Set random seed for tf, np and random."""
@@ -37,10 +34,6 @@
np.random.seed(seed)
-InputGeneratorType = Callable[[Sequence[int], Union[tf.DType, np.dtype]],
- np.ndarray]
-
-
def uniform(shape: Sequence[int],
dtype: Union[tf.DType, np.dtype] = np.float32,
low: float = -1.0,
@@ -94,8 +87,32 @@
return apply_function(spec, generate)
+def normalize_numpy(result: np.ndarray):
+ """Normalizes TF and TFLite's outputs to match IREE's"""
+ if np.isscalar(result):
+ result = np.array(result)
+ if result.dtype == np.bool:
+ # IREE interprets bools as int8s, so we modify this for comparison.
+ result = result.astype(dtype=np.int8)
+ return result
+
+
+def convert_to_numpy(values: Any) -> Any:
+ """Converts any tf.Tensor in values to numpy."""
+
+ def _convert_to_numpy(tensor: Any) -> Any:
+ if not isinstance(tensor, tf.Tensor):
+ return tensor
+ return normalize_numpy(tensor.numpy())
+
+ return apply_function(values, _convert_to_numpy)
+
+
def to_mlir_type(dtype: np.dtype) -> str:
"""Returns a string that denotes the type 'dtype' in MLIR style."""
+ if not isinstance(dtype, np.dtype):
+ # Handle np.int8 _not_ being a dtype.
+ dtype = np.dtype(dtype)
bits = dtype.itemsize * 8
if np.issubdtype(dtype, np.integer):
return f"i{bits}"
@@ -192,953 +209,90 @@
return tf.TensorSpec([None] * len(spec.shape), spec.dtype)
-def _setup_mlir_crash_reproducer(
- function: Any, # pytype doesn't support arbitrary Callable[*args, **kwargs]
- artifacts_dir: str,
- backend_id: str,
-) -> Any: # Callable[Any, Any]
- """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+def check_same(ref: Any, tar: Any, rtol: float,
+ atol: float) -> Tuple[bool, Union[str, None]]:
+ """Checks that ref and tar have identical datastructures and values."""
+ # Check for matching types.
+ if not isinstance(tar, type(ref)):
+ error = ("Expected ref and tar to have the same type but got "
+ f"'{type(ref)}' and '{type(tar)}'")
+ logging.error(error)
+ return False, error
- Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+ if ref is None:
+ # Nothing to compare (e.g. the called method had no outputs).
+ return True, None
- Args:
- function: The callable to decorate.
- artifacts_dir: The directory to write the reproducer to.
- backend_id: The unique backend name to use when writting the reproducer.
+ # Recursive check for dicts.
+ if isinstance(ref, dict):
+ if ref.keys() != tar.keys():
+ error = ("Expected ref and tar to have the same keys, but got "
+ f"'{ref.keys()}' and '{tar.keys()}'")
+ logging.error(error)
+ return False, error
+ # Check that all of the dictionaries' values are the same.
+ for key in ref:
+ same, error = check_same(ref[key], tar[key], rtol, atol)
+ if not same:
+ return same, error
- Returns:
- A function with the same API as the passed function.
- """
+ # Recursive check for iterables.
+ elif isinstance(ref, list) or isinstance(ref, tuple):
+ if len(ref) != len(tar):
+ error = ("Expected ref and tar to have the same length, but got "
+ f"{len(ref)} and {len(tar)}")
+ logging.error(error)
+ return False, error
+ # Check that all of the iterables' values are the same.
+ for i in range(len(ref)):
+ same, error = check_same(ref[i], tar[i], rtol, atol)
+ if not same:
+ return same, error
- def decorator(*args, **kwargs):
- # Set up a crash reproducer for debugging.
- if artifacts_dir is not None:
- compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backend_id}.mlir")
- try:
- results = function(*args, **kwargs)
- except Exception: # pylint: disable=broad-except
- # Disable the crash reproducer (to avoid inadvertently overwriting it).
- if artifacts_dir is not None:
- compiler.Context.default_crash_reproducer_path = None
- raise
- return results
+ # Base check for numpy arrays.
+ elif isinstance(ref, np.ndarray):
+ if ref.dtype != tar.dtype:
+ error = ("Expected ref and tar to have the same dtype, but got "
+ f"'{ref.dtype}' and '{tar.dtype}'")
+ logging.error(error)
+ return False, error
+ if ref.size == tar.size == 0:
+ return True, None
- return decorator
-
-
-def _incrementally_lower_compiler_module(
- compiler_module: compiler.Module,
- backend_info: "BackendInfo",
- artifacts_dir: str,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Lowers a MLIR compiler module incrementally and saves its outputs.
-
- If artifacts_dir is provided then the following artifacts will be saved:
- tf_input.mlir:
- MLIR for the module in TF's input dialect.
- iree_input.mlir:
- The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
- backend_id/compiled.vmfb:
- A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
-
- Args:
- compiler_module: A compiler.Module to lower.
- backend_info: BackendInfo with the details for lowering compiler_module to
- IREE.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- if artifacts_dir is not None:
- os.makedirs(artifacts_dir, exist_ok=True)
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- # Manually run the passes that tf_module_to_compiler_module usually would.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- compiled_module = compiler_module.compile(
- target_backends=backend_info.compiler_targets)
-
- compiled_path = None
- if artifacts_dir is not None:
- backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
- os.makedirs(backend_dir, exist_ok=True)
- compiled_path = os.path.join(backend_dir, "compiled.vmfb")
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
- return compiled_module, compiled_path
-
-
-def _incrementally_compile_tf_module(
- module: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None,
-) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Compile a TensorFlow tf.Module and optionally save compilation artifacts.
-
- The module blob this creates is not callable. See IreeCompiledModule for an
- API that returns a module that can be called without any further steps.
-
- See _incrementally_lower_compiler_module's docstring for details about which
- artifacts will be saved.
-
- Args:
- module: A tf.Module.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
-
- Returns:
- A compiled IREE module blob and the path to the compiled VM FlatBuffer if
- artifacts_dir is provided.
- """
-
- def _compile_module(module, backend_info, exported_names, artifacts_dir):
- compiler_module = compiler.tf_module_to_compiler_module(module,
- exported_names,
- pass_pipeline=())
- return _incrementally_lower_compiler_module(compiler_module, backend_info,
- artifacts_dir)
-
- _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.backend_id)
- return _compile_module(module, backend_info, exported_names, artifacts_dir)
-
-
-def _incrementally_compile_tf_signature_def_saved_model(
- saved_model_dir: str, saved_model_tags: Set[str],
- backend_info: "BackendInfo", exported_name: str, artifacts_dir: str):
- """Compile a SignatureDef SavedModel and optionally save compilation artifacts.
-
- The module blob this creates is not callable. See IreeCompiledModule for an
- API that returns a module that can be called without any further steps.
-
- See _incrementally_lower_compiler_module's docstring for details about which
- artifacts will be saved.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
-
- Returns:
- A compiled IREE module blob and the path to the compiled VM FlatBuffer if
- artifacts_dir is provided.
- """
-
- def _compile_module(saved_model_dir, saved_model_tags, backend_info,
- exported_name, artifacts_dir):
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_signature_def_saved_model_to_compiler_module(
- saved_model_dir, saved_model_tags, [exported_name], pass_pipeline=())
- return _incrementally_lower_compiler_module(compiler_module, backend_info,
- artifacts_dir)
-
- _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.backend_id)
- return _compile_module(saved_model_dir, saved_model_tags, backend_info,
- exported_name, artifacts_dir)
-
-
-class _FunctionWrapper(object):
-
- def __call__(self, *args, **kwargs):
- raise NotImplementedError()
-
- def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
- """Dummy function to match _IreeFunctionWrapper's API."""
- return ("",), ("",)
-
-
-class CompiledModule(object):
- """Base class for the TF and IREE compiled modules."""
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- compiled_paths: Union[Dict[str, str], None],
- ):
- """Shared base constructor – not useful on its own.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- compiled_paths: A dictionary mapping compiled method names to file paths
- corresponding to their serialized representations.
- """
- self.module_name = module_name
- self.backend_info = backend_info
- self.compiled_paths = compiled_paths
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- raise NotImplementedError()
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- raise NotImplementedError()
-
- @classmethod
- def create_from_instance(cls,
- module_instance: tf.Module,
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module instance to the target backend in backend_info.
-
- This is only implemented for IreeCompiledModule.
-
- Args:
- module_instance: The tf.Module instance to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- raise NotImplementedError()
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- raise NotImplementedError()
-
- def __getattr__(self, attr: str) -> _FunctionWrapper:
- raise NotImplementedError()
-
- def iree_serializable(self):
- return False
-
- def tflite_serializable(self):
- return False
-
-
-class _IreeFunctionWrapper(_FunctionWrapper):
- """Wraps an IREE function, making it callable."""
-
- def __init__(self, context: rt.SystemContext, f: rt.system_api.BoundFunction):
- self._context = context
- self._f = f
-
- def __call__(self, *args, **kwargs):
- return self._f(*args, **kwargs)
-
- def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
- """Get cxx serialized inputs and outputs for this function."""
- return self._f.get_serialized_values()
-
-
-class IreeCompiledModule(CompiledModule):
- """Iree compiled module."""
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- compiled_paths: Dict[str, str],
- vm_module: rt.VmModule,
- config: rt.Config,
- ):
- """Base constructor – Use one of the named constructors instead.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- compiled_paths: A dictionary mapping compiled method names to file paths
- corresponding to their serialized representations.
- vm_module: A rt.VmModule containing compilation info to wrap.
- config: A rt.Config containing compilation info to wrap.
- """
- super().__init__(module_name, backend_info, compiled_paths)
- self._vm_module = vm_module
- self._config = config
- self.reinitialize()
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- set_random_seed()
- module_instance = module_class()
- return cls.create_from_instance(module_instance, backend_info,
- exported_names, artifacts_dir)
-
- @classmethod
- def create_from_instance(cls,
- module_instance: tf.Module,
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module instance to the target backend in backend_info.
-
- Args:
- module_instance: The tf.Module instance to compile.
- backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- module_blob, compiled_path = _incrementally_compile_tf_module(
- module=module_instance,
- backend_info=backend_info,
- exported_names=exported_names,
- artifacts_dir=artifacts_dir)
- vm_module = rt.VmModule.from_flatbuffer(module_blob)
- config = rt.Config(driver_name=backend_info.driver)
-
- compiled_paths = None
- if compiled_path is not None:
- # IREE bundles every compiled method into the same compiled module.
- compiled_paths = collections.defaultdict(lambda: compiled_path)
-
- module_name = type(module_instance).__name__
-
- return cls(module_name, backend_info, compiled_paths, vm_module, config)
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- del input_names # Unused.
- del output_names # Unused.
- module_blob, compiled_path = _incrementally_compile_tf_signature_def_saved_model(
- saved_model_dir, saved_model_tags, backend_info, exported_name,
- artifacts_dir)
- vm_module = rt.VmModule.from_flatbuffer(module_blob)
- config = rt.Config(driver_name=backend_info.driver)
-
- compiled_paths = None
- if compiled_path is not None:
- # IREE bundles every compiled method into the same compiled module :)
- compiled_paths = collections.defaultdict(lambda: compiled_path)
-
- return cls(module_name, backend_info, compiled_paths, vm_module, config)
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- # set_random_seed is not needed here because the model_class.__init__ is not
- # called.
- self._context = rt.SystemContext(modules=[self._vm_module],
- config=self._config)
-
- def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
- # Try to resolve it as a function.
- m = self._context.modules[self._vm_module.name]
- f = m[attr]
- return _IreeFunctionWrapper(self._context, f)
-
- def iree_serializable(self) -> bool:
- return self.compiled_paths is not None
-
-
-def _normalize_numpy(result: np.ndarray):
- """Normalizes TF and TFLite's outputs to match IREE's"""
- if np.isscalar(result):
- result = np.array(result)
- if result.dtype == np.bool:
- # IREE interprets bools as int8s, so we modify this for comparison.
- result = result.astype(dtype=np.int8)
- return result
-
-
-def _convert_to_numpy(tensor: Any) -> Any:
- if not isinstance(tensor, tf.Tensor):
- return tensor
- return _normalize_numpy(tensor.numpy())
-
-
-def convert_to_numpy(values: Any) -> Any:
- """Converts any tf.Tensor in values to numpy."""
- return apply_function(values, _convert_to_numpy)
-
-
-class _TfFunctionWrapper(_FunctionWrapper):
- """Wraps a TF function, normalizing it to numpy."""
-
- def __init__(self, f: Callable[..., Any]):
- self._f = f
-
- def __call__(self, *args, **kwargs):
- # TensorFlow will auto-convert all inbound args.
- results = self._f(*args, **kwargs)
- return convert_to_numpy(results)
-
-
-def _convert_inputs_to_tensors(function):
-
- def decorator(*args, **kwargs):
- args = [tf.convert_to_tensor(arg) for arg in args]
- kwargs = {k: tf.convert_to_tensor(v) for k, v in kwargs.items()}
- return function(*args, **kwargs)
-
- return decorator
-
-
-class SignatureDefSavedModelWrapper(object):
- """Wraps a SavedModel to imitate a tf.Module with a method 'exported_name'."""
-
- def __init__(self, saved_model_dir: str, saved_model_tags: Set[str],
- exported_name: str):
- self.saved_model = tf.saved_model.load(saved_model_dir,
- tags=saved_model_tags)
- inference_func = self.saved_model.signatures[exported_name]
- inference_func = _convert_inputs_to_tensors(inference_func)
- self.__setattr__(exported_name, inference_func)
-
-
-class TfCompiledModule(CompiledModule):
- """TensorFlow 'compiled' module.
-
- This facade exists to provide a complimentary API to IreeCompiledModule and
- normalize TensorFlow's output to Numpy.
- """
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- constructor: Callable[[], tf.Module],
- exported_names: Sequence[str],
- ):
- """Base constructor – Use one of the named constructors instead.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- constructor: A callable (class or function) which returns the tf.Module
- subclass instance to wrap.
- exported_names: an optional iterable of strings representing which of the
- tf.Module subclass instance's functions should be callable. If
- exported_names is empty then all functions will be callable.
- """
- super().__init__(module_name, backend_info, compiled_paths=None)
- self._constructor = constructor
- self._exported_names = exported_names
- self.reinitialize()
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- module_name = module_class.__name__
- constructor = module_class
- return cls(module_name, backend_info, constructor, exported_names)
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- constructor = lambda: SignatureDefSavedModelWrapper(
- saved_model_dir, saved_model_tags, exported_name)
- return cls(module_name, backend_info, constructor, [exported_name])
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- set_random_seed()
- self._tf_module = self._constructor()
-
- def __getattr__(self, attr: str) -> _TfFunctionWrapper:
- # Try to resolve it as a function.
- exported = not self._exported_names or attr in self._exported_names
- if not hasattr(self._tf_module, attr) or not exported:
- raise AttributeError(f"The TensorFlow module does not have attr '{attr}'")
- f = getattr(self._tf_module, attr)
- if not f or not hasattr(f, "__call__"):
- raise AttributeError(
- f"The TensorFlow module does not have a callable attr '{attr}'")
- return _TfFunctionWrapper(f)
-
-
-def _get_non_inhereted_function_names(cls):
- """Gets all methods that cls has that its parents don't have."""
- names = set(dir(cls))
- for parent in cls.__bases__:
- names -= set(dir(parent))
- return list(names)
-
-
-def _get_concrete_functions(module_class: Type[tf.Module],
- exported_names: Sequence[str] = ()):
- """Get concrete functions from non-inherited methods or exported_names."""
- if not len(exported_names):
- # Get all method names on 'module_class' that aren't on 'tf.Module'.
- exported_names = _get_non_inhereted_function_names(module_class)
- instance = module_class()
- functions = []
- for name in exported_names:
- functions.append(getattr(instance, name).get_concrete_function())
- return functions, exported_names, instance
-
-
-def tf_module_to_tflite_module_bytes(
- module_class: Type[tf.Module], exported_names: Sequence[str] = ()
-) -> Dict[str, bytes]:
- """Compiles a tf.Module's methods with TFLite.
-
- Args:
- module_class: A tf.Module subclass to compile with TFLite.
- exported_names: an optional iterable of strings representing which of the
- module_class's functions should be compiled. If exported_names is empty
- then all functions will be compiled.
-
- Returns:
- A dict mapping method names to compiled TFLite module bytes.
- """
- tflite_modules = []
- methods, method_names, instance = _get_concrete_functions(
- module_class, exported_names)
- failed_methods = []
- for method, method_name in zip(methods, method_names):
- logging.info("Attempting to convert '%s' to tflite...", method_name)
- try:
- converter = tf.lite.TFLiteConverter.from_concrete_functions([method])
- logging.info("...converted '%s' to tflite.", method_name)
- tflite_modules.append(converter.convert())
- except Exception as e:
- logging.error("Failed to convert '%s' to tflite.", method_name)
- logging.error("TFLite excpetion: %s", e)
- failed_methods.append(method_name)
-
- if failed_methods:
- raise RuntimeError(
- f"Failed to convert the following methods to tflite: {failed_methods}")
-
- # Keep variables alive until TFLite has done the conversion; ConcreteFunctions
- # themselves only keep weak references to variables.
- del instance
- return dict(zip(method_names, tflite_modules))
-
-
-def tf_signature_def_saved_model_to_tflite_module_bytes(
- saved_model_dir: str,
- saved_model_tags: Set[str],
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
-) -> Dict[str, bytes]:
- """Compiles a SignatureDef SavedModel signature with TFLite.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
-
- Returns:
- A dict mapping the signature name to the compiled TFLite module bytes.
- """
- converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
- saved_model_dir,
- tag_set=saved_model_tags,
- signature_key=exported_name,
- input_arrays=input_names,
- output_arrays=output_names)
- tflite_module = converter.convert()
- return dict([[exported_name, tflite_module]])
-
-
-def tflite_module_bytes_to_tflite_interpreters(
- tflite_module_bytes: Dict[str, bytes],
- artifacts_dir: str = None
-) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str], None]]:
- """Compile a dict of TFLite compiled bytes to TFLite interpreters.
-
- Args:
- tflite_module_bytes: A dict mapping method names to compiled TFLite byte
- strings.
- artifacts_dir: an optional path to save compilation artifacts to.
-
- Returns:
- A dictionary mapping method names to TFLite interpreters and a dictionary
- mapping method names to compiled tflite graph paths (or None if
- artifacts_dir is None).
- """
- interpreters = dict()
- compiled_paths = None
- if artifacts_dir is not None:
- compiled_paths = dict()
-
- def _interpret_bytes(method_name: str, tflite_module: bytes, base_dir: str):
- """Save compiled TFLite module bytes and convert into an interpreter."""
- tflite_dir = os.path.join(base_dir, "tflite")
- os.makedirs(tflite_dir, exist_ok=True)
- tflite_path = os.path.join(tflite_dir, f"{method_name}.tflite")
- with open(tflite_path, "wb") as f:
- f.write(tflite_module)
-
- interpreters[method_name] = tf.lite.Interpreter(tflite_path)
- if artifacts_dir is not None:
- compiled_paths[method_name] = tflite_path
-
- # Load each of the converted methods above into tf.lite.Interpreters.
- for method_name, tflite_module in tflite_module_bytes.items():
- if artifacts_dir is None:
- with tempfile.TemporaryDirectory() as base_dir:
- _interpret_bytes(method_name, tflite_module, base_dir)
- else:
- _interpret_bytes(method_name, tflite_module, artifacts_dir)
-
- return interpreters, compiled_paths
-
-
-class _TfLiteFunctionWrapper(_FunctionWrapper):
- """Wraps a TFLite interpreter and makes it behave like a python function."""
-
- def __init__(self, interpreter: tf.lite.Interpreter,
- output_names: Sequence[str]):
- self._interpreter = interpreter
- self._output_names = output_names
-
- def __call__(self, *args,
- **kwargs) -> Union[Dict[str, Any], Tuple[Any], np.ndarray]:
- if len(args) and len(kwargs):
- raise ValueError("Passing both args and kwargs is not supported by "
- "_TfLiteFunctionWrapper")
-
- # Set up and run the function.
- self._interpreter.allocate_tensors()
-
- if len(args):
- # Specifically to get TFLite to work with keras models that take a list of
- # inputs instead of a sequence of args as their inputs, because it decides
- # to change the input signature but it still technically works if you
- # ignore that it does that.
- if len(args) == 1 and isinstance(args[0], list):
- args = args[0]
-
- for arg, detail in zip(args, self._interpreter.get_input_details()):
- self._interpreter.set_tensor(detail["index"], arg)
- else:
- for detail in self._interpreter.get_input_details():
- self._interpreter.set_tensor(detail["index"], kwargs[detail["name"]])
-
- self._interpreter.invoke()
-
- # Extract the outputs from the TFLite interpreter.
- outputs = []
- for detail in self._interpreter.get_output_details():
- value = _normalize_numpy(self._interpreter.get_tensor(detail["index"]))
- if self._output_names is not None:
- name = detail["name"]
- if name not in self._output_names:
- raise ValueError(f"Expected '{name}' to be in {self._output_names}")
- outputs.append([detail["name"], value])
+ if np.issubdtype(ref.dtype, np.floating):
+ same = np.allclose(ref, tar, rtol=rtol, atol=atol, equal_nan=True)
+ abs_diff = np.max(np.abs(ref - tar))
+ rel_diff = np.max(np.abs(ref - tar) / np.max(np.abs(tar)))
+ diff_string = (f"Max abs diff: {abs_diff:.2e}, atol: {atol:.2e}, "
+ f"max relative diff: {rel_diff:.2e}, rtol: {rtol:.2e}")
+ if not same:
+ error = ("Floating point difference between ref and tar was too "
+ f"large. {diff_string}")
+ logging.error(error)
else:
- outputs.append(value)
-
- # Process them to match the output of the tf.Module.
- if self._output_names is not None:
- return dict(outputs)
+ error = None
+ logging.info(
+ "Floating point difference between ref and tar was within "
+ "tolerance. %s", diff_string)
+ return same, error
+ elif np.issubdtype(ref.dtype, np.integer):
+ same = np.array_equal(ref, tar)
+ if not same:
+ abs_diff = np.max(np.abs(ref - tar))
+ error = ("Expected array equality between ref and tar, but got "
+ f"a max elementwise difference of {abs_diff}")
+ logging.error(error)
+ else:
+ error = None
+ return same, error
else:
- if len(outputs) == 1:
- return outputs[0]
- return tuple(outputs)
+ return np.array_equal(ref, tar), None
+ # Base check for native number types.
+ elif isinstance(ref, (int, float)):
+ return ref == tar, None
-class TfLiteCompiledModule(CompiledModule):
- """Compiles a tf.Module with TFLite and allows it to be called."""
-
- def __init__(
- self,
- module_name: str,
- backend_info: "BackendInfo",
- compiled_paths: Dict[str, str],
- interpreters: Dict[str, tf.lite.Interpreter],
- output_names: Sequence[str] = None,
- ):
- """Base constructor – Use one of the named constructors instead.
-
- Args:
- module_name: A name for this compiled module. In most cases this will be
- the name of the tf.Module subclass or instance that is compiled.
- backend_info: BackendInfo with the details about compiling this module.
- compiled_paths: A dictionary mapping compiled method names to file paths
- corresponding to their serialized representations.
- interpreters: A dict of tf.lite.Interpreters to make callable.
- """
- super().__init__(module_name, backend_info, compiled_paths)
- self._interpreters = interpreters
- self._output_names = output_names
-
- @classmethod
- def create_from_class(cls,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module subclass to the target backend in backend_info.
-
- Args:
- module_class: The tf.Module subclass to compile.
- backend_info: BackendInfo with the details for compiling this module.
- exported_names: Optional sequence representing the exported names to keep.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- set_random_seed()
- tflite_module_bytes = tf_module_to_tflite_module_bytes(
- module_class, exported_names)
- interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
- tflite_module_bytes, artifacts_dir)
- module_name = module_class.__name__
- return cls(module_name, backend_info, compiled_paths, interpreters)
-
- @classmethod
- def create_from_signature_def_saved_model(cls,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- backend_info: "BackendInfo",
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None):
- """Compile a SignatureDef SavedModel to the target backend in backend_info.
-
- Args:
- saved_model_dir: Directory of the saved model.
- saved_model_tags: Optional set of tags to use when loading the model.
- module_name: A name for this compiled module.
- backend_info: BackendInfo with the details for compiling the saved model.
- exported_name: A str representing the signature on the saved model to
- compile.
- input_names: A sequence of kwargs to feed to the saved model.
- output_names: A sequence of named outputs to extract from the saved model.
- artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved. No compilation artifacts will be saved if this is not
- provided.
- """
- tflite_module_bytes = tf_signature_def_saved_model_to_tflite_module_bytes(
- saved_model_dir, saved_model_tags, exported_name, input_names,
- output_names)
- interpreters, compiled_paths = tflite_module_bytes_to_tflite_interpreters(
- tflite_module_bytes, artifacts_dir)
- return cls(module_name, backend_info, compiled_paths, interpreters,
- output_names)
-
- def reinitialize(self):
- """Reinitializes all stateful variables."""
- # This is a noop because TFLite (mostly) doesn't support stateful modules.
- pass
-
- def __getattr__(self, attr: str) -> _TfLiteFunctionWrapper:
- # Try to resolve it as an interpreter.
- if not attr in self._interpreters:
- raise AttributeError(
- f"The TFLite module does not have an interpreter for '{attr}'")
- return _TfLiteFunctionWrapper(self._interpreters[attr], self._output_names)
-
- def tflite_serializable(self) -> bool:
- return self.compiled_paths is not None
-
-
-class BackendInfo:
- """Contains information for compiling the specified backend."""
-
- _name_to_info = {
- "tf": {
- "compiled_module_class": TfCompiledModule,
- "driver": None,
- "compiler_targets": None,
- },
- "tflite": {
- "compiled_module_class": TfLiteCompiledModule,
- "driver": None,
- "compiler_targets": None,
- },
- "iree_vmla": {
- "compiled_module_class": IreeCompiledModule,
- "driver": "vmla",
- "compiler_targets": ["vmla"]
- },
- "iree_vulkan": {
- "compiled_module_class": IreeCompiledModule,
- "driver": "vulkan",
- "compiler_targets": ["vulkan-*"]
- },
- }
-
- def __init__(self, backend_name: str, backend_id: str = None):
- """Creates a BackendInfo with the compilation details for backend_name.
-
- Args:
- backend_name: a str specifying which backend to use. Should be one of
- 'tf', 'iree_vmla', 'iree_vulkan'.
- backend_id: an optional str specifying what name to use when saving
- compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
-
- Raises:
- KeyError: if backend_name is not one of
- ['tf', 'iree_vmla', 'iree_vulkan'].
- ValueError: if backend_id doesn't start with backend_name.
- """
- if backend_name not in self._name_to_info:
- raise KeyError(
- "Expected backend_name to be one of "
- f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
- if backend_id is not None and not backend_id.startswith(backend_name):
- raise ValueError(f"Expected backend_id to start with '{backend_name}' "
- f"but got '{backend_id}'.")
-
- self.backend_name = backend_name
- self.backend_id = backend_name if backend_id is None else backend_id
-
- info = self._name_to_info[backend_name]
- self._compiled_module_class = info["compiled_module_class"]
- self.driver = info["driver"]
- self.compiler_targets = info["compiler_targets"]
-
- def compile_from_class(self,
- module_class: Type[tf.Module],
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None) -> CompiledModule:
- """Creates a 'CompiledModule' for this backend."""
- return self._compiled_module_class.create_from_class(
- module_class, self, exported_names, artifacts_dir)
-
- def compile_signature_def_saved_model(
- self,
- saved_model_dir: str,
- saved_model_tags: Set[str],
- module_name: str,
- exported_name: str,
- input_names: Sequence[str],
- output_names: Sequence[str],
- artifacts_dir: str = None) -> CompiledModule:
- return self._compiled_module_class.create_from_signature_def_saved_model(
- saved_model_dir, saved_model_tags, module_name, self, exported_name,
- input_names, output_names, artifacts_dir)
-
- @classmethod
- def get_all_backends(cls) -> Sequence["BackendInfo"]:
- """Returns a list of all BackendInfo configurations."""
- return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
+ # If outputs end up here then an extra branch for that type should be added.
+ else:
+ raise TypeError(f"Encountered results with unexpected type {type(ref)}")
+ return True, None
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 9934460..deef8d9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -14,125 +14,87 @@
# limitations under the License.
"""Tests for pyiree.tf.support.tf_utils."""
-import os
-import tempfile
-
-from absl import logging
from absl.testing import parameterized
import numpy as np
from pyiree.tf.support import tf_utils
import tensorflow as tf
-class ConstantModule(tf.Module):
-
- @tf.function(input_signature=[])
- def meaning(self):
- return tf.constant([42.])
-
-
-class StatefulCountingModule(tf.Module):
-
- def __init__(self):
- self.count = tf.Variable([0.])
-
- @tf.function(input_signature=[])
- def increment(self):
- self.count.assign_add(tf.constant([1.]))
-
- @tf.function(input_signature=[])
- def get_count(self):
- return self.count
-
-
-class RandomInitModule(tf.Module):
-
- def __init__(self):
- self.value = tf.Variable(tf.random.uniform([1]))
-
- @tf.function(input_signature=[])
- def get(self):
- return self.value
-
-
class UtilsTests(tf.test.TestCase, parameterized.TestCase):
- def test_artifact_saving(self):
- backend_info = tf_utils.BackendInfo('iree_vmla')
- with tempfile.TemporaryDirectory() as artifacts_dir:
- tf_module = ConstantModule()
- iree_compiled_module, compiled_path = (
- tf_utils._incrementally_compile_tf_module(
- tf_module, backend_info=backend_info,
- artifacts_dir=artifacts_dir))
+ @parameterized.named_parameters([('int8_to_i8', np.int8, 'i8'),
+ ('int32_to_i32', np.int32, 'i32'),
+ ('float32_to_f32', np.float32, 'f32'),
+ ('float64_to_f64', np.float64, 'f64')])
+ def test_to_mlir_type(self, numpy_type, mlir_type):
+ self.assertEqual(tf_utils.to_mlir_type(numpy_type), mlir_type)
- artifacts_to_check = [
- 'tf_input.mlir',
- 'iree_input.mlir',
- compiled_path,
- ]
- for artifact in artifacts_to_check:
- artifact_path = os.path.join(artifacts_dir, artifact)
- logging.info('Checking path: %s', artifact_path)
- self.assertTrue(os.path.exists(artifact_path))
+ @parameterized.named_parameters([
+ ('single_i32', [np.array([1, 2], dtype=np.int32)], '2xi32=1 2'),
+ ('single_f32', [np.array([1, 2], dtype=np.float32)], '2xf32=1.0 2.0'),
+ ])
+ def test_save_input_values(self, inputs, inputs_str):
+ self.assertEqual(tf_utils.save_input_values(inputs), inputs_str)
+
+ def test_apply_function(self):
+ inputs = [1, [2, 3], (4, 5), {'6': 6, '78': [7, 8]}]
+ expected = [0, [1, 2], (3, 4), {'6': 5, '78': [6, 7]}]
+ result = tf_utils.apply_function(inputs, lambda x: x - 1)
+ self.assertEqual(result, expected)
+ self.assertNotEqual(inputs, expected)
@parameterized.named_parameters([
{
- 'testcase_name': 'tensorflow',
- 'backend_name': 'tf',
+ 'testcase_name': 'all the same',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tar_same': True,
},
{
- 'testcase_name': 'vmla',
- 'backend_name': 'iree_vmla',
+ 'testcase_name': 'wrong int',
+ 'array_c': np.array([1, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tar_same': False,
+ },
+ {
+ 'testcase_name': 'wrong string',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['a', '1', '2']),
+ 'array_e': np.array([0.0, 0.1, 0.2]),
+ 'tar_same': False,
+ },
+ {
+ 'testcase_name': 'wrong float',
+ 'array_c': np.array([0, 1, 2]),
+ 'array_d': np.array(['0', '1', '2']),
+ 'array_e': np.array([1.0, 0.1, 0.2]),
+ 'tar_same': False,
},
])
- def test_unaltered_state(self, backend_name):
- backend_info = tf_utils.BackendInfo(backend_name)
- module = backend_info.compile_from_class(StatefulCountingModule)
+ def test_recursive_check_same(self, array_c, array_d, array_e, tar_same):
- # Test that incrementing works properly.
- self.assertEqual([0.], module.get_count())
- module.increment()
- self.assertEqual([1.], module.get_count())
-
- module.reinitialize()
- # Test reinitialization.
- self.assertEqual([0.], module.get_count())
-
- def test_to_mlir_type(self):
- self.assertEqual('i8', tf_utils.to_mlir_type(np.dtype('int8')))
- self.assertEqual('i32', tf_utils.to_mlir_type(np.dtype('int32')))
- self.assertEqual('f32', tf_utils.to_mlir_type(np.dtype('float32')))
- self.assertEqual('f64', tf_utils.to_mlir_type(np.dtype('float64')))
-
- def test_save_input_values(self):
- inputs = [np.array([1, 2], dtype=np.int32)]
- self.assertEqual('2xi32=1 2', tf_utils.save_input_values(inputs))
- inputs = [np.array([1, 2], dtype=np.float32)]
- self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
-
- @parameterized.named_parameters([
- {
- 'testcase_name': 'tensorflow',
- 'backend_name': 'tf',
- },
- {
- 'testcase_name': 'vmla',
- 'backend_name': 'iree_vmla',
- },
- ])
- def test_random_initialization(self, backend_name):
- backend_info = tf_utils.BackendInfo(backend_name)
-
- # Test compilation is the same.
- module_1 = backend_info.compile_from_class(RandomInitModule)
- module_2 = backend_info.compile_from_class(RandomInitModule)
- self.assertAllEqual(module_1.get(), module_2.get())
-
- # Test reinitialization is the same.
- old_value = module_1.get()
- module_1.reinitialize()
- self.assertAllEqual(old_value, module_1.get())
+ # yapf: disable
+ ref = {
+ 'a': 1,
+ 'b': [
+ {'c': np.array([0, 1, 2])},
+ {'d': np.array(['0', '1', '2'])},
+ {'e': np.array([0.0, 0.1, 0.2])}
+ ],
+ }
+ tar = {
+ 'a': 1,
+ 'b': [
+ {'c': array_c},
+ {'d': array_d},
+ {'e': array_e}
+ ],
+ }
+ # yapf: enable
+ same, _ = tf_utils.check_same(ref, tar, rtol=1e-6, atol=1e-6)
+ self.assertEqual(tar_same, same)
if __name__ == '__main__':
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py
new file mode 100644
index 0000000..1a0c789
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils.py
@@ -0,0 +1,421 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for tracing tf.function inputs and outputs."""
+
+# This file uses the following abbreviations:
+# ref: reference – for the reference CompiledModule
+# tar: target - for one of the target CompiledModules
+
+import copy
+import glob
+import inspect
+import os
+import pickle
+import sys
+import textwrap
+from typing import Any, Callable, Dict, Sequence, Tuple, Union
+
+from absl import logging
+import numpy as np
+from pyiree.tf.support import module_utils
+from pyiree.tf.support import tf_utils
+import tensorflow.compat.v2 as tf
+
+NUMPY_LINEWIDTH = 120
+INDENT = " " * 2
+
+
+def _zfill_width(length: int) -> Union[int, None]:
+ return int(np.ceil(np.log10(length))) if length else None
+
+
+def get_trace_dir(artifacts_dir: str, trace: "Trace") -> str:
+ trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
+ trace.function_name)
+ os.makedirs(trace_dir, exist_ok=True)
+ return trace_dir
+
+
+class ModuleCall:
+
+ def __init__(self,
+ method: str,
+ inputs: Tuple[Any],
+ outputs: Tuple[Any],
+ serialized_inputs: Tuple[str],
+ serialized_outputs: Tuple[str],
+ rtol: float = 1e-6,
+ atol: float = 1e-6):
+ """Records the details of a call to a CompiledModule."""
+ self.method = method
+
+ # Deepcopy to safegard against mutation.
+ self.inputs = copy.deepcopy(inputs)
+ if outputs is not None:
+ outputs = copy.deepcopy(outputs)
+ else:
+ outputs = tuple()
+ self.outputs = outputs if isinstance(outputs, tuple) else (outputs,)
+
+ self.serialized_inputs = serialized_inputs
+ self.serialized_outputs = serialized_outputs
+
+ self.rtol = rtol
+ self.atol = atol
+
+ def get_tolerances(self) -> Tuple[float, float]:
+ """Gets the floating point tolerances associated with this call."""
+ return self.rtol, self.atol
+
+ def _get_shape_and_dtype(self, value: Any) -> str:
+ if isinstance(value, np.ndarray):
+ return tf_utils.get_shape_and_dtype(value, allow_non_mlir_dtype=True)
+ else:
+ return str(type(value))
+
+ def __str__(self):
+ prior_printoptions = np.get_printoptions()
+ np.set_printoptions(linewidth=NUMPY_LINEWIDTH)
+
+ header = f"Method: {self.method}"
+ inputs = "\n".join(
+ [textwrap.indent(str(value), INDENT) for value in self.inputs])
+ input_shapes = ", ".join(
+ self._get_shape_and_dtype(value) for value in self.inputs)
+
+ outputs = "\n".join(
+ [textwrap.indent(str(value), INDENT) for value in self.outputs])
+ output_shapes = ", ".join(
+ self._get_shape_and_dtype(value) for value in self.outputs)
+
+ tolerances = textwrap.indent(f"rtol={self.rtol}, atol={self.atol}", INDENT)
+ body = (f"Inputs: {input_shapes}\n{inputs}\n"
+ f"Outputs: {output_shapes}\n{outputs}"
+ f"\nTolerances:\n{tolerances}")
+ result = f"{header}\n{textwrap.indent(body, INDENT)}"
+
+ np.set_printoptions(**prior_printoptions)
+ return result
+
+ def serialize(self, call_dir: str) -> None:
+ """Stores a serialized copy of this call.
+
+ Can be loaded via ModuleCall.load(call_dir)
+
+ Args:
+ call_dir: str, the path to the directory to serialize this call to.
+ """
+ os.makedirs(call_dir, exist_ok=True)
+
+ metadata = {
+ "method": self.method,
+ "serialized_inputs": self.serialized_inputs,
+ "serialized_outputs": self.serialized_outputs,
+ "rtol": self.rtol,
+ "atol": self.atol
+ }
+ with open(os.path.join(call_dir, "metadata.pkl"), "wb") as f:
+ pickle.dump(metadata, f)
+
+ width = _zfill_width(len(self.inputs))
+ for i, value in enumerate(self.inputs):
+ path = os.path.join(call_dir, f"input_{str(i).zfill(width)}.pkl")
+ with open(path, "wb") as f:
+ pickle.dump(value, f)
+
+ width = _zfill_width(len(self.outputs))
+ for i, value in enumerate(self.outputs):
+ path = os.path.join(call_dir, f"output_{str(i).zfill(width)}.pkl")
+ with open(path, "wb") as f:
+ pickle.dump(value, f)
+
+ @staticmethod
+ def load(call_dir: str) -> "ModuleCall":
+ """Loads and returns a trace serialized with ModuleCall.serialize."""
+ with open(os.path.join(call_dir, "metadata.pkl"), "rb") as f:
+ kwargs = pickle.load(f)
+
+ for result_type in ["input", "output"]:
+ key = f"{result_type}s" # inputs or outputs
+ kwargs[key] = []
+
+ files = glob.glob(os.path.join(call_dir, f"{result_type}_*.pkl"))
+ for filename in sorted(files):
+ with open(filename, "rb") as f:
+ kwargs[key].append(pickle.load(f))
+
+ # Convert to tuple to match python's return type for multiple results.
+ kwargs[key] = tuple(kwargs[key])
+
+ return ModuleCall(**kwargs)
+
+
+class Trace:
+ """Stores the inputs and outputs of a series of calls to a module."""
+
+ def __init__(self,
+ module: Union[module_utils.CompiledModule, None],
+ function: Union[Callable[["TracedModule"], None], None],
+ _load_dict: Dict[str, Any] = None):
+ """Extracts metadata from module and function and initializes.
+
+ Example usage:
+ def forward_pass(...):
+ ...
+ module = IreeCompiledModule(...)
+ trace = Trace(module, forward_pass)
+ forward_pass(TracedModule(module, trace))
+
+ Args:
+ module: the module who's outputs this trace will record.
+ function: the function that module will be traced on.
+ _load_dict: used internally
+ """
+ if _load_dict is None:
+ # Extract metadata from module and function.
+ self.module_name = module.module_name
+ self.compiled_paths = module.compiled_paths
+ self.backend_name = module.backend_info.backend_name
+ self.backend_id = module.backend_info.backend_id
+ self.backend_driver = module.backend_info.driver
+ self.iree_serializable = module.iree_serializable()
+ self.tflite_serializable = module.tflite_serializable()
+ self.function_name = function.__name__
+ self.function_sourcefile = inspect.getsourcefile(function)
+ source, start_line = inspect.getsourcelines(function)
+ self.function_line_numbers = (start_line, start_line + len(source))
+ self.function_source = "".join(source)
+
+ self.calls = []
+ else:
+ self.module_name = _load_dict["module_name"]
+ self.compiled_paths = _load_dict["compiled_paths"]
+ self.backend_name = _load_dict["backend_name"]
+ self.backend_id = _load_dict["backend_id"]
+ self.backend_driver = _load_dict["backend_driver"]
+ self.iree_serializable = _load_dict["iree_serializable"]
+ self.tflite_serializable = _load_dict["tflite_serializable"]
+ self.function_name = _load_dict["function_name"]
+ self.function_sourcefile = _load_dict["function_sourcefile"]
+ self.function_line_numbers = _load_dict["function_line_numbers"]
+ self.function_source = _load_dict["function_source"]
+ self.calls = _load_dict["calls"]
+
+ def __str__(self):
+ header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
+ f"on function '{self.function_name}':")
+ # Give each call a number so it's easier to compare between multiple traces.
+ calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
+ calls = textwrap.indent("\n".join(calls), prefix=INDENT)
+ return f"{header}\n{calls}"
+
+ def __iter__(self):
+ for call in self.calls:
+ yield call
+
+ def save_plaintext(self, trace_dir: str, summarize: bool = True) -> None:
+ """Saves a human-readable string representation of this trace to disk.
+
+ Args:
+ trace_dir: str, path to the directory to save the trace in.
+ summarize: a bool controlling whether numpy should summarize the inputs
+ and outputs if they're large. Setting this to False is very slow for
+ large outputs.
+ """
+ prior_printoptions = np.get_printoptions()
+ np.set_printoptions(
+ linewidth=NUMPY_LINEWIDTH,
+ threshold=None if summarize else sys.maxsize,
+ edgeitems=10) # Can show more items since they won't clutter the logs.
+
+ path = os.path.join(trace_dir, "log.txt")
+ with open(path, "w") as f:
+ f.write(str(self))
+ f.write("\n")
+
+ np.set_printoptions(**prior_printoptions)
+
+ def serialize(self, trace_dir: str) -> None:
+ """Stores a serialized copy of this trace in trace_dir.
+
+ It can be loaded via `Trace.load(trace_dir)`.
+
+ Args:
+ trace_dir: str, path to the directory to serialize the trace to.
+ """
+
+ compiled_paths = None
+ if self.compiled_paths is not None:
+ # Convert to a dict to avoid the issues with serializing defaultdicts.
+ compiled_paths = dict(self.compiled_paths)
+
+ # Python serialization.
+ metadata = {
+ "module_name": self.module_name,
+ "compiled_paths": compiled_paths,
+ "backend_name": self.backend_name,
+ "backend_id": self.backend_id,
+ "backend_driver": self.backend_driver,
+ "iree_serializable": self.iree_serializable,
+ "tflite_serializable": self.tflite_serializable,
+ "function_name": self.function_name,
+ "function_sourcefile": self.function_sourcefile,
+ "function_line_numbers": self.function_line_numbers,
+ "function_source": self.function_source
+ }
+ with open(os.path.join(trace_dir, "metadata.pkl"), "wb") as f:
+ pickle.dump(metadata, f)
+
+ width = _zfill_width(len(self.calls))
+ for i, call in enumerate(self.calls):
+ call_dir = os.path.join(trace_dir, f"call_{str(i).zfill(width)}")
+ call.serialize(call_dir)
+
+ # C++ benchmark serialization.
+ if self.iree_serializable or self.tflite_serializable:
+ entry_function = self.calls[0].method
+ compiled_path = self.compiled_paths[entry_function]
+
+ if self.iree_serializable:
+ serialized_inputs = ", ".join(self.calls[0].serialized_inputs)
+ flagfile = [
+ f"--module_file={compiled_path}",
+ f"--driver={self.backend_driver}",
+ f"--function_inputs={serialized_inputs}",
+ f"--entry_function={entry_function}",
+ ]
+ with open(os.path.join(trace_dir, "flagfile"), "w") as f:
+ f.writelines(line + "\n" for line in flagfile)
+ else:
+ with open(os.path.join(trace_dir, "graph_path"), "w") as f:
+ f.writelines(compiled_path + "\n")
+
+ @staticmethod
+ def load(trace_dir: str) -> "Trace":
+ """Loads and returns a trace serialized with Trace.serialize.
+
+ Args:
+ trace_dir: str, path to the directory of the serialized trace.
+
+ Returns:
+ A Trace deserialized from trace_dir.
+ """
+ with open(os.path.join(trace_dir, "metadata.pkl"), "rb") as f:
+ load_dict = pickle.load(f)
+ call_dirs = sorted(glob.glob(os.path.join(trace_dir, "call_*")))
+ calls = [ModuleCall.load(call_dir) for call_dir in call_dirs]
+ load_dict["calls"] = calls
+ return Trace(module=None, function=None, _load_dict=load_dict)
+
+
+class TracedModule:
+
+ def __init__(self, module: module_utils.CompiledModule, trace: Trace):
+ """Wraps a CompiledModule so that all inputs and outputs are traced.
+
+ The TracedModule returned will have an API almost identical to that of the
+ passed CompiledModule. The only changes is that if the keywords `rtol` or
+ `atol` are passed to one of the CompiledModule's methods, then they will be
+ used to set the tolerance for comparing that call to the same call in
+ another trace. So for example, calling `traced_module.add(a, b rtol=1e-8)`
+ would be the same as calling `module.add(a, b)`.
+
+ Args:
+ module: the CompiledModule to trace.
+ trace: the Trace to record calls to this module with.
+ """
+ self._module = module
+ self._trace = trace
+
+ def _trace_call(self, method: module_utils._FunctionWrapper,
+ method_name: str):
+ """Decorates a CompiledModule method to capture its inputs and outputs."""
+
+ def call(*args, **kwargs):
+ # Pop manually specified tolerances from the kwargs (if any).
+ tolerances = {}
+ tolerances["rtol"] = kwargs.pop("rtol", None)
+ tolerances["atol"] = kwargs.pop("atol", None)
+ # Only pass these to ModuleCall if they were specified by the user.
+ tolerances = {k: v for k, v in tolerances.items() if v is not None}
+
+ # Ensure the inputs are numpy inputs.
+ args = tf_utils.convert_to_numpy(args)
+ kwargs = tf_utils.convert_to_numpy(kwargs)
+
+ # Run the method and record the details of the call.
+ outputs = method(*args, **kwargs)
+ serialized_inputs, serialized_outputs = method.get_serialized_values()
+ self._trace.calls.append(
+ ModuleCall(method_name, args, outputs, serialized_inputs,
+ serialized_outputs, **tolerances))
+ return outputs
+
+ return call
+
+ def __getattr__(self, attr):
+ # Try to resolve it as an attr on self._module.
+ if not hasattr(self._module, attr):
+ raise AttributeError(f"The compiled module does not have attr '{attr}'")
+ module_attr = getattr(self._module, attr)
+ if not hasattr(module_attr, "__call__"):
+ # e.g. traced_module.backend
+ return module_attr
+ else:
+ # e.g. traced_module.simple_mul(a, b)
+ return self._trace_call(module_attr, method_name=attr)
+
+
+def compare_traces(ref_trace: Trace,
+ tar_trace: Trace) -> Tuple[bool, Sequence[str]]:
+ traces_match = True
+ error_messages = []
+
+ # Check that all method invocations match.
+ ref_methods = [(call.method, call.rtol, call.atol) for call in ref_trace]
+ tar_methods = [(call.method, call.rtol, call.atol) for call in tar_trace]
+ if ref_methods != tar_methods:
+ # Raise a ValueError instead of returning False since this is an
+ # unexpected error.
+ raise ValueError(
+ "The reference and target traces have different call structures:\n"
+ f"Reference: {ref_methods}\nTarget: {tar_methods}")
+
+ for ref_call, tar_call in zip(ref_trace, tar_trace):
+ logging.info("Comparing calls to '%s'", ref_call.method)
+ rtol, atol = ref_call.get_tolerances()
+
+ inputs_match, error_message = tf_utils.check_same(ref_call.inputs,
+ tar_call.inputs, rtol,
+ atol)
+ if not inputs_match:
+ error_messages.append(error_message)
+ logging.error("Inputs did not match.")
+ outputs_match, error_message = tf_utils.check_same(ref_call.outputs,
+ tar_call.outputs, rtol,
+ atol)
+ if not outputs_match:
+ error_messages.append(error_message)
+ logging.error("Outputs did not match.")
+ calls_match = inputs_match and outputs_match
+
+ if not calls_match:
+ logging.error("Comparision between '%s' and '%s' failed on method '%s'",
+ ref_trace.backend_id, tar_trace.backend_id, ref_call.method)
+ logging.error("Reference call '%s':\n%s", ref_trace.backend_id, ref_call)
+ logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
+
+ traces_match = traces_match and calls_match
+ return traces_match, error_messages
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py
new file mode 100644
index 0000000..58315c8
--- /dev/null
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/trace_utils_test.py
@@ -0,0 +1,156 @@
+# Lint as: python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for pyiree.tf.support.trace_utils."""
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+import numpy as np
+from pyiree.tf.support import module_utils
+from pyiree.tf.support import trace_utils
+import tensorflow as tf
+
+
+class StatefulCountingModule(tf.Module):
+
+ def __init__(self):
+ self.count = tf.Variable([0.])
+
+ @tf.function(input_signature=[])
+ def increment(self):
+ self.count.assign_add(tf.constant([1.]))
+
+ @tf.function(input_signature=[])
+ def get_count(self):
+ return self.count
+
+ @tf.function(input_signature=[tf.TensorSpec([1])])
+ def increment_by(self, value):
+ self.count.assign_add(value)
+
+ @tf.function(input_signature=[tf.TensorSpec([1]), tf.TensorSpec([1])])
+ def increment_by_max(self, a, b):
+ result = tf.maximum(a, b)
+ self.count.assign_add(result)
+ return result
+
+ @tf.function(input_signature=[])
+ def decrement(self):
+ self.count.assign_sub(tf.constant([1.]))
+
+
+class TestUtilsTests(tf.test.TestCase, parameterized.TestCase):
+
+ def test_trace_inputs_and_outputs(self):
+
+ def trace_function(module):
+ # No inputs or outputs
+ module.increment()
+ # Only inputs
+ module.increment_by(np.array([81.], dtype=np.float32))
+ # Only outputs
+ module.get_count()
+
+ module = module_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('tf'))
+ trace = trace_utils.Trace(module, trace_function)
+ trace_function(trace_utils.TracedModule(module, trace))
+
+ self.assertIsInstance(trace.calls[0].inputs, tuple)
+ self.assertEmpty(trace.calls[0].inputs)
+ self.assertIsInstance(trace.calls[0].outputs, tuple)
+ self.assertEmpty(trace.calls[0].outputs)
+
+ self.assertAllClose(trace.calls[1].inputs[0], [81.])
+ self.assertAllClose(trace.calls[2].outputs[0], [82.])
+
+ def test_nonmatching_methods(self):
+
+ def tf_function(module):
+ module.increment()
+ module.increment()
+
+ def vmla_function(module):
+ module.increment()
+ module.decrement()
+
+ tf_module = module_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('tf'))
+ tf_trace = trace_utils.Trace(tf_module, tf_function)
+ tf_function(trace_utils.TracedModule(tf_module, tf_trace))
+
+ vmla_module = module_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+ vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
+ vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))
+
+ with self.assertRaises(ValueError):
+ trace_utils.compare_traces(tf_trace, vmla_trace)
+
+ def test_nonmatching_inputs(self):
+
+ def tf_function(module):
+ module.increment_by(np.array([42.], dtype=np.float32))
+
+ def vmla_function(module):
+ module.increment_by(np.array([22.], dtype=np.float32))
+
+ tf_module = module_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('tf'))
+ tf_trace = trace_utils.Trace(tf_module, tf_function)
+ tf_function(trace_utils.TracedModule(tf_module, tf_trace))
+
+ vmla_module = module_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+ vmla_trace = trace_utils.Trace(vmla_module, vmla_function)
+ vmla_function(trace_utils.TracedModule(vmla_module, vmla_trace))
+
+ same, error_messages = trace_utils.compare_traces(tf_trace, vmla_trace)
+ self.assertFalse(same)
+
+ def test_trace_serialize_and_load(self):
+
+ def trace_function(module):
+ module.increment()
+ module.increment_by(np.array([81.], dtype=np.float32))
+ module.increment_by_max(np.array([81], dtype=np.float32),
+ np.array([92], dtype=np.float32))
+ module.get_count()
+
+ module = module_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, module_utils.BackendInfo('iree_vmla'))
+ trace = trace_utils.Trace(module, trace_function)
+ trace_function(trace_utils.TracedModule(module, trace))
+
+ with tempfile.TemporaryDirectory() as artifacts_dir:
+ trace_function_dir = trace_utils.get_trace_dir(artifacts_dir, trace)
+ trace.serialize(trace_function_dir)
+ self.assertTrue(
+ os.path.exists(os.path.join(trace_function_dir, 'metadata.pkl')))
+ loaded_trace = trace_utils.Trace.load(trace_function_dir)
+
+ # Check all calls match.
+ self.assertTrue(trace_utils.compare_traces(trace, loaded_trace))
+
+ # Check all other metadata match.
+ self.assertAllEqual(trace.__dict__.keys(), loaded_trace.__dict__.keys())
+ for key in trace.__dict__.keys():
+ if key != 'calls':
+ self.assertEqual(trace.__dict__[key], loaded_trace.__dict__[key])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index e806496..ccc2f63 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -32,16 +32,16 @@
specified directory. These artifacts include MLIR across various lowerings and
the compiled VM FlatBuffer. A basic example of creating and calling an
`IreeCompiledModule` can be found in
-[`tf_utils_test.py`](https://github.com/google/iree/blob/main/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py)
+[`module_utils_test.py`](https://github.com/google/iree/blob/main/integrations/tensorflow/bindings/python/pyiree/tf/support/module_utils_test.py)
When using Keras models or tf.Modules with functions that IREE can't compile,
`exported_names` should be specified. For example:
```python
-from pyiree.tf.support import tf_utils
-vmla_module = tf_utils.IreeCompiledModule(
+from pyiree.tf.support import module_utils
+vmla_module = module_utils.IreeCompiledModule(
module_class=KerasTFModuleClass,
- backend_info=tf_utils.BackendInfo('iree_vmla'),
+ backend_info=module_utils.BackendInfo('iree_vmla'),
exported_names=['predict'])
vmla_module.predict(...)
```