Merge main -> google
* f8b466e2 Bumping tracy version and fixing API usage. (#3328)
* c2bd3b9d Refactor tf bindings compilation API (#3326)
* 9c698256 [vulkan] Create a GUI application to run an IREE module (#3274)
* 61cb3a6e Remove e2e test case decorator (#3304)
* e5953855 Canonicalizes mhlo.dot_general to a rank-3 mhlo.dot_general or mhlo.dot (#3225)
* f7cf2195 Merge google -> main (#3321)
* f25eef3c Update android benchmarking titles to match.. (#3322)
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/3334 from hanhanW:main-to-google 62c8861c803837bfb3b67ac102b6afe4e4c299b0
PiperOrigin-RevId: 335072918
diff --git a/SUBMODULE_VERSIONS b/SUBMODULE_VERSIONS
index 006a608..f18bcd0 100644
--- a/SUBMODULE_VERSIONS
+++ b/SUBMODULE_VERSIONS
@@ -15,6 +15,6 @@
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
57eb48aed36160c4876bc8310d9ca84d42ee9e2a third_party/swiftshader
1454ee0907ee3e71030d7123595fa5487f23a56d third_party/tensorflow
-864d86e8b6d21449474db5e9313dbff90aa9c24f third_party/tracy
+a9a09ab0940408898fccfdcfe2bb8dc19b50f13c third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
909f36b714c9239ee0b112a321220213a474ba53 third_party/vulkan_memory_allocator
diff --git a/docs/developing_iree/e2e_benchmarking.md b/docs/developing_iree/e2e_benchmarking.md
index 2187536..a712d2f 100644
--- a/docs/developing_iree/e2e_benchmarking.md
+++ b/docs/developing_iree/e2e_benchmarking.md
@@ -247,7 +247,7 @@
changes. The flagfile can still take care of specifying the input data, driver
and entry function however.
-## 5. Benchmark the model on Android with TFLite
+## 5. Benchmarking TFLite on Android
### 5.1 Prepare the benchmarking tools
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 982242e..774412d 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -28,14 +28,16 @@
"OutputFormat",
# TensorFlow
"TF_IMPORT_PASS_PIPELINE",
- "tf_load_saved_model",
- "tf_load_signature_def_saved_model",
- "tf_compile_saved_model",
+ "tf_saved_model_to_compiler_module",
+ "tf_signature_def_saved_model_to_compiler_module",
"tf_module_to_compiler_module",
+ "compile_tf_saved_model",
+ "compile_tf_signature_def_saved_model",
+ "compile_tf_module",
]
import tempfile
-from typing import Collection, Optional, Sequence
+from typing import Collection, Optional, Sequence, Set
from . import binding as binding
import tensorflow as tf
@@ -95,74 +97,14 @@
)
-def tf_load_saved_model(saved_model_dir: str,
- exported_names: Collection[str] = (),
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> Module:
- """Loads a TensorFlow saved model from its persistent representation.
-
- See also tf_compile_saved_model() for a one-shot API to load and compile.
-
- Args:
- saved_model_dir: Directory of the saved model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- input_module = binding.load_saved_model(
- compiler_context, saved_model_dir, exported_names=exported_names)
- if pass_pipeline:
- input_module.run_pass_pipeline(pass_pipeline)
- return input_module
-
-
-def tf_load_signature_def_saved_model(
- saved_model_dir: str,
- tags: Collection[str] = set(),
- exported_names: Collection[str] = [],
- pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- compiler_context: Optional[Context] = None) -> Module:
- """Loads a TensorFlow SignatureDef saved model from persistent representation.
-
- Args:
- saved_model_dir: Directory of the saved model.
- tags: Optional tuple of tags to use when loading the model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- pass_pipeline: Passes to run on the imported module prior to returning.
- Defaults to TF_IMPORT_PASS_PIPELINE.
- compiler_context: The pyiree.compiler.Context() backing the module.
-
- Returns:
- An MLIR Module suitable for compilation by the IREE compiler.
- This can be further compiled to an IREE blob by calling
- .compile_to_sequencer_blob.
- """
- if not compiler_context:
- compiler_context = Context()
- input_module = binding.load_signature_def_saved_model(
- compiler_context, saved_model_dir, tags, exported_names=exported_names)
- if pass_pipeline:
- input_module.run_pass_pipeline(pass_pipeline)
- return input_module
-
-
-def tf_compile_saved_model(
+def tf_saved_model_to_compiler_module(
saved_model_dir: str,
exported_names: Collection[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
- target_backends: Collection[str] = (),
- compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
- """Loads and compiles a TensorFlow saved model in one shot.
+ compiler_context: Optional[Context] = None) -> Module:
+ """Converts a TensorFlow SavedModel into a MLIR module.
+
+ See also compile_tf_saved_model() for a one-shot API to load and compile.
Args:
saved_model_dir: Directory of the saved model.
@@ -170,31 +112,64 @@
keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+ if not compiler_context:
+ compiler_context = Context()
+ compiler_module = binding.load_saved_model(compiler_context,
+ saved_model_dir,
+ exported_names=exported_names)
+ if pass_pipeline:
+ compiler_module.run_pass_pipeline(pass_pipeline)
+ return compiler_module
+
+
+def compile_tf_saved_model(
+ saved_model_dir: str,
+ exported_names: Collection[str] = (),
+ target_backends: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
+ """Compiles a TensorFlow SavedModel to IREE in one shot.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
target_backends: The specific target backends to compile for (defaults to
all compiled in targets).
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
Returns:
An OpaqueBlob representing the compiled module.
"""
- input_module = tf_load_saved_model(saved_model_dir, exported_names,
- pass_pipeline, compiler_context)
- return input_module.compile(target_backends=target_backends)
+ compiler_module = tf_saved_model_to_compiler_module(saved_model_dir,
+ exported_names,
+ pass_pipeline,
+ compiler_context)
+ return compiler_module.compile(target_backends=target_backends)
-def tf_module_to_compiler_module(
- module: tf.Module,
- exported_names: Collection[str] = (),
- sm_path: str = None,
+def tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir: str,
+ saved_model_tags: Set[str] = set(),
+ exported_names: Collection[str] = [],
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> Module:
- """Converts a tf.Module into a MLIR module.
+ """Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
Args:
- module: The tf.Module instance to convert to MLIR
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
exported_names: Optional tuple of strings representing the exported names to
keep.
- sm_path: the path to save the tf.Module to, if any. Defaults to None.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -204,16 +179,110 @@
This can be further compiled to an IREE blob by calling
.compile_to_sequencer_blob.
"""
-
- def _convert(sm_path):
- options = tf.saved_model.SaveOptions(save_debug_info=True)
- tf.saved_model.save(module, sm_path, options=options)
- return tf_load_saved_model(sm_path, exported_names, pass_pipeline,
- compiler_context)
-
- if sm_path is None:
- with tempfile.TemporaryDirectory() as sm_path:
- compiler_module = _convert(sm_path)
- else:
- compiler_module = _convert(sm_path)
+ if not compiler_context:
+ compiler_context = Context()
+ compiler_module = binding.load_signature_def_saved_model(
+ compiler_context,
+ saved_model_dir,
+ saved_model_tags,
+ exported_names=exported_names)
+ if pass_pipeline:
+ compiler_module.run_pass_pipeline(pass_pipeline)
return compiler_module
+
+
+def compile_tf_signature_def_saved_model(
+ saved_model_dir: str,
+ saved_model_tags: Set[str] = set(),
+ exported_names: Collection[str] = (),
+ target_backends: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
+ """Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
+
+ Args:
+ saved_model_dir: Directory of the saved model.
+ saved_model_tags: Optional set of tags to use when loading the model.
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ target_backends: The specific target backends to compile for (defaults to
+ all compiled in targets).
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+
+ Returns:
+ An OpaqueBlob representing the compiled module.
+ """
+ compiler_module = tf_signature_def_saved_model_to_compiler_module(
+ saved_model_dir, saved_model_tags, exported_names, pass_pipeline,
+ compiler_context)
+ return compiler_module.compile(target_backends=target_backends)
+
+
+def tf_module_to_compiler_module(
+ module: tf.Module,
+ exported_names: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None,
+ saved_model_dir: str = None) -> Module:
+ """Converts a tf.Module instance into a MLIR module.
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ saved_model_dir: Optional path to save the tf.Module to. The module will not
+ be saved on disk if this is not provided.
+
+ Returns:
+ An MLIR Module suitable for compilation by the IREE compiler.
+ This can be further compiled to an IREE blob by calling
+ .compile_to_sequencer_blob.
+ """
+
+ def _convert(saved_model_dir):
+ options = tf.saved_model.SaveOptions(save_debug_info=True)
+ tf.saved_model.save(module, saved_model_dir, options=options)
+ return tf_saved_model_to_compiler_module(saved_model_dir, exported_names,
+ pass_pipeline, compiler_context)
+
+ if saved_model_dir is None:
+ with tempfile.TemporaryDirectory() as saved_model_dir:
+ compiler_module = _convert(saved_model_dir)
+ else:
+ compiler_module = _convert(saved_model_dir)
+ return compiler_module
+
+
+def compile_tf_module(module: tf.Module,
+ exported_names: Collection[str] = (),
+ target_backends: Collection[str] = (),
+ pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
+ compiler_context: Optional[Context] = None,
+ saved_model_dir: str = None):
+ """Compiles a tf.Module to IREE in one shot.
+
+ Args:
+ module: The tf.Module instance to convert to MLIR
+ exported_names: Optional tuple of strings representing the exported names to
+ keep.
+ target_backends: The specific target backends to compile for (defaults to
+ all compiled in targets).
+ pass_pipeline: Passes to run on the imported module prior to returning.
+ Defaults to TF_IMPORT_PASS_PIPELINE.
+ compiler_context: The pyiree.compiler.Context() backing the module.
+ saved_model_dir: Optional path to save the tf.Module to. The module will not
+ be saved on disk if this is not provided.
+
+ Returns:
+ An OpaqueBlob representing the compiled module.
+ """
+ compiler_module = tf_module_to_compiler_module(module, exported_names,
+ pass_pipeline,
+ compiler_context,
+ saved_model_dir)
+ return compiler_module.compile(target_backends=target_backends)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
index a7aba66..10e55c3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/saved_model_test.py
@@ -62,7 +62,7 @@
tf.saved_model.save(my_module, sm_dir, options=options)
# Load it up.
- input_module = compiler.tf_load_saved_model(sm_dir)
+ input_module = compiler.tf_saved_model_to_compiler_module(sm_dir)
xla_asm = input_module.to_asm()
print("XLA ASM:", xla_asm)
self.assertRegex(xla_asm, "mhlo.tanh")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
index 8a2e1cb..9764b13 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/signature_def_saved_model_test.py
@@ -55,8 +55,8 @@
sess, ["bar"], {"baz": sig}, strip_default_attrs=True)
builder.save()
- module = compiler.tf_load_signature_def_saved_model(
- sm_dir, tags=set(["bar"]), exported_names=["baz"])
+ module = compiler.tf_signature_def_saved_model_to_compiler_module(
+ sm_dir, saved_model_tags=set(["bar"]), exported_names=["baz"])
module_asm = module.to_asm(large_element_limit=100)
self.assertRegexpMatches(module_asm, "flow.variable @[^ ]* dense<10>")
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 87a4c8a..fa3df03 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -22,6 +22,7 @@
# ref: reference – for the reference CompiledModule
# tar: target - for one of the target CompiledModules
+import collections
import copy
import glob
import inspect
@@ -465,7 +466,7 @@
f"--input_file={compiled_path}",
f"--driver={self.backend_driver}",
f"--inputs={serialized_inputs}",
- f"--entry_function={entry_function}"
+ f"--entry_function={entry_function}",
]
with open(os.path.join(trace_dir, "flagfile"), "w") as f:
f.writelines(line + "\n" for line in flagfile)
@@ -551,7 +552,11 @@
return self._trace_call(module_attr, method_name=attr)
-def compile_module(
+Modules = collections.namedtuple('Modules',
+ ['ref_module', 'tar_modules', 'artifacts_dir'])
+
+
+def compile_tf_module(
module_class: Type[tf.Module], exported_names: Sequence[str] = ()
) -> Callable[[Any], Any]:
"""CompiledModuleTestCase decorator that compiles a tf.Module.
@@ -565,80 +570,38 @@
exported_names: optional iterable of strings representing which of
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
-
- Returns:
- Class decorator function.
"""
- def decorator(cls):
- """Decorator Function."""
- if not issubclass(cls, TracedModuleTestCase):
- logging.exception(
- "The 'compile_module' decorator must be applied to a "
- "TracedModuleTestCase derived class, which %s is not.", cls)
- cls._module_class = module_class
- cls._exported_names = exported_names
- return cls
+ # Setup the directory for saving compilation artifacts and traces.
+ artifacts_dir = _setup_artifacts_dir(module_class.__name__)
- return decorator
+ # Get the backend information for this test.
+ ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
+ tar_backend_infos = get_target_backends()
+ compile_backend = lambda backend_info: backend_info.compile(
+ module_class, exported_names, artifacts_dir)
-# Will be initialized by TracedModuleTestCase.setUpClass
-# Global variables are used because storing the compiler context on the cls
-# causes cleaning up refcounts to fail, and tf.test.TestCase wipes the variables
-# on the class instance (self.*) before each unittest.
-# TODO(#2900): Move these back to class variables when we figure out issues with
-# refcounting.
-_global_ref_module = None
-_global_tar_modules = None
+ ref_module = compile_backend(ref_backend_info)
+ tar_modules = [
+ compile_backend(backend_info) for backend_info in tar_backend_infos
+ ]
+ return Modules(ref_module, tar_modules, artifacts_dir)
class TracedModuleTestCase(tf.test.TestCase):
"""Compiles a tf.Module to multiple backends to test their correctness."""
- # Will be initialized by the @compile_module decorator.
- _module_class = None
- _exported_names = ()
-
- @classmethod
- def _compile(cls, backend_info: tf_utils.BackendInfo):
- return backend_info.compile(cls._module_class, cls._exported_names,
- cls._artifacts_dir)
-
- @classmethod
- def setUpClass(cls) -> None:
- # Ran before any of the unit tests.
- super().setUpClass()
- if cls._module_class is None:
- raise AttributeError(
- "setUpClass was called but no module was specified. Specify a module "
- "to compile via the @tf_test_utils.compile_module decorator.")
-
- # Setup the directory for saving compilation artifacts and traces.
- cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
-
- # Get the backend information for this test.
- ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
- f"{FLAGS.reference_backend}_ref")
- tar_backend_infos = get_target_backends()
-
- global _global_ref_module
- global _global_tar_modules
- _global_ref_module = cls._compile(ref_backend_info)
- _global_tar_modules = [
- cls._compile(backend_info) for backend_info in tar_backend_infos
- ]
def setUp(self) -> None:
# Runs before each unit test.
super().setUp()
- global _global_ref_module
- global _global_tar_modules
- _global_ref_module.reinitialize()
- for module in _global_tar_modules:
+ self._modules.ref_module.reinitialize()
+ for module in self._modules.tar_modules:
module.reinitialize()
- def compare_backends(self, trace_function: Callable[[TracedModule],
- None]) -> None:
+ def compare_backends(self, trace_function: Callable[[TracedModule], None],
+ modules: Modules) -> None:
"""Run the reference and target backends on trace_function and compare them.
Random seeds for tensorflow, numpy and python are set before each invocation
@@ -648,17 +611,17 @@
trace_function: a function accepting a TracedModule as its argument.
"""
# Create Traces for each backend.
- ref_trace = Trace(_global_ref_module, trace_function)
+ ref_trace = Trace(modules.ref_module, trace_function)
tar_traces = [
- Trace(module, trace_function) for module in _global_tar_modules
+ Trace(module, trace_function) for module in modules.tar_modules
]
# Run the traces through trace_function with their associated modules.
tf_utils.set_random_seed()
- trace_function(TracedModule(_global_ref_module, ref_trace))
+ trace_function(TracedModule(modules.ref_module, ref_trace))
if FLAGS.log_all_traces:
logging.info(ref_trace)
- for module, trace in zip(_global_tar_modules, tar_traces):
+ for module, trace in zip(modules.tar_modules, tar_traces):
tf_utils.set_random_seed()
trace_function(TracedModule(module, trace))
if FLAGS.log_all_traces:
@@ -674,11 +637,11 @@
failed_backend_indices.append(i)
# Save the results to disk before validating.
- ref_trace_dir = _get_trace_dir(self._artifacts_dir, ref_trace)
+ ref_trace_dir = _get_trace_dir(modules.artifacts_dir, ref_trace)
ref_trace.save_plaintext(ref_trace_dir, FLAGS.summarize)
ref_trace.serialize(ref_trace_dir)
for tar_trace in tar_traces:
- tar_trace_dir = _get_trace_dir(self._artifacts_dir, tar_trace)
+ tar_trace_dir = _get_trace_dir(modules.artifacts_dir, tar_trace)
tar_trace.save_plaintext(tar_trace_dir, FLAGS.summarize)
tar_trace.serialize(tar_trace_dir)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index 01c6e0f..fc8bab6 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -88,40 +88,47 @@
return result
-def backends_to_str(backend_infos: Sequence["BackendInfo"]) -> str:
- """Creates a normalized string representing the provided backends."""
- normalized_names = []
- for backend_info in backend_infos:
- # Remove unusual characters and ensure names don't end or start in "_".
- name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
- normalized_names.append(name.strip("_"))
- return "__".join(normalized_names)
+def _setup_mlir_crash_reproducer(
+ function: Callable[[Any], Any],
+ artifacts_dir: str,
+ backend_name: str,
+) -> Callable[[Any], Any]:
+ """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
+
+ Writes to `artifacts_dir/reproducer__{backend}.mlir` in the case of a crash.
+
+ Args:
+ function: The callable to decorate.
+ artifacts_dir: The directory to write the reproducer to.
+ backend_name: The name of the backend `function` compiles to.
+
+ Returns:
+ A function with the same API as the passed function.
+ """
+
+ def decorator(*args, **kwargs):
+ # Set up a crash reproducer for debugging.
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = os.path.join(
+ artifacts_dir, f"reproducer__{backend_name}.mlir")
+ try:
+ results = function(*args, **kwargs)
+ except Exception: # pylint: disable=broad-except
+ # Disable the crash reproducer (to avoid inadvertently overwriting it).
+ if artifacts_dir is not None:
+ compiler.Context.default_crash_reproducer_path = None
+ raise
+ return results
+
+ return decorator
-def _get_backends_path(artifact_name: str,
- backend_infos: Sequence["BackendInfo"],
- artifacts_dir: str) -> str:
- """Gets the path to save artifact_name under for the specified backend(s)."""
- backends_string = backends_to_str(backend_infos)
- # Put the artifact in a directory if there's only one backend.
- if len(backend_infos) == 1:
- backend_dir = os.path.join(artifacts_dir, backends_string)
- os.makedirs(backend_dir, exist_ok=True)
- return os.path.join(artifacts_dir, backends_string, artifact_name)
- else:
- return os.path.join(artifacts_dir, f"{artifact_name}__{backends_string}")
-
-
-def compile_tf_module(
- tf_module: Type[tf.Module],
- backend_infos: Sequence["BackendInfo"] = (),
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None
+def _incrementally_lower_compiler_module(
+ compiler_module: compiler.Module,
+ backend_info: "BackendInfo",
+ artifacts_dir: str,
) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
- """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
-
- The artifact this creates is not callable. See IreeCompiledModule for an API
- that returns a module that can be called without any further steps.
+ """Lowers a MLIR compiler module incrementally and saves its outputs.
If artifacts_dir is provided then the following artifacts will be saved:
tf_input.mlir:
@@ -129,75 +136,84 @@
iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
backend_name/compiled.vmfb:
- A VM FlatBuffer compiled to the target backends from the IREE MLIR above.
-
- If multiple backends are specified, then instead of saving compiled 'vmfb'
- under 'backend_name/', it will be saved as follows:
- - 'compiled__{backends}.vmfb'
- where 'backends' is a '__' delimited list (e.g. iree_vmla__iree_llvmjit).
+ A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
Args:
- tf_module: A tf.Module.
- backend_infos: Iterable of BackendInfo names to compile for.
+ compiler_module: A compiler.Module to lower.
+ backend_info: BackendInfo with the details for lowering compiler_module to
+ IREE.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ if artifacts_dir is not None:
+ os.makedirs(artifacts_dir, exist_ok=True)
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+ logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+ with open(tf_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ # Manually run the passes that tf_module_to_compiler_module usually would.
+ compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+
+ if artifacts_dir is not None:
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+ logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+ with open(iree_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ compiled_module = compiler_module.compile(
+ target_backends=backend_info.compiler_targets)
+
+ compiled_path = None
+ if artifacts_dir is not None:
+ backend_dir = os.path.join(artifacts_dir, backend_info.name)
+ os.makedirs(backend_dir, exist_ok=True)
+ compiled_path = os.path.join(backend_dir, "compiled.vmfb")
+ logging.info("Saving compiled IREE module to: %s", compiled_path)
+ with open(compiled_path, "wb") as f:
+ f.write(compiled_module)
+ return compiled_module, compiled_path
+
+
+def _incrementally_compile_tf_module(
+ module: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None
+) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
+ """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
+
+ The module blob this creates is not callable. See IreeCompiledModule for an
+ API that returns a module that can be called without any further steps.
+
+ See _incrementally_lower_compiler_module's docstring for details about which
+ artifacts will be saved.
+
+ Args:
+ module: A tf.Module.
+ backend_info: BackendInfo with the details for compiling module to IREE.
exported_names: Iterable of dotted function names to consider for
compilation.
artifacts_dir: An optional string pointing to where compilation artifacts
- should be saved.
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
Returns:
A compiled IREE module blob and the path to the compiled VM FlatBuffer if
artifacts_dir is provided.
"""
- if artifacts_dir is not None:
- # Set up a crash reproducer for debugging.
- backends_string = backends_to_str(backend_infos)
- compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backends_string}.mlir")
-
- try:
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_module_to_compiler_module(tf_module,
+ def _compile_module(module, exported_names, backend_info, artifacts_dir):
+ compiler_module = compiler.tf_module_to_compiler_module(module,
exported_names,
pass_pipeline=())
+ return _incrementally_lower_compiler_module(compiler_module, backend_info,
+ artifacts_dir)
- if artifacts_dir is not None:
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- # Now run the passes manually that tf_module_to_compiler_module would
- # usually do.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
-
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
-
- target_backends = []
- for backend_info in backend_infos:
- target_backends.extend(backend_info.compiler_targets)
- compiled_module = compiler_module.compile(target_backends=target_backends)
-
- compiled_path = None
- if artifacts_dir is not None:
- compiled_path = _get_backends_path("compiled", backend_infos,
- artifacts_dir)
- compiled_path = f"{compiled_path}.vmfb"
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
-
- except Exception: # pylint: disable=broad-except
- if artifacts_dir is not None:
- # Disable the crash reproducer (to avoid inadvertently overwriting it).
- compiler.Context.default_crash_reproducer_path = None
- raise
-
- return compiled_module, compiled_path
+ _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
+ backend_info.name)
+ return _compile_module(module, exported_names, backend_info, artifacts_dir)
class CompiledModule(object):
@@ -280,9 +296,9 @@
super().__init__(module_class, backend_info, exported_names, artifacts_dir)
set_random_seed()
- self._module_blob, compiled_path = compile_tf_module(
- tf_module=module_class(),
- backend_infos=[backend_info],
+ self._module_blob, compiled_path = _incrementally_compile_tf_module(
+ module=module_class(),
+ backend_info=backend_info,
exported_names=exported_names,
artifacts_dir=artifacts_dir)
self._module = rt.VmModule.from_flatbuffer(self._module_blob)
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 43ec6cf..1a441d0 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -44,6 +44,7 @@
def get_count(self):
return self.count
+
class RandomInitModule(tf.Module):
def __init__(self):
@@ -56,25 +57,15 @@
class UtilsTests(tf.test.TestCase, parameterized.TestCase):
- @parameterized.named_parameters([
- {
- 'testcase_name': 'single_backend',
- 'backend_infos': [tf_utils.BackendInfo('iree_vmla')],
- },
- {
- 'testcase_name':
- 'multiple_backends',
- 'backend_infos': [
- tf_utils.BackendInfo('iree_vmla'),
- tf_utils.BackendInfo('iree_llvmjit')
- ],
- },
- ])
- def test_artifact_saving(self, backend_infos):
+ def test_artifact_saving(self):
+ backend_info = tf_utils.BackendInfo('iree_vmla')
with tempfile.TemporaryDirectory() as artifacts_dir:
tf_module = ConstantModule()
- iree_compiled_module, compiled_path = tf_utils.compile_tf_module(
- tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
+ iree_compiled_module, compiled_path = (
+ tf_utils._incrementally_compile_tf_module(
+ tf_module,
+ backend_info=backend_info,
+ artifacts_dir=artifacts_dir))
artifacts_to_check = [
'tf_input.mlir',
@@ -121,7 +112,6 @@
inputs = [np.array([1, 2], dtype=np.float32)]
self.assertEqual('2xf32=1.0 2.0', tf_utils.save_input_values(inputs))
-
@parameterized.named_parameters([
{
'testcase_name': 'tensorflow',
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index ff95379..7adb281 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -80,11 +80,15 @@
backend as a source of truth. For example:
```python
-# Compile a `tf.Module` named `SimpleArithmeticModule` into a `CompiledModule`.
-@tf_test_utils.compile_module(SimpleArithmeticModule)
# Inherit from `TracedModuleTestCase`.
class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(SimpleArithmeticTest, self).__init__(methodName)
+ # Compile a `tf.Module` named `SimpleArithmeticModule` into
+ # `CompiledModule`s for each reference and target backend.
+ self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
# Unit test.
def test_simple_mul(self):
@@ -103,7 +107,7 @@
# Calls `simple_mul` once for each backend, recording the inputs and outputs
# to `module` and then comparing them.
- self.compare_backends(simple_mul)
+ self.compare_backends(simple_mul, self._modules)
```
## Test Suites
diff --git a/integrations/tensorflow/e2e/batch_norm_test.py b/integrations/tensorflow/e2e/batch_norm_test.py
index e16436d..4e3c0cd 100644
--- a/integrations/tensorflow/e2e/batch_norm_test.py
+++ b/integrations/tensorflow/e2e/batch_norm_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Batch norm tests."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -39,9 +40,12 @@
variance_epsilon=1e-4)
-@tf_test_utils.compile_module(BatchNormModule)
class BatchNormTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BatchNormTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(BatchNormModule)
+
def test_batch_norm_inference(self):
def batch_norm_inference(module):
@@ -53,10 +57,15 @@
scale = tf_utils.uniform((16,)) * 1e-3
module.batch_norm_inference(x, mean, variance, offset, scale)
- self.compare_backends(batch_norm_inference)
+ self.compare_backends(batch_norm_inference, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/bool_test.py b/integrations/tensorflow/e2e/bool_test.py
index 161f6ad..ef9dd10 100644
--- a/integrations/tensorflow/e2e/bool_test.py
+++ b/integrations/tensorflow/e2e/bool_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Tests for ops in the tf.math module."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -37,22 +38,25 @@
return tf.math.logical_and(x, y)
-@tf_test_utils.compile_module(MathModule)
class BooleanTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BooleanTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MathModule)
+
def test_constant(self):
def constant(module):
module.constant()
- self.compare_backends(constant)
+ self.compare_backends(constant, self._modules)
def test_greater_than(self):
def greater_than(module):
module.greater_than(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(greater_than)
+ self.compare_backends(greater_than, self._modules)
def test_logical_and(self):
@@ -61,10 +65,15 @@
np.array([True, True, False, False], dtype=np.bool),
np.array([True, False, False, True], dtype=np.bool))
- self.compare_backends(logical_and)
+ self.compare_backends(logical_and, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcast_to_test.py b/integrations/tensorflow/e2e/broadcast_to_test.py
index 6d57d6f..dc53f34 100644
--- a/integrations/tensorflow/e2e/broadcast_to_test.py
+++ b/integrations/tensorflow/e2e/broadcast_to_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
return tf.broadcast_to(x, shape)
-@tf_test_utils.compile_module(BroadcastToModule)
class BroadcastToTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BroadcastToTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(BroadcastToModule)
+
def test_scalar_broadcast_to(self):
def scalar_broadcast_to(module):
@@ -40,10 +44,15 @@
shape = np.array([3, 3], dtype=np.int32)
result = module.scalar_broadcast_to(x, shape)
- self.compare_backends(scalar_broadcast_to)
+ self.compare_backends(scalar_broadcast_to, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/broadcasting_test.py b/integrations/tensorflow/e2e/broadcasting_test.py
index c4f8e38..f72f3b3 100644
--- a/integrations/tensorflow/e2e/broadcasting_test.py
+++ b/integrations/tensorflow/e2e/broadcasting_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test broadcasting support."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -30,9 +31,12 @@
return lhs + rhs
-@tf_test_utils.compile_module(BroadcastingModule)
class BroadcastingTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(BroadcastingTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(BroadcastingModule)
+
def test_add_same_shape(self):
def add_same_shape(module):
@@ -40,7 +44,7 @@
rhs = tf_utils.uniform([4])
module.add(lhs, rhs)
- self.compare_backends(add_same_shape)
+ self.compare_backends(add_same_shape, self._modules)
def test_add_broadcast_lhs(self):
@@ -49,7 +53,7 @@
rhs = tf_utils.uniform([4])
module.add(lhs, rhs)
- self.compare_backends(add_broadcast_lhs)
+ self.compare_backends(add_broadcast_lhs, self._modules)
def test_add_broadcast_rhs(self):
@@ -58,10 +62,15 @@
rhs = tf_utils.uniform([1])
module.add(lhs, rhs)
- self.compare_backends(add_broadcast_rhs)
+ self.compare_backends(add_broadcast_rhs, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/complex_test.py b/integrations/tensorflow/e2e/complex_test.py
index 102396b..4832146 100644
--- a/integrations/tensorflow/e2e/complex_test.py
+++ b/integrations/tensorflow/e2e/complex_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -32,9 +33,12 @@
return tf.math.real(exp)
-@tf_test_utils.compile_module(ComplexModule)
class ComplexTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ComplexTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ComplexModule)
+
def test_complex(self):
def complex_exp(module):
@@ -42,10 +46,15 @@
imag = np.array([-1., 0.4], dtype=np.float32)
module.complex_exp(real, imag)
- self.compare_backends(complex_exp)
+ self.compare_backends(complex_exp, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/concat_test.py b/integrations/tensorflow/e2e/concat_test.py
index 187c6cb..dc0fab8 100644
--- a/integrations/tensorflow/e2e/concat_test.py
+++ b/integrations/tensorflow/e2e/concat_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test concat op."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -51,9 +52,12 @@
return tf.concat([a, b], axis=2)
-@tf_test_utils.compile_module(ConcatOpsModule)
class ConcatOpsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ConcatOpsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ConcatOpsModule)
+
def test_concat_zero_dim(self):
def concat_zero_dim(module):
@@ -61,7 +65,7 @@
b = tf_utils.uniform([1, 5, 1])
module.concat_zero_dim(a, b)
- self.compare_backends(concat_zero_dim)
+ self.compare_backends(concat_zero_dim, self._modules)
def test_concat0axis(self):
@@ -70,7 +74,7 @@
b = tf_utils.uniform([1, 5, 1])
module.concat0axis(a, b)
- self.compare_backends(concat0axis)
+ self.compare_backends(concat0axis, self._modules)
def test_concat1axis(self):
@@ -79,7 +83,7 @@
b = tf_utils.uniform([1, 5, 1])
module.concat1axis(a, b)
- self.compare_backends(concat1axis)
+ self.compare_backends(concat1axis, self._modules)
def test_concat2axis(self):
@@ -88,10 +92,15 @@
b = tf_utils.uniform([1, 5, 1])
module.concat2axis(a, b)
- self.compare_backends(concat2axis)
+ self.compare_backends(concat2axis, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/control_flow_test.py b/integrations/tensorflow/e2e/control_flow_test.py
index bc0a328..4ba691a 100644
--- a/integrations/tensorflow/e2e/control_flow_test.py
+++ b/integrations/tensorflow/e2e/control_flow_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -34,16 +35,19 @@
return i
-@tf_test_utils.compile_module(ControlFlowModule)
class ControlFlowTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ControlFlowTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ControlFlowModule)
+
def test_short_sequence(self):
def short_sequence(module):
input_array = np.array(9., dtype=np.float32)
module.collatz(input_array)
- self.compare_backends(short_sequence)
+ self.compare_backends(short_sequence, self._modules)
def test_long_sequence(self):
@@ -51,10 +55,15 @@
input_array = np.array(178., dtype=np.float32)
module.collatz(input_array)
- self.compare_backends(long_sequence)
+ self.compare_backends(long_sequence, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/conv_test.py b/integrations/tensorflow/e2e/conv_test.py
index 9346a52..cff4496 100644
--- a/integrations/tensorflow/e2e/conv_test.py
+++ b/integrations/tensorflow/e2e/conv_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -99,9 +100,12 @@
return tf.nn.conv2d(img, kernel, [1, 1, 1, 1], "VALID", name="result")
-@tf_test_utils.compile_module(Conv2dModule)
class ConvTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ConvTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(Conv2dModule)
+
def test_id_batch_size_1(self):
def id_batch_size_1(module):
@@ -109,7 +113,7 @@
k = np.ones([1, 1, 1, 1], dtype=np.float32)
module.conv2d_1451x1111_valid(i, k)
- self.compare_backends(id_batch_size_1)
+ self.compare_backends(id_batch_size_1, self._modules)
def test_id_batch_size_2(self):
@@ -118,7 +122,7 @@
k = np.ones([1, 1, 1, 1], dtype=np.float32)
module.conv2d_2451x1111_valid(i, k)
- self.compare_backends(id_batch_size_2)
+ self.compare_backends(id_batch_size_2, self._modules)
def test_asymmetric_kernel(self):
@@ -128,7 +132,7 @@
dtype=np.float32).reshape(2, 3, 1, 1)
module.conv2d_1451x2311_valid(i, k)
- self.compare_backends(asymmetric_kernel)
+ self.compare_backends(asymmetric_kernel, self._modules)
def test_padding(self):
@@ -138,7 +142,7 @@
dtype=np.float32).reshape(2, 3, 1, 1)
module.conv2d_1451x2311_same(i, k)
- self.compare_backends(padding)
+ self.compare_backends(padding, self._modules)
def test_batched_padding(self):
@@ -148,7 +152,7 @@
dtype=np.float32).reshape(2, 3, 1, 1)
module.conv2d_2451x2311_same(i, k)
- self.compare_backends(batched_padding)
+ self.compare_backends(batched_padding, self._modules)
def test_feature_reduce(self):
@@ -157,7 +161,7 @@
k = np.ones([3, 2, 2, 1], dtype=np.float32)
module.conv2d_1452x3221_same(i, k)
- self.compare_backends(feature_reduce)
+ self.compare_backends(feature_reduce, self._modules)
def test_feature_inflate(self):
@@ -166,7 +170,7 @@
k = tf_utils.ndarange([1, 1, 1, 2])
module.conv2d_1451x1112_same(i, k)
- self.compare_backends(feature_inflate)
+ self.compare_backends(feature_inflate, self._modules)
def test_feature_mix(self):
@@ -175,7 +179,7 @@
k = tf_utils.ndarange([1, 1, 2, 2])
module.conv2d_1452x1122_same(i, k)
- self.compare_backends(feature_mix)
+ self.compare_backends(feature_mix, self._modules)
def test_feature_padded(self):
@@ -184,7 +188,7 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_1452x2223_same(i, k)
- self.compare_backends(feature_padded)
+ self.compare_backends(feature_padded, self._modules)
def test_feature_unpadded(self):
@@ -193,7 +197,7 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_1452x2223_valid(i, k)
- self.compare_backends(feature_unpadded)
+ self.compare_backends(feature_unpadded, self._modules)
def test_batched_feature_unpadded(self):
@@ -202,10 +206,15 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_2452x2223_valid(i, k)
- self.compare_backends(batched_feature_unpadded)
+ self.compare_backends(batched_feature_unpadded, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py
index e3ba303..c928d5e 100644
--- a/integrations/tensorflow/e2e/depth_conv_test.py
+++ b/integrations/tensorflow/e2e/depth_conv_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -27,45 +28,58 @@
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_2452x2423_valid(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 1, 1, 1], "VALID", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 1, 1, 1],
+ "VALID",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2423_same(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 1, 1, 1], "SAME", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 1, 1, 1],
+ "SAME",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2423_valid_stride_2(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 2, 2, 1], "VALID", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 2, 2, 1],
+ "VALID",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2423_same_stride_2(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 2, 2, 1], "SAME", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 2, 2, 1],
+ "SAME",
+ name="result")
@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 4], tf.float32),
tf.TensorSpec([2, 4, 4, 1], tf.float32),
])
def conv2d_2453x2441_same_stride_1(self, img, kernel):
- return tf.nn.depthwise_conv2d(
- img, kernel, [1, 1, 1, 1], "SAME", name="result")
+ return tf.nn.depthwise_conv2d(img,
+ kernel, [1, 1, 1, 1],
+ "SAME",
+ name="result")
-@tf_test_utils.compile_module(DepthConv2dModule)
class ConvTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ConvTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(DepthConv2dModule)
+
def test_batched_feature_unpadded(self):
def batched_feature_unpadded(module):
@@ -73,7 +87,7 @@
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_2452x2423_valid(i, k)
- self.compare_backends(batched_feature_unpadded)
+ self.compare_backends(batched_feature_unpadded, self._modules)
def test_batched_feature_unpadded_same(self):
@@ -82,7 +96,7 @@
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2423_same(i, k)
- self.compare_backends(batched_feature_unpadded_same)
+ self.compare_backends(batched_feature_unpadded_same, self._modules)
def test_batched_feature_unpadded_same_stride_2(self):
@@ -91,7 +105,8 @@
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2423_valid_stride_2(i, k)
- self.compare_backends(batched_feature_unpadded_same_stride_2)
+ self.compare_backends(batched_feature_unpadded_same_stride_2,
+ self._modules)
def test_batched_feature_padded_same_stride_2(self):
@@ -100,7 +115,7 @@
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2423_same_stride_2(i, k)
- self.compare_backends(batched_feature_padded_same_stride_2)
+ self.compare_backends(batched_feature_padded_same_stride_2, self._modules)
def test_batched_feature_padded_same_stride_1_output_1(self):
@@ -109,10 +124,16 @@
k = tf_utils.ndarange([2, 4, 4, 1])
module.conv2d_2453x2441_same_stride_1(i, k)
- self.compare_backends(batched_feature_padded_same_stride_1_output_1)
+ self.compare_backends(batched_feature_padded_same_stride_1_output_1,
+ self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
index 3742c6e..1e5ba08 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_relu_test.py
@@ -17,6 +17,7 @@
# This uses a relu instead, allowing it to get to the remaining issue
# (unimplemented dynamic dot_general).
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -42,8 +43,8 @@
self.input_dim = input_dim
self.classes = classes
self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
- self.h2_weights = tf.Variable(
- tf.random.normal([hidden_1_dim, hidden_2_dim]))
+ self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+ hidden_2_dim]))
self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -51,8 +52,7 @@
# Compile with dynamic batch dim.
self.predict = tf.function(
- input_signature=[tf.TensorSpec([None, self.input_dim])])(
- self.predict)
+ input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
def mlp(self, x):
layer_1 = tf.nn.relu(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -65,19 +65,28 @@
return tf.nn.softmax(self.mlp(x))
-@tf_test_utils.compile_module(DynamicMlpReluModule, exported_names=["predict"])
class DynamicMlpReluTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(DynamicMlpReluTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(DynamicMlpReluModule,
+ exported_names=["predict"])
+
def test_dynamic_batch(self):
def dynamic_batch(module):
x = tf_utils.uniform([3, 28 * 28]) * 1e-3
module.predict(x)
- self.compare_backends(dynamic_batch)
+ self.compare_backends(dynamic_batch, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/dynamic_mlp_test.py b/integrations/tensorflow/e2e/dynamic_mlp_test.py
index ff905a5..a0ecdc5 100644
--- a/integrations/tensorflow/e2e/dynamic_mlp_test.py
+++ b/integrations/tensorflow/e2e/dynamic_mlp_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -38,8 +39,8 @@
self.input_dim = input_dim
self.classes = classes
self.h1_weights = tf.Variable(tf.random.normal([input_dim, hidden_1_dim]))
- self.h2_weights = tf.Variable(
- tf.random.normal([hidden_1_dim, hidden_2_dim]))
+ self.h2_weights = tf.Variable(tf.random.normal([hidden_1_dim,
+ hidden_2_dim]))
self.out_weights = tf.Variable(tf.random.normal([hidden_2_dim, classes]))
self.h1_bias = tf.Variable(tf.random.normal([hidden_1_dim]))
self.h2_bias = tf.Variable(tf.random.normal([hidden_2_dim]))
@@ -47,8 +48,7 @@
# Compile with dynamic batch dim.
self.predict = tf.function(
- input_signature=[tf.TensorSpec([None, self.input_dim])])(
- self.predict)
+ input_signature=[tf.TensorSpec([None, self.input_dim])])(self.predict)
def mlp(self, x):
layer_1 = tf.sigmoid(tf.add(tf.matmul(x, self.h1_weights), self.h1_bias))
@@ -61,19 +61,28 @@
return tf.nn.softmax(self.mlp(x))
-@tf_test_utils.compile_module(DynamicMlpModule, exported_names=["predict"])
class DynamicMlpTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(DynamicMlpTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(DynamicMlpModule,
+ exported_names=["predict"])
+
def test_dynamic_batch(self):
def dynamic_batch(module):
x = tf_utils.uniform([3, 28 * 28]) * 1e-3
module.predict(x)
- self.compare_backends(dynamic_batch)
+ self.compare_backends(dynamic_batch, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/fill_test.py b/integrations/tensorflow/e2e/fill_test.py
index 050eb47..adaeac9 100644
--- a/integrations/tensorflow/e2e/fill_test.py
+++ b/integrations/tensorflow/e2e/fill_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -30,9 +31,12 @@
return tf.fill(dims, value)
-@tf_test_utils.compile_module(FillModule)
class FillTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(FillTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(FillModule)
+
def test_fill(self):
def fill(module):
@@ -40,10 +44,15 @@
value = np.array(9., dtype=np.float32)
module.fill(dims, value)
- self.compare_backends(fill)
+ self.compare_backends(fill, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/finite_test.py b/integrations/tensorflow/e2e/finite_test.py
index ff62f3a..1f53406 100644
--- a/integrations/tensorflow/e2e/finite_test.py
+++ b/integrations/tensorflow/e2e/finite_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -24,18 +25,26 @@
return tf.math.is_finite(x)
-@tf_test_utils.compile_module(FiniteModule)
class FiniteTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(FiniteTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(FiniteModule)
+
def test_finite(self):
def finite(module):
module.finite(np.array([0.0, 1.2, -5.0, np.inf], dtype=np.float32))
- self.compare_backends(finite)
+ self.compare_backends(finite, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/gather_test.py b/integrations/tensorflow/e2e/gather_test.py
index fc0ec13..da71d4b 100644
--- a/integrations/tensorflow/e2e/gather_test.py
+++ b/integrations/tensorflow/e2e/gather_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -65,9 +66,12 @@
return tf.gather(params, indices, axis=1, batch_dims=1)
-@tf_test_utils.compile_module(GatherModule)
class GatherTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(GatherTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(GatherModule)
+
def test_gather_axis0_scalar(self):
def gather_axis0_scalar(module):
@@ -75,7 +79,7 @@
params = tf_utils.ndarange([4, 8])
module.gather_axis0_scalar(params, indices)
- self.compare_backends(gather_axis0_scalar)
+ self.compare_backends(gather_axis0_scalar, self._modules)
def test_gather_axis0_batch0(self):
@@ -84,7 +88,7 @@
params = tf_utils.ndarange([4, 8])
module.gather_axis0_batch0(params, indices)
- self.compare_backends(gather_axis0_batch0)
+ self.compare_backends(gather_axis0_batch0, self._modules)
def test_gather_axis1_batch0(self):
@@ -93,7 +97,7 @@
params = tf_utils.ndarange([4, 7, 8])
module.gather_axis1_batch0(params, indices)
- self.compare_backends(gather_axis1_batch0)
+ self.compare_backends(gather_axis1_batch0, self._modules)
def test_gather_axis2_batch1(self):
@@ -102,7 +106,7 @@
params = tf_utils.ndarange([4, 7, 8, 2])
module.gather_axis2_batch1(params, indices)
- self.compare_backends(gather_axis2_batch1)
+ self.compare_backends(gather_axis2_batch1, self._modules)
def test_gather_axis1_batch1(self):
@@ -111,7 +115,7 @@
params = tf_utils.ndarange([4, 7, 8, 2])
module.gather_axis1_batch1(params, indices)
- self.compare_backends(gather_axis1_batch1)
+ self.compare_backends(gather_axis1_batch1, self._modules)
def test_gather_axis2_batch2(self):
@@ -120,11 +124,16 @@
values = np.array([[0, 1, 2, 3], [9, 8, 7, 0]], dtype=np.int32)
module.gather_axis2_batch2(values, indices)
- self.compare_backends(gather_axis2_batch2)
+ self.compare_backends(gather_axis2_batch2, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_static_test.py b/integrations/tensorflow/e2e/keras/lstm_static_test.py
index db8f86e..5cb1827 100644
--- a/integrations/tensorflow/e2e/keras/lstm_static_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_static_test.py
@@ -16,6 +16,7 @@
# This test is the same as keras_lstm_test, but all shapes are static.
# This stresses the TensorList lowering more specifically.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -42,19 +43,28 @@
self.m.call)
-@tf_test_utils.compile_module(LstmStaticModule, exported_names=["predict"])
class LstmStaticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LstmStaticTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LstmStaticModule,
+ exported_names=["predict"])
+
def test_lstm(self):
def predict(module):
inputs = tf_utils.ndarange(INPUT_SHAPE)
module.predict(inputs, rtol=1e-5, atol=1e-5)
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/lstm_test.py b/integrations/tensorflow/e2e/keras/lstm_test.py
index 112e32a..747cc6f 100644
--- a/integrations/tensorflow/e2e/keras/lstm_test.py
+++ b/integrations/tensorflow/e2e/keras/lstm_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -31,28 +32,36 @@
super(LstmModule, self).__init__()
tf_utils.set_random_seed()
inputs = tf.keras.layers.Input(batch_size=None, shape=DYNAMIC_SHAPE[1:])
- outputs = tf.keras.layers.LSTM(
- units=NUM_UNITS, return_sequences=True)(
- inputs)
+ outputs = tf.keras.layers.LSTM(units=NUM_UNITS,
+ return_sequences=True)(inputs)
self.m = tf.keras.Model(inputs, outputs)
self.predict = tf.function(
- input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(
- self.m.call)
+ input_signature=[tf.TensorSpec(DYNAMIC_SHAPE, tf.float32)])(self.m.call)
-@tf_test_utils.compile_module(LstmModule, exported_names=["predict"])
class LstmTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LstmTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LstmModule,
+ exported_names=["predict"])
+
def test_lstm(self):
def predict(module):
inputs = tf_utils.ndarange(INPUT_SHAPE)
module.predict(inputs)
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
+
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/keras/train/model_train_test.py b/integrations/tensorflow/e2e/keras/train/model_train_test.py
index 75286ad..49a3949 100644
--- a/integrations/tensorflow/e2e/keras/train/model_train_test.py
+++ b/integrations/tensorflow/e2e/keras/train/model_train_test.py
@@ -75,8 +75,6 @@
return loss_value
-@tf_test_utils.compile_module(
- ModelTrain.CreateModule, exported_names=["train_step"])
class ModelTrainTest(tf_test_utils.TracedModuleTestCase):
def generate_regression_data(self, size=8):
@@ -104,11 +102,12 @@
# Run one iteration of training step.
module.train_step(inputs, targets)
- self.compare_backends(train_step)
+ self.compare_backends(train_step, self._modules)
if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
tf.enable_v2_behavior()
-
+ tf_test_utils.compile_tf_module(ModelTrain.CreateModule,
+ exported_names=["train_step"])
tf.test.main()
diff --git a/integrations/tensorflow/e2e/keras/vision_model_test.py b/integrations/tensorflow/e2e/keras/vision_model_test.py
index 10663ec..4e1ca64 100644
--- a/integrations/tensorflow/e2e/keras/vision_model_test.py
+++ b/integrations/tensorflow/e2e/keras/vision_model_test.py
@@ -113,8 +113,9 @@
# an external tf.keras URL.
weights = 'imagenet' if FLAGS.data == 'imagenet' else None
- model = APP_MODELS[FLAGS.model](
- weights=weights, include_top=FLAGS.include_top, input_shape=input_shape)
+ model = APP_MODELS[FLAGS.model](weights=weights,
+ include_top=FLAGS.include_top,
+ input_shape=input_shape)
if FLAGS.data == 'cifar10' and FLAGS.url:
model = load_cifar10_weights(model)
@@ -131,19 +132,22 @@
# TODO(b/142948097): Add support for dynamic shapes in SPIR-V lowering.
# Replace input_shape with m.input_shape to make the batch size dynamic.
self.predict = tf.function(
- input_signature=[tf.TensorSpec(get_input_shape())])(
- self.m.predict)
+ input_signature=[tf.TensorSpec(get_input_shape())])(self.m.predict)
-@tf_test_utils.compile_module(VisionModule, exported_names=['predict'])
class AppTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(AppTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(VisionModule,
+ exported_names=['predict'])
+
def test_application(self):
def predict(module):
module.predict(tf_utils.uniform(get_input_shape()))
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
def main(argv):
diff --git a/integrations/tensorflow/e2e/linspace_test.py b/integrations/tensorflow/e2e/linspace_test.py
index b535021..4acc170 100644
--- a/integrations/tensorflow/e2e/linspace_test.py
+++ b/integrations/tensorflow/e2e/linspace_test.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
-class LinSpaceModule(tf.Module):
+class LinspaceModule(tf.Module):
def __init__(self):
pass
@@ -33,9 +34,12 @@
return tf.linspace(start, stop, num)
-@tf_test_utils.compile_module(LinSpaceModule)
class LinspaceTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LinspaceTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LinspaceModule)
+
def test_linspace(self):
def linspace(module):
@@ -43,10 +47,15 @@
stop = np.array(12., dtype=np.float32)
module.linspace(start, stop)
- self.compare_backends(linspace)
+ self.compare_backends(linspace, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/logical_ops_test.py b/integrations/tensorflow/e2e/logical_ops_test.py
index ea3fd7a..ee83a95 100644
--- a/integrations/tensorflow/e2e/logical_ops_test.py
+++ b/integrations/tensorflow/e2e/logical_ops_test.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for ops in the tf.math module that specifically handle logical ops."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -46,9 +47,12 @@
return tf.math.logical_not(x)
-@tf_test_utils.compile_module(LogicalOpsModule)
class LogicalOpsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(LogicalOpsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(LogicalOpsModule)
+
def test_logical_and(self):
def logical_and(module):
@@ -56,7 +60,7 @@
np.array([1, 1, 0, 0], dtype=np.bool),
np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_and)
+ self.compare_backends(logical_and, self._modules)
def test_logical_or(self):
@@ -65,7 +69,7 @@
np.array([1, 1, 0, 0], dtype=np.bool),
np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_or)
+ self.compare_backends(logical_or, self._modules)
def test_logical_xor(self):
@@ -74,17 +78,22 @@
np.array([1, 1, 0, 0], dtype=np.bool),
np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_xor)
+ self.compare_backends(logical_xor, self._modules)
def test_logical_not(self):
def logical_not(module):
module.logical_not(np.array([0, 1, 1, 0], dtype=np.bool))
- self.compare_backends(logical_not)
+ self.compare_backends(logical_not, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/mandelbrot_test.py b/integrations/tensorflow/e2e/mandelbrot_test.py
index b5a3929..51c1509 100644
--- a/integrations/tensorflow/e2e/mandelbrot_test.py
+++ b/integrations/tensorflow/e2e/mandelbrot_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -90,9 +91,12 @@
return tf.reshape(in_the_set, shape=[view_pixels, view_pixels])
-@tf_test_utils.compile_module(MandelbrotModule)
class MandelbrotTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MandelbrotTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MandelbrotModule)
+
def test_mandelbrot(self):
def mandelbrot(module):
@@ -101,10 +105,15 @@
# This is a much more detailed view, so more iterations are needed.
module.calculate(-0.7436447860, 0.1318252536, 0.0000029336, 400, 3000)
- self.compare_backends(mandelbrot)
+ self.compare_backends(mandelbrot, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/math_test.py b/integrations/tensorflow/e2e/math_test.py
index 48bcbb6..a96193a 100644
--- a/integrations/tensorflow/e2e/math_test.py
+++ b/integrations/tensorflow/e2e/math_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Tests for ops in the tf.math module."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -42,46 +43,54 @@
return tf.math.mod(x, 2.0)
-@tf_test_utils.compile_module(MathModule)
class MathTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MathTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MathModule)
+
def test_abs(self):
def abs(module):
module.abs(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(abs)
+ self.compare_backends(abs, self._modules)
def test_ceil(self):
def ceil(module):
module.ceil(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(ceil)
+ self.compare_backends(ceil, self._modules)
def test_cos(self):
def cos(module):
module.cos(np.array([-0.5, 0.0, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(cos)
+ self.compare_backends(cos, self._modules)
def test_log(self):
def log(module):
module.log(np.array([0.1, 0.2, 0.5, 1.0], dtype=np.float32))
- self.compare_backends(log)
+ self.compare_backends(log, self._modules)
def test_mod(self):
def mod(module):
module.mod(np.array([0.0, 1.2, 1.5, 3.75], dtype=np.float32))
- self.compare_backends(mod)
+ self.compare_backends(mod, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
index 9c87b7b..4a3daf7 100644
--- a/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_dynamic_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test matrix ops."""
+from absl import app
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -43,16 +44,19 @@
return tf.matmul(lhs, rhs)
-@tf_test_utils.compile_module(MatrixOpsDynamicModule)
class MatrixOpsDynamicTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MatrixOpsDynamicTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MatrixOpsDynamicModule)
+
def test_matmul_high_rank_batch(self):
def matmul_high_rank_batch(module):
module.matmul_high_rank_batch(
tf_utils.uniform([1, 7, 4, 2]), tf_utils.uniform([7, 1, 2, 4]))
- self.compare_backends(matmul_high_rank_batch)
+ self.compare_backends(matmul_high_rank_batch, self._modules)
def test_matmul_dynamic_matching_batch(self):
@@ -60,7 +64,7 @@
module.matmul_dynamic(
tf_utils.uniform([2, 2, 3]), tf_utils.uniform([2, 3, 4]))
- self.compare_backends(matmul_dynamic_matching_batch)
+ self.compare_backends(matmul_dynamic_matching_batch, self._modules)
def test_matmul_dynamic_broadcast_lhs(self):
@@ -68,7 +72,7 @@
module.matmul_dynamic(
tf_utils.uniform([1, 2, 3]), tf_utils.uniform([2, 3, 4]))
- self.compare_backends(matmul_dynamic_broadcast_lhs)
+ self.compare_backends(matmul_dynamic_broadcast_lhs, self._modules)
def test_matmul_dynamic_broadcast_rhs(self):
@@ -76,7 +80,7 @@
module.matmul_dynamic(
tf_utils.uniform([2, 2, 3]), tf_utils.uniform([1, 3, 4]))
- self.compare_backends(matmul_dynamic_broadcast_rhs)
+ self.compare_backends(matmul_dynamic_broadcast_rhs, self._modules)
def test_matmul_dynamic_rank_broadcasting(self):
@@ -84,10 +88,15 @@
module.matmul_dynamic_lhs_batch(
tf_utils.uniform([7, 2, 3]), tf_utils.uniform([3, 4]))
- self.compare_backends(matmul_dynamic_rank_broadcasting)
+ self.compare_backends(matmul_dynamic_rank_broadcasting, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/matrix_ops_static_test.py b/integrations/tensorflow/e2e/matrix_ops_static_test.py
index 82fa482..3a7fe47 100644
--- a/integrations/tensorflow/e2e/matrix_ops_static_test.py
+++ b/integrations/tensorflow/e2e/matrix_ops_static_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Test matrix ops."""
+from absl import app
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
import tensorflow.compat.v2 as tf
@@ -55,16 +56,19 @@
return tf.matmul(lhs, rhs)
-@tf_test_utils.compile_module(MatrixOpsStaticModule)
class MatrixOpsStaticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(MatrixOpsStaticTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MatrixOpsStaticModule)
+
def test_basic_matmul(self):
def basic_matmul(module):
module.basic_matmul(tf_utils.uniform([LEFT_DIM, INNER_DIM]),
tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
- self.compare_backends(basic_matmul)
+ self.compare_backends(basic_matmul, self._modules)
def test_matmul_lhs_batch(self):
@@ -73,7 +77,7 @@
tf_utils.uniform([BATCH_DIM, LEFT_DIM, INNER_DIM]),
tf_utils.uniform([INNER_DIM, RIGHT_DIM]))
- self.compare_backends(matmul_lhs_batch)
+ self.compare_backends(matmul_lhs_batch, self._modules)
def test_matmul_rhs_batch(self):
@@ -82,7 +86,7 @@
tf_utils.uniform([LEFT_DIM, INNER_DIM]),
tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
- self.compare_backends(matmul_rhs_batch)
+ self.compare_backends(matmul_rhs_batch, self._modules)
def test_matmul_broadcast_singleton_dimension(self):
@@ -91,10 +95,15 @@
tf_utils.uniform([1, LEFT_DIM, INNER_DIM]),
tf_utils.uniform([BATCH_DIM, INNER_DIM, RIGHT_DIM]))
- self.compare_backends(matmul_broadcast_singleton_dimension)
+ self.compare_backends(matmul_broadcast_singleton_dimension, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/mobile_bert_squad_test.py b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
index e8d3997..4f8daf5 100644
--- a/integrations/tensorflow/e2e/mobile_bert_squad_test.py
+++ b/integrations/tensorflow/e2e/mobile_bert_squad_test.py
@@ -81,10 +81,14 @@
return self.inference_func(**inputs)
-@tf_test_utils.compile_module(MobileBertSquad, exported_names=["predict"])
class MobileBertSquadTest(tf_test_utils.TracedModuleTestCase):
"""Tests of MobileBertSquad."""
+ def __init__(self, methodName="runTest"):
+ super(MobileBertSquadTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(MobileBertSquad,
+ exported_names=["predict"])
+
def test_predict(self):
def predict(module):
@@ -94,11 +98,15 @@
module.predict(input_ids, input_mask, segment_ids, atol=1e0)
- self.compare_backends(predict)
+ self.compare_backends(predict, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
-
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/range_test.py b/integrations/tensorflow/e2e/range_test.py
index f1e093c..f4e3e97 100644
--- a/integrations/tensorflow/e2e/range_test.py
+++ b/integrations/tensorflow/e2e/range_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -31,9 +32,12 @@
return tf.range(start, stop, delta)
-@tf_test_utils.compile_module(RangeModule)
class RangeTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(RangeTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(RangeModule)
+
def test_range(self):
def range(module):
@@ -42,10 +46,15 @@
delta = np.array(3, dtype=np.float32)
result = module.range(start, stop, delta)
- self.compare_backends(range)
+ self.compare_backends(range, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/resource_ops_test.py b/integrations/tensorflow/e2e/resource_ops_test.py
index dd5ad6d..d387e86 100644
--- a/integrations/tensorflow/e2e/resource_ops_test.py
+++ b/integrations/tensorflow/e2e/resource_ops_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -28,18 +29,26 @@
return self.counter.assign_add(value)
-@tf_test_utils.compile_module(ResourcesOpsModule)
class ResourcesOpsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ResourcesOpsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ResourcesOpsModule)
+
def test_add_assign(self):
def add_assign(module):
module.add_assign(np.array(9., dtype=np.float32))
- self.compare_backends(add_assign)
+ self.compare_backends(add_assign, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/ring_buffer_test.py b/integrations/tensorflow/e2e/ring_buffer_test.py
index 8e437ea..c56bf31 100644
--- a/integrations/tensorflow/e2e/ring_buffer_test.py
+++ b/integrations/tensorflow/e2e/ring_buffer_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -31,16 +32,20 @@
# buffer has size [buffer_size, dims]
# only the first dimension is used for updating buffer in a ring manner
- self._buffer = tf.Variable(
- tf.zeros((self._buffer_size,) + dims, dtype=dtype),
- trainable=False,
- name="RingBuffer")
+ self._buffer = tf.Variable(tf.zeros((self._buffer_size,) + dims,
+ dtype=dtype),
+ trainable=False,
+ name="RingBuffer")
# Size of the data available for reading
- self._data_size = tf.Variable(
- 0, trainable=False, dtype=tf.int32, name="FramerBuffer/Size")
+ self._data_size = tf.Variable(0,
+ trainable=False,
+ dtype=tf.int32,
+ name="FramerBuffer/Size")
# The index pointing to the head of the data available for reading
- self._read_head = tf.Variable(
- 0, trainable=False, dtype=tf.int32, name="FramerBuffer/Head")
+ self._read_head = tf.Variable(0,
+ trainable=False,
+ dtype=tf.int32,
+ name="FramerBuffer/Head")
@property
def dtype(self):
@@ -82,8 +87,8 @@
start = tf.math.floormod(
self._read_head.read_value() + self._data_size.read_value(),
self._buffer_size)
- indices = tf.math.floormod(
- tf.range(start, limit=start + elements_size), self._buffer_size)
+ indices = tf.math.floormod(tf.range(start, limit=start + elements_size),
+ self._buffer_size)
tf.compat.v1.scatter_update(self._buffer, indices, elements)
@@ -118,8 +123,8 @@
Tensor of elements with shape [length, dims...].
"""
start = self._read_head + offset
- indices = tf.math.floormod(
- tf.range(start, limit=start + length), self._buffer_size)
+ indices = tf.math.floormod(tf.range(start, limit=start + length),
+ self._buffer_size)
result = tf.gather(self._buffer, indices)
if consume:
self.consume(length, offset)
@@ -148,8 +153,9 @@
def build(self, input_shape):
super(StatefulRingBuffer, self).build(input_shape)
buffer_size = self.state_shape[1]
- self.rb = RingBuffer(
- buffer_size=buffer_size, dims=(self.state_shape[2],), dtype=tf.float32)
+ self.rb = RingBuffer(buffer_size=buffer_size,
+ dims=(self.state_shape[2],),
+ dtype=tf.float32)
def call(self, inputs):
self.rb.write(inputs)
@@ -177,10 +183,13 @@
return self.rb(x)
-@tf_test_utils.compile_module(
- StatefulRingBufferModule, exported_names=["predict"])
class StatefulRingBufferTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(StatefulRingBufferTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(StatefulRingBufferModule,
+ exported_names=["predict"])
+
def test_stateful_ringbuffer(self):
def stateful_ringbuffer(module):
@@ -198,10 +207,15 @@
module.predict(input3)
# output = np.array([[3.0, 4.0]], dtype=np.float32)
- self.compare_backends(stateful_ringbuffer)
+ self.compare_backends(stateful_ringbuffer, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/scatter_update_test.py b/integrations/tensorflow/e2e/scatter_update_test.py
index 8b43e3a..8f34002 100644
--- a/integrations/tensorflow/e2e/scatter_update_test.py
+++ b/integrations/tensorflow/e2e/scatter_update_test.py
@@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Test scatter update behavior for tensorflow."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
-"""Test scatter update behavior for tensorflow."""
class ScatterUpdateModule(tf.Module):
@@ -48,9 +49,12 @@
return tf.tensor_scatter_nd_update(tensor, indices, updates)
-@tf_test_utils.compile_module(ScatterUpdateModule)
class ScatterUpdateTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(ScatterUpdateTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(ScatterUpdateModule)
+
def test_scatter_update_1D(self):
def scatter_update_1D(module):
@@ -59,7 +63,7 @@
updates = np.array([9, 10, 11], dtype=np.int32)
module.scatter_update_1D(tensor, indices, updates)
- self.compare_backends(scatter_update_1D)
+ self.compare_backends(scatter_update_1D, self._modules)
def test_scatter_update_2D(self):
@@ -69,7 +73,7 @@
updates = np.array([2, 5, 8], dtype=np.int32)
module.scatter_update_2D(tensor, indices, updates)
- self.compare_backends(scatter_update_2D)
+ self.compare_backends(scatter_update_2D, self._modules)
def test_scatter_update_2D_slice(self):
@@ -79,10 +83,15 @@
updates = np.array([[2, 3, 4]], dtype=np.int32)
module.scatter_update_2D_slice(tensor, indices, updates)
- self.compare_backends(scatter_update_2D_slice)
+ self.compare_backends(scatter_update_2D_slice, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_arithmetic_test.py b/integrations/tensorflow/e2e/simple_arithmetic_test.py
index aaec578..cc9ca09 100644
--- a/integrations/tensorflow/e2e/simple_arithmetic_test.py
+++ b/integrations/tensorflow/e2e/simple_arithmetic_test.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""Several baseline e2e simple arithmetic tests."""
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -37,9 +38,12 @@
return tf.matmul(a, b)
-@tf_test_utils.compile_module(SimpleArithmeticModule)
class SimpleArithmeticTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(SimpleArithmeticTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SimpleArithmeticModule)
+
def test_simple_mul(self):
def simple_mul(module):
@@ -48,7 +52,7 @@
c = module.simple_mul(a, b)
module.simple_mul(a, c)
- self.compare_backends(simple_mul)
+ self.compare_backends(simple_mul, self._modules)
def test_simple_matmul(self):
@@ -58,10 +62,15 @@
b = tf_utils.uniform((3072, 256)) * 1e-3
module.simple_matmul(a, b)
- self.compare_backends(simple_matmul)
+ self.compare_backends(simple_matmul, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/simple_stateful_test.py b/integrations/tensorflow/e2e/simple_stateful_test.py
index 1120a4f..e4140d4 100644
--- a/integrations/tensorflow/e2e/simple_stateful_test.py
+++ b/integrations/tensorflow/e2e/simple_stateful_test.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -33,19 +34,27 @@
return self.counter
-@tf_test_utils.compile_module(SimpleStatefulModule)
class StatefulTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(StatefulTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SimpleStatefulModule)
+
def test_stateful(self):
def get_state(module):
module.inc_by(np.array(1., dtype=np.float32))
module.get_state()
- self.compare_backends(get_state)
+ self.compare_backends(get_state, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/sliding_window_test.py b/integrations/tensorflow/e2e/sliding_window_test.py
index 513aa97..413ffb2 100644
--- a/integrations/tensorflow/e2e/sliding_window_test.py
+++ b/integrations/tensorflow/e2e/sliding_window_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import tensorflow.compat.v2 as tf
@@ -75,9 +76,13 @@
return self.sw(x)
-@tf_test_utils.compile_module(SlidingWindowModule, exported_names=["predict"])
class SlidingWindowTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(SlidingWindowTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(SlidingWindowModule,
+ exported_names=["predict"])
+
def test_sliding_window(self):
def sliding_window(module):
@@ -89,10 +94,15 @@
result2 = module.predict(input2)
# output2 = np.array([[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
- self.compare_backends(sliding_window)
+ self.compare_backends(sliding_window, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/strings_test.py b/integrations/tensorflow/e2e/strings_test.py
index 206b33c..70f6468 100644
--- a/integrations/tensorflow/e2e/strings_test.py
+++ b/integrations/tensorflow/e2e/strings_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
import string
@@ -40,9 +41,12 @@
return tf.strings.reduce_join(wps, 1)
-@tf_test_utils.compile_module(StringsModule)
class StringsTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(StringsTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(StringsModule)
+
def test_print_ids(self):
def print_ids(module):
@@ -51,7 +55,7 @@
[13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
module.print_ids(input_ids)
- self.compare_backends(print_ids)
+ self.compare_backends(print_ids, self._modules)
def test_strings_to_ids(self):
@@ -61,10 +65,15 @@
[13, 24, 16, 28, 94, 15, 24, 27, 94, 28, 29, 10, 34]])
module.strings_to_ids(input_ids)
- self.compare_backends(strings_to_ids)
+ self.compare_backends(strings_to_ids, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/integrations/tensorflow/e2e/tensorlist_test.py b/integrations/tensorflow/e2e/tensorlist_test.py
index 440bd43..62babc5 100644
--- a/integrations/tensorflow/e2e/tensorlist_test.py
+++ b/integrations/tensorflow/e2e/tensorlist_test.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from absl import app
import numpy as np
from pyiree.tf.support import tf_test_utils
from pyiree.tf.support import tf_utils
@@ -50,8 +51,9 @@
@tf.function(
input_signature=[tf.TensorSpec([STATIC_SIZE, STATIC_SIZE], tf.float32)])
def slice_first_element_with_from_tensor_high_rank(self, t):
- ta = tf.TensorArray(
- dtype=tf.float32, size=STATIC_SIZE, element_shape=[STATIC_SIZE])
+ ta = tf.TensorArray(dtype=tf.float32,
+ size=STATIC_SIZE,
+ element_shape=[STATIC_SIZE])
ta = ta.unstack(t)
return ta.read(0)
@@ -66,23 +68,26 @@
return ta.stack()
-@tf_test_utils.compile_module(TensorListModule)
class TensorListTest(tf_test_utils.TracedModuleTestCase):
+ def __init__(self, methodName="runTest"):
+ super(TensorListTest, self).__init__(methodName)
+ self._modules = tf_test_utils.compile_tf_module(TensorListModule)
+
def test_identity_through_tensorlist(self):
def identity_through_tensorlist(module):
module.identity_through_tensorlist(np.array(42., dtype=np.float32))
- self.compare_backends(identity_through_tensorlist)
+ self.compare_backends(identity_through_tensorlist, self._modules)
def test_add_through_tensorlist(self):
def add_through_tensorlist(module):
- module.add_through_tensorlist(
- np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+ module.add_through_tensorlist(np.array(42., dtype=np.float32),
+ np.array(43., dtype=np.float32))
- self.compare_backends(add_through_tensorlist)
+ self.compare_backends(add_through_tensorlist, self._modules)
def test_slice_first_element_with_from_tensor(self):
@@ -90,7 +95,7 @@
module.slice_first_element_with_from_tensor(
np.arange(STATIC_SIZE, dtype=np.float32))
- self.compare_backends(slice_first_element_with_from_tensor)
+ self.compare_backends(slice_first_element_with_from_tensor, self._modules)
def test_slice_first_element_with_from_tensor_high_rank(self):
@@ -98,18 +103,24 @@
module.slice_first_element_with_from_tensor_high_rank(
tf_utils.ndarange([STATIC_SIZE, STATIC_SIZE]))
- self.compare_backends(slice_first_element_with_from_tensor_high_rank)
+ self.compare_backends(slice_first_element_with_from_tensor_high_rank,
+ self._modules)
def test_concat_with_tensorlist_stack(self):
def concat_with_tensorlist_stack(module):
- module.concat_with_tensorlist_stack(
- np.array(42., dtype=np.float32), np.array(43., dtype=np.float32))
+ module.concat_with_tensorlist_stack(np.array(42., dtype=np.float32),
+ np.array(43., dtype=np.float32))
- self.compare_backends(concat_with_tensorlist_stack)
+ self.compare_backends(concat_with_tensorlist_stack, self._modules)
-if __name__ == "__main__":
- if hasattr(tf, "enable_v2_behavior"):
+def main(argv):
+ del argv # Unused
+ if hasattr(tf, 'enable_v2_behavior'):
tf.enable_v2_behavior()
tf.test.main()
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/iree/base/tracing.cc b/iree/base/tracing.cc
index 7049d39..15301ad 100644
--- a/iree/base/tracing.cc
+++ b/iree/base/tracing.cc
@@ -68,12 +68,13 @@
TracyLfqCommitC;
}
#endif // TRACY_NO_VERIFY
- auto name_ptr =
- reinterpret_cast<char*>(tracy::tracy_malloc(name_length + 1));
+ auto name_ptr = reinterpret_cast<char*>(tracy::tracy_malloc(name_length));
memcpy(name_ptr, name, name_length);
- name_ptr[name_length] = '\0';
TracyLfqPrepareC(tracy::QueueType::ZoneName);
- tracy::MemWrite(&item->zoneText.text, reinterpret_cast<uint64_t>(name_ptr));
+ tracy::MemWrite(&item->zoneTextFat.text,
+ reinterpret_cast<uint64_t>(name_ptr));
+ tracy::MemWrite(&item->zoneTextFat.size,
+ static_cast<uint64_t>(name_length));
TracyLfqCommitC;
}
@@ -84,22 +85,9 @@
const char* file_name, size_t file_name_length, uint32_t line,
const char* function_name, size_t function_name_length, const char* name,
size_t name_length) {
- // NOTE: cloned from tracy::Profiler::AllocSourceLocation so that we can use
- // the string lengths we already have.
- const uint32_t src_loc_length =
- static_cast<uint32_t>(4 + 4 + 4 + function_name_length + 1 +
- file_name_length + 1 + name_length);
- auto ptr = reinterpret_cast<char*>(tracy::tracy_malloc(src_loc_length));
- memcpy(ptr, &src_loc_length, 4);
- memset(ptr + 4, 0, 4);
- memcpy(ptr + 8, &line, 4);
- memcpy(ptr + 12, function_name, function_name_length + 1);
- memcpy(ptr + 12 + function_name_length + 1, file_name, file_name_length + 1);
- if (name_length) {
- memcpy(ptr + 12 + function_name_length + 1 + file_name_length + 1, name,
- name_length);
- }
- uint64_t src_loc = reinterpret_cast<uint64_t>(ptr);
+ uint64_t src_loc = tracy::Profiler::AllocSourceLocation(
+ line, file_name, file_name_length, function_name, function_name_length,
+ name, name_length);
const iree_zone_id_t zone_id = tracy::GetProfiler().GetNextZoneId();
@@ -152,3 +140,19 @@
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
+
+#if defined(__cplusplus) && \
+ (IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING)
+
+void* operator new(size_t count) noexcept {
+ auto ptr = malloc(count);
+ IREE_TRACE_ALLOC(ptr, count);
+ return ptr;
+}
+
+void operator delete(void* ptr) noexcept {
+ IREE_TRACE_FREE(ptr);
+ free(ptr);
+}
+
+#endif // __cplusplus && IREE_TRACING_FEATURE_ALLOCATION_TRACKING
diff --git a/iree/base/tracing.h b/iree/base/tracing.h
index 03de9b6..d8026a9 100644
--- a/iree/base/tracing.h
+++ b/iree/base/tracing.h
@@ -355,14 +355,14 @@
#define IREE_TRACE_ALLOC(ptr, size) \
___tracy_emit_memory_alloc_callstack(ptr, size, \
- IREE_TRACING_MAX_CALLSTACK_DEPTH)
+ IREE_TRACING_MAX_CALLSTACK_DEPTH, 0)
#define IREE_TRACE_FREE(ptr) \
- ___tracy_emit_memory_free_callstack(ptr, IREE_TRACING_MAX_CALLSTACK_DEPTH)
+ ___tracy_emit_memory_free_callstack(ptr, IREE_TRACING_MAX_CALLSTACK_DEPTH, 0)
#else
-#define IREE_TRACE_ALLOC(ptr, size) ___tracy_emit_memory_alloc(ptr, size)
-#define IREE_TRACE_FREE(ptr) ___tracy_emit_memory_free(ptr)
+#define IREE_TRACE_ALLOC(ptr, size) ___tracy_emit_memory_alloc(ptr, size, 0)
+#define IREE_TRACE_FREE(ptr) ___tracy_emit_memory_free(ptr, 0)
#endif // IREE_TRACING_FEATURE_ALLOCATION_CALLSTACKS
@@ -371,24 +371,11 @@
#define IREE_TRACE_FREE(ptr)
#endif // IREE_TRACING_FEATURE_ALLOCATION_TRACKING
-#ifdef __cplusplus
-
-#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING
-
-inline void* operator new(size_t count) {
- auto ptr = malloc(count);
- IREE_TRACE_ALLOC(ptr, count);
- return ptr;
-}
-
-inline void operator delete(void* ptr) noexcept {
- IREE_TRACE_FREE(ptr);
- free(ptr);
-}
-
-#endif // IREE_TRACING_FEATURE_ALLOCATION_TRACKING
-
-#endif // __cplusplus
+#if defined(__cplusplus) && \
+ (IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING)
+void* operator new(size_t count) noexcept;
+void operator delete(void* ptr) noexcept;
+#endif // __cplusplus && IREE_TRACING_FEATURE_ALLOCATION_TRACKING
//===----------------------------------------------------------------------===//
// Instrumentation C++ RAII types, wrappers, and macros
diff --git a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
index 8123574..e6fe664 100644
--- a/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/HLOToHLOPreprocessing.cpp
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <numeric>
+
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
@@ -56,6 +58,13 @@
[](APInt v) -> bool { return !v.isNullValue(); });
}
+static DenseIntElementsAttr make1DElementsAttr(PatternRewriter &rewriter,
+ ArrayRef<int64_t> integers) {
+ auto type = RankedTensorType::get({static_cast<int64_t>(integers.size())},
+ rewriter.getIntegerType(64));
+ return DenseIntElementsAttr::get(type, integers);
+}
+
class DecomposeLog1PPattern : public OpRewritePattern<mhlo::Log1pOp> {
public:
using OpRewritePattern<mhlo::Log1pOp>::OpRewritePattern;
@@ -324,6 +333,240 @@
}
};
+// Rewrites rank-3 mhlo.dot_general so lhs contraction dimension is
+// inner most (2) and rhs contraction dimension is dim right after batch
+// dimension. The pattern inserts transposes so the dot_general always has the
+// form: {batch_dim, parallel, contraction}.{batch_dim, contraction, parallel}
+class TransposeRank3GenericDotGeneral
+ : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+ using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+ PatternRewriter &rewriter) const override {
+ auto lhsShapeType = op.lhs().getType().dyn_cast<RankedTensorType>();
+ auto rhsShapeType = op.rhs().getType().dyn_cast<RankedTensorType>();
+ auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
+
+ if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+ if (resultType.getRank() != 3) return failure();
+
+ if (op.dot_dimension_numbers().lhs_contracting_dimensions().size() != 1 ||
+ op.dot_dimension_numbers().rhs_contracting_dimensions().size() != 1)
+ return failure();
+
+ int64_t lhsBatchDim = (*op.dot_dimension_numbers()
+ .lhs_batching_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ int64_t rhsBatchDim = (*op.dot_dimension_numbers()
+ .rhs_batching_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ int64_t lhsContractionDim = (*op.dot_dimension_numbers()
+ .lhs_contracting_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ int64_t rhsContractionDim = (*op.dot_dimension_numbers()
+ .rhs_contracting_dimensions()
+ .int_value_begin())
+ .getSExtValue();
+ // Only accept rank-3 tensors with dim order when dims are :
+ // lhs : {batch_dim, contraction, parallel}
+ // rhs : {batch_dim, parallel, contraction}
+ if (lhsBatchDim != 0 || rhsBatchDim != 0) return failure();
+ // No transposes are needed.
+ if (lhsContractionDim == 2 && rhsContractionDim == 1) return failure();
+
+ Value lhs = op.lhs(), rhs = op.rhs();
+
+ // transpose {batch_dim, contraction, parallel} case.
+ if (lhsContractionDim == 1) {
+ Type transposedType = RankedTensorType::get(
+ {lhsShapeType.getDimSize(0), lhsShapeType.getDimSize(2),
+ lhsShapeType.getDimSize(1)},
+ resultType.getElementType());
+ lhs = rewriter.create<mhlo::TransposeOp>(
+ op.getLoc(), transposedType, lhs,
+ make1DElementsAttr(rewriter, {0, 2, 1}));
+ }
+
+ // transpose {batch_dim, contraction, parallel} case.
+ if (rhsContractionDim == 2) {
+ Type transposedType = RankedTensorType::get(
+ {rhsShapeType.getDimSize(0), rhsShapeType.getDimSize(2),
+ rhsShapeType.getDimSize(1)},
+ resultType.getElementType());
+ rhs = rewriter.create<mhlo::TransposeOp>(
+ op.getLoc(), transposedType, rhs,
+ make1DElementsAttr(rewriter, {0, 2, 1}));
+ }
+
+ auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+ /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*lhs_contracting_dimensions=*/make1DElementsAttr(rewriter, {2}),
+ /*rhs_contracting_dimensions=*/
+ make1DElementsAttr(rewriter, {1}), rewriter.getContext());
+
+ Value result = rewriter.create<mhlo::DotGeneralOp>(
+ op.getLoc(), op.getType(), lhs, rhs, dimensionNumbers,
+ op.precision_configAttr());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
+// Rewrite mhlo.dot_general to operate on rank-3 tensors when reduction dims are
+// in consecutive order and not spliting the domain. This pattern inserts
+// reshapes to collapse consecutive reduction and parallel dims to always
+// generate a rank-3 dot_general op.
+class RankReducedDotGeneral : public OpRewritePattern<mhlo::DotGeneralOp> {
+ public:
+ using OpRewritePattern<mhlo::DotGeneralOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::DotGeneralOp op,
+ PatternRewriter &rewriter) const override {
+ auto lhsShapeType = op.lhs().getType().dyn_cast<ShapedType>();
+ auto rhsShapeType = op.rhs().getType().dyn_cast<ShapedType>();
+ auto resultType = op.getResult().getType().dyn_cast<ShapedType>();
+
+ if (!lhsShapeType || !rhsShapeType || !resultType) return failure();
+ if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape())
+ return failure();
+ if (resultType.getRank() <= 3) return failure();
+
+ mhlo::DotDimensionNumbers dimNumbers = op.dot_dimension_numbers();
+ auto lhsBatchingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.lhs_batching_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+ auto rhsBatchingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.rhs_batching_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+ auto lhsContractingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.lhs_contracting_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+ auto rhsContractingDims = llvm::to_vector<4>(
+ llvm::map_range(dimNumbers.rhs_contracting_dimensions(),
+ [](APInt v) { return v.getSExtValue(); }));
+
+ if (lhsBatchingDims.empty() || rhsBatchingDims.empty()) return failure();
+
+ llvm::sort(lhsBatchingDims);
+ llvm::sort(lhsContractingDims);
+ llvm::sort(rhsBatchingDims);
+ llvm::sort(rhsContractingDims);
+
+ auto isConsecutive = [](ArrayRef<int64_t> array) {
+ for (int i = 1; i < array.size(); ++i) {
+ if (array[i] - array[i - 1] != 1) return false;
+ }
+ return true;
+ };
+
+ auto isDomainSplit = [](ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> batchingDims,
+ ArrayRef<int64_t> contractingDims) {
+ // Batching and contracting are contiguous.
+ if ((contractingDims.front() - batchingDims.back()) == 1) return false;
+ // Contracting dims are inner most.
+ if (contractingDims.back() == (shape.size() - 1)) return false;
+ return true;
+ };
+
+ if (!isConsecutive(lhsBatchingDims) || !isConsecutive(lhsContractingDims) ||
+ !isConsecutive(rhsBatchingDims) || !isConsecutive(rhsContractingDims))
+ return failure();
+
+ if (isDomainSplit(lhsShapeType.getShape(), lhsBatchingDims,
+ lhsContractingDims) ||
+ isDomainSplit(rhsShapeType.getShape(), rhsBatchingDims,
+ rhsContractingDims))
+ return failure();
+
+ // Collapsing shape into a rank-3 tensor, returns newCollabsedShape
+ // contraction and parallel dim indices.
+ auto computeCollapsedShape = [](ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> batchingDims,
+ ArrayRef<int64_t> contractingDims) {
+ auto newRank =
+ shape.size() - batchingDims.size() - contractingDims.size() + 2;
+ auto batchingSize = std::accumulate(
+ batchingDims.begin(), batchingDims.end(), 1,
+ [shape](const int64_t accum, const int64_t index) -> int64_t {
+ return accum * shape[index];
+ });
+ auto contractingSize = std::accumulate(
+ contractingDims.begin(), contractingDims.end(), 1,
+ [shape](const int64_t accum, const int64_t index) -> int64_t {
+ return accum * shape[index];
+ });
+
+ int parallelDimIndex, contractingDimIndex, parallelDimSize = 1;
+ if (contractingDims.front() - batchingDims.back() > 1) {
+ parallelDimIndex = 1;
+ contractingDimIndex = 2;
+ for (int i = batchingDims.back() + 1; i < contractingDims.front();
+ ++i) {
+ parallelDimSize *= shape[i];
+ }
+ } else {
+ contractingDimIndex = 1;
+ parallelDimIndex = 2;
+ for (int i = contractingDims.back() + 1; i < shape.size(); ++i) {
+ parallelDimSize *= shape[i];
+ }
+ }
+ llvm::SmallVector<int64_t, 4> newShape(newRank);
+ newShape[0] = batchingSize;
+ newShape[contractingDimIndex] = contractingSize;
+ newShape[parallelDimIndex] = parallelDimSize;
+ return std::make_tuple(newShape, contractingDimIndex, parallelDimIndex);
+ };
+
+ int lhsContractingDimIndex, rhsContractingDimIndex, lhsParallelDimIndex,
+ rhsParallelDimIndex;
+ SmallVector<int64_t, 4> lhsNewShape, rhsNewShape;
+ std::tie(lhsNewShape, lhsContractingDimIndex, lhsParallelDimIndex) =
+ computeCollapsedShape(lhsShapeType.getShape(), lhsBatchingDims,
+ lhsContractingDims);
+
+ std::tie(rhsNewShape, rhsContractingDimIndex, rhsParallelDimIndex) =
+ computeCollapsedShape(rhsShapeType.getShape(), rhsBatchingDims,
+ rhsContractingDims);
+ SmallVector<int64_t, 4> resultNewShape = {lhsNewShape[0],
+ lhsNewShape[lhsParallelDimIndex],
+ rhsNewShape[rhsParallelDimIndex]};
+ Type dotGeneralResultType =
+ RankedTensorType::get(resultNewShape, resultType.getElementType());
+
+ auto loc = op.getLoc();
+ Value reshapedLhs = rewriter.create<mhlo::ReshapeOp>(
+ loc, RankedTensorType::get(lhsNewShape, lhsShapeType.getElementType()),
+ op.lhs());
+ Value reshapedRhs = rewriter.create<mhlo::ReshapeOp>(
+ loc, RankedTensorType::get(rhsNewShape, rhsShapeType.getElementType()),
+ op.rhs());
+ auto dimensionNumbers = mhlo::DotDimensionNumbers::get(
+ /*lhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*rhs_batching_dimensions=*/make1DElementsAttr(rewriter, {0}),
+ /*lhs_contracting_dimensions=*/
+ make1DElementsAttr(rewriter, {lhsContractingDimIndex}),
+ /*rhs_contracting_dimensions=*/
+ make1DElementsAttr(rewriter, {rhsContractingDimIndex}),
+ rewriter.getContext());
+ Value dotGeneralResult = rewriter.create<mhlo::DotGeneralOp>(
+ loc, dotGeneralResultType, reshapedLhs, reshapedRhs, dimensionNumbers,
+ op.precision_configAttr());
+
+ Value result =
+ rewriter.create<mhlo::ReshapeOp>(loc, resultType, dotGeneralResult);
+ rewriter.replaceOp(op, result);
+
+ return success();
+ }
+}; // namespace
+
// clang-format off
//
// Reorder BroadcastInDimOp and N-ary elementwise op.
@@ -425,6 +668,11 @@
AdjustDepthwiseFilterShape, DecomposeLog1PPattern,
DecomposeExpM1Pattern>(context);
+ // dot_general canoncalization patterns.
+ mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, context);
+ patterns.insert<RankReducedDotGeneral, TransposeRank3GenericDotGeneral>(
+ context);
+
// Unary elementwise op.
patterns.insert<
ReorderBroadcastInDimOpAndElementwiseOp<mhlo::AbsOp>,
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
new file mode 100644
index 0000000..67b8f46
--- /dev/null
+++ b/iree/compiler/Dialect/Flow/Transforms/test/hlo_to_hlo_preprocessing_canoncalize_dot_general.mlir
@@ -0,0 +1,104 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-flow-hlo-to-hlo-preprocessing %s | IreeFileCheck %s
+
+func @dot_general_to_dot(%arg0: tensor<1x32x128x4xf32>, %arg1: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<> : tensor<0xi64>,
+ lhs_contracting_dimensions = dense<[2, 3]> : tensor<2xi64>,
+ rhs_batching_dimensions = dense<> : tensor<0xi64>,
+ rhs_contracting_dimensions = dense<[0, 1]> : tensor<2xi64>
+ }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x32x128x4xf32>, tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+ return %0 : tensor<1x32x8x64xf32>
+}
+
+// CHECK: dot_general_to_dot(%[[ARG0:.+]]: tensor<1x32x128x4xf32>, %[[ARG1:.+]]: tensor<128x4x8x64xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x32x128x4xf32>) -> tensor<32x512xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<128x4x8x64xf32>) -> tensor<512x512xf32>
+// CHECK: %[[DOT:.+]] = "mhlo.dot"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT]]) : (tensor<32x512xf32>) -> tensor<1x32x8x64xf32>
+// CHECK: return %[[RESULT]] : tensor<1x32x8x64xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced(%arg0: tensor<1x8x32x64xf32>, %arg1 : tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+ }, name = "dot_general_to_dot", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x32x64xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED]], %[[ARG1_RESHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_a_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<2> : tensor<1xi64>
+ }, name = "dot_general_to_dot_trans_a", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x64x32xf32>, tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_a_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x64x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_RESHAPED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_RSSHAPED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG0_RESHAPED_TR:.+]] = "mhlo.transpose"(%[[ARG0_RESHAPED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_RESHAPED_TR]], %[[ARG1_RSSHAPED]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_b_transposed(%arg0: tensor<1x8x32x64xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<3> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+ }, name = "dot_general_to_dot_trans_b", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x32x64xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_b_transposed(%[[ARG0:.+]]: tensor<1x8x32x64xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
+
+
+// -----
+
+func @dot_general_to_dot_general_rank_reduced_ab_transposed(%arg0: tensor<1x8x64x32xf32>, %arg1: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32> {
+ %0 = "mhlo.dot_general"(%arg0, %arg1) {
+ dot_dimension_numbers = {
+ lhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
+ rhs_batching_dimensions = dense<[0, 1]> : tensor<2xi64>,
+ rhs_contracting_dimensions = dense<3> : tensor<1xi64>
+ }, name = "dot_general_to_dot_trans_ab", precision_config = ["DEFAULT", "DEFAULT"]
+ } : (tensor<1x8x64x32xf32>, tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+ return %0 : tensor<1x8x32x32xf32>
+}
+// CHECK: dot_general_to_dot_general_rank_reduced_ab_transposed(%[[ARG0:.+]]: tensor<1x8x64x32xf32>, %[[ARG1:.+]]: tensor<1x8x32x64xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: %[[ARG0_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG0]]) : (tensor<1x8x64x32xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[ARG1_REHSPAED:.+]] = "mhlo.reshape"(%[[ARG1]]) : (tensor<1x8x32x64xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG0_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG0_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x64x32xf32>) -> tensor<8x32x64xf32>
+// CHECK: %[[ARG1_REHSPAED_TR:.+]] = "mhlo.transpose"(%[[ARG1_REHSPAED]]) {permutation = dense<[0, 2, 1]> : tensor<3xi64>} : (tensor<8x32x64xf32>) -> tensor<8x64x32xf32>
+// CHECK: %[[DOT_RESULT:.+]] = "mhlo.dot_general"(%[[ARG0_REHSPAED_TR]], %[[ARG1_REHSPAED_TR]])
+// CHECK: %[[RESULT:.+]] = "mhlo.reshape"(%[[DOT_RESULT]]) : (tensor<8x32x32xf32>) -> tensor<1x8x32x32xf32>
+// CHECK: return %[[RESULT]] : tensor<1x8x32x32xf32>
diff --git a/iree/test/e2e/xla_ops/BUILD b/iree/test/e2e/xla_ops/BUILD
index 55f8428..2bbd6b4 100644
--- a/iree/test/e2e/xla_ops/BUILD
+++ b/iree/test/e2e/xla_ops/BUILD
@@ -52,6 +52,7 @@
"cosine.mlir",
"divide.mlir",
"dot.mlir",
+ "dot_general.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"floor.mlir",
@@ -104,6 +105,7 @@
"cosine.mlir",
"divide.mlir",
"dot.mlir",
+ "dot_general.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"floor.mlir",
diff --git a/iree/test/e2e/xla_ops/CMakeLists.txt b/iree/test/e2e/xla_ops/CMakeLists.txt
index ae2a8c1..4bd7219 100644
--- a/iree/test/e2e/xla_ops/CMakeLists.txt
+++ b/iree/test/e2e/xla_ops/CMakeLists.txt
@@ -45,6 +45,7 @@
"cosine.mlir"
"divide.mlir"
"dot.mlir"
+ "dot_general.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"
@@ -97,6 +98,7 @@
"cosine.mlir"
"divide.mlir"
"dot.mlir"
+ "dot_general.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"floor.mlir"
diff --git a/iree/testing/vulkan/BUILD b/iree/testing/vulkan/BUILD
index bf3fdce..4e235aa 100644
--- a/iree/testing/vulkan/BUILD
+++ b/iree/testing/vulkan/BUILD
@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+load(
+ "//iree:build_defs.oss.bzl",
+ "PLATFORM_VULKAN_DEPS",
+)
+
package(
default_visibility = ["//visibility:public"],
features = ["layering_check"],
@@ -37,3 +42,33 @@
"@vulkan_sdk//:sdk",
],
)
+
+cc_binary(
+ name = "iree-run-module-vulkan-gui",
+ srcs = ["iree-run-module-vulkan-gui-main.cc"],
+ linkopts = select({
+ "@bazel_tools//src/conditions:windows": [
+ "-SUBSYSTEM:WINDOWS",
+ ],
+ "//conditions:default": [],
+ }),
+ tags = [
+ "manual",
+ "nokokoro",
+ ],
+ deps = [
+ ":vulkan_gui_util",
+ "//iree/base:init",
+ "//iree/base:main",
+ "//iree/base:status",
+ "//iree/base:tracing",
+ "//iree/modules/hal",
+ "//iree/hal/vulkan:vulkan_driver_module",
+ "//iree/tools/utils:vm_util",
+ "//iree/vm",
+ "//iree/vm:bytecode_module",
+ "//iree/vm:ref_cc",
+ "@com_google_absl//absl/flags:flag",
+ "@sdl2//:SDL2",
+ ] + PLATFORM_VULKAN_DEPS,
+)
diff --git a/iree/testing/vulkan/CMakeLists.txt b/iree/testing/vulkan/CMakeLists.txt
index 2c995a9..6ceb5dc 100644
--- a/iree/testing/vulkan/CMakeLists.txt
+++ b/iree/testing/vulkan/CMakeLists.txt
@@ -44,3 +44,31 @@
SDL2-static
Vulkan::Vulkan
)
+
+if(${CMAKE_HOST_SYSTEM_NAME} STREQUAL "Windows")
+ set(_GUI_LINKOPTS "-SUBSYSTEM:WINDOWS")
+else()
+ set(_GUI_LINKOPTS "")
+endif()
+
+iree_cc_binary(
+ NAME
+ iree-run-module-vulkan-gui
+ SRCS
+ "iree-run-module-vulkan-gui-main.cc"
+ DEPS
+ ::vulkan_gui_util
+ absl::flags
+ iree::base::file_io
+ iree::base::init
+ iree::base::main
+ iree::base::status
+ iree::base::tracing
+ iree::hal::vulkan::vulkan_driver_module
+ iree::modules::hal
+ iree::tools::utils::vm_util
+ iree::vm
+ iree::vm::bytecode_module
+ LINKOPTS
+ "${_GUI_LINKOPTS}"
+)
diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
new file mode 100644
index 0000000..ec1349b
--- /dev/null
+++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
@@ -0,0 +1,454 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Vulkan GUI utility functions
+// Other matters here: we need to pull in this first to make sure Vulkan API
+// prototypes are defined so that we can statically link against them.
+#include "iree/testing/vulkan/vulkan_gui_util.h"
+
+// Other dependencies (helpers, etc.)
+#include "absl/flags/flag.h"
+#include "iree/base/file_io.h"
+#include "iree/base/init.h"
+#include "iree/base/main.h"
+#include "iree/base/status.h"
+#include "iree/modules/hal/hal_module.h"
+#include "iree/tools/utils/vm_util.h"
+#include "iree/vm/api.h"
+#include "iree/vm/bytecode_module.h"
+
+ABSL_FLAG(std::string, module_file, "-",
+ "File containing the module to load that contains the entry "
+ "function. Defaults to stdin.");
+
+ABSL_FLAG(std::string, entry_function, "",
+ "Name of a function contained in the module specified by input_file "
+ "to run.");
+
+ABSL_FLAG(std::vector<std::string>, inputs, {},
+ "A comma-separated list of of input buffers of the format:"
+ "[shape]xtype=[value]\n"
+ "2x2xi32=1 2 3 4\n"
+ "Optionally, brackets may be used to separate the element values. "
+ "They are ignored by the parser.\n"
+ "2x2xi32=[[1 2][3 4]]\n"
+ "Due to the absence of repeated flags in absl, commas should not be "
+ "used to separate elements. They are reserved for separating input "
+ "values:\n"
+ "2x2xi32=[[1 2][3 4]], 1x2xf32=[[1 2]]");
+
+ABSL_FLAG(std::string, inputs_file, "",
+ "Provides a file for input shapes and optional values (see "
+ "ParseToVariantListFromFile in vm_util.h for details)");
+
+static VkAllocationCallbacks* g_Allocator = NULL;
+static VkInstance g_Instance = VK_NULL_HANDLE;
+static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
+static VkDevice g_Device = VK_NULL_HANDLE;
+static uint32_t g_QueueFamily = (uint32_t)-1;
+static VkQueue g_Queue = VK_NULL_HANDLE;
+static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
+static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
+
+static ImGui_ImplVulkanH_Window g_MainWindowData;
+static uint32_t g_MinImageCount = 2;
+static bool g_SwapChainRebuild = false;
+static int g_SwapChainResizeWidth = 0;
+static int g_SwapChainResizeHeight = 0;
+
+namespace iree {
+namespace {
+
+void check_vk_result(VkResult err) {
+ if (err == 0) return;
+ IREE_LOG(FATAL) << "VkResult: " << err;
+}
+
+void CleanupVulkan() {
+ vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
+
+ vkDestroyDevice(g_Device, g_Allocator);
+ vkDestroyInstance(g_Instance, g_Allocator);
+}
+
+void CleanupVulkanWindow() {
+ ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
+ g_Allocator);
+}
+
+StatusOr<std::string> GetModuleContentsFromFlags() {
+ auto module_file = absl::GetFlag(FLAGS_module_file);
+ std::string contents;
+ if (module_file == "-") {
+ contents = std::string{std::istreambuf_iterator<char>(std::cin),
+ std::istreambuf_iterator<char>()};
+ } else {
+ IREE_ASSIGN_OR_RETURN(contents, file_io::GetFileContents(module_file));
+ }
+ return contents;
+}
+
+// Runs the current IREE bytecode module and renders its result to a window
+// using ImGui.
+Status RunModuleAndUpdateImGuiWindow(
+ iree_hal_device_t* device, iree_vm_context_t* context,
+ iree_vm_function_t function, const std::string& function_name,
+ const vm::ref<iree_vm_list_t>& inputs,
+ const std::vector<RawSignatureParser::Description>& output_descs,
+ const std::string& window_title) {
+ vm::ref<iree_vm_list_t> outputs;
+ IREE_RETURN_IF_ERROR(iree_vm_list_create(/*element_type=*/nullptr,
+ output_descs.size(),
+ iree_allocator_system(), &outputs));
+
+ IREE_LOG(INFO) << "EXEC @" << function_name;
+ IREE_RETURN_IF_ERROR(iree_vm_invoke(context, function, /*policy=*/nullptr,
+ inputs.get(), outputs.get(),
+ iree_allocator_system()))
+ << "invoking function " << function_name;
+
+ std::ostringstream oss;
+ IREE_RETURN_IF_ERROR(PrintVariantList(output_descs, outputs.get(), &oss))
+ << "printing results";
+
+ outputs.reset();
+
+ ImGui::Begin(window_title.c_str(), /*p_open=*/nullptr,
+ ImGuiWindowFlags_AlwaysAutoResize);
+
+ ImGui::Text("Entry function:");
+ ImGui::Text(function_name.c_str());
+ ImGui::Separator();
+
+ ImGui::Text("Invocation result:");
+ ImGui::Text(oss.str().c_str());
+ ImGui::Separator();
+
+ // Framerate counter.
+ ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
+ 1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
+
+ ImGui::End();
+ return OkStatus();
+}
+} // namespace
+} // namespace iree
+
+int iree::IreeMain(int argc, char** argv) {
+ iree::InitializeEnvironment(&argc, &argv);
+
+ // --------------------------------------------------------------------------
+ // Create a window.
+ if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
+ IREE_LOG(FATAL) << "Failed to initialize SDL";
+ return 1;
+ }
+
+ // Setup window
+ SDL_WindowFlags window_flags = (SDL_WindowFlags)(
+ SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
+ SDL_Window* window = SDL_CreateWindow(
+ "IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
+ SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
+
+ // Setup Vulkan
+ iree_hal_vulkan_features_t iree_vulkan_features =
+ static_cast<iree_hal_vulkan_features_t>(
+ IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS |
+ IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS |
+ IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS);
+ std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
+ std::vector<const char*> extensions =
+ GetInstanceExtensions(window, iree_vulkan_features);
+ SetupVulkan(iree_vulkan_features, layers.data(), layers.size(),
+ extensions.data(), extensions.size(), g_Allocator, &g_Instance,
+ &g_QueueFamily, &g_PhysicalDevice, &g_Queue, &g_Device,
+ &g_DescriptorPool);
+
+ // Create Window Surface
+ VkSurfaceKHR surface;
+ VkResult err;
+ if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
+ printf("Failed to create Vulkan surface.\n");
+ return 1;
+ }
+
+ // Create Framebuffers
+ int w, h;
+ SDL_GetWindowSize(window, &w, &h);
+ ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
+ SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
+ g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
+
+ // Setup Dear ImGui context
+ IMGUI_CHECKVERSION();
+ ImGui::CreateContext();
+ ImGuiIO& io = ImGui::GetIO();
+ (void)io;
+
+ ImGui::StyleColorsDark();
+
+ // Setup Platform/Renderer bindings
+ ImGui_ImplSDL2_InitForVulkan(window);
+ ImGui_ImplVulkan_InitInfo init_info = {};
+ init_info.Instance = g_Instance;
+ init_info.PhysicalDevice = g_PhysicalDevice;
+ init_info.Device = g_Device;
+ init_info.QueueFamily = g_QueueFamily;
+ init_info.Queue = g_Queue;
+ init_info.PipelineCache = g_PipelineCache;
+ init_info.DescriptorPool = g_DescriptorPool;
+ init_info.Allocator = g_Allocator;
+ init_info.MinImageCount = g_MinImageCount;
+ init_info.ImageCount = wd->ImageCount;
+ init_info.CheckVkResultFn = check_vk_result;
+ ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
+
+ // Upload Fonts
+ {
+ // Use any command queue
+ VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
+ VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
+
+ err = vkResetCommandPool(g_Device, command_pool, 0);
+ check_vk_result(err);
+ VkCommandBufferBeginInfo begin_info = {};
+ begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+ begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+ err = vkBeginCommandBuffer(command_buffer, &begin_info);
+ check_vk_result(err);
+
+ ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
+
+ VkSubmitInfo end_info = {};
+ end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
+ end_info.commandBufferCount = 1;
+ end_info.pCommandBuffers = &command_buffer;
+ err = vkEndCommandBuffer(command_buffer);
+ check_vk_result(err);
+ err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
+ check_vk_result(err);
+
+ err = vkDeviceWaitIdle(g_Device);
+ check_vk_result(err);
+ ImGui_ImplVulkan_DestroyFontUploadObjects();
+ }
+ // --------------------------------------------------------------------------
+
+ // --------------------------------------------------------------------------
+ // Setup IREE.
+ // This call to |iree_api_init| is not technically required, but it is
+ // included for completeness.
+ IREE_CHECK_OK(iree_api_init(&argc, &argv));
+
+ // Check API version.
+ iree_api_version_t actual_version;
+ iree_status_t status =
+ iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
+ if (iree_status_is_ok(status)) {
+ IREE_LOG(INFO) << "IREE runtime API version " << actual_version;
+ } else {
+ IREE_LOG(FATAL) << "Unsupported runtime API version " << actual_version;
+ }
+
+ // Register HAL module types.
+ IREE_CHECK_OK(iree_hal_module_register_types());
+
+ // Create a runtime Instance.
+ iree_vm_instance_t* iree_instance = nullptr;
+ IREE_CHECK_OK(
+ iree_vm_instance_create(iree_allocator_system(), &iree_instance));
+
+ // Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
+ IREE_LOG(INFO) << "Creating Vulkan driver/device";
+ // Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
+ iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
+ IREE_CHECK_OK(iree_hal_vulkan_syms_create(
+ reinterpret_cast<void*>(&vkGetInstanceProcAddr), &iree_vk_syms));
+ // Create the driver sharing our VkInstance.
+ iree_hal_driver_t* iree_vk_driver = nullptr;
+ iree_hal_vulkan_driver_options_t options;
+ options.api_version = VK_API_VERSION_1_0;
+ options.features = static_cast<iree_hal_vulkan_features_t>(
+ IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS |
+ IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS);
+ IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
+ options, iree_vk_syms, g_Instance, &iree_vk_driver));
+ // Create a device sharing our VkDevice and queue. This makes capturing with
+ // vendor tools easier because we will have sync compute residing in the
+ // rendered frame.
+ iree_hal_vulkan_queue_set_t compute_queue_set;
+ compute_queue_set.queue_family_index = g_QueueFamily;
+ compute_queue_set.queue_indices = 1 << 0;
+ iree_hal_vulkan_queue_set_t transfer_queue_set;
+ transfer_queue_set.queue_indices = 0;
+ iree_hal_device_t* iree_vk_device = nullptr;
+ IREE_CHECK_OK(iree_hal_vulkan_driver_wrap_device(
+ iree_vk_driver, g_PhysicalDevice, g_Device, compute_queue_set,
+ transfer_queue_set, &iree_vk_device));
+ // Create a HAL module using the HAL device.
+ iree_vm_module_t* hal_module = nullptr;
+ IREE_CHECK_OK(iree_hal_module_create(iree_vk_device, iree_allocator_system(),
+ &hal_module));
+
+ // Load bytecode module from embedded data.
+ IREE_LOG(INFO) << "Loading IREE byecode module...";
+ auto module_file_or = iree::GetModuleContentsFromFlags();
+ if (!module_file_or) {
+ IREE_LOG(FATAL) << "Error when reading module file"
+ << module_file_or.status();
+ }
+ iree_vm_module_t* bytecode_module = nullptr;
+ IREE_CHECK_OK(iree_vm_bytecode_module_create(
+ iree_const_byte_span_t{
+ reinterpret_cast<const uint8_t*>(module_file_or->data()),
+ module_file_or->size()},
+ iree_allocator_null(), iree_allocator_system(), &bytecode_module));
+
+ // Allocate a context that will hold the module state across invocations.
+ iree_vm_context_t* iree_context = nullptr;
+ std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
+ IREE_CHECK_OK(iree_vm_context_create_with_modules(
+ iree_instance, modules.data(), modules.size(), iree_allocator_system(),
+ &iree_context));
+ IREE_LOG(INFO) << "Context with modules is ready for use";
+
+ // Lookup the entry point function.
+ std::string entry_function = absl::GetFlag(FLAGS_entry_function);
+ iree_vm_function_t main_function;
+ IREE_CHECK_OK(bytecode_module->lookup_function(
+ bytecode_module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT,
+ iree_string_view_t{entry_function.data(), entry_function.size()},
+ &main_function));
+ iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
+ IREE_LOG(INFO) << "Resolved main function named '"
+ << std::string(main_function_name.data,
+ main_function_name.size)
+ << "'";
+
+ IREE_CHECK_OK(ValidateFunctionAbi(main_function));
+
+ auto main_function_input_descs = ParseInputSignature(main_function);
+ if (!main_function_input_descs.ok()) {
+ IREE_LOG(FATAL) << main_function_input_descs.status().ToString();
+ }
+ StatusOr<vm::ref<iree_vm_list_t>> main_function_inputs;
+ if (!absl::GetFlag(FLAGS_inputs_file).empty()) {
+ if (!absl::GetFlag(FLAGS_inputs).empty()) {
+ IREE_LOG(FATAL)
+ << "Expected only one of inputs and inputs_file to be set";
+ }
+ main_function_inputs = ParseToVariantListFromFile(
+ *main_function_input_descs, iree_hal_device_allocator(iree_vk_device),
+ absl::GetFlag(FLAGS_inputs_file));
+ } else {
+ main_function_inputs = ParseToVariantList(
+ *main_function_input_descs, iree_hal_device_allocator(iree_vk_device),
+ absl::GetFlag(FLAGS_inputs));
+ }
+ if (!main_function_inputs.ok()) {
+ IREE_LOG(FATAL) << main_function_inputs.status().ToString();
+ }
+
+ auto main_function_output_descs = ParseOutputSignature(main_function);
+ if (!main_function_output_descs.ok()) {
+ IREE_LOG(FATAL) << main_function_output_descs.status().ToString();
+ }
+
+ const std::string& window_title = absl::GetFlag(FLAGS_module_file);
+ // --------------------------------------------------------------------------
+
+ // --------------------------------------------------------------------------
+ // Main loop.
+ bool done = false;
+ while (!done) {
+ SDL_Event event;
+
+ while (SDL_PollEvent(&event)) {
+ if (event.type == SDL_QUIT) {
+ done = true;
+ }
+
+ ImGui_ImplSDL2_ProcessEvent(&event);
+ if (event.type == SDL_QUIT) done = true;
+ if (event.type == SDL_WINDOWEVENT &&
+ event.window.event == SDL_WINDOWEVENT_RESIZED &&
+ event.window.windowID == SDL_GetWindowID(window)) {
+ g_SwapChainResizeWidth = (int)event.window.data1;
+ g_SwapChainResizeHeight = (int)event.window.data2;
+ g_SwapChainRebuild = true;
+ }
+ }
+
+ if (g_SwapChainRebuild) {
+ g_SwapChainRebuild = false;
+ ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
+ ImGui_ImplVulkanH_CreateWindow(g_Instance, g_PhysicalDevice, g_Device,
+ &g_MainWindowData, g_QueueFamily,
+ g_Allocator, g_SwapChainResizeWidth,
+ g_SwapChainResizeHeight, g_MinImageCount);
+ g_MainWindowData.FrameIndex = 0;
+ }
+
+ // Start the Dear ImGui frame
+ ImGui_ImplVulkan_NewFrame();
+ ImGui_ImplSDL2_NewFrame(window);
+ ImGui::NewFrame();
+
+ // Custom window.
+ auto status = RunModuleAndUpdateImGuiWindow(
+ iree_vk_device, iree_context, main_function, entry_function,
+ main_function_inputs.value(), main_function_output_descs.value(),
+ window_title);
+ if (!status.ok()) {
+ IREE_LOG(FATAL) << status;
+ done = true;
+ continue;
+ }
+
+ // Rendering
+ ImGui::Render();
+ RenderFrame(wd, g_Device, g_Queue);
+
+ PresentFrame(wd, g_Queue);
+ }
+ // --------------------------------------------------------------------------
+
+ // --------------------------------------------------------------------------
+ // Cleanup
+ main_function_inputs.value().reset();
+
+ iree_vm_module_release(hal_module);
+ iree_vm_module_release(bytecode_module);
+ iree_vm_context_release(iree_context);
+ iree_hal_device_release(iree_vk_device);
+ iree_hal_driver_release(iree_vk_driver);
+ iree_hal_vulkan_syms_release(iree_vk_syms);
+ iree_vm_instance_release(iree_instance);
+
+ err = vkDeviceWaitIdle(g_Device);
+ check_vk_result(err);
+ ImGui_ImplVulkan_Shutdown();
+ ImGui_ImplSDL2_Shutdown();
+ ImGui::DestroyContext();
+
+ CleanupVulkanWindow();
+ CleanupVulkan();
+
+ SDL_DestroyWindow(window);
+ SDL_Quit();
+ // --------------------------------------------------------------------------
+
+ return 0;
+}
diff --git a/third_party/tracy b/third_party/tracy
index 864d86e..a9a09ab 160000
--- a/third_party/tracy
+++ b/third_party/tracy
@@ -1 +1 @@
-Subproject commit 864d86e8b6d21449474db5e9313dbff90aa9c24f
+Subproject commit a9a09ab0940408898fccfdcfe2bb8dc19b50f13c