blob: b3238076d927d24e281ab4d678c3b871714659ec [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <functional>
#include <iostream> // TODO: Remove
#include <vector>
#include "iree_pjrt/common/compiler.h"
#include "iree_pjrt/partitioner_api/embedding_api.h"
namespace iree::pjrt {
//===----------------------------------------------------------------------===//
// IREE compiler.
//===----------------------------------------------------------------------===//
namespace {
class MMapCompilerOutput : public CompilerOutput {
public:
MMapCompilerOutput(openxla_partitioner_output_t* output, void* data,
size_t length)
: output_(output), data_(data), length_(length) {}
~MMapCompilerOutput() { openxlaPartitionerOutputDestroy(output_); }
void* GetData() { return data_; }
size_t GetDataSize() { return length_; }
private:
openxla_partitioner_output_t* output_;
void* data_;
size_t length_;
};
using SessionRecycler = std::function<void(openxla_partitioner_session_t*)>;
class OpenXLAPartitionerJob : public CompilerJob {
public:
// Takes ownership of both |session| and |inv|. On destruction, destroys
// |inv| and passes |session| to the recycler (this can be used to implement
// session pooling).
OpenXLAPartitionerJob(openxla_partitioner_session_t* session,
openxla_partitioner_invocation_t* inv,
SessionRecycler session_recycler)
: session_(session), inv_(inv), session_recycler_(session_recycler) {}
~OpenXLAPartitionerJob() {
if (error_) {
openxlaPartitionerErrorDestroy(error_);
}
for (auto* source : retained_sources_) {
openxlaPartitionerSourceDestroy(source);
}
openxlaPartitionerInvocationDestroy(inv_);
session_recycler_(session_);
if (output_) {
openxlaPartitionerOutputDestroy(output_);
}
}
std::string GetErrorMessage() override {
if (!error_) return std::string();
const char* cstr = openxlaPartitionerErrorGetMessage(error_);
return std::string(cstr);
}
void EnableCrashDumps(
ArtifactDumper::Transaction* artifact_transaction) override {
if (crash_dump_transaction_) return;
crash_dump_transaction_ = artifact_transaction;
openxlaPartitionerInvocationSetCrashHandler(
inv_, /*genLocalReproducer=*/false,
[](openxla_partitioner_output_t** outOutput,
void* userData) -> openxla_partitioner_error_t* {
auto* self = static_cast<OpenXLAPartitionerJob*>(userData);
auto maybePath = self->crash_dump_transaction_->AllocateArtifactPath(
/*label=*/"crash_reproducer", /*extension=*/"mlir",
/*index=*/self->crash_dump_count_++);
if (!maybePath) {
*outOutput = nullptr;
return nullptr;
}
return openxlaPartitionerOutputOpenFile(maybePath->c_str(),
outOutput);
},
this);
}
bool SetFlag(const char* flag) override {
auto* error = openxlaPartitionerSessionSetFlags(session_, 1, &flag);
if (error) {
SetError(error);
return false;
}
return true;
}
// TODO: Find another way to deal with this.
// bool SetFlags(xla::CompileOptions options) override {
// int num_partitions = options.executable_build_options.num_partitions();
// int num_replicas = options.executable_build_options.num_replicas();
// bool use_spmd_partitioning =
// options.executable_build_options.use_spmd_partitioning();
// auto allow_spmd_sharding_propagation_to_output =
// options.executable_build_options
// .allow_spmd_sharding_propagation_to_output();
// if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-num-partitions=",
// num_partitions)
// .c_str())) {
// return false;
// }
// if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-replica-count=",
// num_replicas)
// .c_str())) {
// return false;
// }
// if (!SetFlag(
// absl::StrCat("--openxla-partitioner-gspmd-use-spmd-partitioning=",
// use_spmd_partitioning)
// .c_str())) {
// return false;
// }
// if (!SetFlag(
// absl::StrCat(
// "--openxla-partitioner-gspmd-allow-spmd-"
// "sharding-propagation-to-output=",
// absl::StrJoin(allow_spmd_sharding_propagation_to_output,
// ",")) .c_str())) {
// return false;
// }
// return true;
// }
std::string GetFlags() override {
std::string flags;
openxlaPartitionerSessionGetFlags(
session_, /*nonDefaultOnly=*/false,
[](const char* flag, size_t length, void* userData) {
std::string* capture_flags = static_cast<std::string*>(userData);
if (!capture_flags->empty()) {
capture_flags->append(" ");
}
capture_flags->append(flag, length);
},
&flags);
return flags;
}
bool ParseSourceBuffer(const void* buffer, size_t length) override {
openxla_partitioner_source_t* source;
auto* error = openxlaPartitionerSourceWrapBuffer(
session_, "<jit>", static_cast<const char*>(buffer), length,
/*isNullTerminated=*/false, &source);
if (error) {
SetError(error);
return false;
}
retained_sources_.push_back(source);
return openxlaPartitionerInvocationParseSource(inv_, source);
}
std::unique_ptr<CompilerOutput> CompileStandardPipeline() override {
if (!openxlaPartitionerInvocationPipeline(inv_, "gspmd")) {
return nullptr;
}
openxla_partitioner_error_t* error =
openxlaPartitionerOutputOpenMembuffer(&output_);
if (error) {
SetError(error);
return nullptr;
}
// Output.
error = openxlaPartitionerInvocationOutputIR(inv_, output_);
if (error) {
SetError(error);
return nullptr;
}
// Map the data.
void* output_data = nullptr;
uint64_t size = -1;
error = openxlaPartitionerOutputMapMemory(output_, &output_data, &size);
if (error) {
SetError(error);
return nullptr;
}
// Transfer the output_ to MMapCompilerOutput since the mapping is only
// valid for the life of the output.
openxla_partitioner_output_t* local_output = output_;
output_ = nullptr;
return std::make_unique<MMapCompilerOutput>(local_output, output_data,
size);
}
private:
void SetError(openxla_partitioner_error_t* error) {
if (error_) {
openxlaPartitionerErrorDestroy(error_);
}
error_ = error;
}
openxla_partitioner_session_t* session_;
openxla_partitioner_invocation_t* inv_;
SessionRecycler session_recycler_;
std::vector<openxla_partitioner_source_t*> retained_sources_;
openxla_partitioner_error_t* error_ = nullptr;
ArtifactDumper::Transaction* crash_dump_transaction_ = nullptr;
int crash_dump_count_ = 0;
// Output.
openxla_partitioner_output_t* output_ = nullptr;
};
} // namespace
std::unique_ptr<CompilerJob> OpenXLAPartitioner::StartJob() {
auto* session = openxlaPartitionerSessionCreate();
auto* inv = openxlaPartitionerInvocationCreate(session);
// TODO: Capture diagnostics, etc vs spewing to stderr.
openxlaPartitionerInvocationEnableConsoleDiagnostics(inv);
return std::make_unique<OpenXLAPartitionerJob>(
session, inv, [](openxla_partitioner_session_t* session) {
openxlaPartitionerSessionDestroy(session);
});
}
std::string OpenXLAPartitioner::GetRevision() {
std::string result;
const char* revision = openxlaPartitionerGetRevision();
result.append(revision[0] ? revision : "<unknown>");
result.append(" (API version ");
int packed_api_version = openxlaPartitionerGetAPIVersion();
result.append(std::to_string(packed_api_version >> 16));
result.append(".");
result.append(std::to_string(packed_api_version & 0xffff));
result.append(")");
return result;
}
} // namespace iree::pjrt