Adds support for HAL executable object linkage.
Objects are passed through flow/stream and handled by the HAL
infrastructure during interface materialization (where we create the
variants based on target configuration). Each backend then gets the
objects specified for it and can use those in backend-dependent ways.
For now the LLVM-CPU backend only supports external function calls
useful for microkernels and such. This allows for a majority of IREE's
features when defining flow/stream executables that call out to externs
(binding/operand packing/optimization, inlining, linking, and
automatic multi-targeting). In the future support can be added for
generating the boilerplate for external device functions called all the
way from (annotated) source inputs.
The GPU backends (CUDA/Vulkan SPIR-V) currently only support entire
top-level function definition (CUDA kernels or SPIR-V compute shaders).
In the future support can be added for linking (PTX linking or
spirv-link) to enable the microkernel-style substitution of ops that
supports fusion and interface optimization.
Objects can be embedded as data or referenced by file path allowing for
both JIT and precompilation approaches (a codegen backend could take
some input IR, produce via an external tool some objects, and then
rewrite the IR to reference those objects). Because paths are hard the
`--iree-hal-executable-object-search-path=` flag can be used (repeatedly)
to add search paths. When coming from frontends it's probably best to
rely on embedding.
diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
index 70a3b20..ee92720 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp
@@ -1077,6 +1077,16 @@
llvmFuncOp, symbol);
if (calleeOp && !calleeOp.isExternal()) return failure();
+ // If the function is marked as statically linked we don't touch it. That'll
+ // let it fall through to the linker stage where it can be picked up either
+ // from the runtime build (in the case of us producing static libraries) or
+ // the user-specified object files (when producing dynamic libraries).
+ if (calleeOp->hasAttr("hal.import.static")) {
+ return rewriter.notifyMatchFailure(callOp,
+ "external function is marked static "
+ "and does not need an import wrapper");
+ }
+
// TODO(benvanik): way to determine if weak (maybe via linkage?).
bool weak = false;
diff --git a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
index 31f885d..75ae2f3 100644
--- a/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
+++ b/compiler/src/iree/compiler/Codegen/Utils/LinkingUtils.cpp
@@ -196,6 +196,9 @@
OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
auto linkedModuleOp = getInnerModuleFn(linkedTargetOp.getInnerModule());
+ // Aggregation of all external objects specified on variants used.
+ SetVector<Attribute> objectAttrs;
+
// Iterate over all source executable ops, linking as many as we can.
for (auto sourceExecutableOp : sourceExecutableOps) {
// Remap root executable refs.
@@ -211,6 +214,11 @@
// use function multi-versioning to let LLVM insert runtime switches.
if (variantOp.getTarget() != linkedTargetOp.getTarget()) continue;
+ // Add any required object files to the set we will link in the target.
+ if (auto objectsAttr = variantOp.getObjectsAttr()) {
+ objectAttrs.insert(objectsAttr.begin(), objectsAttr.end());
+ }
+
// Remap variant refs.
auto oldVariantRefAttr =
SymbolRefAttr::get(builder.getContext(), sourceExecutableOp.getName(),
@@ -267,6 +275,12 @@
}
}
+ // Attach object files from source variants.
+ if (!objectAttrs.empty()) {
+ linkedTargetOp.setObjectsAttr(
+ builder.getArrayAttr(objectAttrs.takeVector()));
+ }
+
// Update references to @executable::@target::@entry symbols.
replaceEntryPointUses(moduleOp, symbolReplacements);
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
index 8ae38de..0847f76 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.h
@@ -36,6 +36,8 @@
SmallVector<DescriptorSetLayoutBinding> bindings;
};
+using PipelineResourceMap = SmallVector<std::pair<unsigned, unsigned>>;
+
struct PipelineLayout {
// Total number of 32-bit push constants allocated. Not all dispatchable
// functions using this layout will use all constants.
@@ -44,7 +46,7 @@
SmallVector<DescriptorSetLayout> setLayouts;
// Mapping of flattened source resource bindings into the descriptor sets.
// Matches 1:1 with the IREE::Stream::CmdDispatchOp::resources.
- SmallVector<std::pair<unsigned, unsigned>> resourceMap;
+ PipelineResourceMap resourceMap;
void print(llvm::raw_ostream &os) const;
};
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
index c9b65f4..ba4a477 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALBase.td
@@ -635,6 +635,10 @@
// device that can load an executable of this target.
Attribute getMatchExpression();
+ // Returns true if this attribute is a generic version of |specificAttr|.
+ // A more generic version will match with many specific versions.
+ bool isGenericOf(IREE::HAL::ExecutableTargetAttr specificAttr);
+
// Returns the executable target configuration for the given operation.
// This will recursively walk parent operations until one with the
// `hal.executable.target` attribute is found or a `hal.executable.variant`
@@ -646,6 +650,100 @@
}
//===----------------------------------------------------------------------===//
+// #hal.executable.object<*>
+//===----------------------------------------------------------------------===//
+
+def HAL_ExecutableObjectAttr : AttrDef<HAL_Dialect, "ExecutableObject"> {
+ let mnemonic = "executable.object";
+ let summary = [{object file reference}];
+ let description = [{
+ WIP; defines an object file that can be linked into executables.
+ Today this is only supported for external file references with paths the
+ compiler can successfully resolve from its current working directory.
+
+ Future revisions may change this to an interface that allows both internal
+ and external resources to define the object contents. Linking needs to be
+ updated to support various object compositions and certain backends may
+ require additional infrastructure support.
+
+ In the long term the goal is to allow combinations of declared objects and
+ generated code in order to give control of linking behavior to frontends.
+ Instead of needing global command line flags to link in additional blobs
+ the frontend can emit executables with the dependencies already defined per
+ variant without needing to reach into the IREE compiler code.
+
+ Example:
+ ```mlir
+ #hal.executable.object<{path = "some/file.obj"}>
+ #hal.executable.object<{data = dense<[...]> : vector<2048xi8>}>
+ ```
+ }];
+
+ let parameters = (ins
+ OptionalParameter<"StringAttr", "">:$path,
+ OptionalParameter<"DenseIntElementsAttr", "">:$data
+ );
+
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ // Returns the absolute path of the referenced object file if it exists.
+ FailureOr<std::string> getAbsolutePath();
+
+ // Returns the contents of the object file or None if loading failed.
+ // TODO(benvanik): better return type to support mapping/etc? eh
+ Optional<std::string> loadData();
+ }];
+}
+
+def HAL_ExecutableObjectArrayAttr :
+ TypedArrayAttrBase<HAL_ExecutableObjectAttr,
+ "HAL executable object references">;
+
+def HAL_ExecutableObjectsAttr : AttrDef<HAL_Dialect, "ExecutableObjects"> {
+ let mnemonic = "executable.objects";
+ let summary = [{target-specific object file references}];
+ let description = [{
+ A dictionary mapping executable target specifications to a list of objects.
+ This is used to allow layers of the stack that support multi-targeting to
+ specify information used during lowering into each particular target.
+
+ The key attributes are matched against each target variant based on the
+ backend and format as well as any configuration data provided. When
+ comparing the configuration only fields present in both the key and
+ target variant will be checked and must match. This allows specification of
+ generic sets ("all x86_64 targets get these objects") as well as specific
+ ones ("only x86_64 targets with vector_size = 64 get these objects").
+
+ Example:
+ ```mlir
+ #hal.executable.objects<{
+ #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64"> = [
+ #hal.executable.object<{path = "some/file_arm_64.obj"}>
+ ],
+ #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> = [
+ #hal.executable.object<{path = "some/file_x86_64.obj"}>
+ ]
+ }>
+ ```
+ }];
+
+ let parameters = (ins
+ AttrParameter<"ArrayAttr", "">:$targets,
+ AttrParameter<"ArrayAttr", "">:$targetObjects
+ );
+
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ // Returns the objects specified for the given generic target.
+ Optional<ArrayAttr> getApplicableObjects(
+ IREE::HAL::ExecutableTargetAttr specificTargetAttr);
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// #hal.affinity.queue<*>
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index 3a9012b..1561a83 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -1484,7 +1484,8 @@
let arguments = (ins
OptionalAttr<StrAttr>:$sym_visibility,
- SymbolNameAttr:$sym_name
+ SymbolNameAttr:$sym_name,
+ OptionalAttr<HAL_ExecutableObjectsAttr>:$objects
);
let regions = (region SizedRegion<1>:$body);
@@ -1500,6 +1501,10 @@
let extraClassDeclaration = [{
Block& getBlock() { return getBody().front(); }
+ bool isExternal() {
+ return getBlock().getOps<::mlir::ModuleOp>().empty();
+ }
+
::mlir::ModuleOp getInnerModule() {
auto moduleOps = getBlock().getOps<::mlir::ModuleOp>();
assert(!moduleOps.empty() && "source ops need inner modules");
@@ -1648,7 +1653,8 @@
let arguments = (ins
OptionalAttr<StrAttr>:$sym_visibility,
SymbolNameAttr:$sym_name,
- HAL_ExecutableTargetAttr:$target
+ HAL_ExecutableTargetAttr:$target,
+ OptionalAttr<HAL_ExecutableObjectArrayAttr>:$objects
);
let regions = (region SizedRegion<1>:$body);
@@ -1657,6 +1663,7 @@
custom<SymbolVisibility>($sym_visibility)
$sym_name
`,` `target` `=` $target
+ (`,` `objects` `=` $objects^ )?
attr-dict-with-keyword
regions
}];
@@ -1669,6 +1676,10 @@
let extraClassDeclaration = [{
Block& getBlock() { return getBody().front(); }
+ bool isExternal() {
+ return getBlock().getOps<::mlir::ModuleOp>().empty();
+ }
+
::mlir::ModuleOp getInnerModule() {
auto moduleOps = getBlock().getOps<::mlir::ModuleOp>();
assert(!moduleOps.empty() && "source ops need inner modules");
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
index dcc42b6..7fe4a78 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp
@@ -11,6 +11,10 @@
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Utils/StringUtils.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
@@ -313,6 +317,49 @@
return DeviceMatchExecutableFormatAttr::get(getContext(), getFormat());
}
+// For now this is very simple: if there are any specified fields that are
+// present in this attribute they must match. We could allow target backends
+// to customize this via attribute interfaces in the future if we needed.
+bool ExecutableTargetAttr::isGenericOf(
+ IREE::HAL::ExecutableTargetAttr specificAttr) {
+ if (getBackend() != specificAttr.getBackend() ||
+ getFormat() != specificAttr.getFormat()) {
+ // Totally different backends and binary formats.
+ // There may be cases where we want to share things - such as when targeting
+ // both DLLs and dylibs or something - but today almost all of these are
+ // unique situations.
+ return false;
+ }
+
+ // If the config is empty on either we can quickly match.
+ // This is the most common case for users manually specifying targets.
+ auto genericConfigAttr = getConfiguration();
+ auto specificConfigAttr = specificAttr.getConfiguration();
+ if (!genericConfigAttr || !specificConfigAttr) return true;
+
+ // Ensure all fields in specificConfigAttr either don't exist or match.
+ for (auto expectedAttr : specificConfigAttr.getValue()) {
+ auto actualValue = genericConfigAttr.getNamed(expectedAttr.getName());
+ if (!actualValue) {
+ continue; // ignore, not present in generic
+ }
+ if (actualValue->getValue() != expectedAttr.getValue()) {
+ return false; // mismatch, both have values but they differ
+ }
+ }
+
+ // Ensure all fields in genericConfigAttr exist in the specific one.
+ // If missing then the generic is _more_ specific and can't match.
+ for (auto actualAttr : genericConfigAttr.getValue()) {
+ if (!specificConfigAttr.getNamed(actualAttr.getName())) {
+ return false; // mismatch, present in generic but not specific
+ }
+ }
+
+ // All fields match or are omitted in the generic version.
+ return true;
+}
+
// static
ExecutableTargetAttr ExecutableTargetAttr::lookup(Operation *op) {
auto *context = op->getContext();
@@ -334,6 +381,181 @@
}
//===----------------------------------------------------------------------===//
+// #hal.executable.object
+//===----------------------------------------------------------------------===//
+
+// static
+Attribute ExecutableObjectAttr::parse(AsmParser &p, Type type) {
+ NamedAttrList dict;
+ // `<{` dict `}>`
+ if (failed(p.parseLess()) || failed(p.parseOptionalAttrDict(dict)) ||
+ failed(p.parseGreater())) {
+ return {};
+ }
+ auto pathAttr = dict.get("path").dyn_cast_or_null<StringAttr>();
+ auto dataAttr = dict.get("data").dyn_cast_or_null<DenseIntElementsAttr>();
+ return get(p.getContext(), pathAttr, dataAttr);
+}
+
+void ExecutableObjectAttr::print(AsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << "<{";
+ if (auto pathAttr = getPath()) {
+ os << "path = ";
+ p.printAttribute(getPath());
+ } else if (auto dataAttr = getData()) {
+ os << "data = ";
+ p.printAttribute(getData());
+ }
+ os << "}>";
+}
+
+// Tries to find |filePath| on disk either at its absolute path or joined with
+// any of the specified |searchPaths| in order.
+// Returns the absolute file path when found or a failure if there are no hits.
+static FailureOr<std::string> findFileInPaths(
+ StringRef filePath, ArrayRef<std::string> searchPaths) {
+ // First try to see if it's an absolute path - we don't want to perform any
+ // additional processing on top of that.
+ if (llvm::sys::path::is_absolute(filePath)) {
+ if (llvm::sys::fs::exists(filePath)) return filePath.str();
+ return failure();
+ }
+
+ // Try a relative lookup from the current working directory.
+ if (llvm::sys::fs::exists(filePath)) return filePath.str();
+
+ // Search each path in turn for a file that exists.
+ // It doesn't mean we can open it but we'll get a better error out of the
+ // actual open attempt than what we could produce here.
+ for (auto searchPath : searchPaths) {
+ SmallVector<char> tryPath{searchPath.begin(), searchPath.end()};
+ llvm::sys::path::append(tryPath, filePath);
+ if (llvm::sys::fs::exists(Twine(tryPath))) return Twine(tryPath).str();
+ }
+
+ // Not found in either the user-specified absolute path, cwd, or the search
+ // paths.
+ return failure();
+}
+
+static llvm::cl::list<std::string> clExecutableObjectSearchPath(
+ "iree-hal-executable-object-search-path",
+ llvm::cl::desc("Additional search paths for resolving "
+ "#hal.executable.object file references."),
+ llvm::cl::ZeroOrMore);
+
+FailureOr<std::string> ExecutableObjectAttr::getAbsolutePath() {
+ auto pathAttr = getPath();
+ if (!pathAttr) return failure(); // not a file reference
+ return findFileInPaths(pathAttr.getValue(), clExecutableObjectSearchPath);
+}
+
+Optional<std::string> ExecutableObjectAttr::loadData() {
+ if (auto dataAttr = getData()) {
+ // This is shady but so is using this feature.
+ // TODO(benvanik): figure out a way to limit the attribute to signless int8.
+ // We could share the attribute -> byte array code with the VM constant
+ // serialization if we wanted.
+ auto rawData = dataAttr.getRawData();
+ return std::string(rawData.data(), rawData.size());
+ } else if (auto pathAttr = getPath()) {
+ // Search for file and try to load it if found.
+ auto filePath =
+ findFileInPaths(pathAttr.getValue(), clExecutableObjectSearchPath);
+ if (failed(filePath)) {
+ llvm::errs()
+ << "ERROR: referenced object file not found on any path; use "
+ "--iree-hal-executable-object-search-path= to add search paths: "
+ << *this << "\n";
+ return None;
+ }
+ auto file = llvm::MemoryBuffer::getFile(*filePath);
+ if (!file) return None;
+ return std::string((*file)->getBuffer());
+ }
+ return None;
+}
+
+//===----------------------------------------------------------------------===//
+// #hal.executable.objects
+//===----------------------------------------------------------------------===//
+
+// static
+LogicalResult ExecutableObjectsAttr::verify(
+ function_ref<mlir::InFlightDiagnostic()> emitError, ArrayAttr targetsAttr,
+ ArrayAttr targetObjectsAttr) {
+ if (targetsAttr.size() != targetObjectsAttr.size()) {
+ return emitError() << "targets and objects must be 1:1";
+ }
+ for (auto targetAttr : targetsAttr) {
+ if (!targetAttr.isa<IREE::HAL::ExecutableTargetAttr>()) {
+ return emitError()
+ << "target keys must be #hal.executable.target attributes";
+ }
+ }
+ for (auto objectsAttr : targetObjectsAttr) {
+ auto objectsArrayAttr = objectsAttr.dyn_cast<ArrayAttr>();
+ if (!objectsArrayAttr) {
+ return emitError() << "target objects must be an array of "
+ "#hal.executable.object attributes";
+ }
+ }
+ return success();
+}
+
+// static
+Attribute ExecutableObjectsAttr::parse(AsmParser &p, Type type) {
+ // `<{` target = [objects, ...], ... `}>`
+ SmallVector<Attribute> targetAttrs;
+ SmallVector<Attribute> objectsAttrs;
+ if (failed(p.parseLess())) return {};
+ if (succeeded(p.parseLBrace()) && !succeeded(p.parseOptionalRBrace())) {
+ do {
+ Attribute targetAttr;
+ ArrayAttr objectsAttr;
+ if (failed(p.parseAttribute(targetAttr)) || failed(p.parseEqual()) ||
+ failed(p.parseAttribute(objectsAttr))) {
+ return {};
+ }
+ targetAttrs.push_back(targetAttr);
+ objectsAttrs.push_back(objectsAttr);
+ } while (succeeded(p.parseOptionalComma()));
+ if (failed(p.parseRBrace())) return {};
+ }
+ if (failed(p.parseGreater())) return {};
+ return get(p.getContext(), ArrayAttr::get(p.getContext(), targetAttrs),
+ ArrayAttr::get(p.getContext(), objectsAttrs));
+}
+
+void ExecutableObjectsAttr::print(AsmPrinter &p) const {
+ auto &os = p.getStream();
+ os << "<{";
+ llvm::interleaveComma(llvm::zip(getTargets(), getTargetObjects()), os,
+ [&](std::tuple<Attribute, Attribute> keyValue) {
+ p.printAttribute(std::get<0>(keyValue));
+ os << " = ";
+ p.printAttributeWithoutType(std::get<1>(keyValue));
+ });
+ os << "}>";
+}
+
+Optional<ArrayAttr> ExecutableObjectsAttr::getApplicableObjects(
+ IREE::HAL::ExecutableTargetAttr specificTargetAttr) {
+ SmallVector<Attribute> allObjectAttrs;
+ for (auto [targetAttr, objectsAttr] :
+ llvm::zip(getTargets(), getTargetObjects())) {
+ auto genericTargetAttr = targetAttr.cast<IREE::HAL::ExecutableTargetAttr>();
+ if (genericTargetAttr.isGenericOf(specificTargetAttr)) {
+ auto objectsArrayAttr = objectsAttr.cast<ArrayAttr>();
+ allObjectAttrs.append(objectsArrayAttr.begin(), objectsArrayAttr.end());
+ }
+ }
+ if (allObjectAttrs.empty()) return None;
+ return ArrayAttr::get(specificTargetAttr.getContext(), allObjectAttrs);
+}
+
+//===----------------------------------------------------------------------===//
// #hal.affinity.queue
//===----------------------------------------------------------------------===//
@@ -632,6 +854,9 @@
#include "iree/compiler/Dialect/HAL/IR/HALTypeInterfaces.cpp.inc"
void HALDialect::registerAttributes() {
+ // Register command line flags:
+ (void)clExecutableObjectSearchPath;
+
addAttributes<
#define GET_ATTRDEF_LIST
#include "iree/compiler/Dialect/HAL/IR/HALAttrs.cpp.inc" // IWYU pragma: keep
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
index b96edd0..525c1a2 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir
@@ -31,6 +31,35 @@
// -----
+"executable.objects"() {
+ // CHECK: data = #hal.executable.object<{data = dense<[4, 5, 6, 7]> : vector<4xi8>}>
+ data = #hal.executable.object<{data = dense<[4, 5, 6, 7]> : vector<4xi8>}>,
+ // CHECK: path = #hal.executable.object<{path = "foo"}>
+ path = #hal.executable.object<{path = "foo"}>
+} : () -> ()
+
+// -----
+
+#target_a = #hal.executable.target<"llvm-cpu", "a">
+#target_b = #hal.executable.target<"llvm-cpu", "b">
+#target_c = #hal.executable.target<"llvm-cpu", "c">
+// CHECK-LABEL: "executable.target_objects"
+"executable.target_objects"() {
+ // CHECK-SAME: empty = #hal.executable.objects<{}>
+ empty = #hal.executable.objects<{}>,
+ // CHECK-SAME: targets_a = #hal.executable.objects<{#hal.executable.target<"llvm-cpu", "a"> = [#hal.executable.object<{path = "a.o"}>]}>
+ targets_a = #hal.executable.objects<{
+ #target_a = [#hal.executable.object<{path = "a.o"}>]
+ }>,
+ // CHECK-SAME: targets_bc = #hal.executable.objects<{#hal.executable.target<"llvm-cpu", "b"> = [#hal.executable.object<{path = "b.o"}>], #hal.executable.target<"llvm-cpu", "c"> = [#hal.executable.object<{path = "c.o"}>]}>
+ targets_bc = #hal.executable.objects<{
+ #target_b = [#hal.executable.object<{path = "b.o"}>],
+ #target_c = [#hal.executable.object<{path = "c.o"}>]
+ }>
+} : () -> ()
+
+// -----
+
"affinity.queue"() {
// CHECK: any = #hal.affinity.queue<*>
any = #hal.affinity.queue<*>,
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
index 56c5aae..b73c8c7 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
@@ -202,7 +202,13 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
+ // For now we disable translation if the variant has external object files.
+ // We could instead perform linking with those objects (if they're bitcode
+ // ala libdevice.bc, etc).
+ if (variantOp.isExternal()) return;
+
buildLLVMGPUTransformPassPipeline(passManager, false);
}
@@ -220,44 +226,14 @@
auto libraryName =
variantOp->getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
- ModuleOp innerModuleOp = variantOp.getInnerModule();
+ // TODO(thomasraoux): property handle export ordinals; this code is assuming
+ // that ordinals are dense starting at 0 but that is not required.
- // Remove all the functions that are not part of the CUDA kernel.
- // TODO: Find a better solution to handle this.
- auto illegalFuncOps =
- llvm::to_vector<4>(innerModuleOp.getOps<func::FuncOp>());
- for (auto funcOp : illegalFuncOps) {
- funcOp.erase();
- }
-
- auto llvmModule =
- mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
- if (!llvmModule) {
- return variantOp.emitError() << "failed to translate the MLIR LLVM "
- "dialect to the native llvm::Module";
- }
-
- // Collect all the entry point names.
- llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOps;
- for (auto op : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
- exportOps[op.getSymName()] = op;
- }
- std::vector<std::array<int32_t, 3>> workgroupSizes;
- std::vector<std::string> entryPointNames;
- std::vector<uint32_t> workgroupLocalMemories;
-
- for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
- auto *llvmFunc = llvmModule->getFunction(func.getName());
- if (llvmFunc->isDeclaration()) continue;
- // setName will make sure the function name is unique.
- llvmFunc->setName(sanitizeSymbolName(func.getName()));
- entryPointNames.emplace_back(llvmFunc->getName());
+ // Collect all the entry point parameters.
+ SmallVector<std::array<int32_t, 3>> workgroupSizes;
+ SmallVector<uint32_t> workgroupLocalMemories;
+ for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
std::array<int32_t, 3> workgroupSize;
- uint32_t workgroupLocalMemory = 0;
- auto exportOp = exportOps[func.getName()];
- if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
- workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
- }
if (Optional<ArrayAttr> workgroupSizeAttr = exportOp.getWorkgroupSize()) {
for (auto it : llvm::enumerate(workgroupSizeAttr.value())) {
workgroupSize[it.index()] = it.value().cast<IntegerAttr>().getInt();
@@ -265,91 +241,149 @@
} else {
workgroupSize = {1, 1, 1};
}
- workgroupLocalMemories.push_back(workgroupLocalMemory);
workgroupSizes.push_back(workgroupSize);
- llvm::Metadata *llvmMetadata[] = {
- llvm::ValueAsMetadata::get(llvmFunc),
- llvm::MDString::get(llvmModule->getContext(), "kernel"),
- llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
- llvm::Type::getInt32Ty(llvmModule->getContext()), 1))};
- llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmModule->getContext(), llvmMetadata);
- llvmModule->getOrInsertNamedMetadata("nvvm.annotations")
- ->addOperand(llvmMetadataNode);
- /* Set maximum number of threads in the thread block (CTA). */
- auto generateMetadata = [&](int dim, StringRef name) {
- llvm::Metadata *llvmMetadata[] = {
- llvm::ValueAsMetadata::get(llvmFunc),
- llvm::MDString::get(llvmModule->getContext(), name),
- llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
- llvm::Type::getInt32Ty(llvmModule->getContext()), dim))};
- llvm::MDNode *llvmMetadataNode =
- llvm::MDNode::get(llvmModule->getContext(), llvmMetadata);
- llvmModule->getOrInsertNamedMetadata("nvvm.annotations")
- ->addOperand(llvmMetadataNode);
- };
- generateMetadata(workgroupSize[0], "maxntidx");
- generateMetadata(workgroupSize[1], "maxntidy");
- generateMetadata(workgroupSize[2], "maxntidz");
+ uint32_t workgroupLocalMemory = 0;
+ if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
+ workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
+ }
+ workgroupLocalMemories.push_back(workgroupLocalMemory);
}
- std::unique_ptr<llvm::TargetMachine> targetMachine;
- {
- llvm::Triple triple("nvptx64-nvidia-cuda");
- std::string targetChip = clTargetChip;
- std::string features = "+ptx60";
- std::string error;
- const llvm::Target *target =
- llvm::TargetRegistry::lookupTarget("", triple, error);
- if (target == nullptr) {
- return variantOp.emitError() << "cannot initialize target triple";
+ SmallVector<std::string> entryPointNames;
+ std::string ptxImage;
+ if (variantOp.isExternal()) {
+ if (!variantOp.getObjects().has_value()) {
+ return variantOp.emitOpError()
+ << "no objects defined for external variant";
+ } else if (variantOp.getObjects()->getValue().size() != 1) {
+ // For now we assume there will be exactly one object file.
+ // In the future we will want to perform a linking step here and ideally
+ // support _also_ linking in the codegen results.
+ return variantOp.emitOpError() << "only one object reference is "
+ "supported for external variants";
}
- targetMachine.reset(target->createTargetMachine(triple.str(), targetChip,
- features, {}, {}));
- if (targetMachine == nullptr) {
- return variantOp.emitError() << "cannot initialize target machine";
+
+ // Take exported names verbatim. The user must have already sanitized
+ // these to match the names in their kernels. We don't support any kind of
+ // mangling and if the user was silly enough to rely on nvcc C++ mangling
+ // they'll have to figure that out.
+ for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ entryPointNames.emplace_back(exportOp.getSymName());
}
+
+ auto objectAttr = variantOp.getObjects()
+ ->getValue()
+ .front()
+ .cast<IREE::HAL::ExecutableObjectAttr>();
+ if (auto data = objectAttr.loadData()) {
+ ptxImage = data.value();
+ } else {
+ return variantOp.emitOpError()
+ << "object file could not be loaded: " << objectAttr;
+ }
+ } else {
+ ModuleOp innerModuleOp = variantOp.getInnerModule();
+
+ // Remove all the functions that are not part of the CUDA kernel.
+ // TODO(thomasraoux): remove this? this should not be required.
+ auto illegalFuncOps =
+ llvm::to_vector<4>(innerModuleOp.getOps<func::FuncOp>());
+ for (auto funcOp : illegalFuncOps) {
+ funcOp.erase();
+ }
+
+ auto llvmModule =
+ mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
+ if (!llvmModule) {
+ return variantOp.emitError() << "failed to translate the MLIR LLVM "
+ "dialect to the native llvm::Module";
+ }
+
+ for (auto [exportOp, workgroupSize] :
+ llvm::zip(variantOp.getOps<IREE::HAL::ExecutableExportOp>(),
+ workgroupSizes)) {
+ auto *llvmFunc = llvmModule->getFunction(exportOp.getName());
+ if (llvmFunc->isDeclaration()) continue;
+
+ // setName will make sure the function name is unique.
+ llvmFunc->setName(sanitizeSymbolName(exportOp.getName()));
+ entryPointNames.emplace_back(llvmFunc->getName());
+
+ auto *annotations =
+ llvmModule->getOrInsertNamedMetadata("nvvm.annotations");
+ auto setMetadataValueI32 = [&](StringRef name, int value) {
+ llvm::Metadata *llvmMetadata[] = {
+ llvm::ValueAsMetadata::get(llvmFunc),
+ llvm::MDString::get(llvmModule->getContext(), name),
+ llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(llvmModule->getContext()), value))};
+ annotations->addOperand(
+ llvm::MDNode::get(llvmModule->getContext(), llvmMetadata));
+ };
+ // Mark the entry point as a kernel.
+ setMetadataValueI32("kernel", 1);
+ // Set the maximum number of threads in the thread block (CTA).
+ setMetadataValueI32("maxntidx", workgroupSize[0]);
+ setMetadataValueI32("maxntidy", workgroupSize[1]);
+ setMetadataValueI32("maxntidz", workgroupSize[2]);
+ }
+
+ std::unique_ptr<llvm::TargetMachine> targetMachine;
+ {
+ llvm::Triple triple("nvptx64-nvidia-cuda");
+ std::string targetChip = clTargetChip;
+ std::string features = "+ptx60";
+ std::string error;
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget("", triple, error);
+ if (target == nullptr) {
+ return variantOp.emitError() << "cannot initialize target triple";
+ }
+ targetMachine.reset(target->createTargetMachine(
+ triple.str(), targetChip, features, {}, {}));
+ if (targetMachine == nullptr) {
+ return variantOp.emitError() << "cannot initialize target machine";
+ }
+ }
+
+ llvmModule->setDataLayout(targetMachine->createDataLayout());
+
+ linkAndOptimize(*llvmModule, *targetMachine);
+
+ // Serialize CUDA kernel into the binary that we will embed in the
+ // final FlatBuffer.
+ ptxImage = translateModuleToISA(*llvmModule, *targetMachine);
}
- llvmModule->setDataLayout(targetMachine->createDataLayout());
-
- linkAndOptimize(*llvmModule, *targetMachine);
+ if (dumpPtx) {
+ llvm::dbgs() << ptxImage;
+ }
+ if (!options.dumpBinariesPath.empty()) {
+ dumpDataToPath(options.dumpBinariesPath, options.dumpBaseName,
+ variantOp.getName(), ".ptx", ptxImage);
+ }
FlatbufferBuilder builder;
iree_CUDAExecutableDef_start_as_root(builder);
- // Serialize cuda kernel into the binary that we will embed in the
- // final FlatBuffer.
- std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine);
- if (dumpPtx) {
- llvm::dbgs() << targetISA;
- }
- if (!options.dumpBinariesPath.empty()) {
- dumpDataToPath(options.dumpBinariesPath, options.dumpBaseName,
- variantOp.getName(), ".ptx", targetISA);
- }
- auto ptxCudeRef = flatbuffers_uint8_vec_create(
- builder, reinterpret_cast<const uint8_t *>(targetISA.c_str()),
- targetISA.size());
-
- auto entryPointsRef = builder.createStringVec(entryPointNames);
- auto workgroupLocalMemoriesRef =
- builder.createInt32Vec(workgroupLocalMemories);
-
+ auto ptxImageRef = flatbuffers_uint8_vec_create(
+ builder, reinterpret_cast<const uint8_t *>(ptxImage.c_str()),
+ ptxImage.size());
iree_CUDABlockSizeDef_vec_start(builder);
- auto blockSizes = workgroupSizes.begin();
- for (auto shader : entryPointNames) {
- iree_CUDABlockSizeDef_vec_push_create(builder, (*blockSizes)[0],
- (*blockSizes)[1], (*blockSizes)[2]);
- ++blockSizes;
+ for (const auto &workgroupSize : workgroupSizes) {
+ iree_CUDABlockSizeDef_vec_push_create(builder, workgroupSize[0],
+ workgroupSize[1], workgroupSize[2]);
}
auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder);
+ auto workgroupLocalMemoriesRef =
+ builder.createInt32Vec(workgroupLocalMemories);
+ auto entryPointsRef = builder.createStringVec(entryPointNames);
iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_CUDAExecutableDef_shared_memory_size_add(builder,
workgroupLocalMemoriesRef);
- iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef);
- iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef);
+ iree_CUDAExecutableDef_ptx_image_add(builder, ptxImageRef);
iree_CUDAExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp
index 853cb4b..4e8e3af 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LLVMCPUTarget.cpp
@@ -149,7 +149,8 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
buildLLVMCPUCodegenPassPipeline(passManager);
}
@@ -467,6 +468,32 @@
bitcodeFile.outputFile->keep();
}
}
+
+ // If custom object files were specified then add those to our artifact set.
+ // These will either be combined into the resulting static library or linked
+ // statically into the resulting dynamic library.
+ if (auto objectAttrs = variantOp.getObjects()) {
+ for (auto [index, attr] : llvm::enumerate(objectAttrs.value())) {
+ auto objectAttr = attr.cast<IREE::HAL::ExecutableObjectAttr>();
+ if (objectAttr.getPath()) {
+ auto absolutePath = objectAttr.getAbsolutePath();
+ if (failed(absolutePath)) {
+ llvm::errs()
+ << "ERROR: referenced object file not found on any path; use "
+ "--iree-hal-executable-object-search-path= to add search "
+ "paths: "
+ << objectAttr << "\n";
+ return failure();
+ }
+ objectFiles.push_back(Artifact::fromFile(*absolutePath));
+ } else if (auto dataAttr = objectAttr.getData()) {
+ objectFiles.push_back(Artifact::createTemporary(
+ objectFiles.front().path + "_object_" + std::to_string(index),
+ ".o"));
+ }
+ }
+ }
+
if (options_.linkStatic) {
return serializeStaticLibraryExecutable(options, variantOp,
executableBuilder, libraryName,
@@ -528,7 +555,7 @@
<< " " << linkArtifacts.libraryFile.path;
linkArtifacts.keepAllFiles();
for (auto &objectFile : objectFiles) {
- objectFile.outputFile->keep();
+ objectFile.keep();
}
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp
index aba0a6f..943760e 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.cpp
@@ -18,6 +18,9 @@
namespace HAL {
// static
+Artifact Artifact::fromFile(StringRef path) { return {path.str(), nullptr}; }
+
+// static
Artifact Artifact::createTemporary(StringRef prefix, StringRef suffix) {
auto sanitizedPrefix = sanitizeFileName(prefix);
auto sanitizedSuffix = sanitizeFileName(suffix);
@@ -54,6 +57,10 @@
return {filePath.str().str(), std::move(file)};
}
+void Artifact::keep() const {
+ if (outputFile) outputFile->keep();
+}
+
Optional<std::vector<int8_t>> Artifact::read() const {
auto fileData = llvm::MemoryBuffer::getFile(path);
if (!fileData) {
@@ -85,10 +92,10 @@
void Artifact::close() { outputFile->os().close(); }
void Artifacts::keepAllFiles() {
- if (libraryFile.outputFile) libraryFile.outputFile->keep();
- if (debugFile.outputFile) debugFile.outputFile->keep();
+ libraryFile.keep();
+ debugFile.keep();
for (auto &file : otherFiles) {
- file.outputFile->keep();
+ file.keep();
}
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h
index 119116d..0036762 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/LinkerTool.h
@@ -24,6 +24,10 @@
namespace HAL {
struct Artifact {
+ // Wraps an existing file on the file system.
+ // The file will not be deleted when the artifact is destroyed.
+ static Artifact fromFile(StringRef path);
+
// Creates an output file path/container pair.
// By default the file will be deleted when the link completes; callers must
// use llvm::ToolOutputFile::keep() to prevent deletion upon success (or if
@@ -41,6 +45,9 @@
std::string path;
std::unique_ptr<llvm::ToolOutputFile> outputFile;
+ // Preserves the file contents on disk after the artifact has been destroyed.
+ void keep() const;
+
// Reads the artifact file contents as bytes.
Optional<std::vector<int8_t>> read() const;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
index c82b9b4..ce86f66 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/LLVM/internal/EmbeddedLinkerTool.cpp
@@ -181,9 +181,11 @@
// command themselves.
if (targetOptions.keepLinkerArtifacts) {
for (auto &objectFile : objectFiles) {
- llvm::errs() << "linker input preserved: "
- << objectFile.outputFile->getFilename();
- objectFile.outputFile->keep();
+ if (objectFile.outputFile) {
+ llvm::errs() << "linker input preserved: "
+ << objectFile.outputFile->getFilename();
+ objectFile.keep();
+ }
}
}
return llvm::None;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
index 6abec12..b9f2cb4 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp
@@ -66,7 +66,13 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
+ // For now we disable translation if the variant has external object files.
+ // We could instead perform linking with those objects (if they're Metal
+ // archives, etc).
+ if (variantOp.isExternal()) return;
+
buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
index 3e48d40..e36f44b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.cpp
@@ -98,7 +98,13 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
+ // For now we disable translation if the variant has external object files.
+ // We could instead perform linking with those objects (if they're bitcode
+ // ala libdevice.bc, etc).
+ if (variantOp.isExternal()) return;
+
buildLLVMGPUTransformPassPipeline(passManager, true);
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
index 64f6e28..a2fd896 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/TargetBackend.h
@@ -166,7 +166,8 @@
// module { spirv.module { ... } }
// }
// }
- virtual void buildTranslationPassPipeline(OpPassManager &passManager) = 0;
+ virtual void buildTranslationPassPipeline(
+ IREE::HAL::ExecutableVariantOp variantOp, OpPassManager &passManager) = 0;
// Inserts passes used to link `hal.executable.variant` ops together.
// The pass manager will be nested on the parent module of `hal.executable`
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp
index 039d248..f1a3d97 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VMVX/VMVXTarget.cpp
@@ -53,7 +53,8 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
IREE::VMVX::buildVMVXTransformPassPipeline(passManager);
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
@@ -158,7 +159,8 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
IREE::VMVX::buildVMVXTransformPassPipeline(passManager);
}
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 6828d2c..41dc40a 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -113,13 +113,26 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
+ // For now we disable translation if the variant has external object files.
+ // We could instead perform linking with those objects (if they're .spv
+ // files we could use spirv-link or import them into MLIR and merge here).
+ if (variantOp.isExternal()) return;
+
buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
}
LogicalResult serializeExecutable(const SerializationOptions &options,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
+ // Today we special-case external variants but in the future we could allow
+ // for a linking approach allowing both code generation and external .spv
+ // files to be combined together.
+ if (variantOp.isExternal()) {
+ return serializeExternalExecutable(options, variantOp, executableBuilder);
+ }
+
ModuleOp innerModuleOp = variantOp.getInnerModule();
auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
if (!llvm::hasSingleElement(spirvModuleOps)) {
@@ -180,6 +193,67 @@
return success();
}
+ LogicalResult serializeExternalExecutable(
+ const SerializationOptions &options,
+ IREE::HAL::ExecutableVariantOp variantOp, OpBuilder &executableBuilder) {
+ if (!variantOp.getObjects().has_value()) {
+ return variantOp.emitOpError()
+ << "no objects defined for external variant";
+ } else if (variantOp.getObjects()->getValue().size() != 1) {
+ // For now we assume there will be exactly one object file.
+ // TODO(#7824): support multiple .spv files in a single flatbuffer archive
+ // so that we can combine executables.
+ return variantOp.emitOpError() << "only one object reference is "
+ "supported for external variants";
+ }
+
+ // Take exported names verbatim for passing into VkShaderModuleCreateInfo.
+ SmallVector<StringRef, 8> entryPointNames;
+ for (auto exportOp : variantOp.getOps<IREE::HAL::ExecutableExportOp>()) {
+ entryPointNames.emplace_back(exportOp.getSymName());
+ }
+
+ // Load .spv object file.
+ auto objectAttr = variantOp.getObjects()
+ ->getValue()
+ .front()
+ .cast<IREE::HAL::ExecutableObjectAttr>();
+ std::string spvBinary;
+ if (auto data = objectAttr.loadData()) {
+ spvBinary = data.value();
+ } else {
+ return variantOp.emitOpError()
+ << "object file could not be loaded: " << objectAttr;
+ }
+ if (spvBinary.size() % 4 != 0) {
+ return variantOp.emitOpError()
+ << "object file is not 4-byte aligned as expected for SPIR-V";
+ }
+
+ FlatbufferBuilder builder;
+ iree_SpirVExecutableDef_start_as_root(builder);
+
+ auto spvCodeRef = flatbuffers_uint32_vec_create(
+ builder, reinterpret_cast<const uint32_t *>(spvBinary.data()),
+ spvBinary.size() / sizeof(uint32_t));
+
+ auto entryPointsRef = builder.createStringVec(entryPointNames);
+
+ iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef);
+ iree_SpirVExecutableDef_code_add(builder, spvCodeRef);
+ iree_SpirVExecutableDef_end_as_root(builder);
+
+ // Add the binary data to the target executable.
+ auto binaryOp = executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
+ variantOp.getLoc(), variantOp.getSymName(),
+ variantOp.getTarget().getFormat(),
+ builder.getBufferAttr(executableBuilder.getContext()));
+ binaryOp.setMimeTypeAttr(
+ executableBuilder.getStringAttr("application/x-flatbuffers"));
+
+ return success();
+ }
+
private:
ArrayAttr getExecutableTargets(MLIRContext *context) const {
SmallVector<Attribute> targetAttrs;
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
index db2ea5e..ae8a544 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/WebGPU/WebGPUTarget.cpp
@@ -84,7 +84,11 @@
context, b.getStringAttr(deviceID()), configAttr);
}
- void buildTranslationPassPipeline(OpPassManager &passManager) override {
+ void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
+ OpPassManager &passManager) override {
+ // For now we disable translation if the variant has external object files.
+ if (variantOp.isExternal()) return;
+
// WebGPU does not support push constants (yet?), so replace loads from
// push constants with loads from uniform buffers.
// The corresponding runtime code must perform similar emulation, based
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
index 3c9e9d9..85b808d 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp
@@ -34,6 +34,20 @@
namespace {
//===----------------------------------------------------------------------===//
+// Linkage utilities
+//===----------------------------------------------------------------------===//
+
+static void setApplicableObjects(Operation *sourceOp,
+ IREE::HAL::ExecutableVariantOp targetOp) {
+ auto objectsAttr = sourceOp->getAttrOfType<IREE::HAL::ExecutableObjectsAttr>(
+ "hal.executable.objects");
+ if (!objectsAttr) return;
+ auto objects = objectsAttr.getApplicableObjects(targetOp.getTarget());
+ if (!objects) return;
+ targetOp.setObjectsAttr(*objects);
+}
+
+//===----------------------------------------------------------------------===//
// hal.executable.source materialization
//===----------------------------------------------------------------------===//
@@ -50,13 +64,13 @@
// With this hand-authored path all variants have the same layout and entry
// points and we can just clone them.
auto sourceEntryPointOps = sourceOp.getOps<IREE::HAL::ExecutableExportOp>();
- auto sourceModuleOp = sourceOp.getInnerModule();
// Materialize all of the hal.executable.variant ops for all backends we are
// targeting.
SymbolTable targetSymbolTable(executableOp);
OpBuilder targetBuilder(&executableOp.getBlock().back());
for (auto targetAttr : targetAttrs) {
+ // Create new variant and clone the exports.
auto targetVariantOp = targetBuilder.create<IREE::HAL::ExecutableVariantOp>(
sourceOp->getLoc(), targetAttr.getSymbolNameFragment(), targetAttr);
targetSymbolTable.insert(targetVariantOp);
@@ -64,8 +78,20 @@
for (auto sourceEntryPointOp : sourceEntryPointOps) {
variantBuilder.clone(*sourceEntryPointOp);
}
- variantBuilder.clone(*sourceModuleOp);
+
+ // Clone any target-specific object files specified.
+ if (auto objectsAttr = sourceOp.getObjectsAttr()) {
+ auto objects = objectsAttr.getApplicableObjects(targetAttr);
+ if (objects) targetVariantOp.setObjectsAttr(*objects);
+ }
+
+ // Clone inner module contents.
+ if (!sourceOp.isExternal()) {
+ auto sourceModuleOp = sourceOp.getInnerModule();
+ variantBuilder.clone(*sourceModuleOp);
+ }
}
+
// Remove the original.
sourceOp.erase();
@@ -210,9 +236,9 @@
// Annotates |dispatchOp| with resource binding to interface binding mappings.
// TODO(benvanik): have a HAL op with structured information instead.
static void annotateDispatchSite(IREE::Stream::CmdDispatchOp dispatchOp,
- const PipelineLayout &pipelineLayout) {
+ const PipelineResourceMap &resourceMap) {
SmallVector<Attribute> bindingAttrs;
- for (auto setBinding : pipelineLayout.resourceMap) {
+ for (auto setBinding : resourceMap) {
bindingAttrs.push_back(IREE::HAL::InterfaceBindingAttr::get(
dispatchOp.getContext(), setBinding.first, setBinding.second));
}
@@ -227,34 +253,39 @@
IREE::Stream::ExecutableOp sourceExecutableOp,
IREE::HAL::ExecutableOp targetExecutableOp,
const BindingLayoutAnalysis &layoutAnalysis) {
+ auto sourceModuleOp = sourceExecutableOp.getInnerModule();
auto variantOps =
targetExecutableOp.getBlock().getOps<IREE::HAL::ExecutableVariantOp>();
OpBuilder executableBuilder(&targetExecutableOp.getBlock().front());
- // For each exported function create a HAL export decl and dispatch thunk.
+ // Build a map of source function definitions to their version with the
+ // updated interface.
+ DenseMap<Operation *, Operation *> targetFuncOps;
int nextOrdinal = 0;
for (auto exportOp : sourceExecutableOp.getBody()
.getOps<IREE::Stream::ExecutableExportOp>()) {
- int ordinal = nextOrdinal++;
- auto sourceFuncOp =
- sourceExecutableOp.getInnerModule().lookupSymbol<mlir::func::FuncOp>(
- exportOp.getFunctionRef());
+ auto sourceFuncOp = sourceModuleOp.lookupSymbol<mlir::func::FuncOp>(
+ exportOp.getFunctionRef());
if (failed(verifyEntryPointTypes(sourceFuncOp))) return failure();
- const auto &pipelineLayout = layoutAnalysis.getPipelineLayout(exportOp);
-
// Create the interface for this entry point based on the analysis of its
// usage within the program.
+ const auto &pipelineLayout = layoutAnalysis.getPipelineLayout(exportOp);
auto layoutAttr = makePipelineLayoutAttr(pipelineLayout, executableBuilder);
- // Clone the source function and update it to use the new interface.
- auto baseFuncOp =
- cloneFuncWithInterface(sourceFuncOp, pipelineLayout, layoutAttr);
+ // Update all dispatch sites with the binding information required for
+ // conversion into the HAL dialect. By doing this here we ensure that the
+ // dialect conversion needs only local information on the ops and that it's
+ // not possible for the dispatches and their targets to get out of sync.
+ for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
+ annotateDispatchSite(dispatchOp, pipelineLayout.resourceMap);
+ }
- // Clone the updated function into each variant.
+ // Clone the updated function declaration into each variant.
+ int ordinal = nextOrdinal++;
for (auto variantOp : variantOps) {
// Declare the entry point on the target.
- OpBuilder targetBuilder(&variantOp.getBlock().front());
+ OpBuilder targetBuilder(variantOp.getInnerModule());
auto newExportOp = targetBuilder.create<IREE::HAL::ExecutableExportOp>(
exportOp.getLoc(),
targetBuilder.getStringAttr(exportOp.getFunctionRef()),
@@ -272,20 +303,38 @@
newExportOp.getWorkgroupCount().insertArgument(0u, deviceType,
newExportOp.getLoc());
}
-
- // Clone the updated interface-based function into the target.
- auto targetFuncOp = baseFuncOp.clone();
- variantOp.getInnerModule().push_back(targetFuncOp);
}
- // Update all dispatch sites with the binding information.
- for (auto dispatchOp : layoutAnalysis.getExportDispatches(exportOp)) {
- annotateDispatchSite(dispatchOp, pipelineLayout);
- }
-
- baseFuncOp.erase();
+ // Clone the source function and update it to use the new interface.
+ auto targetFuncOp =
+ cloneFuncWithInterface(sourceFuncOp, pipelineLayout, layoutAttr);
+ targetFuncOps[sourceFuncOp] = targetFuncOp;
}
+ // Clone all of the ops in the source module to each variant.
+ // We'll use the exported functions with the updated interfaces in place of
+ // the original versions and copy everything else verbatim.
+ for (auto variantOp : variantOps) {
+ auto targetBuilder = OpBuilder::atBlockBegin(
+ &variantOp.getInnerModule().getBodyRegion().front());
+ for (auto &op : sourceModuleOp.getOps()) {
+ auto targetFuncOp = targetFuncOps.find(&op);
+ if (targetFuncOp != targetFuncOps.end()) {
+ // Clone the updated function instead of the original.
+ targetBuilder.clone(*targetFuncOp->second);
+ } else {
+ // Regular op (globals, external function declarations, etc).
+ targetBuilder.clone(op);
+ }
+ }
+ }
+
+ // Drop the temporary target functions. We could avoid an additional clone if
+ // we only had one variant but this is relatively small in cost (once per
+ // variant).
+ for (auto it : targetFuncOps) it.second->erase();
+ targetFuncOps.clear();
+
return success();
}
@@ -425,6 +474,7 @@
targetBuilder.create<IREE::HAL::ExecutableVariantOp>(
sourceOp->getLoc(), targetAttr.getSymbolNameFragment(),
targetAttr);
+ setApplicableObjects(sourceOp, targetContainerOp);
targetSymbolTable.insert(targetContainerOp);
OpBuilder containerBuilder(&targetContainerOp.getBlock().back());
containerBuilder.create<mlir::ModuleOp>(sourceOp->getLoc());
@@ -444,6 +494,49 @@
sourceOp.erase();
}
+
+ // Do a cleanup pass for any dispatches that don't yet have interfaces
+ // assigned. If we had dispatches to externally-defined HAL executables we
+ // won't have materialized them from the stream ops above. We do expect to
+ // be able to find the dispatch targets such that we can pull out the
+ // pipeline layout, though, and any that fall through are errors.
+ auto annotateDispatchOp = [&](IREE::Stream::CmdDispatchOp dispatchOp) {
+ if (dispatchOp->hasAttr("hal.interface.bindings")) {
+ // Already have bindings defined.
+ return WalkResult::advance();
+ }
+ PipelineResourceMap resourceMap;
+ auto exportOp =
+ symbolTable.lookupNearestSymbolFrom<IREE::HAL::ExecutableExportOp>(
+ dispatchOp, dispatchOp.getEntryPointAttr());
+ if (exportOp) {
+ // Export found - we can use the pipeline layout defined there to infer
+ // the bindings. This allows for bindings to be sparse or have
+ // additional information declared.
+ for (auto setLayout : exportOp.getLayoutAttr().getSetLayouts()) {
+ for (auto binding : setLayout.getBindings()) {
+ resourceMap.emplace_back(setLayout.getOrdinal(),
+ binding.getOrdinal());
+ }
+ }
+ } else {
+ // No export found - this is likely an external executable and we can
+ // infer a dense pipeline layout. This is kind of shady as we may want
+ // to error in these cases where users have something special explicitly
+ // defined but then typo things but the ergonomic improvements in the
+ // normal case are worth that risk.
+ size_t resourceCount = dispatchOp.getResources().size();
+ for (int i = 0; i < resourceCount; ++i) {
+ // set=0, binding=resource ordinal
+ resourceMap.emplace_back(0, i);
+ }
+ }
+ annotateDispatchSite(dispatchOp, resourceMap);
+ return WalkResult::advance();
+ };
+ if (getOperation()->walk(annotateDispatchOp).wasInterrupted()) {
+ return signalPassFailure();
+ }
}
};
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
index 33ba80e..394c5e8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/TranslateExecutables.cpp
@@ -63,7 +63,7 @@
}
OpPassManager passManager(variantOp.getOperationName());
- targetBackend->buildTranslationPassPipeline(passManager);
+ targetBackend->buildTranslationPassPipeline(variantOp, passManager);
if (failed(runPipeline(passManager, variantOp))) {
variantOp.emitError() << "failed to run translation of source "
"executable to target executable for backend "
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
index d1bdbad..bdcdf81 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir
@@ -26,6 +26,7 @@
// CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
// CHECK-NEXT: }
// CHECK: builtin.module
+ // CHECK-NEXT: func.func private @extern_func()
// CHECK-NEXT: func.func @entry
// CHECK: hal.executable.variant public @embedded_elf_x86_64, target = #executable_target_embedded_elf_x86_64
// CHECK: hal.executable.export public @entry ordinal(0) layout(#pipeline_layout) {
@@ -33,6 +34,7 @@
// CHECK-NEXT: hal.return %[[ARG0]], %[[ARG1]], %[[ARG0]] : index, index, index
// CHECK-NEXT: }
// CHECK: builtin.module
+ // CHECK-NEXT: func.func private @extern_func()
// CHECK-NEXT: func.func @entry
stream.executable private @ex_workgroups {
@@ -40,6 +42,7 @@
stream.return %arg0, %arg1, %arg0 : index, index, index
}
builtin.module {
+ func.func private @extern_func()
func.func @entry(%operand: i32, %arg0: !stream.binding {stream.alignment = 64 : index}, %arg1: !stream.binding {stream.alignment = 64 : index}, %arg2: !stream.binding {stream.alignment = 64 : index}) {
return
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
index 47f8228..2b96754 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp
@@ -316,11 +316,12 @@
}
}
- rewriter.replaceOpWithNewOp<IREE::Stream::AsyncDispatchOp>(
+ auto newOp = rewriter.replaceOpWithNewOp<IREE::Stream::AsyncDispatchOp>(
op, resultTypes, adaptor.getWorkload(), adaptor.getEntryPoint(),
dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets,
dispatchOperandEnds, dispatchOperandLengths, resultSizes,
adaptor.getTiedOperandsAttr(), getAffinityFor(op));
+ newOp->setDialectAttrs(op->getDialectAttrs());
return success();
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
index d65b4d8..e48959b 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp
@@ -644,10 +644,11 @@
newResourceAccesses.push_back(resourceAccess);
}
- builder.create<IREE::Stream::CmdDispatchOp>(
+ auto newOp = builder.create<IREE::Stream::CmdDispatchOp>(
asyncOp.getLoc(), asyncOp.getWorkload(), asyncOp.getEntryPoint(),
newOperands, newResources, newResourceSizes, newResourceOffsets,
newResourceLengths, builder.getArrayAttr(newResourceAccesses));
+ newOp->setDialectAttrs(asyncOp->getDialectAttrs());
asyncOp.erase();
return success();
}
diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.c b/runtime/src/iree/hal/drivers/cuda/native_executable.c
index 2c2bb1b..3c0d241 100644
--- a/runtime/src/iree/hal/drivers/cuda/native_executable.c
+++ b/runtime/src/iree/hal/drivers/cuda/native_executable.c
@@ -93,22 +93,28 @@
"cuModuleLoadDataEx");
}
- executable->entry_count = entry_count;
- for (iree_host_size_t i = 0; i < entry_count; i++) {
- if (iree_status_is_ok(status)) {
+ if (iree_status_is_ok(status)) {
+ executable->entry_count = entry_count;
+ for (iree_host_size_t i = 0; i < entry_count; i++) {
CUfunction function = NULL;
const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
status = CU_RESULT_TO_STATUS(
context->syms, cuModuleGetFunction(&function, module, entry_name),
"cuModuleGetFunction");
- if (iree_status_is_ok(status)) {
- status = CU_RESULT_TO_STATUS(
- context->syms,
- cuFuncSetAttribute(function,
- CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
- shared_memory_sizes[i]),
- "cuFuncSetAttribute");
+ if (!iree_status_is_ok(status)) break;
+ if (!function) {
+ status = iree_make_status(IREE_STATUS_NOT_FOUND,
+ "exported module function %s not found",
+ entry_name);
+ break;
}
+ status = CU_RESULT_TO_STATUS(
+ context->syms,
+ cuFuncSetAttribute(function,
+ CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
+ shared_memory_sizes[i]),
+ "cuFuncSetAttribute");
+ if (!iree_status_is_ok(status)) break;
executable->entry_functions[i].cu_function = function;
executable->entry_functions[i].block_size_x = block_sizes_vec[i].x;
executable->entry_functions[i].block_size_y = block_sizes_vec[i].y;
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
index 3ddfcc9..06563bc 100644
--- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -353,6 +353,7 @@
iree_hal_cuda_pipeline_layout_num_constants(layout);
iree_host_size_t constant_base_index =
iree_hal_cuda_push_constant_index(layout);
+
// Patch the push constants in the kernel arguments.
for (iree_host_size_t i = 0; i < num_constants; i++) {
*((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) =