Improving VM conversion performance. (#18957)
The major change here is using a precomputed import table in VM
conversion patterns. This removes the symbol lookup that was happening
on each call. In models with 100k calls to imports this speeds things up
a lot.
Also squashed a few more perf issues involving symbol lookups while
profiling and made some passes that could nest on function-like ops do
so.
These changes drop VM translation of the 405b model from 3.5mins to
~1.5min. Disabling verification (`-verify-each=0` to iree-opt or
`-verify=false` to iree-compile) takes it to 1min.
Remaining work is mostly around parallelizing some passes that are not
trivially parallelizable (FoldGlobals, DropUnusedCalls, etc) and
parallelizing some analysis (Explorer global init, call graph walking)
that tends to get real expensive when there are 250k calls and 500k ops.
Any place that does a symbol use walk is going to suffer. Many of these
fixes are in our code but there's several upstream components that fall
over with this amount of IR (CallGraph, DataFlowSolver, the verifier,
etc).
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
index 1357e07..a84205a 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
@@ -88,10 +88,12 @@
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
index 07df5d8..1de6f1c 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/Passes.cpp
@@ -149,18 +149,20 @@
//===----------------------------------------------------------------------===//
static void addCleanupPatterns(OpPassManager &passManager) {
- // Standard MLIR cleanup.
- passManager.addPass(mlir::createCSEPass());
- passManager.addPass(mlir::createCanonicalizerPass());
- passManager.addPass(mlir::createCSEPass());
- // Simplify util.global accesses; this can help with data flow tracking as
- // redundant store-loads are removed.
FunctionLikeNest(passManager)
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ // Standard MLIR cleanup.
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
+ // Simplify util.global accesses; this can help with data flow tracking as
+ // redundant store-loads are removed.
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
}
diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
index 408bb02..2234c62 100644
--- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp
@@ -44,10 +44,12 @@
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
index 745f029..32c7819 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.cpp
@@ -126,20 +126,33 @@
// TODO(benvanik): filter the use list by traversal actions; where this runs
// today we don't yet have the actions specified so we can't.
+ // Initialize the full list of globals.
+ for (auto globalOp :
+ symbolTableOp->getRegion(0).getOps<IREE::Util::GlobalOpInterface>()) {
+ auto globalInfo = std::make_unique<GlobalInfo>();
+ globalInfo->op = globalOp;
+ globalInfosByName[globalOp.getGlobalName().getValue()] = globalInfo.get();
+ globalInfos[globalOp] = std::move(globalInfo);
+ }
+
+ // Walk the module and gather uses.
+ //
+ // TODO: find a way to do this more efficiently when the module is large.
+ // We could parallelize on top-level functions and then merge at the end.
auto allUses = symbolTable.getSymbolUses(&symbolTableOp->getRegion(0));
- if (!allUses.has_value())
- return;
- for (auto use : allUses.value()) {
- auto *symbolOp =
- symbolTable.lookupNearestSymbolFrom(use.getUser(), use.getSymbolRef());
- if (!isa_and_nonnull<IREE::Util::GlobalOpInterface>(symbolOp))
- continue;
- auto &globalInfo = globalInfos[symbolOp];
- globalInfo.op = cast<IREE::Util::GlobalOpInterface>(symbolOp);
- if (isa<IREE::Util::GlobalAddressOpInterface>(use.getUser())) {
- globalInfo.isIndirect = true;
- } else {
- globalInfo.uses.push_back(use.getUser());
+ if (allUses.has_value()) {
+ for (auto use : allUses.value()) {
+ auto globalInfoIt = globalInfosByName.find(
+ use.getSymbolRef().getLeafReference().getValue());
+ if (globalInfoIt == globalInfosByName.end()) {
+ continue; // not a global
+ }
+ auto *globalInfo = globalInfoIt->second;
+ if (isa<IREE::Util::GlobalAddressOpInterface>(use.getUser())) {
+ globalInfo->isIndirect = true;
+ } else {
+ globalInfo->uses.push_back(use.getUser());
+ }
}
}
}
@@ -175,7 +188,7 @@
auto it = globalInfos.find(globalOp);
if (it == globalInfos.end())
return nullptr;
- return &it->second;
+ return it->second.get();
}
const Explorer::GlobalInfo *Explorer::queryGlobalInfoFrom(StringRef globalName,
@@ -189,12 +202,12 @@
auto it = globalInfos.find(op);
if (it == globalInfos.end())
return nullptr;
- return &it->second;
+ return it->second.get();
}
void Explorer::forEachGlobal(std::function<void(const GlobalInfo *)> fn) {
- for (auto it : globalInfos) {
- fn(&it.second);
+ for (auto &it : globalInfos) {
+ fn(it.second.get());
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
index 35ee12a..4c9482b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Explorer.h
@@ -403,7 +403,8 @@
DenseMap<StringRef, TraversalAction> dialectActions;
DenseMap<OperationName, TraversalAction> opActions;
- DenseMap<Operation *, GlobalInfo> globalInfos;
+ DenseMap<Operation *, std::unique_ptr<GlobalInfo>> globalInfos;
+ DenseMap<StringRef, GlobalInfo *> globalInfosByName;
ModuleAnalysisManager analysisManager;
};
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
index 7fdc301..70387f6 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/GlobalTable.cpp
@@ -69,6 +69,8 @@
globalOrder.push_back(globalName);
}
+ // TODO: parallelize this by gathering on multiple threads per callable and
+ // then merging at the end.
for (auto callableOp : moduleOp.getOps<CallableOpInterface>()) {
if (auto uses = SymbolTable::getSymbolUses(callableOp)) {
for (auto use : *uses) {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
index ac527d7..e7c722d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Patterns.cpp
@@ -46,7 +46,7 @@
if (!op.getCallableRegion())
return failure();
auto ®ion = *op.getCallableRegion();
- if (region.empty())
+ if (region.empty() || region.hasOneBlock())
return failure();
// Analyze all branches in the op to compute the information we'll need to
@@ -501,7 +501,6 @@
context->getOrLoadDialect<IREE::Util::UtilDialect>()
->getCanonicalizationPatterns(patterns);
- // TODO(benvanik): same as branch folding but for calls.
patterns.insert<FoldBlockArgumentsPattern, ElideBranchOperandsPattern>(
context);
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
index fa837d4..ea5e257 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp
@@ -18,6 +18,56 @@
namespace mlir::iree_compiler {
+LogicalResult ImportTable::build(Operation *rootOp,
+ const TypeConverter &typeConverter) {
+ for (auto funcOp : rootOp->getRegion(0).getOps<FunctionOpInterface>()) {
+ if (!funcOp.isExternal()) {
+ continue; // only external functions are imports
+ }
+
+ ImportTable::Import import;
+ import.name = funcOp.getNameAttr();
+ import.fallback = funcOp->getAttrOfType<SymbolRefAttr>("vm.fallback");
+
+ // Try to use an assigned signature or fall back to converting the input.
+ if (auto importOp = dyn_cast<IREE::VM::ImportOp>(funcOp.getOperation())) {
+ // Import ops have their signature used directly.
+ import.signature = importOp.getFunctionType();
+ } else if (auto signatureAttr =
+ funcOp->getAttrOfType<TypeAttr>("vm.signature")) {
+ // Directly use the specified signature.
+ import.signature =
+ dyn_cast_if_present<FunctionType>(signatureAttr.getValue());
+ }
+ if (!import.signature) {
+ // Convert the signature using the type converter.
+ SmallVector<Type> argumentTypes;
+ if (failed(typeConverter.convertTypes(funcOp.getArgumentTypes(),
+ argumentTypes))) {
+ return funcOp.emitError() << "unable to convert import argument types";
+ }
+ SmallVector<Type> resultTypes;
+ if (failed(typeConverter.convertTypes(funcOp.getResultTypes(),
+ resultTypes))) {
+ return funcOp.emitError() << "unable to convert import result types";
+ }
+ import.signature =
+ FunctionType::get(rootOp->getContext(), argumentTypes, resultTypes);
+ }
+
+ symbols[import.name.getValue()] = std::move(import);
+ }
+
+ return success();
+}
+
+std::optional<ImportTable::Import> ImportTable::find(StringRef symbolName) {
+ auto it = symbols.find(symbolName);
+ if (it == symbols.end())
+ return std::nullopt;
+ return it->second;
+}
+
// TODO(benvanik): replace with iree/compiler/Utils/ModuleUtils.h.
// There may be some special insertion order arrangement required based on the
// nested vm.module here.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
index b2f0a8f..c5d557f 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h
@@ -20,6 +20,33 @@
// segment_sizes array.
constexpr int kFixedSingleValue = -1;
+// A table of import information.
+class ImportTable {
+public:
+ // Information about an import function.
+ struct Import {
+ // Used to ensure the StringRef in the map stays live.
+ StringAttr name;
+ // Function signature derived from the type or overridden by `vm.signature`.
+ FunctionType signature;
+ // Optional fallback function that should be used when the import is
+ // unavailable at runtime taken from `vm.fallback`.
+ SymbolRefAttr fallback;
+ };
+
+ // Builds a table of all import functions nested within the given |rootOp|.
+ // Clones any information such that the original ops can be mutated/erased.
+ // Must only be called once the type converter has been fully populated.
+ LogicalResult build(Operation *rootOp, const TypeConverter &typeConverter);
+
+ // Finds an import with the given name if there exists one.
+ std::optional<Import> find(StringRef symbolName);
+
+private:
+ // Map of symbol names within the root op to import symbol info.
+ DenseMap<StringRef, Import> symbols;
+};
+
// Appends a set of vm.import ops from a module to a target VM module.
// Imports will only be added if they are not already present in the target
// module.
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp
index f9a5c26..1cb4f41 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.cpp
@@ -170,7 +170,6 @@
constexpr const char *kRetainedAttributes[] = {
"nosideeffects",
"vm.fallback",
- "vm.signature",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
@@ -241,7 +240,11 @@
};
struct CallOpConversion : public OpConversionPattern<func::CallOp> {
- using OpConversionPattern::OpConversionPattern;
+ ImportTable &importTable;
+ CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
+ ImportTable &importTable, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ importTable(importTable) {}
LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -262,7 +265,8 @@
// conversion if imports have fallbacks that are themselves imports.
auto callResults = convertCallOp(
callOp->getParentOfType<IREE::VM::ModuleOp>(), callOp.getLoc(),
- callOp.getCallee(), adaptor.getOperands(), resultTypes, rewriter);
+ callOp.getCallee(), adaptor.getOperands(), resultTypes, importTable,
+ rewriter);
if (failed(callResults)) {
return rewriter.notifyMatchFailure(
callOp, "unable to convert call (results mismatch)");
@@ -277,36 +281,12 @@
FailureOr<SmallVector<Value>>
convertCallOp(Operation *rootOp, Location loc, StringRef calleeName,
ValueRange operands, TypeRange resultTypes,
+ ImportTable &importTable,
ConversionPatternRewriter &rewriter) const {
- // (Slow) lookup of the target function, which may be an import that we need
- // to perform type conversion for.
- auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName);
- if (auto funcOp = dyn_cast_or_null<FunctionOpInterface>(calleeOp)) {
- if (funcOp.isExternal()) {
- // Import that may require conversion.
- // This case handles when funcs are declared after the call.
- FunctionType convertedSignature;
- if (auto signatureAttr =
- funcOp->getAttrOfType<TypeAttr>("vm.signature")) {
- if (auto importSignature =
- llvm::dyn_cast<FunctionType>(signatureAttr.getValue())) {
- convertedSignature = importSignature;
- }
- }
- if (!convertedSignature) {
- convertedSignature =
- rewriter.getFunctionType(TypeRange(operands), resultTypes);
- }
- return convertImportCallOp(rootOp, loc, calleeName, operands,
- resultTypes, convertedSignature, funcOp,
- rewriter);
- }
- } else if (auto importOp = dyn_cast_or_null<IREE::VM::ImportOp>(calleeOp)) {
- // Calling an import.
- // This case handles when funcs are declared before the call and have
- // already been converted.
- return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes,
- importOp.getFunctionType(), importOp,
+ // Lookup the target and detect if it is an import.
+ auto import = importTable.find(calleeName);
+ if (import.has_value()) {
+ return convertImportCallOp(rootOp, loc, *import, operands, resultTypes,
rewriter);
}
@@ -319,19 +299,19 @@
// Converts a call to an import that may be optional.
// Returns the new converted call results.
FailureOr<SmallVector<Value>>
- convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName,
- ValueRange operands, TypeRange resultTypes,
- FunctionType importSignature, Operation *calleeOp,
+ convertImportCallOp(Operation *rootOp, Location loc,
+ ImportTable::Import &import, ValueRange operands,
+ TypeRange resultTypes,
ConversionPatternRewriter &rewriter) const {
- auto fallbackAttr = calleeOp->getAttrOfType<SymbolRefAttr>("vm.fallback");
- return fallbackAttr
- ? convertOptionalImportCallOp(
- rootOp, loc, calleeName, operands, resultTypes,
- importSignature,
- fallbackAttr.getLeafReference().getValue(), rewriter)
- : convertMandatoryImportCallOp(rootOp, loc, calleeName, operands,
- resultTypes, importSignature,
- rewriter);
+ if (import.fallback) {
+ return convertOptionalImportCallOp(
+ rootOp, loc, import.name, operands, resultTypes, import.signature,
+ import.fallback.getLeafReference().getValue(), rewriter);
+ } else {
+ return convertMandatoryImportCallOp(rootOp, loc, import.name, operands,
+ resultTypes, import.signature,
+ rewriter);
+ }
}
// Converts a call to an optional import by adding logic to check whether it
@@ -374,7 +354,7 @@
// Not resolved: call fallback as a normal function.
rewriter.setInsertionPointToStart(fallbackBlock);
auto fallbackResults = convertCallOp(rootOp, loc, fallbackName, operands,
- resultTypes, rewriter);
+ resultTypes, importTable, rewriter);
if (failed(fallbackResults))
return failure();
rewriter.create<IREE::VM::BranchOp>(loc, exitBlock, *fallbackResults);
@@ -557,12 +537,14 @@
void populateStandardToVMPatterns(MLIRContext *context,
TypeConverter &typeConverter,
+ ImportTable &importTable,
RewritePatternSet &patterns) {
patterns
- .insert<AssertOpConversion, BranchOpConversion, CallOpConversion,
- CondBranchOpConversion, SwitchOpConversion, ModuleOpConversion,
- FuncOpConversion, ExternalFuncOpConversion, ReturnOpConversion>(
- typeConverter, context);
+ .insert<AssertOpConversion, BranchOpConversion, CondBranchOpConversion,
+ SwitchOpConversion, ModuleOpConversion, FuncOpConversion,
+ ExternalFuncOpConversion, ReturnOpConversion>(typeConverter,
+ context);
+ patterns.insert<CallOpConversion>(typeConverter, context, importTable);
patterns.insert<CastingOpConversion<mlir::UnrealizedConversionCastOp>>(
typeConverter, context);
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h
index 37af2bf..b26e9ff 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/StandardToVM/Patterns.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_STANDARDTOVM_CONVERTSTANDARDTOVM_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_STANDARDTOVM_CONVERTSTANDARDTOVM_H_
+#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -15,6 +16,7 @@
// Appends standard dialect to vm dialect patterns to the given pattern list.
void populateStandardToVMPatterns(MLIRContext *context,
TypeConverter &typeConverter,
+ ImportTable &importTable,
RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp
index 87f0b56..bc0ce19 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/ConvertStructuralOps.cpp
@@ -146,7 +146,6 @@
constexpr const char *kRetainedAttributes[] = {
"nosideeffects",
"vm.fallback",
- "vm.signature",
};
auto retainedAttributes = ArrayRef<const char *>(
kRetainedAttributes,
@@ -217,8 +216,12 @@
}
};
-class CallOpConversion : public OpConversionPattern<IREE::Util::CallOp> {
- using OpConversionPattern::OpConversionPattern;
+struct CallOpConversion : public OpConversionPattern<IREE::Util::CallOp> {
+ ImportTable &importTable;
+ CallOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
+ ImportTable &importTable, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ importTable(importTable) {}
LogicalResult
matchAndRewrite(IREE::Util::CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -255,35 +258,10 @@
convertCallOp(Operation *rootOp, Location loc, StringRef calleeName,
ValueRange operands, TypeRange resultTypes,
ConversionPatternRewriter &rewriter) const {
- // (Slow) lookup of the target function, which may be an import that we need
- // to perform type conversion for.
- auto calleeOp = SymbolTable::lookupSymbolIn(rootOp, calleeName);
- if (auto funcOp = dyn_cast_or_null<IREE::Util::FuncOp>(calleeOp)) {
- if (funcOp.isExternal()) {
- // Import that may require conversion.
- // This case handles when funcs are declared after the call.
- FunctionType convertedSignature;
- if (auto signatureAttr =
- funcOp->getAttrOfType<TypeAttr>("vm.signature")) {
- if (auto importSignature =
- llvm::dyn_cast<FunctionType>(signatureAttr.getValue())) {
- convertedSignature = importSignature;
- }
- }
- if (!convertedSignature) {
- convertedSignature =
- rewriter.getFunctionType(TypeRange(operands), resultTypes);
- }
- return convertImportCallOp(rootOp, loc, calleeName, operands,
- resultTypes, convertedSignature, funcOp,
- rewriter);
- }
- } else if (auto importOp = dyn_cast_or_null<IREE::VM::ImportOp>(calleeOp)) {
- // Calling an import.
- // This case handles when funcs are declared before the call and have
- // already been converted.
- return convertImportCallOp(rootOp, loc, calleeName, operands, resultTypes,
- importOp.getFunctionType(), importOp,
+ // Lookup the target and detect if it is an import.
+ auto import = importTable.find(calleeName);
+ if (import.has_value()) {
+ return convertImportCallOp(rootOp, loc, *import, operands, resultTypes,
rewriter);
}
@@ -296,19 +274,19 @@
// Converts a call to an import that may be optional.
// Returns the new converted call results.
FailureOr<SmallVector<Value>>
- convertImportCallOp(Operation *rootOp, Location loc, StringRef calleeName,
- ValueRange operands, TypeRange resultTypes,
- FunctionType importSignature, Operation *calleeOp,
+ convertImportCallOp(Operation *rootOp, Location loc,
+ ImportTable::Import &import, ValueRange operands,
+ TypeRange resultTypes,
ConversionPatternRewriter &rewriter) const {
- auto fallbackAttr = calleeOp->getAttrOfType<SymbolRefAttr>("vm.fallback");
- return fallbackAttr
- ? convertOptionalImportCallOp(
- rootOp, loc, calleeName, operands, resultTypes,
- importSignature,
- fallbackAttr.getLeafReference().getValue(), rewriter)
- : convertMandatoryImportCallOp(rootOp, loc, calleeName, operands,
- resultTypes, importSignature,
- rewriter);
+ if (import.fallback) {
+ return convertOptionalImportCallOp(
+ rootOp, loc, import.name, operands, resultTypes, import.signature,
+ import.fallback.getLeafReference().getValue(), rewriter);
+ } else {
+ return convertMandatoryImportCallOp(rootOp, loc, import.name, operands,
+ resultTypes, import.signature,
+ rewriter);
+ }
}
// Converts a call to an optional import by adding logic to check whether it
@@ -405,13 +383,14 @@
void populateUtilStructuralToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
+ ImportTable &importTable,
RewritePatternSet &patterns) {
conversionTarget.addIllegalOp<IREE::Util::InitializerOp, IREE::Util::FuncOp,
IREE::Util::CallOp, IREE::Util::ReturnOp>();
- patterns
- .insert<InitializerOpConversion, FuncOpConversion,
- ExternalFuncOpConversion, CallOpConversion, ReturnOpConversion>(
- typeConverter, context);
+ patterns.insert<InitializerOpConversion, FuncOpConversion,
+ ExternalFuncOpConversion, ReturnOpConversion>(typeConverter,
+ context);
+ patterns.insert<CallOpConversion>(typeConverter, context, importTable);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp
index 2d60a81..f257369 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.cpp
@@ -45,6 +45,7 @@
void populateUtilStructuralToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
+ ImportTable &importTable,
RewritePatternSet &patterns);
namespace {
@@ -127,6 +128,7 @@
void populateUtilToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
+ ImportTable &importTable,
RewritePatternSet &patterns) {
patterns.insert<NullOpConversion>(typeConverter, context);
patterns.insert<CmpEQOpConversion>(typeConverter, context);
@@ -146,7 +148,7 @@
populateUtilStatusToVMPatterns(context, conversionTarget, typeConverter,
patterns);
populateUtilStructuralToVMPatterns(context, conversionTarget, typeConverter,
- patterns);
+ importTable, patterns);
}
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h
index 4faf2d8..baa13fa 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/UtilToVM/Patterns.h
@@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_VM_CONVERSION_UTILTOVM_PATTERNS_H_
#define IREE_COMPILER_DIALECT_VM_CONVERSION_UTILTOVM_PATTERNS_H_
+#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -16,6 +17,7 @@
void populateUtilToVMPatterns(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
+ ImportTable &importTable,
RewritePatternSet &patterns);
} // namespace mlir::iree_compiler
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
index 266b1e1..203a15b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp
@@ -791,12 +791,12 @@
result.addAttributes(attrs);
}
-LogicalResult ConstRefRodataOp::verify() {
+LogicalResult
+ConstRefRodataOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
Operation *op = getOperation();
- auto *rodataOp =
- op->getParentOfType<VM::ModuleOp>().lookupSymbol(getRodata());
- if (!rodataOp) {
- return op->emitOpError() << "Undefined rodata section: " << getRodata();
+ if (!symbolTable.lookupNearestSymbolFrom(op, getRodataAttr())) {
+ return op->emitError() << "undefined rodata section: '" << getRodata()
+ << "'";
}
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
index 6e9899d..c23e687 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
@@ -1137,6 +1137,7 @@
}
def VM_ConstRefRodataOp : VM_PureOp<"const.ref.rodata", [
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<VM_SerializableOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
@@ -1170,8 +1171,6 @@
OpBuilder<(ins "RodataOp":$rodataOp,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
];
-
- let hasVerifier = 1;
}
def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
index 38c659a..3d59c2b 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
@@ -129,11 +129,14 @@
}
}
+ // Populated below after all type converters are registered.
+ ImportTable importTable;
+
RewritePatternSet patterns(&getContext());
populateUtilConversionPatterns(context, conversionTarget, typeConverter,
patterns);
populateUtilToVMPatterns(context, conversionTarget, typeConverter,
- patterns);
+ importTable, patterns);
conversionTarget.addIllegalDialect<affine::AffineDialect>();
populateAffineToStdConversionPatterns(patterns);
@@ -146,7 +149,7 @@
populateMathToVMPatterns(context, typeConverter, patterns);
conversionTarget.addIllegalDialect<func::FuncDialect>();
- populateStandardToVMPatterns(context, typeConverter, patterns);
+ populateStandardToVMPatterns(context, typeConverter, importTable, patterns);
// Populate patterns from all used dialects, providing the imports they
// registered earlier.
@@ -156,6 +159,12 @@
importSymbols, patterns, conversionTarget, typeConverter);
}
+ // Build an import table so that we can quickly look up import information
+ // during conversion.
+ if (failed(importTable.build(innerModuleOp, typeConverter))) {
+ return signalPassFailure(); // error emitted already
+ }
+
if (failed(applyPartialConversion(outerModuleOp, conversionTarget,
std::move(patterns)))) {
outerModuleOp.emitError() << "conversion to vm.module failed";
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
index 871087c..ffce82e 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/GlobalInitialization.cpp
@@ -93,28 +93,22 @@
explorer.initialize();
SmallVector<Operation *> deadOps;
explorer.forEachGlobal([&](const Explorer::GlobalInfo *globalInfo) {
- if (globalInfo->uses.empty()) {
- // No uses - erase the global entirely.
- deadOps.push_back(globalInfo->op);
- } else {
- // TODO(benvanik): verify we want this behavior - we likely want to change
- // this to be mutable only if stores exist outside of initializers.
- //
- // If there are stores mark the global as mutable. We need to update all
- // of the loads if this changes anything.
- bool hasStores = !globalInfo->getStores().empty();
- bool didChange = globalInfo->op.isGlobalMutable() != hasStores;
+ if (globalInfo->uses.empty())
+ return;
+ // TODO(benvanik): verify we want this behavior - we likely want to change
+ // this to be mutable only if stores exist outside of initializers.
+ //
+ // If there are stores mark the global as mutable. We need to update all
+ // of the loads if this changes anything.
+ bool hasStores = !globalInfo->getStores().empty();
+ bool didChange = globalInfo->op.isGlobalMutable() != hasStores;
+ if (didChange) {
globalInfo->op.setGlobalMutable(hasStores);
- if (didChange) {
- for (auto loadOp : globalInfo->getLoads())
- loadOp.setGlobalImmutable(!hasStores);
+ for (auto loadOp : globalInfo->getLoads()) {
+ loadOp.setGlobalImmutable(!hasStores);
}
}
- for (auto loadOp : globalInfo->getLoads())
- loadOp.setGlobalImmutable(!globalInfo->op.isGlobalMutable());
});
- for (auto *deadOp : deadOps)
- deadOp->erase();
}
} // namespace
@@ -171,8 +165,7 @@
InlinerInterface inlinerInterface(&getContext());
SmallVector<Operation *> deadOps;
for (auto &op : moduleOp.getBlock().getOperations()) {
- if (auto globalOp = dyn_cast<IREE::VM::GlobalRefOp>(op)) {
- } else if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
+ if (auto globalOp = dyn_cast<IREE::Util::GlobalOpInterface>(op)) {
if (llvm::isa<IREE::VM::RefType>(globalOp.getGlobalType())) {
if (failed(appendRefInitialization(globalOp, initBuilder))) {
globalOp.emitOpError()
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 665d4f0..890d9c5 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -35,21 +35,24 @@
// TODO(benvanik): run in a fixed-point iteration pipeline.
// Standard MLIR cleanup.
- passManager.addPass(mlir::createCanonicalizerPass());
- passManager.addPass(mlir::createCSEPass());
+ FunctionLikeNest(passManager)
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass);
// Aggressive MLIR cleanup.
passManager.addNestedPass<IREE::VM::ModuleOp>(
IREE::VM::createDropUnusedCallsPass());
passManager.addPass(mlir::createSymbolDCEPass());
- // Simplify util.global accesses; this can help with data flow tracking as
- // redundant store-loads are removed.
FunctionLikeNest(passManager)
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ // Simplify util.global accesses; this can help with data flow tracking as
+ // redundant store-loads are removed.
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
}
diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
index afee21c..3fc5682 100644
--- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
+++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp
@@ -108,10 +108,12 @@
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
diff --git a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
index bd61d4b..94c78b9 100644
--- a/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
+++ b/compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
@@ -182,16 +182,17 @@
FunctionLikeNest(mainPassManager)
.addPass(createGlobalLoopInvariantCodeMotionPass)
.addPass(IREE::Flow::createCanonicalizerPass)
- .addPass(mlir::createCSEPass);
+ .addPass(mlir::createCSEPass)
- // Simplify util.global accesses early on; this can help with dispatch
- // region formation as redundant store-loads are removed.
- FunctionLikeNest(mainPassManager)
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ // Simplify util.global accesses early on; this can help with dispatch
+ // region formation as redundant store-loads are removed.
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Module level cleanup and canonicalization of util.global (and other
// util ops).
- mainPassManager.addPass(IREE::Util::createApplyPatternsPass());
mainPassManager.addPass(IREE::Util::createFoldGlobalsPass());
mainPassManager.addPass(IREE::Util::createIPOPass());
diff --git a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp
index 0d68029c..68b4799 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Inline/Transforms/Passes.cpp
@@ -27,17 +27,19 @@
//===----------------------------------------------------------------------===//
static void addCleanupPatterns(OpPassManager &passManager) {
- // Standard MLIR cleanup.
- passManager.addPass(mlir::createCanonicalizerPass());
- passManager.addPass(mlir::createCSEPass());
-
FunctionLikeNest(passManager)
+ // Standard MLIR cleanup.
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
}
diff --git a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp
index 96c7eb8..7adccce 100644
--- a/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Modules/HAL/Loader/Transforms/Passes.cpp
@@ -27,17 +27,19 @@
//===----------------------------------------------------------------------===//
static void addCleanupPatterns(OpPassManager &passManager) {
- // Standard MLIR cleanup.
- passManager.addPass(mlir::createCanonicalizerPass());
- passManager.addPass(mlir::createCSEPass());
-
FunctionLikeNest(passManager)
+ // Standard MLIR cleanup.
+ .addPass(mlir::createCanonicalizerPass)
+ .addPass(mlir::createCSEPass)
+
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
- .addPass(IREE::Util::createSimplifyGlobalAccessesPass);
+ .addPass(IREE::Util::createSimplifyGlobalAccessesPass)
+
+ // Aggressive cleanup.
+ .addPass(IREE::Util::createApplyPatternsPass);
// Cleanup and canonicalization of util.global (and other util ops).
- passManager.addPass(IREE::Util::createApplyPatternsPass());
passManager.addPass(IREE::Util::createFoldGlobalsPass());
passManager.addPass(IREE::Util::createFuseGlobalsPass());
}