blob: c79da46353fcfed18052ec6c9e427b13d65ed8eb [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <memory>
#include "./binding.h"
#include "./hal.h"
#include "./invoke.h"
#include "./io.h"
#include "./loop.h"
#include "./numpy_interop.h"
#include "./py_module.h"
#include "./status_utils.h"
#include "./vm.h"
#include "iree/base/internal/flags.h"
#include "iree/hal/drivers/init.h"
namespace {
// Stable storage for flag processing. Flag handling uses string views,
// expecting the caller to keep the original strings around for as long
// as the flags are in use. This object holds one set of flag strings
// for each invocation of parse_flags.
std::vector<std::unique_ptr<std::vector<std::string>>> alloced_flag_cache;
} // namespace
namespace iree {
namespace python {
NB_MODULE(_runtime, m) {
numpy::InitializeNumPyInterop();
IREE_TRACE_APP_ENTER();
IREE_CHECK_OK(iree_hal_register_all_available_drivers(
iree_hal_driver_registry_default()));
m.doc() = "IREE Binding Backend Helpers";
SetupHalBindings(m);
SetupInvokeBindings(m);
SetupIoBindings(m);
SetupLoopBindings(m);
SetupPyModuleBindings(m);
SetupVmBindings(m);
// Adds the given set of strings to the global flags. These new flags
// take effect upon the next creation of a driver. They do not affect
// drivers already created.
m.def("parse_flags", [](py::args py_flags) {
// Make a new set of strings at the back of the cache
alloced_flag_cache.emplace_back(
std::make_unique<std::vector<std::string>>(std::vector<std::string>()));
auto &alloced_flags = *alloced_flag_cache.back();
// Add the given python strings to the std::string set.
alloced_flags.push_back("python");
for (py::handle py_flag : py_flags) {
alloced_flags.push_back(py::cast<std::string>(py_flag));
}
// As the flags-processing mechanism of the C API requires long-lived
// char * strings, create a set of char * strings from the std::strings,
// with the std::strings responsible for maintaining the storage.
// Must build pointer vector after filling std::strings so pointers are
// stable.
std::vector<char *> flag_ptrs;
for (auto &alloced_flag : alloced_flags) {
flag_ptrs.push_back(const_cast<char *>(alloced_flag.c_str()));
}
// Send the flags to the C API
char **argv = &flag_ptrs[0];
int argc = flag_ptrs.size();
CheckApiStatus(iree_flags_parse(IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP,
&argc, &argv),
"Error parsing flags");
});
m.def("disable_leak_checker", []() { py::set_leak_warnings(false); });
}
} // namespace python
} // namespace iree