PJRT C API v0.35 (#15269)
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 9ca76e9..bd9b3ce 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -1817,19 +1817,144 @@
return status;
}
+static void BindUndefineds(PJRT_Api* api) {
+#define _STUB(API) \
+ api->API = +[](API##_Args* args) -> PJRT_Error* { \
+ return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, #API)); \
+ }
+
+ _STUB(PJRT_Plugin_Initialize);
+ _STUB(PJRT_Plugin_Attributes);
+
+ _STUB(PJRT_Event_Destroy);
+ _STUB(PJRT_Event_IsReady);
+ _STUB(PJRT_Event_Error);
+ _STUB(PJRT_Event_Await);
+ _STUB(PJRT_Event_OnReady);
+
+ _STUB(PJRT_Client_Create);
+ _STUB(PJRT_Client_Destroy);
+ _STUB(PJRT_Client_PlatformName);
+ _STUB(PJRT_Client_ProcessIndex);
+ _STUB(PJRT_Client_PlatformVersion);
+ _STUB(PJRT_Client_Devices);
+ _STUB(PJRT_Client_AddressableDevices);
+ _STUB(PJRT_Client_LookupDevice);
+ _STUB(PJRT_Client_LookupAddressableDevice);
+ _STUB(PJRT_Client_AddressableMemories);
+ _STUB(PJRT_Client_Compile);
+ _STUB(PJRT_Client_DefaultDeviceAssignment);
+ _STUB(PJRT_Client_BufferFromHostBuffer);
+
+ _STUB(PJRT_DeviceDescription_Id);
+ _STUB(PJRT_DeviceDescription_ProcessIndex);
+ _STUB(PJRT_DeviceDescription_Attributes);
+ _STUB(PJRT_DeviceDescription_Kind);
+ _STUB(PJRT_DeviceDescription_DebugString);
+ _STUB(PJRT_DeviceDescription_ToString);
+
+ _STUB(PJRT_Device_GetDescription);
+ _STUB(PJRT_Device_IsAddressable);
+ _STUB(PJRT_Device_LocalHardwareId);
+ _STUB(PJRT_Device_AddressableMemories);
+ _STUB(PJRT_Device_DefaultMemory);
+ _STUB(PJRT_Device_MemoryStats);
+
+ _STUB(PJRT_Memory_Id);
+ _STUB(PJRT_Memory_Kind);
+ _STUB(PJRT_Memory_DebugString);
+ _STUB(PJRT_Memory_ToString);
+ _STUB(PJRT_Memory_AddressableByDevices);
+
+ _STUB(PJRT_Executable_Destroy);
+ _STUB(PJRT_Executable_Name);
+ _STUB(PJRT_Executable_NumReplicas);
+ _STUB(PJRT_Executable_NumPartitions);
+ _STUB(PJRT_Executable_NumOutputs);
+ _STUB(PJRT_Executable_SizeOfGeneratedCodeInBytes);
+ _STUB(PJRT_Executable_GetCostAnalysis);
+ _STUB(PJRT_Executable_OutputMemoryKinds);
+ _STUB(PJRT_Executable_OptimizedProgram);
+ _STUB(PJRT_Executable_Serialize);
+
+ _STUB(PJRT_LoadedExecutable_Destroy);
+ _STUB(PJRT_LoadedExecutable_GetExecutable);
+ _STUB(PJRT_LoadedExecutable_AddressableDevices);
+ _STUB(PJRT_LoadedExecutable_Delete);
+ _STUB(PJRT_LoadedExecutable_IsDeleted);
+ _STUB(PJRT_LoadedExecutable_Execute);
+ _STUB(PJRT_Executable_DeserializeAndLoad);
+ _STUB(PJRT_LoadedExecutable_Fingerprint);
+
+ _STUB(PJRT_Buffer_Destroy);
+ _STUB(PJRT_Buffer_ElementType);
+ _STUB(PJRT_Buffer_Dimensions);
+ _STUB(PJRT_Buffer_UnpaddedDimensions);
+ _STUB(PJRT_Buffer_DynamicDimensionIndices);
+ _STUB(PJRT_Buffer_GetMemoryLayout);
+ _STUB(PJRT_Buffer_OnDeviceSizeInBytes);
+ _STUB(PJRT_Buffer_Device);
+ _STUB(PJRT_Buffer_Memory);
+ _STUB(PJRT_Buffer_Delete);
+ _STUB(PJRT_Buffer_IsDeleted);
+ _STUB(PJRT_Buffer_CopyToDevice);
+ _STUB(PJRT_Buffer_ToHostBuffer);
+ _STUB(PJRT_Buffer_IsOnCpu);
+ _STUB(PJRT_Buffer_ReadyEvent);
+ _STUB(PJRT_Buffer_UnsafePointer);
+ _STUB(PJRT_Buffer_IncreaseExternalReferenceCount);
+ _STUB(PJRT_Buffer_DecreaseExternalReferenceCount);
+ _STUB(PJRT_Buffer_OpaqueDeviceMemoryDataPointer);
+
+ _STUB(PJRT_CopyToDeviceStream_Destroy);
+ _STUB(PJRT_CopyToDeviceStream_AddChunk);
+ _STUB(PJRT_CopyToDeviceStream_TotalBytes);
+ _STUB(PJRT_CopyToDeviceStream_GranuleSize);
+ _STUB(PJRT_CopyToDeviceStream_CurrentBytes);
+
+ _STUB(PJRT_TopologyDescription_Create);
+ _STUB(PJRT_TopologyDescription_Destroy);
+ _STUB(PJRT_TopologyDescription_PlatformName);
+ _STUB(PJRT_TopologyDescription_PlatformVersion);
+ _STUB(PJRT_TopologyDescription_GetDeviceDescriptions);
+ _STUB(PJRT_TopologyDescription_Serialize);
+ _STUB(PJRT_TopologyDescription_Attributes);
+
+ _STUB(PJRT_Compile);
+
+ // Always add new fields to the end of the struct. Move fields below to their
+ // corresponding places after each major version bump.
+ _STUB(PJRT_Executable_OutputElementTypes);
+ _STUB(PJRT_Executable_OutputDimensions);
+
+ _STUB(PJRT_Buffer_CopyToMemory);
+
+ _STUB(PJRT_Client_CreateViewOfDeviceBuffer);
+}
+
//===----------------------------------------------------------------------===//
// Top-level API binding.
//===----------------------------------------------------------------------===//
void BindMonomorphicApi(PJRT_Api* api) {
api->struct_size = PJRT_Api_STRUCT_SIZE;
+ api->extension_start = nullptr;
+ api->pjrt_api_version.major_version = PJRT_API_MAJOR;
+ api->pjrt_api_version.minor_version = PJRT_API_MINOR;
+
+ // This is a bare implementation throwing UNDEFINED errors. This way new
+ // functions will not segmentation fault on invocation.
+ BindUndefineds(api);
+ ErrorInstance::BindApi(api);
+
+ api->PJRT_Plugin_Initialize =
+ +[](PJRT_Plugin_Initialize_Args* args) -> PJRT_Error* { return nullptr; };
// Bind by object types.
BufferInstance::BindApi(api);
ClientInstance::BindApi(api);
DeviceDescription::BindApi(api);
DeviceInstance::BindApi(api);
- ErrorInstance::BindApi(api);
EventInstance::BindApi(api);
ExecutableImage::BindApi(api);
LoadedExecutableInstance::BindApi(api);
diff --git a/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
index f3064d1..18b6492 100644
--- a/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
+++ b/integrations/pjrt/third_party/pjrt_c_api/xla/pjrt/c/pjrt_c_api.h
@@ -53,7 +53,7 @@
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
-#define PJRT_API_MINOR 31
+#define PJRT_API_MINOR 35
// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
@@ -718,6 +718,43 @@
typedef PJRT_Error* PJRT_Client_BufferFromHostBuffer(
PJRT_Client_BufferFromHostBuffer_Args* args);
+struct PJRT_Client_CreateViewOfDeviceBuffer_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Client* client;
+ // A pointer to a non-owned device buffer. A PJRT_Buffer that is a non-owned
+ // view of this device buffer will be created.
+ void* device_buffer_ptr;
+ const int64_t* dims;
+ size_t num_dims;
+ PJRT_Buffer_Type element_type;
+ PJRT_Buffer_MemoryLayout* layout;
+ // The device that `device_buffer_ptr` is on.
+ PJRT_Device* device;
+ // A callback to be performed when the PJRT_Buffer is done with the on-device
+ // buffer. This callback is optional and can be a nullptr.
+ void (*on_delete_callback)(void* device_buffer_ptr, void* user_arg);
+ // `on_delete_callback_arg` will be passed to `on_delete_callback` as
+ // `user_arg` argument.
+ void* on_delete_callback_arg;
+ // A platform-specific stream handle that should contain the work or events
+ // needed to materialize the on-device buffer. It is optional and can be
+ // casted from a nullptr. PJRT_Client_CreateViewOfDeviceBuffer_Args will
+ // append an event to `stream` that indicates when the returned buffer is
+ // ready to use. This is intended to support dlpack on GPU and is not expected
+ // to be supported on all hardware platforms.
+ intptr_t stream;
+ PJRT_Buffer* buffer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_CreateViewOfDeviceBuffer_Args, buffer);
+
+// Creates a PJRT buffer that is a non-owned view of an on-device buffer
+// (typically allocated by another library). The buffer may be mutated,
+// for example, if the buffer is donated to an Execute operation. This method is
+// not required on all hardware platforms.
+typedef PJRT_Error* PJRT_Client_CreateViewOfDeviceBuffer(
+ PJRT_Client_CreateViewOfDeviceBuffer_Args* args);
+
// -------------------------- Device Descriptions ------------------------------
// Device descriptions may be associated with an actual device
@@ -1278,6 +1315,24 @@
typedef PJRT_Error* PJRT_Executable_SizeOfGeneratedCodeInBytes(
PJRT_Executable_SizeOfGeneratedCodeInBytes_Args* args);
+struct PJRT_Executable_Fingerprint_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Executable* executable;
+ // Has the lifetime of `executable`
+ const char* executable_fingerprint; // out
+ size_t executable_fingerprint_size; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_Fingerprint_Args,
+ executable_fingerprint_size);
+
+// A unique fingerprint for `executable`. Two executables that were produced by
+// compiling with identical inputs (same program, compile options, compiler
+// version, etc.) should have the same fingerprint. May not be implemented by
+// all platforms.
+typedef PJRT_Error* PJRT_Executable_Fingerprint(
+ PJRT_Executable_Fingerprint_Args* args);
+
struct PJRT_Executable_GetCostAnalysis_Args {
size_t struct_size;
void* priv;
@@ -1397,10 +1452,11 @@
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_LoadedExecutable_Fingerprint_Args,
executable_fingerprint_size);
-// A unique fingerprint for `executable`. Two executables that were produced by
-// compiling with identical inputs (same program, compile options, compiler
-// version, etc.) should have the same fingerprint. May not be implemented by
-// all platforms.
+// DEPRECATED. Will be removed in PJRT version 2.0. Please use
+// PJRT_Executable_Fingerprint instead. A unique fingerprint for `executable`.
+// Two executables that were produced by compiling with identical inputs (same
+// program, compile options, compiler version, etc.) should have the same
+// fingerprint. May not be implemented by all platforms.
typedef PJRT_Error* PJRT_LoadedExecutable_Fingerprint(
PJRT_LoadedExecutable_Fingerprint_Args* args);
@@ -1565,12 +1621,27 @@
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToDevice_Args, dst_buffer);
-// Copies the buffer to device `dst_device`. Caller is responsible for freeing
-// returned `dst_buffer` with PJRT_Buffer_Destroy. Returns an error if the
-// buffer is already on `dst_device`.
+// Copies the buffer to device `dst_device` within the same client. Caller is
+// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy.
+// Returns an error if the buffer is already on `dst_device`.
typedef PJRT_Error* PJRT_Buffer_CopyToDevice(
PJRT_Buffer_CopyToDevice_Args* args);
+struct PJRT_Buffer_CopyToMemory_Args {
+ size_t struct_size;
+ void* priv;
+ PJRT_Buffer* buffer;
+ PJRT_Memory* dst_memory;
+ PJRT_Buffer* dst_buffer; // out
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_Buffer_CopyToMemory_Args, dst_buffer);
+
+// Copies the buffer to memory `dst_memory` within the same client. Caller is
+// responsible for freeing returned `dst_buffer` with PJRT_Buffer_Destroy.
+// Returns an error if the buffer is already on `dst_memory`.
+typedef PJRT_Error* PJRT_Buffer_CopyToMemory(
+ PJRT_Buffer_CopyToMemory_Args* args);
+
struct PJRT_Buffer_IsOnCpu_Args {
size_t struct_size;
void* priv;
@@ -1905,6 +1976,7 @@
typedef enum {
PJRT_Structure_Type_Gpu_Custom_Call = 0,
+ PJRT_Structure_Type_Profiler,
} PJRT_Structure_Type;
// PJRT_Structure_Base contains a type and a pointer to next
@@ -2033,10 +2105,14 @@
// corresponding places after each major version bump.
_PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputElementTypes);
_PJRT_API_STRUCT_FIELD(PJRT_Executable_OutputDimensions);
+
+ _PJRT_API_STRUCT_FIELD(PJRT_Buffer_CopyToMemory);
+ _PJRT_API_STRUCT_FIELD(PJRT_Client_CreateViewOfDeviceBuffer);
+ _PJRT_API_STRUCT_FIELD(PJRT_Executable_Fingerprint);
} PJRT_Api;
const size_t PJRT_Api_STRUCT_SIZE =
- PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Executable_OutputDimensions);
+ PJRT_STRUCT_SIZE(PJRT_Api, PJRT_Client_CreateViewOfDeviceBuffer);
#undef _PJRT_API_STRUCT_FIELD