[PJRT] Add support of passing per-compilation compile options (#19438)

As discussed in
https://github.com/iree-org/iree/pull/19418#discussion_r1876524772,
https://github.com/iree-org/iree/pull/19418#pullrequestreview-2490115222
and https://github.com/iree-org/iree/pull/19418#discussion_r1877191103,
here we support to read `env_option_overrides` as IREE compile flags
from `compile_options` passed by frontends like JAX in a per-compilation
basis.

Most of these code already exists but has been commented due to some
problems: `compile_options` was not yet available in that time, but it's
now introduced by #19369.

A simple use case is shown below, also as a test case:

https://github.com/iree-org/iree/blob/c37a80212dd4a541762fc9fdaaa615b6d0a62829/integrations/pjrt/test/test_compile_options.py#L9-L15

ci-exactly: build_packages, test_pjrt

---------

Signed-off-by: PragmaTwice <twice@apache.org>
Co-authored-by: Scott Todd <scott.todd0@gmail.com>
diff --git a/.github/workflows/pkgci_test_pjrt.yml b/.github/workflows/pkgci_test_pjrt.yml
index c9f4b78..f288352 100644
--- a/.github/workflows/pkgci_test_pjrt.yml
+++ b/.github/workflows/pkgci_test_pjrt.yml
@@ -61,7 +61,7 @@
           source ${VENV_DIR}/bin/activate
           python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
           # install
-          python -m pip install jax==0.4.35
+          python -m pip install jax==0.4.36
       - name: Run tests
         run: |
           source ${VENV_DIR}/bin/activate
diff --git a/build_tools/testing/run_jax_tests.sh b/build_tools/testing/run_jax_tests.sh
index 66b2135..a065d16 100755
--- a/build_tools/testing/run_jax_tests.sh
+++ b/build_tools/testing/run_jax_tests.sh
@@ -48,6 +48,13 @@
 diff_jax_test test/test_degenerate.py
 diff_jax_test test/test_simple.py
 
+# here we test if the compile options is passed to IREE PJRT plugin successfully.
+# we pass --iree-scheduling-dump-statistics-format=csv via jax.jit,
+# and see if there's statistics in the output
+compile_options_test_tmp_out=$(mktemp /tmp/jax_test_result_compile_options.XXXXXX)
+JAX_PLATFORMS=$actual_jax_platform python test/test_compile_options.py 2>&1 | tee $compile_options_test_tmp_out
+cat $compile_options_test_tmp_out | grep '@main_dispatch'
+
 
 # FIXME: we can also utilize the native test cases from JAX,
 # e.g. `tests/nn_test.py` from the JAX repo, as below,
diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
index 4371cd3..c5855fc 100644
--- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
+++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
@@ -56,6 +56,7 @@
       iree::compiler::bindings::c::loader
       iree_pjrt::partitioner_api
       iree_pjrt::partitioner_api::loader
+      iree_pjrt_deps::protos
     PUBLIC
 )
 
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 68b45ec..6f2d787 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -1487,8 +1487,8 @@
     }
 
     // Set flags.
-    // TODO: Plumb CompileOptions through.
-    // if (!job->SetFlags(options)) return MakeCompilerError(*job);
+    if (!job->SetFlags(options)) return MakeCompilerError(*job);
+
     if (artifact_tx) {
       artifact_tx->WriteArtifact(
           /*label=*/"partitioner_flags", /*extension=*/"txt", /*index=*/-1,
@@ -1538,8 +1538,8 @@
     if (!SetDefaultCompilerFlags(job.get())) {
       return MakeCompilerError(*job);
     }
-    // TODO: Plumb CompileOptions through.
-    // if (!job->SetFlags(options)) return MakeCompilerError(*job);
+    if (!job->SetFlags(options)) return MakeCompilerError(*job);
+
     if (artifact_tx) {
       artifact_tx->WriteArtifact(
           /*label=*/"flags", /*extension=*/"txt", /*index=*/-1,
diff --git a/integrations/pjrt/src/iree_pjrt/common/compiler.h b/integrations/pjrt/src/iree_pjrt/common/compiler.h
index faf07db..9969a00 100644
--- a/integrations/pjrt/src/iree_pjrt/common/compiler.h
+++ b/integrations/pjrt/src/iree_pjrt/common/compiler.h
@@ -11,8 +11,7 @@
 #include <string>
 
 #include "iree_pjrt/common/debugging.h"
-// TODO: Excise.
-// #include "xla/pjrt/pjrt_executable.h"
+#include "xla/pjrt/compile_options.pb.h"
 
 namespace iree::pjrt {
 
@@ -37,8 +36,7 @@
   // setup of a job (or if the underlying session will not be re-used).
   // Returns false on failure.
   virtual bool SetFlag(const char* flag) = 0;
-  // TODO: Excise.
-  // virtual bool SetFlags(xla::CompileOptions options) = 0;
+  virtual bool SetFlags(xla::CompileOptionsProto options) = 0;
 
   // Gets all flags as a string. This is intended for debug printing a plausible
   // command line to reproduce compilation.
diff --git a/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc b/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
index b323807..c00f6eb 100644
--- a/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
@@ -97,41 +97,44 @@
     return true;
   }
 
-  // TODO: Find another way to deal with this.
-  // bool SetFlags(xla::CompileOptions options) override {
-  //   int num_partitions = options.executable_build_options.num_partitions();
-  //   int num_replicas = options.executable_build_options.num_replicas();
-  //   bool use_spmd_partitioning =
-  //       options.executable_build_options.use_spmd_partitioning();
-  //   auto allow_spmd_sharding_propagation_to_output =
-  //       options.executable_build_options
-  //           .allow_spmd_sharding_propagation_to_output();
-  //   if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-num-partitions=",
-  //                             num_partitions)
-  //                    .c_str())) {
-  //     return false;
-  //   }
-  //   if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-replica-count=",
-  //                             num_replicas)
-  //                    .c_str())) {
-  //     return false;
-  //   }
-  //   if (!SetFlag(
-  //           absl::StrCat("--openxla-partitioner-gspmd-use-spmd-partitioning=",
-  //                        use_spmd_partitioning)
-  //               .c_str())) {
-  //     return false;
-  //   }
-  //   if (!SetFlag(
-  //           absl::StrCat(
-  //               "--openxla-partitioner-gspmd-allow-spmd-"
-  //               "sharding-propagation-to-output=",
-  //               absl::StrJoin(allow_spmd_sharding_propagation_to_output,
-  //               ",")) .c_str())) {
-  //     return false;
-  //   }
-  //   return true;
-  // }
+  bool SetFlags(xla::CompileOptionsProto options) override {
+    int num_partitions = options.executable_build_options().num_partitions();
+    int num_replicas = options.executable_build_options().num_replicas();
+    bool use_spmd_partitioning =
+        options.executable_build_options().use_spmd_partitioning();
+    auto allow_spmd_sharding_propagation_to_output =
+        options.executable_build_options()
+            .allow_spmd_sharding_propagation_to_output();
+    if (!SetFlag(("--openxla-partitioner-gspmd-num-partitions=" +
+                  std::to_string(num_partitions))
+                     .c_str())) {
+      return false;
+    }
+    if (!SetFlag(("--openxla-partitioner-gspmd-replica-count=" +
+                  std::to_string(num_replicas))
+                     .c_str())) {
+      return false;
+    }
+    if (!SetFlag(("--openxla-partitioner-gspmd-use-spmd-partitioning=" +
+                  std::to_string(use_spmd_partitioning))
+                     .c_str())) {
+      return false;
+    }
+    std::string allow_spmd_sharding_propagation_to_output_str;
+    for (size_t i = 0; i < allow_spmd_sharding_propagation_to_output.size();
+         ++i) {
+      if (i != 0) allow_spmd_sharding_propagation_to_output_str += ",";
+      allow_spmd_sharding_propagation_to_output_str +=
+          std::to_string(allow_spmd_sharding_propagation_to_output[i]);
+    }
+    if (!SetFlag(("--openxla-partitioner-gspmd-allow-spmd-"
+                  "sharding-propagation-to-output=" +
+                  allow_spmd_sharding_propagation_to_output_str)
+                     .c_str())) {
+      return false;
+    }
+    return true;
+  }
 
   std::string GetFlags() override {
     std::string flags;
diff --git a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
index 095e540..1cddbc3 100644
--- a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
@@ -95,29 +95,28 @@
     return true;
   }
 
-  // TODO: Excise: Cannot dep on an internal XLA structure.
-  // bool SetFlags(xla::CompileOptions options) override {
-  //   // Set extra options, overriding env variables if appropriate.
-  //   for (auto [option, option_override] : options.env_option_overrides) {
-  //     std::string override_string;
-  //     if (auto override_val = std::get_if<std::string>(&option_override)) {
-  //       override_string = *override_val;
-  //     } else if (auto override_val = std::get_if<bool>(&option_override)) {
-  //       override_string = *override_val ? "true" : "false";
-  //     } else if (auto override_val = std::get_if<int64_t>(&option_override))
-  //     {
-  //       override_string = std::to_string(*override_val);
-  //     } else {
-  //       assert(false &&
-  //              "option value should be of type string, bool, or int64");
-  //     }
-  //     if (!SetFlag(absl::StrCat("--", option, "=", override_string).c_str()))
-  //     {
-  //       return false;
-  //     }
-  //   }
-  //   return true;
-  // }
+  bool SetFlags(xla::CompileOptionsProto options) override {
+    // Set extra options, overriding env variables if appropriate.
+    for (auto [option, option_override] : options.env_option_overrides()) {
+      std::string override_string;
+      if (option_override.has_string_field()) {
+        override_string = option_override.string_field();
+      } else if (option_override.has_bool_field()) {
+        override_string = option_override.bool_field() ? "true" : "false";
+      } else if (option_override.has_int_field()) {
+        override_string = std::to_string(option_override.int_field());
+      } else if (option_override.has_double_field()) {
+        override_string = std::to_string(option_override.double_field());
+      } else {
+        assert(false &&
+               "option value should be of type string, bool, int, or double");
+      }
+      if (!SetFlag(("--" + option + "=" + override_string).c_str())) {
+        return false;
+      }
+    }
+    return true;
+  }
 
   std::string GetFlags() override {
     std::string flags;
diff --git a/integrations/pjrt/test/test_compile_options.py b/integrations/pjrt/test/test_compile_options.py
new file mode 100644
index 0000000..877ab99
--- /dev/null
+++ b/integrations/pjrt/test/test_compile_options.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from functools import partial
+import jax.numpy as jnp
+from jax import jit
+
+a = jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+
+@partial(jit, compiler_options={"iree-scheduling-dump-statistics-format": "csv"})
+def f(a, b):
+    return a + b
+
+
+print(f(a, a))