[pjrt] Add ROCM and Vulkan backends. (#15106)

* The vulkan backend raises an error on some memory transfer thing that
needs some debugging for d/h transfers.
* I don't have a functioning ROCM installation on this machine but it
compiles and loads to the point of trying to find ROCM libraries at
runtime.
* I'm going to raise an RFC to enable ROCM compiler backend support in
default builds as then ROCM would work out of the box with shipped
project compiler binary packages.
diff --git a/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py b/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py
index 765d529..0feb18d 100644
--- a/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py
+++ b/integrations/pjrt/ctstools/openxla/cts/pytest_artifact_saver.py
@@ -14,9 +14,11 @@
 
 
 def pytest_addoption(parser, pluginmanager) -> None:
-    parser.addoption("--openxla-pjrt-artifact-dir",
-                     dest="OPENXLA_PJRT_ARTIFACT_DIR",
-                     help="Saves OpenXLA PJRT compilation artifacts")
+    parser.addoption(
+        "--openxla-pjrt-artifact-dir",
+        dest="OPENXLA_PJRT_ARTIFACT_DIR",
+        help="Saves OpenXLA PJRT compilation artifacts",
+    )
 
 
 def pytest_sessionstart(session: pytest.Session) -> None:
@@ -26,10 +28,11 @@
 
 def pytest_runtest_setup(item: pytest.Item) -> None:
     artifact_dir = item.session.stash[ARTIFACT_DIR_KEY]
-    if artifact_dir is None: return
-    sanitized_name = (item.nodeid.replace(".py::",
-                                          "::").replace("/", "_").replace(
-                                              "::", "__"))
+    if artifact_dir is None:
+        return
+    sanitized_name = (
+        item.nodeid.replace(".py::", "::").replace("/", "_").replace("::", "__")
+    )
     test_dir = artifact_dir / sanitized_name
     shutil.rmtree(test_dir, ignore_errors=True)
     test_dir.mkdir(parents=True, exist_ok=True)
@@ -43,11 +46,12 @@
 def pytest_runtest_makereport(item, call) -> None:
     outcome = yield
     test_dir = item.stash[TEST_DIR_KEY]
-    if test_dir is None: return
+    if test_dir is None:
+        return
 
     result = outcome.get_result()
 
-    if call.when == 'call' and result.failed:
+    if call.when == "call" and result.failed:
         with open(test_dir / "error.txt", "wt") as f:
             f.write(result.longreprtext)
             f.write("\n\nSTDERR:\n-------\n")
@@ -60,7 +64,8 @@
 
 def pytest_runtest_teardown(item: pytest.Item) -> None:
     test_dir = item.stash[TEST_DIR_KEY]
-    if test_dir is None: return
+    if test_dir is None:
+        return
     dir_entries = list(test_dir.iterdir())
     crash_marker = test_dir / "CRASH_MARKER"
     if crash_marker.is_file():
@@ -68,4 +73,3 @@
     if not dir_entries:
         # Remove empty directories on success.
         test_dir.rmdir()
-
diff --git a/integrations/pjrt/ctstools/setup.py b/integrations/pjrt/ctstools/setup.py
index a7cf6e3..f3ab04d 100644
--- a/integrations/pjrt/ctstools/setup.py
+++ b/integrations/pjrt/ctstools/setup.py
@@ -9,6 +9,8 @@
 setup(
     name="openxla_pjrt_ctstools",
     packages=["openxla.cts"],
-    entry_points={"pytest11": ["openxla_pjrt_artifacts = openxla.cts.pytest_artifact_saver"]},
+    entry_points={
+        "pytest11": ["openxla_pjrt_artifacts = openxla.cts.pytest_artifact_saver"]
+    },
     classifiers=["Framework :: Pytest"],
 )
diff --git a/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py b/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py
index 9b9a897..4f10470 100644
--- a/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py
+++ b/integrations/pjrt/python_packages/iree_cuda_plugin/setup.py
@@ -32,9 +32,7 @@
         print("*****************************", file=sys.stderr)
         self.build_configuration(
             os.path.join(THIS_DIR, "build", "cmake"),
-            extra_cmake_args=(
-                "-DIREE_HAL_DRIVER_CUDA=ON",
-            ),
+            extra_cmake_args=("-DIREE_HAL_DRIVER_CUDA=ON",),
         )
         print("Target populated.", file=sys.stderr)
 
@@ -56,10 +54,10 @@
     author="The IREE Team",
     author_email="iree-discuss@googlegroups.com",
     license="Apache-2.0",
-    description="IREE PJRT Plugin for CPUs (generic)",
+    description="IREE PJRT Plugin for CUDA (generic)",
     long_description=README,
     long_description_content_type="text/markdown",
-    url="https://github.com/openxla/openxla-pjrt-plugin",
+    url="https://github.com/openxla/iree",
     classifiers=[
         "Development Status :: 3 - Alpha",
         "License :: OSI Approved :: Apache Software License",
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/jax_plugins/iree_rocm/__init__.py b/integrations/pjrt/python_packages/iree_rocm_plugin/jax_plugins/iree_rocm/__init__.py
new file mode 100644
index 0000000..9ab8058
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/jax_plugins/iree_rocm/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+from pathlib import Path
+import platform
+import sys
+
+import jax._src.xla_bridge as xb
+
+logger = logging.getLogger(__name__)
+
+
+def probe_iree_compiler_dylib() -> str:
+    """Probes an installed iree.compiler for the compiler dylib."""
+    # TODO: Move this out of the ctypes API initialization.
+    from iree.compiler.api import ctypes_dl
+
+    return ctypes_dl._probe_iree_compiler_dylib()
+
+
+def initialize():
+    import iree._pjrt_libs.rocm as lib_package
+
+    path = Path(lib_package.__file__).resolve().parent / "pjrt_plugin_iree_rocm.so"
+    if not path.exists():
+        logger.warning(
+            f"WARNING: Native library {path} does not exist. "
+            f"This most likely indicates an issue with how {__package__} "
+            f"was built or installed."
+        )
+    xb.register_plugin(
+        "iree_rocm",
+        priority=500,
+        library_path=str(path),
+        options={
+            "COMPILER_LIB_PATH": str(probe_iree_compiler_dylib()),
+        },
+    )
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml
new file mode 100644
index 0000000..f6c1689
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml
@@ -0,0 +1,6 @@
+[build-system]
+requires = [
+    "setuptools>=42",
+    "wheel",
+]
+build-backend = "setuptools.build_meta"
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
new file mode 100644
index 0000000..795aa38
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
@@ -0,0 +1,93 @@
+#!/usr/bin/python3
+
+# Copyright 2023 The OpenXLA Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Early splice the _setup_support directory onto the python path.
+import os
+from pathlib import Path
+import sys
+
+THIS_DIR = os.path.realpath(os.path.dirname(__file__))
+sys.path.insert(0, os.path.join(THIS_DIR, "..", "_setup_support"))
+
+import iree_pjrt_setup
+from setuptools import setup, find_namespace_packages
+
+README = r"""
+OpenXLA PJRT Plugin for ROCM
+"""
+
+# Setup and get version information.
+CMAKE_BUILD_DIR_ABS = os.path.join(THIS_DIR, "build", "cmake")
+
+
+class CMakeBuildPy(iree_pjrt_setup.BaseCMakeBuildPy):
+    def build_default_configuration(self):
+        print("*****************************", file=sys.stderr)
+        print("* Building base runtime     *", file=sys.stderr)
+        print("*****************************", file=sys.stderr)
+        self.build_configuration(
+            os.path.join(THIS_DIR, "build", "cmake"),
+            extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=ROCM",),
+        )
+        print("Target populated.", file=sys.stderr)
+
+
+iree_pjrt_setup.populate_built_package(
+    os.path.join(
+        CMAKE_BUILD_DIR_ABS,
+        "python",
+        "iree",
+        "_pjrt_libs",
+        "rocm",
+    )
+)
+
+
+setup(
+    name=f"iree-pjrt-plugin-rocm{iree_pjrt_setup.PACKAGE_SUFFIX}",
+    version=f"{iree_pjrt_setup.PACKAGE_VERSION}",
+    author="The IREE Team",
+    author_email="iree-discuss@googlegroups.com",
+    license="Apache-2.0",
+    description="IREE PJRT Plugin for ROCM (generic)",
+    long_description=README,
+    long_description_content_type="text/markdown",
+    url="https://github.com/openxla/iree",
+    classifiers=[
+        "Development Status :: 3 - Alpha",
+        "License :: OSI Approved :: Apache Software License",
+        "Programming Language :: Python :: 3",
+    ],
+    packages=[
+        "jax_plugins.iree_rocm",
+        "iree._pjrt_libs.rocm",
+    ],
+    package_dir={
+        "jax_plugins.iree_rocm": "jax_plugins/iree_rocm",
+        "iree._pjrt_libs.rocm": "build/cmake/python/iree/_pjrt_libs/rocm",
+    },
+    package_data={
+        "iree._pjrt_libs.rocm": ["pjrt_plugin_iree_rocm.*"],
+    },
+    cmdclass={
+        "build": iree_pjrt_setup.PjrtPluginBuild,
+        "build_py": CMakeBuildPy,
+        "bdist_wheel": iree_pjrt_setup.bdist_wheel,
+        "install": iree_pjrt_setup.platlib_install,
+    },
+    zip_safe=False,  # Needs to reference embedded shared libraries.
+    entry_points={
+        # We must advertise which Python modules should be treated as loadable
+        # plugins. This augments the path based scanning that Jax does, which
+        # is not always robust to all packaging circumstances.
+        "jax_plugins": [
+            "iree-rocm = jax_plugins.iree_rocm",
+        ],
+    },
+    install_requires=iree_pjrt_setup.install_requires,
+)
diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/jax_plugins/iree_vulkan/__init__.py b/integrations/pjrt/python_packages/iree_vulkan_plugin/jax_plugins/iree_vulkan/__init__.py
new file mode 100644
index 0000000..a3fa678
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/jax_plugins/iree_vulkan/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+from pathlib import Path
+import platform
+import sys
+
+import jax._src.xla_bridge as xb
+
+logger = logging.getLogger(__name__)
+
+
+def probe_iree_compiler_dylib() -> str:
+    """Probes an installed iree.compiler for the compiler dylib."""
+    # TODO: Move this out of the ctypes API initialization.
+    from iree.compiler.api import ctypes_dl
+
+    return ctypes_dl._probe_iree_compiler_dylib()
+
+
+def initialize():
+    import iree._pjrt_libs.vulkan as lib_package
+
+    path = Path(lib_package.__file__).resolve().parent / "pjrt_plugin_iree_vulkan.so"
+    if not path.exists():
+        logger.warning(
+            f"WARNING: Native library {path} does not exist. "
+            f"This most likely indicates an issue with how {__package__} "
+            f"was built or installed."
+        )
+    xb.register_plugin(
+        "iree_vulkan",
+        priority=500,
+        library_path=str(path),
+        options={
+            "COMPILER_LIB_PATH": str(probe_iree_compiler_dylib()),
+        },
+    )
diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml
new file mode 100644
index 0000000..f6c1689
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml
@@ -0,0 +1,6 @@
+[build-system]
+requires = [
+    "setuptools>=42",
+    "wheel",
+]
+build-backend = "setuptools.build_meta"
diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/setup.py b/integrations/pjrt/python_packages/iree_vulkan_plugin/setup.py
new file mode 100644
index 0000000..2995c02
--- /dev/null
+++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/setup.py
@@ -0,0 +1,93 @@
+#!/usr/bin/python3
+
+# Copyright 2023 The OpenXLA Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Early splice the _setup_support directory onto the python path.
+import os
+from pathlib import Path
+import sys
+
+THIS_DIR = os.path.realpath(os.path.dirname(__file__))
+sys.path.insert(0, os.path.join(THIS_DIR, "..", "_setup_support"))
+
+import iree_pjrt_setup
+from setuptools import setup, find_namespace_packages
+
+README = r"""
+OpenXLA PJRT Plugin for Vulkan
+"""
+
+# Setup and get version information.
+CMAKE_BUILD_DIR_ABS = os.path.join(THIS_DIR, "build", "cmake")
+
+
+class CMakeBuildPy(iree_pjrt_setup.BaseCMakeBuildPy):
+    def build_default_configuration(self):
+        print("*****************************", file=sys.stderr)
+        print("* Building base runtime     *", file=sys.stderr)
+        print("*****************************", file=sys.stderr)
+        self.build_configuration(
+            os.path.join(THIS_DIR, "build", "cmake"),
+            extra_cmake_args=("-DIREE_HAL_DRIVER_VULKAN=ON",),
+        )
+        print("Target populated.", file=sys.stderr)
+
+
+iree_pjrt_setup.populate_built_package(
+    os.path.join(
+        CMAKE_BUILD_DIR_ABS,
+        "python",
+        "iree",
+        "_pjrt_libs",
+        "vulkan",
+    )
+)
+
+
+setup(
+    name=f"iree-pjrt-plugin-vulkan{iree_pjrt_setup.PACKAGE_SUFFIX}",
+    version=f"{iree_pjrt_setup.PACKAGE_VERSION}",
+    author="The IREE Team",
+    author_email="iree-discuss@googlegroups.com",
+    license="Apache-2.0",
+    description="IREE PJRT Plugin for Vulkan (generic)",
+    long_description=README,
+    long_description_content_type="text/markdown",
+    url="https://github.com/openxla/iree",
+    classifiers=[
+        "Development Status :: 3 - Alpha",
+        "License :: OSI Approved :: Apache Software License",
+        "Programming Language :: Python :: 3",
+    ],
+    packages=[
+        "jax_plugins.iree_vulkan",
+        "iree._pjrt_libs.vulkan",
+    ],
+    package_dir={
+        "jax_plugins.iree_vulkan": "jax_plugins/iree_vulkan",
+        "iree._pjrt_libs.vulkan": "build/cmake/python/iree/_pjrt_libs/vulkan",
+    },
+    package_data={
+        "iree._pjrt_libs.vulkan": ["pjrt_plugin_iree_vulkan.*"],
+    },
+    cmdclass={
+        "build": iree_pjrt_setup.PjrtPluginBuild,
+        "build_py": CMakeBuildPy,
+        "bdist_wheel": iree_pjrt_setup.bdist_wheel,
+        "install": iree_pjrt_setup.platlib_install,
+    },
+    zip_safe=False,  # Needs to reference embedded shared libraries.
+    entry_points={
+        # We must advertise which Python modules should be treated as loadable
+        # plugins. This augments the path based scanning that Jax does, which
+        # is not always robust to all packaging circumstances.
+        "jax_plugins": [
+            "iree-vulkan = jax_plugins.iree_vulkan",
+        ],
+    },
+    install_requires=iree_pjrt_setup.install_requires,
+)
diff --git a/integrations/pjrt/src/CMakeLists.txt b/integrations/pjrt/src/CMakeLists.txt
index ac40a8d..6310d38 100644
--- a/integrations/pjrt/src/CMakeLists.txt
+++ b/integrations/pjrt/src/CMakeLists.txt
@@ -22,8 +22,14 @@
 add_subdirectory(iree_pjrt/partitioner_api)
 
 if(IREE_HAL_DRIVER_LOCAL_TASK)
-    add_subdirectory(iree_pjrt/cpu)
+  add_subdirectory(iree_pjrt/cpu)
 endif()
 if(IREE_HAL_DRIVER_CUDA)
-    add_subdirectory(iree_pjrt/cuda)
+  add_subdirectory(iree_pjrt/cuda)
+endif()
+if("ROCM" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
+  add_subdirectory(iree_pjrt/rocm)
+endif()
+if(IREE_HAL_DRIVER_VULKAN)
+  add_subdirectory(iree_pjrt/vulkan)
 endif()
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
index 2c65500..9ca76e9 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.cc
@@ -945,6 +945,8 @@
 ClientInstance::ClientInstance(std::unique_ptr<Platform> platform)
     : platform_(std::move(platform)) {
   host_allocator_ = iree_allocator_system();
+  IREE_CHECK_OK(
+      iree_hal_driver_registry_allocate(host_allocator_, &driver_registry_));
   cached_platform_version_ = "git";  // TODO: Plumb through version info.
 }
 
@@ -959,6 +961,7 @@
   // ordering (bad shutdown ordering of the driver is a frequent cause of
   // bugs).
   iree_hal_driver_release(driver_);
+  iree_hal_driver_registry_free(driver_registry_);
 }
 
 void ClientInstance::BindApi(PJRT_Api* api) {
diff --git a/integrations/pjrt/src/iree_pjrt/common/api_impl.h b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
index d56a0bd..dc405f1 100644
--- a/integrations/pjrt/src/iree_pjrt/common/api_impl.h
+++ b/integrations/pjrt/src/iree_pjrt/common/api_impl.h
@@ -402,7 +402,7 @@
 // created against an API.
 //===----------------------------------------------------------------------===//
 
-struct ClientInstance {
+class ClientInstance {
  public:
   ClientInstance(std::unique_ptr<Platform> platform);
   virtual ~ClientInstance();
@@ -461,6 +461,7 @@
 
  protected:
   iree_allocator_t host_allocator_;
+  iree_hal_driver_registry_t* driver_registry_ = nullptr;
   std::string cached_platform_name_;
   std::string cached_platform_version_;
 
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt
new file mode 100644
index 0000000..1bc4128
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt
@@ -0,0 +1,49 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_cc_library(
+  NAME
+    client
+  HDRS
+    "client.h"
+  SRCS
+    "client.cc"
+  DEPS
+    iree_pjrt::common
+    iree::experimental::rocm
+    iree::experimental::rocm::registration
+)
+
+iree_cc_library(
+  SHARED
+  NAME
+    dylib
+  DEFINES
+    # Causes PJRT dynamic linking entry points to be made visible.
+    PJRT_PLUGIN_BUILDING_LIBRARY
+  SRCS
+    "dylib_entry_point.cc"
+  DEPS
+    ::client
+    iree_pjrt::common
+    iree_pjrt::common::dylib_platform
+)
+
+# Output to the project wide python binary directory tree.
+set(_NATIVE_PYTHON_DIR "${IREE_PJRT_PYTHON_BINARY_DIR}/iree/_pjrt_libs/rocm")
+file(WRITE "${_NATIVE_PYTHON_DIR}/__init__.py" "")
+set_target_properties(iree_pjrt_rocm_dylib
+  PROPERTIES
+    PREFIX ""  # Disable "lib" prefix.
+    LIBRARY_OUTPUT_NAME pjrt_plugin_iree_rocm
+    RUNTIME_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+    LIBRARY_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+)
+
+# TODO: Find a better way to decide whether can link with undefined symbols.
+if(NOT IREE_ENABLE_ASAN)
+  target_link_options(iree_pjrt_rocm_dylib PRIVATE "-Wl,--no-undefined")
+endif()
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc
new file mode 100644
index 0000000..5f290d0
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc
@@ -0,0 +1,37 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/rocm/client.h"
+
+#include "experimental/rocm/registration/driver_module.h"
+
+namespace iree::pjrt::rocm {
+
+ROCMClientInstance::ROCMClientInstance(std::unique_ptr<Platform> platform)
+    : ClientInstance(std::move(platform)) {
+  // Seems that it must match how registered. Action at a distance not
+  // great.
+  // TODO: Get this when constructing the client so it is guaranteed to
+  // match.
+  cached_platform_name_ = "iree_rocm";
+  IREE_CHECK_OK(iree_hal_rocm_driver_module_register(driver_registry_));
+}
+
+ROCMClientInstance::~ROCMClientInstance() {}
+
+iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) {
+  iree_string_view_t driver_name = iree_make_cstring_view("rocm");
+  IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
+      driver_registry_, driver_name, host_allocator_, out_driver));
+  logger().debug("ROCM driver created");
+  return iree_ok_status();
+}
+
+bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
+  return compiler_job->SetFlag("--iree-hal-target-backends=rocm");
+}
+
+}  // namespace iree::pjrt::rocm
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.h b/integrations/pjrt/src/iree_pjrt/rocm/client.h
new file mode 100644
index 0000000..e2b78da
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/client.h
@@ -0,0 +1,27 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
+#define IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
+
+#include "experimental/rocm/api.h"
+#include "iree_pjrt/common/api_impl.h"
+
+namespace iree::pjrt::rocm {
+
+class ROCMClientInstance final : public ClientInstance {
+ public:
+  ROCMClientInstance(std::unique_ptr<Platform> platform);
+  ~ROCMClientInstance();
+  iree_status_t CreateDriver(iree_hal_driver_t** out_driver) override;
+  bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override;
+
+ private:
+};
+
+}  // namespace iree::pjrt::rocm
+
+#endif  // IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_
diff --git a/integrations/pjrt/src/iree_pjrt/rocm/dylib_entry_point.cc b/integrations/pjrt/src/iree_pjrt/rocm/dylib_entry_point.cc
new file mode 100644
index 0000000..e157a73
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/rocm/dylib_entry_point.cc
@@ -0,0 +1,22 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/common/dylib_platform.h"
+#include "iree_pjrt/rocm/client.h"
+
+// Provides the shared library exports.
+#include "iree_pjrt/common/dylib_entry_point.cc.inc"
+
+namespace iree::pjrt {
+namespace {
+
+// Declared but not implemented by the include file.
+void InitializeAPI(PJRT_Api* api) {
+  BindApi<DylibPlatform, rocm::ROCMClientInstance>(api);
+}
+
+}  // namespace
+}  // namespace iree::pjrt
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/vulkan/CMakeLists.txt
new file mode 100644
index 0000000..4524b28
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/CMakeLists.txt
@@ -0,0 +1,49 @@
+# Copyright 2023 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_cc_library(
+  NAME
+    client
+  HDRS
+    "client.h"
+  SRCS
+    "client.cc"
+  DEPS
+    iree_pjrt::common
+    iree::hal::drivers::vulkan
+    iree::hal::drivers::vulkan::registration
+)
+
+iree_cc_library(
+  SHARED
+  NAME
+    dylib
+  DEFINES
+    # Causes PJRT dynamic linking entry points to be made visible.
+    PJRT_PLUGIN_BUILDING_LIBRARY
+  SRCS
+    "dylib_entry_point.cc"
+  DEPS
+    ::client
+    iree_pjrt::common
+    iree_pjrt::common::dylib_platform
+)
+
+# Output to the project wide python binary directory tree.
+set(_NATIVE_PYTHON_DIR "${IREE_PJRT_PYTHON_BINARY_DIR}/iree/_pjrt_libs/vulkan")
+file(WRITE "${_NATIVE_PYTHON_DIR}/__init__.py" "")
+set_target_properties(iree_pjrt_vulkan_dylib
+  PROPERTIES
+    PREFIX ""  # Disable "lib" prefix.
+    LIBRARY_OUTPUT_NAME pjrt_plugin_iree_vulkan
+    RUNTIME_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+    LIBRARY_OUTPUT_DIRECTORY "${_NATIVE_PYTHON_DIR}"
+)
+
+# TODO: Find a better way to decide whether can link with undefined symbols.
+if(NOT IREE_ENABLE_ASAN)
+  target_link_options(iree_pjrt_vulkan_dylib PRIVATE "-Wl,--no-undefined")
+endif()
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc
new file mode 100644
index 0000000..853ead8
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc
@@ -0,0 +1,38 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/vulkan/client.h"
+
+#include "iree/hal/drivers/vulkan/registration/driver_module.h"
+
+namespace iree::pjrt::vulkan {
+
+VulkanClientInstance::VulkanClientInstance(std::unique_ptr<Platform> platform)
+    : ClientInstance(std::move(platform)) {
+  // Seems that it must match how registered. Action at a distance not
+  // great.
+  // TODO: Get this when constructing the client so it is guaranteed to
+  // match.
+  cached_platform_name_ = "iree_vulkan";
+  IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(driver_registry_));
+}
+
+VulkanClientInstance::~VulkanClientInstance() {}
+
+iree_status_t VulkanClientInstance::CreateDriver(
+    iree_hal_driver_t** out_driver) {
+  iree_string_view_t driver_name = iree_make_cstring_view("vulkan");
+  IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create(
+      driver_registry_, driver_name, host_allocator_, out_driver));
+  logger().debug("Vulkan driver created");
+  return iree_ok_status();
+}
+
+bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) {
+  return compiler_job->SetFlag("--iree-hal-target-backends=vulkan");
+}
+
+}  // namespace iree::pjrt::vulkan
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.h b/integrations/pjrt/src/iree_pjrt/vulkan/client.h
new file mode 100644
index 0000000..7520355
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.h
@@ -0,0 +1,27 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_PJRT_PLUGIN_PJRT_VULKAN_CLIENT_H_
+#define IREE_PJRT_PLUGIN_PJRT_VULKAN_CLIENT_H_
+
+#include "iree/hal/drivers/vulkan/api.h"
+#include "iree_pjrt/common/api_impl.h"
+
+namespace iree::pjrt::vulkan {
+
+class VulkanClientInstance final : public ClientInstance {
+ public:
+  VulkanClientInstance(std::unique_ptr<Platform> platform);
+  ~VulkanClientInstance();
+  iree_status_t CreateDriver(iree_hal_driver_t** out_driver) override;
+  bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override;
+
+ private:
+};
+
+}  // namespace iree::pjrt::vulkan
+
+#endif  // IREE_PJRT_PLUGIN_PJRT_VULKAN_CLIENT_H_
diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/dylib_entry_point.cc b/integrations/pjrt/src/iree_pjrt/vulkan/dylib_entry_point.cc
new file mode 100644
index 0000000..46d1f40
--- /dev/null
+++ b/integrations/pjrt/src/iree_pjrt/vulkan/dylib_entry_point.cc
@@ -0,0 +1,22 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree_pjrt/common/dylib_platform.h"
+#include "iree_pjrt/vulkan/client.h"
+
+// Provides the shared library exports.
+#include "iree_pjrt/common/dylib_entry_point.cc.inc"
+
+namespace iree::pjrt {
+namespace {
+
+// Declared but not implemented by the include file.
+void InitializeAPI(PJRT_Api* api) {
+  BindApi<DylibPlatform, vulkan::VulkanClientInstance>(api);
+}
+
+}  // namespace
+}  // namespace iree::pjrt