[PJRT] Add support of passing per-compilation compile options (#19438)
As discussed in
https://github.com/iree-org/iree/pull/19418#discussion_r1876524772,
https://github.com/iree-org/iree/pull/19418#pullrequestreview-2490115222
and https://github.com/iree-org/iree/pull/19418#discussion_r1877191103,
here we support to read `env_option_overrides` as IREE compile flags
from `compile_options` passed by frontends like JAX in a per-compilation
basis.
Most of these code already exists but has been commented due to some
problems: `compile_options` was not yet available in that time, but it's
now introduced by #19369.
A simple use case is shown below, also as a test case:
https://github.com/iree-org/iree/blob/c37a80212dd4a541762fc9fdaaa615b6d0a62829/integrations/pjrt/test/test_compile_options.py#L9-L15
ci-exactly: build_packages, test_pjrt
---------
Signed-off-by: PragmaTwice <twice@apache.org>
Co-authored-by: Scott Todd <scott.todd0@gmail.com>
diff --git a/.github/workflows/pkgci_test_pjrt.yml b/.github/workflows/pkgci_test_pjrt.yml
index c9f4b78..f288352 100644
--- a/.github/workflows/pkgci_test_pjrt.yml
+++ b/.github/workflows/pkgci_test_pjrt.yml
@@ -61,7 +61,7 @@
source ${VENV_DIR}/bin/activate
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
# install
- python -m pip install jax==0.4.35
+ python -m pip install jax==0.4.36
- name: Run tests
run: |
source ${VENV_DIR}/bin/activate
diff --git a/build_tools/testing/run_jax_tests.sh b/build_tools/testing/run_jax_tests.sh
index 66b2135..a065d16 100755
--- a/build_tools/testing/run_jax_tests.sh
+++ b/build_tools/testing/run_jax_tests.sh
@@ -48,6 +48,13 @@
diff_jax_test test/test_degenerate.py
diff_jax_test test/test_simple.py
+# here we test if the compile options is passed to IREE PJRT plugin successfully.
+# we pass --iree-scheduling-dump-statistics-format=csv via jax.jit,
+# and see if there's statistics in the output
+compile_options_test_tmp_out=$(mktemp /tmp/jax_test_result_compile_options.XXXXXX)
+JAX_PLATFORMS=$actual_jax_platform python test/test_compile_options.py 2>&1 | tee $compile_options_test_tmp_out
+cat $compile_options_test_tmp_out | grep '@main_dispatch'
+
# FIXME: we can also utilize the native test cases from JAX,
# e.g. `tests/nn_test.py` from the JAX repo, as below,
diff --git a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
index 4371cd3..c5855fc 100644
--- a/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
+++ b/integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt
@@ -56,6 +56,7 @@
iree::compiler::bindings::c::loader
iree_pjrt::partitioner_api
iree_pjrt::partitioner_api::loader
+ iree_pjrt_deps::protos
PUBLIC
)
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 68b45ec..6f2d787 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -1487,8 +1487,8 @@
}
// Set flags.
- // TODO: Plumb CompileOptions through.
- // if (!job->SetFlags(options)) return MakeCompilerError(*job);
+ if (!job->SetFlags(options)) return MakeCompilerError(*job);
+
if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"partitioner_flags", /*extension=*/"txt", /*index=*/-1,
@@ -1538,8 +1538,8 @@
if (!SetDefaultCompilerFlags(job.get())) {
return MakeCompilerError(*job);
}
- // TODO: Plumb CompileOptions through.
- // if (!job->SetFlags(options)) return MakeCompilerError(*job);
+ if (!job->SetFlags(options)) return MakeCompilerError(*job);
+
if (artifact_tx) {
artifact_tx->WriteArtifact(
/*label=*/"flags", /*extension=*/"txt", /*index=*/-1,
diff --git a/integrations/pjrt/src/iree_pjrt/common/compiler.h b/integrations/pjrt/src/iree_pjrt/common/compiler.h
index faf07db..9969a00 100644
--- a/integrations/pjrt/src/iree_pjrt/common/compiler.h
+++ b/integrations/pjrt/src/iree_pjrt/common/compiler.h
@@ -11,8 +11,7 @@
#include <string>
#include "iree_pjrt/common/debugging.h"
-// TODO: Excise.
-// #include "xla/pjrt/pjrt_executable.h"
+#include "xla/pjrt/compile_options.pb.h"
namespace iree::pjrt {
@@ -37,8 +36,7 @@
// setup of a job (or if the underlying session will not be re-used).
// Returns false on failure.
virtual bool SetFlag(const char* flag) = 0;
- // TODO: Excise.
- // virtual bool SetFlags(xla::CompileOptions options) = 0;
+ virtual bool SetFlags(xla::CompileOptionsProto options) = 0;
// Gets all flags as a string. This is intended for debug printing a plausible
// command line to reproduce compilation.
diff --git a/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc b/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
index b323807..c00f6eb 100644
--- a/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc
@@ -97,41 +97,44 @@
return true;
}
- // TODO: Find another way to deal with this.
- // bool SetFlags(xla::CompileOptions options) override {
- // int num_partitions = options.executable_build_options.num_partitions();
- // int num_replicas = options.executable_build_options.num_replicas();
- // bool use_spmd_partitioning =
- // options.executable_build_options.use_spmd_partitioning();
- // auto allow_spmd_sharding_propagation_to_output =
- // options.executable_build_options
- // .allow_spmd_sharding_propagation_to_output();
- // if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-num-partitions=",
- // num_partitions)
- // .c_str())) {
- // return false;
- // }
- // if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-replica-count=",
- // num_replicas)
- // .c_str())) {
- // return false;
- // }
- // if (!SetFlag(
- // absl::StrCat("--openxla-partitioner-gspmd-use-spmd-partitioning=",
- // use_spmd_partitioning)
- // .c_str())) {
- // return false;
- // }
- // if (!SetFlag(
- // absl::StrCat(
- // "--openxla-partitioner-gspmd-allow-spmd-"
- // "sharding-propagation-to-output=",
- // absl::StrJoin(allow_spmd_sharding_propagation_to_output,
- // ",")) .c_str())) {
- // return false;
- // }
- // return true;
- // }
+ bool SetFlags(xla::CompileOptionsProto options) override {
+ int num_partitions = options.executable_build_options().num_partitions();
+ int num_replicas = options.executable_build_options().num_replicas();
+ bool use_spmd_partitioning =
+ options.executable_build_options().use_spmd_partitioning();
+ auto allow_spmd_sharding_propagation_to_output =
+ options.executable_build_options()
+ .allow_spmd_sharding_propagation_to_output();
+ if (!SetFlag(("--openxla-partitioner-gspmd-num-partitions=" +
+ std::to_string(num_partitions))
+ .c_str())) {
+ return false;
+ }
+ if (!SetFlag(("--openxla-partitioner-gspmd-replica-count=" +
+ std::to_string(num_replicas))
+ .c_str())) {
+ return false;
+ }
+ if (!SetFlag(("--openxla-partitioner-gspmd-use-spmd-partitioning=" +
+ std::to_string(use_spmd_partitioning))
+ .c_str())) {
+ return false;
+ }
+ std::string allow_spmd_sharding_propagation_to_output_str;
+ for (size_t i = 0; i < allow_spmd_sharding_propagation_to_output.size();
+ ++i) {
+ if (i != 0) allow_spmd_sharding_propagation_to_output_str += ",";
+ allow_spmd_sharding_propagation_to_output_str +=
+ std::to_string(allow_spmd_sharding_propagation_to_output[i]);
+ }
+ if (!SetFlag(("--openxla-partitioner-gspmd-allow-spmd-"
+ "sharding-propagation-to-output=" +
+ allow_spmd_sharding_propagation_to_output_str)
+ .c_str())) {
+ return false;
+ }
+ return true;
+ }
std::string GetFlags() override {
std::string flags;
diff --git a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
index 095e540..1cddbc3 100644
--- a/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc
@@ -95,29 +95,28 @@
return true;
}
- // TODO: Excise: Cannot dep on an internal XLA structure.
- // bool SetFlags(xla::CompileOptions options) override {
- // // Set extra options, overriding env variables if appropriate.
- // for (auto [option, option_override] : options.env_option_overrides) {
- // std::string override_string;
- // if (auto override_val = std::get_if<std::string>(&option_override)) {
- // override_string = *override_val;
- // } else if (auto override_val = std::get_if<bool>(&option_override)) {
- // override_string = *override_val ? "true" : "false";
- // } else if (auto override_val = std::get_if<int64_t>(&option_override))
- // {
- // override_string = std::to_string(*override_val);
- // } else {
- // assert(false &&
- // "option value should be of type string, bool, or int64");
- // }
- // if (!SetFlag(absl::StrCat("--", option, "=", override_string).c_str()))
- // {
- // return false;
- // }
- // }
- // return true;
- // }
+ bool SetFlags(xla::CompileOptionsProto options) override {
+ // Set extra options, overriding env variables if appropriate.
+ for (auto [option, option_override] : options.env_option_overrides()) {
+ std::string override_string;
+ if (option_override.has_string_field()) {
+ override_string = option_override.string_field();
+ } else if (option_override.has_bool_field()) {
+ override_string = option_override.bool_field() ? "true" : "false";
+ } else if (option_override.has_int_field()) {
+ override_string = std::to_string(option_override.int_field());
+ } else if (option_override.has_double_field()) {
+ override_string = std::to_string(option_override.double_field());
+ } else {
+ assert(false &&
+ "option value should be of type string, bool, int, or double");
+ }
+ if (!SetFlag(("--" + option + "=" + override_string).c_str())) {
+ return false;
+ }
+ }
+ return true;
+ }
std::string GetFlags() override {
std::string flags;
diff --git a/integrations/pjrt/test/test_compile_options.py b/integrations/pjrt/test/test_compile_options.py
new file mode 100644
index 0000000..877ab99
--- /dev/null
+++ b/integrations/pjrt/test/test_compile_options.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from functools import partial
+import jax.numpy as jnp
+from jax import jit
+
+a = jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+
+@partial(jit, compiler_options={"iree-scheduling-dump-statistics-format": "csv"})
+def f(a, b):
+ return a + b
+
+
+print(f(a, a))