Add python-only and native flag parsing to the Python API. (#8261)
* Add python-only and native flag parsing to the Python API.
* Exposes one python flag `iree.runtime.flags.FUNCTION_INPUT_VALIDATION` (default True), which can be used to disable some expensive validation checks.
* Adds `iree.runtime.flags.parse_flags("--flag1", "--flag2")` to set native flags from Python.
diff --git a/bindings/python/iree/runtime/CMakeLists.txt b/bindings/python/iree/runtime/CMakeLists.txt
index fcae8c8..ee59990 100644
--- a/bindings/python/iree/runtime/CMakeLists.txt
+++ b/bindings/python/iree/runtime/CMakeLists.txt
@@ -37,6 +37,7 @@
DEPS
iree::base
iree::base::cc
+ iree::base::internal::flags
iree::hal
iree::hal::drivers
iree::modules::hal
@@ -50,6 +51,7 @@
SRCS
"__init__.py"
"array_interop.py"
+ "flags.py"
"function.py"
"system_api.py"
"tracing.py"
@@ -100,6 +102,13 @@
iree_py_test(
NAME
+ flags_test
+ SRCS
+ "flags_test.py"
+)
+
+iree_py_test(
+ NAME
function_test
SRCS
"function_test.py"
diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py
index 9d2e1f0..e0d8643 100644
--- a/bindings/python/iree/runtime/__init__.py
+++ b/bindings/python/iree/runtime/__init__.py
@@ -43,3 +43,5 @@
from .system_api import *
from .function import *
from .tracing import *
+
+from . import flags
diff --git a/bindings/python/iree/runtime/flags.py b/bindings/python/iree/runtime/flags.py
new file mode 100644
index 0000000..a7b1020
--- /dev/null
+++ b/bindings/python/iree/runtime/flags.py
@@ -0,0 +1,12 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .binding import parse_flags
+
+# When enabled, performs additional function input validation checks. In the
+# event of errors, this will yield nicer error messages but comes with a
+# runtime cost.
+FUNCTION_INPUT_VALIDATION = True
diff --git a/bindings/python/iree/runtime/flags_test.py b/bindings/python/iree/runtime/flags_test.py
new file mode 100644
index 0000000..886176a
--- /dev/null
+++ b/bindings/python/iree/runtime/flags_test.py
@@ -0,0 +1,24 @@
+# Copyright 2019 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree import runtime as rt
+import numpy as np
+import unittest
+
+
+class FlagsTest(unittest.TestCase):
+
+ def testParse(self):
+ # We always have the logging verbose level available so use it.
+ rt.flags.parse_flags("--iree_v=1")
+
+ def testParseError(self):
+ with self.assertRaisesRegex(ValueError, "flag 'barbar' not recognized"):
+ rt.flags.parse_flags("--barbar")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/bindings/python/iree/runtime/function.py b/bindings/python/iree/runtime/function.py
index df9565a..d0ebd79 100644
--- a/bindings/python/iree/runtime/function.py
+++ b/bindings/python/iree/runtime/function.py
@@ -28,6 +28,8 @@
map_dtype_to_element_type,
DeviceArray,
)
+from .flags import (
+ FUNCTION_INPUT_VALIDATION,)
__all__ = [
"FunctionInvoker",
@@ -140,7 +142,7 @@
_merge_python_sequence_to_vm(inv, arg_list, args, self._arg_descs)
if call_trace:
call_trace.add_vm_list(arg_list, "args")
- self._vm_context.invoke(self._vm_function, arg_list, ret_list)
+ self._invoke(arg_list, ret_list)
if call_trace:
call_trace.add_vm_list(ret_list, "results")
@@ -162,6 +164,10 @@
if call_trace:
call_trace.end_call()
+ # Break out invoke so it shows up in profiles.
+ def _invoke(self, arg_list, ret_list):
+ self._vm_context.invoke(self._vm_function, arg_list, ret_list)
+
def _parse_abi_dict(self, vm_function: VmFunction):
reflection = vm_function.reflection
abi_json = reflection.get("iree.abi")
@@ -284,7 +290,7 @@
def _ndarray_to_vm(inv: Invocation, t: VmVariantList, x, desc):
# Validate and implicit conversion against type descriptor.
- if desc is not None:
+ if FUNCTION_INPUT_VALIDATION and desc is not None:
desc_type = desc[0]
if desc_type != "ndarray":
_raise_argument_error(inv, f"passed an ndarray but expected {desc_type}")
@@ -532,7 +538,7 @@
# For dynamic mode, just assume we have the right arity.
if descs is None:
descs = [None] * len(py_list)
- else:
+ elif FUNCTION_INPUT_VALIDATION:
len_py_list = sum([1 for x in py_list if x is not MissingArgument])
if len(py_list) != len_py_list:
_raise_argument_error(
diff --git a/bindings/python/iree/runtime/initialize_module.cc b/bindings/python/iree/runtime/initialize_module.cc
index 334668a..b48586c 100644
--- a/bindings/python/iree/runtime/initialize_module.cc
+++ b/bindings/python/iree/runtime/initialize_module.cc
@@ -8,6 +8,7 @@
#include "bindings/python/iree/runtime/hal.h"
#include "bindings/python/iree/runtime/status_utils.h"
#include "bindings/python/iree/runtime/vm.h"
+#include "iree/base/internal/flags.h"
#include "iree/base/status_cc.h"
#include "iree/hal/drivers/init.h"
@@ -21,6 +22,26 @@
m.doc() = "IREE Binding Backend Helpers";
SetupHalBindings(m);
SetupVmBindings(m);
+
+ m.def("parse_flags", [](py::args py_flags) {
+ std::vector<std::string> alloced_flags;
+ alloced_flags.push_back("python");
+ for (auto &py_flag : py_flags) {
+ alloced_flags.push_back(py::cast<std::string>(py_flag));
+ }
+
+ // Must build pointer vector after filling so pointers are stable.
+ std::vector<char *> flag_ptrs;
+ for (auto &alloced_flag : alloced_flags) {
+ flag_ptrs.push_back(const_cast<char *>(alloced_flag.c_str()));
+ }
+
+ char **argv = &flag_ptrs[0];
+ int argc = flag_ptrs.size();
+ CheckApiStatus(
+ iree_flags_parse(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv),
+ "Error parsing flags");
+ });
}
} // namespace python