Finish implementing native ABI support. (#5844)

* Does not yet flip it on.
* Adds a pass and a commented out usage to generate ABI metadata for MHLO, allowing JAX tests to pass.
  * I was just going to let this go down the untyped path, but it turns out people have been relying on implicit casting between scalar -> 0d tensor, so I just did it right.
* Switches TF returns to emit tuples instead of lists.
* Passes pytree_test.py, mobile bert (uses old kwarg style TF arg passing), and JAX tests.
* Once landed, I will start a branch to flip everything for real, and there will likely be additional triage.
diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py
index b3c5533..f032ab9 100644
--- a/bindings/python/iree/runtime/__init__.py
+++ b/bindings/python/iree/runtime/__init__.py
@@ -24,7 +24,7 @@
 # FunctionAbi imports
 from .binding import FunctionAbi
 # Hal imports
-from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, MemoryAccess, MemoryType, Shape
+from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, HalElementType, MemoryAccess, MemoryType, Shape
 # HostTypeFactory imports
 from .binding import HostTypeFactory
 # Vm imports
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 6433635..52ddba5 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -150,9 +150,9 @@
           f"Malformed function reflection metadata structure: {reflection}")
 
     # See if kwargs are expected.
-    if self._ret_descs:
-      maybe_kwargs_desc = self._ret_descs[-1]
-      if maybe_kwargs_desc and maybe_kwargs_desc[0] == "kwargs_sdict":
+    if self._arg_descs:
+      maybe_kwargs_desc = self._arg_descs[-1]
+      if maybe_kwargs_desc and maybe_kwargs_desc[0] == "sdict_kwargs":
         self._has_kwargs = True
 
   def __repr__(self):
@@ -171,10 +171,21 @@
 
 
 def _int_to_vm(inv: Invocation, t: VmVariantList, x, desc):
+  # Implicit conversion to a 0d tensor.
+  if _is_0d_ndarray_descriptor(desc):
+    casted = _cast_scalar_to_ndarray(inv, x, desc)
+    _ndarray_to_vm(inv, t, casted, desc)
+    return
+
   _raise_argument_error(inv, "Python int arguments not yet supported")
 
 
 def _float_to_vm(inv: Invocation, t: VmVariantList, x, desc):
+  # Implicit conversion to a 0d tensor.
+  if _is_0d_ndarray_descriptor(desc):
+    casted = _cast_scalar_to_ndarray(inv, x, desc)
+    _ndarray_to_vm(inv, t, casted, desc)
+    return
   _raise_argument_error(inv, "Python float arguments not yet supported")
 
 
@@ -198,7 +209,7 @@
 
 def _dict_to_vm(inv: Invocation, t: VmVariantList, x, desc):
   desc_type = desc[0]
-  if desc_type != "sdict" and desc_type != "kwargs_sdict":
+  if desc_type != "sdict" and desc_type != "sdict_kwargs":
     _raise_argument_error(inv, f"passed a dict but expected {desc_type}")
   # When decoding a dict, the desc object is like:
   # ['sdict', ['key0', [...value_type_0...]], ['key1', [...value_type_1...]]]]
@@ -277,8 +288,37 @@
   return vm_list.get_as_ndarray(vm_index)
 
 
+def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  # The descriptor for an sdict is like:
+  #   ['sdict', ['key1', value1], ...]
+  sub_vm_list = vm_list.get_as_list(vm_index)
+  item_keys = []
+  item_descs = []
+  for k, d in desc[1:]:
+    item_keys.append(k)
+    item_descs.append(d)
+  py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs)
+  return dict(zip(item_keys, py_items))
+
+
+def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  # The descriptor for an slist is like:
+  #   ['slist, item1, ...]
+  sub_vm_list = vm_list.get_as_list(vm_index)
+  item_descs = desc[1:]
+  py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs)
+  return py_items
+
+
+def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+  return tuple(_vm_to_slist(inv, vm_list, vm_index, desc))
+
+
 VM_TO_PYTHON_CONVERTERS = {
     "ndarray": _vm_to_ndarray,
+    "sdict": _vm_to_sdict,
+    "slist": _vm_to_slist,
+    "stuple": _vm_to_stuple,
 }
 
 ABI_TYPE_TO_DTYPE = {
@@ -306,6 +346,21 @@
 )
 
 
+def _is_0d_ndarray_descriptor(desc):
+  # Example: ["ndarray", "f32", 0]
+  return desc[0] == "ndarray" and desc[2] == 0
+
+
+def _cast_scalar_to_ndarray(inv: Invocation, x, desc):
+  # Example descriptor: ["ndarray", "f32", 0]
+  dtype_str = desc[1]
+  try:
+    dtype = ABI_TYPE_TO_DTYPE[dtype_str]
+  except KeyError:
+    _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'")
+  return dtype(x)
+
+
 def _raise_argument_error(inv: Invocation, summary: str, e: Exception = None):
   new_e = ValueError(f"Error passing argument: {summary} "
                      f"(while encoding argument {inv.summarize_arg_error()})")
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index 77f1d9a..c0656b5 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -291,6 +291,17 @@
                  "Error moving buffer view");
 }
 
+VmVariantList VmVariantList::GetAsList(int index) {
+  iree_vm_ref_t ref = {0};
+  CheckApiStatus(iree_vm_list_get_ref_assign(raw_ptr(), index, &ref),
+                 "Could not access list element");
+  iree_vm_list_t* sub_list = NULL;
+  CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list),
+                 "Could not deref list (wrong type?)");
+  iree_vm_list_retain(sub_list);
+  return VmVariantList(sub_list);
+}
+
 py::object VmVariantList::GetAsNdarray(int index) {
   iree_vm_variant_t v = iree_vm_variant_empty();
   CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
@@ -468,6 +479,7 @@
       .def_property_readonly("size", &VmVariantList::size)
       .def("__len__", &VmVariantList::size)
       .def("get_as_ndarray", &VmVariantList::GetAsNdarray)
+      .def("get_as_list", &VmVariantList::GetAsList)
       .def("push_list", &VmVariantList::PushList)
       .def("push_buffer_view", &VmVariantList::PushBufferView)
       .def("__repr__", &VmVariantList::DebugString);
diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h
index 2b43328..33422c0 100644
--- a/bindings/python/iree/runtime/vm.h
+++ b/bindings/python/iree/runtime/vm.h
@@ -101,6 +101,7 @@
   void PushList(VmVariantList& other);
   void PushBufferView(HalDevice& device, py::object py_buffer_object,
                       iree_hal_element_type_e element_type);
+  VmVariantList GetAsList(int index);
   py::object GetAsNdarray(int index);
 
  private:
diff --git a/bindings/python/iree/runtime/vm_test.py b/bindings/python/iree/runtime/vm_test.py
index 38e177a..e74b071 100644
--- a/bindings/python/iree/runtime/vm_test.py
+++ b/bindings/python/iree/runtime/vm_test.py
@@ -85,6 +85,32 @@
     logging.info("variant_list: %s", l)
     self.assertEqual(l.size, 0)
 
+  def test_variant_list_buffers(self):
+    ET = iree.runtime.HalElementType
+    for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
+                   (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
+                   (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
+                   (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
+                   (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
+      # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
+      lst = iree.runtime.VmVariantList(5)
+      ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
+      lst.push_buffer_view(self.device, ary1, et)
+      ary2 = lst.get_as_ndarray(0)
+      np.testing.assert_array_equal(ary1, ary2)
+      with self.assertRaises(IndexError):
+        lst.get_as_ndarray(1)
+
+  def test_variant_list_list(self):
+    lst1 = iree.runtime.VmVariantList(5)
+    lst2 = iree.runtime.VmVariantList(5)
+    lst1.push_list(lst2)
+    self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1))
+    lstout = lst1.get_as_list(0)
+    self.assertEqual("<VmVariantList(0): []>", str(lstout))
+    with self.assertRaises(IndexError):
+      lst1.get_as_list(1)
+
   def test_context_id(self):
     instance = iree.runtime.VmInstance()
     context1 = iree.runtime.VmContext(instance)
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index c7a9002..585cb22 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -22,6 +22,7 @@
     name = "TF",
     srcs = [
         "ConvertToMHLO.cpp",
+        "EmitDefaultIREEABI.cpp",
         "FlattenTuplesInCFG.cpp",
         "LowerExportedFunctions.cpp",
         "LowerGlobalTensors.cpp",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/EmitDefaultIREEABI.cpp b/integrations/tensorflow/iree_tf_compiler/TF/EmitDefaultIREEABI.cpp
new file mode 100644
index 0000000..0d3ff47
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TF/EmitDefaultIREEABI.cpp
@@ -0,0 +1,123 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TF/Passes.h"
+#include "llvm/Support/JSON.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace json = llvm::json;
+
+namespace mlir {
+namespace iree_integrations {
+namespace TF {
+
+class EmitDefaultIREEABIPass
+    : public PassWrapper<EmitDefaultIREEABIPass, OperationPass<FuncOp>> {
+ public:
+  void runOnOperation() override {
+    auto funcOp = getOperation();
+    if (SymbolTable::getSymbolVisibility(funcOp) !=
+        SymbolTable::Visibility::Public) {
+      return;
+    }
+    if (funcOp->hasAttr("iree.abi")) {
+      return;
+    }
+
+    json::Array refArgs;
+    for (Type t : funcOp.getArgumentTypes()) {
+      auto descriptor = mapTypeToJsonTypeRecord(t);
+      if (!descriptor) {
+        funcOp.emitWarning()
+            << "unable to generate reflection descriptor for argument type "
+            << t;
+        return;
+      }
+      refArgs.push_back(*descriptor);
+    }
+
+    json::Array refReturns;
+    for (Type t : funcOp.getCallableResults()) {
+      auto descriptor = mapTypeToJsonTypeRecord(t);
+      if (!descriptor) {
+        funcOp.emitWarning()
+            << "unable to generate reflection descriptor for result type " << t;
+        return;
+      }
+      refReturns.push_back(*descriptor);
+    }
+
+    Builder builder(&getContext());
+    json::Object refDict;
+    refDict["v"] = json::Value(1);
+    refDict["a"] = json::Value(std::move(refArgs));
+    refDict["r"] = json::Value(std::move(refReturns));
+    json::Value refDictValue(std::move(refDict));
+    std::string refStr;
+    llvm::raw_string_ostream refOut(refStr);
+    refOut << refDictValue;
+    refOut.flush();
+    funcOp->setAttr("iree.abi", builder.getStringAttr(refStr));
+  }
+
+  llvm::Optional<json::Value> mapTypeToJsonTypeRecord(Type type) {
+    if (auto shapedType = type.dyn_cast<ShapedType>()) {
+      json::Array record({
+          json::Value("ndarray"),
+          mapTypeToJsonTypeRecord(shapedType.getElementType()),
+          shapedType.hasRank() ? json::Value(shapedType.getRank())
+                               : json::Value(nullptr),
+      });
+      if (shapedType.hasRank()) {
+        for (auto dim : shapedType.getShape()) {
+          record.push_back(dim == ShapedType::kDynamicSize
+                               ? json::Value(nullptr)
+                               : json::Value(dim));
+        }
+      }
+      return json::Value(std::move(record));
+    }
+
+    // Primitives.
+    if (auto integerType = type.dyn_cast<IntegerType>()) {
+      std::string name = (Twine("i") + Twine(integerType.getWidth())).str();
+      return json::Value(std::move(name));
+    }
+    if (auto floatType = type.dyn_cast<FloatType>()) {
+      if (floatType == FloatType::getBF16(floatType.getContext())) {
+        // Why Google?
+        return json::Value("bf16");
+      }
+      std::string name = (Twine("f") + Twine(floatType.getWidth())).str();
+      return json::Value(std::move(name));
+    }
+
+    return llvm::None;
+  }
+};
+
+std::unique_ptr<OperationPass<FuncOp>> createEmitDefaultIREEABIPass() {
+  return std::make_unique<EmitDefaultIREEABIPass>();
+}
+
+static PassRegistration<EmitDefaultIREEABIPass> funcPass(
+    "iree-tf-emit-default-iree-abi", "Emits simple default ABI metadata");
+
+}  // namespace TF
+}  // namespace iree_integrations
+}  // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
index 3985472..7a4ed52 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
@@ -120,6 +120,10 @@
   // - It removes tf_saved_model.semantics from the module, which we can only
   //   do at the very end.
   pm.addPass(createLowerExportedFunctionsPass());
+  // TODO: Remove the above and uncomment the below to enable IREE native ABI.
+  // pm.addPass(createSavedModelToIREEABIPass());
+  // // Inline the wrapper functions.
+  // pm.addPass(createInlinerPass());
 
   //----------------------------------------------------------------------------
   // Ensure that all Tensorflow has been legalized away
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
index a34f563..8ebc87a 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
@@ -46,6 +46,11 @@
 // Converts the TF dialect to the XLA MHLO dialect.
 std::unique_ptr<FunctionPass> createConvertToMHLOPass();
 
+// Annotates an appropriate iree.abi attribute on public functions that
+// operate exclusively on tensor types. This corresponds to the expectations
+// of MHLO and is suitable for such programs.
+std::unique_ptr<OperationPass<FuncOp>> createEmitDefaultIREEABIPass();
+
 // Flattens tuple values in function signatures and blocks.
 std::unique_ptr<OperationPass<ModuleOp>> createFlattenTuplesInCFGPass();
 
@@ -93,6 +98,7 @@
   registerMHLOImportPassPipeline();
 
   createConvertToMHLOPass();
+  createEmitDefaultIREEABIPass();
   createFlattenTuplesInCFGPass();
   createLowerGlobalTensorsPass();
   createLowerExportedFunctionsPass();
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp b/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
index 9d946b6..b5ee466 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
@@ -327,7 +327,7 @@
   }
 
   StructureLevel *bindValue(Location loc, int newValueIndex, Type valueType,
-                            ArrayAttr indexPathAttr) {
+                            ArrayAttr indexPathAttr, bool bindTuple = false) {
     StructureLevel *current = this;
     // Move forward through non terminal path segments.
     for (Attribute indexAttr : indexPathAttr) {
@@ -337,7 +337,8 @@
         if (!current) return nullptr;
       } else if (auto intAttr = indexAttr.dyn_cast<IntegerAttr>()) {
         int childIndex = intAttr.getInt();
-        current = current->allocateChild(loc, childIndex);
+        current =
+            current->allocateChild(loc, childIndex, /*asTuple=*/bindTuple);
         if (!current) return nullptr;
       } else {
         emitError(loc)
@@ -380,8 +381,10 @@
     return &children.back();
   }
 
-  StructureLevel *allocateChild(Location loc, int childIndex) {
-    if (type == LevelType::None) type = LevelType::List;
+  StructureLevel *allocateChild(Location loc, int childIndex,
+                                bool asTuple = false) {
+    if (type == LevelType::None)
+      type = asTuple ? LevelType::Tuple : LevelType::List;
     if (type != LevelType::List && type != LevelType::Tuple) {
       emitError(loc) << "structure path mismatch: dereference a non-sequence "
                      << "with a sequence key " << childIndex;
@@ -439,8 +442,12 @@
              << " on result " << i;
     }
     internalFunc.removeResultAttr(i, savedModelIndexPathIdent);
+    // TODO: The TensorFlow SavedModel attribute system does not distinguish
+    // lists from tuples, but TensorFlow internally does. Until this is
+    // plumbed through somehow, arbitrarily emit results as tuples as that
+    // was determined by someone at some point to be more canonical.
     if (!resultsRoot.bindValue(loc, i, internalFuncType.getResult(i),
-                               indexPathAttr)) {
+                               indexPathAttr, /*bindTuple=*/true)) {
       return failure();
     }
   }
@@ -449,7 +456,7 @@
   // towards multi-return safe by converting to tuple.
   // TODO: Investigate upstream whether there are additional signals to be
   // plumbed.
-  bool isMultiResult = resultsRoot.type == LevelType::List;
+  bool isMultiResult = resultsRoot.type == LevelType::Tuple;
 
   // Build the wrapper function type.
   SmallVector<Type> wrapperArgTypes;
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD
index 9ee3af8..1275d72 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD
@@ -26,6 +26,7 @@
     srcs = enforce_glob(
         [
             "convert_to_mhlo.mlir",
+            "emit_default_iree_abi.mlir",
             "lower_global_tensors.mlir",
             "lower_global_tensors_complex.mlir",
             "lower_global_tensors_invalid.mlir",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/emit_default_iree_abi.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/emit_default_iree_abi.mlir
new file mode 100644
index 0000000..4598dbd
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/emit_default_iree_abi.mlir
@@ -0,0 +1,7 @@
+// RUN: iree-tf-opt %s -iree-tf-emit-default-iree-abi -split-input-file -verify-diagnostics | IreeFileCheck %s
+
+// CHECK-LABEL: func @valid
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22ndarray\22,\22f32\22,2,2,3],[\22ndarray\22,\22f32\22,1,3]],\22r\22:[[\22ndarray\22,\22f32\22,1,3],[\22ndarray\22,\22f32\22,2,2,3]],\22v\22:1}"
+func @valid(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> (tensor<3xf32>, tensor<2x3xf32>) {
+  return %arg1, %arg0 : tensor<3xf32>, tensor<2x3xf32>
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
index 2c27083..24ee86c 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
@@ -50,7 +50,7 @@
 // -----
 // CHECK-LABEL: module @dict_nest
 // CHECK: func @dict_nest
-// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22stuple\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
 // CHECK: %[[c0:.+]] = constant 0 : index
 // CHECK: %[[L0:.+]] = iree.list.get %arg0[%[[c0]]] : !iree.list<?> -> !iree.list<?>
 // CHECK: %[[c0_0:.+]] = constant 0 : index
@@ -98,7 +98,7 @@
 // -----
 // CHECK-LABEL: module @kwargs
 // CHECK: func @dict_nest
-// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0],[\22sdict_kwargs\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0],[\22sdict_kwargs\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22stuple\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
 module @kwargs attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 729 : i32}, tf_saved_model.semantics}  {
   func @__inference_dict_nest_190(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["a"]}, %arg1: tensor<16xf32> {tf_saved_model.index_path = ["b"]}, %arg2: tensor<16xf32> {tf._user_specified_name = "mapping", tf_saved_model.index_path = [0, "list", 0]}, %arg3: tensor<16xf32> {tf._user_specified_name = "mapping", tf_saved_model.index_path = [0, "list", 1]}, %arg4: tensor<f32> {tf._user_specified_name = "scalar", tf_saved_model.index_path = [1]}) -> (tensor<16xf32> {tf_saved_model.index_path = ["dict", "a"]}, tensor<16xf32> {tf_saved_model.index_path = ["dict", "b"]}, tensor<16xf32> {tf_saved_model.index_path = ["list", 0]}, tensor<16xf32> {tf_saved_model.index_path = ["list", 1]}) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf.shape<16>, #tf.shape<16>, #tf.shape<16>, #tf.shape<16>, #tf.shape<>], tf_saved_model.exported_names = ["dict_nest"]} {
     %0 = "tf.Identity"(%arg0) {device = ""} : (tensor<16xf32>) -> tensor<16xf32>
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 78abe0c..e18d1eb 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -189,7 +189,7 @@
     return 2;
   }
 
-  // Find the entry function an annotate it as exported.
+  // Find the entry function and annotate it as exported.
   // Note that the XLA importer always produced an MLIR module with a @main
   // function.
   std::string entryName = "main";
@@ -221,6 +221,13 @@
   applyPassManagerCLOptions(pm);
 
   iree_integrations::TF::buildMHLOImportPassPipeline(pm);
+
+  // Note that we emit the ABI last since any needed function-level
+  // transformations (i.e. de-tupling, etc) should have been done.
+  // TODO: Uncomment this to enable IREE native bindings.
+  // pm.addNestedPass<FuncOp>(
+  //     iree_integrations::TF::createEmitDefaultIREEABIPass());
+
   if (failed(pm.run(*module))) {
     llvm::errs()
         << "Running iree-xla-import pass pipeline failed (see diagnostics)\n";