Removing the use of the legacy_sync hack from all but ROCM. (#16493)
The ROCM HAL does not support command buffers at all and thus only
executes with legacy_sync set. No other HAL requires it.
This should be a no-op for Vulkan/WebGPU, which are synchronous but
handle that internally during submission. Only the ROCM HAL is relying
on the compiler to insert the waits and change command buffers to
allow-inline-execution.
This is in preparation for removing the FixupLegacySyncPass.
diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp
index 1fffeca..b5ef37f 100644
--- a/compiler/plugins/target/CUDA/CUDATarget.cpp
+++ b/compiler/plugins/target/CUDA/CUDATarget.cpp
@@ -58,7 +58,6 @@
bool clUsePtxas = false;
std::string clUsePtxasFrom;
std::string clUsePtxasParams;
- bool enableLegacySync = false;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("CUDA HAL Target");
@@ -101,12 +100,6 @@
"iree-hal-cuda-use-ptxas-params", clUsePtxasParams,
llvm::cl::cat(category),
llvm::cl::desc("Passes the given additional parameters to ptxas."));
-
- binder.opt<bool>(
- "iree-hal-cuda-enable-legacy-sync", enableLegacySync,
- llvm::cl::cat(category),
- llvm::cl::desc(
- "Enable legacy sync mode that handles semaphores synchronously."));
}
};
} // namespace
@@ -391,12 +384,6 @@
Builder b(context);
SmallVector<NamedAttribute> configItems;
- // Indicates that the runtime HAL driver operates only in the legacy
- // synchronous mode.
- if (options.enableLegacySync) {
- configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr());
- }
-
configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));
diff --git a/compiler/plugins/target/WebGPU/WebGPUTarget.cpp b/compiler/plugins/target/WebGPU/WebGPUTarget.cpp
index af0293d..6fd63fa 100644
--- a/compiler/plugins/target/WebGPU/WebGPUTarget.cpp
+++ b/compiler/plugins/target/WebGPU/WebGPUTarget.cpp
@@ -75,10 +75,6 @@
Builder b(context);
SmallVector<NamedAttribute> configItems;
- // Indicates that the runtime HAL driver operates only in the legacy
- // synchronous mode.
- configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr());
-
configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
index fb45c20..0dd78f5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/nvvm_extract_address_computation.mlir
@@ -74,7 +74,6 @@
// CHECK: %[[IV_NEXT:.*]] = llvm.mul %[[IV]], %[[C8192]] : i64
#executable_target_cuda_nvptx_fb = #hal.executable.target<"cuda", "cuda-nvptx-fb", {target_arch = "sm_80"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
-#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}>
hal.executable private @matmul_dispatch_0 {
hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
hal.executable.export public @matmul_dispatch_0_matmul_2560x2560x2560 ordinal(0) layout(#pipeline_layout) {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir
index d10e408..b16675f 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_transform_cuda.mlir
@@ -402,7 +402,6 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>
-#device_target_cuda = #hal.device.target<"cuda", {executable_targets = [#executable_target_cuda_nvptx_fb], legacy_sync}>
hal.executable @reduction_2d_trailing_elementwise_static_dispatch_0 {
hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda_nvptx_fb) {
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index f47fd80..4b54096 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -128,10 +128,6 @@
Builder b(context);
SmallVector<NamedAttribute> configItems;
- // Indicates that the runtime HAL driver operates only in the legacy
- // synchronous mode.
- configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr());
-
configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));
diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
index 61b2572..9184e99 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
+++ b/compiler/src/iree/compiler/GlobalOptimization/test/materialize_homogeneous_encodings.mlir
@@ -36,7 +36,7 @@
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
-#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb], legacy_sync}>
+#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb]}>
module attributes {hal.device.targets = [#device_target_vulkan]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
@@ -71,7 +71,7 @@
#executable_target_embedded_elf_x86_64_ = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {target_triple = "x86_64-none-elf", cpu_features = "+avx512f"}>
#device_target_llvm_cpu = #hal.device.target<"llvm-cpu", {executable_targets = [#executable_target_embedded_elf_x86_64_]}>
#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan", "vulkan-spirv-fb">
-#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb], legacy_sync}>
+#device_target_vulkan = #hal.device.target<"vulkan", {executable_targets = [#executable_target_vulkan_spirv_fb]}>
module attributes {hal.device.targets = [#device_target_vulkan, #device_target_llvm_cpu]} {
util.func public @lhs_encoding(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = arith.constant 0.000000e+00 : f32
diff --git a/samples/custom_dispatch/cuda/kernels/example.mlir b/samples/custom_dispatch/cuda/kernels/example.mlir
index 22355a8..1438b20 100644
--- a/samples/custom_dispatch/cuda/kernels/example.mlir
+++ b/samples/custom_dispatch/cuda/kernels/example.mlir
@@ -28,9 +28,7 @@
executable_targets = [
#nvptx_sm_52_target,
#nvptx_sm_80_target
- ],
- // HACK: CUDA target currently uses the legacy synchronous execution model.
- legacy_sync
+ ]
}>
module @example attributes {hal.device.targets = [#cuda_target]} {
diff --git a/samples/custom_dispatch/vulkan/shaders/example.mlir b/samples/custom_dispatch/vulkan/shaders/example.mlir
index daf42ff..0aa2b31 100644
--- a/samples/custom_dispatch/vulkan/shaders/example.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example.mlir
@@ -25,9 +25,7 @@
// It's possible, for example, to support targeting multiple devices in the same
// compiled binary.
#vulkan_target = #hal.device.target<"vulkan", {
- executable_targets = [#spirv_target],
- // HACK: Vulkan target currently uses the legacy synchronous execution model.
- legacy_sync
+ executable_targets = [#spirv_target]
}>
module @example attributes {hal.device.targets = [#vulkan_target]} {
diff --git a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
index 5979d76..ea81680 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_inline.mlir
@@ -25,9 +25,7 @@
// It's possible, for example, to support targeting multiple devices in the same
// compiled binary.
#vulkan_target = #hal.device.target<"vulkan", {
- executable_targets = [#spirv_target],
- // HACK: Vulkan target currently uses the legacy synchronous execution model.
- legacy_sync
+ executable_targets = [#spirv_target]
}>
module @example attributes {hal.device.targets = [#vulkan_target]} {
diff --git a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
index 82662ac..d48c2f9 100644
--- a/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
+++ b/samples/custom_dispatch/vulkan/shaders/example_transform.mlir
@@ -32,9 +32,7 @@
// kernel that supports multiple targets by specifying an object per-target, but
// that requires authoring the kernel for multiple targets.
#vulkan_target = #hal.device.target<"vulkan", {
- executable_targets = [#spirv_target],
- // HACK: Vulkan target currently uses the legacy synchronous execution model.
- legacy_sync
+ executable_targets = [#spirv_target]
}>
#map = affine_map<(d0, d1) -> (d0, d1)>
diff --git a/samples/transform_dialect/example_module.mlir b/samples/transform_dialect/example_module.mlir
index 1e4ac4e..7b4743b 100644
--- a/samples/transform_dialect/example_module.mlir
+++ b/samples/transform_dialect/example_module.mlir
@@ -5,7 +5,7 @@
// !B_size = tensor<5x16xf32>
// !C_size = tensor<16x16xf32>
// !O_size = tensor<16xf32>
-//
+//
// module {
// func.func @example_module(%A : !A_size, %B : !B_size, %C : !C_size) -> !O_size {
// %0 = linalg.add ins(%A, %A : !A_size, !A_size)
@@ -16,10 +16,10 @@
// %2 = linalg.reduce
// ins(%1 : !C_size)
// outs(%empty : !O_size)
-// dimensions = [1]
+// dimensions = [1]
// (%in: f32, %out: f32) {
-// %3 = arith.addf %out, %in: f32
-// linalg.yield %3: f32
+// %3 = arith.addf %out, %in: f32
+// linalg.yield %3: f32
// }
// return %2 : !O_size
// }
@@ -27,13 +27,13 @@
#target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>
-module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>], legacy_sync}>]} {
+module attributes {hal.device.targets = [#hal.device.target<"vulkan", {executable_targets = [#hal.executable.target<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, GroupNonUniform], [SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Vulkan, #spirv.resource_limits<max_compute_workgroup_size = [128, 128, 64], subgroup_size = 64, cooperative_matrix_properties_khr = []>>}>]}>]} {
hal.executable private @example_module_dispatch_0 {
hal.executable.variant public @vulkan_spirv_fb target(<"vulkan", "vulkan-spirv-fb", {spirv.target_env = #target_env}>) {
hal.executable.export public @example_module_dispatch_0_generic_80_f32 ordinal(0) layout(
#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
- %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
@@ -59,7 +59,7 @@
hal.executable.export public @example_module_dispatch_1_matmul_16x16x5_f32 ordinal(0) layout(
#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
- %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
@@ -83,7 +83,7 @@
hal.executable.export public @example_module_dispatch_2_generic_16x16_f32 ordinal(0) layout(
#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
^bb0(%arg0: !hal.device):
- %x, %y, %z = flow.dispatch.workgroup_count_from_slice
+ %x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
diff --git a/tests/e2e/stablehlo_ops/BUILD.bazel b/tests/e2e/stablehlo_ops/BUILD.bazel
index 2e9bc1c..9b73263 100644
--- a/tests/e2e/stablehlo_ops/BUILD.bazel
+++ b/tests/e2e/stablehlo_ops/BUILD.bazel
@@ -414,7 +414,6 @@
compiler_flags = [
# TODO(#13984): memset emulation required for graphs.
"--iree-stream-emulate-memset",
- "--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda",
input_type = "stablehlo",
@@ -433,9 +432,6 @@
iree_check_single_backend_test_suite(
name = "check_cuda_stream",
srcs = CUDA_SRCS,
- compiler_flags = [
- "--iree-hal-cuda-enable-legacy-sync=false",
- ],
driver = "cuda",
input_type = "stablehlo",
runner_args = ["--cuda_use_streams=true"],
diff --git a/tests/e2e/stablehlo_ops/CMakeLists.txt b/tests/e2e/stablehlo_ops/CMakeLists.txt
index b322b8e..65cb226 100644
--- a/tests/e2e/stablehlo_ops/CMakeLists.txt
+++ b/tests/e2e/stablehlo_ops/CMakeLists.txt
@@ -373,7 +373,6 @@
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
- "--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
"stablehlo"
RUNNER_ARGS
@@ -455,8 +454,6 @@
"cuda"
DRIVER
"cuda"
- COMPILER_FLAGS
- "--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
"stablehlo"
RUNNER_ARGS
diff --git a/tests/e2e/tosa_ops/BUILD.bazel b/tests/e2e/tosa_ops/BUILD.bazel
index 121155f..3054672 100644
--- a/tests/e2e/tosa_ops/BUILD.bazel
+++ b/tests/e2e/tosa_ops/BUILD.bazel
@@ -301,7 +301,6 @@
compiler_flags = [
# TODO(#13984): memset emulation required for graphs.
"--iree-stream-emulate-memset",
- "--iree-hal-cuda-enable-legacy-sync=false",
],
driver = "cuda",
input_type = "tosa",
@@ -320,9 +319,6 @@
iree_check_single_backend_test_suite(
name = "check_cuda_stream",
srcs = CUDA_SRCS,
- compiler_flags = [
- "--iree-hal-cuda-enable-legacy-sync=false",
- ],
driver = "cuda",
input_type = "tosa",
runner_args = ["--cuda_use_streams=true"],
diff --git a/tests/e2e/tosa_ops/CMakeLists.txt b/tests/e2e/tosa_ops/CMakeLists.txt
index 7615551..f3578a5 100644
--- a/tests/e2e/tosa_ops/CMakeLists.txt
+++ b/tests/e2e/tosa_ops/CMakeLists.txt
@@ -272,7 +272,6 @@
"cuda"
COMPILER_FLAGS
"--iree-stream-emulate-memset"
- "--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
"tosa"
RUNNER_ARGS
@@ -333,8 +332,6 @@
"cuda"
DRIVER
"cuda"
- COMPILER_FLAGS
- "--iree-hal-cuda-enable-legacy-sync=false"
INPUT_TYPE
"tosa"
RUNNER_ARGS