blob: 2ee46924e719a6969b4f78ad3328ecb542367190 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <algorithm>
#include <iterator>
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTraits.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "iree-util-fold-globals"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace Util {
namespace {
template <typename R>
static size_t count(R &&range) {
return std::distance(range.begin(), range.end());
}
struct Global {
size_t ordinal = 0;
IREE::Util::GlobalOp op;
bool isIndirect = false;
SmallVector<IREE::Util::GlobalLoadOp> loadOps;
SmallVector<IREE::Util::GlobalStoreOp> storeOps;
};
enum class GlobalAction {
PRESERVE,
UPDATE,
DELETE,
};
struct GlobalTable {
mlir::ModuleOp moduleOp;
SmallVector<StringRef> globalOrder;
DenseMap<StringRef, Global> globalMap;
size_t size() const { return globalOrder.size(); }
explicit GlobalTable(mlir::ModuleOp moduleOp) : moduleOp(moduleOp) {
rebuild();
}
void rebuild() {
globalOrder.clear();
globalMap.clear();
for (auto globalOp : moduleOp.getOps<IREE::Util::GlobalOp>()) {
globalMap[globalOp.getName()] = Global{globalOrder.size(), globalOp};
globalOrder.push_back(globalOp.getName());
}
for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
callableOp.walk([&](Operation *op) {
if (auto addressOp = dyn_cast<IREE::Util::GlobalAddressOp>(op)) {
globalMap[addressOp.global()].isIndirect = true;
} else if (auto loadOp = dyn_cast<IREE::Util::GlobalLoadOp>(op)) {
globalMap[loadOp.global()].loadOps.push_back(loadOp);
} else if (auto storeOp = dyn_cast<IREE::Util::GlobalStoreOp>(op)) {
globalMap[storeOp.global()].storeOps.push_back(storeOp);
}
});
}
}
bool forEach(std::function<GlobalAction(Global &global)> fn) {
bool didChange = false;
for (size_t i = 0; i < size();) {
auto globalName = globalOrder[i];
auto action = fn(globalMap[globalName]);
switch (action) {
case GlobalAction::PRESERVE: {
++i;
break;
}
case GlobalAction::UPDATE: {
didChange |= true;
++i;
break;
}
case GlobalAction::DELETE: {
didChange |= true;
auto &global = globalMap[globalName];
assert(!global.op.isPublic() && "can't delete public globals");
assert(global.loadOps.empty() && "must not be used");
for (auto storeOp : global.storeOps) {
storeOp.erase();
}
global.op.erase();
globalMap.erase(globalName);
globalOrder.erase(globalOrder.begin() + i);
break;
}
}
}
return didChange;
}
};
// Inlines constant stores into global initializers if always stored to the
// same value.
//
// Example:
// util.global mutable @a : i32
// builtin.func @foo() {
// %c5 = arith.constant 5 : i32
// util.global.store %c5, @a : i32
// }
// builtin.func @bar() {
// %c5 = arith.constant 5 : i32
// util.global.store %c5, @a : i32
// }
// ->
// util.global @a = 5 : i32
static bool inlineConstantGlobalStores(GlobalTable &globalTable) {
return globalTable.forEach([&](Global &global) {
if (global.isIndirect) return GlobalAction::PRESERVE;
// Find the constant value used in all stores.
// All stores must match the initial value of the global _or_ the global
// must be uninitialized. A proper dataflow analysis would let us identify
// the cases where there are loads before stores and the implicit 0 initial
// value for uninitialized globals would be observable.
Attribute constantValue = global.op.initial_valueAttr();
for (auto storeOp : global.storeOps) {
Attribute valueAttr;
if (!matchPattern(storeOp.value(), m_Constant(&valueAttr))) {
constantValue = {};
break;
}
if (!constantValue) {
// First found.
constantValue = valueAttr;
} else if (constantValue != valueAttr) {
// Non-uniform.
constantValue = {};
break;
}
}
if (!constantValue) return GlobalAction::PRESERVE;
// Propagate constant into the initial value. Note that there may have been
// a previous initial value that is being replaced.
global.op.initial_valueAttr(constantValue);
// Remove all of the stores.
for (auto storeOp : global.storeOps) {
storeOp.erase();
}
global.storeOps.clear();
return GlobalAction::UPDATE;
});
}
// Renames globals that are always chained through each other.
//
// Example:
// util.global mutable @chained0 : i32
// util.global mutable @chained1 : i32
// builtin.func @foo() {
// %0 = util.global.load @chained0 : i32
// util.global.store %0, @chained1 : i32
// ->
// util.global.mutable @chained0 : i32
static bool renameChainedGlobals(GlobalTable &globalTable) {
return globalTable.forEach([&](Global &global) {
if (global.isIndirect) return GlobalAction::PRESERVE;
if (global.op.isPublic()) return GlobalAction::PRESERVE;
// Find the other symbol this global is chained with by looking for uniform
// stores. Note that we don't care about initializers.
FlatSymbolRefAttr aliasName;
for (auto storeOp : global.storeOps) {
// Check to see if the stored value comes from another global.
auto *definingOp = storeOp.value().getDefiningOp();
if (auto loadOp =
dyn_cast_or_null<IREE::Util::GlobalLoadOp>(definingOp)) {
if (!aliasName) {
aliasName = loadOp.globalAttr();
} else if (aliasName != loadOp.globalAttr()) {
aliasName = {};
break;
}
} else {
aliasName = {};
break;
}
}
if (!aliasName) return GlobalAction::PRESERVE;
// Replace all loads from the global with the aliased global.
auto &aliasGlobal = globalTable.globalMap[aliasName.getValue()];
for (auto loadOp : global.loadOps) {
loadOp.globalAttr(aliasName);
aliasGlobal.loadOps.push_back(loadOp);
}
global.loadOps.clear();
// Erase all stores to the global.
for (auto storeOp : global.storeOps) {
storeOp.erase();
}
global.storeOps.clear();
return GlobalAction::DELETE;
});
}
// Updates globals to be immutable if they are only stored during
// initialization.
//
// Example:
// util.global mutable @a : i32
// util.initializer {
// %c5 = arith.constant 5 : i32
// util.global.store %c5, @a : i32
// ->
// util.global @a = 5 : i32
static bool updateGlobalImmutability(GlobalTable &globalTable) {
return globalTable.forEach([&](Global &global) {
if (global.isIndirect) return GlobalAction::PRESERVE;
if (global.op.isPublic()) return GlobalAction::PRESERVE;
if (!global.storeOps.empty()) return GlobalAction::PRESERVE;
if (!global.op.isMutable()) return GlobalAction::PRESERVE;
global.op.removeIs_mutableAttr();
return GlobalAction::UPDATE;
});
}
// Tries to materialize a constant op for |attr| of |type|.
// Returns nullptr if the attribute could not be materialized as a constant.
static Value tryMaterializeConstant(Location loc, Type type, Attribute attr,
OpBuilder &builder) {
if (arith::ConstantOp::isBuildableWith(attr, type)) {
// Common case fast-path.
return builder.create<arith::ConstantOp>(loc, type, attr);
} else if (mlir::ConstantOp::isBuildableWith(attr, type)) {
return builder.create<mlir::ConstantOp>(loc, type, attr);
}
// Fallback that asks a dialect to materialize things. This may fail!
auto *op = attr.getDialect().materializeConstant(builder, attr, type, loc);
if (!op) return nullptr;
return op->getResult(0);
}
// Inlines constant global values that are known to not change.
static bool inlineConstantGlobalLoads(GlobalTable &globalTable) {
return globalTable.forEach([&](Global &global) {
if (global.isIndirect) return GlobalAction::PRESERVE;
if (!global.storeOps.empty()) return GlobalAction::PRESERVE;
if (global.op.isMutable()) return GlobalAction::PRESERVE;
if (global.op->hasAttr("noinline")) return GlobalAction::PRESERVE;
if (!global.op.initial_valueAttr()) return GlobalAction::PRESERVE;
if (global.op.type().isa<IREE::Util::ReferenceTypeInterface>()) {
// We only inline value types; reference types have meaning as globals.
return GlobalAction::PRESERVE;
}
// Inline initial value into all loads.
for (auto loadOp : global.loadOps) {
OpBuilder builder(loadOp);
auto constantValue =
tryMaterializeConstant(loadOp.getLoc(), loadOp.result().getType(),
global.op.initial_valueAttr(), builder);
if (!constantValue) {
// Failed to materialize the constant. This is going to fail for all the
// others as well so we bail here. This may be intentional for example
// if some value-type has no default value (similar to a deleted default
// constructor in C++) - in which case the global is preserved as-is.
return GlobalAction::PRESERVE;
}
loadOp.replaceAllUsesWith(constantValue);
loadOp.erase();
}
global.loadOps.clear();
// Only delete if private.
return global.op.isPrivate() ? GlobalAction::DELETE : GlobalAction::UPDATE;
});
}
// Erases all globals in the module that are never loaded.
// This differs from SymbolDCE in that initializations and storage operations
// are discarded.
static bool eraseUnusedGlobals(GlobalTable &globalTable) {
return globalTable.forEach([&](Global &global) {
if (global.isIndirect) return GlobalAction::PRESERVE;
if (global.op.isPublic()) return GlobalAction::PRESERVE;
if (global.loadOps.empty()) {
// No loads; remove entirely.
return GlobalAction::DELETE;
}
return GlobalAction::PRESERVE;
});
}
// Deduplicates immutable globals with constant initial values.
// This is a simplified and safer version of the global fusion pass.
static bool deduplicateConstantGlobals(GlobalTable &globalTable) {
// Build sets of all equivalent globals.
llvm::EquivalenceClasses<StringRef> ec;
DenseMap<Attribute, StringRef> leaderMap;
for (auto globalIt : globalTable.globalMap) {
auto &global = globalIt.getSecond();
if (global.isIndirect || global.op.isPublic() ||
!global.op.initial_value().hasValue() || !global.storeOps.empty()) {
// Not eligible for deduplication.
continue;
}
auto it =
leaderMap.insert({global.op.initial_valueAttr(), global.op.getName()});
if (it.second) {
// Inserted new.
ec.insert(global.op.getName());
} else {
// Existing global with the same value.
ec.unionSets(global.op.getName(), it.first->second);
}
}
bool didChange = false;
for (auto it = ec.begin(), end = ec.end(); it != end; ++it) {
if (!it->isLeader()) continue; // Ignore non-leader sets.
if (++ec.member_begin(it) == ec.member_end()) continue;
IREE::Util::GlobalOp baseGlobalOp = globalTable.globalMap[it->getData()].op;
// Build fused location from all of the globals.
SmallVector<Location> locs;
for (auto mi = ec.member_begin(it); mi != ec.member_end(); ++mi) {
Global &global = globalTable.globalMap[*mi];
locs.push_back(global.op.getLoc());
if (global.op->isBeforeInBlock(baseGlobalOp)) {
baseGlobalOp = global.op;
}
}
auto fusedLoc = FusedLoc::get(baseGlobalOp.getContext(), locs);
// Update base global location.
baseGlobalOp->setLoc(fusedLoc);
// Replace all other globals to point at the new one.
auto baseGlobalNameAttr = FlatSymbolRefAttr::get(
baseGlobalOp.getContext(), baseGlobalOp.getSymbolName());
for (auto mi = ec.member_begin(it); mi != ec.member_end(); ++mi) {
Global &global = globalTable.globalMap[*mi];
if (global.op == baseGlobalOp) continue;
for (auto loadOp : global.loadOps) {
loadOp.globalAttr(baseGlobalNameAttr);
}
global.op.erase();
}
didChange |= true;
}
if (didChange) {
// We could keep the table up to date to avoid the rebuilds by merging all
// loads into the base global.
globalTable.rebuild();
}
return didChange;
}
class FoldGlobalsPass
: public PassWrapper<FoldGlobalsPass, OperationPass<mlir::ModuleOp>> {
public:
explicit FoldGlobalsPass() = default;
FoldGlobalsPass(const FoldGlobalsPass &pass) {}
StringRef getArgument() const override { return "iree-util-fold-globals"; }
StringRef getDescription() const override {
return "Folds duplicate globals and propagates constants.";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::StandardOpsDialect>();
registry.insert<mlir::arith::ArithmeticDialect>();
registry.insert<IREE::Util::UtilDialect>();
}
void runOnOperation() override {
auto *context = &getContext();
OwningRewritePatternList patterns(context);
for (auto *dialect : context->getLoadedDialects()) {
dialect->getCanonicalizationPatterns(patterns);
}
for (auto op : context->getRegisteredOperations()) {
op.getCanonicalizationPatterns(patterns, context);
}
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
auto moduleOp = getOperation();
beforeFoldingGlobals = count(moduleOp.getOps<IREE::Util::GlobalOp>());
for (int i = 0; i < 10; ++i) {
// TODO(benvanik): determine if we need this expensive folding.
if (failed(applyPatternsAndFoldGreedily(moduleOp, frozenPatterns))) {
signalPassFailure();
}
GlobalTable globalTable(moduleOp);
bool didChange = false;
LLVM_DEBUG(llvm::dbgs() << "==== inlineConstantGlobalStores ====\n");
if (inlineConstantGlobalStores(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
didChange = true;
}
LLVM_DEBUG(llvm::dbgs() << "==== renameChainedGlobals ====\n");
if (renameChainedGlobals(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
didChange = true;
}
LLVM_DEBUG(llvm::dbgs() << "==== updateGlobalImmutability ====\n");
if (updateGlobalImmutability(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
didChange = true;
}
LLVM_DEBUG(llvm::dbgs() << "==== inlineConstantGlobalLoads ====\n");
if (inlineConstantGlobalLoads(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
didChange = true;
}
LLVM_DEBUG(llvm::dbgs() << "==== eraseUnusedGlobals ====\n");
if (eraseUnusedGlobals(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
didChange = true;
}
LLVM_DEBUG(llvm::dbgs() << "==== deduplicateConstantGlobals ====\n");
if (deduplicateConstantGlobals(globalTable)) {
LLVM_DEBUG(moduleOp.dump());
didChange = true;
}
if (!didChange) break;
}
afterFoldingGlobals = count(moduleOp.getOps<IREE::Util::GlobalOp>());
}
private:
Statistic beforeFoldingGlobals{this, "global ops before folding",
"Number of util.global ops before folding"};
Statistic afterFoldingGlobals{this, "global ops after folding",
"Number of util.global ops after folding"};
};
} // namespace
std::unique_ptr<OperationPass<mlir::ModuleOp>> createFoldGlobalsPass() {
return std::make_unique<FoldGlobalsPass>();
}
static PassRegistration<FoldGlobalsPass> pass;
} // namespace Util
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir