Improving linking support for ROCM and ukernels. (#19211)
To support externally-defined ukernels on ROCM the ROCMTarget has been
brought in-line with LLVMCPU/CUDA by calling `linkBitcodeObjects`. To
make authoring passes that include object references
`#hal.executable.object` now allows any data type to be associated so
long as it is serializable allowing for external resource attrs and
other custom attributes that may serialize based on other information.
To allow patterns to attach object references all ops within an
executable variant can now declare a `hal.executable.objects` array that
will be hoisted and merged into the top-level variant objects after our
executable linking pass (before serialization where they are used).
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index b56b9bb..a49780f 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -527,6 +527,26 @@
}
}
+ // Link bitcode (*.bc) object attrs specified by the input program.
+ // Note that this happens after the command-line files so that the command
+ // line ones override the symbols coming from the embedded files.
+ auto specializationCallback = [&](llvm::Module &userModule) {
+ // TODO: inject __nvvm_reflect-style functions/globals for
+ // bitcode specialization based on the targetMachine and configuration.
+ // These could use any information we have on the IREE side as well as
+ // the TargetMachine.
+ };
+ unsigned linkerFlags =
+ llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc;
+ if (failed(linkBitcodeObjects(variantOp.getLoc(), linker, linkerFlags,
+ *targetMachine, variantOp.getObjectsAttr(),
+ llvmModule->getContext(),
+ specializationCallback))) {
+ return mlir::emitError(variantOp.getLoc())
+ << "failed linking in user objects for target triple '"
+ << targetArch.str() << "'";
+ }
+
// Link module to HIP device library.
if (bitcodeDirectory.empty()) {
return variantOp.emitError()
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
index 95cf53e..dd1002a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp
@@ -231,7 +231,8 @@
}
auto pathAttr = llvm::dyn_cast_if_present<StringAttr>(dict.get("path"));
auto dataAttr =
- llvm::dyn_cast_if_present<DenseIntElementsAttr>(dict.get("data"));
+ llvm::dyn_cast_if_present<IREE::Util::SerializableAttrInterface>(
+ dict.get("data"));
return get(p.getContext(), pathAttr, dataAttr);
}
@@ -312,12 +313,14 @@
std::optional<std::string> ExecutableObjectAttr::loadData() {
if (auto dataAttr = getData()) {
- // This is shady but so is using this feature.
- // TODO(benvanik): figure out a way to limit the attribute to signless int8.
- // We could share the attribute -> byte array code with the VM constant
- // serialization if we wanted.
- auto rawData = dataAttr.getRawData();
- return std::string(rawData.data(), rawData.size());
+ std::string buffer;
+ buffer.resize(dataAttr.getStorageSize());
+ if (failed(dataAttr.serializeToBuffer(
+ UnknownLoc::get(dataAttr.getContext()), llvm::endianness::native,
+ ArrayRef(buffer.data(), buffer.size())))) {
+ return std::nullopt;
+ }
+ return buffer;
} else if (auto pathAttr = getPath()) {
// Search for file and try to load it if found.
auto filePath =
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
index 3c1ebd7..c0dbc59 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td
@@ -542,7 +542,7 @@
let parameters = (ins
AttrParameter<"StringAttr", "">:$path,
- OptionalParameter<"DenseIntElementsAttr", "">:$data
+ OptionalParameter<"IREE::Util::SerializableAttrInterface", "">:$data
);
let hasCustomAssemblyFormat = 1;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index 1d07611..b39f541 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -24,6 +24,7 @@
"DumpExecutableSources.cpp",
"ElideRedundantCommands.cpp",
"FixupLegacySync.cpp",
+ "HoistExecutableObjects.cpp",
"InitializeDevices.cpp",
"InlineMemoizeRegions.cpp",
"LinkExecutables.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 1d57198..7dccc49 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@
"DumpExecutableSources.cpp"
"ElideRedundantCommands.cpp"
"FixupLegacySync.cpp"
+ "HoistExecutableObjects.cpp"
"InitializeDevices.cpp"
"InlineMemoizeRegions.cpp"
"LinkExecutables.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/HoistExecutableObjects.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/HoistExecutableObjects.cpp
new file mode 100644
index 0000000..713c6a7
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/HoistExecutableObjects.cpp
@@ -0,0 +1,69 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_HOISTEXECUTABLEOBJECTSPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-hoist-executable-objects
+//===----------------------------------------------------------------------===//
+
+struct HoistExecutableObjectsPass
+ : public IREE::HAL::impl::HoistExecutableObjectsPassBase<
+ HoistExecutableObjectsPass> {
+ void runOnOperation() override {
+ // Note that some executables may be external and not have any contents.
+ if (getOperation().isExternal()) {
+ return;
+ }
+
+ auto objectsAttrName =
+ StringAttr::get(&getContext(), "hal.executable.objects");
+
+ // Seed with existing variant-level object attrs, if any present.
+ SetVector<Attribute> allObjectAttrs;
+ if (auto existingAttr = getOperation().getObjectsAttr()) {
+ allObjectAttrs.insert(existingAttr.begin(), existingAttr.end());
+ }
+
+ // Move all op-level attributes into a unique set. Note that order can be
+ // important so we use an ordered set.
+ //
+ // We could do this first as a gather step in parallel if this walk gets too
+ // expensive.
+ bool foundAnyAttrs = false;
+ getOperation().getInnerModule().walk([&](Operation *op) {
+ auto objectsAttr = op->getAttrOfType<ArrayAttr>(objectsAttrName);
+ if (objectsAttr) {
+ allObjectAttrs.insert(objectsAttr.begin(), objectsAttr.end());
+ op->removeAttr(objectsAttrName);
+ foundAnyAttrs = true;
+ }
+ });
+
+ // Update the variant if any changes were made.
+ if (foundAnyAttrs) {
+ getOperation().setObjectsAttr(
+ ArrayAttr::get(&getContext(), allObjectAttrs.getArrayRef()));
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 1de6f1c..46048f1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -471,7 +471,7 @@
// TODO(benvanik): move translation down to here.
// After all executables are translated and before resolving export
- // ordinals, we allow the backends to link executables together. For
+ // ordinals we allow the backends to link executables together. For
// example, the LLVM AOT backend may combine all executable targets for the
// same architecture into a single executable and link it as a shared
// library.
@@ -479,6 +479,14 @@
passManager.addPass(IREE::HAL::createLinkExecutablesPass({targetRegistry}));
}
+ // If any executable variants have external objects referenced within them
+ // we hoist them up to the top-level variant. This is done after linking so
+ // that we have the greatest chance of combining executables without different
+ // object attrs preventing the merging.
+ passManager.nest<IREE::HAL::ExecutableOp>()
+ .addNestedPass<IREE::HAL::ExecutableVariantOp>(
+ IREE::HAL::createHoistExecutableObjectsPass());
+
// Resolve export ordinals from nested symbol references prior to
// serialization. As this pass creates lookup ops it should run before
// MaterializeResourceCachesPass.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
index b017b81..6bec7a6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
@@ -415,6 +415,15 @@
];
}
+def HoistExecutableObjectsPass :
+ Pass<"iree-hal-hoist-executable-objects", "IREE::HAL::ExecutableVariantOp"> {
+ let summary = "Hoists local executable object annotations to the parent `hal.executable.variant`.";
+ let description = [{
+ Finds all `hal.executable.objects` attrs on all ops within an executable
+ inner module and moves them to the parent `hal.executable.variant` op.
+ }];
+}
+
def PruneExecutablesPass :
Pass<"iree-hal-prune-executables", "mlir::ModuleOp"> {
let summary = "Prunes executable variants and exports that are not referenced.";
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp
index 1623a27..c460d4c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/SubstituteExecutables.cpp
@@ -173,12 +173,13 @@
// objects in case there were any as this does entire executable replacement -
// there may have been microkernel libraries or something referenced by the
// existing module.
- auto dataObjectAttr = builder.getAttr<IREE::HAL::ExecutableObjectAttr>(
- builder.getStringAttr(llvm::sys::path::filename(filePath)),
- DenseIntElementsAttr::get(
+ auto dataAttr =
+ cast<IREE::Util::SerializableAttrInterface>(DenseIntElementsAttr::get(
VectorType::get({static_cast<int64_t>(fileContents->size())},
builder.getI8Type()),
ArrayRef(fileContents->data(), fileContents->size())));
+ auto dataObjectAttr = builder.getAttr<IREE::HAL::ExecutableObjectAttr>(
+ builder.getStringAttr(llvm::sys::path::filename(filePath)), dataAttr);
variantOp.setObjectsAttr(builder.getArrayAttr({dataObjectAttr}));
// Drop the inner module if present (may already be external).
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
index b949fda..aa2130b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
@@ -24,6 +24,7 @@
"dump_executable_sources.mlir",
"elide_redundant_commands.mlir",
"fixup_legacy_sync.mlir",
+ "hoist_executable_objects.mlir",
"initialize_devices.mlir",
"inline_memoize_regions.mlir",
"materialize_dispatch_instrumentation.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index 1b9d35e..cd20f38 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -22,6 +22,7 @@
"dump_executable_sources.mlir"
"elide_redundant_commands.mlir"
"fixup_legacy_sync.mlir"
+ "hoist_executable_objects.mlir"
"initialize_devices.mlir"
"inline_memoize_regions.mlir"
"materialize_dispatch_instrumentation.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/hoist_executable_objects.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/hoist_executable_objects.mlir
new file mode 100644
index 0000000..17330a6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/hoist_executable_objects.mlir
@@ -0,0 +1,61 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-hoist-executable-objects)))" %s | FileCheck %s
+
+// Tests that attributes on top-level ops and nested ops are all detected,
+// deduplicated, and moved to the variant.
+
+// CHECK: hal.executable public @executable
+hal.executable public @executable {
+ // CHECK: hal.executable.variant public @backend
+ // CHECK-SAME: objects([
+ // CHECK-SAME: #hal.executable.object<{path = "existing_variant.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "extern_fn_common.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "extern_fn_a.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "extern_fn_b.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "nested_common.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "nested_a.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "nested_b.obj"}>
+ hal.executable.variant public @backend target(#hal.executable.target<"backend", "format">) objects([
+ #hal.executable.object<{path = "existing_variant.obj"}>
+ ]) {
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+ ]>)
+ builtin.module {
+ // CHECK: func.func private @extern_fn_a
+ // CHECK-NOT: hal.executable.objects
+ func.func private @extern_fn_a() attributes {
+ hal.executable.objects = [
+ #hal.executable.object<{path = "extern_fn_common.obj"}>,
+ #hal.executable.object<{path = "extern_fn_a.obj"}>
+ ]
+ }
+ // CHECK: func.func private @extern_fn_b
+ // CHECK-NOT: hal.executable.objects
+ func.func private @extern_fn_b() attributes {
+ hal.executable.objects = [
+ #hal.executable.object<{path = "extern_fn_common.obj"}>,
+ #hal.executable.object<{path = "extern_fn_b.obj"}>
+ ]
+ }
+ func.func @entry0() {
+ // CHECK: call @extern_fn_a
+ // CHECK-NOT: hal.executable.objects
+ call @extern_fn_a() {
+ hal.executable.objects = [
+ #hal.executable.object<{path = "nested_common.obj"}>,
+ #hal.executable.object<{path = "nested_a.obj"}>
+ ]
+ } : () -> ()
+ call @extern_fn_b() {
+ // CHECK: call @extern_fn_b
+ // CHECK-NOT: hal.executable.objects
+ hal.executable.objects = [
+ #hal.executable.object<{path = "nested_common.obj"}>,
+ #hal.executable.object<{path = "nested_b.obj"}>
+ ]
+ } : () -> ()
+ return
+ }
+ }
+ }
+}
diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
index 809fe8f..09ee551 100644
--- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
+++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp
@@ -126,8 +126,10 @@
static IREE::HAL::ExecutableObjectAttr
convertExecutableObject(IREE::Input::ExecutableObjectAttr src) {
- return IREE::HAL::ExecutableObjectAttr::get(src.getContext(), src.getPath(),
- src.getData());
+ return IREE::HAL::ExecutableObjectAttr::get(
+ src.getContext(), src.getPath(),
+ dyn_cast_if_present<IREE::Util::SerializableAttrInterface>(
+ src.getData()));
}
static IREE::HAL::ExecutableTargetAttr