Merge google -> main (#3337)
* 1e81da7a Merge main -> google
* 62c8861c Merge branch 'google' into main-to-google
* 8896f35a Synchronize submodules
* 2c02b9c1 Integrate TF at tensorflow/tensorflow@1454ee0907ee
* b843a99e Synchronize submodules
* bb076514 Integrate LLVM at llvm/llvm-project@8825fec37e73
* 9adef4b1 Synchronize submodules
* b56cefca Integrate LLVM at llvm/llvm-project@bfd7ee92ccec
* 52b019e0 Fix vkGetInstanceProcAddr returning nil issue.
* c795db9d Rename filecheck-lib target to match style of other targets
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 774412d..ba966db 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -99,7 +99,7 @@
def tf_saved_model_to_compiler_module(
saved_model_dir: str,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> Module:
"""Converts a TensorFlow SavedModel into a MLIR module.
@@ -108,8 +108,7 @@
Args:
saved_model_dir: Directory of the saved model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -131,18 +130,17 @@
def compile_tf_saved_model(
saved_model_dir: str,
- exported_names: Collection[str] = (),
- target_backends: Collection[str] = (),
+ exported_names: Sequence[str] = (),
+ target_backends: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
"""Compiles a TensorFlow SavedModel to IREE in one shot.
Args:
saved_model_dir: Directory of the saved model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ exported_names: Optional sequence representing the exported names to keep.
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -160,7 +158,7 @@
def tf_signature_def_saved_model_to_compiler_module(
saved_model_dir: str,
saved_model_tags: Set[str] = set(),
- exported_names: Collection[str] = [],
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> Module:
"""Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
@@ -168,8 +166,7 @@
Args:
saved_model_dir: Directory of the saved model.
saved_model_tags: Optional set of tags to use when loading the model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -194,8 +191,8 @@
def compile_tf_signature_def_saved_model(
saved_model_dir: str,
saved_model_tags: Set[str] = set(),
- exported_names: Collection[str] = (),
- target_backends: Collection[str] = (),
+ exported_names: Sequence[str] = (),
+ target_backends: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
"""Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
@@ -203,10 +200,9 @@
Args:
saved_model_dir: Directory of the saved model.
saved_model_tags: Optional set of tags to use when loading the model.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ exported_names: Optional sequence representing the exported names to keep.
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -222,7 +218,7 @@
def tf_module_to_compiler_module(
module: tf.Module,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None,
saved_model_dir: str = None) -> Module:
@@ -230,8 +226,7 @@
Args:
module: The tf.Module instance to convert to MLIR
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
@@ -259,8 +254,8 @@
def compile_tf_module(module: tf.Module,
- exported_names: Collection[str] = (),
- target_backends: Collection[str] = (),
+ exported_names: Sequence[str] = (),
+ target_backends: Sequence[str] = (),
pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
compiler_context: Optional[Context] = None,
saved_model_dir: str = None):
@@ -268,10 +263,9 @@
Args:
module: The tf.Module instance to convert to MLIR
- exported_names: Optional tuple of strings representing the exported names to
- keep.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ exported_names: Optional sequence representing the exported names to keep.
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to TF_IMPORT_PASS_PIPELINE.
compiler_context: The pyiree.compiler.Context() backing the module.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index fa3df03..aa9481e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -66,22 +66,22 @@
def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
- """Decodes --target_backends and creates unique names for their artifacts."""
+ """Decodes --target_backends and creates unique ids for them."""
backend_names = FLAGS.target_backends.split(",")
backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
- artifact_names = []
+ backend_ids = []
# If there are multiple copies of the same backend_name, index them. e.g.
# backend_names = ["tf", "iree_vmla", "tf"]
- # --> artifact_names = ["tf_0", "iree_vmla", "tf_1"]
+ # --> backend_ids = ["tf_0", "iree_vmla", "tf_1"]
for backend_name in backend_names:
if backend_name in backend_to_index:
- artifact_names.append(f"{backend_name}_{backend_to_index[backend_name]}")
+ backend_ids.append(f"{backend_name}_{backend_to_index[backend_name]}")
backend_to_index[backend_name] += 1
else:
- artifact_names.append(backend_name)
+ backend_ids.append(backend_name)
- return backend_names, artifact_names
+ return backend_names, backend_ids
def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
@@ -95,10 +95,10 @@
"""
if FLAGS.target_backends is not None:
logging.info("Using backends from command line: %s", FLAGS.target_backends)
- backend_names, names = _parse_target_backends()
+ backend_names, backend_ids = _parse_target_backends()
backends = [
- tf_utils.BackendInfo(backend, name)
- for backend, name in zip(backend_names, names)
+ tf_utils.BackendInfo(backend_name, backend_id)
+ for backend_name, backend_id in zip(backend_names, backend_ids)
]
else:
# If no backends are specified, use them all.
@@ -261,10 +261,11 @@
# Extract metadata from module and function.
self.module_name = module.module_name
self.compiled_paths = module.compiled_paths
- self.backend_name = module.backend
+ self.backend_name = module.backend_info.backend_name
+ self.backend_id = module.backend_info.backend_id
+ self.backend_driver = module.backend_info.driver
self.iree_serializable = module.iree_serializable()
self.tflite_serializable = module.tflite_serializable()
- self.backend_driver = module.backend_driver
self.function_name = function.__name__
self.function_sourcefile = inspect.getsourcefile(function)
source, start_line = inspect.getsourcelines(function)
@@ -276,9 +277,10 @@
self.module_name = _load_dict["module_name"]
self.compiled_paths = _load_dict["compiled_paths"]
self.backend_name = _load_dict["backend_name"]
+ self.backend_id = _load_dict["backend_id"]
+ self.backend_driver = _load_dict["backend_driver"]
self.iree_serializable = _load_dict["iree_serializable"]
self.tflite_serializable = _load_dict["tflite_serializable"]
- self.backend_driver = _load_dict["backend_driver"]
self.function_name = _load_dict["function_name"]
self.function_sourcefile = _load_dict["function_sourcefile"]
self.function_line_numbers = _load_dict["function_line_numbers"]
@@ -286,7 +288,7 @@
self.calls = _load_dict["calls"]
def __str__(self):
- header = (f"Trace of {self.module_name} compiled to '{self.backend_name}' "
+ header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
f"on function '{self.function_name}':")
# Give each call a number so it's easier to compare between multiple traces.
calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
@@ -327,11 +329,11 @@
if not calls_match:
logging.error("Comparision between '%s' and '%s' failed on method '%s'",
- ref_trace.backend_name, tar_trace.backend_name,
+ ref_trace.backend_id, tar_trace.backend_id,
ref_call.method)
- logging.error("Reference call '%s':\n%s", ref_trace.backend_name,
+ logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
ref_call)
- logging.error("Target call '%s':\n%s", tar_trace.backend_name, tar_call)
+ logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
traces_match = traces_match and calls_match
return traces_match
@@ -434,14 +436,20 @@
trace_dir: str, path to the directory to serialize the trace to.
"""
+ compiled_paths = None
+ if self.compiled_paths is not None:
+ # Convert to a dict to avoid the issues with serializing defaultdicts.
+ compiled_paths = dict(self.compiled_paths)
+
# Python serialization.
metadata = {
"module_name": self.module_name,
- "compiled_paths": self.compiled_paths,
+ "compiled_paths": compiled_paths,
"backend_name": self.backend_name,
+ "backend_id": self.backend_id,
+ "backend_driver": self.backend_driver,
"iree_serializable": self.iree_serializable,
"tflite_serializable": self.tflite_serializable,
- "backend_driver": self.backend_driver,
"function_name": self.function_name,
"function_sourcefile": self.function_sourcefile,
"function_line_numbers": self.function_line_numbers,
@@ -493,7 +501,7 @@
def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
- trace_dir = os.path.join(artifacts_dir, trace.backend_name, "traces",
+ trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
trace.function_name)
os.makedirs(trace_dir, exist_ok=True)
return trace_dir
@@ -559,17 +567,17 @@
def compile_tf_module(
module_class: Type[tf.Module], exported_names: Sequence[str] = ()
) -> Callable[[Any], Any]:
- """CompiledModuleTestCase decorator that compiles a tf.Module.
-
- A CompiledModule is created for each backend in --target_backends. They can
- be accessed individually via self.compiled_modules.backend_name or as a union
- via self.get_module().
+ """Compiles module_class to each backend that we test.
Args:
module_class: the tf.Module subclass to compile.
exported_names: optional iterable of strings representing which of
module_class's functions to compile. If exported_names is empty all
functions will be compiled.
+
+ Returns:
+ A 'Modules' namedtuple containing the reference module, target modules and
+ artifacts directory.
"""
# Setup the directory for saving compilation artifacts and traces.
@@ -580,7 +588,7 @@
f"{FLAGS.reference_backend}_ref")
tar_backend_infos = get_target_backends()
- compile_backend = lambda backend_info: backend_info.compile(
+ compile_backend = lambda backend_info: backend_info.compile_from_class(
module_class, exported_names, artifacts_dir)
ref_module = compile_backend(ref_backend_info)
@@ -631,7 +639,7 @@
failed_backend_indices = []
for i, tar_trace in enumerate(tar_traces):
logging.info("Comparing the reference backend '%s' with '%s'",
- ref_trace.backend_name, tar_trace.backend_name)
+ ref_trace.backend_id, tar_trace.backend_id)
traces_match = Trace.compare_traces(ref_trace, tar_trace)
if not traces_match:
failed_backend_indices.append(i)
@@ -649,7 +657,7 @@
if failed_backend_indices:
# Extract info for logging.
failed_backends = [
- tar_traces[i].backend_name for i in failed_backend_indices
+ tar_traces[i].backend_id for i in failed_backend_indices
]
self.fail(
"Comparision between the reference backend and the following targets "
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 6157c68..e1fb6e3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -117,8 +117,8 @@
# Only outputs
module.get_count()
- module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ module = tf_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('tf'))
trace = tf_test_utils.Trace(module, trace_function)
trace_function(tf_test_utils.TracedModule(module, trace))
@@ -140,13 +140,13 @@
module.increment()
module.decrement()
- tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ tf_module = tf_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('tf'))
tf_trace = tf_test_utils.Trace(tf_module, tf_function)
tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
- vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('iree_vmla'))
+ vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
@@ -161,13 +161,13 @@
def vmla_function(module):
module.increment_by(np.array([22.], dtype=np.float32))
- tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('tf'))
+ tf_module = tf_utils.TfCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('tf'))
tf_trace = tf_test_utils.Trace(tf_module, tf_function)
tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
- vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('iree_vmla'))
+ vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
@@ -178,12 +178,12 @@
def trace_function(module):
module.increment()
module.increment_by(np.array([81.], dtype=np.float32))
- module.increment_by_max(
- np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
+ module.increment_by_max(np.array([81], dtype=np.float32),
+ np.array([92], dtype=np.float32))
module.get_count()
- module = tf_utils.IreeCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo('iree_vmla'))
+ module = tf_utils.IreeCompiledModule.create_from_class(
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
trace = tf_test_utils.Trace(module, trace_function)
trace_function(tf_test_utils.TracedModule(module, trace))
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index fc8bab6..b3237f9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -16,6 +16,7 @@
# pylint: disable=protected-access
+import collections
import os
import random
import re
@@ -91,7 +92,7 @@
def _setup_mlir_crash_reproducer(
function: Callable[[Any], Any],
artifacts_dir: str,
- backend_name: str,
+ backend_id: str,
) -> Callable[[Any], Any]:
"""Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
@@ -100,7 +101,7 @@
Args:
function: The callable to decorate.
artifacts_dir: The directory to write the reproducer to.
- backend_name: The name of the backend `function` compiles to.
+ backend_id: The unique backend name to use when writting the reproducer.
Returns:
A function with the same API as the passed function.
@@ -110,7 +111,7 @@
# Set up a crash reproducer for debugging.
if artifacts_dir is not None:
compiler.Context.default_crash_reproducer_path = os.path.join(
- artifacts_dir, f"reproducer__{backend_name}.mlir")
+ artifacts_dir, f"reproducer__{backend_id}.mlir")
try:
results = function(*args, **kwargs)
except Exception: # pylint: disable=broad-except
@@ -135,7 +136,7 @@
MLIR for the module in TF's input dialect.
iree_input.mlir:
The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
- backend_name/compiled.vmfb:
+ backend_id/compiled.vmfb:
A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
Args:
@@ -167,7 +168,7 @@
compiled_path = None
if artifacts_dir is not None:
- backend_dir = os.path.join(artifacts_dir, backend_info.name)
+ backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
os.makedirs(backend_dir, exist_ok=True)
compiled_path = os.path.join(backend_dir, "compiled.vmfb")
logging.info("Saving compiled IREE module to: %s", compiled_path)
@@ -180,7 +181,7 @@
module: Type[tf.Module],
backend_info: "BackendInfo",
exported_names: Sequence[str] = (),
- artifacts_dir: str = None
+ artifacts_dir: str = None,
) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
"""Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
@@ -193,8 +194,7 @@
Args:
module: A tf.Module.
backend_info: BackendInfo with the details for compiling module to IREE.
- exported_names: Iterable of dotted function names to consider for
- compilation.
+ exported_names: Optional sequence representing the exported names to keep.
artifacts_dir: An optional string pointing to where compilation artifacts
should be saved. No compilation artifacts will be saved if this is not
provided.
@@ -212,29 +212,72 @@
artifacts_dir)
_compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
- backend_info.name)
+ backend_info.backend_id)
return _compile_module(module, exported_names, backend_info, artifacts_dir)
class CompiledModule(object):
"""Base class for the TF and IREE compiled modules."""
- def __init__(self, module_class: Type[tf.Module], backend_info: "BackendInfo",
- exported_names: Sequence[str], artifacts_dir: str):
- """Shared base constructor – not useful on its own."""
- self._module_class = module_class
- self._backend_info = backend_info
- self._exported_names = exported_names
- self._artifacts_dir = artifacts_dir
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ ):
+ """Shared base constructor – not useful on its own.
- # Public attributes:
- self.backend = self._backend_info.name
- self.backend_driver = self._backend_info.driver
- self.module_name = self._module_class.__name__
- self.compiled_paths = None
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ """
+ self.module_name = module_name
+ self.backend_info = backend_info
+ self.compiled_paths = compiled_paths
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ This is only implemented for IreeCompiledModule.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
raise NotImplementedError()
def iree_serializable(self):
@@ -244,14 +287,6 @@
return False
-def _get_non_inhereted_function_names(cls):
- """Gets all methods that cls has that its parents don't have."""
- names = set(dir(cls))
- for parent in cls.__bases__:
- names -= set(dir(parent))
- return list(names)
-
-
class _FunctionWrapper(object):
def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
@@ -277,57 +312,94 @@
class IreeCompiledModule(CompiledModule):
"""Iree compiled module."""
- def __init__(self,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Compile a tf.Module to the target backend in backend_info.
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ vm_module: rt.VmModule,
+ config: rt.Config,
+ ):
+ """Base constructor – Use one of the named constructors instead.
Args:
- module_class: the tf.Module subclass to compile.
- backend_info: an element of BackendInfo corresponding to the IREE backend
- to compile to.
- exported_names: an optional iterable of strings representing which of the
- module_class's functions to compile. If exported_names is empty all
- functions will be compiled.
- artifacts_dir: an optional path to save compilation artifacts to.
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ vm_module: A rt.VmModule containing compilation info to wrap.
+ config: A rt.Config containing compilation info to wrap.
"""
- super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._vm_module = vm_module
+ self._config = config
+ self.reinitialize()
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
set_random_seed()
- self._module_blob, compiled_path = _incrementally_compile_tf_module(
- module=module_class(),
+ module_instance = module_class()
+ return cls.create_from_instance(module_instance, backend_info,
+ exported_names, artifacts_dir)
+
+ @classmethod
+ def create_from_instance(cls,
+ module_instance: tf.Module,
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module instance to the target backend in backend_info.
+
+ Args:
+ module_instance: The tf.Module instance to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_blob, compiled_path = _incrementally_compile_tf_module(
+ module=module_instance,
backend_info=backend_info,
exported_names=exported_names,
artifacts_dir=artifacts_dir)
- self._module = rt.VmModule.from_flatbuffer(self._module_blob)
- self._config = rt.Config(driver_name=backend_info.driver)
+ vm_module = rt.VmModule.from_flatbuffer(module_blob)
+ config = rt.Config(driver_name=backend_info.driver)
- self.compiled_paths = None
+ compiled_paths = None
if compiled_path is not None:
- if not len(exported_names):
- # Get all method names on 'module_class' that aren't on 'tf.Module'.
- # This doesn't address all possbile scenarios.
- # TODO(meadowlark): Figure out how to get a list of all of the functions
- # that this module has access to via `pyiree.rt.system_api.BoundModule`.
- exported_names = _get_non_inhereted_function_names(module_class)
- self.compiled_paths = dict([
- (method, compiled_path) for method in exported_names
- ])
+ # IREE bundles every compiled method into the same compiled module.
+ compiled_paths = collections.defaultdict(lambda: compiled_path)
- self.reinitialize()
+ module_name = type(module_instance).__name__
+
+ return cls(module_name, backend_info, compiled_paths, vm_module, config)
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
# set_random_seed is not needed here because the model_class.__init__ is not
# called.
- self._context = rt.SystemContext(modules=[self._module],
+ self._context = rt.SystemContext(modules=[self._vm_module],
config=self._config)
def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
# Try to resolve it as a function.
- m = self._context.modules[self._module.name]
+ m = self._context.modules[self._vm_module.name]
f = m[attr]
return _IreeFunctionWrapper(self._context, f)
@@ -376,29 +448,54 @@
normalize TensorFlow's output to Numpy.
"""
- def __init__(self,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- """Wrap a tf.Module in a TFCompiledModule facade.
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ constructor: Callable[[], tf.Module],
+ exported_names: Sequence[str],
+ ):
+ """Base constructor – Use one of the named constructors instead.
Args:
- module_class: the tf.Module subclass to 'compile'.
- backend_info: one of the 'tf*' elements in BackendInfo.
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ constructor: A callable (class or function) which returns the tf.Module
+ subclass instance to wrap.
exported_names: an optional iterable of strings representing which of the
- module_class's functions should be callable. If exported_names is empty
- then all functions will be callable.
- artifacts_dir: an optional path to save compilation artifacts to. Has no
- effect for this subclass as nothing is compiled.
+ tf.Module subclass instance's functions should be callable. If
+ exported_names is empty then all functions will be callable.
"""
- super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ super().__init__(module_name, backend_info, compiled_paths=None)
+ self._constructor = constructor
+ self._exported_names = exported_names
self.reinitialize()
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
+ module_name = module_class.__name__
+ constructor = module_class
+ return cls(module_name, backend_info, constructor, exported_names)
+
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
set_random_seed()
- self._tf_module = self._module_class()
+ self._tf_module = self._constructor()
def __getattr__(self, attr: str) -> _TfFunctionWrapper:
# Try to resolve it as a function.
@@ -412,6 +509,14 @@
return _TfFunctionWrapper(f)
+def _get_non_inhereted_function_names(cls):
+ """Gets all methods that cls has that its parents don't have."""
+ names = set(dir(cls))
+ for parent in cls.__bases__:
+ names -= set(dir(parent))
+ return list(names)
+
+
def _get_concrete_functions(module_class: Type[tf.Module],
exported_names: Sequence[str] = ()):
"""Get concrete functions from non-inherited methods or exported_names."""
@@ -425,12 +530,12 @@
return functions, exported_names
-def compile_to_tflite(
+def tf_module_to_tflite_interpreters(
module_class: Type[tf.Module],
exported_names: Sequence[str] = (),
artifacts_dir: str = None
) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str]], None]:
- """Compile a dict of TFLite interpreters for the methods on module_class.
+ """Compile a tf.Module to TFLite interpreters for each of its methods.
Args:
module_class: A tf.Module subclass to compile with TFLite. If module_class
@@ -463,22 +568,14 @@
if artifacts_dir is not None:
compiled_paths[name] = tflite_path
+ # Convert module_class's methods into TFLite module byte-strings.
tflite_modules = []
- names = []
- if hasattr(module_class, "get_legacy_tflite_saved_model_converter_kwargs"):
- kwargs = module_class.get_legacy_tflite_saved_model_converter_kwargs()
- converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
- kwargs["model_path"],
- input_arrays=kwargs["input_arrays"],
- output_arrays=kwargs["output_arrays"])
+ functions, names = _get_concrete_functions(module_class, exported_names)
+ for function in functions:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
tflite_modules.append(converter.convert())
- names.append(kwargs["exported_name"])
- else:
- functions, names = _get_concrete_functions(module_class, exported_names)
- for function in functions:
- converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
- tflite_modules.append(converter.convert())
+ # Load each of the converted methods above into tf.lite.Interpreters.
for name, tflite_module in zip(names, tflite_modules):
if artifacts_dir is None:
with tempfile.TemporaryDirectory() as base_dir:
@@ -537,18 +634,50 @@
class TfLiteCompiledModule(CompiledModule):
"""Compiles a tf.Module with TFLite and allows it to be called."""
- def __init__(self,
- module_class: Type[tf.Module],
- backend_info: "BackendInfo",
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None):
- super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+ def __init__(
+ self,
+ module_name: str,
+ backend_info: "BackendInfo",
+ compiled_paths: Dict[str, str],
+ interpreters: Dict[str, tf.lite.Interpreter],
+ ):
+ """Base constructor – Use one of the named constructors instead.
+
+ Args:
+ module_name: A name for this compiled module. In most cases this will be
+ the name of the tf.Module subclass or instance that is compiled.
+ backend_info: BackendInfo with the details about compiling this module.
+ compiled_paths: A dictionary mapping compiled method names to file paths
+ corresponding to their serialized representations.
+ interpreters: A dict of tf.lite.Interpreters to make callable.
+ """
+ super().__init__(module_name, backend_info, compiled_paths)
+ self._interpreters = interpreters
+
+ @classmethod
+ def create_from_class(cls,
+ module_class: Type[tf.Module],
+ backend_info: "BackendInfo",
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None):
+ """Compile a tf.Module subclass to the target backend in backend_info.
+
+ Args:
+ module_class: The tf.Module subclass to compile.
+ backend_info: BackendInfo with the details for compiling module to IREE.
+ exported_names: Optional sequence representing the exported names to keep.
+ artifacts_dir: An optional string pointing to where compilation artifacts
+ should be saved. No compilation artifacts will be saved if this is not
+ provided.
+ """
set_random_seed()
- self._interpreters, self.compiled_paths = compile_to_tflite(
+ interpreters, compiled_paths = tf_module_to_tflite_interpreters(
module_class, exported_names, artifacts_dir)
+ module_name = module_class.__name__
+ return cls(module_name, backend_info, compiled_paths, interpreters)
def reinitialize(self):
- """Reinitializes to the initial state of the passed module_class."""
+ """Reinitializes all stateful variables."""
# This is a noop because TFLite (mostly) doesn't support stateful modules.
pass
@@ -594,36 +723,43 @@
},
}
- def __init__(self, backend_name: str, artifact_name: str = None):
+ def __init__(self, backend_name: str, backend_id: str = None):
"""Creates a BackendInfo with the compilation details for backend_name.
Args:
backend_name: a str specifying which backend to use. Should be one of
'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
- artifact_name: an optional str specifying what name to use when saving
- compiled artifacts.
+ backend_id: an optional str specifying what name to use when saving
+ compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
Raises:
KeyError: if backend_name is not one of ['tf', 'iree_vmla',
'iree_llvmjit', 'iree_vulkan'].
+ ValueError: if backend_id doesn't start with backend_name.
"""
if backend_name not in self._name_to_info:
raise KeyError(
"Expected backend_name to be one of "
f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+ if backend_id is not None and not backend_id.startswith(backend_name):
+ raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+ f"but got '{backend_id}'.")
+
+ self.backend_name = backend_name
+ self.backend_id = backend_name if backend_id is None else backend_id
+
info = self._name_to_info[backend_name]
self._compiled_module_class = info["compiled_module_class"]
self.driver = info["driver"]
self.compiler_targets = info["compiler_targets"]
- self.name = backend_name if artifact_name is None else artifact_name
- def compile(self,
- module: Type[tf.Module],
- exported_names: Sequence[str] = (),
- artifacts_dir: str = None) -> CompiledModule:
+ def compile_from_class(self,
+ module_class: Type[tf.Module],
+ exported_names: Sequence[str] = (),
+ artifacts_dir: str = None) -> CompiledModule:
"""Creates a 'CompiledModule' for this backend."""
- return self._compiled_module_class(module, self, exported_names,
- artifacts_dir)
+ return self._compiled_module_class.create_from_class(
+ module_class, self, exported_names, artifacts_dir)
@classmethod
def get_all_backends(cls) -> Sequence["BackendInfo"]:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 1a441d0..a9c259c 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -89,7 +89,7 @@
])
def test_unaltered_state(self, backend_name):
backend_info = tf_utils.BackendInfo(backend_name)
- module = backend_info.compile(StatefulCountingModule)
+ module = backend_info.compile_from_class(StatefulCountingModule)
# Test that incrementing works properly.
self.assertEqual([0.], module.get_count())
@@ -126,8 +126,8 @@
backend_info = tf_utils.BackendInfo(backend_name)
# Test compilation is the same.
- module_1 = backend_info.compile(RandomInitModule)
- module_2 = backend_info.compile(RandomInitModule)
+ module_1 = backend_info.compile_from_class(RandomInitModule)
+ module_2 = backend_info.compile_from_class(RandomInitModule)
self.assertAllEqual(module_1.get(), module_2.get())
# Test reinitialization is the same.
diff --git a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
index 78baeb6..df9b317 100644
--- a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
@@ -54,7 +54,7 @@
def xla_load_module_proto(
xla_computation,
compiler_context: Optional[Context] = None,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE) -> Module:
"""Loads a XLA saved model from its persistent representation.
@@ -63,8 +63,7 @@
Args:
xla_computation: XLA Computation generate from XLA Python client
compiler_context: The pyiree.compiler.Context() backing the module.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to XLA_IMPORT_PASS_PIPELINE.
@@ -85,21 +84,20 @@
def xla_compile_module_proto(
xla_computation,
compiler_context: Optional[Context] = None,
- exported_names: Collection[str] = (),
+ exported_names: Sequence[str] = (),
pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE,
- target_backends: Collection[str] = ()
+ target_backends: Sequence[str] = ()
) -> binding.OpaqueBlob:
"""Loads and compiles a XLA saved model in one shot.
Args:
xla_computation: XLA Computation generate from XLA Python client
compiler_context: The pyiree.compiler.Context() backing the module.
- exported_names: Optional tuple of strings representing the exported names to
- keep.
+ exported_names: Optional sequence representing the exported names to keep.
pass_pipeline: Passes to run on the imported module prior to returning.
Defaults to XLA_IMPORT_PASS_PIPELINE.
- target_backends: The specific target backends to compile for (defaults to
- all compiled in targets).
+ target_backends: Optional sequence of specific target backends to compile
+ for (defaults to all compiled in targets).
Returns:
An OpaqueBlob representing the compiled module.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 2a0b458..dc289d6 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -95,6 +95,7 @@
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
+ "sort_test.py",
"strings_test.py",
]
@@ -112,6 +113,7 @@
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
+ "sort_test.py",
"strings_test.py",
]
@@ -218,7 +220,6 @@
name = "mobile_bert_squad_tests",
backends_to_srcs = {
"tf": ["mobile_bert_squad_test.py"],
- "tflite": ["mobile_bert_squad_test.py"],
},
reference_backend = "tf",
tags = [
@@ -234,6 +235,7 @@
iree_e2e_test_suite(
name = "mobile_bert_squad_tests_failing",
backends_to_srcs = {
+ "tflite": ["mobile_bert_squad_test.py"],
"iree_vmla": ["mobile_bert_squad_test.py"],
"iree_llvmjit": ["mobile_bert_squad_test.py"],
"iree_vulkan": ["mobile_bert_squad_test.py"],
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
index 9c55ead..ae5d0d2 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
@@ -22,7 +22,7 @@
iree_cmake_extra_content(
content = """
-if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} AND NOT ${IREE_TARGET_BACKEND_METAL-SPIRV})
return()
endif()
""",
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
index 0922d4a..45d0b21 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} AND NOT ${IREE_TARGET_BACKEND_METAL-SPIRV})
return()
endif()
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index a80a260..6af672a 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -42,6 +42,7 @@
// Pseudo-ops are illegal.
// If we end up with a lot of these, consider using an "is pseudo" trait.
addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
+ addIllegalOp<IREE::VMLA::SortPseudoOp>();
// Allow other ops to pass through so long as their type is valid (not a
// tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index ccd1a3f..75b640f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -704,6 +704,29 @@
TypeConverter &typeConverter;
};
+struct SortOpConversion : public OpConversionPattern<IREE::VMLA::SortPseudoOp> {
+ SortOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ LogicalResult matchAndRewrite(
+ IREE::VMLA::SortPseudoOp srcOp, ArrayRef<Value> rawOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto inputType =
+ srcOp.getOperand().getType().cast<ShapedType>().getElementType();
+ auto src = rawOperands[0];
+ auto src_shape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.value(), typeConverter, rewriter);
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ rewriter.createOrFold<IREE::VMLA::SortOp>(srcOp.getLoc(), src, src_shape,
+ dst, TypeAttr::get(inputType));
+ rewriter.replaceOp(srcOp, {dst});
+ return success();
+ }
+
+ TypeConverter &typeConverter;
+};
+
struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
ConvertOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
@@ -769,6 +792,9 @@
IREE::VMLA::BatchMatMulOp>>(context,
typeConverter);
+ // vmla.sort.pseudo
+ patterns.insert<SortOpConversion>(context, typeConverter);
+
// Simple 1:1 conversion patterns using the automated trait-based converter.
// Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
new file mode 100644
index 0000000..0903793
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @sort1D(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[C16:%.+]] = constant 16 : index
+ // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
+ // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+ // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4]>), out [[BL]] : f32
+ // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+ // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4]>) {batch_dims = 0 : i64, dim = 0 : i64} : f32
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+ // CHECK: return [[BUF]] : !vmla.buffer
+ return %sort : tensor<4xf32>
+}
+
+
+// CHECK-LABEL: func @sort2D
+func @sort2D(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[C64:%.+]] = constant 64 : index
+ // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,4]>
+ // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+ // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BL]] : f32
+ // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+ // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4,4]>) {batch_dims = 1 : i64, dim = 1 : i64} : f32
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+ // CHECK: return [[BUF]] : !vmla.buffer
+ return %sort : tensor<4x4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 7e4f6ea..1b66485 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -321,6 +321,7 @@
VMLA_TYPED_IMPORT_OP(IREE::VMLA::FloorOp, "vmla.floor");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CeilOp, "vmla.ceil");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::RoundOp, "vmla.round");
+ VMLA_TYPED_IMPORT_OP(IREE::VMLA::SortOp, "vmla.sort");
patterns.insert<VMLAConvertImportOpConversion>(context, importSymbols,
typeConverter, "vmla.convert");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f5cd0fe..422fed3 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -422,7 +422,7 @@
}
//===----------------------------------------------------------------------===//
-// VMLA Ops: Convultion
+// VMLA Ops: Convolution
//===----------------------------------------------------------------------===//
def VLMA_ConvOp : VMLA_Op<"conv", [VMLA_IncludeShapes]> {
@@ -460,6 +460,46 @@
}
//===----------------------------------------------------------------------===//
+// VMLA Ops: Sorting
+//===----------------------------------------------------------------------===//
+
+def VMLA_SortPseudoOp : VMLA_Op<"sort.pseudo"> {
+ let summary = "Tensor-level pseudo-op of VMLA::SortOp.";
+ let description = [{
+ This is a tensor-level version of VMLA::SortOp, to facilitate
+ the lowering process.
+
+ This operation generates a sorted index list along the last dimension,
+ performing batch-wise along all other dimensions.
+ }];
+ let arguments = (ins
+ AnyTensor:$value
+ );
+ let results = (outs
+ I32Tensor:$dst
+ );
+
+ let assemblyFormat = [{
+ $value attr-dict `:` `(`type($value)`)` `->` type($dst)
+ }];
+}
+
+def VMLA_SortOp : VMLA_ElementTypeOp<"sort", [VMLA_IncludeShapes]> {
+ let arguments = (ins
+ VMLA_Buffer:$src,
+ VMLA_Shape:$src_shape,
+ VMLA_Buffer:$dst,
+ VMLA_AnyTypeAttr:$element_type
+ );
+
+ let assemblyFormat = [{
+ $src`(`$src_shape `:` type($src_shape)`)``,`
+ `out` $dst attr-dict `:` $element_type
+ }];
+}
+
+
+//===----------------------------------------------------------------------===//
// VMLA Ops: GEMM/GEMV
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index d937026..62dd921 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -277,6 +277,96 @@
}
};
+// Lower mhlo::SortOp to an pseudo SortOp in the VMLA dialect. This
+// pseudo op generates a set of ordered indices for that array along the last
+// dimension. Then using a torch_index_select the values can be reordered to
+// support arbitrary inputs.
+//
+// TODO(suderman): This lowering only covers the case of ascending values, we
+// should support a separate descending value case by having separate
+// SortAscending and SortDescending operations.
+class LowerSortOp : public OpRewritePattern<mhlo::SortOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(mhlo::SortOp op,
+ PatternRewriter &rewriter) const override {
+ auto operandTy = op.getOperand(0).getType().cast<RankedTensorType>();
+ bool lastDimension =
+ (op.dimension() == -1) || (op.dimension() == (operandTy.getRank() - 1));
+
+ // TODO(suderman): Add transpose to sort along the last dimension.
+ if (!lastDimension) return failure();
+
+ auto &comparator = op.comparator();
+ auto &block = comparator.getBlocks().front();
+ auto &operations = block.getOperations();
+ auto comparison = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
+
+ // First verify that the block is purely a return of a comparison. This
+ // handles sorting a single tensor of values.
+ if (!comparison) return failure();
+
+ auto returnOp =
+ dyn_cast_or_null<mhlo::ReturnOp>(&(*(++operations.begin())));
+ if (!returnOp) return failure();
+
+ if (returnOp.getOperand(0) != comparison.getResult()) return failure();
+
+ // Determine which operands being compared.
+ auto lhs = comparison.getOperand(0);
+ auto rhs = comparison.getOperand(1);
+ auto lhsIndex = -1;
+ auto rhsIndex = -1;
+ for (auto arg : llvm::enumerate(block.getArguments())) {
+ if (arg.value() == lhs) lhsIndex = arg.index();
+ if (arg.value() == rhs) rhsIndex = arg.index();
+ }
+
+ // This should never happen but best to check.
+ if (lhsIndex == -1) return failure();
+ if (rhsIndex == -1) return failure();
+
+ // They should not be the same.
+ if (lhsIndex == rhsIndex) return failure();
+
+ // Comparisons need to pull from same Sort operand..
+ auto lhsOperand = lhsIndex / 2;
+ auto rhsOperand = rhsIndex / 2;
+ if (lhsOperand != rhsOperand) return failure();
+
+ // Must be GT, GE, LT, or LE.
+ auto isGt = comparison.comparison_direction() == "GT" ||
+ comparison.comparison_direction() == "GE";
+ auto isLt = comparison.comparison_direction() == "LT" ||
+ comparison.comparison_direction() == "LE";
+ if (!isGt && !isLt) return failure();
+
+ bool operandParity = lhsIndex > rhsIndex;
+ auto isAscending = operandParity ^ isGt;
+ // TODO(suderman): Add support for descended sorting.
+ if (!isAscending) return failure();
+
+ auto operand = op.getOperand(lhsOperand);
+ auto sortedIndices = rewriter.create<VMLA::SortPseudoOp>(
+ op.getLoc(),
+ RankedTensorType::get(operandTy.getShape(), rewriter.getI32Type()),
+ operand);
+
+ llvm::SmallVector<Value, 6> sortedResults;
+ for (auto operand : op.getOperands()) {
+ auto tensorTy = operand.getType().cast<RankedTensorType>();
+ auto gathered = rewriter.create<mhlo::TorchIndexSelectOp>(
+ op.getLoc(), tensorTy, operand, sortedIndices,
+ /**dim=*/operandTy.getRank() - 1,
+ /**batch_dims=*/operandTy.getRank() - 1);
+ sortedResults.push_back(gathered);
+ }
+
+ rewriter.replaceOp(op, sortedResults);
+ return success();
+ }
+};
+
class PreConversionLoweringPass
: public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
public:
@@ -310,6 +400,8 @@
patterns.insert<LowerBroadcastInDimOp>(context);
target.addIllegalOp<mhlo::BroadcastOp>();
patterns.insert<LowerBroadcastOp>(context);
+ target.addIllegalOp<mhlo::SortOp>();
+ patterns.insert<LowerSortOp>(context);
if (failed(applyPartialConversion(getOperation(), target, patterns))) {
return signalPassFailure();
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
index 3473e44..0b9cd82 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -17,6 +17,38 @@
// -----
// CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+ // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 0 : i64, dim = 0 : i64}
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+ // CHECK: return [[GATHER]]
+ return %sort : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+ // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 1 : i64, dim = 1 : i64}
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+ // CHECK return [[GATHER]]
+ return %sort : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
func @f(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
// CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 17d7e85..ff575b0 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -333,6 +333,20 @@
vm.import @ceil.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
vm.import @round.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
+
+vm.import @sort.i8(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i16(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i32(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.f32(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+
//===----------------------------------------------------------------------===//
// VMLA Ops: conversion
//===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index c6ab9d6..ba5b8bc 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -168,6 +168,12 @@
absl::Span<const int32_t> dimensions);
};
+struct Sort {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<int32_t> dst_buffer, ShapeSpan src_shape);
+};
+
struct Broadcast {
template <typename T>
static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index d3545d1..0b4f904 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -15,7 +15,10 @@
#ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
#define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
+#include <algorithm>
#include <cmath>
+#include <iostream>
+#include <numeric>
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
@@ -519,6 +522,25 @@
}
template <typename T>
+Status Sort::Execute(absl::Span<const T> src_buffer,
+ absl::Span<int32_t> dst_buffer, ShapeSpan src_shape) {
+ int elements = src_buffer.size();
+ const int sort_size = src_shape.back();
+
+ for (int i = 0; i < elements; i += sort_size) {
+ auto src_subspan = src_buffer.subspan(i, sort_size);
+ auto dst_subspan = dst_buffer.subspan(i, sort_size);
+ std::iota(dst_subspan.begin(), dst_subspan.end(), 0);
+ std::stable_sort(dst_subspan.begin(), dst_subspan.end(),
+ [&src_subspan](int32_t i1, int32_t i2) {
+ return src_subspan[i1] < src_subspan[i2];
+ });
+ }
+
+ return OkStatus();
+}
+
+template <typename T>
Status Broadcast::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 09dbb3c..5852de0 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -642,6 +642,19 @@
IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);
IREE_VMLA_UNARY_OP(RoundF32, kernels::Round, float);
+#define IREE_VMLA_SORT_OP(name, type) \
+ Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape, \
+ const vm::ref<Buffer>& dst) { \
+ IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
+ return kernels::Sort::Execute<type>(src->As<type>(), dst->As<int32_t>(), \
+ src_shape); \
+ }
+
+ IREE_VMLA_SORT_OP(SortI8, int8_t);
+ IREE_VMLA_SORT_OP(SortI16, int16_t);
+ IREE_VMLA_SORT_OP(SortI32, int32_t);
+ IREE_VMLA_SORT_OP(SortF32, float);
+
//===--------------------------------------------------------------------===//
// VMLA Ops: conversion
//===--------------------------------------------------------------------===//
@@ -970,6 +983,10 @@
vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
vm::MakeNativeFunction("round.f32", &VMLAModuleState::RoundF32),
+ vm::MakeNativeFunction("sort.i8", &VMLAModuleState::SortI8),
+ vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16),
+ vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32),
+ vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32),
vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
diff --git a/iree/test/e2e/xla_ops/sort.mlir b/iree/test/e2e/xla_ops/sort.mlir
new file mode 100644
index 0000000..1820d8d
--- /dev/null
+++ b/iree/test/e2e/xla_ops/sort.mlir
@@ -0,0 +1,40 @@
+func @sort1D() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[3, 2, 1, 4]> : tensor<4xi32>
+
+ %sort = "mhlo.sort"(%input) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 0 : i64, is_stable = false} : (tensor<4xi32>) -> tensor<4xi32>
+
+ check.expect_eq_const(%sort, dense<[1, 2, 3, 4]> : tensor<4xi32>) : tensor<4xi32>
+ return
+}
+
+func @sort2D() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[[1, 2, 3, 4],
+ [4, 3, 2, 1]]> : tensor<2x4xi32>
+
+ %sort = "mhlo.sort"(%input) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = false} : (tensor<2x4xi32>) -> tensor<2x4xi32>
+
+ check.expect_eq_const(%sort, dense<[[1, 2, 3, 4], [1, 2, 3, 4]]> : tensor<2x4xi32>) : tensor<2x4xi32>
+ return
+}
+
+func @sort3D() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[[[1, 2, 3, 4],
+ [4, 3, 2, 1]]]> : tensor<1x2x4xi32>
+
+ %sort = "mhlo.sort"(%input) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 2 : i64, is_stable = false} : (tensor<1x2x4xi32>) -> tensor<1x2x4xi32>
+
+ check.expect_eq_const(%sort, dense<[[[1, 2, 3, 4], [1, 2, 3, 4]]]> : tensor<1x2x4xi32>) : tensor<1x2x4xi32>
+ return
+}
diff --git a/scripts/git/google_to_main.sh b/scripts/git/google_to_main.sh
index 8b8543d..b00fd9a 100755
--- a/scripts/git/google_to_main.sh
+++ b/scripts/git/google_to_main.sh
@@ -56,4 +56,7 @@
echo "${BODY?}"
exit 1
fi
-gh pr create --base main --title="${TITLE?}" --body="${BODY?}"
+
+# Workaround https://github.com/cli/cli/issues/1820
+GITHUB_USERNAME="$(gh config get -h github.com user)"
+gh pr create --base main --head="${GITHUB_USERNAME?}:${PR_BRANCH?}" --title="${TITLE?}" --body="${BODY?}"
diff --git a/scripts/git/main_to_google.sh b/scripts/git/main_to_google.sh
index ee07573..b9e56ae 100755
--- a/scripts/git/main_to_google.sh
+++ b/scripts/git/main_to_google.sh
@@ -56,4 +56,7 @@
echo "${BODY?}"
exit 1
fi
-gh pr create --base google --title="${TITLE?}" --body="${BODY?}"
+
+# Workaround https://github.com/cli/cli/issues/1820
+GITHUB_USERNAME="$(gh config get -h github.com user)"
+gh pr create --base google --head="${GITHUB_USERNAME?}:${PR_BRANCH?}" --title="${TITLE?}" --body="${BODY?}"
diff --git a/scripts/git/update_tf_submodule.sh b/scripts/git/update_tf_submodule.sh
index 0031cfa..28799d1 100755
--- a/scripts/git/update_tf_submodule.sh
+++ b/scripts/git/update_tf_submodule.sh
@@ -72,4 +72,7 @@
echo "${BODY?}"
exit 1
fi
-gh pr create --title="${TITLE?}" --body="${BODY?}" --base="${BASE_BRANCH?}"
+
+# Workaround https://github.com/cli/cli/issues/1820
+GITHUB_USERNAME="$(gh config get -h github.com user)"
+gh pr create --base="${BASE_BRANCH?}" --head="${GITHUB_USERNAME?}:${PR_BRANCH?}" --title="${TITLE?}" --body="${BODY?}"