Adding `HoistExecutableObjectsPass`.
This runs nested on variants to find all `hal.executable.objects`
attrs nested in the inner module and move them to the parent
`hal.executable.variant`. This allows codegen/plugin/etc passes running
on executable contents to declare an object they want to include by
making only local changes (such as in a pattern rewriter) and then
letting the pass move them to the variant where they belong.
This only handles arrays of objects as expected after
`MaterializeInterfacesPass` runs - target object dictionaries are not
very easy to merge and we generally want to run after executable
translation/linking anyway where they have already been baked out.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
index 1d07611..b39f541 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/BUILD.bazel
@@ -24,6 +24,7 @@
"DumpExecutableSources.cpp",
"ElideRedundantCommands.cpp",
"FixupLegacySync.cpp",
+ "HoistExecutableObjects.cpp",
"InitializeDevices.cpp",
"InlineMemoizeRegions.cpp",
"LinkExecutables.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
index 1d57198..7dccc49 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@
"DumpExecutableSources.cpp"
"ElideRedundantCommands.cpp"
"FixupLegacySync.cpp"
+ "HoistExecutableObjects.cpp"
"InitializeDevices.cpp"
"InlineMemoizeRegions.cpp"
"LinkExecutables.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/HoistExecutableObjects.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/HoistExecutableObjects.cpp
new file mode 100644
index 0000000..713c6a7
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/HoistExecutableObjects.cpp
@@ -0,0 +1,69 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::iree_compiler::IREE::HAL {
+
+#define GEN_PASS_DEF_HOISTEXECUTABLEOBJECTSPASS
+#include "iree/compiler/Dialect/HAL/Transforms/Passes.h.inc"
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// --iree-hal-hoist-executable-objects
+//===----------------------------------------------------------------------===//
+
+struct HoistExecutableObjectsPass
+ : public IREE::HAL::impl::HoistExecutableObjectsPassBase<
+ HoistExecutableObjectsPass> {
+ void runOnOperation() override {
+ // Note that some executables may be external and not have any contents.
+ if (getOperation().isExternal()) {
+ return;
+ }
+
+ auto objectsAttrName =
+ StringAttr::get(&getContext(), "hal.executable.objects");
+
+ // Seed with existing variant-level object attrs, if any present.
+ SetVector<Attribute> allObjectAttrs;
+ if (auto existingAttr = getOperation().getObjectsAttr()) {
+ allObjectAttrs.insert(existingAttr.begin(), existingAttr.end());
+ }
+
+ // Move all op-level attributes into a unique set. Note that order can be
+ // important so we use an ordered set.
+ //
+ // We could do this first as a gather step in parallel if this walk gets too
+ // expensive.
+ bool foundAnyAttrs = false;
+ getOperation().getInnerModule().walk([&](Operation *op) {
+ auto objectsAttr = op->getAttrOfType<ArrayAttr>(objectsAttrName);
+ if (objectsAttr) {
+ allObjectAttrs.insert(objectsAttr.begin(), objectsAttr.end());
+ op->removeAttr(objectsAttrName);
+ foundAnyAttrs = true;
+ }
+ });
+
+ // Update the variant if any changes were made.
+ if (foundAnyAttrs) {
+ getOperation().setObjectsAttr(
+ ArrayAttr::get(&getContext(), allObjectAttrs.getArrayRef()));
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 1de6f1c..46048f1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -471,7 +471,7 @@
// TODO(benvanik): move translation down to here.
// After all executables are translated and before resolving export
- // ordinals, we allow the backends to link executables together. For
+ // ordinals we allow the backends to link executables together. For
// example, the LLVM AOT backend may combine all executable targets for the
// same architecture into a single executable and link it as a shared
// library.
@@ -479,6 +479,14 @@
passManager.addPass(IREE::HAL::createLinkExecutablesPass({targetRegistry}));
}
+ // If any executable variants have external objects referenced within them
+ // we hoist them up to the top-level variant. This is done after linking so
+ // that we have the greatest chance of combining executables without different
+ // object attrs preventing the merging.
+ passManager.nest<IREE::HAL::ExecutableOp>()
+ .addNestedPass<IREE::HAL::ExecutableVariantOp>(
+ IREE::HAL::createHoistExecutableObjectsPass());
+
// Resolve export ordinals from nested symbol references prior to
// serialization. As this pass creates lookup ops it should run before
// MaterializeResourceCachesPass.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
index b017b81..6bec7a6 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.td
@@ -415,6 +415,15 @@
];
}
+def HoistExecutableObjectsPass :
+ Pass<"iree-hal-hoist-executable-objects", "IREE::HAL::ExecutableVariantOp"> {
+ let summary = "Hoists local executable object annotations to the parent `hal.executable.variant`.";
+ let description = [{
+ Finds all `hal.executable.objects` attrs on all ops within an executable
+ inner module and moves them to the parent `hal.executable.variant` op.
+ }];
+}
+
def PruneExecutablesPass :
Pass<"iree-hal-prune-executables", "mlir::ModuleOp"> {
let summary = "Prunes executable variants and exports that are not referenced.";
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
index b949fda..aa2130b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/BUILD.bazel
@@ -24,6 +24,7 @@
"dump_executable_sources.mlir",
"elide_redundant_commands.mlir",
"fixup_legacy_sync.mlir",
+ "hoist_executable_objects.mlir",
"initialize_devices.mlir",
"inline_memoize_regions.mlir",
"materialize_dispatch_instrumentation.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
index 1b9d35e..cd20f38 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt
@@ -22,6 +22,7 @@
"dump_executable_sources.mlir"
"elide_redundant_commands.mlir"
"fixup_legacy_sync.mlir"
+ "hoist_executable_objects.mlir"
"initialize_devices.mlir"
"inline_memoize_regions.mlir"
"materialize_dispatch_instrumentation.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/hoist_executable_objects.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/hoist_executable_objects.mlir
new file mode 100644
index 0000000..17330a6
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/hoist_executable_objects.mlir
@@ -0,0 +1,61 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-hoist-executable-objects)))" %s | FileCheck %s
+
+// Tests that attributes on top-level ops and nested ops are all detected,
+// deduplicated, and moved to the variant.
+
+// CHECK: hal.executable public @executable
+hal.executable public @executable {
+ // CHECK: hal.executable.variant public @backend
+ // CHECK-SAME: objects([
+ // CHECK-SAME: #hal.executable.object<{path = "existing_variant.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "extern_fn_common.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "extern_fn_a.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "extern_fn_b.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "nested_common.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "nested_a.obj"}>,
+ // CHECK-SAME: #hal.executable.object<{path = "nested_b.obj"}>
+ hal.executable.variant public @backend target(#hal.executable.target<"backend", "format">) objects([
+ #hal.executable.object<{path = "existing_variant.obj"}>
+ ]) {
+ hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>
+ ]>)
+ builtin.module {
+ // CHECK: func.func private @extern_fn_a
+ // CHECK-NOT: hal.executable.objects
+ func.func private @extern_fn_a() attributes {
+ hal.executable.objects = [
+ #hal.executable.object<{path = "extern_fn_common.obj"}>,
+ #hal.executable.object<{path = "extern_fn_a.obj"}>
+ ]
+ }
+ // CHECK: func.func private @extern_fn_b
+ // CHECK-NOT: hal.executable.objects
+ func.func private @extern_fn_b() attributes {
+ hal.executable.objects = [
+ #hal.executable.object<{path = "extern_fn_common.obj"}>,
+ #hal.executable.object<{path = "extern_fn_b.obj"}>
+ ]
+ }
+ func.func @entry0() {
+ // CHECK: call @extern_fn_a
+ // CHECK-NOT: hal.executable.objects
+ call @extern_fn_a() {
+ hal.executable.objects = [
+ #hal.executable.object<{path = "nested_common.obj"}>,
+ #hal.executable.object<{path = "nested_a.obj"}>
+ ]
+ } : () -> ()
+ call @extern_fn_b() {
+ // CHECK: call @extern_fn_b
+ // CHECK-NOT: hal.executable.objects
+ hal.executable.objects = [
+ #hal.executable.object<{path = "nested_common.obj"}>,
+ #hal.executable.object<{path = "nested_b.obj"}>
+ ]
+ } : () -> ()
+ return
+ }
+ }
+ }
+}