[pjrt] Add ROCM and Vulkan backends. (#15106)
* The vulkan backend raises an error on some memory transfer thing that
needs some debugging for d/h transfers.
* I don't have a functioning ROCM installation on this machine but it
compiles and loads to the point of trying to find ROCM libraries at
runtime.
* I'm going to raise an RFC to enable ROCM compiler backend support in
default builds as then ROCM would work out of the box with shipped
project compiler binary packages.
diff --git a/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py b/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py
index 765d529..0feb18d 100644
--- a/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py
+++ b/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py
@@ -14,9 +14,11 @@
def pytest_addoption(parser, pluginmanager) -> None:
- parser.addoption("--openxla-pjrt-artifact-dir",
- dest="OPENXLA_PJRT_ARTIFACT_DIR",
- help="Saves OpenXLA PJRT compilation artifacts")
+ parser.addoption(
+ "--openxla-pjrt-artifact-dir",
+ dest="OPENXLA_PJRT_ARTIFACT_DIR",
+ help="Saves OpenXLA PJRT compilation artifacts",
+ )
def pytest_sessionstart(session: pytest.Session) -> None:
@@ -26,10 +28,11 @@
def pytest_runtest_setup(item: pytest.Item) -> None:
artifact_dir = item.session.stash[ARTIFACT_DIR_KEY]
- if artifact_dir is None: return
- sanitized_name = (item.nodeid.replace(".py::",
- "::").replace("/", "_").replace(
- "::", "__"))
+ if artifact_dir is None:
+ return
+ sanitized_name = (
+ item.nodeid.replace(".py::", "::").replace("/", "_").replace("::", "__")
+ )
test_dir = artifact_dir / sanitized_name
shutil.rmtree(test_dir, ignore_errors=True)
test_dir.mkdir(parents=True, exist_ok=True)
@@ -43,11 +46,12 @@
def pytest_runtest_makereport(item, call) -> None:
outcome = yield
test_dir = item.stash[TEST_DIR_KEY]
- if test_dir is None: return
+ if test_dir is None:
+ return
result = outcome.get_result()
- if call.when == 'call' and result.failed:
+ if call.when == "call" and result.failed:
with open(test_dir / "error.txt", "wt") as f:
f.write(result.longreprtext)
f.write("\n\nSTDERR:\n-------\n")
@@ -60,7 +64,8 @@
def pytest_runtest_teardown(item: pytest.Item) -> None:
test_dir = item.stash[TEST_DIR_KEY]
- if test_dir is None: return
+ if test_dir is None:
+ return
dir_entries = list(test_dir.iterdir())
crash_marker = test_dir / "CRASH_MARKER"
if crash_marker.is_file():
@@ -68,4 +73,3 @@
if not dir_entries:
# Remove empty directories on success.
test_dir.rmdir()
-
diff --git a/integrations/pjrt/ctstools/setup.py b/integrations/pjrt/ctstools/setup.py
index a7cf6e3..f3ab04d 100644
--- a/integrations/pjrt/ctstools/setup.py
+++ b/integrations/pjrt/ctstools/setup.py
@@ -9,6 +9,8 @@
setup(
name="openxla_pjrt_ctstools",
packages=["openxla.cts"],
- entry_points={"pytest11": ["openxla_pjrt_artifacts = openxla.cts.pytest_artifact_saver"]},
+ entry_points={
+ "pytest11": ["openxla_pjrt_artifacts = openxla.cts.pytest_artifact_saver"]
+ },
classifiers=["Framework :: Pytest"],
)
diff --git a/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py b/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py
index 9b9a897..4f10470 100644
--- a/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py
+++ b/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py
@@ -32,9 +32,7 @@
print("*****************************", file=sys.stderr)
self.build_configuration(
os.path.join(THIS_DIR, "build", "cmake"),
- extra_cmake_args=(
- "-DIREE_HAL_DRIVER_CUDA=ON",
- ),
+ extra_cmake_args=("-DIREE_HAL_DRIVER_CUDA=ON",),
)
print("Target populated.", file=sys.stderr)
@@ -56,10 +54,10 @@
author="The IREE Team",
author_email="iree-discuss@googlegroups.com",
license="Apache-2.0",
- description="IREE PJRT Plugin for CPUs (generic)",
+ description="IREE PJRT Plugin for CUDA (generic)",
long_description=README,
long_description_content_type="text/markdown",
- url="https://github.com/openxla/openxla-pjrt-plugin",
+ url="https://github.com/openxla/iree",
classifiers=[
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/jax_plugins/iree_rocm/__init__.py b/integrations/pjrt/python_packages/iree_rocm_plugin/jax_plugins/iree_rocm/__init__.py
new file mode 100644
index 0000000..9ab8058
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/jax_plugins/iree_rocm/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+from pathlib import Path
+import platform
+import sys
+
+import jax._src.xla_bridge as xb
+
+logger = logging.getLogger(__name__)
+
+
+def probe_iree_compiler_dylib() -> str:
+ """Probes an installed iree.compiler for the compiler dylib."""
+ # TODO: Move this out of the ctypes API initialization.
+ from iree.compiler.api import ctypes_dl
+
+ return ctypes_dl._probe_iree_compiler_dylib()
+
+
+def initialize():
+ import iree._pjrt_libs.rocm as lib_package
+
+ path = Path(lib_package.__file__).resolve().parent / "pjrt_plugin_iree_rocm.so"
+ if not path.exists():
+ logger.warning(
+ f"WARNING: Native library {path} does not exist. "
+ f"This most likely indicates an issue with how {__package__} "
+ f"was built or installed."
+ )
+ xb.register_plugin(
+ "iree_rocm",
+ priority=500,
+ library_path=str(path),
+ options={
+ "COMPILER_LIB_PATH": str(probe_iree_compiler_dylib()),
+ },
+ )
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml
new file mode 100644
index 0000000..f6c1689
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml
@@ -0,0 +1,6 @@
+[build-system]
+requires = [
+ "setuptools>=42",
+ "wheel",
+]
+build-backend = "setuptools.build_meta"
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
new file mode 100644
index 0000000..795aa38
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
@@ -0,0 +1,93 @@
+#!/usr/bin/python3
+
+# Copyright 2023 The OpenXLA Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Early splice the _setup_support directory onto the python path.
+import os
+from pathlib import Path
+import sys
+
+THIS_DIR = os.path.realpath(os.path.dirname(__file__))
+sys.path.insert(0, os.path.join(THIS_DIR, "..", "_setup_support"))
+
+import iree_pjrt_setup
+from setuptools import setup, find_namespace_packages
+
+README = r"""
+OpenXLA PJRT Plugin for ROCM
+"""
+
+# Setup and get version information.
+CMAKE_BUILD_DIR_ABS = os.path.join(THIS_DIR, "build", "cmake")
+
+
+class CMakeBuildPy(iree_pjrt_setup.BaseCMakeBuildPy):
+ def build_default_configuration(self):
+ print("*****************************", file=sys.stderr)
+ print("* Building base runtime *", file=sys.stderr)
+ print("*****************************", file=sys.stderr)
+ self.build_configuration(
+ os.path.join(THIS_DIR, "build", "cmake"),
+ extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=ROCM",),
+ )
+ print("Target populated.", file=sys.stderr)
+
+
+iree_pjrt_setup.populate_built_package(
+ os.path.join(
+ CMAKE_BUILD_DIR_ABS,
+ "python",
+ "iree",
+ "_pjrt_libs",
+ "rocm",
+ )
+)
+
+
+setup(
+ name=f"iree-pjrt-plugin-rocm{iree_pjrt_setup.PACKAGE_SUFFIX}",
+ version=f"{iree_pjrt_setup.PACKAGE_VERSION}",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache-2.0",
+ description="IREE PJRT Plugin for ROCM (generic)",
+ long_description=README,
+ long_description_content_type="text/markdown",
+ url="https://github.com/openxla/iree",
+ classifiers=[
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ ],
+ packages=[
+ "jax_plugins.iree_rocm",
+ "iree._pjrt_libs.rocm",
+ ],
+ package_dir={
+ "jax_plugins.iree_rocm": "jax_plugins/iree_rocm",
+ "iree._pjrt_libs.rocm": "build/cmake/python/iree/_pjrt_libs/rocm",
+ },
+ package_data={
+ "iree._pjrt_libs.rocm": ["pjrt_plugin_iree_rocm.*"],
+ },
+ cmdclass={
+ "build": iree_pjrt_setup.PjrtPluginBuild,
+ "build_py": CMakeBuildPy,
+ "bdist_wheel": iree_pjrt_setup.bdist_wheel,
+ "install": iree_pjrt_setup.platlib_install,
+ },
+ zip_safe=False, # Needs to reference embedded shared libraries.
+ entry_points={
+ # We must advertise which Python modules should be treated as loadable
+ # plugins. This augments the path based scanning that Jax does, which
+ # is not always robust to all packaging circumstances.
+ "jax_plugins": [
+ "iree-rocm = jax_plugins.iree_rocm",
+ ],
+ },
+ install_requires=iree_pjrt_setup.install_requires,
+)
diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/jax_plugins/iree_vulkan/__init__.py b/integrations/pjrt/python_packages/iree_vulkan_plugin/jax_plugins/iree_vulkan/__init__.py
new file mode 100644
index 0000000..a3fa678
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/jax_plugins/iree_vulkan/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+from pathlib import Path
+import platform
+import sys
+
+import jax._src.xla_bridge as xb
+
+logger = logging.getLogger(__name__)
+
+
+def probe_iree_compiler_dylib() -> str:
+ """Probes an installed iree.compiler for the compiler dylib."""
+ # TODO: Move this out of the ctypes API initialization.
+ from iree.compiler.api import ctypes_dl
+
+ return ctypes_dl._probe_iree_compiler_dylib()
+
+
+def initialize():
+ import iree._pjrt_libs.vulkan as lib_package
+
+ path = Path(lib_package.__file__).resolve().parent / "pjrt_plugin_iree_vulkan.so"
+ if not path.exists():
+ logger.warning(
+ f"WARNING: Native library {path} does not exist. "
+ f"This most likely indicates an issue with how {__package__} "
+ f"was built or installed."
+ )
+ xb.register_plugin(
+ "iree_vulkan",
+ priority=500,
+ library_path=str(path),
+ options={
+ "COMPILER_LIB_PATH": str(probe_iree_compiler_dylib()),
+ },
+ )
diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml
new file mode 100644
index 0000000..f6c1689
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml
@@ -0,0 +1,6 @@
+[build-system]
+requires = [
+ "setuptools>=42",
+ "wheel",
+]
+build-backend = "setuptools.build_meta"
diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/setup.py b/integrations/pjrt/python_packages/iree_vulkan_plugin/setup.py
new file mode 100644
index 0000000..2995c02
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/setup.py
@@ -0,0 +1,93 @@
+#!/usr/bin/python3
+
+# Copyright 2023 The OpenXLA Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Early splice the _setup_support directory onto the python path.
+import os
+from pathlib import Path
+import sys
+
+THIS_DIR = os.path.realpath(os.path.dirname(__file__))
+sys.path.insert(0, os.path.join(THIS_DIR, "..", "_setup_support"))
+
+import iree_pjrt_setup
+from setuptools import setup, find_namespace_packages
+
+README = r"""
+OpenXLA PJRT Plugin for Vulkan
+"""
+
+# Setup and get version information.
+CMAKE_BUILD_DIR_ABS = os.path.join(THIS_DIR, "build", "cmake")
+
+
+class CMakeBuildPy(iree_pjrt_setup.BaseCMakeBuildPy):
+ def build_default_configuration(self):
+ print("*****************************", file=sys.stderr)
+ print("* Building base runtime *", file=sys.stderr)
+ print("*****************************", file=sys.stderr)
+ self.build_configuration(
+ os.path.join(THIS_DIR, "build", "cmake"),
+ extra_cmake_args=("-DIREE_HAL_DRIVER_VULKAN=ON",),
+ )
+ print("Target populated.", file=sys.stderr)
+
+
+iree_pjrt_setup.populate_built_package(
+ os.path.join(
+ CMAKE_BUILD_DIR_ABS,
+ "python",
+ "iree",
+ "_pjrt_libs",
+ "vulkan",
+ )
+)
+
+
+setup(
+ name=f"iree-pjrt-plugin-vulkan{iree_pjrt_setup.PACKAGE_SUFFIX}",
+ version=f"{iree_pjrt_setup.PACKAGE_VERSION}",
+ author="The IREE Team",
+ author_email="iree-discuss@googlegroups.com",
+ license="Apache-2.0",
+ description="IREE PJRT Plugin for Vulkan (generic)",
+ long_description=README,
+ long_description_content_type="text/markdown",
+ url="https://github.com/openxla/iree",
+ classifiers=[
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ ],
+ packages=[
+ "jax_plugins.iree_vulkan",
+ "iree._pjrt_libs.vulkan",
+ ],
+ package_dir={
+ "jax_plugins.iree_vulkan": "jax_plugins/iree_vulkan",
+ "iree._pjrt_libs.vulkan": "build/cmake/python/iree/_pjrt_libs/vulkan",
+ },
+ package_data={
+ "iree._pjrt_libs.vulkan": ["pjrt_plugin_iree_vulkan.*"],
+ },
+ cmdclass={
+ "build": iree_pjrt_setup.PjrtPluginBuild,
+ "build_py": CMakeBuildPy,
+ "bdist_wheel": iree_pjrt_setup.bdist_wheel,
+ "install": iree_pjrt_setup.platlib_install,
+ },
+ zip_safe=False, # Needs to reference embedded shared libraries.
+ entry_points={
+ # We must advertise which Python modules should be treated as loadable
+ # plugins. This augments the path based scanning that Jax does, which
+ # is not always robust to all packaging circumstances.
+ "jax_plugins": [
+ "iree-vulkan = jax_plugins.iree_vulkan",
+ ],
+ },
+ install_requires=iree_pjrt_setup.install_requires,
+)
diff --git a/integrations/pjrt/src/CMakeLists.txt b/integrations/pjrt/src/CMakeLists.txt
index ac40a8d..6310d38 100644
--- a/integrations/pjrt/src/CMakeLists.txt
+++ b/integrations/pjrt/src/CMakeLists.txt
@@ -22,8 +22,14 @@
add_subdirectory(iree_pjrt/partitioner_api)
if(IREE_HAL_DRIVER_LOCAL_TASK)
- add_subdirectory(iree_pjrt/cpu)
+ add_subdirectory(iree_pjrt/cpu)
endif()
if(IREE_HAL_DRIVER_CUDA)
- add_subdirectory(iree_pjrt/cuda)
+ add_subdirectory(iree_pjrt/cuda)
+endif()
+if("ROCM" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
+ add_subdirectory(iree_pjrt/rocm)
+endif()
+if(IREE_HAL_DRIVER_VULKAN)
+ add_subdirectory(iree_pjrt/vulkan)
endif()
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 2c65500..9ca76e9 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -945,6 +945,8 @@
ClientInstance::ClientInstance(std::unique_ptr<Platform> platform)
: platform_(std::move(platform)) {
host_allocator_ = iree_allocator_system();
+ IREE_CHECK_OK(
+ iree_hal_driver_registry_allocate(host_allocator_, &driver_registry_));
cached_platform_version_ = "git"; // TODO: Plumb through version info.
}
@@ -959,6 +961,7 @@
// ordering (bad shutdown ordering of the driver is a frequent cause of
// bugs).
iree_hal_driver_release(driver_);
+ iree_hal_driver_registry_free(driver_registry_);
}
void ClientInstance::BindApi(PJRT_Api* api) {
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index d56a0bd..dc405f1 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -402,7 +402,7 @@
// created against an API.
//===----------------------------------------------------------------------===//
-struct ClientInstance {
+class ClientInstance {
public:
ClientInstance(std::unique_ptr<Platform> platform);
virtual ~ClientInstance();
@@ -461,6 +461,7 @@
protected:
iree_allocator_t host_allocator_;
+ iree_hal_driver_registry_t* driver_registry_ = nullptr;
std::string cached_platform_name_;
std::string cached_platform_version_;
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt
new file mode 100644
index 0000000..1bc4128
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt
@@ -0,0 +1,49 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_cc_library(
+ NAME
+ client
+ HDRS
+ "client.h"
+ SRCS
+ "client.cc"
+ DEPS
+ iree_pjrt::common
+ iree::experimental::rocm
+ iree::experimental::rocm::registration
+)
+
+iree_cc_library(
+ SHARED
+ NAME
+ dylib
+ DEFINES
+ # Causes PJRT dynamic linking entry points to be made visible.
+ PJRT_PLUGIN_BUILDING_LIBRARY
+ SRCS
+ "dylib_entry_point.cc"
+ DEPS
+ ::client
+ iree_pjrt::common
+ iree_pjrt::common::dylib_platform
+)
+
+# Output to the project wide python binary directory tree.
+set(_NATIVE_PYTHON_DIR "${IREE_PJRT_PYTHON_BINARY_DIR}/iree/_pjrt_libs/rocm")
+file(WRITE "${_NATIVE_PYTHON_DIR}/__init__.py" "")
+set_target_properties(iree_pjrt_rocm_dylib
+ PROPERTIES
+ PREFIX "" # Disable "lib" prefix.
+ LIBRARY_OUTPUT_NAME pjrt_plugin_iree_rocm
+ RUNTIME_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+ LIBRARY_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+)
+
+# TODO: Find a better way to decide whether can link with undefined symbols.
+if(NOT IREE_ENABLE_ASAN)
+ target_link_options(iree_pjrt_rocm_dylib PRIVATE "-Wl,--no-undefined")
+endif()
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc
new file mode 100644
index 0000000..5f290d0
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc
@@ -0,0 +1,37 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/rocm/client.h"
+
+#include "experimental/rocm/registration/driver_module.h"
+
+namespace iree::pjrt::rocm {
+
+ROCMClientInstance::ROCMClientInstance(std::unique_ptr<Platform> platform)
+ : ClientInstance(std::move(platform)) {
+ // Seems that it must match how registered. Action at a distance not
+ // great.
+ // TODO: Get this when constructing the client so it is guaranteed to
+ // match.
+ cached_platform_name_ = "iree_rocm";
+ IREE_CHECK_OK(iree_hal_rocm_driver_module_register(driver_registry_));
+}
+
+ROCMClientInstance::~ROCMClientInstance() {}
+
+iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
+ iree_string_view_t driver_name = iree_make_cstring_view("rocm");
+ IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
+ driver_registry_, driver_name, host_allocator_, out_driver));
+ logger().debug("ROCM driver created");
+ return iree_ok_status();
+}
+
+bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
+ return compiler_job->SetFlag("--iree-hal-target-backends=rocm");
+}
+
+} // namespace iree::pjrt::rocm
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.h b/integrations/pjrt/src/iree_pjrt/rocm/client.h
new file mode 100644
index 0000000..e2b78da
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/client.h
@@ -0,0 +1,27 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
+#define IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
+
+#include "experimental/rocm/api.h"
+#include "iree_pjrt/common/api_impl.h"
+
+namespace iree::pjrt::rocm {
+
+class ROCMClientInstance final : public ClientInstance {
+ public:
+ ROCMClientInstance(std::unique_ptr<Platform> platform);
+ ~ROCMClientInstance();
+ iree_status_t CreateDriver(iree_hal_driver_t** out_driver) override;
+ bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override;
+
+ private:
+};
+
+} // namespace iree::pjrt::rocm
+
+#endif // IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/dylib_entry_point.cc b/integrations/pjrt/src/iree_pjrt/rocm/dylib_entry_point.cc
new file mode 100644
index 0000000..e157a73
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/dylib_entry_point.cc
@@ -0,0 +1,22 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/common/dylib_platform.h"
+#include "iree_pjrt/rocm/client.h"
+
+// Provides the shared library exports.
+#include "iree_pjrt/common/dylib_entry_point.cc.inc"
+
+namespace iree::pjrt {
+namespace {
+
+// Declared but not implemented by the include file.
+void InitializeAPI(PJRT_Api* api) {
+ BindApi<DylibPlatform, rocm::ROCMClientInstance>(api);
+}
+
+} // namespace
+} // namespace iree::pjrt
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/vulkan/CMakeLists.txt
new file mode 100644
index 0000000..4524b28
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/CMakeLists.txt
@@ -0,0 +1,49 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_cc_library(
+ NAME
+ client
+ HDRS
+ "client.h"
+ SRCS
+ "client.cc"
+ DEPS
+ iree_pjrt::common
+ iree::hal::drivers::vulkan
+ iree::hal::drivers::vulkan::registration
+)
+
+iree_cc_library(
+ SHARED
+ NAME
+ dylib
+ DEFINES
+ # Causes PJRT dynamic linking entry points to be made visible.
+ PJRT_PLUGIN_BUILDING_LIBRARY
+ SRCS
+ "dylib_entry_point.cc"
+ DEPS
+ ::client
+ iree_pjrt::common
+ iree_pjrt::common::dylib_platform
+)
+
+# Output to the project wide python binary directory tree.
+set(_NATIVE_PYTHON_DIR "${IREE_PJRT_PYTHON_BINARY_DIR}/iree/_pjrt_libs/vulkan")
+file(WRITE "${_NATIVE_PYTHON_DIR}/__init__.py" "")
+set_target_properties(iree_pjrt_vulkan_dylib
+ PROPERTIES
+ PREFIX "" # Disable "lib" prefix.
+ LIBRARY_OUTPUT_NAME pjrt_plugin_iree_vulkan
+ RUNTIME_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+ LIBRARY_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+)
+
+# TODO: Find a better way to decide whether can link with undefined symbols.
+if(NOT IREE_ENABLE_ASAN)
+ target_link_options(iree_pjrt_vulkan_dylib PRIVATE "-Wl,--no-undefined")
+endif()
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc
new file mode 100644
index 0000000..853ead8
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc
@@ -0,0 +1,38 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/vulkan/client.h"
+
+#include "iree/hal/drivers/vulkan/registration/driver_module.h"
+
+namespace iree::pjrt::vulkan {
+
+VulkanClientInstance::VulkanClientInstance(std::unique_ptr<Platform> platform)
+ : ClientInstance(std::move(platform)) {
+ // Seems that it must match how registered. Action at a distance not
+ // great.
+ // TODO: Get this when constructing the client so it is guaranteed to
+ // match.
+ cached_platform_name_ = "iree_vulkan";
+ IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(driver_registry_));
+}
+
+VulkanClientInstance::~VulkanClientInstance() {}
+
+iree_status_t VulkanClientInstance::CreateDriver(
+ iree_hal_driver_t** out_driver) {
+ iree_string_view_t driver_name = iree_make_cstring_view("vulkan");
+ IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
+ driver_registry_, driver_name, host_allocator_, out_driver));
+ logger().debug("Vulkan driver created");
+ return iree_ok_status();
+}
+
+bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
+ return compiler_job->SetFlag("--iree-hal-target-backends=vulkan");
+}
+
+} // namespace iree::pjrt::vulkan
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.h b/integrations/pjrt/src/iree_pjrt/vulkan/client.h
new file mode 100644
index 0000000..7520355
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.h
@@ -0,0 +1,27 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_PJRT_PLUGIN_PJRT_VULKAN_CLIENT_H_
+#define IREE_PJRT_PLUGIN_PJRT_VULKAN_CLIENT_H_
+
+#include "iree/hal/drivers/vulkan/api.h"
+#include "iree_pjrt/common/api_impl.h"
+
+namespace iree::pjrt::vulkan {
+
+class VulkanClientInstance final : public ClientInstance {
+ public:
+ VulkanClientInstance(std::unique_ptr<Platform> platform);
+ ~VulkanClientInstance();
+ iree_status_t CreateDriver(iree_hal_driver_t** out_driver) override;
+ bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override;
+
+ private:
+};
+
+} // namespace iree::pjrt::vulkan
+
+#endif // IREE_PJRT_PLUGIN_PJRT_VULKAN_CLIENT_H_
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/dylib_entry_point.cc b/integrations/pjrt/src/iree_pjrt/vulkan/dylib_entry_point.cc
new file mode 100644
index 0000000..46d1f40
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/dylib_entry_point.cc
@@ -0,0 +1,22 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/common/dylib_platform.h"
+#include "iree_pjrt/vulkan/client.h"
+
+// Provides the shared library exports.
+#include "iree_pjrt/common/dylib_entry_point.cc.inc"
+
+namespace iree::pjrt {
+namespace {
+
+// Declared but not implemented by the include file.
+void InitializeAPI(PJRT_Api* api) {
+ BindApi<DylibPlatform, vulkan::VulkanClientInstance>(api);
+}
+
+} // namespace
+} // namespace iree::pjrt