Enabling linking in the ROCM/CUDA compiler targets. (#18936)
This does exactly what the LLVMCPU side does - which is bad for compile
time (serializes LLVM codegen) but much better for runtime. Future
improvements should move LLVM codegen to the linking phase so it can
happen in parallel and then perform the linking using LLVM's linker
(each executable turned into a .o and then combined into a .so, or
last-level bitcode if then we just want serialization to be bitcode to
machine code). This is definitely a compile-time regression but we can't
keep pessimizing runtime.
diff --git a/compiler/plugins/target/CUDA/CUDATarget.cpp b/compiler/plugins/target/CUDA/CUDATarget.cpp
index 18896f2..ffc49b5 100644
--- a/compiler/plugins/target/CUDA/CUDATarget.cpp
+++ b/compiler/plugins/target/CUDA/CUDATarget.cpp
@@ -461,6 +461,10 @@
buildLLVMGPUCodegenPassPipeline(passManager, false);
}
+ void buildLinkingPassPipeline(OpPassManager &passManager) override {
+ buildLLVMGPULinkingPassPipeline(passManager, "cuda");
+ }
+
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
diff --git a/compiler/plugins/target/CUDA/test/smoketest.mlir b/compiler/plugins/target/CUDA/test/smoketest.mlir
index 6e6fa94..0c12f06 100644
--- a/compiler/plugins/target/CUDA/test/smoketest.mlir
+++ b/compiler/plugins/target/CUDA/test/smoketest.mlir
@@ -1,8 +1,6 @@
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-hal-transformation-pipeline --iree-gpu-test-target=sm_60 --iree-hal-dump-executable-binaries-to=- %s 2>&1 | FileCheck %s --check-prefix=PTX
-#map = affine_map<(d0) -> (d0)>
-
module attributes {
hal.device.targets = [
#hal.device.target<"cuda", [
@@ -11,13 +9,13 @@
]
} {
-stream.executable public @add_dispatch_0 {
- stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
+stream.executable public @add_dispatch_executable {
+ stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
- func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
+ func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
@@ -26,7 +24,7 @@
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
@@ -36,12 +34,42 @@
}
}
+stream.executable public @mul_dispatch_executable {
+ stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ stream.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
+ %c0 = arith.constant 0 : index
+ %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
+ %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
+ %arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
+ %0 = tensor.empty() : tensor<16xf32>
+ %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
+ %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %4 = arith.mulf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
+ } -> tensor<16xf32>
+ flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
+ return
+ }
+ }
+}
+
}
-// PTX: .entry add_dispatch_0
+// PTX: .entry add_dispatch
// PTX: .maxntid 64, 1, 1
// PTX: add.rn.f32
-// CHECK: hal.executable.binary public @cuda_nvptx_fb attributes {
+// PTX: .entry mul_dispatch
+// PTX: .maxntid 64, 1, 1
+// PTX: mul.rn.f32
+
+// CHECK: hal.executable public @smoketest_linked
+// CHECK-NEXT: hal.executable.binary public @cuda_nvptx_fb attributes {
// CHECK-SAME: data = dense
// CHECK-SAME: format = "cuda-nvptx-fb"
diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
index 7db50ac..ee8e256 100644
--- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
+++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
@@ -241,7 +241,7 @@
}
void buildLinkingPassPipeline(OpPassManager &passManager) override {
- buildLLVMCPULinkingPassPipeline(passManager);
+ buildLLVMCPULinkingPassPipeline(passManager, "llvm-cpu");
}
// Gets the LLVM target from |variantOp|.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index c860b63..05ab667 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -269,6 +269,10 @@
buildLLVMGPUCodegenPassPipeline(passManager, true);
}
+ void buildLinkingPassPipeline(OpPassManager &passManager) override {
+ buildLLVMGPULinkingPassPipeline(passManager, "rocm");
+ }
+
// Performs optimizations on |module| (including LTO-style whole-program
// ones). Inspired by code section in
// https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp
diff --git a/compiler/plugins/target/ROCM/test/smoketest.mlir b/compiler/plugins/target/ROCM/test/smoketest.mlir
index 1afe688..a25547b 100644
--- a/compiler/plugins/target/ROCM/test/smoketest.mlir
+++ b/compiler/plugins/target/ROCM/test/smoketest.mlir
@@ -2,19 +2,19 @@
module attributes {
hal.device.targets = [
- #hal.device.target<"hip", [
- #hal.executable.target<"rocm", "rocm-hsaco-fb">
+ #hal.device.target<"amdgpu", [
+ #hal.executable.target<"rocm", "amdgcn-amd-amdhsa">
]> : !hal.device
]
} {
-stream.executable public @add_dispatch_0 {
- stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
+stream.executable public @add_dispatch_executable {
+ stream.executable.export @add_dispatch workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
- func.func @add_dispatch_0(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
+ func.func @add_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
@@ -23,7 +23,7 @@
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> tensor<16xf32>
@@ -33,11 +33,37 @@
}
}
+stream.executable public @mul_dispatch_executable {
+ stream.executable.export @mul_dispatch workgroups(%arg0 : index) -> (index, index, index) {
+ %x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
+ stream.return %x, %y, %z : index, index, index
+ }
+ builtin.module {
+ func.func @mul_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding, %arg2_binding: !stream.binding) {
+ %c0 = arith.constant 0 : index
+ %arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
+ %arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
+ %arg2 = stream.binding.subspan %arg2_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
+ %0 = tensor.empty() : tensor<16xf32>
+ %1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
+ %2 = flow.dispatch.tensor.load %arg1, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%1, %2 : tensor<16xf32>, tensor<16xf32>) outs(%0 : tensor<16xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+ %4 = arith.mulf %arg3, %arg4 : f32
+ linalg.yield %4 : f32
+ } -> tensor<16xf32>
+ flow.dispatch.tensor.store %3, %arg2, offsets=[0], sizes=[16], strides=[1] : tensor<16xf32> -> !flow.dispatch.tensor<writeonly:tensor<16xf32>>
+ return
+ }
+ }
+}
+
}
-// CHECK: hal.executable.binary public @rocm_hsaco_fb attributes {
+// CHECK: hal.executable public @smoketest_linked
+// CHECK: hal.executable.binary public @amdgcn_amd_amdhsa attributes {
// CHECK-SAME: data = dense
-// CHECK-SAME: format = "rocm-hsaco-fb"
+// CHECK-SAME: format = "amdgcn-amd-amdhsa"
// -----
@@ -52,13 +78,13 @@
]
} {
-stream.executable public @add_dispatch_0 {
- stream.executable.export @add_dispatch_0 workgroups(%arg0 : index) -> (index, index, index) {
+stream.executable public @executable {
+ stream.executable.export @export workgroups(%arg0 : index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
} loc(#loc)
builtin.module {
- func.func @add_dispatch_0() {
+ func.func @export() {
return
} loc(#loc)
} loc(#loc)
diff --git a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
index 9b4639b..555601c 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
@@ -1,120 +1,120 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
-#include "iree/compiler/dialects/iree_gpu.h"
-#include "mlir-c/IR.h"
-#include "mlir/CAPI/IR.h"
-#include "mlir/CAPI/Support.h"
-
-bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
- return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
- unwrap(attr));
-}
-
-MlirAttribute
-ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory,
- bool *noReduceSharedMemoryBankConflicts,
- MlirAttribute *reorderWorkgroupsStrategy) {
- mlir::MLIRContext *ctx = unwrap(mlirCtx);
- mlir::Builder b(ctx);
- auto prefetchSharedMemoryAttr = mlir::BoolAttr();
- if (prefetchSharedMemory) {
- prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory);
- }
- auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr();
- if (noReduceSharedMemoryBankConflicts) {
- noReduceSharedMemoryBankConflictsAttr =
- b.getBoolAttr(*noReduceSharedMemoryBankConflicts);
- }
- auto strategyAttr =
- mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr();
- if (reorderWorkgroupsStrategy) {
- strategyAttr = llvm::dyn_cast<
- mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
- unwrap(*reorderWorkgroupsStrategy));
- }
- return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get(
- ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr,
- strategyAttr));
-}
-
-MlirAttribute
-ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) {
- auto gpuAttr =
- llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
- unwrap(attr));
- return wrap(gpuAttr.getPrefetchSharedMemory());
-}
-
-MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts(
- MlirAttribute attr) {
- auto gpuAttr =
- llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
- unwrap(attr));
- return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts());
-}
-
-MlirAttribute
-ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) {
- auto gpuAttr =
- llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
- unwrap(attr));
- return wrap(gpuAttr.getReorderWorkgroupsStrategy());
-}
-
-MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
- return wrap(
- mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
-}
-
-static_assert(
- static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
- static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
- ReorderWorkgroupsStrategy::None) &&
- static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) ==
- static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
- ReorderWorkgroupsStrategy::Swizzle) &&
- static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
- static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
- ReorderWorkgroupsStrategy::Transpose) &&
- static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
- mlir::iree_compiler::IREE::GPU::
- getMaxEnumValForReorderWorkgroupsStrategy(),
- "ireeGPUReorderWorkgroupsStrategyEnum and "
- "mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
- "have diverged");
-
-bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
- return llvm::isa<
- mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
- unwrap(attr));
-}
-
-MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
- return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::
- getTypeID());
-}
-
-MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
- MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) {
- mlir::MLIRContext *ctx = unwrap(mlirCtx);
- return wrap(
- mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
- ctx, static_cast<
- mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>(
- value)));
-}
-
-ireeGPUReorderWorkgroupsStrategyEnum
-ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
- assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
- "attr is not a GPUReorderWorkgroupsStrategyAttr");
- return static_cast<ireeGPUReorderWorkgroupsStrategyEnum>(
- llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
- unwrap(attr))
- .getValue());
-}
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
+#include "iree/compiler/dialects/iree_gpu.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+
+bool ireeAttributeIsAGPUPipelineOptionsAttr(MlirAttribute attr) {
+ return llvm::isa<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
+ unwrap(attr));
+}
+
+MlirAttribute
+ireeGPUPipelineOptionsAttrGet(MlirContext mlirCtx, bool *prefetchSharedMemory,
+ bool *noReduceSharedMemoryBankConflicts,
+ MlirAttribute *reorderWorkgroupsStrategy) {
+ mlir::MLIRContext *ctx = unwrap(mlirCtx);
+ mlir::Builder b(ctx);
+ auto prefetchSharedMemoryAttr = mlir::BoolAttr();
+ if (prefetchSharedMemory) {
+ prefetchSharedMemoryAttr = b.getBoolAttr(*prefetchSharedMemory);
+ }
+ auto noReduceSharedMemoryBankConflictsAttr = mlir::BoolAttr();
+ if (noReduceSharedMemoryBankConflicts) {
+ noReduceSharedMemoryBankConflictsAttr =
+ b.getBoolAttr(*noReduceSharedMemoryBankConflicts);
+ }
+ auto strategyAttr =
+ mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr();
+ if (reorderWorkgroupsStrategy) {
+ strategyAttr = llvm::dyn_cast<
+ mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
+ unwrap(*reorderWorkgroupsStrategy));
+ }
+ return wrap(mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::get(
+ ctx, prefetchSharedMemoryAttr, noReduceSharedMemoryBankConflictsAttr,
+ strategyAttr));
+}
+
+MlirAttribute
+ireeGPUPipelineOptionsAttrGetPrefetchSharedMemory(MlirAttribute attr) {
+ auto gpuAttr =
+ llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
+ unwrap(attr));
+ return wrap(gpuAttr.getPrefetchSharedMemory());
+}
+
+MlirAttribute ireeGPUPipelineOptionsAttrGetNoReduceSharedMemoryBankConflicts(
+ MlirAttribute attr) {
+ auto gpuAttr =
+ llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
+ unwrap(attr));
+ return wrap(gpuAttr.getNoReduceSharedMemoryBankConflicts());
+}
+
+MlirAttribute
+ireeGPUPipelineOptionsAttrGetReorderWorkgroupsStrategy(MlirAttribute attr) {
+ auto gpuAttr =
+ llvm::cast<mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr>(
+ unwrap(attr));
+ return wrap(gpuAttr.getReorderWorkgroupsStrategy());
+}
+
+MlirTypeID ireeGPUPipelineOptionsAttrGetTypeID() {
+ return wrap(
+ mlir::iree_compiler::IREE::GPU::GPUPipelineOptionsAttr::getTypeID());
+}
+
+static_assert(
+ static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumNone) ==
+ static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
+ ReorderWorkgroupsStrategy::None) &&
+ static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumSwizzle) ==
+ static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
+ ReorderWorkgroupsStrategy::Swizzle) &&
+ static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
+ static_cast<uint32_t>(mlir::iree_compiler::IREE::GPU::
+ ReorderWorkgroupsStrategy::Transpose) &&
+ static_cast<uint32_t>(ireeGPUReorderWorkgroupsStrategyEnumTranspose) ==
+ mlir::iree_compiler::IREE::GPU::
+ getMaxEnumValForReorderWorkgroupsStrategy(),
+ "ireeGPUReorderWorkgroupsStrategyEnum and "
+ "mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy definitions "
+ "have diverged");
+
+bool ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(MlirAttribute attr) {
+ return llvm::isa<
+ mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
+ unwrap(attr));
+}
+
+MlirTypeID ireeGPUReorderWorkgroupsStrategyAttrGetTypeID() {
+ return wrap(mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::
+ getTypeID());
+}
+
+MlirAttribute ireeGPUReorderWorkgroupsStrategyAttrGet(
+ MlirContext mlirCtx, ireeGPUReorderWorkgroupsStrategyEnum value) {
+ mlir::MLIRContext *ctx = unwrap(mlirCtx);
+ return wrap(
+ mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr::get(
+ ctx, static_cast<
+ mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategy>(
+ value)));
+}
+
+ireeGPUReorderWorkgroupsStrategyEnum
+ireeGPUReorderWorkgroupsStrategyAttrGetValue(MlirAttribute attr) {
+ assert(ireeAttributeIsAGPUReorderWorkgroupsStrategyAttr(attr) &&
+ "attr is not a GPUReorderWorkgroupsStrategyAttr");
+ return static_cast<ireeGPUReorderWorkgroupsStrategyEnum>(
+ llvm::cast<mlir::iree_compiler::IREE::GPU::ReorderWorkgroupsStrategyAttr>(
+ unwrap(attr))
+ .getValue());
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp
index 8a2e91c..7bfe586 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPULinkExecutables.cpp
@@ -19,7 +19,8 @@
struct LLVMCPULinkExecutablesPass
: public impl::LLVMCPULinkExecutablesPassBase<LLVMCPULinkExecutablesPass> {
- LLVMCPULinkExecutablesPass() = default;
+ using impl::LLVMCPULinkExecutablesPassBase<
+ LLVMCPULinkExecutablesPass>::LLVMCPULinkExecutablesPassBase;
void runOnOperation() override {
auto moduleOp = getOperation();
auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
@@ -30,29 +31,36 @@
return;
// Guess a module name, if needed, to make the output files readable.
- auto moduleName = guessModuleName(moduleOp, "llvm_module");
+ auto moduleName = guessModuleName(moduleOp, "module");
// Create our new "linked" hal.executable.
- std::string linkedExecutableName =
- llvm::formatv("{0}_linked_{1}", moduleName, "llvm_cpu");
+ SymbolTable moduleTable(moduleOp);
+ std::string linkedExecutableName = llvm::formatv("{0}_linked", moduleName);
auto linkedExecutableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
moduleOp.getLoc(), linkedExecutableName);
linkedExecutableOp.setVisibility(
sourceExecutableOps.front().getVisibility());
+ moduleTable.insert(linkedExecutableOp);
auto executableBuilder =
OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock());
// Gather all unique executable targets - we may have multiple.
auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps);
- for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) {
+ for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) {
+ // Only link the target specified. If none specified link all.
+ if (!target.empty() && targetAttr.getBackend().getValue() != target) {
+ continue; // not linking this target
+ }
+
// Add our hal.executable.variant with an empty module.
std::string linkedVariantName =
executableTargetAttrs.size() == 1
- ? attr.getSymbolNameFragment()
- : llvm::formatv("{0}_{1}", attr.getSymbolNameFragment(), index);
+ ? targetAttr.getSymbolNameFragment()
+ : llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(),
+ index);
auto linkedTargetOp =
executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
- moduleOp.getLoc(), linkedVariantName, attr);
+ moduleOp.getLoc(), linkedVariantName, targetAttr);
auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
targetBuilder.create<mlir::ModuleOp>(moduleOp.getLoc());
@@ -71,5 +79,6 @@
}
}
};
+
} // namespace
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
index 71b3aec..9ef65e2 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
@@ -827,9 +827,12 @@
// NOTE: this runs on the top-level program module containing all
// hal.executable ops.
-void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager) {
+void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager,
+ std::optional<std::string> target) {
// Link together executables. This may produce some IR duplication.
- modulePassManager.addPass(createLLVMCPULinkExecutablesPass());
+ LLVMCPULinkExecutablesPassOptions linkOptions;
+ linkOptions.target = target.value_or("");
+ modulePassManager.addPass(createLLVMCPULinkExecutablesPass(linkOptions));
// Cleanup IR duplication.
modulePassManager.addNestedPass<IREE::HAL::ExecutableOp>(
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
index 42d4035..4696bc8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
@@ -12,6 +12,8 @@
#ifndef IREE_COMPILER_CODEGEN_LLVMCPU_PASSES_H_
#define IREE_COMPILER_CODEGEN_LLVMCPU_PASSES_H_
+#include <optional>
+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "mlir/Pass/Pass.h"
@@ -156,7 +158,9 @@
//----------------------------------------------------------------------------//
/// Populates passes needed to link HAL executables across LLVMCPU targets.
-void buildLLVMCPULinkingPassPipeline(OpPassManager &modulePassManager);
+void buildLLVMCPULinkingPassPipeline(
+ OpPassManager &modulePassManager,
+ std::optional<std::string> target = std::nullopt);
//----------------------------------------------------------------------------//
// Register LLVMCPU Passes
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
index c9aec67..12f90be 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
@@ -69,6 +69,13 @@
def LLVMCPULinkExecutablesPass :
Pass<"iree-llvmcpu-link-executables", "mlir::ModuleOp"> {
let summary = "Links LLVMCPU HAL executables within the top-level program module.";
+ let options = [
+ Option<
+ "target", "target",
+ "std::string", "",
+ "Target backend name whose executables will be linked by this pass."
+ >,
+ ];
}
def LLVMCPULowerExecutableTargetPass :
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
index 3d8c7a2..19af0c4 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel
@@ -91,11 +91,13 @@
"ConvertToROCDL.cpp",
"ExtractAddressComputationGPUPass.cpp",
"KernelConfig.cpp",
+ "LLVMGPUAssignConstantOrdinals.cpp",
"LLVMGPUCastAddressSpaceFunction.cpp",
"LLVMGPUCastTypeToFitMMA.cpp",
"LLVMGPUConfigureTensorLayouts.cpp",
"LLVMGPUConfigureVectorLayouts.cpp",
"LLVMGPUConvolutionToIGEMM.cpp",
+ "LLVMGPULinkExecutables.cpp",
"LLVMGPULowerExecutableTarget.cpp",
"LLVMGPUPackSharedMemoryAlloc.cpp",
"LLVMGPUPrefetching.cpp",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
index 9016d63..aa2c5a5 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt
@@ -76,11 +76,13 @@
"ConvertToROCDL.cpp"
"ExtractAddressComputationGPUPass.cpp"
"KernelConfig.cpp"
+ "LLVMGPUAssignConstantOrdinals.cpp"
"LLVMGPUCastAddressSpaceFunction.cpp"
"LLVMGPUCastTypeToFitMMA.cpp"
"LLVMGPUConfigureTensorLayouts.cpp"
"LLVMGPUConfigureVectorLayouts.cpp"
"LLVMGPUConvolutionToIGEMM.cpp"
+ "LLVMGPULinkExecutables.cpp"
"LLVMGPULowerExecutableTarget.cpp"
"LLVMGPUPackSharedMemoryAlloc.cpp"
"LLVMGPUPrefetching.cpp"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp
new file mode 100644
index 0000000..c789b92
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUAssignConstantOrdinals.cpp
@@ -0,0 +1,53 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPUASSIGNCONSTANTORDINALSPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+struct LLVMGPUAssignConstantOrdinalsPass
+ : public impl::LLVMGPUAssignConstantOrdinalsPassBase<
+ LLVMGPUAssignConstantOrdinalsPass> {
+ void runOnOperation() override {
+ auto variantOp = getOperation();
+
+ // Get a constant key -> ordinal mapping.
+ auto keyOrdinals = variantOp.gatherConstantOrdinals();
+ if (keyOrdinals.empty())
+ return;
+
+ // Update placeholders to hold the concrete ordinal values.
+ // Eventually MLIR or LLVM will inline them.
+ auto moduleOp = variantOp.getInnerModule();
+ for (auto globalOp :
+ llvm::make_early_inc_range(moduleOp.getOps<LLVM::GlobalOp>())) {
+ auto keyAttr = globalOp->getAttr(
+ IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName());
+ if (!keyAttr)
+ continue;
+ auto it = keyOrdinals.find(keyAttr);
+ if (it == keyOrdinals.end()) {
+ globalOp.emitOpError()
+ << "no constant block providing key '" << keyAttr << "'";
+ return signalPassFailure();
+ }
+ globalOp->removeAttr(
+ IREE::HAL::ExecutableConstantBlockOp::getKeyAttrName());
+ globalOp.setConstantAttr(UnitAttr::get(globalOp.getContext()));
+ globalOp.setValueAttr(IntegerAttr::get(
+ IntegerType::get(globalOp.getContext(), 32), it->second));
+ }
+ }
+};
+} // namespace
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp
new file mode 100644
index 0000000..5ffaff9
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPULinkExecutables.cpp
@@ -0,0 +1,123 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
+#include "iree/compiler/Codegen/Utils/LinkingUtils.h"
+#include "iree/compiler/Utils/ModuleUtils.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler {
+
+#define GEN_PASS_DEF_LLVMGPULINKEXECUTABLESPASS
+#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
+
+namespace {
+
+// Returns true if the address space of a global symbol is private to the module
+// scope it originates in. AMD and NVIDIA disagree on the naming but the values
+// match. LLVM is a mess here.
+static bool isSymbolAddressSpacePrivate(uint32_t addressSpace) {
+ return addressSpace == /*local*/ 3 || addressSpace == /*private*/ 5;
+}
+
+static SymbolTable::Visibility
+convertLinkageToVisibility(LLVM::Linkage linkage) {
+ switch (linkage) {
+ case LLVM::Linkage::Private:
+ return SymbolTable::Visibility::Private;
+ case LLVM::Linkage::External:
+ return SymbolTable::Visibility::Public;
+ default:
+ return SymbolTable::Visibility::Public;
+ }
+}
+
+// Returns true if we are allowed to rename |op| as part of merging.
+// The LLVMGPU lowering is super careful about assigning linkage so we err on
+// the side of renaming (as 100% of usage today does not reference external
+// things).
+static bool allowRenamingPrivateLLVMSymbols(Operation *op) {
+ if (auto globalOp = dyn_cast<LLVM::GlobalOp>(op)) {
+ if (isSymbolAddressSpacePrivate(globalOp.getAddrSpace())) {
+ return true;
+ }
+ return convertLinkageToVisibility(globalOp.getLinkage()) ==
+ SymbolTable::Visibility::Private;
+ } else if (auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(op)) {
+ return convertLinkageToVisibility(funcOp.getLinkage()) ==
+ SymbolTable::Visibility::Private;
+ }
+ return SymbolTable::getSymbolVisibility(op) ==
+ SymbolTable::Visibility::Private;
+}
+
+struct LLVMGPULinkExecutablesPass
+ : public impl::LLVMGPULinkExecutablesPassBase<LLVMGPULinkExecutablesPass> {
+ using impl::LLVMGPULinkExecutablesPassBase<
+ LLVMGPULinkExecutablesPass>::LLVMGPULinkExecutablesPassBase;
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
+
+ auto sourceExecutableOps =
+ llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
+ if (sourceExecutableOps.size() <= 1)
+ return;
+
+ // Guess a module name, if needed, to make the output files readable.
+ auto moduleName = guessModuleName(moduleOp, "module");
+
+ // Create our new "linked" hal.executable.
+ SymbolTable moduleTable(moduleOp);
+ std::string linkedExecutableName = llvm::formatv("{0}_linked", moduleName);
+ auto linkedExecutableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>(
+ moduleOp.getLoc(), linkedExecutableName);
+ linkedExecutableOp.setVisibility(
+ sourceExecutableOps.front().getVisibility());
+ moduleTable.insert(linkedExecutableOp);
+ auto executableBuilder =
+ OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock());
+
+ // Gather all unique executable targets - we may have multiple.
+ auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps);
+ for (auto [index, targetAttr] : llvm::enumerate(executableTargetAttrs)) {
+ // Only link the target specified. If none specified link all.
+ if (!target.empty() && targetAttr.getBackend().getValue() != target) {
+ continue; // not linking this target
+ }
+
+ // Add our hal.executable.variant with an empty module.
+ std::string linkedVariantName =
+ executableTargetAttrs.size() == 1
+ ? targetAttr.getSymbolNameFragment()
+ : llvm::formatv("{0}_{1}", targetAttr.getSymbolNameFragment(),
+ index);
+ auto linkedTargetOp =
+ executableBuilder.create<IREE::HAL::ExecutableVariantOp>(
+ moduleOp.getLoc(), linkedVariantName, targetAttr);
+ auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
+ targetBuilder.create<mlir::ModuleOp>(moduleOp.getLoc());
+
+ auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule,
+ mlir::ModuleOp linkedInnerModule,
+ DenseMap<StringRef, Operation *> &symbolMap) {
+ return mergeModuleInto(sourceInnerModule, linkedInnerModule, symbolMap,
+ allowRenamingPrivateLLVMSymbols);
+ };
+
+ // Try linking together all executables in moduleOp.
+ if (failed(linkExecutablesInto(moduleOp, sourceExecutableOps,
+ linkedExecutableOp, linkedTargetOp,
+ mergeModuleFn))) {
+ return signalPassFailure();
+ }
+ }
+ }
+};
+} // namespace
+} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 86f65e1..3c7eaf8 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -1220,6 +1220,25 @@
});
}
+// NOTE: this runs on the top-level program module containing all
+// hal.executable ops.
+void buildLLVMGPULinkingPassPipeline(OpPassManager &modulePassManager,
+ std::optional<std::string> target) {
+ // Link together executables. This may produce some IR duplication.
+ LLVMGPULinkExecutablesPassOptions linkOptions;
+ linkOptions.target = target.value_or("");
+ modulePassManager.addPass(createLLVMGPULinkExecutablesPass(linkOptions));
+
+ // Cleanup IR duplication.
+ modulePassManager.addNestedPass<IREE::HAL::ExecutableOp>(
+ mlir::createCanonicalizerPass());
+
+ // Assign final executable constant and import ordinals.
+ auto &variantPassManager = modulePassManager.nest<IREE::HAL::ExecutableOp>()
+ .nest<IREE::HAL::ExecutableVariantOp>();
+ variantPassManager.addPass(createLLVMGPUAssignConstantOrdinalsPass());
+}
+
//===----------------------------------------------------------------------===//
// ROCDL Pass Pipelines
//===----------------------------------------------------------------------===//
@@ -1298,6 +1317,13 @@
[](OpPassManager &passManager) {
buildLLVMGPUCodegenPassPipeline(passManager, true);
});
+
+ static PassPipelineRegistration<> LLVMGPULinkingPipeline(
+ "iree-codegen-llvmgpu-linking-pipeline",
+ "Runs the LLVMGPU HAL executable linking pipeline",
+ [](OpPassManager &modulePassManager) {
+ buildLLVMGPULinkingPassPipeline(modulePassManager);
+ });
}
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
index d932564..e7132c7 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h
@@ -12,6 +12,8 @@
#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_PASSES_H_
#define IREE_COMPILER_CODEGEN_LLVMGPU_PASSES_H_
+#include <optional>
+
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
@@ -22,7 +24,7 @@
using IREE::GPU::GPUPipelineOptions;
//----------------------------------------------------------------------------//
-// LLVMGPU backend Pass Pipelines.
+// LLVMGPU Backend Pass Pipelines
//----------------------------------------------------------------------------//
/// Lowering using SIMT CUDA core operations.
@@ -99,8 +101,17 @@
IREE::Codegen::TranslationInfoAttr translationInfo,
ArrayRef<int64_t> workgroupSize);
+//----------------------------------------------------------------------------//
+// LLVMGPU Linking Passes and Pipelines
+//----------------------------------------------------------------------------//
+
+/// Populates passes needed to link HAL executables across LLVMGPU targets.
+void buildLLVMGPULinkingPassPipeline(
+ OpPassManager &modulePassManager,
+ std::optional<std::string> target = std::nullopt);
+
//------------------------------------------------------------------------------
-// Wrappers that not use tablegen options.
+// Wrappers that do not use tablegen options
//------------------------------------------------------------------------------
enum class GPUTensorCoreType {
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
index aa6b552..0b8df81 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.td
@@ -66,6 +66,11 @@
];
}
+def LLVMGPUAssignConstantOrdinalsPass :
+ Pass<"iree-llvmgpu-assign-constant-ordinals", "IREE::HAL::ExecutableVariantOp"> {
+ let summary = "Assigns executable constant ordinals across all LLVMGPU variants.";
+}
+
def LLVMGPUCastAddressSpaceFunctionPass :
Pass<"iree-llvmgpu-cast-address-space-function", "ModuleOp"> {
let summary = "Cast address space to generic in CallOp and FuncOp";
@@ -98,6 +103,18 @@
];
}
+def LLVMGPULinkExecutablesPass :
+ Pass<"iree-llvmgpu-link-executables", "mlir::ModuleOp"> {
+ let summary = "Links LLVMGPU HAL executables within the top-level program module.";
+ let options = [
+ Option<
+ "target", "target",
+ "std::string", "",
+ "Target backend name whose executables will be linked by this pass."
+ >,
+ ];
+}
+
def LLVMGPULowerExecutableTargetPass :
InterfacePass<"iree-llvmgpu-lower-executable-target", "mlir::FunctionOpInterface"> {
let summary = "Perform lowering of executable target using one of the IREE::HAL::DispatchLoweringPassPipeline";
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
index 4097320..1088035 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel
@@ -21,6 +21,7 @@
"amdgpu_chained_matmul.mlir",
"amdgpu_contraction_distribution.mlir",
"amdgpu_set_anchor_layouts.mlir",
+ "assign_constant_ordinals.mlir",
"conv_pipeline_test_cuda.mlir",
"conv_pipeline_test_rocm.mlir",
"convert_to_nvvm.mlir",
@@ -38,6 +39,7 @@
"gpu_set_num_workgroups.mlir",
"gpu_pipeline_generalize_named_ops.mlir",
"gpu_pipeline_igemm.mlir",
+ "link_executables.mlir",
"nvvm_extract_address_computation.mlir",
"nvvm_pipeline_test.mlir",
"nvvm_mma_sync_pipeline_test.mlir",
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
index 2a86fd3..795ee25 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt
@@ -17,6 +17,7 @@
"amdgpu_chained_matmul.mlir"
"amdgpu_contraction_distribution.mlir"
"amdgpu_set_anchor_layouts.mlir"
+ "assign_constant_ordinals.mlir"
"cast_address_space_function.mlir"
"cast_type_to_fit_mma.mlir"
"config_custom_op.mlir"
@@ -39,6 +40,7 @@
"illegal_configuration.mlir"
"legalize.mlir"
"linalg_transform.mlir"
+ "link_executables.mlir"
"llvmgpu_bufferize.mlir"
"llvmgpu_convolution_to_igemm.mlir"
"nvvm_extract_address_computation.mlir"
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir
new file mode 100644
index 0000000..8a133f9
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/assign_constant_ordinals.mlir
@@ -0,0 +1,22 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-llvmgpu-assign-constant-ordinals)))" --split-input-file %s | FileCheck %s
+
+hal.executable private @executable {
+ hal.executable.variant public @variant target(#hal.executable.target<"rocm", "rocm-hsaco-fb">) {
+ hal.executable.constant.block(%device: !hal.device) -> i32 as "foo" {
+ %c0 = arith.constant 0 : i32
+ hal.return %c0 : i32
+ }
+ hal.executable.constant.block(%device: !hal.device) -> i32 as "bar" {
+ %c1 = arith.constant 1 : i32
+ hal.return %c1 : i32
+ }
+ builtin.module {
+ // CHECK: llvm.mlir.global internal constant @__constant_ordinal_foo_a(0 : i32)
+ llvm.mlir.global internal @__constant_ordinal_foo_a() {addr_space = 4 : i32, hal.executable.constant.key = "foo", sym_visibility = "private"} : i32
+ // CHECK: llvm.mlir.global internal constant @__constant_ordinal_foo_b(0 : i32)
+ llvm.mlir.global internal @__constant_ordinal_foo_b() {addr_space = 4 : i32, hal.executable.constant.key = "foo", sym_visibility = "private"} : i32
+ // CHECK: llvm.mlir.global internal constant @__constant_ordinal_bar(1 : i32)
+ llvm.mlir.global internal @__constant_ordinal_bar() {addr_space = 4 : i32, hal.executable.constant.key = "bar", sym_visibility = "private"} : i32
+ }
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir
new file mode 100644
index 0000000..5655992
--- /dev/null
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/link_executables.mlir
@@ -0,0 +1,150 @@
+// RUN: iree-opt --iree-llvmgpu-link-executables --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-TARGET
+// RUN: iree-opt --pass-pipeline='builtin.module(iree-llvmgpu-link-executables{target="cuda"},iree-llvmgpu-link-executables{target="rocm"})' --split-input-file %s | FileCheck %s --check-prefix=CHECK-MULTI
+
+#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb">
+
+// Expect a single executable with both exports and correct ordinals.
+// CHECK: hal.executable private @link_executables_linked
+// CHECK: hal.executable.variant public @rocm_hsaco_fb
+// CHECK: hal.executable.export public @export0 ordinal(0)
+// CHECK: hal.executable.export public @export1 ordinal(1)
+
+// Expect one LLVM module with all globals and functions.
+// Note that shared memory is duplicated but dynamic shared memory is not.
+// CHECK: builtin.module
+// CHECK-NEXT: llvm.mlir.global external @__dynamic_shared_memory__
+// CHECK-NEXT: llvm.mlir.global private @__shared_memory__{{.+}} : !llvm.array<2 x array<64 x i32>>
+// CHECK-NEXT: llvm.func @export0
+// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
+// CHECK-NEXT: llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
+// CHECK: llvm.mlir.global private @__shared_memory___0{{.+}} : !llvm.array<2 x array<128 x i32>>
+// CHECK-NEXT: llvm.func @export1
+// CHECK-NEXT: llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
+// CHECK-NEXT: llvm.mlir.addressof @__shared_memory___0 : !llvm.ptr<3>
+
+hal.executable private @executable0 {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) {
+ hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<2 x array<64 x i32>>
+ llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
+ %0 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
+ %1 = llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
+ llvm.return
+ }
+ }
+ }
+}
+hal.executable private @executable1 {
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) {
+ hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.mlir.global external @__dynamic_shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<2 x array<128 x i32>>
+ llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
+ %0 = llvm.mlir.addressof @__dynamic_shared_memory__ : !llvm.ptr<3>
+ %1 = llvm.mlir.addressof @__shared_memory__ : !llvm.ptr<3>
+ llvm.return
+ }
+ }
+ }
+}
+
+// -----
+
+#executable_target_cuda = #hal.executable.target<"cuda", "cuda-nvptx-fb">
+#executable_target_rocm = #hal.executable.target<"rocm", "rocm-hsaco-fb">
+
+// Expect a single executable with multiple variants when not specifying target.
+// CHECK: hal.executable private @link_executables_linked
+// CHECK: hal.executable.variant public @cuda_nvptx_fb_0
+// CHECK: hal.executable.export public @export0 ordinal(0)
+// CHECK: hal.executable.export public @export1 ordinal(1)
+// CHECK: hal.executable.variant public @rocm_hsaco_fb_1
+// CHECK: hal.executable.export public @export0 ordinal(0)
+// CHECK: hal.executable.export public @export1 ordinal(1)
+
+// Expect only one target be linked when specified.
+// CHECK-TARGET: hal.executable private @link_executables_linked
+// CHECK-TARGET: hal.executable.variant public @rocm_hsaco_fb_1
+// CHECK-TARGET: hal.executable.export public @export0 ordinal(0)
+// CHECK-TARGET: hal.executable.export public @export1 ordinal(1)
+// CHECK-TARGET: hal.executable private @executable0
+// CHECK-TARGET: hal.executable.variant public @cuda_nvptx_fb
+// CHECK-TARGET: hal.executable.export public @export0 ordinal(0)
+// CHECK-TARGET: hal.executable private @executable1
+// CHECK-TARGET: hal.executable.variant public @cuda_nvptx_fb
+// CHECK-TARGET: hal.executable.export public @export1 ordinal(0)
+
+// Multiple applications of the pass per target should not conflict.
+// CHECK-MULTI: hal.executable private @link_executables_linked_0
+// CHECK-MULTI: hal.executable.variant public @rocm_hsaco_fb_1
+// CHECK-MULTI: hal.executable.export public @export0 ordinal(0)
+// CHECK-MULTI: hal.executable.export public @export1 ordinal(1)
+// CHECK-MULTI: hal.executable private @link_executables_linked
+// CHECK-MULTI: hal.executable.variant public @cuda_nvptx_fb_0
+// CHECK-MULTI: hal.executable.export public @export0 ordinal(0)
+// CHECK-MULTI: hal.executable.export public @export1 ordinal(1)
+
+hal.executable private @executable0 {
+ hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda) {
+ hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
+ llvm.return
+ }
+ }
+ }
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) {
+ hal.executable.export public @export0 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.func @export0(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
+ llvm.return
+ }
+ }
+ }
+}
+hal.executable private @executable1 {
+ hal.executable.variant public @cuda_nvptx_fb target(#executable_target_cuda) {
+ hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
+ llvm.return
+ }
+ }
+ }
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm) {
+ hal.executable.export public @export1 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>) {
+ ^bb0(%arg0: !hal.device):
+ %c1 = arith.constant 1 : index
+ hal.return %c1, %c1, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.func @export1(%arg0: !llvm.ptr<1> {llvm.align = 16 : i32, llvm.noalias}) {
+ llvm.return
+ }
+ }
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
index ad4e543..003d3f7 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
@@ -67,7 +67,8 @@
// symbol tracked in |targetSymbolMap|.
LogicalResult
mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp,
- DenseMap<StringRef, Operation *> &targetSymbolMap) {
+ DenseMap<StringRef, Operation *> &targetSymbolMap,
+ std::function<bool(mlir::Operation *op)> canRenameSymbol) {
auto &sourceBlock = sourceModuleOp->getRegion(0).front();
auto &targetBlock = targetModuleOp->getRegion(0).front();
SymbolTable sourceSymbolTable(sourceModuleOp);
@@ -90,15 +91,19 @@
// use the existing target op.
continue;
}
- if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) {
+ if (canRenameSymbol(symbolOp)) {
// Since the source symbol is private we can rename it as all uses
// are known to be local to the source module.
renameWithDisambiguatedName(sourceOp, sourceModuleOp, targetSymbolMap,
&sourceSymbolTable);
} else {
// The source symbol has 'nested' or 'public' visibility.
- if (SymbolTable::getSymbolVisibility(targetOp) !=
- SymbolTable::Visibility::Private) {
+ if (canRenameSymbol(targetOp)) {
+ // Keep the original name for our new op, rename the target op.
+ renameWithDisambiguatedName(targetOp, targetModuleOp,
+ targetSymbolMap,
+ /*optionalSymbolTable=*/nullptr);
+ } else {
// Oops! Both symbols are public and we can't safely rename either.
// If you hit this with ops that you think are safe to rename, mark
// them private.
@@ -109,11 +114,6 @@
// where that isn't true.
return sourceOp->emitError()
<< "multiple public symbols with the name: " << symbolName;
- } else {
- // Keep the original name for our new op, rename the target op.
- renameWithDisambiguatedName(targetOp, targetModuleOp,
- targetSymbolMap,
- /*optionalSymbolTable=*/nullptr);
}
}
}
diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h
index cf4ca4d..a33f168 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h
+++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.h
@@ -19,6 +19,11 @@
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h version.
// Only difference is one has the symbol map that we don't even need.
+static inline bool allowRenamingPrivateSymbols(Operation *op) {
+ return SymbolTable::getSymbolVisibility(op) ==
+ SymbolTable::Visibility::Private;
+}
+
// Destructively merges |sourceModuleOp| into |targetModuleOp|.
// |targetSymbolMap| is updated with the new symbols.
//
@@ -29,7 +34,9 @@
// symbol tracked in |targetSymbolMap|.
LogicalResult
mergeModuleInto(Operation *sourceModuleOp, Operation *targetModuleOp,
- DenseMap<StringRef, Operation *> &targetSymbolMap);
+ DenseMap<StringRef, Operation *> &targetSymbolMap,
+ std::function<bool(mlir::Operation *op)> canRenameSymbol =
+ allowRenamingPrivateSymbols);
// Links all executables for the current target found in |moduleOp| into
// |linkedExecutableOp|. Functions will be moved into |linkedModuleOp|.
diff --git a/experimental/web/sample_static/device_multithreaded.c b/experimental/web/sample_static/device_multithreaded.c
index c70924b..8b5ba39 100644
--- a/experimental/web/sample_static/device_multithreaded.c
+++ b/experimental/web/sample_static/device_multithreaded.c
@@ -18,7 +18,7 @@
// Register the statically linked executable library.
const iree_hal_executable_library_query_fn_t libraries[] = {
- mnist_linked_llvm_cpu_library_query,
+ mnist_linked_library_query,
};
iree_hal_executable_loader_t* library_loader = NULL;
iree_status_t status = iree_hal_static_library_loader_create(
diff --git a/experimental/web/sample_static/device_sync.c b/experimental/web/sample_static/device_sync.c
index 3fbe3ee..f072903 100644
--- a/experimental/web/sample_static/device_sync.c
+++ b/experimental/web/sample_static/device_sync.c
@@ -15,7 +15,7 @@
// Register the statically linked executable library.
const iree_hal_executable_library_query_fn_t libraries[] = {
- mnist_linked_llvm_cpu_library_query,
+ mnist_linked_library_query,
};
iree_hal_executable_loader_t* library_loader = NULL;
iree_status_t status = iree_hal_static_library_loader_create(
diff --git a/tests/e2e/stablehlo_models/CMakeLists.txt b/tests/e2e/stablehlo_models/CMakeLists.txt
index f12f2fa..896a852 100644
--- a/tests/e2e/stablehlo_models/CMakeLists.txt
+++ b/tests/e2e/stablehlo_models/CMakeLists.txt
@@ -42,7 +42,7 @@
SRC
"mnist_fake_weights.mlir"
STATIC_LIB_PREFIX
- mnist_fake_weights_linked_llvm_cpu
+ mnist_fake_weights_linked
ENTRY_FUNCTION
"predict"
FUNCTION_INPUTS
@@ -57,7 +57,7 @@
SRC
"mnist_fake_weights.mlir"
STATIC_LIB_PREFIX
- mnist_fake_weights_linked_llvm_cpu
+ mnist_fake_weights_linked
ENTRY_FUNCTION
"predict"
FUNCTION_INPUTS