Merge google -> main (#3488)

* 84467e66 Merge pull request #3486 from ScottTodd:main-to-google
* 66bddbfd Merge branch 'google' into main-to-google
* 6d2473bd Synchronize submodules
* a8a8242c Integrate TF at tensorflow/tensorflow@069b9b2f1079
* dabd7364 Synchronize submodules
* 3510a109 Integrate LLVM at llvm/llvm-project@50df5f24dc33
* 227e12d4 Synchronize submodules
* 14c5540c Integrate LLVM at llvm/llvm-project@196bee9648a9
* ddbead4d Synchronize submodules
* 340a030b Integrate LLVM at llvm/llvm-project@220de1f32add
* 7f12e791 Add support for building TFHub intree.
* 535f63eb Synchronize submodules
* e001092b Integrate LLVM at llvm/llvm-project@89657b3a3b57
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index 752a6fd..ab3c3a4 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -41,7 +41,7 @@
 
 namespace {
 
-bool AreInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
+bool areInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
                              IREE::HAL::InterfaceOp rhs) {
   auto lhsBindings = lhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
   auto rhsBindings = rhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
@@ -125,7 +125,8 @@
     //   }
 
     OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
-    auto executableOps = moduleOp.getOps<IREE::HAL::ExecutableOp>();
+    auto executableOps =
+        llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
 
     // Create our new "linked" hal.executable.
     auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
@@ -143,9 +144,14 @@
     auto linkedVmModuleOp =
         builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
 
-    int executablesLinked = 0;
     llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
     int nextEntryPointOrdinal = 0;
+    DenseMap<StringRef, Operation *> symbolMap;
+    DenseMap<Attribute, Attribute> entryPointRefReplacements;
+    auto linkedExecutableBuilder =
+        OpBuilder::atBlockBegin(linkedExecutableOp.getBody());
+    auto linkedTargetBuilder =
+        OpBuilder::atBlockBegin(linkedTargetOp.getBody());
     for (auto executableOp : executableOps) {
       auto targetOps = llvm::to_vector<4>(
           executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
@@ -157,103 +163,112 @@
 
         IREE::HAL::InterfaceOp interfaceOpForExecutable;
         for (auto interfaceOp : interfaceOps) {
-          if (AreInterfacesEquivalent(interfaceOp,
+          if (areInterfacesEquivalent(interfaceOp,
                                       executableOp.getFirstInterfaceOp())) {
             interfaceOpForExecutable = interfaceOp;
+            break;
           }
         }
         if (!interfaceOpForExecutable) {
-          builder.setInsertionPoint(linkedTargetOp);
-          interfaceOpForExecutable = dyn_cast<IREE::HAL::InterfaceOp>(
-              builder.clone(*executableOp.getFirstInterfaceOp()));
+          interfaceOpForExecutable =
+              dyn_cast<IREE::HAL::InterfaceOp>(linkedExecutableBuilder.clone(
+                  *executableOp.getFirstInterfaceOp()));
           interfaceOpForExecutable.setName(
               llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
           interfaceOps.push_back(interfaceOpForExecutable);
         }
 
-        // Clone entry point ops, remapping ordinals and updating symbol refs.
-        builder.setInsertionPoint(linkedModuleOp);
+        // Clone entry point ops and queue remapping ordinals and updating
+        // symbol refs.
         for (auto entryPointOp :
              targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
           auto newEntryPointOp =
-              builder.create<IREE::HAL::ExecutableEntryPointOp>(
+              linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
                   entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
                   builder.getI32IntegerAttr(nextEntryPointOrdinal++),
                   builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
                   entryPointOp.signatureAttr());
 
-          // Update references to @executable::@target::@entry symbols.
-          // SymbolTable::replaceAllSymbolUses only looks at root symbols,
-          // which we can't blindly replace (other targets will map to other
-          // linked executables).
-          auto executableUses =
-              SymbolTable::getSymbolUses(executableOp, moduleOp);
-          if (!executableUses.hasValue()) continue;
-          for (auto executableUse : executableUses.getValue()) {
-            auto executableUser = executableUse.getUser();
-            // Only process symbols for this @target::@entry.
-            auto nestedRefs =
-                executableUse.getSymbolRef().getNestedReferences();
-            if (nestedRefs.size() != 2 ||
-                nestedRefs[0].getValue() != targetOp.sym_name() ||
-                nestedRefs[1].getValue() != entryPointOp.sym_name()) {
-              continue;
-            }
-            if (auto dispatchOp =
-                    dyn_cast<IREE::HAL::CommandBufferDispatchSymbolOp>(
-                        executableUser)) {
-              // New nested reference to the linked exe/target/entry.
-              StringRef newExecutableOpSymName =
-                  linkedExecutableOp
-                      .getAttrOfType<StringAttr>(
-                          SymbolTable::getSymbolAttrName())
-                      .getValue();
-              auto newSymbolRefAttr = builder.getSymbolRefAttr(
-                  newExecutableOpSymName,
-                  {builder.getSymbolRefAttr(linkedTargetOp),
-                   builder.getSymbolRefAttr(newEntryPointOp)});
-              dispatchOp.setAttr("entry_point", newSymbolRefAttr);
-            }
-          }
+          // Add to replacement table for fixing up dispatch calls referencing
+          // this entry point.
+          auto oldSymbolRefAttr = builder.getSymbolRefAttr(
+              executableOp.getName(), {builder.getSymbolRefAttr(targetOp),
+                                       builder.getSymbolRefAttr(entryPointOp)});
+          auto newSymbolRefAttr = builder.getSymbolRefAttr(
+              linkedExecutableOp.getName(),
+              {builder.getSymbolRefAttr(linkedTargetOp),
+               builder.getSymbolRefAttr(newEntryPointOp)});
+          entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
         }
 
-        // Clone vm.module ops, including their contents.
+        // Merge the existing vm.module op into the new linked vm.module op.
         auto vmModuleOps =
             targetOp.getInnerModule().getOps<IREE::VM::ModuleOp>();
         if (vmModuleOps.empty()) {
           return targetOp.getInnerModule().emitError()
                  << "target's outer module does not contain a vm.module op";
         }
-        auto vmModuleOp = *vmModuleOps.begin();
-        builder.setInsertionPoint(&linkedVmModuleOp.getBlock().back());
-        // Use a SymbolTable to guard against inserting duplicate symbols.
-        SymbolTable symbolTable(linkedVmModuleOp.getOperation());
+        mergeModuleInto(*vmModuleOps.begin(), linkedVmModuleOp, symbolMap);
 
-        for (auto &op : vmModuleOp.getBody()->getOperations()) {
-          if (auto terminatorOp = dyn_cast<IREE::VM::ModuleTerminatorOp>(op)) {
-            continue;
-          }
-          if (op.hasTrait<SymbolOpInterface::Trait>() &&
-              symbolTable.lookup(dyn_cast<SymbolOpInterface>(op).getName())) {
-            continue;
-          }
-          builder.clone(op);
-        }
-
-        // Now that we're done cloning its ops, delete the original target op.
         targetOp.erase();
+      }
 
-        executablesLinked++;
+      if (executableOp.getOps<IREE::HAL::ExecutableTargetOp>().empty()) {
+        executableOp.erase();
       }
     }
 
-    if (executablesLinked == 0) {
+    // Update references to @executable::@target::@entry symbols.
+    replaceEntryPointUses(moduleOp, entryPointRefReplacements);
+
+    // Remove if we didn't add anything.
+    if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
+      linkedTargetOp.erase();
       linkedExecutableOp.erase();
     }
 
     return success();
   }
 
+  // Destructively merges |sourceModuleOp| into |targetModuleOp|.
+  // |targetSymbolTable| is updated with the new symbols.
+  void mergeModuleInto(IREE::VM::ModuleOp sourceModuleOp,
+                       IREE::VM::ModuleOp targetModuleOp,
+                       DenseMap<StringRef, Operation *> &targetSymbolMap) {
+    auto allOps = llvm::to_vector<8>(llvm::map_range(
+        sourceModuleOp.getBlock(), [&](Operation &op) { return &op; }));
+    for (auto &op : allOps) {
+      if (op->isKnownTerminator()) continue;
+      if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
+        if (targetSymbolMap.count(symbolInterface.getName())) {
+          // TODO(scotttodd): compare ops to ensure we aren't copying different
+          // things with the same name.
+          continue;
+        }
+        targetSymbolMap[symbolInterface.getName()] = op;
+      }
+      op->moveBefore(&targetModuleOp.getBlock().back());
+    }
+
+    // Now that we're done cloning its ops, delete the original target op.
+    sourceModuleOp.erase();
+  }
+
+  // Replaces each usage of an entry point with its original symbol name with a
+  // new symbol name.
+  void replaceEntryPointUses(
+      mlir::ModuleOp moduleOp,
+      const DenseMap<Attribute, Attribute> &replacements) {
+    for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+      funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
+        auto it = replacements.find(dispatchOp.entry_point());
+        if (it != replacements.end()) {
+          dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
+        }
+      });
+    }
+  }
+
   LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
                                     OpBuilder &executableBuilder) override {
     // Serialize the VM module to bytes.