Add statistics to Util/Transforms/FoldGlobals.cpp. (#7535)
Following up on https://github.com/google/iree/pull/7312#discussion_r728261108, this adds more [statistics](https://mlir.llvm.org/docs/PassManagement/#pass-statistics) to IREE's compiler passes. I personally build with the CMake option `-DLLVM_FORCE_ENABLE_STATS=ON` so the `-pass-statistics` flag works in `RelWithDebInfo` builds, but we don't yet have anything like that configured on our CI or recommended in our developer documentation. See also https://github.com/google/iree/issues/6161#issuecomment-912635307 for how this could be useful.
Sample output on [bert_encoder_unrolled_fake_weights.mlir](https://github.com/google/iree/blob/main/iree/test/e2e/models/bert_encoder_unrolled_fake_weights.mlir) (fake weights are folded aggressively 😛):
```
mlir::iree_compiler::IREE::Util::`anonymous-namespace'::FoldGlobalsPass
(S) 1113 global ops before folding - Number of util.global ops before folding
(S) 7 global ops after folding - Number of util.global ops after folding
```
diff --git a/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp b/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
index 94d5b26..10b012c 100644
--- a/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
+++ b/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
@@ -31,6 +31,11 @@
namespace Util {
namespace {
+template <typename R>
+static size_t count(R &&range) {
+ return std::distance(range.begin(), range.end());
+}
+
struct Global {
size_t ordinal = 0;
IREE::Util::GlobalOp op;
@@ -356,6 +361,9 @@
class FoldGlobalsPass
: public PassWrapper<FoldGlobalsPass, OperationPass<mlir::ModuleOp>> {
public:
+ explicit FoldGlobalsPass() = default;
+ FoldGlobalsPass(const FoldGlobalsPass &pass) {}
+
StringRef getArgument() const override { return "iree-util-fold-globals"; }
StringRef getDescription() const override {
@@ -382,6 +390,7 @@
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
auto moduleOp = getOperation();
+ beforeFoldingGlobals = count(moduleOp.getOps<IREE::Util::GlobalOp>());
for (int i = 0; i < 10; ++i) {
// TODO(benvanik): determine if we need this expensive folding.
if (failed(applyPatternsAndFoldGreedily(moduleOp, frozenPatterns))) {
@@ -429,7 +438,15 @@
if (!didChange) break;
}
+
+ afterFoldingGlobals = count(moduleOp.getOps<IREE::Util::GlobalOp>());
}
+
+ private:
+ Statistic beforeFoldingGlobals{this, "global ops before folding",
+ "Number of util.global ops before folding"};
+ Statistic afterFoldingGlobals{this, "global ops after folding",
+ "Number of util.global ops after folding"};
};
} // namespace