Update to PJRT 0.38 (#15492)
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 865dfa5..d6f1dc4 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -1449,7 +1449,7 @@
return iree_ok_status();
}
-PJRT_Error* ClientInstance::Compile(PJRT_Program* program,
+PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
/*xla::CompileOptions options,*/
LoadedExecutableInstance** out_executable) {
std::unique_ptr<ArtifactDumper::Transaction> artifact_tx;
@@ -2155,7 +2155,7 @@
api->struct_size = PJRT_Api_STRUCT_SIZE;
api->extension_start = nullptr;
api->pjrt_api_version.major_version = PJRT_API_MAJOR;
- api->pjrt_api_version.minor_version = PJRT_API_MINOR - 1;
+ api->pjrt_api_version.minor_version = PJRT_API_MINOR;
// This is a bare implementation throwing UNDEFINED errors. This way new
// functions will not segmentation fault on invocation.
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index ecbea06..c6debca 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -451,8 +451,9 @@
// Compiles.
// See TODOs in PJRT_Client_Compile.
- PJRT_Error* Compile(PJRT_Program* program, /*xla::CompileOptions options, */
- LoadedExecutableInstance** executable);
+ PJRT_Error* Compile(
+ const PJRT_Program* program, /*xla::CompileOptions options, */
+ LoadedExecutableInstance** executable);
// ---------------------------------------------------------------------------
// Subclass hooks.
@@ -530,7 +531,7 @@
// Populate config_vars() from the client create_options.
for (size_t i = 0; i < args->num_options; ++i) {
- PJRT_NamedValue* nv = args->create_options + i;
+ const PJRT_NamedValue* nv = args->create_options + i;
// For now, we only support string types.
if (nv->type != PJRT_NamedValue_kString) continue;
std::string name(nv->name, nv->name_size);
diff --git a/integrations/pjrt/src/iree_pjrt/common/stubs.inc b/integrations/pjrt/src/iree_pjrt/common/stubs.inc
index e63f6e8..2204b3c 100644
--- a/integrations/pjrt/src/iree_pjrt/common/stubs.inc
+++ b/integrations/pjrt/src/iree_pjrt/common/stubs.inc
@@ -93,4 +93,4 @@
_STUB(PJRT_Buffer_CopyToMemory);
_STUB(PJRT_Client_CreateViewOfDeviceBuffer);
_STUB(PJRT_Executable_Fingerprint);
-
+ _STUB(PJRT_Client_TopologyDescription);
diff --git a/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
index 18b6492..122f941 100644
--- a/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
+++ b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
@@ -25,7 +25,7 @@
#define PJRT_DEFINE_STRUCT_TRAITS(sname, last_field) \
typedef struct sname sname; \
- const size_t sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field);
+ enum { sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field) }
#ifdef __cplusplus
extern "C" {
@@ -53,7 +53,7 @@
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
-#define PJRT_API_MINOR 35
+#define PJRT_API_MINOR 38
// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
@@ -183,8 +183,8 @@
size_t struct_size;
void* priv;
// Returned attributes have the lifetime of the process.
- PJRT_NamedValue* attributes; // out
- size_t num_attributes; // out
+ const PJRT_NamedValue* attributes; // out
+ size_t num_attributes; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Plugin_Attributes_Args, attributes);
@@ -282,6 +282,7 @@
typedef struct PJRT_Device PJRT_Device;
typedef struct PJRT_Memory PJRT_Memory;
typedef struct PJRT_DeviceDescription PJRT_DeviceDescription;
+typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
typedef struct PJRT_Executable PJRT_Executable;
typedef struct PJRT_LoadedExecutable PJRT_LoadedExecutable;
typedef struct PJRT_Buffer PJRT_Buffer;
@@ -345,7 +346,7 @@
size_t struct_size;
void* priv;
// Extra platform-specific options to create a client.
- PJRT_NamedValue* create_options;
+ const PJRT_NamedValue* create_options;
size_t num_options;
// Key-value get/put callback provided by the caller of PJRT_Client_Create.
// PJRT client can use these callbacks to share information between
@@ -418,12 +419,26 @@
typedef PJRT_Error* PJRT_Client_PlatformVersion(
PJRT_Client_PlatformVersion_Args* args);
+struct PJRT_Client_TopologyDescription_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // Is owned by and has the same lifetime as `client`.
+ PJRT_TopologyDescription* topology; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_TopologyDescription_Args, topology);
+
+// Returns the topology description of the runtime topology. The returned
+// topology is owned by the client and should not be deleted by the caller.
+typedef PJRT_Error* PJRT_Client_TopologyDescription(
+ PJRT_Client_TopologyDescription_Args* args);
+
struct PJRT_Client_Devices_Args {
size_t struct_size;
void* priv;
PJRT_Client* client;
- PJRT_Device** devices; // out
- size_t num_devices; // out
+ PJRT_Device* const* devices; // out
+ size_t num_devices; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Devices_Args, num_devices);
@@ -435,8 +450,8 @@
size_t struct_size;
void* priv;
PJRT_Client* client;
- PJRT_Device** addressable_devices; // out
- size_t num_addressable_devices; // out
+ PJRT_Device* const* addressable_devices; // out
+ size_t num_addressable_devices; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableDevices_Args,
num_addressable_devices);
@@ -483,8 +498,8 @@
size_t struct_size;
void* priv;
PJRT_Client* client;
- PJRT_Memory** addressable_memories; // out
- size_t num_addressable_memories; // out
+ PJRT_Memory* const* addressable_memories; // out
+ size_t num_addressable_memories; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_AddressableMemories_Args,
num_addressable_memories);
@@ -518,7 +533,7 @@
PJRT_Client* client;
// Only needs to stay alive for the duration of the Compile call.
// `program->format` and `program->format_size` are owned by the caller.
- PJRT_Program* program;
+ const PJRT_Program* program;
// TODO(b/240560013): consider putting some of option fields in priv.
// Serialized CompileOptionsProto
// (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto)
@@ -799,8 +814,8 @@
size_t struct_size;
void* priv;
PJRT_DeviceDescription* device_description;
- size_t num_attributes; // out
- PJRT_NamedValue* attributes; // out
+ size_t num_attributes; // out
+ const PJRT_NamedValue* attributes; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_DeviceDescription_Attributes_Args, attributes);
@@ -898,8 +913,8 @@
void* priv;
PJRT_Device* device;
// Has the lifetime of `device`.
- PJRT_Memory** memories; // out
- size_t num_memories; // out
+ PJRT_Memory* const* memories; // out
+ size_t num_memories; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Device_AddressableMemories_Args, memories);
@@ -1025,8 +1040,8 @@
size_t struct_size;
void* priv;
PJRT_Memory* memory;
- PJRT_Device** devices; // out
- size_t num_devices; // out
+ PJRT_Device* const* devices; // out
+ size_t num_devices; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Memory_AddressableByDevices_Args, num_devices);
@@ -1114,8 +1129,8 @@
size_t struct_size;
void* priv;
PJRT_LoadedExecutable* executable;
- PJRT_Device** addressable_devices; // out
- size_t num_addressable_devices; // out
+ PJRT_Device* const* addressable_devices; // out
+ size_t num_addressable_devices; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_AddressableDevices_Args,
num_addressable_devices);
@@ -1259,7 +1274,7 @@
// Only needs to stay alive for the duration of the Execute call.
PJRT_ExecuteOptions* options;
// Execution input of size [`num_devices`, `num_args`].
- PJRT_Buffer*** argument_lists;
+ PJRT_Buffer* const* const* argument_lists;
size_t num_devices;
size_t num_args;
// Execution output of size [`num_devices`, num_outputs`], where `num_outputs`
@@ -1267,7 +1282,7 @@
// outer (`PJRT_Buffer***`) and inner lists (`PJRT_Buffer**`) must be
// allocated and deallocated by the caller. PJRT_Buffer_Destroy must be called
// on the output PJRT_Buffer*.
- PJRT_Buffer*** output_lists; // in/out
+ PJRT_Buffer** const* output_lists; // in/out
// If `device_complete_events` isn't nullptr, `device_complete_events` needs
// to be the same length as `output_lists` (i.e. of length `num_devices`), and
// each `PJRT_Event` will become ready once the corresponding device execution
@@ -1340,7 +1355,7 @@
size_t num_properties; // out
// `properties` and any embedded data are owned by and have the same lifetime
// as `executable`.
- PJRT_NamedValue* properties; // out
+ const PJRT_NamedValue* properties; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCostAnalysis_Args, properties);
@@ -1389,7 +1404,7 @@
PJRT_Executable* executable;
size_t num_outputs;
// Has length `num_outputs`.
- const char** memory_kinds; // out
+ const char* const* memory_kinds; // out
// Has length `num_outputs`.
const size_t* memory_kind_sizes; // out
};
@@ -1829,15 +1844,13 @@
// ------------------------------ Device Topology ------------------------------
-typedef struct PJRT_TopologyDescription PJRT_TopologyDescription;
-
struct PJRT_TopologyDescription_Create_Args {
size_t struct_size;
void* priv;
const char* topology_name;
size_t topology_name_size;
// Extra platform-specific options to create a client.
- PJRT_NamedValue* create_options;
+ const PJRT_NamedValue* create_options;
size_t num_options;
PJRT_TopologyDescription* topology; // out
};
@@ -1897,8 +1910,8 @@
void* priv;
PJRT_TopologyDescription* topology;
// Has the same lifetime as topology.
- PJRT_DeviceDescription** descriptions; // out
- size_t num_descriptions; // out
+ PJRT_DeviceDescription* const* descriptions; // out
+ size_t num_descriptions; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_GetDeviceDescriptions_Args,
num_descriptions);
@@ -1939,8 +1952,8 @@
PJRT_TopologyDescription* topology;
// Only lives as long as topology.
- PJRT_NamedValue* attributes; // out
- size_t num_attributes; // out
+ const PJRT_NamedValue* attributes; // out
+ size_t num_attributes; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_TopologyDescription_Attributes_Args,
num_attributes);
@@ -1955,7 +1968,7 @@
const PJRT_TopologyDescription* topology;
// Only needs to stay alive for the duration of the Compile call.
// `program->format` and `program->format_size` are owned by the caller.
- PJRT_Program* program;
+ const PJRT_Program* program;
// TODO(b/240560013): consider putting some of option fields in priv.
// Serialized CompileOptionsProto
// (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/compile_options.proto)
@@ -2107,12 +2120,18 @@
_PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputDimensions);
_PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);
+
_PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);
+
_PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_TopologyDescription);
} PJRT_Api;
-const size_t PJRT_Api_STRUCT_SIZE =
- PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateViewOfDeviceBuffer);
+enum {
+ PJRT_Api_STRUCT_SIZE =
+ PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_TopologyDescription)
+};
#undef _PJRT_API_STRUCT_FIELD