Finish implementing native ABI support. (#5844)
* Does not yet flip it on.
* Adds a pass and a commented out usage to generate ABI metadata for MHLO, allowing JAX tests to pass.
* I was just going to let this go down the untyped path, but it turns out people have been relying on implicit casting between scalar -> 0d tensor, so I just did it right.
* Switches TF returns to emit tuples instead of lists.
* Passes pytree_test.py, mobile bert (uses old kwarg style TF arg passing), and JAX tests.
* Once landed, I will start a branch to flip everything for real, and there will likely be additional triage.
diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py
index b3c5533..f032ab9 100644
--- a/bindings/python/iree/runtime/__init__.py
+++ b/bindings/python/iree/runtime/__init__.py
@@ -24,7 +24,7 @@
# FunctionAbi imports
from .binding import FunctionAbi
# Hal imports
-from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, MemoryAccess, MemoryType, Shape
+from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, HalElementType, MemoryAccess, MemoryType, Shape
# HostTypeFactory imports
from .binding import HostTypeFactory
# Vm imports
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index 6433635..52ddba5 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -150,9 +150,9 @@
f"Malformed function reflection metadata structure: {reflection}")
# See if kwargs are expected.
- if self._ret_descs:
- maybe_kwargs_desc = self._ret_descs[-1]
- if maybe_kwargs_desc and maybe_kwargs_desc[0] == "kwargs_sdict":
+ if self._arg_descs:
+ maybe_kwargs_desc = self._arg_descs[-1]
+ if maybe_kwargs_desc and maybe_kwargs_desc[0] == "sdict_kwargs":
self._has_kwargs = True
def __repr__(self):
@@ -171,10 +171,21 @@
def _int_to_vm(inv: Invocation, t: VmVariantList, x, desc):
+ # Implicit conversion to a 0d tensor.
+ if _is_0d_ndarray_descriptor(desc):
+ casted = _cast_scalar_to_ndarray(inv, x, desc)
+ _ndarray_to_vm(inv, t, casted, desc)
+ return
+
_raise_argument_error(inv, "Python int arguments not yet supported")
def _float_to_vm(inv: Invocation, t: VmVariantList, x, desc):
+ # Implicit conversion to a 0d tensor.
+ if _is_0d_ndarray_descriptor(desc):
+ casted = _cast_scalar_to_ndarray(inv, x, desc)
+ _ndarray_to_vm(inv, t, casted, desc)
+ return
_raise_argument_error(inv, "Python float arguments not yet supported")
@@ -198,7 +209,7 @@
def _dict_to_vm(inv: Invocation, t: VmVariantList, x, desc):
desc_type = desc[0]
- if desc_type != "sdict" and desc_type != "kwargs_sdict":
+ if desc_type != "sdict" and desc_type != "sdict_kwargs":
_raise_argument_error(inv, f"passed a dict but expected {desc_type}")
# When decoding a dict, the desc object is like:
# ['sdict', ['key0', [...value_type_0...]], ['key1', [...value_type_1...]]]]
@@ -277,8 +288,37 @@
return vm_list.get_as_ndarray(vm_index)
+def _vm_to_sdict(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+ # The descriptor for an sdict is like:
+ # ['sdict', ['key1', value1], ...]
+ sub_vm_list = vm_list.get_as_list(vm_index)
+ item_keys = []
+ item_descs = []
+ for k, d in desc[1:]:
+ item_keys.append(k)
+ item_descs.append(d)
+ py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs)
+ return dict(zip(item_keys, py_items))
+
+
+def _vm_to_slist(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+ # The descriptor for an slist is like:
+ # ['slist, item1, ...]
+ sub_vm_list = vm_list.get_as_list(vm_index)
+ item_descs = desc[1:]
+ py_items = _extract_vm_sequence_to_python(inv, sub_vm_list, item_descs)
+ return py_items
+
+
+def _vm_to_stuple(inv: Invocation, vm_list: VmVariantList, vm_index: int, desc):
+ return tuple(_vm_to_slist(inv, vm_list, vm_index, desc))
+
+
VM_TO_PYTHON_CONVERTERS = {
"ndarray": _vm_to_ndarray,
+ "sdict": _vm_to_sdict,
+ "slist": _vm_to_slist,
+ "stuple": _vm_to_stuple,
}
ABI_TYPE_TO_DTYPE = {
@@ -306,6 +346,21 @@
)
+def _is_0d_ndarray_descriptor(desc):
+ # Example: ["ndarray", "f32", 0]
+ return desc[0] == "ndarray" and desc[2] == 0
+
+
+def _cast_scalar_to_ndarray(inv: Invocation, x, desc):
+ # Example descriptor: ["ndarray", "f32", 0]
+ dtype_str = desc[1]
+ try:
+ dtype = ABI_TYPE_TO_DTYPE[dtype_str]
+ except KeyError:
+ _raise_argument_error(inv, f"unrecognized dtype '{dtype_str}'")
+ return dtype(x)
+
+
def _raise_argument_error(inv: Invocation, summary: str, e: Exception = None):
new_e = ValueError(f"Error passing argument: {summary} "
f"(while encoding argument {inv.summarize_arg_error()})")
diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc
index 77f1d9a..c0656b5 100644
--- a/bindings/python/iree/runtime/vm.cc
+++ b/bindings/python/iree/runtime/vm.cc
@@ -291,6 +291,17 @@
"Error moving buffer view");
}
+VmVariantList VmVariantList::GetAsList(int index) {
+ iree_vm_ref_t ref = {0};
+ CheckApiStatus(iree_vm_list_get_ref_assign(raw_ptr(), index, &ref),
+ "Could not access list element");
+ iree_vm_list_t* sub_list = NULL;
+ CheckApiStatus(iree_vm_list_check_deref(ref, &sub_list),
+ "Could not deref list (wrong type?)");
+ iree_vm_list_retain(sub_list);
+ return VmVariantList(sub_list);
+}
+
py::object VmVariantList::GetAsNdarray(int index) {
iree_vm_variant_t v = iree_vm_variant_empty();
CheckApiStatus(iree_vm_list_get_variant(raw_ptr(), index, &v),
@@ -468,6 +479,7 @@
.def_property_readonly("size", &VmVariantList::size)
.def("__len__", &VmVariantList::size)
.def("get_as_ndarray", &VmVariantList::GetAsNdarray)
+ .def("get_as_list", &VmVariantList::GetAsList)
.def("push_list", &VmVariantList::PushList)
.def("push_buffer_view", &VmVariantList::PushBufferView)
.def("__repr__", &VmVariantList::DebugString);
diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h
index 2b43328..33422c0 100644
--- a/bindings/python/iree/runtime/vm.h
+++ b/bindings/python/iree/runtime/vm.h
@@ -101,6 +101,7 @@
void PushList(VmVariantList& other);
void PushBufferView(HalDevice& device, py::object py_buffer_object,
iree_hal_element_type_e element_type);
+ VmVariantList GetAsList(int index);
py::object GetAsNdarray(int index);
private:
diff --git a/bindings/python/iree/runtime/vm_test.py b/bindings/python/iree/runtime/vm_test.py
index 38e177a..e74b071 100644
--- a/bindings/python/iree/runtime/vm_test.py
+++ b/bindings/python/iree/runtime/vm_test.py
@@ -85,6 +85,32 @@
logging.info("variant_list: %s", l)
self.assertEqual(l.size, 0)
+ def test_variant_list_buffers(self):
+ ET = iree.runtime.HalElementType
+ for dt, et in ((np.int8, ET.SINT_8), (np.int16, ET.SINT_16),
+ (np.int32, ET.SINT_32), (np.int64, ET.SINT_64),
+ (np.uint8, ET.UINT_8), (np.uint16, ET.UINT_16),
+ (np.uint32, ET.UINT_32), (np.uint64, ET.UINT_64),
+ (np.float32, ET.FLOAT_32), (np.float64, ET.FLOAT_64)):
+ # TODO: Unimplemented: (np.float16, ET.FLOAT_16)
+ lst = iree.runtime.VmVariantList(5)
+ ary1 = np.asarray([1, 2, 3, 4], dtype=dt)
+ lst.push_buffer_view(self.device, ary1, et)
+ ary2 = lst.get_as_ndarray(0)
+ np.testing.assert_array_equal(ary1, ary2)
+ with self.assertRaises(IndexError):
+ lst.get_as_ndarray(1)
+
+ def test_variant_list_list(self):
+ lst1 = iree.runtime.VmVariantList(5)
+ lst2 = iree.runtime.VmVariantList(5)
+ lst1.push_list(lst2)
+ self.assertEqual("<VmVariantList(1): [List[]]>", str(lst1))
+ lstout = lst1.get_as_list(0)
+ self.assertEqual("<VmVariantList(0): []>", str(lstout))
+ with self.assertRaises(IndexError):
+ lst1.get_as_list(1)
+
def test_context_id(self):
instance = iree.runtime.VmInstance()
context1 = iree.runtime.VmContext(instance)
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
index c7a9002..585cb22 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD
@@ -22,6 +22,7 @@
name = "TF",
srcs = [
"ConvertToMHLO.cpp",
+ "EmitDefaultIREEABI.cpp",
"FlattenTuplesInCFG.cpp",
"LowerExportedFunctions.cpp",
"LowerGlobalTensors.cpp",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/EmitDefaultIREEABI.cpp b/integrations/tensorflow/iree_tf_compiler/TF/EmitDefaultIREEABI.cpp
new file mode 100644
index 0000000..0d3ff47
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TF/EmitDefaultIREEABI.cpp
@@ -0,0 +1,123 @@
+// Copyright 2021 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree_tf_compiler/TF/Passes.h"
+#include "llvm/Support/JSON.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace json = llvm::json;
+
+namespace mlir {
+namespace iree_integrations {
+namespace TF {
+
+class EmitDefaultIREEABIPass
+ : public PassWrapper<EmitDefaultIREEABIPass, OperationPass<FuncOp>> {
+ public:
+ void runOnOperation() override {
+ auto funcOp = getOperation();
+ if (SymbolTable::getSymbolVisibility(funcOp) !=
+ SymbolTable::Visibility::Public) {
+ return;
+ }
+ if (funcOp->hasAttr("iree.abi")) {
+ return;
+ }
+
+ json::Array refArgs;
+ for (Type t : funcOp.getArgumentTypes()) {
+ auto descriptor = mapTypeToJsonTypeRecord(t);
+ if (!descriptor) {
+ funcOp.emitWarning()
+ << "unable to generate reflection descriptor for argument type "
+ << t;
+ return;
+ }
+ refArgs.push_back(*descriptor);
+ }
+
+ json::Array refReturns;
+ for (Type t : funcOp.getCallableResults()) {
+ auto descriptor = mapTypeToJsonTypeRecord(t);
+ if (!descriptor) {
+ funcOp.emitWarning()
+ << "unable to generate reflection descriptor for result type " << t;
+ return;
+ }
+ refReturns.push_back(*descriptor);
+ }
+
+ Builder builder(&getContext());
+ json::Object refDict;
+ refDict["v"] = json::Value(1);
+ refDict["a"] = json::Value(std::move(refArgs));
+ refDict["r"] = json::Value(std::move(refReturns));
+ json::Value refDictValue(std::move(refDict));
+ std::string refStr;
+ llvm::raw_string_ostream refOut(refStr);
+ refOut << refDictValue;
+ refOut.flush();
+ funcOp->setAttr("iree.abi", builder.getStringAttr(refStr));
+ }
+
+ llvm::Optional<json::Value> mapTypeToJsonTypeRecord(Type type) {
+ if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ json::Array record({
+ json::Value("ndarray"),
+ mapTypeToJsonTypeRecord(shapedType.getElementType()),
+ shapedType.hasRank() ? json::Value(shapedType.getRank())
+ : json::Value(nullptr),
+ });
+ if (shapedType.hasRank()) {
+ for (auto dim : shapedType.getShape()) {
+ record.push_back(dim == ShapedType::kDynamicSize
+ ? json::Value(nullptr)
+ : json::Value(dim));
+ }
+ }
+ return json::Value(std::move(record));
+ }
+
+ // Primitives.
+ if (auto integerType = type.dyn_cast<IntegerType>()) {
+ std::string name = (Twine("i") + Twine(integerType.getWidth())).str();
+ return json::Value(std::move(name));
+ }
+ if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (floatType == FloatType::getBF16(floatType.getContext())) {
+ // Why Google?
+ return json::Value("bf16");
+ }
+ std::string name = (Twine("f") + Twine(floatType.getWidth())).str();
+ return json::Value(std::move(name));
+ }
+
+ return llvm::None;
+ }
+};
+
+std::unique_ptr<OperationPass<FuncOp>> createEmitDefaultIREEABIPass() {
+ return std::make_unique<EmitDefaultIREEABIPass>();
+}
+
+static PassRegistration<EmitDefaultIREEABIPass> funcPass(
+ "iree-tf-emit-default-iree-abi", "Emits simple default ABI metadata");
+
+} // namespace TF
+} // namespace iree_integrations
+} // namespace mlir
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
index 3985472..7a4ed52 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.cpp
@@ -120,6 +120,10 @@
// - It removes tf_saved_model.semantics from the module, which we can only
// do at the very end.
pm.addPass(createLowerExportedFunctionsPass());
+ // TODO: Remove the above and uncomment the below to enable IREE native ABI.
+ // pm.addPass(createSavedModelToIREEABIPass());
+ // // Inline the wrapper functions.
+ // pm.addPass(createInlinerPass());
//----------------------------------------------------------------------------
// Ensure that all Tensorflow has been legalized away
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
index a34f563..8ebc87a 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
+++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h
@@ -46,6 +46,11 @@
// Converts the TF dialect to the XLA MHLO dialect.
std::unique_ptr<FunctionPass> createConvertToMHLOPass();
+// Annotates an appropriate iree.abi attribute on public functions that
+// operate exclusively on tensor types. This corresponds to the expectations
+// of MHLO and is suitable for such programs.
+std::unique_ptr<OperationPass<FuncOp>> createEmitDefaultIREEABIPass();
+
// Flattens tuple values in function signatures and blocks.
std::unique_ptr<OperationPass<ModuleOp>> createFlattenTuplesInCFGPass();
@@ -93,6 +98,7 @@
registerMHLOImportPassPipeline();
createConvertToMHLOPass();
+ createEmitDefaultIREEABIPass();
createFlattenTuplesInCFGPass();
createLowerGlobalTensorsPass();
createLowerExportedFunctionsPass();
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp b/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
index 9d946b6..b5ee466 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TF/SavedModelToIreeABI.cpp
@@ -327,7 +327,7 @@
}
StructureLevel *bindValue(Location loc, int newValueIndex, Type valueType,
- ArrayAttr indexPathAttr) {
+ ArrayAttr indexPathAttr, bool bindTuple = false) {
StructureLevel *current = this;
// Move forward through non terminal path segments.
for (Attribute indexAttr : indexPathAttr) {
@@ -337,7 +337,8 @@
if (!current) return nullptr;
} else if (auto intAttr = indexAttr.dyn_cast<IntegerAttr>()) {
int childIndex = intAttr.getInt();
- current = current->allocateChild(loc, childIndex);
+ current =
+ current->allocateChild(loc, childIndex, /*asTuple=*/bindTuple);
if (!current) return nullptr;
} else {
emitError(loc)
@@ -380,8 +381,10 @@
return &children.back();
}
- StructureLevel *allocateChild(Location loc, int childIndex) {
- if (type == LevelType::None) type = LevelType::List;
+ StructureLevel *allocateChild(Location loc, int childIndex,
+ bool asTuple = false) {
+ if (type == LevelType::None)
+ type = asTuple ? LevelType::Tuple : LevelType::List;
if (type != LevelType::List && type != LevelType::Tuple) {
emitError(loc) << "structure path mismatch: dereference a non-sequence "
<< "with a sequence key " << childIndex;
@@ -439,8 +442,12 @@
<< " on result " << i;
}
internalFunc.removeResultAttr(i, savedModelIndexPathIdent);
+ // TODO: The TensorFlow SavedModel attribute system does not distinguish
+ // lists from tuples, but TensorFlow internally does. Until this is
+ // plumbed through somehow, arbitrarily emit results as tuples as that
+ // was determined by someone at some point to be more canonical.
if (!resultsRoot.bindValue(loc, i, internalFuncType.getResult(i),
- indexPathAttr)) {
+ indexPathAttr, /*bindTuple=*/true)) {
return failure();
}
}
@@ -449,7 +456,7 @@
// towards multi-return safe by converting to tuple.
// TODO: Investigate upstream whether there are additional signals to be
// plumbed.
- bool isMultiResult = resultsRoot.type == LevelType::List;
+ bool isMultiResult = resultsRoot.type == LevelType::Tuple;
// Build the wrapper function type.
SmallVector<Type> wrapperArgTypes;
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD
index 9ee3af8..1275d72 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/BUILD
@@ -26,6 +26,7 @@
srcs = enforce_glob(
[
"convert_to_mhlo.mlir",
+ "emit_default_iree_abi.mlir",
"lower_global_tensors.mlir",
"lower_global_tensors_complex.mlir",
"lower_global_tensors_invalid.mlir",
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/emit_default_iree_abi.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/emit_default_iree_abi.mlir
new file mode 100644
index 0000000..4598dbd
--- /dev/null
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/emit_default_iree_abi.mlir
@@ -0,0 +1,7 @@
+// RUN: iree-tf-opt %s -iree-tf-emit-default-iree-abi -split-input-file -verify-diagnostics | IreeFileCheck %s
+
+// CHECK-LABEL: func @valid
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22ndarray\22,\22f32\22,2,2,3],[\22ndarray\22,\22f32\22,1,3]],\22r\22:[[\22ndarray\22,\22f32\22,1,3],[\22ndarray\22,\22f32\22,2,2,3]],\22v\22:1}"
+func @valid(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> (tensor<3xf32>, tensor<2x3xf32>) {
+ return %arg1, %arg0 : tensor<3xf32>, tensor<2x3xf32>
+}
diff --git a/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir b/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
index 2c27083..24ee86c 100644
--- a/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
+++ b/integrations/tensorflow/iree_tf_compiler/TF/test/saved_model_to_iree_abi.mlir
@@ -50,7 +50,7 @@
// -----
// CHECK-LABEL: module @dict_nest
// CHECK: func @dict_nest
-// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22stuple\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
// CHECK: %[[c0:.+]] = constant 0 : index
// CHECK: %[[L0:.+]] = iree.list.get %arg0[%[[c0]]] : !iree.list<?> -> !iree.list<?>
// CHECK: %[[c0_0:.+]] = constant 0 : index
@@ -98,7 +98,7 @@
// -----
// CHECK-LABEL: module @kwargs
// CHECK: func @dict_nest
-// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0],[\22sdict_kwargs\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
+// CHECK-SAME{LITERAL}: iree.abi = "{\22a\22:[[\22sdict\22,[\22list\22,[\22slist\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]],[\22ndarray\22,\22f32\22,0],[\22sdict_kwargs\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],\22r\22:[[\22sdict\22,[\22dict\22,[\22sdict\22,[\22a\22,[\22ndarray\22,\22f32\22,1,16]],[\22b\22,[\22ndarray\22,\22f32\22,1,16]]]],[\22list\22,[\22stuple\22,[\22ndarray\22,\22f32\22,1,16],[\22ndarray\22,\22f32\22,1,16]]]]],\22v\22:1}"
module @kwargs attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 729 : i32}, tf_saved_model.semantics} {
func @__inference_dict_nest_190(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["a"]}, %arg1: tensor<16xf32> {tf_saved_model.index_path = ["b"]}, %arg2: tensor<16xf32> {tf._user_specified_name = "mapping", tf_saved_model.index_path = [0, "list", 0]}, %arg3: tensor<16xf32> {tf._user_specified_name = "mapping", tf_saved_model.index_path = [0, "list", 1]}, %arg4: tensor<f32> {tf._user_specified_name = "scalar", tf_saved_model.index_path = [1]}) -> (tensor<16xf32> {tf_saved_model.index_path = ["dict", "a"]}, tensor<16xf32> {tf_saved_model.index_path = ["dict", "b"]}, tensor<16xf32> {tf_saved_model.index_path = ["list", 0]}, tensor<16xf32> {tf_saved_model.index_path = ["list", 1]}) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf.shape<16>, #tf.shape<16>, #tf.shape<16>, #tf.shape<16>, #tf.shape<>], tf_saved_model.exported_names = ["dict_nest"]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<16xf32>) -> tensor<16xf32>
diff --git a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
index 78abe0c..e18d1eb 100644
--- a/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/iree-import-xla-main.cpp
@@ -189,7 +189,7 @@
return 2;
}
- // Find the entry function an annotate it as exported.
+ // Find the entry function and annotate it as exported.
// Note that the XLA importer always produced an MLIR module with a @main
// function.
std::string entryName = "main";
@@ -221,6 +221,13 @@
applyPassManagerCLOptions(pm);
iree_integrations::TF::buildMHLOImportPassPipeline(pm);
+
+ // Note that we emit the ABI last since any needed function-level
+ // transformations (i.e. de-tupling, etc) should have been done.
+ // TODO: Uncomment this to enable IREE native bindings.
+ // pm.addNestedPass<FuncOp>(
+ // iree_integrations::TF::createEmitDefaultIREEABIPass());
+
if (failed(pm.run(*module))) {
llvm::errs()
<< "Running iree-xla-import pass pipeline failed (see diagnostics)\n";