Preserve reflection attrs on functions when wrapping for the native ABI. (#16129)
(required for #16130)
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index 2d0e90e..6216eab 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -400,6 +400,12 @@
auto *context = exportOp.getContext();
SmallVector<NamedAttribute> attrs;
+ if (auto reflectionAttr =
+ exportOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+ attrs.append(reflectionAttr.getValue().begin(),
+ reflectionAttr.getValue().end());
+ }
+
if (auto abiAttr = exportOp->getAttr("iree.abi")) {
attrs.emplace_back(StringAttr::get(context, "iree.abi"), abiAttr);
}
@@ -487,6 +493,7 @@
// Populate the reflection attrs based on the original types.
populateReflectionAttrs(invocationModel, exportOp, wrapperOp);
+ exportOp->removeAttr("iree.reflection");
auto *entryBlock = wrapperOp.addEntryBlock();
auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 16045fc..5aa9da3 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -33,6 +33,25 @@
// -----
+// Tests that an existing iree.reflection dictionary is merged with the new
+// reflection information.
+
+// CHECK-LABEL: func.func @existingReflection
+// CHECK-SAME: iree.reflection =
+// CHECK-SAME: iree.abi.declaration = "sync func @existingReflection
+// CHECK-SAME: some.attr = 4 : index
+// CHECK: func.func private @_existingReflection
+// CHECK-NOT: iree.reflection = {some.attr = 4 : index}
+func.func @existingReflection() attributes {
+ iree.reflection = {
+ some.attr = 4 : index
+ }
+} {
+ return
+}
+
+// -----
+
// Tests that iree.abi.declaration is added when needed and otherwise the user
// provided value is passed through.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp
index 3404799..36851d3 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp
@@ -19,6 +19,11 @@
if (!innerModuleOp) {
innerModuleOp =
ModuleOp::create(outerModuleOp.getLoc(), outerModuleOp.getName());
+ if (auto reflectionAttr =
+ outerModuleOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+ innerModuleOp->setAttr("iree.reflection", reflectionAttr);
+ outerModuleOp->removeAttr("iree.reflection");
+ }
innerModuleOp.getBodyRegion().takeBody(outerModuleOp.getBodyRegion());
outerModuleOp.getBodyRegion().getBlocks().push_back(new Block());
outerModuleOp.push_back(innerModuleOp);
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index de92cc8..36d87ac 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -46,6 +46,10 @@
if (auto version = srcOp->getAttrOfType<IntegerAttr>("vm.version")) {
newModuleOp.setVersionAttr(version);
}
+ if (auto reflectionAttr =
+ srcOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+ newModuleOp->setAttr("iree.reflection", reflectionAttr);
+ }
Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front();
rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock);
auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock),
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index dac72f4..4cd16ab 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -193,6 +193,27 @@
return success();
}
+// Returns a list of reflection AttrDefs with entries from |attrs| (or an
+// empty/null list).
+static iree_vm_AttrDef_vec_ref_t makeAttrDefs(DictionaryAttr attrs,
+ FlatbufferBuilder &fbb) {
+ if (!attrs || attrs.empty())
+ return 0;
+ SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
+ for (auto attr : attrs) {
+ auto key = attr.getName().strref();
+ auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
+ if (!value || key.empty())
+ continue;
+ // NOTE: if we actually want to keep these we should dedupe them (as the
+ // keys and likely several of the values are shared across all functions).
+ auto valueRef = fbb.createString(value.getValue());
+ auto keyRef = fbb.createString(key);
+ attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
+ }
+ return iree_vm_AttrDef_vec_create(fbb, attrRefs.data(), attrRefs.size());
+}
+
// Creates a FunctionSignatureDef based on the given function metadata.
// Some fields are not used on all signature defs and added only when present on
// the argument objects/attrs.
@@ -236,24 +257,9 @@
if (!cconv.has_value())
return {};
- // Reflection attributes.
- iree_vm_AttrDef_vec_ref_t attrsRef = 0;
- if (auto attrs = funcOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
- SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
- for (auto attr : attrs) {
- auto key = attr.getName().strref();
- auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
- if (!value || key.empty())
- continue;
- // NOTE: if we actually want to keep these we should dedupe them (as the
- // keys and likely several of the values are shared across all functions).
- auto valueRef = fbb.createString(value.getValue());
- auto keyRef = fbb.createString(key);
- attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
- }
- attrsRef =
- iree_vm_AttrDef_vec_create(fbb, attrRefs.data(), attrRefs.size());
- }
+ // Encode reflection attributes.
+ iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs(
+ funcOp->getAttrOfType<DictionaryAttr>("iree.reflection"), fbb);
return createFunctionSignatureDef(funcOp.getFunctionType(), typeTable,
cconv.value(), attrsRef, fbb);
@@ -474,6 +480,10 @@
return iree_vm_TypeDef_end(fbb);
});
+ // Encode reflection attributes.
+ iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs(
+ moduleOp->getAttrOfType<DictionaryAttr>("iree.reflection"), fbb);
+
// NOTE: we keep the vectors clustered here so that we can hopefully keep the
// pages mapped at runtime; vector dereferences in FlatBuffers require
// touching these structs to get length/etc and as such we don't want to be
@@ -525,7 +535,7 @@
iree_vm_BytecodeModuleDef_version_add(fbb,
moduleOp.getVersion().value_or(0u));
iree_vm_BytecodeModuleDef_requirements_add(fbb, moduleRequirements);
- // TODO(benvanik): iree_vm_BytecodeModuleDef_attrs_add
+ iree_vm_BytecodeModuleDef_attrs_add(fbb, attrsRef);
iree_vm_BytecodeModuleDef_types_add(fbb, typesRef);
iree_vm_BytecodeModuleDef_dependencies_add(fbb, dependenciesRef);
iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef);
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel
index 5c93641..fc411fb 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel
@@ -18,8 +18,8 @@
[
"constant_encoding.mlir",
"dependencies.mlir",
- "function_attrs.mlir",
"module_encoding_smoke.mlir",
+ "reflection_attrs.mlir",
],
include = ["*.mlir"],
),
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt
index 5e10f37..32cf14a 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt
@@ -16,8 +16,8 @@
SRCS
"constant_encoding.mlir"
"dependencies.mlir"
- "function_attrs.mlir"
"module_encoding_smoke.mlir"
+ "reflection_attrs.mlir"
TOOLS
FileCheck
iree-compile
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/function_attrs.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/function_attrs.mlir
deleted file mode 100644
index 243b17a..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/function_attrs.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: iree-compile --split-input-file --compile-mode=vm \
-// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s | FileCheck %s
-
-// CHECK-LABEL: simple_module
-vm.module @simple_module {
- vm.export @func
- // CHECK: "exported_functions":
- // CHECK: "attrs":
- // CHECK: "key": "f"
- // CHECK: "value": "FOOBAR"
- // CHECK: "key": "fv"
- // CHECK: "value": "INFINITY"
- vm.func @func(%arg0 : i32) -> i32
- attributes { iree.reflection = { f = "FOOBAR", fv = "INFINITY" } }
- {
- vm.return %arg0 : i32
- }
-}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
new file mode 100644
index 0000000..01463b1
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-compile --split-input-file --compile-mode=vm \
+// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s | FileCheck %s
+
+// CHECK-LABEL: simple_module
+// CHECK: "attrs":
+// CHECK: "key": "module_attr_0"
+// CHECK: "value": "MODULE_ATTR_0"
+// CHECK: "key": "module_attr_1"
+// CHECK: "value": "MODULE_ATTR_1"
+vm.module @simple_module attributes {
+ iree.reflection = {
+ module_attr_0 = "MODULE_ATTR_0",
+ module_attr_1 = "MODULE_ATTR_1"
+ }
+} {
+ vm.export @func
+ // CHECK: "exported_functions":
+ // CHECK: "attrs":
+ // CHECK: "key": "func_attr_0"
+ // CHECK: "value": "FUNC_ATTR_0"
+ // CHECK: "key": "func_attr_1"
+ // CHECK: "value": "FUNC_ATTR_1"
+ vm.func @func(%arg0 : i32) -> i32 attributes {
+ iree.reflection = {
+ func_attr_0 = "FUNC_ATTR_0",
+ func_attr_1 = "FUNC_ATTR_1"
+ }
+ } {
+ vm.return %arg0 : i32
+ }
+}
diff --git a/tools/iree-dump-module-main.c b/tools/iree-dump-module-main.c
index 70780f1..8d284bb 100644
--- a/tools/iree-dump-module-main.c
+++ b/tools/iree-dump-module-main.c
@@ -401,7 +401,7 @@
fprintf(stdout, "Attributes:\n");
iree_tooling_print_attr_defs(iree_vm_BytecodeModuleDef_attrs(module_def),
2);
- fprintf(stdout, "\n\n");
+ fprintf(stdout, "\n");
}
if (iree_vm_BytecodeModuleDef_types_is_present(module_def)) {