Merge google -> main (#3337)

* 1e81da7a Merge main -> google
* 62c8861c Merge branch 'google' into main-to-google
* 8896f35a Synchronize submodules
* 2c02b9c1 Integrate TF at tensorflow/tensorflow@1454ee0907ee
* b843a99e Synchronize submodules
* bb076514 Integrate LLVM at llvm/llvm-project@8825fec37e73
* 9adef4b1 Synchronize submodules
* b56cefca Integrate LLVM at llvm/llvm-project@bfd7ee92ccec
* 52b019e0 Fix vkGetInstanceProcAddr returning nil issue.
* c795db9d Rename filecheck-lib target to match style of other targets
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 774412d..ba966db 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -99,7 +99,7 @@
 
 def tf_saved_model_to_compiler_module(
     saved_model_dir: str,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> Module:
   """Converts a TensorFlow SavedModel into a MLIR module.
@@ -108,8 +108,7 @@
 
   Args:
     saved_model_dir: Directory of the saved model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -131,18 +130,17 @@
 
 def compile_tf_saved_model(
     saved_model_dir: str,
-    exported_names: Collection[str] = (),
-    target_backends: Collection[str] = (),
+    exported_names: Sequence[str] = (),
+    target_backends: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
   """Compiles a TensorFlow SavedModel to IREE in one shot.
 
   Args:
     saved_model_dir: Directory of the saved model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    exported_names: Optional sequence representing the exported names to keep.
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -160,7 +158,7 @@
 def tf_signature_def_saved_model_to_compiler_module(
     saved_model_dir: str,
     saved_model_tags: Set[str] = set(),
-    exported_names: Collection[str] = [],
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> Module:
   """Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
@@ -168,8 +166,7 @@
   Args:
     saved_model_dir: Directory of the saved model.
     saved_model_tags: Optional set of tags to use when loading the model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -194,8 +191,8 @@
 def compile_tf_signature_def_saved_model(
     saved_model_dir: str,
     saved_model_tags: Set[str] = set(),
-    exported_names: Collection[str] = (),
-    target_backends: Collection[str] = (),
+    exported_names: Sequence[str] = (),
+    target_backends: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
   """Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
@@ -203,10 +200,9 @@
   Args:
     saved_model_dir: Directory of the saved model.
     saved_model_tags: Optional set of tags to use when loading the model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    exported_names: Optional sequence representing the exported names to keep.
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -222,7 +218,7 @@
 
 def tf_module_to_compiler_module(
     module: tf.Module,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None,
     saved_model_dir: str = None) -> Module:
@@ -230,8 +226,7 @@
 
   Args:
     module: The tf.Module instance to convert to MLIR
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -259,8 +254,8 @@
 
 
 def compile_tf_module(module: tf.Module,
-                      exported_names: Collection[str] = (),
-                      target_backends: Collection[str] = (),
+                      exported_names: Sequence[str] = (),
+                      target_backends: Sequence[str] = (),
                       pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
                       compiler_context: Optional[Context] = None,
                       saved_model_dir: str = None):
@@ -268,10 +263,9 @@
 
   Args:
     module: The tf.Module instance to convert to MLIR
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    exported_names: Optional sequence representing the exported names to keep.
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index fa3df03..aa9481e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -66,22 +66,22 @@
 
 
 def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
-  """Decodes --target_backends and creates unique names for their artifacts."""
+  """Decodes --target_backends and creates unique ids for them."""
   backend_names = FLAGS.target_backends.split(",")
   backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
-  artifact_names = []
+  backend_ids = []
 
   # If there are multiple copies of the same backend_name, index them. e.g.
   # backend_names = ["tf", "iree_vmla", "tf"]
-  # --> artifact_names = ["tf_0", "iree_vmla", "tf_1"]
+  # --> backend_ids = ["tf_0", "iree_vmla", "tf_1"]
   for backend_name in backend_names:
     if backend_name in backend_to_index:
-      artifact_names.append(f"{backend_name}_{backend_to_index[backend_name]}")
+      backend_ids.append(f"{backend_name}_{backend_to_index[backend_name]}")
       backend_to_index[backend_name] += 1
     else:
-      artifact_names.append(backend_name)
+      backend_ids.append(backend_name)
 
-  return backend_names, artifact_names
+  return backend_names, backend_ids
 
 
 def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
@@ -95,10 +95,10 @@
   """
   if FLAGS.target_backends is not None:
     logging.info("Using backends from command line: %s", FLAGS.target_backends)
-    backend_names, names = _parse_target_backends()
+    backend_names, backend_ids = _parse_target_backends()
     backends = [
-        tf_utils.BackendInfo(backend, name)
-        for backend, name in zip(backend_names, names)
+        tf_utils.BackendInfo(backend_name, backend_id)
+        for backend_name, backend_id in zip(backend_names, backend_ids)
     ]
   else:
     # If no backends are specified, use them all.
@@ -261,10 +261,11 @@
       # Extract metadata from module and function.
       self.module_name = module.module_name
       self.compiled_paths = module.compiled_paths
-      self.backend_name = module.backend
+      self.backend_name = module.backend_info.backend_name
+      self.backend_id = module.backend_info.backend_id
+      self.backend_driver = module.backend_info.driver
       self.iree_serializable = module.iree_serializable()
       self.tflite_serializable = module.tflite_serializable()
-      self.backend_driver = module.backend_driver
       self.function_name = function.__name__
       self.function_sourcefile = inspect.getsourcefile(function)
       source, start_line = inspect.getsourcelines(function)
@@ -276,9 +277,10 @@
       self.module_name = _load_dict["module_name"]
       self.compiled_paths = _load_dict["compiled_paths"]
       self.backend_name = _load_dict["backend_name"]
+      self.backend_id = _load_dict["backend_id"]
+      self.backend_driver = _load_dict["backend_driver"]
       self.iree_serializable = _load_dict["iree_serializable"]
       self.tflite_serializable = _load_dict["tflite_serializable"]
-      self.backend_driver = _load_dict["backend_driver"]
       self.function_name = _load_dict["function_name"]
       self.function_sourcefile = _load_dict["function_sourcefile"]
       self.function_line_numbers = _load_dict["function_line_numbers"]
@@ -286,7 +288,7 @@
       self.calls = _load_dict["calls"]
 
   def __str__(self):
-    header = (f"Trace of {self.module_name} compiled to '{self.backend_name}' "
+    header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
               f"on function '{self.function_name}':")
     # Give each call a number so it's easier to compare between multiple traces.
     calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
@@ -327,11 +329,11 @@
 
       if not calls_match:
         logging.error("Comparision between '%s' and '%s' failed on method '%s'",
-                      ref_trace.backend_name, tar_trace.backend_name,
+                      ref_trace.backend_id, tar_trace.backend_id,
                       ref_call.method)
-        logging.error("Reference call '%s':\n%s", ref_trace.backend_name,
+        logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
                       ref_call)
-        logging.error("Target call '%s':\n%s", tar_trace.backend_name, tar_call)
+        logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
 
       traces_match = traces_match and calls_match
     return traces_match
@@ -434,14 +436,20 @@
       trace_dir: str, path to the directory to serialize the trace to.
     """
 
+    compiled_paths = None
+    if self.compiled_paths is not None:
+      # Convert to a dict to avoid the issues with serializing defaultdicts.
+      compiled_paths = dict(self.compiled_paths)
+
     # Python serialization.
     metadata = {
         "module_name": self.module_name,
-        "compiled_paths": self.compiled_paths,
+        "compiled_paths": compiled_paths,
         "backend_name": self.backend_name,
+        "backend_id": self.backend_id,
+        "backend_driver": self.backend_driver,
         "iree_serializable": self.iree_serializable,
         "tflite_serializable": self.tflite_serializable,
-        "backend_driver": self.backend_driver,
         "function_name": self.function_name,
         "function_sourcefile": self.function_sourcefile,
         "function_line_numbers": self.function_line_numbers,
@@ -493,7 +501,7 @@
 
 
 def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
-  trace_dir = os.path.join(artifacts_dir, trace.backend_name, "traces",
+  trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
                            trace.function_name)
   os.makedirs(trace_dir, exist_ok=True)
   return trace_dir
@@ -559,17 +567,17 @@
 def compile_tf_module(
     module_class: Type[tf.Module], exported_names: Sequence[str] = ()
 ) -> Callable[[Any], Any]:
-  """CompiledModuleTestCase decorator that compiles a tf.Module.
-
-  A CompiledModule is created for each backend in --target_backends. They can
-  be accessed individually via self.compiled_modules.backend_name or as a union
-  via self.get_module().
+  """Compiles module_class to each backend that we test.
 
   Args:
     module_class: the tf.Module subclass to compile.
     exported_names: optional iterable of strings representing which of
       module_class's functions to compile. If exported_names is empty all
       functions will be compiled.
+
+  Returns:
+    A 'Modules' namedtuple containing the reference module, target modules and
+    artifacts directory.
   """
 
   # Setup the directory for saving compilation artifacts and traces.
@@ -580,7 +588,7 @@
                                           f"{FLAGS.reference_backend}_ref")
   tar_backend_infos = get_target_backends()
 
-  compile_backend = lambda backend_info: backend_info.compile(
+  compile_backend = lambda backend_info: backend_info.compile_from_class(
       module_class, exported_names, artifacts_dir)
 
   ref_module = compile_backend(ref_backend_info)
@@ -631,7 +639,7 @@
     failed_backend_indices = []
     for i, tar_trace in enumerate(tar_traces):
       logging.info("Comparing the reference backend '%s' with '%s'",
-                   ref_trace.backend_name, tar_trace.backend_name)
+                   ref_trace.backend_id, tar_trace.backend_id)
       traces_match = Trace.compare_traces(ref_trace, tar_trace)
       if not traces_match:
         failed_backend_indices.append(i)
@@ -649,7 +657,7 @@
     if failed_backend_indices:
       # Extract info for logging.
       failed_backends = [
-          tar_traces[i].backend_name for i in failed_backend_indices
+          tar_traces[i].backend_id for i in failed_backend_indices
       ]
       self.fail(
           "Comparision between the reference backend and the following targets "
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 6157c68..e1fb6e3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -117,8 +117,8 @@
       # Only outputs
       module.get_count()
 
-    module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                       tf_utils.BackendInfo('tf'))
+    module = tf_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('tf'))
     trace = tf_test_utils.Trace(module, trace_function)
     trace_function(tf_test_utils.TracedModule(module, trace))
 
@@ -140,13 +140,13 @@
       module.increment()
       module.decrement()
 
-    tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                          tf_utils.BackendInfo('tf'))
+    tf_module = tf_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('tf'))
     tf_trace = tf_test_utils.Trace(tf_module, tf_function)
     tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
 
-    vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
-                                              tf_utils.BackendInfo('iree_vmla'))
+    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
@@ -161,13 +161,13 @@
     def vmla_function(module):
       module.increment_by(np.array([22.], dtype=np.float32))
 
-    tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                          tf_utils.BackendInfo('tf'))
+    tf_module = tf_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('tf'))
     tf_trace = tf_test_utils.Trace(tf_module, tf_function)
     tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
 
-    vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
-                                              tf_utils.BackendInfo('iree_vmla'))
+    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
@@ -178,12 +178,12 @@
     def trace_function(module):
       module.increment()
       module.increment_by(np.array([81.], dtype=np.float32))
-      module.increment_by_max(
-          np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
+      module.increment_by_max(np.array([81], dtype=np.float32),
+                              np.array([92], dtype=np.float32))
       module.get_count()
 
-    module = tf_utils.IreeCompiledModule(StatefulCountingModule,
-                                         tf_utils.BackendInfo('iree_vmla'))
+    module = tf_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     trace = tf_test_utils.Trace(module, trace_function)
     trace_function(tf_test_utils.TracedModule(module, trace))
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index fc8bab6..b3237f9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -16,6 +16,7 @@
 
 # pylint: disable=protected-access
 
+import collections
 import os
 import random
 import re
@@ -91,7 +92,7 @@
 def _setup_mlir_crash_reproducer(
     function: Callable[[Any], Any],
     artifacts_dir: str,
-    backend_name: str,
+    backend_id: str,
 ) -> Callable[[Any], Any]:
   """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
 
@@ -100,7 +101,7 @@
   Args:
     function: The callable to decorate.
     artifacts_dir: The directory to write the reproducer to.
-    backend_name: The name of the backend `function` compiles to.
+    backend_id: The unique backend name to use when writting the reproducer.
 
   Returns:
     A function with the same API as the passed function.
@@ -110,7 +111,7 @@
     # Set up a crash reproducer for debugging.
     if artifacts_dir is not None:
       compiler.Context.default_crash_reproducer_path = os.path.join(
-          artifacts_dir, f"reproducer__{backend_name}.mlir")
+          artifacts_dir, f"reproducer__{backend_id}.mlir")
     try:
       results = function(*args, **kwargs)
     except Exception:  # pylint: disable=broad-except
@@ -135,7 +136,7 @@
       MLIR for the module in TF's input dialect.
     iree_input.mlir:
       The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
-    backend_name/compiled.vmfb:
+    backend_id/compiled.vmfb:
       A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
 
   Args:
@@ -167,7 +168,7 @@
 
   compiled_path = None
   if artifacts_dir is not None:
-    backend_dir = os.path.join(artifacts_dir, backend_info.name)
+    backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
     os.makedirs(backend_dir, exist_ok=True)
     compiled_path = os.path.join(backend_dir, "compiled.vmfb")
     logging.info("Saving compiled IREE module to: %s", compiled_path)
@@ -180,7 +181,7 @@
     module: Type[tf.Module],
     backend_info: "BackendInfo",
     exported_names: Sequence[str] = (),
-    artifacts_dir: str = None
+    artifacts_dir: str = None,
 ) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
   """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
 
@@ -193,8 +194,7 @@
   Args:
     module: A tf.Module.
     backend_info: BackendInfo with the details for compiling module to IREE.
-    exported_names: Iterable of dotted function names to consider for
-      compilation.
+    exported_names: Optional sequence representing the exported names to keep.
     artifacts_dir: An optional string pointing to where compilation artifacts
       should be saved. No compilation artifacts will be saved if this is not
       provided.
@@ -212,29 +212,72 @@
                                                 artifacts_dir)
 
   _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
-                                                 backend_info.name)
+                                                 backend_info.backend_id)
   return _compile_module(module, exported_names, backend_info, artifacts_dir)
 
 
 class CompiledModule(object):
   """Base class for the TF and IREE compiled modules."""
 
-  def __init__(self, module_class: Type[tf.Module], backend_info: "BackendInfo",
-               exported_names: Sequence[str], artifacts_dir: str):
-    """Shared base constructor – not useful on its own."""
-    self._module_class = module_class
-    self._backend_info = backend_info
-    self._exported_names = exported_names
-    self._artifacts_dir = artifacts_dir
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+  ):
+    """Shared base constructor – not useful on its own.
 
-    # Public attributes:
-    self.backend = self._backend_info.name
-    self.backend_driver = self._backend_info.driver
-    self.module_name = self._module_class.__name__
-    self.compiled_paths = None
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+    """
+    self.module_name = module_name
+    self.backend_info = backend_info
+    self.compiled_paths = compiled_paths
 
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_instance(cls,
+                           module_instance: tf.Module,
+                           backend_info: "BackendInfo",
+                           exported_names: Sequence[str] = (),
+                           artifacts_dir: str = None):
+    """Compile a tf.Module instance to the target backend in backend_info.
+
+    This is only implemented for IreeCompiledModule.
+
+    Args:
+      module_instance: The tf.Module instance to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
     raise NotImplementedError()
 
   def iree_serializable(self):
@@ -244,14 +287,6 @@
     return False
 
 
-def _get_non_inhereted_function_names(cls):
-  """Gets all methods that cls has that its parents don't have."""
-  names = set(dir(cls))
-  for parent in cls.__bases__:
-    names -= set(dir(parent))
-  return list(names)
-
-
 class _FunctionWrapper(object):
 
   def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
@@ -277,57 +312,94 @@
 class IreeCompiledModule(CompiledModule):
   """Iree compiled module."""
 
-  def __init__(self,
-               module_class: Type[tf.Module],
-               backend_info: "BackendInfo",
-               exported_names: Sequence[str] = (),
-               artifacts_dir: str = None):
-    """Compile a tf.Module to the target backend in backend_info.
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+      vm_module: rt.VmModule,
+      config: rt.Config,
+  ):
+    """Base constructor – Use one of the named constructors instead.
 
     Args:
-      module_class: the tf.Module subclass to compile.
-      backend_info: an element of BackendInfo corresponding to the IREE backend
-        to compile to.
-      exported_names: an optional iterable of strings representing which of the
-        module_class's functions to compile. If exported_names is empty all
-        functions will be compiled.
-      artifacts_dir: an optional path to save compilation artifacts to.
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+      vm_module: A rt.VmModule containing compilation info to wrap.
+      config: A rt.Config containing compilation info to wrap.
     """
-    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+    super().__init__(module_name, backend_info, compiled_paths)
+    self._vm_module = vm_module
+    self._config = config
+    self.reinitialize()
 
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
     set_random_seed()
-    self._module_blob, compiled_path = _incrementally_compile_tf_module(
-        module=module_class(),
+    module_instance = module_class()
+    return cls.create_from_instance(module_instance, backend_info,
+                                    exported_names, artifacts_dir)
+
+  @classmethod
+  def create_from_instance(cls,
+                           module_instance: tf.Module,
+                           backend_info: "BackendInfo",
+                           exported_names: Sequence[str] = (),
+                           artifacts_dir: str = None):
+    """Compile a tf.Module instance to the target backend in backend_info.
+
+    Args:
+      module_instance: The tf.Module instance to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    module_blob, compiled_path = _incrementally_compile_tf_module(
+        module=module_instance,
         backend_info=backend_info,
         exported_names=exported_names,
         artifacts_dir=artifacts_dir)
-    self._module = rt.VmModule.from_flatbuffer(self._module_blob)
-    self._config = rt.Config(driver_name=backend_info.driver)
+    vm_module = rt.VmModule.from_flatbuffer(module_blob)
+    config = rt.Config(driver_name=backend_info.driver)
 
-    self.compiled_paths = None
+    compiled_paths = None
     if compiled_path is not None:
-      if not len(exported_names):
-        # Get all method names on 'module_class' that aren't on 'tf.Module'.
-        # This doesn't address all possbile scenarios.
-        # TODO(meadowlark): Figure out how to get a list of all of the functions
-        # that this module has access to via `pyiree.rt.system_api.BoundModule`.
-        exported_names = _get_non_inhereted_function_names(module_class)
-      self.compiled_paths = dict([
-          (method, compiled_path) for method in exported_names
-      ])
+      # IREE bundles every compiled method into the same compiled module.
+      compiled_paths = collections.defaultdict(lambda: compiled_path)
 
-    self.reinitialize()
+    module_name = type(module_instance).__name__
+
+    return cls(module_name, backend_info, compiled_paths, vm_module, config)
 
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
     # set_random_seed is not needed here because the model_class.__init__ is not
     # called.
-    self._context = rt.SystemContext(modules=[self._module],
+    self._context = rt.SystemContext(modules=[self._vm_module],
                                      config=self._config)
 
   def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
     # Try to resolve it as a function.
-    m = self._context.modules[self._module.name]
+    m = self._context.modules[self._vm_module.name]
     f = m[attr]
     return _IreeFunctionWrapper(self._context, f)
 
@@ -376,29 +448,54 @@
   normalize TensorFlow's output to Numpy.
   """
 
-  def __init__(self,
-               module_class: Type[tf.Module],
-               backend_info: "BackendInfo",
-               exported_names: Sequence[str] = (),
-               artifacts_dir: str = None):
-    """Wrap a tf.Module in a TFCompiledModule facade.
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      constructor: Callable[[], tf.Module],
+      exported_names: Sequence[str],
+  ):
+    """Base constructor – Use one of the named constructors instead.
 
     Args:
-      module_class: the tf.Module subclass to 'compile'.
-      backend_info: one of the 'tf*' elements in BackendInfo.
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      constructor: A callable (class or function) which returns the tf.Module
+        subclass instance to wrap.
       exported_names: an optional iterable of strings representing which of the
-        module_class's functions should be callable. If exported_names is empty
-        then all functions will be callable.
-      artifacts_dir: an optional path to save compilation artifacts to. Has no
-        effect for this subclass as nothing is compiled.
+        tf.Module subclass instance's functions should be callable. If
+        exported_names is empty then all functions will be callable.
     """
-    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+    super().__init__(module_name, backend_info, compiled_paths=None)
+    self._constructor = constructor
+    self._exported_names = exported_names
     self.reinitialize()
 
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    module_name = module_class.__name__
+    constructor = module_class
+    return cls(module_name, backend_info, constructor, exported_names)
+
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
     set_random_seed()
-    self._tf_module = self._module_class()
+    self._tf_module = self._constructor()
 
   def __getattr__(self, attr: str) -> _TfFunctionWrapper:
     # Try to resolve it as a function.
@@ -412,6 +509,14 @@
     return _TfFunctionWrapper(f)
 
 
+def _get_non_inhereted_function_names(cls):
+  """Gets all methods that cls has that its parents don't have."""
+  names = set(dir(cls))
+  for parent in cls.__bases__:
+    names -= set(dir(parent))
+  return list(names)
+
+
 def _get_concrete_functions(module_class: Type[tf.Module],
                             exported_names: Sequence[str] = ()):
   """Get concrete functions from non-inherited methods or exported_names."""
@@ -425,12 +530,12 @@
   return functions, exported_names
 
 
-def compile_to_tflite(
+def tf_module_to_tflite_interpreters(
     module_class: Type[tf.Module],
     exported_names: Sequence[str] = (),
     artifacts_dir: str = None
 ) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str]], None]:
-  """Compile a dict of TFLite interpreters for the methods on module_class.
+  """Compile a tf.Module to TFLite interpreters for each of its methods.
 
   Args:
     module_class: A tf.Module subclass to compile with TFLite. If module_class
@@ -463,22 +568,14 @@
     if artifacts_dir is not None:
       compiled_paths[name] = tflite_path
 
+  # Convert module_class's methods into TFLite module byte-strings.
   tflite_modules = []
-  names = []
-  if hasattr(module_class, "get_legacy_tflite_saved_model_converter_kwargs"):
-    kwargs = module_class.get_legacy_tflite_saved_model_converter_kwargs()
-    converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
-        kwargs["model_path"],
-        input_arrays=kwargs["input_arrays"],
-        output_arrays=kwargs["output_arrays"])
+  functions, names = _get_concrete_functions(module_class, exported_names)
+  for function in functions:
+    converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
     tflite_modules.append(converter.convert())
-    names.append(kwargs["exported_name"])
-  else:
-    functions, names = _get_concrete_functions(module_class, exported_names)
-    for function in functions:
-      converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
-      tflite_modules.append(converter.convert())
 
+  # Load each of the converted methods above into tf.lite.Interpreters.
   for name, tflite_module in zip(names, tflite_modules):
     if artifacts_dir is None:
       with tempfile.TemporaryDirectory() as base_dir:
@@ -537,18 +634,50 @@
 class TfLiteCompiledModule(CompiledModule):
   """Compiles a tf.Module with TFLite and allows it to be called."""
 
-  def __init__(self,
-               module_class: Type[tf.Module],
-               backend_info: "BackendInfo",
-               exported_names: Sequence[str] = (),
-               artifacts_dir: str = None):
-    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+      interpreters: Dict[str, tf.lite.Interpreter],
+  ):
+    """Base constructor – Use one of the named constructors instead.
+
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+      interpreters: A dict of tf.lite.Interpreters to make callable.
+    """
+    super().__init__(module_name, backend_info, compiled_paths)
+    self._interpreters = interpreters
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
     set_random_seed()
-    self._interpreters, self.compiled_paths = compile_to_tflite(
+    interpreters, compiled_paths = tf_module_to_tflite_interpreters(
         module_class, exported_names, artifacts_dir)
+    module_name = module_class.__name__
+    return cls(module_name, backend_info, compiled_paths, interpreters)
 
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
     # This is a noop because TFLite (mostly) doesn't support stateful modules.
     pass
 
@@ -594,36 +723,43 @@
       },
   }
 
-  def __init__(self, backend_name: str, artifact_name: str = None):
+  def __init__(self, backend_name: str, backend_id: str = None):
     """Creates a BackendInfo with the compilation details for backend_name.
 
     Args:
       backend_name: a str specifying which backend to use. Should be one of
         'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
-      artifact_name: an optional str specifying what name to use when saving
-        compiled artifacts.
+      backend_id: an optional str specifying what name to use when saving
+        compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
 
     Raises:
       KeyError: if backend_name is not one of ['tf', 'iree_vmla',
       'iree_llvmjit', 'iree_vulkan'].
+      ValueError: if backend_id doesn't start with backend_name.
     """
     if backend_name not in self._name_to_info:
       raise KeyError(
           "Expected backend_name to be one of "
           f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+    if backend_id is not None and not backend_id.startswith(backend_name):
+      raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+                       f"but got '{backend_id}'.")
+
+    self.backend_name = backend_name
+    self.backend_id = backend_name if backend_id is None else backend_id
+
     info = self._name_to_info[backend_name]
     self._compiled_module_class = info["compiled_module_class"]
     self.driver = info["driver"]
     self.compiler_targets = info["compiler_targets"]
-    self.name = backend_name if artifact_name is None else artifact_name
 
-  def compile(self,
-              module: Type[tf.Module],
-              exported_names: Sequence[str] = (),
-              artifacts_dir: str = None) -> CompiledModule:
+  def compile_from_class(self,
+                         module_class: Type[tf.Module],
+                         exported_names: Sequence[str] = (),
+                         artifacts_dir: str = None) -> CompiledModule:
     """Creates a 'CompiledModule' for this backend."""
-    return self._compiled_module_class(module, self, exported_names,
-                                       artifacts_dir)
+    return self._compiled_module_class.create_from_class(
+        module_class, self, exported_names, artifacts_dir)
 
   @classmethod
   def get_all_backends(cls) -> Sequence["BackendInfo"]:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 1a441d0..a9c259c 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -89,7 +89,7 @@
   ])
   def test_unaltered_state(self, backend_name):
     backend_info = tf_utils.BackendInfo(backend_name)
-    module = backend_info.compile(StatefulCountingModule)
+    module = backend_info.compile_from_class(StatefulCountingModule)
 
     # Test that incrementing works properly.
     self.assertEqual([0.], module.get_count())
@@ -126,8 +126,8 @@
     backend_info = tf_utils.BackendInfo(backend_name)
 
     # Test compilation is the same.
-    module_1 = backend_info.compile(RandomInitModule)
-    module_2 = backend_info.compile(RandomInitModule)
+    module_1 = backend_info.compile_from_class(RandomInitModule)
+    module_2 = backend_info.compile_from_class(RandomInitModule)
     self.assertAllEqual(module_1.get(), module_2.get())
 
     # Test reinitialization is the same.
diff --git a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
index 78baeb6..df9b317 100644
--- a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
@@ -54,7 +54,7 @@
 def xla_load_module_proto(
     xla_computation,
     compiler_context: Optional[Context] = None,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE) -> Module:
   """Loads a XLA saved model from its persistent representation.
 
@@ -63,8 +63,7 @@
   Args:
     xla_computation: XLA Computation generate from XLA Python client
     compiler_context: The pyiree.compiler.Context() backing the module.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to XLA_IMPORT_PASS_PIPELINE.
 
@@ -85,21 +84,20 @@
 def xla_compile_module_proto(
     xla_computation,
     compiler_context: Optional[Context] = None,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE,
-    target_backends: Collection[str] = ()
+    target_backends: Sequence[str] = ()
 ) -> binding.OpaqueBlob:
   """Loads and compiles a XLA saved model in one shot.
 
   Args:
     xla_computation: XLA Computation generate from XLA Python client
     compiler_context: The pyiree.compiler.Context() backing the module.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to XLA_IMPORT_PASS_PIPELINE.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
 
   Returns:
     An OpaqueBlob representing the compiled module.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 2a0b458..dc289d6 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -95,6 +95,7 @@
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
+    "sort_test.py",
     "strings_test.py",
 ]
 
@@ -112,6 +113,7 @@
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
+    "sort_test.py",
     "strings_test.py",
 ]
 
@@ -218,7 +220,6 @@
     name = "mobile_bert_squad_tests",
     backends_to_srcs = {
         "tf": ["mobile_bert_squad_test.py"],
-        "tflite": ["mobile_bert_squad_test.py"],
     },
     reference_backend = "tf",
     tags = [
@@ -234,6 +235,7 @@
 iree_e2e_test_suite(
     name = "mobile_bert_squad_tests_failing",
     backends_to_srcs = {
+        "tflite": ["mobile_bert_squad_test.py"],
         "iree_vmla": ["mobile_bert_squad_test.py"],
         "iree_llvmjit": ["mobile_bert_squad_test.py"],
         "iree_vulkan": ["mobile_bert_squad_test.py"],
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
index 9c55ead..ae5d0d2 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/BUILD
@@ -22,7 +22,7 @@
 
 iree_cmake_extra_content(
     content = """
-if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} AND NOT ${IREE_TARGET_BACKEND_METAL-SPIRV})
   return()
 endif()
 """,
diff --git a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
index 0922d4a..45d0b21 100644
--- a/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/SPIRVCommon/CMakeLists.txt
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV})
+if(NOT ${IREE_TARGET_BACKEND_VULKAN-SPIRV} AND NOT ${IREE_TARGET_BACKEND_METAL-SPIRV})
   return()
 endif()
 
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index a80a260..6af672a 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -42,6 +42,7 @@
   // Pseudo-ops are illegal.
   // If we end up with a lot of these, consider using an "is pseudo" trait.
   addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
+  addIllegalOp<IREE::VMLA::SortPseudoOp>();
 
   // Allow other ops to pass through so long as their type is valid (not a
   // tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index ccd1a3f..75b640f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -704,6 +704,29 @@
   TypeConverter &typeConverter;
 };
 
+struct SortOpConversion : public OpConversionPattern<IREE::VMLA::SortPseudoOp> {
+  SortOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  LogicalResult matchAndRewrite(
+      IREE::VMLA::SortPseudoOp srcOp, ArrayRef<Value> rawOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto inputType =
+        srcOp.getOperand().getType().cast<ShapedType>().getElementType();
+    auto src = rawOperands[0];
+    auto src_shape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.value(), typeConverter, rewriter);
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    rewriter.createOrFold<IREE::VMLA::SortOp>(srcOp.getLoc(), src, src_shape,
+                                              dst, TypeAttr::get(inputType));
+    rewriter.replaceOp(srcOp, {dst});
+    return success();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
   ConvertOpConversion(MLIRContext *context, TypeConverter &typeConverter)
       : OpConversionPattern(context), typeConverter(typeConverter) {}
@@ -769,6 +792,9 @@
                                    IREE::VMLA::BatchMatMulOp>>(context,
                                                                typeConverter);
 
+  // vmla.sort.pseudo
+  patterns.insert<SortOpConversion>(context, typeConverter);
+
   // Simple 1:1 conversion patterns using the automated trait-based converter.
   // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
   patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
new file mode 100644
index 0000000..0903793
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @sort1D(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[C16:%.+]] = constant 16 : index
+  // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
+  // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+  // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4]>), out [[BL]] : f32
+  // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+  // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4]>) {batch_dims = 0 : i64, dim = 0 : i64} : f32
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+  // CHECK: return [[BUF]] : !vmla.buffer
+  return %sort : tensor<4xf32>
+}
+
+
+// CHECK-LABEL: func @sort2D
+func @sort2D(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[C64:%.+]] = constant 64 : index
+  // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,4]>
+  // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+  // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BL]] : f32
+  // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+  // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4,4]>) {batch_dims = 1 : i64, dim = 1 : i64} : f32
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+  // CHECK: return [[BUF]] : !vmla.buffer
+  return %sort : tensor<4x4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 7e4f6ea..1b66485 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -321,6 +321,7 @@
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::FloorOp, "vmla.floor");
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::CeilOp, "vmla.ceil");
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::RoundOp, "vmla.round");
+  VMLA_TYPED_IMPORT_OP(IREE::VMLA::SortOp, "vmla.sort");
 
   patterns.insert<VMLAConvertImportOpConversion>(context, importSymbols,
                                                  typeConverter, "vmla.convert");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f5cd0fe..422fed3 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -422,7 +422,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// VMLA Ops: Convultion
+// VMLA Ops: Convolution
 //===----------------------------------------------------------------------===//
 
 def VLMA_ConvOp : VMLA_Op<"conv", [VMLA_IncludeShapes]> {
@@ -460,6 +460,46 @@
 }
 
 //===----------------------------------------------------------------------===//
+// VMLA Ops: Sorting
+//===----------------------------------------------------------------------===//
+
+def VMLA_SortPseudoOp : VMLA_Op<"sort.pseudo"> {
+  let summary = "Tensor-level pseudo-op of VMLA::SortOp.";
+  let description = [{
+    This is a tensor-level version of VMLA::SortOp, to facilitate
+    the lowering process.
+
+    This operation generates a sorted index list along the last dimension,
+    performing batch-wise along all other dimensions.
+  }];
+  let arguments = (ins
+    AnyTensor:$value
+  );
+  let results = (outs
+    I32Tensor:$dst
+  );
+
+  let assemblyFormat = [{
+    $value attr-dict `:` `(`type($value)`)` `->` type($dst)
+  }];
+}
+
+def VMLA_SortOp : VMLA_ElementTypeOp<"sort", [VMLA_IncludeShapes]> {
+  let arguments = (ins
+    VMLA_Buffer:$src,
+    VMLA_Shape:$src_shape,
+    VMLA_Buffer:$dst,
+    VMLA_AnyTypeAttr:$element_type
+  );
+
+  let assemblyFormat = [{
+    $src`(`$src_shape `:` type($src_shape)`)``,`
+    `out` $dst attr-dict `:` $element_type
+  }];
+}
+
+
+//===----------------------------------------------------------------------===//
 // VMLA Ops: GEMM/GEMV
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index d937026..62dd921 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -277,6 +277,96 @@
   }
 };
 
+// Lower mhlo::SortOp to an pseudo SortOp in the VMLA dialect. This
+// pseudo op generates a set of ordered indices for that array along the last
+// dimension. Then using a torch_index_select the values can be reordered to
+// support arbitrary inputs.
+//
+// TODO(suderman): This lowering only covers the case of ascending values, we
+// should support a separate descending value case by having separate
+// SortAscending and SortDescending operations.
+class LowerSortOp : public OpRewritePattern<mhlo::SortOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(mhlo::SortOp op,
+                                PatternRewriter &rewriter) const override {
+    auto operandTy = op.getOperand(0).getType().cast<RankedTensorType>();
+    bool lastDimension =
+        (op.dimension() == -1) || (op.dimension() == (operandTy.getRank() - 1));
+
+    // TODO(suderman): Add transpose to sort along the last dimension.
+    if (!lastDimension) return failure();
+
+    auto &comparator = op.comparator();
+    auto &block = comparator.getBlocks().front();
+    auto &operations = block.getOperations();
+    auto comparison = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
+
+    // First verify that the block is purely a return of a comparison. This
+    // handles sorting a single tensor of values.
+    if (!comparison) return failure();
+
+    auto returnOp =
+        dyn_cast_or_null<mhlo::ReturnOp>(&(*(++operations.begin())));
+    if (!returnOp) return failure();
+
+    if (returnOp.getOperand(0) != comparison.getResult()) return failure();
+
+    // Determine which operands being compared.
+    auto lhs = comparison.getOperand(0);
+    auto rhs = comparison.getOperand(1);
+    auto lhsIndex = -1;
+    auto rhsIndex = -1;
+    for (auto arg : llvm::enumerate(block.getArguments())) {
+      if (arg.value() == lhs) lhsIndex = arg.index();
+      if (arg.value() == rhs) rhsIndex = arg.index();
+    }
+
+    // This should never happen but best to check.
+    if (lhsIndex == -1) return failure();
+    if (rhsIndex == -1) return failure();
+
+    // They should not be the same.
+    if (lhsIndex == rhsIndex) return failure();
+
+    // Comparisons need to pull from same Sort operand..
+    auto lhsOperand = lhsIndex / 2;
+    auto rhsOperand = rhsIndex / 2;
+    if (lhsOperand != rhsOperand) return failure();
+
+    // Must be GT, GE, LT, or LE.
+    auto isGt = comparison.comparison_direction() == "GT" ||
+                comparison.comparison_direction() == "GE";
+    auto isLt = comparison.comparison_direction() == "LT" ||
+                comparison.comparison_direction() == "LE";
+    if (!isGt && !isLt) return failure();
+
+    bool operandParity = lhsIndex > rhsIndex;
+    auto isAscending = operandParity ^ isGt;
+    // TODO(suderman): Add support for descended sorting.
+    if (!isAscending) return failure();
+
+    auto operand = op.getOperand(lhsOperand);
+    auto sortedIndices = rewriter.create<VMLA::SortPseudoOp>(
+        op.getLoc(),
+        RankedTensorType::get(operandTy.getShape(), rewriter.getI32Type()),
+        operand);
+
+    llvm::SmallVector<Value, 6> sortedResults;
+    for (auto operand : op.getOperands()) {
+      auto tensorTy = operand.getType().cast<RankedTensorType>();
+      auto gathered = rewriter.create<mhlo::TorchIndexSelectOp>(
+          op.getLoc(), tensorTy, operand, sortedIndices,
+          /**dim=*/operandTy.getRank() - 1,
+          /**batch_dims=*/operandTy.getRank() - 1);
+      sortedResults.push_back(gathered);
+    }
+
+    rewriter.replaceOp(op, sortedResults);
+    return success();
+  }
+};
+
 class PreConversionLoweringPass
     : public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
  public:
@@ -310,6 +400,8 @@
     patterns.insert<LowerBroadcastInDimOp>(context);
     target.addIllegalOp<mhlo::BroadcastOp>();
     patterns.insert<LowerBroadcastOp>(context);
+    target.addIllegalOp<mhlo::SortOp>();
+    patterns.insert<LowerSortOp>(context);
 
     if (failed(applyPartialConversion(getOperation(), target, patterns))) {
       return signalPassFailure();
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
index 3473e44..0b9cd82 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -17,6 +17,38 @@
 // -----
 
 // CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+  // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 0 : i64, dim = 0 : i64}
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+  // CHECK: return [[GATHER]]
+  return %sort : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+  // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 1 : i64, dim = 1 : i64}
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+  // CHECK return [[GATHER]]
+  return %sort : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
 func @f(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
   // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
   %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 17d7e85..ff575b0 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -333,6 +333,20 @@
 vm.import @ceil.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
 vm.import @round.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
 
+
+vm.import @sort.i8(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i16(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.f32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: conversion
 //===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index c6ab9d6..ba5b8bc 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -168,6 +168,12 @@
                         absl::Span<const int32_t> dimensions);
 };
 
+struct Sort {
+  template <typename T>
+  static Status Execute(absl::Span<const T> src_buffer,
+                        absl::Span<int32_t> dst_buffer, ShapeSpan src_shape);
+};
+
 struct Broadcast {
   template <typename T>
   static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index d3545d1..0b4f904 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -15,7 +15,10 @@
 #ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
 #define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
 
+#include <algorithm>
 #include <cmath>
+#include <iostream>
+#include <numeric>
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
@@ -519,6 +522,25 @@
 }
 
 template <typename T>
+Status Sort::Execute(absl::Span<const T> src_buffer,
+                     absl::Span<int32_t> dst_buffer, ShapeSpan src_shape) {
+  int elements = src_buffer.size();
+  const int sort_size = src_shape.back();
+
+  for (int i = 0; i < elements; i += sort_size) {
+    auto src_subspan = src_buffer.subspan(i, sort_size);
+    auto dst_subspan = dst_buffer.subspan(i, sort_size);
+    std::iota(dst_subspan.begin(), dst_subspan.end(), 0);
+    std::stable_sort(dst_subspan.begin(), dst_subspan.end(),
+                     [&src_subspan](int32_t i1, int32_t i2) {
+                       return src_subspan[i1] < src_subspan[i2];
+                     });
+  }
+
+  return OkStatus();
+}
+
+template <typename T>
 Status Broadcast::Execute(absl::Span<const T> src_buffer,
                           absl::Span<T> dst_buffer) {
   for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 09dbb3c..5852de0 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -642,6 +642,19 @@
   IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);
   IREE_VMLA_UNARY_OP(RoundF32, kernels::Round, float);
 
+#define IREE_VMLA_SORT_OP(name, type)                                        \
+  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,       \
+              const vm::ref<Buffer>& dst) {                                  \
+    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                            \
+    return kernels::Sort::Execute<type>(src->As<type>(), dst->As<int32_t>(), \
+                                        src_shape);                          \
+  }
+
+  IREE_VMLA_SORT_OP(SortI8, int8_t);
+  IREE_VMLA_SORT_OP(SortI16, int16_t);
+  IREE_VMLA_SORT_OP(SortI32, int32_t);
+  IREE_VMLA_SORT_OP(SortF32, float);
+
   //===--------------------------------------------------------------------===//
   // VMLA Ops: conversion
   //===--------------------------------------------------------------------===//
@@ -970,6 +983,10 @@
     vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
     vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
     vm::MakeNativeFunction("round.f32", &VMLAModuleState::RoundF32),
+    vm::MakeNativeFunction("sort.i8", &VMLAModuleState::SortI8),
+    vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16),
+    vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32),
+    vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32),
     vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
 
     vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
diff --git a/iree/test/e2e/xla_ops/sort.mlir b/iree/test/e2e/xla_ops/sort.mlir
new file mode 100644
index 0000000..1820d8d
--- /dev/null
+++ b/iree/test/e2e/xla_ops/sort.mlir
@@ -0,0 +1,40 @@
+func @sort1D() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[3, 2, 1, 4]> : tensor<4xi32>
+
+  %sort = "mhlo.sort"(%input) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<4xi32>) -> tensor<4xi32>
+
+  check.expect_eq_const(%sort, dense<[1, 2, 3, 4]> : tensor<4xi32>) : tensor<4xi32>
+  return
+}
+
+func @sort2D() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[[1, 2, 3, 4],
+                                           [4, 3, 2, 1]]> : tensor<2x4xi32>
+
+  %sort = "mhlo.sort"(%input) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = false} : (tensor<2x4xi32>) -> tensor<2x4xi32>
+
+  check.expect_eq_const(%sort, dense<[[1, 2, 3, 4], [1, 2, 3, 4]]> : tensor<2x4xi32>) : tensor<2x4xi32>
+  return
+}
+
+func @sort3D() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[[[1, 2, 3, 4],
+                                            [4, 3, 2, 1]]]> : tensor<1x2x4xi32>
+
+  %sort = "mhlo.sort"(%input) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 2 : i64, is_stable = false} : (tensor<1x2x4xi32>) -> tensor<1x2x4xi32>
+
+  check.expect_eq_const(%sort, dense<[[[1, 2, 3, 4], [1, 2, 3, 4]]]> : tensor<1x2x4xi32>) : tensor<1x2x4xi32>
+  return
+}
diff --git a/scripts/git/google_to_main.sh b/scripts/git/google_to_main.sh
index 8b8543d..b00fd9a 100755
--- a/scripts/git/google_to_main.sh
+++ b/scripts/git/google_to_main.sh
@@ -56,4 +56,7 @@
   echo "${BODY?}"
   exit 1
 fi
-gh pr create --base main --title="${TITLE?}" --body="${BODY?}"
+
+# Workaround https://github.com/cli/cli/issues/1820
+GITHUB_USERNAME="$(gh config get -h github.com user)"
+gh pr create --base main --head="${GITHUB_USERNAME?}:${PR_BRANCH?}" --title="${TITLE?}" --body="${BODY?}"
diff --git a/scripts/git/main_to_google.sh b/scripts/git/main_to_google.sh
index ee07573..b9e56ae 100755
--- a/scripts/git/main_to_google.sh
+++ b/scripts/git/main_to_google.sh
@@ -56,4 +56,7 @@
   echo "${BODY?}"
   exit 1
 fi
-gh pr create --base google --title="${TITLE?}" --body="${BODY?}"
+
+# Workaround https://github.com/cli/cli/issues/1820
+GITHUB_USERNAME="$(gh config get -h github.com user)"
+gh pr create --base google --head="${GITHUB_USERNAME?}:${PR_BRANCH?}" --title="${TITLE?}" --body="${BODY?}"
diff --git a/scripts/git/update_tf_submodule.sh b/scripts/git/update_tf_submodule.sh
index 0031cfa..28799d1 100755
--- a/scripts/git/update_tf_submodule.sh
+++ b/scripts/git/update_tf_submodule.sh
@@ -72,4 +72,7 @@
   echo "${BODY?}"
   exit 1
 fi
-gh pr create --title="${TITLE?}" --body="${BODY?}" --base="${BASE_BRANCH?}"
+
+# Workaround https://github.com/cli/cli/issues/1820
+GITHUB_USERNAME="$(gh config get -h github.com user)"
+gh pr create --base="${BASE_BRANCH?}" --head="${GITHUB_USERNAME?}:${PR_BRANCH?}" --title="${TITLE?}" --body="${BODY?}"