Switching VM's EraseUnusedCallOp pattern to a pass. (#18950)
This lets it use a symbol table to speed things up _a lot_.
Includes various fixes found during testing.
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index 2413bed..883f7fb 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -78,7 +78,9 @@
deps = [
"//compiler/bindings/c:headers",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
+ "//compiler/src/iree/compiler/Dialect/VM/Target:init_targets",
"//compiler/src/iree/compiler/PluginAPI:PluginManager",
+ "//compiler/src/iree/compiler/Tools:init_llvmir_translations",
"//compiler/src/iree/compiler/Tools:init_passes_and_dialects",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Debug",
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index 191ea93..3ee76d9 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -80,7 +80,9 @@
MLIRPass
MLIRSupport
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::VM::Target::init_targets
iree::compiler::PluginAPI::PluginManager
+ iree::compiler::Tools::init_llvmir_translations
iree::compiler::Tools::init_passes_and_dialects
iree::compiler::bindings::c::headers
PUBLIC
diff --git a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
index b621724..4d3b07c 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
@@ -9,8 +9,10 @@
// Based on mlir-opt but registers the passes and dialects we care about.
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/VM/Target/init_targets.h"
#include "iree/compiler/PluginAPI/PluginManager.h"
#include "iree/compiler/Tools/init_dialects.h"
+#include "iree/compiler/Tools/init_llvmir_translations.h"
#include "iree/compiler/Tools/init_passes.h"
#include "iree/compiler/tool_entry_points_api.h"
#include "llvm/Support/InitLLVM.h"
@@ -145,6 +147,8 @@
mlir::DialectRegistry registry;
mlir::iree_compiler::registerAllDialects(registry);
mlir::iree_compiler::registerAllPasses();
+ mlir::iree_compiler::registerVMTargets();
+ mlir::iree_compiler::registerLLVMIRTranslations(registry);
// Register the pass to drop embedded transform dialect IR.
// TODO: this should be upstreamed.
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index bac70cd..bec93d4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -3142,49 +3142,8 @@
SwapInvertedCondBranchOpTargets>(context);
}
-namespace {
-
-/// Removes vm.call ops to functions that are marked as having no side-effects
-/// if the results are unused.
-template <typename T>
-struct EraseUnusedCallOp : public OpRewritePattern<T> {
- using OpRewritePattern<T>::OpRewritePattern;
- LogicalResult matchAndRewrite(T op,
- PatternRewriter &rewriter) const override {
- // First check if the call is unused - this ensures we only do the symbol
- // lookup if we are actually going to use it.
- for (auto result : op.getResults()) {
- if (!result.use_empty()) {
- return failure();
- }
- }
-
- auto *calleeOp = SymbolTable::lookupSymbolIn(
- op->template getParentOfType<ModuleOp>(), op.getCallee());
-
- bool hasNoSideEffects = false;
- if (calleeOp->getAttr("nosideeffects")) {
- hasNoSideEffects = true;
- } else if (auto import = dyn_cast<ImportInterface>(calleeOp)) {
- hasNoSideEffects = !import.hasSideEffects();
- }
- if (!hasNoSideEffects) {
- // Op has side-effects (or may have them); can't remove.
- return failure();
- }
-
- // Erase op as it is unused.
- rewriter.eraseOp(op);
- return success();
- }
-};
-
-} // namespace
-
void CallOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.insert<EraseUnusedCallOp<CallOp>>(context);
-}
+ MLIRContext *context) {}
namespace {
@@ -3210,8 +3169,7 @@
void CallVariadicOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.insert<EraseUnusedCallOp<CallVariadicOp>, ConvertNonVariadicToCallOp>(
- context);
+ results.insert<ConvertNonVariadicToCallOp>(context);
}
namespace {
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir
index 17f2a89..82d16d3 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir
@@ -55,17 +55,12 @@
// ^bb2(%1 : i32):
// vm.return %1 : i32
// }
+}
- // CHECK-LABEL: @erase_unused_pure_call
- vm.func @erase_unused_pure_call(%arg0 : i32) {
- %0 = vm.call @nonvariadic_pure_func(%arg0) : (i32) -> i32
- %1 = vm.call.variadic @variadic_pure_func([%arg0]) : (i32 ...) -> i32
- // CHECK-NEXT: vm.return
- vm.return
- }
- vm.import private @nonvariadic_pure_func(%arg0 : i32) -> i32 attributes {nosideeffects}
- vm.import private @variadic_pure_func(%arg0 : i32 ...) -> i32 attributes {nosideeffects}
+// -----
+// CHECK-LABEL: @call_folds
+vm.module @call_folds {
// CHECK-LABEL: @convert_nonvariadic_to_call
vm.func @convert_nonvariadic_to_call(%arg0 : i32) -> (i32, i32) {
// CHECK-NEXT: vm.call @nonvariadic_func(%arg0) : (i32) -> i32
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel
index bf11310..381e452 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel
@@ -18,6 +18,7 @@
"Conversion.cpp",
"DeduplicateRodata.cpp",
"DropEmptyModuleInitializers.cpp",
+ "DropUnusedCalls.cpp",
"GlobalInitialization.cpp",
"HoistInlinedRodata.cpp",
"OrdinalAllocation.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 71c5aa4..c3c8968 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@
"Conversion.cpp"
"DeduplicateRodata.cpp"
"DropEmptyModuleInitializers.cpp"
+ "DropUnusedCalls.cpp"
"GlobalInitialization.cpp"
"HoistInlinedRodata.cpp"
"OrdinalAllocation.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
new file mode 100644
index 0000000..9690bca
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
@@ -0,0 +1,104 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::VM {
+
+namespace {
+
+/// Removes vm.call ops to functions that are marked as having no side-effects
+/// if the results are unused.
+template <typename T>
+struct EraseUnusedCallOp : public OpRewritePattern<T> {
+ DenseSet<StringRef> &noSideEffectsSymbols;
+ EraseUnusedCallOp(MLIRContext *context,
+ DenseSet<StringRef> &noSideEffectsSymbols,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<T>(context, benefit),
+ noSideEffectsSymbols(noSideEffectsSymbols) {}
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ // First check if the call is unused - this ensures we only do the symbol
+ // lookup if we are actually going to use it.
+ for (auto result : op.getResults()) {
+ if (!result.use_empty()) {
+ return failure();
+ }
+ }
+
+ // Check that
+ bool hasNoSideEffects = noSideEffectsSymbols.contains(op.getCallee());
+ if (!hasNoSideEffects) {
+ // Op has side-effects (or may have them); can't remove.
+ return failure();
+ }
+
+ // Erase op as it is unused.
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+} // namespace
+
+class DropUnusedCallsPass
+ : public PassWrapper<DropUnusedCallsPass,
+ OperationPass<IREE::VM::ModuleOp>> {
+public:
+ StringRef getArgument() const override { return "iree-vm-drop-unused-calls"; }
+
+ StringRef getDescription() const override {
+ return "Drops vm.call ops that have no side effects and are unused.";
+ }
+
+ void runOnOperation() override {
+ auto moduleOp = getOperation();
+ SymbolTable symbolTable(moduleOp);
+
+ // Find all top-level symbols that have no side effects.
+ DenseSet<StringRef> noSideEffectsSymbols;
+ for (auto symbolOp : moduleOp.getOps<SymbolOpInterface>()) {
+ if (symbolOp->getAttr("nosideeffects")) {
+ noSideEffectsSymbols.insert(symbolOp.getName());
+ } else if (auto importOp =
+ dyn_cast<ImportInterface>(symbolOp.getOperation())) {
+ if (!importOp.hasSideEffects()) {
+ noSideEffectsSymbols.insert(symbolOp.getName());
+ }
+ }
+ }
+
+ // Remove all unused calls.
+ // Note that we want to remove entire chains of unused calls and run this
+ // as a pattern application.
+ RewritePatternSet patterns(&getContext());
+ // patterns
+ patterns.insert<EraseUnusedCallOp<IREE::VM::CallOp>,
+ EraseUnusedCallOp<IREE::VM::CallVariadicOp>>(
+ &getContext(), noSideEffectsSymbols);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createDropUnusedCallsPass() {
+ return std::make_unique<DropUnusedCallsPass>();
+}
+
+static PassRegistration<DropUnusedCallsPass> pass;
+
+} // namespace mlir::iree_compiler::IREE::VM
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 05852aa..665d4f0 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -38,6 +38,11 @@
passManager.addPass(mlir::createCanonicalizerPass());
passManager.addPass(mlir::createCSEPass());
+ // Aggressive MLIR cleanup.
+ passManager.addNestedPass<IREE::VM::ModuleOp>(
+ IREE::VM::createDropUnusedCallsPass());
+ passManager.addPass(mlir::createSymbolDCEPass());
+
// Simplify util.global accesses; this can help with data flow tracking as
// redundant store-loads are removed.
FunctionLikeNest(passManager)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
index 9b569e6..88c0ade 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
@@ -86,6 +86,9 @@
std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
createDropEmptyModuleInitializersPass();
+// Drops unused calls to functions marked as having no side effects.
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createDropUnusedCallsPass();
+
// Sinks defining ops with few uses to their use-sites to reduce the total
// number of live registers at the cost of additional storage requirements.
std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createSinkDefiningOpsPass();
@@ -101,6 +104,7 @@
createHoistInlinedRodataPass();
createDeduplicateRodataPass();
createDropEmptyModuleInitializersPass();
+ createDropUnusedCallsPass();
createGlobalInitializationPass();
createOrdinalAllocationPass();
createResolveRodataLoadsPass();
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel
index 6f0bdd9..937957d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel
@@ -18,6 +18,7 @@
[
"deduplicate_rodata.mlir",
"drop_empty_module_initializers.mlir",
+ "drop_unused_calls.mlir",
"global_initialization.mlir",
"hoist_inlined_rodata.mlir",
"ordinal_allocation.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
index 7e2336f..e854e95 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
@@ -16,6 +16,7 @@
SRCS
"deduplicate_rodata.mlir"
"drop_empty_module_initializers.mlir"
+ "drop_unused_calls.mlir"
"global_initialization.mlir"
"hoist_inlined_rodata.mlir"
"ordinal_allocation.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir
new file mode 100644
index 0000000..bbd8b0b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(vm.module(iree-vm-drop-unused-calls))" %s | FileCheck %s
+
+// CHECK-LABEL: @drop_calls
+vm.module public @drop_calls {
+ // CHECK: vm.func @fn
+ vm.func @fn(%arg0 : i32) {
+ // CHECK-NOT: vm.call @nonvariadic_pure_func
+ %0 = vm.call @nonvariadic_pure_func(%arg0) : (i32) -> i32
+ // CHECK-NOT: vm.call.variadic @variadic_pure_func
+ %1 = vm.call.variadic @variadic_pure_func([%arg0]) : (i32 ...) -> i32
+ // CHECK-NEXT: vm.return
+ vm.return
+ }
+ vm.import private @nonvariadic_pure_func(%arg0 : i32) -> i32 attributes {nosideeffects}
+ vm.import private @variadic_pure_func(%arg0 : i32 ...) -> i32 attributes {nosideeffects}
+}
+
+// -----
+
+// CHECK-LABEL: @drop_call_trees
+vm.module public @drop_call_trees {
+ // CHECK: vm.func @fn
+ vm.func @fn(%arg0 : i32) {
+ // CHECK: vm.call @impure_func
+ %0 = vm.call @impure_func(%arg0) : (i32) -> i32
+ // CHECK-NOT: vm.call @pure_func_a
+ %1 = vm.call @pure_func_a(%0) : (i32) -> i32
+ // CHECK-NOT: vm.call @pure_func_b
+ %2 = vm.call @pure_func_b(%1) : (i32) -> i32
+ // CHECK-NEXT: vm.return
+ vm.return
+ }
+ vm.import private @impure_func(%arg0 : i32) -> i32
+ vm.import private @pure_func_a(%arg0 : i32) -> i32 attributes {nosideeffects}
+ vm.import private @pure_func_b(%arg0 : i32) -> i32 attributes {nosideeffects}
+}