Bucket before deduping executables. (#13652)
Serves as little preprocessing step to reduce the number of pairwise
comparisons that need to be done by only comparing pairwise per bucket
(in one case >>50x speedup).
Only change is moving these to operate per bucket.
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
index e36179a..fa71ba4 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
@@ -207,47 +207,66 @@
void runOnOperation() override {
auto moduleOp = getOperation();
- auto executableOps = llvm::to_vector<8>(moduleOp.getOps<ExecutableOp>());
- auto builder = OpBuilder::atBlockBegin(moduleOp.getBody());
+ // Bucket based on the hash of the names of at most the first 5 ops.
+ // 5 was randomly chosen to be small enough to not increase overhead much,
+ // but giving at least enough of a sample that there is some bucketing. This
+ // was not empiraclly deetermined.
+ llvm::MapVector<uint32_t, SmallVector<ExecutableOp, 3>> executableOpsMap;
+ totalExecutables = 0;
+ for (auto op : moduleOp.getOps<ExecutableOp>()) {
+ int count = 0;
+ llvm::hash_code hash(1);
+ op.walk([&](Operation *it) {
+ hash = llvm::hash_combine(hash, it->getName());
+ return (++count >= 5) ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ executableOpsMap[hash_value(hash)].push_back(op);
+ ++totalExecutables;
+ }
+
+ auto builder = OpBuilder::atBlockBegin(moduleOp.getBody());
SmallVector<ExecutableOp, 3> duplicateExecutableOps;
DenseMap<Attribute, SymbolRefAttr> entryPointRefReplacements;
// For each executable, find the first executable which it is equivalent to.
- for (int i = executableOps.size() - 1; i >= 0; --i) {
- auto duplicateExecutableOp = executableOps[i];
+ for (auto &[key, executableOps] : executableOpsMap) {
+ (void)key;
+ for (int i = executableOps.size() - 1; i >= 0; --i) {
+ auto duplicateExecutableOp = executableOps[i];
- for (int j = 0; j < i; ++j) {
- auto referenceExecutableOp = executableOps[j];
- if (!isStructurallyEquivalentTo(duplicateExecutableOp.getBody(),
- referenceExecutableOp.getBody())) {
- continue;
+ for (int j = 0; j < i; ++j) {
+ auto referenceExecutableOp = executableOps[j];
+ if (!isStructurallyEquivalentTo(duplicateExecutableOp.getBody(),
+ referenceExecutableOp.getBody())) {
+ continue;
+ }
+
+ // Found an equivalent executable! Record it and move on to the next.
+ duplicateExecutableOps.push_back(duplicateExecutableOp);
+
+ // Record entry point reference replacements.
+ for (auto [oldExportOp, newExportOp] :
+ llvm::zip_equal(duplicateExecutableOp.getBlock()
+ .getOps<ExecutableExportOp>(),
+ referenceExecutableOp.getBlock()
+ .getOps<ExecutableExportOp>())) {
+ auto oldSymbolRefAttr = SymbolRefAttr::get(
+ builder.getContext(), duplicateExecutableOp.getName(),
+ {SymbolRefAttr::get(builder.getContext(),
+ oldExportOp.getSymName())});
+ auto newSymbolRefAttr = SymbolRefAttr::get(
+ builder.getContext(), referenceExecutableOp.getName(),
+ {SymbolRefAttr::get(builder.getContext(),
+ newExportOp.getSymName())});
+ entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
+ }
+
+ break;
}
-
- // Found an equivalent executable! Record it and move on to the next.
- duplicateExecutableOps.push_back(duplicateExecutableOp);
-
- // Record entry point reference replacements.
- for (auto [oldExportOp, newExportOp] : llvm::zip_equal(
- duplicateExecutableOp.getBlock().getOps<ExecutableExportOp>(),
- referenceExecutableOp.getBlock()
- .getOps<ExecutableExportOp>())) {
- auto oldSymbolRefAttr = SymbolRefAttr::get(
- builder.getContext(), duplicateExecutableOp.getName(),
- {SymbolRefAttr::get(builder.getContext(),
- oldExportOp.getSymName())});
- auto newSymbolRefAttr = SymbolRefAttr::get(
- builder.getContext(), referenceExecutableOp.getName(),
- {SymbolRefAttr::get(builder.getContext(),
- newExportOp.getSymName())});
- entryPointRefReplacements[oldSymbolRefAttr] = newSymbolRefAttr;
- }
-
- break;
}
}
- totalExecutables = executableOps.size();
executablesDeduplicated = duplicateExecutableOps.size();
remainingExecutables = totalExecutables - executablesDeduplicated;