Merge google -> main (#3488)
* 84467e66 Merge pull request #3486 from ScottTodd:main-to-google
* 66bddbfd Merge branch 'google' into main-to-google
* 6d2473bd Synchronize submodules
* a8a8242c Integrate TF at tensorflow/tensorflow@069b9b2f1079
* dabd7364 Synchronize submodules
* 3510a109 Integrate LLVM at llvm/llvm-project@50df5f24dc33
* 227e12d4 Synchronize submodules
* 14c5540c Integrate LLVM at llvm/llvm-project@196bee9648a9
* ddbead4d Synchronize submodules
* 340a030b Integrate LLVM at llvm/llvm-project@220de1f32add
* 7f12e791 Add support for building TFHub intree.
* 535f63eb Synchronize submodules
* e001092b Integrate LLVM at llvm/llvm-project@89657b3a3b57
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
index 752a6fd..ab3c3a4 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
@@ -41,7 +41,7 @@
namespace {
-bool AreInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
+bool areInterfacesEquivalent(IREE::HAL::InterfaceOp lhs,
IREE::HAL::InterfaceOp rhs) {
auto lhsBindings = lhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
auto rhsBindings = rhs.getBlock().getOps<IREE::HAL::InterfaceBindingOp>();
@@ -125,7 +125,8 @@
// }
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());
- auto executableOps = moduleOp.getOps<IREE::HAL::ExecutableOp>();
+ auto executableOps =
+ llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>());
// Create our new "linked" hal.executable.
auto linkedExecutableOp = builder.create<IREE::HAL::ExecutableOp>(
@@ -143,9 +144,14 @@
auto linkedVmModuleOp =
builder.create<IREE::VM::ModuleOp>(moduleOp.getLoc(), "linked_module");
- int executablesLinked = 0;
llvm::SmallVector<IREE::HAL::InterfaceOp, 4> interfaceOps;
int nextEntryPointOrdinal = 0;
+ DenseMap<StringRef, Operation *> symbolMap;
+ DenseMap<Attribute, Attribute> entryPointRefReplacements;
+ auto linkedExecutableBuilder =
+ OpBuilder::atBlockBegin(linkedExecutableOp.getBody());
+ auto linkedTargetBuilder =
+ OpBuilder::atBlockBegin(linkedTargetOp.getBody());
for (auto executableOp : executableOps) {
auto targetOps = llvm::to_vector<4>(
executableOp.getOps<IREE::HAL::ExecutableTargetOp>());
@@ -157,103 +163,112 @@
IREE::HAL::InterfaceOp interfaceOpForExecutable;
for (auto interfaceOp : interfaceOps) {
- if (AreInterfacesEquivalent(interfaceOp,
+ if (areInterfacesEquivalent(interfaceOp,
executableOp.getFirstInterfaceOp())) {
interfaceOpForExecutable = interfaceOp;
+ break;
}
}
if (!interfaceOpForExecutable) {
- builder.setInsertionPoint(linkedTargetOp);
- interfaceOpForExecutable = dyn_cast<IREE::HAL::InterfaceOp>(
- builder.clone(*executableOp.getFirstInterfaceOp()));
+ interfaceOpForExecutable =
+ dyn_cast<IREE::HAL::InterfaceOp>(linkedExecutableBuilder.clone(
+ *executableOp.getFirstInterfaceOp()));
interfaceOpForExecutable.setName(
llvm::formatv("legacy_io_{0}", interfaceOps.size()).str());
interfaceOps.push_back(interfaceOpForExecutable);
}
- // Clone entry point ops, remapping ordinals and updating symbol refs.
- builder.setInsertionPoint(linkedModuleOp);
+ // Clone entry point ops and queue remapping ordinals and updating
+ // symbol refs.
for (auto entryPointOp :
targetOp.getOps<IREE::HAL::ExecutableEntryPointOp>()) {
auto newEntryPointOp =
- builder.create<IREE::HAL::ExecutableEntryPointOp>(
+ linkedTargetBuilder.create<IREE::HAL::ExecutableEntryPointOp>(
entryPointOp.getLoc(), entryPointOp.sym_nameAttr(),
builder.getI32IntegerAttr(nextEntryPointOrdinal++),
builder.getSymbolRefAttr(interfaceOpForExecutable.getName()),
entryPointOp.signatureAttr());
- // Update references to @executable::@target::@entry symbols.
- // SymbolTable::replaceAllSymbolUses only looks at root symbols,
- // which we can't blindly replace (other targets will map to other
- // linked executables).
- auto executableUses =
- SymbolTable::getSymbolUses(executableOp, moduleOp);
- if (!executableUses.hasValue()) continue;
- for (auto executableUse : executableUses.getValue()) {
- auto executableUser = executableUse.getUser();
- // Only process symbols for this @target::@entry.
- auto nestedRefs =
- executableUse.getSymbolRef().getNestedReferences();
- if (nestedRefs.size() != 2 ||
- nestedRefs[0].getValue() != targetOp.sym_name() ||
- nestedRefs[1].getValue() != entryPointOp.sym_name()) {
- continue;
- }
- if (auto dispatchOp =
- dyn_cast<IREE::HAL::CommandBufferDispatchSymbolOp>(
- executableUser)) {
- // New nested reference to the linked exe/target/entry.
- StringRef newExecutableOpSymName =
- linkedExecutableOp
- .getAttrOfType<StringAttr>(
- SymbolTable::getSymbolAttrName())
- .getValue();
- auto newSymbolRefAttr = builder.getSymbolRefAttr(
- newExecutableOpSymName,
- {builder.getSymbolRefAttr(linkedTargetOp),
- builder.getSymbolRefAttr(newEntryPointOp)});
- dispatchOp.setAttr("entry_point", newSymbolRefAttr);
- }
- }
+ // Add to replacement table for fixing up dispatch calls referencing
+ // this entry point.
+ auto oldSymbolRefAttr = builder.getSymbolRefAttr(
+ executableOp.getName(), {builder.getSymbolRefAttr(targetOp),
+ builder.getSymbolRefAttr(entryPointOp)});
+ auto newSymbolRefAttr = builder.getSymbolRefAttr(
+ linkedExecutableOp.getName(),
+ {builder.getSymbolRefAttr(linkedTargetOp),
+ builder.getSymbolRefAttr(newEntryPointOp)});
+ entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
}
- // Clone vm.module ops, including their contents.
+ // Merge the existing vm.module op into the new linked vm.module op.
auto vmModuleOps =
targetOp.getInnerModule().getOps<IREE::VM::ModuleOp>();
if (vmModuleOps.empty()) {
return targetOp.getInnerModule().emitError()
<< "target's outer module does not contain a vm.module op";
}
- auto vmModuleOp = *vmModuleOps.begin();
- builder.setInsertionPoint(&linkedVmModuleOp.getBlock().back());
- // Use a SymbolTable to guard against inserting duplicate symbols.
- SymbolTable symbolTable(linkedVmModuleOp.getOperation());
+ mergeModuleInto(*vmModuleOps.begin(), linkedVmModuleOp, symbolMap);
- for (auto &op : vmModuleOp.getBody()->getOperations()) {
- if (auto terminatorOp = dyn_cast<IREE::VM::ModuleTerminatorOp>(op)) {
- continue;
- }
- if (op.hasTrait<SymbolOpInterface::Trait>() &&
- symbolTable.lookup(dyn_cast<SymbolOpInterface>(op).getName())) {
- continue;
- }
- builder.clone(op);
- }
-
- // Now that we're done cloning its ops, delete the original target op.
targetOp.erase();
+ }
- executablesLinked++;
+ if (executableOp.getOps<IREE::HAL::ExecutableTargetOp>().empty()) {
+ executableOp.erase();
}
}
- if (executablesLinked == 0) {
+ // Update references to @executable::@target::@entry symbols.
+ replaceEntryPointUses(moduleOp, entryPointRefReplacements);
+
+ // Remove if we didn't add anything.
+ if (linkedTargetOp.getOps<IREE::HAL::ExecutableEntryPointOp>().empty()) {
+ linkedTargetOp.erase();
linkedExecutableOp.erase();
}
return success();
}
+ // Destructively merges |sourceModuleOp| into |targetModuleOp|.
+ // |targetSymbolTable| is updated with the new symbols.
+ void mergeModuleInto(IREE::VM::ModuleOp sourceModuleOp,
+ IREE::VM::ModuleOp targetModuleOp,
+ DenseMap<StringRef, Operation *> &targetSymbolMap) {
+ auto allOps = llvm::to_vector<8>(llvm::map_range(
+ sourceModuleOp.getBlock(), [&](Operation &op) { return &op; }));
+ for (auto &op : allOps) {
+ if (op->isKnownTerminator()) continue;
+ if (auto symbolInterface = dyn_cast<SymbolOpInterface>(op)) {
+ if (targetSymbolMap.count(symbolInterface.getName())) {
+ // TODO(scotttodd): compare ops to ensure we aren't copying different
+ // things with the same name.
+ continue;
+ }
+ targetSymbolMap[symbolInterface.getName()] = op;
+ }
+ op->moveBefore(&targetModuleOp.getBlock().back());
+ }
+
+ // Now that we're done cloning its ops, delete the original target op.
+ sourceModuleOp.erase();
+ }
+
+ // Replaces each usage of an entry point with its original symbol name with a
+ // new symbol name.
+ void replaceEntryPointUses(
+ mlir::ModuleOp moduleOp,
+ const DenseMap<Attribute, Attribute> &replacements) {
+ for (auto funcOp : moduleOp.getOps<mlir::FuncOp>()) {
+ funcOp.walk([&](IREE::HAL::CommandBufferDispatchSymbolOp dispatchOp) {
+ auto it = replacements.find(dispatchOp.entry_point());
+ if (it != replacements.end()) {
+ dispatchOp.entry_pointAttr(it->second.cast<SymbolRefAttr>());
+ }
+ });
+ }
+ }
+
LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes.