[PJRT] Allow to pass extra compile options via env variables (#19418)

Sometime it's useful to pass some extra IREE compiler options to the
PJRT plugin by environment variables to debug/do some
experiment/performance tuning without recompilation.

This is a rewrite to the following code which was commented as a TODO.


https://github.com/iree-org/iree/blob/b68c535ece28e139492606f391493f3e95242420/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc#L231-L245

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <twice@apache.org>
Co-authored-by: Scott Todd <scott.todd0@gmail.com>
diff --git a/integrations/pjrt/README.md b/integrations/pjrt/README.md
index 9c018be..f4f1545 100644
--- a/integrations/pjrt/README.md
+++ b/integrations/pjrt/README.md
@@ -36,6 +36,19 @@
 JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
 ```
 
+## Advanced settings
+
+To pass additional compile options to IREE during JIT compilation, you can use
+the `IREE_PJRT_IREE_COMPILER_OPTIONS` environment variable. This variable can
+be set to a space-delimited list of flags that would be passed to the
+`iree-compile` command-line tool.
+
+For example:
+```shell
+export IREE_PJRT_IREE_COMPILER_OPTIONS=--iree-scheduling-dump-statistics-format=csv
+JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
+```
+
 ## Incrementally developing
 
 If you did an editable install (`-e`) above, then you should be able to incrementally
diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
index c5855fc..7127091 100644
--- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
+++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
@@ -9,6 +9,7 @@
     common
   HDRS
     "api_impl.h"
+    "command_line_utils.h"
     "dylib_entry_point.cc.inc"
     "iree_helpers.h"
     "layout_utils.h"
@@ -16,6 +17,7 @@
     "tensor_utils.h"
   SRCS
     "api_impl.cc"
+    "command_line_utils.cc"
     "layout_utils.cc"
     "platform.cc"
     "tensor_utils.cc"
@@ -60,6 +62,17 @@
     PUBLIC
 )
 
+iree_cc_test(
+  NAME
+    command_line_utils_test
+  SRCS
+    "command_line_utils_test.cc"
+  DEPS
+    ::common
+    iree::testing::gtest
+    iree::testing::gtest_main
+)
+
 iree_cc_library(
     NAME
       debugging
diff --git a/integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc
new file mode 100644
index 0000000..31f6af0
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.cc
@@ -0,0 +1,54 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "command_line_utils.h"
+
+namespace iree {
+namespace pjrt {
+
+// TODO: currently this function doesn't handle escape sequences,
+// it just ensure that single/double quotes are interpreted corrently.
+std::optional<std::vector<std::string>> ParseOptionsFromCommandLine(
+    std::string_view options_str) {
+  std::vector<std::string> options;
+  std::string current;
+
+  enum { NORMAL, SINGLE_QUOTE, DOUBLE_QUOTE } state = NORMAL;
+  for (auto it = options_str.begin(); it != options_str.end(); ++it) {
+    if (std::isspace(*it) && state == NORMAL) {
+      if (!current.empty()) {
+        options.push_back(std::move(current));
+        current.clear();
+      }
+    } else if (*it == '"' && state != SINGLE_QUOTE) {
+      if (state == NORMAL)
+        state = DOUBLE_QUOTE;
+      else if (state == DOUBLE_QUOTE)
+        state = NORMAL;
+    } else if (*it == '\'' && state != DOUBLE_QUOTE) {
+      if (state == NORMAL)
+        state = SINGLE_QUOTE;
+      else if (state == SINGLE_QUOTE)
+        state = NORMAL;
+    } else {
+      current.push_back(*it);
+    }
+  }
+
+  if (!current.empty()) {
+    options.push_back(std::move(current));
+  }
+
+  // if it's still in a quote, then return nullopt
+  if (state != NORMAL) {
+    return std::nullopt;
+  }
+
+  return options;
+}
+
+}  // namespace pjrt
+}  // namespace iree
diff --git a/integrations/pjrt/src/iree_pjrt/common/command_line_utils.h b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.h
new file mode 100644
index 0000000..b8df54d
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/common/command_line_utils.h
@@ -0,0 +1,26 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_COMMAND_LINE_UTILS_H_
+#define IREE_PJRT_PLUGIN_PJRT_COMMON_COMMAND_LINE_UTILS_H_
+
+#include <optional>
+#include <string>
+#include <string_view>
+#include <vector>
+
+namespace iree {
+namespace pjrt {
+
+// parse command line options (maybe with quotes) to an array of options
+// e.g. `a b "c d"` -> {"a", "b", "c d"}
+std::optional<std::vector<std::string>> ParseOptionsFromCommandLine(
+    std::string_view options_str);
+
+}  // namespace pjrt
+}  // namespace iree
+
+#endif
diff --git a/integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc b/integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc
new file mode 100644
index 0000000..ffd7852
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/common/command_line_utils_test.cc
@@ -0,0 +1,24 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/common/command_line_utils.h"
+
+#include <gtest/gtest.h>
+
+using namespace iree::pjrt;
+
+TEST(CommandLineUtils, ParseOptionsFromCommandLine) {
+  EXPECT_EQ(ParseOptionsFromCommandLine("--help --verbose"),
+            (std::vector<std::string>{"--help", "--verbose"}));
+  EXPECT_EQ(ParseOptionsFromCommandLine("-a='x y' -b \"n m\""),
+            (std::vector<std::string>{"-a=x y", "-b", "n m"}));
+  EXPECT_EQ(ParseOptionsFromCommandLine("'\"' \"'\""),
+            (std::vector<std::string>{"\"", "'"}));
+  EXPECT_EQ(ParseOptionsFromCommandLine("ab   abc d 'e f g' h  "),
+            (std::vector<std::string>{"ab", "abc", "d", "e f g", "h"}));
+  EXPECT_EQ(ParseOptionsFromCommandLine("a 'b"), std::nullopt);
+  EXPECT_EQ(ParseOptionsFromCommandLine("x\"y"), std::nullopt);
+}
diff --git a/integrations/pjrt/src/iree_pjrt/common/compiler.h b/integrations/pjrt/src/iree_pjrt/common/compiler.h
index 9969a00..e7dcdc5 100644
--- a/integrations/pjrt/src/iree_pjrt/common/compiler.h
+++ b/integrations/pjrt/src/iree_pjrt/common/compiler.h
@@ -76,12 +76,16 @@
 // An AbstractCompiler based on IREE.
 class IREECompiler : public AbstractCompiler {
  public:
+  IREECompiler(std::vector<std::string> extra_options = {})
+      : extra_options_(std::move(extra_options)) {}
+
   std::unique_ptr<CompilerJob> StartJob() override;
   std::string GetRevision() override;
   std::string GetErrorMessage() override { return error_message_; }
 
  private:
   std::string error_message_;
+  std::vector<std::string> extra_options_;
 };
 
 // An AbstractCompiler based on the HLO partitioner.
diff --git a/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc b/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
index f42d5a5..6feb5c8 100644
--- a/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/dylib_platform.cc
@@ -14,6 +14,7 @@
 #include "iree/base/internal/path.h"
 #include "iree/compiler/embedding_api.h"
 #include "iree/compiler/loader.h"
+#include "iree_pjrt/common/command_line_utils.h"
 #include "iree_pjrt/partitioner_api/embedding_api.h"
 #include "iree_pjrt/partitioner_api/loader.h"
 
@@ -98,7 +99,15 @@
     message.append(*loaded_compiler);
     logger().debug(message);
   }
-  compiler_ = std::make_unique<IREECompiler>();
+
+  std::vector<std::string> extra_compiler_options;
+  if (auto options_str = config_vars().Lookup("IREE_COMPILER_OPTIONS")) {
+    if (auto options = ParseOptionsFromCommandLine(*options_str)) {
+      extra_compiler_options = std::move(*options);
+      logger().debug("Extra compile options: " + *options_str);
+    }
+  }
+  compiler_ = std::make_unique<IREECompiler>(std::move(extra_compiler_options));
   {
     std::string message("Compiler Version: ");
     message.append(compiler_->GetRevision());
diff --git a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
index 1cddbc3..87bc3b8 100644
--- a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
@@ -228,20 +228,12 @@
   }
 
   // Propagate all options set via environment variable.
-  // TODO: Excise/translate to something that doesn't rely on LLVM.
-  // if (std::optional<std::string> env_value = llvm::sys::Process::GetEnv(
-  //         llvm::StringRef("IREE_COMPILER_OPTIONS"))) {
-  //   llvm::SmallVector<const char*, 20> new_argv;
-  //   llvm::BumpPtrAllocator a;
-  //   llvm::StringSaver saver(a);
-
-  //   llvm::cl::TokenizeGNUCommandLine(*env_value, saver, new_argv);
-  //   for (auto arg : new_argv)
-  //     if (!job->SetFlag(arg)) {
-  //       error_message_ = job->GetErrorMessage();
-  //       return nullptr;
-  //     }
-  // }
+  for (auto arg : extra_options_) {
+    if (!job->SetFlag(arg.c_str())) {
+      error_message_ = job->GetErrorMessage();
+      return nullptr;
+    }
+  }
 
   return job;
 }