blob: 7b98e024edc1876fb1036fe0ad596731d40ef29d [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Utils/ModuleUtils.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Operation.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
namespace iree_compiler {
// Renames |op| within |moduleOp| with a new name that is unique within both
// |moduleOp| and |optionalSymbolTable| (if one is provided).
static void renameWithDisambiguatedName(Operation *op, Operation *moduleOp,
SymbolTable &symbolTable) {
StringRef originalName = SymbolTable::getSymbolName(op).getValue();
// This could stand to be rewritten noting that nested symbol refs exist.
// Iteratively try suffixes until we find one that isn't used.
// It'd be nice if there was a SymbolTable::getUniqueName or something.
std::string disambiguatedName;
int uniqueingCounter = 0;
do {
disambiguatedName =
llvm::formatv("{0}_{1}", originalName, uniqueingCounter++).str();
} while (symbolTable.lookup(disambiguatedName));
SymbolTableCollection symbolTables;
SymbolUserMap symbolUsers(symbolTables, moduleOp);
mlir::StringAttr nameAttr =
mlir::StringAttr::get(op->getContext(), disambiguatedName);
symbolUsers.replaceAllUsesWith(op, nameAttr);
SymbolTable::setSymbolName(op, disambiguatedName);
}
LogicalResult mergeModuleInto(Operation *sourceOp, Operation *targetOp,
SymbolTable &targetSymbolTable,
OpBuilder &targetBuilder) {
// Capture source information we need prior to destructively merging.
SymbolTable sourceSymbolTable(sourceOp);
auto &sourceBlock = sourceOp->getRegion(0).front();
auto sourceOps = llvm::to_vector<8>(
llvm::map_range(sourceBlock, [&](Operation &op) { return &op; }));
// Resolve conflicts and move the op.
for (auto &sourceOp : sourceOps) {
if (sourceOp->hasTrait<OpTrait::IsTerminator>()) continue;
if (auto symbolOp = dyn_cast<SymbolOpInterface>(sourceOp)) {
auto symbolName = symbolOp.getName();
// Resolve symbol name conflicts.
if (auto targetOp = targetSymbolTable.lookup(symbolName)) {
if (symbolOp.getVisibility() == SymbolTable::Visibility::Private) {
// Private symbols can be safely folded into duplicates or renamed.
if (OperationEquivalence::isEquivalentTo(
targetOp, sourceOp, OperationEquivalence::exactValueMatch,
OperationEquivalence::exactValueMatch,
OperationEquivalence::Flags::IgnoreLocations)) {
// Optimization: skip over duplicate private symbols.
// We could let CSE do this later, but we may as well check here.
continue;
} else {
// Preserve the op but give it a unique name.
renameWithDisambiguatedName(sourceOp, sourceOp, sourceSymbolTable);
}
} else {
// The source symbol has 'nested' or 'public' visibility.
if (SymbolTable::getSymbolVisibility(targetOp) !=
SymbolTable::Visibility::Private) {
// Oops! Both symbols are public and we can't safely rename either.
// If you hit this with ops that you think are safe to rename, mark
// them private.
return sourceOp->emitError()
<< "multiple public symbols with the name: " << symbolName;
} else {
// Keep the original name for our new op, rename the target op.
renameWithDisambiguatedName(targetOp, targetOp, targetSymbolTable);
}
}
}
sourceOp->moveBefore(targetBuilder.getInsertionBlock(),
targetBuilder.getInsertionPoint());
targetSymbolTable.insert(sourceOp);
targetBuilder.setInsertionPoint(sourceOp);
}
}
return success();
}
LogicalResult mergeSourceModuleInto(Location loc, StringRef source,
Operation *targetOp,
SymbolTable &targetSymbolTable,
OpBuilder &targetBuilder) {
// Parse the module. This will only fail if the compiler was built wrong;
// we're loading the embedded files from the compiler binary.
auto sourceModuleRef =
mlir::parseSourceString(source, targetBuilder.getContext());
if (!sourceModuleRef) {
return mlir::emitError(
loc, "source module failed to parse; ensure dialects are registered");
}
// Merge all of the module contents.
return mergeModuleInto(*sourceModuleRef, targetOp, targetSymbolTable,
targetBuilder);
}
} // namespace iree_compiler
} // namespace mlir