Improving global folding and IPO for immutable globals. (#18066)

Parameters and constants end up as globals that previously were not
getting marked as immutable unless the globals were default initialized.
Now the GlobalTable tracks whether a global is exclusively stored within
initializers (or functions only called from initializers) in order to
mark them as immutable. IPO was updated to support propagating uniform
immutable global loads across call edges as if they were constants (as
they effectively are just constants stored on the global scope).

Required for #17875 (to avoid treating constants/parameters as dynamic
binding table values).
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
index 44e76a4..4f5ad49 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/export_benchmark_funcs.mlir
@@ -1,9 +1,9 @@
-// RUN: iree-opt --split-input-file --iree-flow-transformation-pipeline --iree-flow-export-benchmark-funcs --verify-diagnostics %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-flow-export-benchmark-funcs-pass --verify-diagnostics %s | FileCheck %s
 
 // Basic usage from the `--iree-native-bindings-support` flag.
 
 // CHECK-LABEL: func private @simpleMul
-util.func public @simpleMul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.module.export} {
+util.func public @simpleMul(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view {
   %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<4xf32>
   %1 = hal.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32>
   %2 = arith.mulf %0, %1 : tensor<4xf32>
@@ -41,8 +41,8 @@
 //     CHECK: util.global private @[[GLOBAL_ARG1:.+]] {{{.+}}} = 0 : i32
 
 //     CHECK: util.func public @while_benchmark()
-// CHECK-DAG:   %[[ARG0:.+]] = util.global.load immutable @[[GLOBAL_ARG0]] : i32
-// CHECK-DAG:   %[[ARG1:.+]] = util.global.load immutable @[[GLOBAL_ARG1]] : i32
+// CHECK-DAG:   %[[ARG0:.+]] = util.global.load @[[GLOBAL_ARG0]] : i32
+// CHECK-DAG:   %[[ARG1:.+]] = util.global.load @[[GLOBAL_ARG1]] : i32
 //     CHECK:   %[[RET0:.+]] = util.call @while(%[[ARG0]], %[[ARG1]])
 //     CHECK:   util.optimization_barrier %[[RET0]] : i32
 //     CHECK:   util.return
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
index e9f0af1..7fdc301 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
@@ -6,10 +6,57 @@
 
 #include "iree/compiler/Dialect/Util/Analysis/GlobalTable.h"
 
+#include "mlir/Analysis/CallGraph.h"
+
 namespace mlir::iree_compiler::IREE::Util {
 
+// Returns a set of all top-level callable ops that are externally reachable.
+// Callables only reachable from initializers are excluded.
+static DenseSet<Operation *>
+calculateExternallyReachableOps(ModuleOp moduleOp) {
+  DenseSet<Operation *> externallyReachableOps;
+
+  // Expensive; we want to avoid this unless the call graph changes.
+  CallGraph callGraph(moduleOp);
+
+  SetVector<CallGraphNode *> worklist;
+  worklist.insert(callGraph.begin(), callGraph.end());
+  while (!worklist.empty()) {
+    auto *node = worklist.pop_back_val();
+    if (node->isExternal()) {
+      // Skip declarations.
+      continue;
+    }
+    auto *callableOp = node->getCallableRegion()->getParentOp();
+    if (isa<IREE::Util::InitializerOpInterface>(callableOp)) {
+      // Initializers are never externally reachable.
+      continue;
+    }
+    bool isExternallyReachable = externallyReachableOps.contains(callableOp);
+    if (auto funcOp = dyn_cast<FunctionOpInterface>(callableOp)) {
+      // Public functions exported on the module are externally reachable.
+      isExternallyReachable |= funcOp.isPublic();
+    }
+    if (isExternallyReachable) {
+      // Insert into the set of reachable ops and also any outgoing calls.
+      // Queue up the edges in the worklist for further processing.
+      externallyReachableOps.insert(callableOp);
+      for (auto outgoingEdge : *node) {
+        auto *calleeNode = outgoingEdge.getTarget();
+        if (!calleeNode->isExternal()) {
+          externallyReachableOps.insert(
+              calleeNode->getCallableRegion()->getParentOp());
+          worklist.insert(outgoingEdge.getTarget());
+        }
+      }
+    }
+  }
+
+  return externallyReachableOps;
+}
+
 GlobalTable::GlobalTable(mlir::ModuleOp moduleOp) : moduleOp(moduleOp) {
-  rebuild();
+  externallyReachableOps = calculateExternallyReachableOps(moduleOp);
 }
 
 void GlobalTable::rebuild() {
@@ -46,6 +93,18 @@
       }
     }
   }
+
+  for (auto &[globalName, global] : globalMap) {
+    bool anyNonInitializerStores = false;
+    for (auto storeOp : global.storeOps) {
+      auto callableOp = storeOp->getParentOfType<CallableOpInterface>();
+      if (externallyReachableOps.contains(callableOp)) {
+        anyNonInitializerStores = true;
+        break;
+      }
+    }
+    global.onlyInitialized = !anyNonInitializerStores;
+  }
 }
 
 Global &GlobalTable::lookup(StringRef globalName) {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.h
index f7c17ce..5da6610 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.h
@@ -31,6 +31,10 @@
   // currently have any input programs that require doing so.
   bool isIndirect = false;
 
+  // True if all stores to the global are performed within initializers or calls
+  // only reachable from initializers.
+  bool onlyInitialized = false;
+
   // All util.global.load ops referencing the global.
   SmallVector<IREE::Util::GlobalLoadOpInterface> loadOps;
   // All util.global.store ops referencing the global.
@@ -81,12 +85,20 @@
 // A constructed table of analyzed globals in a module with some utilities for
 // manipulating them. This is designed for simple uses and more advanced
 // analysis should be performed with an Explorer or DFX.
+//
+// The global table is not built on creation and `rebuild` must be called before
+// querying it.
 struct GlobalTable {
   GlobalTable() = delete;
   explicit GlobalTable(mlir::ModuleOp moduleOp);
 
   MLIRContext *getContext() { return moduleOp.getContext(); }
 
+  // Rebuilds the global table.
+  // Must be called if the table is to be used after any globals or operations
+  // on globals have changed.
+  void rebuild();
+
   // Total number of globals in the module.
   size_t size() const { return globalOrder.size(); }
 
@@ -114,10 +126,13 @@
   void eraseGlobal(StringRef globalName);
 
 private:
-  void rebuild();
-
   // Module under analysis.
   mlir::ModuleOp moduleOp;
+
+  // Top-level callables that are externally reachable.
+  // Excludes initializers or any callable only reachable from initializers.
+  DenseSet<Operation *> externallyReachableOps;
+
   // All globals in the order they are declared by symbol name.
   SmallVector<StringRef> globalOrder;
   // A map of global symbol names to analysis results.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
index c3e719e..467d45f 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FoldGlobals.cpp
@@ -158,7 +158,7 @@
   return globalTable.forEach([&](Global &global) {
     if (!global.isCandidate()) {
       return GlobalAction::PRESERVE;
-    } else if (!global.storeOps.empty()) {
+    } else if (!global.storeOps.empty() && !global.onlyInitialized) {
       return GlobalAction::PRESERVE;
     }
     bool didChangeAny = global.op.isGlobalMutable() != false;
@@ -365,28 +365,29 @@
   void runOnOperation() override {
     auto *context = &getContext();
     RewritePatternSet patterns(context);
-
     for (auto *dialect : context->getLoadedDialects()) {
       dialect->getCanonicalizationPatterns(patterns);
     }
     for (auto op : context->getRegisteredOperations()) {
       op.getCanonicalizationPatterns(patterns, context);
     }
-
     FrozenRewritePatternSet frozenPatterns(std::move(patterns));
 
     auto moduleOp = getOperation();
-    beforeFoldingGlobals =
-        count(moduleOp.getOps<IREE::Util::GlobalOpInterface>());
+    GlobalTable globalTable(moduleOp);
+    beforeFoldingGlobals = globalTable.size();
     for (int i = 0; i < 10; ++i) {
       // TODO(benvanik): determine if we need this expensive folding.
       if (failed(applyPatternsAndFoldGreedily(moduleOp, frozenPatterns))) {
         signalPassFailure();
+        return;
       }
 
-      GlobalTable globalTable(moduleOp);
       bool didChange = false;
 
+      // Rebuild the global table after potential pattern changes.
+      globalTable.rebuild();
+
       LLVM_DEBUG(llvm::dbgs() << "==== inlineConstantGlobalStores ====\n");
       if (inlineConstantGlobalStores(globalTable)) {
         LLVM_DEBUG(moduleOp.dump());
@@ -424,6 +425,7 @@
       }
 
       if (!didChange) {
+        // No changes; complete fixed-point iteration.
         break;
       }
     }
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp
index d9c2f5b..ea45521 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/FuseGlobals.cpp
@@ -59,6 +59,7 @@
     auto moduleOp = getOperation();
 
     GlobalTable globalTable(moduleOp);
+    globalTable.rebuild();
 
     // Build a map of global symbol to a bitvector indicating which globals are
     // stored with the same values in all instances.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
index 6307926..1d6279d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/IPO.cpp
@@ -56,6 +56,7 @@
   // Which args are uniform from all call sites.
   BitVector callerUniformArgs;
   // Values for each arg if they are uniformly constant at all call sites.
+  // May be any constant attribute or an immutable global symbol ref.
   SmallVector<LocAttr> callerUniformArgValues;
   // Uniform call operand index -> deduplicated index.
   // Base/non-duplicated values will be identity.
@@ -69,6 +70,7 @@
   // Which results are uniform from all return sites in the function.
   BitVector calleeUniformResults;
   // Values for each result if they are uniformly constant at all return sites.
+  // May be any constant attribute or an immutable global symbol ref.
   SmallVector<LocAttr> calleeUniformResultValues;
   // Uniform callee return operand index -> deduplicated index.
   // Base/non-duplicated values will be identity.
@@ -94,9 +96,13 @@
         os << "dupe(%arg" << callerUniformArgDupeMap[i] << ") ";
       }
       os << argTypes[i] << " ";
-      if (callerUniformArgValues[i]) {
-        os << "constant = ";
-        callerUniformArgValues[i].attr.print(os);
+      if (auto constant = callerUniformArgValues[i]) {
+        if (isa<SymbolRefAttr>(constant.attr)) {
+          os << "immutable global = ";
+        } else {
+          os << "constant = ";
+        }
+        constant.attr.print(os);
       }
       os << "\n";
     }
@@ -113,9 +119,13 @@
         os << "pass(%arg" << passthroughResultArgs[i] << ") ";
       }
       os << resultTypes[i] << " ";
-      if (calleeUniformResultValues[i]) {
-        os << "constant = ";
-        calleeUniformResultValues[i].attr.print(os);
+      if (auto constant = calleeUniformResultValues[i]) {
+        if (isa<SymbolRefAttr>(constant.attr)) {
+          os << "immutable global = ";
+        } else {
+          os << "constant = ";
+        }
+        constant.attr.print(os);
       }
       os << "\n";
     }
@@ -128,6 +138,17 @@
   }
 };
 
+// Returns a global symbol ref if the value is loaded from an immutable global.
+static SymbolRefAttr matchImmutableGlobalLoad(Value value) {
+  if (auto loadOp = dyn_cast_if_present<IREE::Util::GlobalLoadOpInterface>(
+          value.getDefiningOp())) {
+    if (loadOp.isGlobalImmutable()) {
+      return loadOp.getGlobalAttr();
+    }
+  }
+  return {};
+}
+
 // Note that the analysis results may be incomplete.
 static FuncAnalysis analyzeFuncOp(IREE::Util::FuncOp funcOp,
                                   Explorer &explorer) {
@@ -184,6 +205,12 @@
             value.getType(),
             constantValue,
         };
+      } else if (auto globalRef = matchImmutableGlobalLoad(value)) {
+        analysis.calleeUniformResultValues[i] = {
+            value.getLoc(),
+            value.getType(),
+            globalRef,
+        };
       }
 
       // Check to see if the value returned is the same as previously seen.
@@ -245,6 +272,20 @@
           // Value constant has changed from prior calls: mark non-uniform.
           analysis.callerUniformArgs.reset(i);
         }
+      } else if (auto globalRef = matchImmutableGlobalLoad(value)) {
+        if (!seenArgAttrs[i]) {
+          // First call site with a constant or immutable global: stash so we
+          // can inline it if it's uniform.
+          seenArgAttrs[i] = globalRef;
+          analysis.callerUniformArgValues[i] = {
+              value.getLoc(),
+              value.getType(),
+              globalRef,
+          };
+        } else if (seenArgAttrs[i] != globalRef) {
+          // Value constant has changed from prior calls: mark non-uniform.
+          analysis.callerUniformArgs.reset(i);
+        }
       } else {
         // Check to see if the value is the same as previously seen.
         // This will ensure that across calling functions we set non-uniform
@@ -367,6 +408,14 @@
                                      OpBuilder &builder) {
   Operation *op = nullptr;
 
+  // Immutable global loads are represented as constant symbol refs.
+  if (auto globalRef = dyn_cast<SymbolRefAttr>(constantValue.attr)) {
+    op = builder.create<IREE::Util::GlobalLoadOp>(
+        constantValue.loc.value(), constantValue.type,
+        globalRef.getLeafReference().getValue(),
+        /*is_immutable=*/true);
+  }
+
   // Handle special builtin types that for some reason can't materialize
   // themselves.
   if (arith::ConstantOp::isBuildableWith(constantValue.attr,
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir
index 5d83864..906cf38 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/fold_globals.mlir
@@ -98,6 +98,37 @@
 
 // -----
 
+// CHECK: util.global private @immutable_initializer_local
+util.global private mutable @immutable_initializer_local : index
+// CHECK: util.global private @immutable_initializer_callee
+util.global private mutable @immutable_initializer_callee : index
+// CHECK: util.global private mutable @mutable : index
+util.global private mutable @mutable : index
+util.func private @generate_value() -> index
+util.initializer {
+  %value = util.call @generate_value() : () -> index
+  util.global.store %value, @immutable_initializer_local : index
+  util.return
+}
+util.func @public_func() -> (index, index, index) {
+  util.call @public_callee() : () -> ()
+  // CHECK-DAG: %[[LOCAL:.+]] = util.global.load immutable @immutable_initializer_local
+  %0 = util.global.load @immutable_initializer_local : index
+  // CHECK-DAG: %[[CALLEE:.+]] = util.global.load immutable @immutable_initializer_callee
+  %1 = util.global.load @immutable_initializer_callee : index
+  // CHECK-DAG: %[[MUTABLE:.+]] = util.global.load @mutable
+  %2 = util.global.load @mutable : index
+  // CHECK: return %[[LOCAL]], %[[CALLEE]], %[[MUTABLE]]
+  util.return %0, %1, %2 : index, index, index
+}
+util.func private @public_callee() {
+  %value = util.call @generate_value() : () -> index
+  util.global.store %value, @mutable : index
+  util.return
+}
+
+// -----
+
 // CHECK: util.global private mutable @used0 = 5 : index
 util.global private mutable @used0 = 5 : index
 // CHECK: util.global private mutable @used1 : index
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/ipo.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/ipo.mlir
index 27ead2b..61a0ec4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/ipo.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/ipo.mlir
@@ -209,6 +209,61 @@
 
 // -----
 
+// Tests that uniform args that come from immutable globals are inlined into
+// callees.
+
+util.global private @global : index
+
+// CHECK-LABEL: util.func private @immutable_global_arg_callee
+// CHECK-SAME: () -> index
+util.func private @immutable_global_arg_callee(%arg0: index) -> index {
+  // CHECK: %[[GLOBAL_VALUE:.+]] = util.global.load immutable @global
+  // CHECK: %[[ADD:.+]] = arith.addi %[[GLOBAL_VALUE]], %[[GLOBAL_VALUE]]
+  %add = arith.addi %arg0, %arg0 : index
+  // CHECK: util.return %[[ADD]]
+  util.return %add : index
+}
+
+// CHECK: util.func public @immutable_global_arg_caller
+util.func public @immutable_global_arg_caller() -> index {
+  %global_value = util.global.load immutable @global : index
+  // CHECK: %[[RET:.+]] = util.call @immutable_global_arg_callee() : () -> index
+  %ret = util.call @immutable_global_arg_callee(%global_value) : (index) -> index
+  // CHECK: util.return %[[RET]]
+  util.return %ret : index
+}
+
+// -----
+
+// Tests that uniformly results that are immutable global loads get inlined into
+// callers.
+
+util.global private @global : index
+
+// CHECK-LABEL: util.func private @immutable_global_result_callee
+// CHECK-SAME: (%[[ARG0:.+]]: i1)
+util.func private @immutable_global_result_callee(%arg0: i1) -> index {
+  %global_value = util.global.load immutable @global : index
+  cf.cond_br %arg0, ^bb1, ^bb2
+^bb1:
+  // CHECK: util.return
+  util.return %global_value : index
+^bb2:
+  // CHECK: util.return
+  util.return %global_value : index
+}
+
+// CHECK: util.func public @immutable_global_result_caller(%[[ARG0:.+]]: i1)
+util.func public @immutable_global_result_caller(%arg0: i1) -> index {
+  // CHECK: call @immutable_global_result_callee(%[[ARG0]]) : (i1) -> ()
+  %ret0 = util.call @immutable_global_result_callee(%arg0) : (i1) -> index
+  // CHECK: %[[GLOBAL_VALUE:.+]] = util.global.load immutable @global
+  // CHECK: util.return %[[GLOBAL_VALUE]]
+  util.return %ret0 : index
+}
+
+// -----
+
 // Tests that uniformly duplicate constant results get combined/inlined.
 
 // CHECK-LABEL: util.func private @dupe_constant_result_callee