Make iree-vm-deduplicate-rodata O(ops) instead of O(replacements*ops). (#14782)
Fixes #14754.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp
index 0cbfdfd..5bb794c 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DeduplicateRodata.cpp
@@ -56,7 +56,7 @@
bucketOps.push_back(rodataOp);
}
- SymbolTable symbolTable(moduleOp);
+ DenseMap<SymbolRefAttr, SymbolRefAttr> replacements;
for (auto bucketKV : bucketedOps) {
auto &bucketOps = bucketKV.second;
@@ -80,17 +80,23 @@
}
// Point all duplicates at the base op.
+ auto baseName = FlatSymbolRefAttr::get(baseOp.getNameAttr());
for (auto duplicateOp : bucketOps) {
- if (failed(symbolTable.replaceAllSymbolUses(
- duplicateOp, baseOp.getNameAttr(), moduleOp))) {
- duplicateOp.emitError()
- << "failed to replace duplicate rodata op with base op "
- << baseOp.getName();
- return signalPassFailure();
- }
+ replacements.insert(std::make_pair(
+ FlatSymbolRefAttr::get(duplicateOp.getSymNameAttr()), baseName));
duplicateOp.erase();
}
}
+
+ AttrTypeReplacer replacer;
+ replacer.addReplacement(
+ [&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
+ auto replacement = replacements.find(attr);
+ if (replacement != replacements.end())
+ return {replacement->getSecond(), WalkResult::skip()};
+ return {attr, WalkResult::skip()};
+ });
+ moduleOp.walk([&](Operation *op) { replacer.replaceElementsIn(op); });
}
};