Preserve reflection attrs on functions when wrapping for the native ABI. (#16129)

(required for #16130)
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
index 2d0e90e..6216eab 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/WrapEntryPoints.cpp
@@ -400,6 +400,12 @@
   auto *context = exportOp.getContext();
   SmallVector<NamedAttribute> attrs;
 
+  if (auto reflectionAttr =
+          exportOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+    attrs.append(reflectionAttr.getValue().begin(),
+                 reflectionAttr.getValue().end());
+  }
+
   if (auto abiAttr = exportOp->getAttr("iree.abi")) {
     attrs.emplace_back(StringAttr::get(context, "iree.abi"), abiAttr);
   }
@@ -487,6 +493,7 @@
 
   // Populate the reflection attrs based on the original types.
   populateReflectionAttrs(invocationModel, exportOp, wrapperOp);
+  exportOp->removeAttr("iree.reflection");
 
   auto *entryBlock = wrapperOp.addEntryBlock();
   auto entryBuilder = OpBuilder::atBlockBegin(entryBlock);
diff --git a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
index 16045fc..5aa9da3 100644
--- a/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
+++ b/compiler/src/iree/compiler/Bindings/Native/Transforms/test/wrap_entry_points.mlir
@@ -33,6 +33,25 @@
 
 // -----
 
+// Tests that an existing iree.reflection dictionary is merged with the new
+// reflection information.
+
+// CHECK-LABEL: func.func @existingReflection
+// CHECK-SAME: iree.reflection =
+// CHECK-SAME:     iree.abi.declaration = "sync func @existingReflection
+// CHECK-SAME:     some.attr = 4 : index
+// CHECK: func.func private @_existingReflection
+// CHECK-NOT: iree.reflection = {some.attr = 4 : index}
+func.func @existingReflection() attributes {
+  iree.reflection = {
+    some.attr = 4 : index
+  }
+} {
+  return
+}
+
+// -----
+
 // Tests that iree.abi.declaration is added when needed and otherwise the user
 // provided value is passed through.
 
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp
index 3404799..36851d3 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ConversionTarget.cpp
@@ -19,6 +19,11 @@
   if (!innerModuleOp) {
     innerModuleOp =
         ModuleOp::create(outerModuleOp.getLoc(), outerModuleOp.getName());
+    if (auto reflectionAttr =
+            outerModuleOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+      innerModuleOp->setAttr("iree.reflection", reflectionAttr);
+      outerModuleOp->removeAttr("iree.reflection");
+    }
     innerModuleOp.getBodyRegion().takeBody(outerModuleOp.getBodyRegion());
     outerModuleOp.getBodyRegion().getBlocks().push_back(new Block());
     outerModuleOp.push_back(innerModuleOp);
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
index de92cc8..36d87ac 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/ConvertStandardToVM.cpp
@@ -46,6 +46,10 @@
     if (auto version = srcOp->getAttrOfType<IntegerAttr>("vm.version")) {
       newModuleOp.setVersionAttr(version);
     }
+    if (auto reflectionAttr =
+            srcOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
+      newModuleOp->setAttr("iree.reflection", reflectionAttr);
+    }
     Block *firstCreatedBlock = &newModuleOp.getBodyRegion().front();
     rewriter.inlineRegionBefore(srcOp.getBodyRegion(), firstCreatedBlock);
     auto blockRange = llvm::make_range(Region::iterator(firstCreatedBlock),
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
index dac72f4..4cd16ab 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp
@@ -193,6 +193,27 @@
   return success();
 }
 
+// Returns a list of reflection AttrDefs with entries from |attrs| (or an
+// empty/null list).
+static iree_vm_AttrDef_vec_ref_t makeAttrDefs(DictionaryAttr attrs,
+                                              FlatbufferBuilder &fbb) {
+  if (!attrs || attrs.empty())
+    return 0;
+  SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
+  for (auto attr : attrs) {
+    auto key = attr.getName().strref();
+    auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
+    if (!value || key.empty())
+      continue;
+    // NOTE: if we actually want to keep these we should dedupe them (as the
+    // keys and likely several of the values are shared across all functions).
+    auto valueRef = fbb.createString(value.getValue());
+    auto keyRef = fbb.createString(key);
+    attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
+  }
+  return iree_vm_AttrDef_vec_create(fbb, attrRefs.data(), attrRefs.size());
+}
+
 // Creates a FunctionSignatureDef based on the given function metadata.
 // Some fields are not used on all signature defs and added only when present on
 // the argument objects/attrs.
@@ -236,24 +257,9 @@
   if (!cconv.has_value())
     return {};
 
-  // Reflection attributes.
-  iree_vm_AttrDef_vec_ref_t attrsRef = 0;
-  if (auto attrs = funcOp->getAttrOfType<DictionaryAttr>("iree.reflection")) {
-    SmallVector<iree_vm_AttrDef_ref_t> attrRefs;
-    for (auto attr : attrs) {
-      auto key = attr.getName().strref();
-      auto value = llvm::dyn_cast<StringAttr>(attr.getValue());
-      if (!value || key.empty())
-        continue;
-      // NOTE: if we actually want to keep these we should dedupe them (as the
-      // keys and likely several of the values are shared across all functions).
-      auto valueRef = fbb.createString(value.getValue());
-      auto keyRef = fbb.createString(key);
-      attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef));
-    }
-    attrsRef =
-        iree_vm_AttrDef_vec_create(fbb, attrRefs.data(), attrRefs.size());
-  }
+  // Encode reflection attributes.
+  iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs(
+      funcOp->getAttrOfType<DictionaryAttr>("iree.reflection"), fbb);
 
   return createFunctionSignatureDef(funcOp.getFunctionType(), typeTable,
                                     cconv.value(), attrsRef, fbb);
@@ -474,6 +480,10 @@
     return iree_vm_TypeDef_end(fbb);
   });
 
+  // Encode reflection attributes.
+  iree_vm_AttrDef_vec_ref_t attrsRef = makeAttrDefs(
+      moduleOp->getAttrOfType<DictionaryAttr>("iree.reflection"), fbb);
+
   // NOTE: we keep the vectors clustered here so that we can hopefully keep the
   // pages mapped at runtime; vector dereferences in FlatBuffers require
   // touching these structs to get length/etc and as such we don't want to be
@@ -525,7 +535,7 @@
   iree_vm_BytecodeModuleDef_version_add(fbb,
                                         moduleOp.getVersion().value_or(0u));
   iree_vm_BytecodeModuleDef_requirements_add(fbb, moduleRequirements);
-  // TODO(benvanik): iree_vm_BytecodeModuleDef_attrs_add
+  iree_vm_BytecodeModuleDef_attrs_add(fbb, attrsRef);
   iree_vm_BytecodeModuleDef_types_add(fbb, typesRef);
   iree_vm_BytecodeModuleDef_dependencies_add(fbb, dependenciesRef);
   iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef);
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel
index 5c93641..fc411fb 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/BUILD.bazel
@@ -18,8 +18,8 @@
         [
             "constant_encoding.mlir",
             "dependencies.mlir",
-            "function_attrs.mlir",
             "module_encoding_smoke.mlir",
+            "reflection_attrs.mlir",
         ],
         include = ["*.mlir"],
     ),
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt
index 5e10f37..32cf14a 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/CMakeLists.txt
@@ -16,8 +16,8 @@
   SRCS
     "constant_encoding.mlir"
     "dependencies.mlir"
-    "function_attrs.mlir"
     "module_encoding_smoke.mlir"
+    "reflection_attrs.mlir"
   TOOLS
     FileCheck
     iree-compile
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/function_attrs.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/function_attrs.mlir
deleted file mode 100644
index 243b17a..0000000
--- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/function_attrs.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: iree-compile --split-input-file --compile-mode=vm \
-// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s | FileCheck %s
-
-// CHECK-LABEL: simple_module
-vm.module @simple_module {
-  vm.export @func
-  // CHECK: "exported_functions":
-  // CHECK: "attrs":
-  // CHECK:   "key": "f"
-  // CHECK:   "value": "FOOBAR"
-  // CHECK:   "key": "fv"
-  // CHECK:   "value": "INFINITY"
-  vm.func @func(%arg0 : i32) -> i32
-    attributes { iree.reflection = { f = "FOOBAR", fv = "INFINITY" } }
-  {
-    vm.return %arg0 : i32
-  }
-}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
new file mode 100644
index 0000000..01463b1
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/test/reflection_attrs.mlir
@@ -0,0 +1,31 @@
+// RUN: iree-compile --split-input-file --compile-mode=vm \
+// RUN: --iree-vm-bytecode-module-output-format=flatbuffer-text %s | FileCheck %s
+
+// CHECK-LABEL: simple_module
+// CHECK: "attrs":
+// CHECK:   "key": "module_attr_0"
+// CHECK:   "value": "MODULE_ATTR_0"
+// CHECK:   "key": "module_attr_1"
+// CHECK:   "value": "MODULE_ATTR_1"
+vm.module @simple_module attributes {
+  iree.reflection = {
+    module_attr_0 = "MODULE_ATTR_0",
+    module_attr_1 = "MODULE_ATTR_1"
+  }
+} {
+  vm.export @func
+  // CHECK: "exported_functions":
+  // CHECK: "attrs":
+  // CHECK:   "key": "func_attr_0"
+  // CHECK:   "value": "FUNC_ATTR_0"
+  // CHECK:   "key": "func_attr_1"
+  // CHECK:   "value": "FUNC_ATTR_1"
+  vm.func @func(%arg0 : i32) -> i32 attributes {
+    iree.reflection = {
+      func_attr_0 = "FUNC_ATTR_0",
+      func_attr_1 = "FUNC_ATTR_1"
+    }
+  } {
+    vm.return %arg0 : i32
+  }
+}
diff --git a/tools/iree-dump-module-main.c b/tools/iree-dump-module-main.c
index 70780f1..8d284bb 100644
--- a/tools/iree-dump-module-main.c
+++ b/tools/iree-dump-module-main.c
@@ -401,7 +401,7 @@
     fprintf(stdout, "Attributes:\n");
     iree_tooling_print_attr_defs(iree_vm_BytecodeModuleDef_attrs(module_def),
                                  2);
-    fprintf(stdout, "\n\n");
+    fprintf(stdout, "\n");
   }
 
   if (iree_vm_BytecodeModuleDef_types_is_present(module_def)) {