[PJRT] Allow to pass extra compile options via env variables (#19418)
Sometime it's useful to pass some extra IREE compiler options to the
PJRT plugin by environment variables to debug/do some
experiment/performance tuning without recompilation.
This is a rewrite to the following code which was commented as a TODO.
https://github.com/iree-org/iree/blob/b68c535ece28e139492606f391493f3e95242420/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc#L231-L245
ci-exactly: build_packages, test_pjrt
---------
Signed-off-by: PragmaTwice <twice@apache.org>
Co-authored-by: Scott Todd <scott.todd0@gmail.com>
diff --git a/integrations/pjrt/README.md b/integrations/pjrt/README.md
index 9c018be..f4f1545 100644
--- a/integrations/pjrt/README.md
+++ b/integrations/pjrt/README.md
@@ -36,6 +36,19 @@
JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
```
+## Advanced settings
+
+To pass additional compile options to IREE during JIT compilation, you can use
+the `IREE_PJRT_IREE_COMPILER_OPTIONS` environment variable. This variable can
+be set to a space-delimited list of flags that would be passed to the
+`iree-compile` command-line tool.
+
+For example:
+```shell
+export IREE_PJRT_IREE_COMPILER_OPTIONS=--iree-scheduling-dump-statistics-format=csv
+JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
+```
+
## Incrementally developing
If you did an editable install (`-e`) above, then you should be able to incrementally
diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
index c5855fc..7127091 100644
--- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
+++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
@@ -9,6 +9,7 @@
common
HDRS
"api_impl.h"
+ "command_line_utils.h"
"dylib_entry_point.cc.inc"
"iree_helpers.h"
"layout_utils.h"
@@ -16,6 +17,7 @@
"tensor_utils.h"
SRCS
"api_impl.cc"
+ "command_line_utils.cc"
"layout_utils.cc"
"platform.cc"
"tensor_utils.cc"
@@ -60,6 +62,17 @@
PUBLIC
)
+iree_cc_test(
+ NAME
+ command_line_utils_test
+ SRCS
+ "command_line_utils_test.cc"
+ DEPS
+ ::common
+ iree::testing::gtest
+ iree::testing::gtest_main
+)
+
iree_cc_library(
NAME
debugging
diff --git a/integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc
new file mode 100644
index 0000000..31f6af0
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc
@@ -0,0 +1,54 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "command_line_utils.h"
+
+namespace iree {
+namespace pjrt {
+
+// TODO: currently this function doesn't handle escape sequences,
+// it just ensure that single/double quotes are interpreted corrently.
+std::optional<std::vector<std::string>> ParseOptionsFromCommandLine(
+ std::string_view options_str) {
+ std::vector<std::string> options;
+ std::string current;
+
+ enum { NORMAL, SINGLE_QUOTE, DOUBLE_QUOTE } state = NORMAL;
+ for (auto it = options_str.begin(); it != options_str.end(); ++it) {
+ if (std::isspace(*it) && state == NORMAL) {
+ if (!current.empty()) {
+ options.push_back(std::move(current));
+ current.clear();
+ }
+ } else if (*it == '"' && state != SINGLE_QUOTE) {
+ if (state == NORMAL)
+ state = DOUBLE_QUOTE;
+ else if (state == DOUBLE_QUOTE)
+ state = NORMAL;
+ } else if (*it == '\'' && state != DOUBLE_QUOTE) {
+ if (state == NORMAL)
+ state = SINGLE_QUOTE;
+ else if (state == SINGLE_QUOTE)
+ state = NORMAL;
+ } else {
+ current.push_back(*it);
+ }
+ }
+
+ if (!current.empty()) {
+ options.push_back(std::move(current));
+ }
+
+ // if it's still in a quote, then return nullopt
+ if (state != NORMAL) {
+ return std::nullopt;
+ }
+
+ return options;
+}
+
+} // namespace pjrt
+} // namespace iree
diff --git a/integrations/pjrt/src/iree_pjrt/common/command_line_utils.h b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.h
new file mode 100644
index 0000000..b8df54d
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.h
@@ -0,0 +1,26 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_COMMAND_LINE_UTILS_H_
+#define IREE_PJRT_PLUGIN_PJRT_COMMON_COMMAND_LINE_UTILS_H_
+
+#include <optional>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace iree {
+namespace pjrt {
+
+// parse command line options (maybe with quotes) to an array of options
+// e.g. `a b "c d"` -> {"a", "b", "c d"}
+std::optional<std::vector<std::string>> ParseOptionsFromCommandLine(
+ std::string_view options_str);
+
+} // namespace pjrt
+} // namespace iree
+
+#endif
diff --git a/integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc b/integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc
new file mode 100644
index 0000000..ffd7852
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc
@@ -0,0 +1,24 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/common/command_line_utils.h"
+
+#include <gtest/gtest.h>
+
+using namespace iree::pjrt;
+
+TEST(CommandLineUtils, ParseOptionsFromCommandLine) {
+ EXPECT_EQ(ParseOptionsFromCommandLine("--help --verbose"),
+ (std::vector<std::string>{"--help", "--verbose"}));
+ EXPECT_EQ(ParseOptionsFromCommandLine("-a='x y' -b \"n m\""),
+ (std::vector<std::string>{"-a=x y", "-b", "n m"}));
+ EXPECT_EQ(ParseOptionsFromCommandLine("'\"' \"'\""),
+ (std::vector<std::string>{"\"", "'"}));
+ EXPECT_EQ(ParseOptionsFromCommandLine("ab abc d 'e f g' h "),
+ (std::vector<std::string>{"ab", "abc", "d", "e f g", "h"}));
+ EXPECT_EQ(ParseOptionsFromCommandLine("a 'b"), std::nullopt);
+ EXPECT_EQ(ParseOptionsFromCommandLine("x\"y"), std::nullopt);
+}
diff --git a/integrations/pjrt/src/iree_pjrt/common/compiler.h b/integrations/pjrt/src/iree_pjrt/common/compiler.h
index 9969a00..e7dcdc5 100644
--- a/integrations/pjrt/src/iree_pjrt/common/compiler.h
+++ b/integrations/pjrt/src/iree_pjrt/common/compiler.h
@@ -76,12 +76,16 @@
// An AbstractCompiler based on IREE.
class IREECompiler : public AbstractCompiler {
public:
+ IREECompiler(std::vector<std::string> extra_options = {})
+ : extra_options_(std::move(extra_options)) {}
+
std::unique_ptr<CompilerJob> StartJob() override;
std::string GetRevision() override;
std::string GetErrorMessage() override { return error_message_; }
private:
std::string error_message_;
+ std::vector<std::string> extra_options_;
};
// An AbstractCompiler based on the HLO partitioner.
diff --git a/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc b/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
index f42d5a5..6feb5c8 100644
--- a/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
@@ -14,6 +14,7 @@
#include "iree/base/internal/path.h"
#include "iree/compiler/embedding_api.h"
#include "iree/compiler/loader.h"
+#include "iree_pjrt/common/command_line_utils.h"
#include "iree_pjrt/partitioner_api/embedding_api.h"
#include "iree_pjrt/partitioner_api/loader.h"
@@ -98,7 +99,15 @@
message.append(*loaded_compiler);
logger().debug(message);
}
- compiler_ = std::make_unique<IREECompiler>();
+
+ std::vector<std::string> extra_compiler_options;
+ if (auto options_str = config_vars().Lookup("IREE_COMPILER_OPTIONS")) {
+ if (auto options = ParseOptionsFromCommandLine(*options_str)) {
+ extra_compiler_options = std::move(*options);
+ logger().debug("Extra compile options: " + *options_str);
+ }
+ }
+ compiler_ = std::make_unique<IREECompiler>(std::move(extra_compiler_options));
{
std::string message("Compiler Version: ");
message.append(compiler_->GetRevision());
diff --git a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
index 1cddbc3..87bc3b8 100644
--- a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
@@ -228,20 +228,12 @@
}
// Propagate all options set via environment variable.
- // TODO: Excise/translate to something that doesn't rely on LLVM.
- // if (std::optional<std::string> env_value = llvm::sys::Process::GetEnv(
- // llvm::StringRef("IREE_COMPILER_OPTIONS"))) {
- // llvm::SmallVector<const char*, 20> new_argv;
- // llvm::BumpPtrAllocator a;
- // llvm::StringSaver saver(a);
-
- // llvm::cl::TokenizeGNUCommandLine(*env_value, saver, new_argv);
- // for (auto arg : new_argv)
- // if (!job->SetFlag(arg)) {
- // error_message_ = job->GetErrorMessage();
- // return nullptr;
- // }
- // }
+ for (auto arg : extra_options_) {
+ if (!job->SetFlag(arg.c_str())) {
+ error_message_ = job->GetErrorMessage();
+ return nullptr;
+ }
+ }
return job;
}