Adding IREE_EXTERNAL_HAL_DRIVERS cmake var. This allows for out-of-tree drivers by specifying the static library target and registration function, ala LLVM_EXTERNAL_PROJECTS. To test this the ROCM driver has been moved to an external driver and can be enabled with -DIREE_EXTERNAL_HAL_DRIVERS=rocm during cmake config. With some bazel work we'd be able to make the internal drivers work this way too and remove a lot of boilerplate. For now the existing options are preserved.
diff --git a/CMakeLists.txt b/CMakeLists.txt index ff038f8..6999c61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -70,7 +70,6 @@ option(IREE_BUILD_EXPERIMENTAL_REMOTING "Builds experimental remoting support." OFF) option(IREE_BUILD_EXPERIMENTAL_WEB_SAMPLES "Builds experimental web samples." OFF) -option(IREE_HAL_DRIVER_EXPERIMENTAL_ROCM "Builds the experimental ROCm Backend." OFF) #------------------------------------------------------------------------------- # Derived flags based on primary options @@ -85,6 +84,10 @@ # -DIREE_HAL_DRIVER_DEFAULTS=OFF #------------------------------------------------------------------------------- +# External HAL drivers; see runtime/src/iree/hal/drivers/CMakeLists.txt for more +# information on how to declare external drivers. +set(IREE_EXTERNAL_HAL_DRIVERS "" CACHE STRING "") + option(IREE_HAL_DRIVER_DEFAULTS "Sets the default value for all runtime HAL drivers" ON) # CUDA support must be explicitly enabled. set(IREE_HAL_DRIVER_CUDA_DEFAULT OFF) @@ -144,6 +147,9 @@ if(IREE_HAL_DRIVER_VULKAN) message(STATUS " - vulkan") endif() +if(IREE_EXTERNAL_HAL_DRIVERS) + message(STATUS " + external: ${IREE_EXTERNAL_HAL_DRIVERS}") +endif() message(STATUS "IREE HAL local executable library loaders:") if(IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF) @@ -157,6 +163,15 @@ endif() #------------------------------------------------------------------------------- +# Experimental ROCM HAL driver +#------------------------------------------------------------------------------- + +set(IREE_EXTERNAL_ROCM_HAL_DRIVER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/experimental/rocm") +set(IREE_EXTERNAL_ROCM_HAL_DRIVER_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/experimental/rocm") +set(IREE_EXTERNAL_ROCM_HAL_DRIVER_TARGET "iree::experimental::rocm::registration") +set(IREE_EXTERNAL_ROCM_HAL_DRIVER_REGISTER "iree_hal_rocm_driver_module_register") + +#------------------------------------------------------------------------------- # Compiler Target Options # By default, all compiler targets supported by the current platform which do # not require external deps are enabled by default. This can be changed with: @@ -669,11 +684,6 @@ add_subdirectory(benchmarks) endif() -if(IREE_HAL_DRIVER_EXPERIMENTAL_ROCM) - add_subdirectory(build_tools/third_party/rocm EXCLUDE_FROM_ALL) - add_subdirectory(experimental/rocm) -endif() - if(IREE_BUILD_COMPILER) add_subdirectory(compiler) endif()
diff --git a/build_tools/cmake/iree_hal_cts_test_suite.cmake b/build_tools/cmake/iree_hal_cts_test_suite.cmake index 8c89438..d879b61 100644 --- a/build_tools/cmake/iree_hal_cts_test_suite.cmake +++ b/build_tools/cmake/iree_hal_cts_test_suite.cmake
@@ -50,16 +50,6 @@ ${ARGN} ) - # Omit tests for which the specified driver is not enabled. - string(TOUPPER ${_RULE_DRIVER_NAME} _UPPERCASE_DRIVER) - string(REPLACE "-" "_" _NORMALIZED_DRIVER ${_UPPERCASE_DRIVER}) - if(NOT DEFINED IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) - message(SEND_ERROR "Unknown driver '${_RULE_DRIVER_NAME}'. Check IREE_HAL_DRIVER_* options.") - endif() - if(NOT IREE_HAL_DRIVER_${_NORMALIZED_DRIVER}) - return() - endif() - list(APPEND _RULE_LABELS "driver=${_RULE_DRIVER_NAME}") # Enable executable tests if a compiler target backend capable of producing
diff --git a/build_tools/third_party/rocm/CMakeLists.txt b/build_tools/third_party/rocm/CMakeLists.txt deleted file mode 100644 index d4aba2d..0000000 --- a/build_tools/third_party/rocm/CMakeLists.txt +++ /dev/null
@@ -1,33 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -if(NOT ${IREE_HAL_DRIVER_EXPERIMENTAL_ROCM}) - return() -endif() - -if(NOT ROCM_HEADERS_API_ROOT) - set(ROCM_HEADERS_API_ROOT "/opt/rocm/include") -endif() - -if (EXISTS ${ROCM_HEADERS_API_ROOT}) - message(STATUS "ROCm Header Path: ${ROCM_HEADERS_API_ROOT}") -else() - message(SEND_ERROR "Could not locate ROCm: ${ROCM_HEADERS_API_ROOT}") -endif() - -external_cc_library( - PACKAGE - rocm_headers - NAME - rocm_headers - ROOT - ${ROCM_HEADERS_API_ROOT} - HDRS - "hip/hip_runtime.h" - INCLUDES - ${ROCM_HEADERS_API_ROOT} -) - -unset(ROCM_HEADERS_API_ROOT) \ No newline at end of file
diff --git a/docs/website/docs/deployment-configurations/gpu-cuda-rocm.md b/docs/website/docs/deployment-configurations/gpu-cuda-rocm.md index cc1039d..ac56d45 100644 --- a/docs/website/docs/deployment-configurations/gpu-cuda-rocm.md +++ b/docs/website/docs/deployment-configurations/gpu-cuda-rocm.md
@@ -42,7 +42,7 @@ Please make sure you have followed the [Getting started][get-started] page to build IREE from source, then enable the CUDA HAL driver with the `IREE_HAL_DRIVER_CUDA` option or the experimental ROCm HAL driver with the -`IREE_HAL_DRIVER_EXPERIMENTAL_ROCM` option. +`IREE_EXTERNAL_HAL_DRIVERS=rocm` option. #### Download compiler as Python package
diff --git a/experimental/rocm/CMakeLists.txt b/experimental/rocm/CMakeLists.txt index 142bd1f..e6b042d 100644 --- a/experimental/rocm/CMakeLists.txt +++ b/experimental/rocm/CMakeLists.txt
@@ -4,11 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -if(NOT IREE_HAL_DRIVER_EXPERIMENTAL_ROCM) - return() +iree_add_all_subdirs() + +if(NOT ROCM_HEADERS_API_ROOT) + set(ROCM_HEADERS_API_ROOT "/opt/rocm/include") endif() -iree_add_all_subdirs() +if(EXISTS ${ROCM_HEADERS_API_ROOT}) + message(STATUS "ROCm Header Path: ${ROCM_HEADERS_API_ROOT}") +else() + message(SEND_ERROR "Could not locate ROCm: ${ROCM_HEADERS_API_ROOT}") +endif() iree_cc_library( NAME @@ -44,8 +50,10 @@ INCLUDES "${CMAKE_CURRENT_LIST_DIR}/../.." "${PROJECT_BINARY_DIR}" + "${ROCM_HEADERS_API_ROOT}" DEPS ::dynamic_symbols + rocm_headers iree::base iree::base::core_headers iree::base::internal @@ -60,8 +68,6 @@ PUBLIC ) -add_definitions(-D__HIP_PLATFORM_HCC__) - iree_cc_library( NAME dynamic_symbols @@ -74,6 +80,8 @@ "dynamic_symbols.c" INCLUDES "${CMAKE_CURRENT_LIST_DIR}/../.." + COPTS + "-D__HIP_PLATFORM_HCC__=1" DEPS rocm_headers iree::base::core_headers @@ -95,5 +103,3 @@ LABELS "driver=rocm" ) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/experimental/rocm/cts/CMakeLists.txt b/experimental/rocm/cts/CMakeLists.txt index dd1566c..5710e99 100644 --- a/experimental/rocm/cts/CMakeLists.txt +++ b/experimental/rocm/cts/CMakeLists.txt
@@ -16,7 +16,7 @@ EXECUTABLE_FORMAT "\"PTXE\"" DEPS - experimental::rocm::registration + iree::experimental::rocm::registration EXCLUDED_TESTS # This test depends on iree_hal_rocm_direct_command_buffer_update_buffer # via iree_hal_buffer_view_allocate_buffer, which is not implemented yet.
diff --git a/experimental/rocm/registration/CMakeLists.txt b/experimental/rocm/registration/CMakeLists.txt index c11a941..b0d8409 100644 --- a/experimental/rocm/registration/CMakeLists.txt +++ b/experimental/rocm/registration/CMakeLists.txt
@@ -4,10 +4,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -iree_add_all_subdirs() - -if(IREE_HAL_DRIVER_EXPERIMENTAL_ROCM) - iree_cc_library( NAME registration @@ -20,15 +16,9 @@ iree::base::cc iree::base::core_headers iree::base::tracing + iree::experimental::rocm iree::hal - experimental::rocm - INCLUDES - "${CMAKE_CURRENT_LIST_DIR}/../../.." DEFINES "IREE_HAVE_HAL_EXPERIMENTAL_ROCM_DRIVER_MODULE=1" PUBLIC ) - -endif() - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/runtime/src/iree/hal/drivers/CMakeLists.txt b/runtime/src/iree/hal/drivers/CMakeLists.txt index b1d7275..97b4192 100644 --- a/runtime/src/iree/hal/drivers/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/CMakeLists.txt
@@ -4,25 +4,127 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -set(IREE_HAL_DRIVER_MODULES) +# TODO: we could make all of the internal drivers operate in the same way and +# have init.c generated too. That'd require bazel goo; today by having the hand- +# coded file we can use it in bazel as-is. + +# Link in externally defined drivers. +# This allows users to conditionally enable drivers that live outside of the +# IREE source tree by specifying a few cmake variables. +# +# Drivers are expected to have a CMakeLists.txt that is parsed when enabled. +# If a driver is optional it may set an IREE_EXTERNAL_{name}_HAL_DRIVER_FOUND +# variable to FALSE and be ignored, such as when dependencies are not found or +# other user configuration has disabled them. +# +# Each driver provides a static library target name and a function that is +# called at runtime to register the driver. +# +# Required variables: +# IREE_EXTERNAL_{name}_HAL_DRIVER_TARGET: static library target name. +# IREE_EXTERNAL_{name}_HAL_DRIVER_REGISTER: registration function: +# iree_status_t {name}_register(iree_hal_driver_registry_t* registry) +# Optional variables: +# IREE_EXTERNAL_{name}_HAL_DRIVER_OPTIONAL: true if the driver not being found +# is not an error. +# IREE_EXTERNAL_{name}_HAL_DRIVER_SOURCE_DIR: source directory with a +# CMakeLists.txt included when the driver is enabled. +# IREE_EXTERNAL_{name}_HAL_DRIVER_BINARY_DIR: binary directory for cmake outs. +# IREE_EXTERNAL_{name}_HAL_DRIVER_FOUND: bool to indicate whether the driver +# was found and valid for use. +set(IREE_EXTERNAL_HAL_DRIVERS_USED) +foreach(_DRIVER_NAME ${IREE_EXTERNAL_HAL_DRIVERS}) + string(TOUPPER "IREE_EXTERNAL_${_DRIVER_NAME}_HAL_DRIVER" _DRIVER_VAR) + string(REGEX REPLACE "-" "_" _DRIVER_VAR ${_DRIVER_VAR}) + message(STATUS "Adding IREE external HAL driver: ${_DRIVER_NAME}") + + set(_DRIVER_OPTIONAL ${${_DRIVER_VAR}_OPTIONAL}) + set(_DRIVER_SOURCE_DIR ${${_DRIVER_VAR}_SOURCE_DIR}) + set(_DRIVER_BINARY_DIR ${${_DRIVER_VAR}_BINARY_DIR}) + + # Default to found unless the user overrides it in the driver source. + # This allows the driver to decide to disable itself even if the user + # requested it. + set(${_DRIVER_VAR}_FOUND TRUE CACHE BOOL + "Whether the external driver is valid for use.") + + # Include the driver source CMakeLists.txt if required. + # Users may have already defined the targets and not need this. + if(_DRIVER_SOURCE_DIR) + if(NOT EXISTS "${_DRIVER_SOURCE_DIR}/CMakeLists.txt") + message(FATAL_ERROR "External driver CMakeLists.txt not found at " + "${_DRIVER_SOURCE_DIR}") + endif() + add_subdirectory(${_DRIVER_SOURCE_DIR} ${_DRIVER_BINARY_DIR}) + endif() + + # If found then add to the list of valid drivers. + if(${${_DRIVER_VAR}_FOUND}) + list(APPEND IREE_EXTERNAL_HAL_DRIVERS_USED ${_DRIVER_NAME}) + else() + if(${_DRIVER_OPTIONAL}) + message(STATUS "Optional external driver '${_DRIVER_NAME}' requested " + "but not found; disabling and continuing") + else() + message(FATAL_ERROR "External driver '${_DRIVER_NAME}' not found; may " + "have unavailable dependencies") + endif() + endif() +endforeach() + +# Produce an init_external.c that contains all of the registration calls. +# This will be called by the init.c after internal drivers are registered. +set(_INIT_EXTERNAL_C_SRC) +set(_INIT_EXTERNAL_COPTS) +set(_INIT_EXTERNAL_DEPS) +if(IREE_EXTERNAL_HAL_DRIVERS_USED) + message(STATUS "Registering external HAL drivers: ${IREE_EXTERNAL_HAL_DRIVERS_USED}") + + set(_INIT_EXTERNAL_COPTS "-DIREE_HAVE_HAL_EXTERNAL_DRIVERS=1") + + # Build the list of deps and our source code lines. + set(_INIT_EXTERNAL_DEPS) + set(_INIT_EXTERNAL_REGISTER_DECLS) + set(_INIT_EXTERNAL_REGISTER_CALLS) + foreach(_DRIVER_NAME ${IREE_EXTERNAL_HAL_DRIVERS_USED}) + string(TOUPPER "IREE_EXTERNAL_${_DRIVER_NAME}_HAL_DRIVER" _DRIVER_VAR) + string(REGEX REPLACE "-" "_" _DRIVER_VAR ${_DRIVER_VAR}) + set(_DRIVER_TARGET ${${_DRIVER_VAR}_TARGET}) + set(_DRIVER_REGISTER ${${_DRIVER_VAR}_REGISTER}) + list(APPEND _INIT_EXTERNAL_DEPS ${_DRIVER_TARGET}) + + list(APPEND _INIT_EXTERNAL_REGISTER_DECLS + "extern iree_status_t ${_DRIVER_REGISTER}(iree_hal_driver_registry_t* registry);\n") + list(APPEND _INIT_EXTERNAL_REGISTER_CALLS + "IREE_RETURN_IF_ERROR(${_DRIVER_REGISTER}(registry));\n") + endforeach() + + # Read template file and substitute variables. + set(_INIT_EXTERNAL_C_TPL "${CMAKE_CURRENT_SOURCE_DIR}/init_external.c.in") + set(_INIT_EXTERNAL_C_SRC "${CMAKE_CURRENT_BINARY_DIR}/init_external.c") + file(READ ${_INIT_EXTERNAL_C_TPL} _INIT_EXTERNAL_TEMPLATE) + file( + CONFIGURE OUTPUT ${_INIT_EXTERNAL_C_SRC} + CONTENT "${_INIT_EXTERNAL_TEMPLATE}" + ) +endif() + +set(_INIT_INTERNAL_DEPS) if(IREE_HAL_DRIVER_CUDA) add_subdirectory(cuda) - list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::drivers::cuda::registration) + list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::cuda::registration) endif() if(IREE_HAL_DRIVER_LOCAL_SYNC) add_subdirectory(local_sync) - list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::drivers::local_sync::registration) + list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::local_sync::registration) endif() if(IREE_HAL_DRIVER_LOCAL_TASK) add_subdirectory(local_task) - list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::drivers::local_task::registration) + list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::local_task::registration) endif() if(IREE_HAL_DRIVER_VULKAN) add_subdirectory(vulkan) - list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::drivers::vulkan::registration) -endif() -if(IREE_HAL_DRIVER_EXPERIMENTAL_ROCM) - list(APPEND IREE_HAL_DRIVER_MODULES experimental::rocm::registration) + list(APPEND _INIT_INTERNAL_DEPS iree::hal::drivers::vulkan::registration) endif() iree_cc_library( @@ -32,9 +134,13 @@ "init.h" SRCS "init.c" + ${_INIT_EXTERNAL_C_SRC} + COPTS + ${_INIT_EXTERNAL_COPTS} DEPS iree::base iree::base::tracing - ${IREE_HAL_DRIVER_MODULES} + ${_INIT_INTERNAL_DEPS} + ${_INIT_EXTERNAL_DEPS} PUBLIC )
diff --git a/runtime/src/iree/hal/drivers/init.c b/runtime/src/iree/hal/drivers/init.c index 56c30fe..ef04c8a 100644 --- a/runtime/src/iree/hal/drivers/init.c +++ b/runtime/src/iree/hal/drivers/init.c
@@ -24,9 +24,16 @@ #include "iree/hal/drivers/vulkan/registration/driver_module.h" #endif // IREE_HAVE_HAL_VULKAN_DRIVER_MODULE -#if defined(IREE_HAVE_HAL_EXPERIMENTAL_ROCM_DRIVER_MODULE) -#include "experimental/rocm/registration/driver_module.h" -#endif // IREE_HAVE_HAL_EXPERIMENTAL_ROCM_DRIVER_MODULE +#if defined(IREE_HAVE_HAL_EXTERNAL_DRIVERS) +// Defined in the generated init_external.c file: +extern iree_status_t iree_hal_register_external_drivers( + iree_hal_driver_registry_t* registry); +#else +static iree_status_t iree_hal_register_external_drivers( + iree_hal_driver_registry_t* registry) { + return iree_ok_status(); +} +#endif // IREE_HAVE_HAL_EXTERNAL_DRIVERS IREE_API_EXPORT iree_status_t iree_hal_register_all_available_drivers(iree_hal_driver_registry_t* registry) { @@ -52,10 +59,8 @@ z0, iree_hal_vulkan_driver_module_register(registry)); #endif // IREE_HAVE_HAL_VULKAN_DRIVER_MODULE -#if defined(IREE_HAVE_HAL_EXPERIMENTAL_ROCM_DRIVER_MODULE) IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_rocm_driver_module_register(registry)); -#endif // IREE_HAVE_HAL_EXPERIMENTAL_ROCM_DRIVER_MODULE + z0, iree_hal_register_external_drivers(registry)); IREE_TRACE_ZONE_END(z0); return iree_ok_status();
diff --git a/runtime/src/iree/hal/drivers/init_external.c.in b/runtime/src/iree/hal/drivers/init_external.c.in new file mode 100644 index 0000000..3abe2f2 --- /dev/null +++ b/runtime/src/iree/hal/drivers/init_external.c.in
@@ -0,0 +1,14 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/init.h" + +${_INIT_EXTERNAL_REGISTER_DECLS} +iree_status_t iree_hal_register_external_drivers( + iree_hal_driver_registry_t* registry) { + ${_INIT_EXTERNAL_REGISTER_CALLS} + return iree_ok_status(); +}