Switching VM's EraseUnusedCallOp pattern to a pass. (#18950)

This lets it use a symbol table to speed things up _a lot_.

Includes various fixes found during testing.
diff --git a/compiler/src/iree/compiler/API/Internal/BUILD.bazel b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
index 2413bed..883f7fb 100644
--- a/compiler/src/iree/compiler/API/Internal/BUILD.bazel
+++ b/compiler/src/iree/compiler/API/Internal/BUILD.bazel
@@ -78,7 +78,9 @@
     deps = [
         "//compiler/bindings/c:headers",
         "//compiler/src/iree/compiler/Dialect/HAL/Target",
+        "//compiler/src/iree/compiler/Dialect/VM/Target:init_targets",
         "//compiler/src/iree/compiler/PluginAPI:PluginManager",
+        "//compiler/src/iree/compiler/Tools:init_llvmir_translations",
         "//compiler/src/iree/compiler/Tools:init_passes_and_dialects",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Debug",
diff --git a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
index 191ea93..3ee76d9 100644
--- a/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
+++ b/compiler/src/iree/compiler/API/Internal/CMakeLists.txt
@@ -80,7 +80,9 @@
     MLIRPass
     MLIRSupport
     iree::compiler::Dialect::HAL::Target
+    iree::compiler::Dialect::VM::Target::init_targets
     iree::compiler::PluginAPI::PluginManager
+    iree::compiler::Tools::init_llvmir_translations
     iree::compiler::Tools::init_passes_and_dialects
     iree::compiler::bindings::c::headers
   PUBLIC
diff --git a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
index b621724..4d3b07c 100644
--- a/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
+++ b/compiler/src/iree/compiler/API/Internal/IREEOptToolEntryPoint.cpp
@@ -9,8 +9,10 @@
 // Based on mlir-opt but registers the passes and dialects we care about.
 
 #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/VM/Target/init_targets.h"
 #include "iree/compiler/PluginAPI/PluginManager.h"
 #include "iree/compiler/Tools/init_dialects.h"
+#include "iree/compiler/Tools/init_llvmir_translations.h"
 #include "iree/compiler/Tools/init_passes.h"
 #include "iree/compiler/tool_entry_points_api.h"
 #include "llvm/Support/InitLLVM.h"
@@ -145,6 +147,8 @@
   mlir::DialectRegistry registry;
   mlir::iree_compiler::registerAllDialects(registry);
   mlir::iree_compiler::registerAllPasses();
+  mlir::iree_compiler::registerVMTargets();
+  mlir::iree_compiler::registerLLVMIRTranslations(registry);
 
   // Register the pass to drop embedded transform dialect IR.
   // TODO: this should be upstreamed.
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
index bac70cd..bec93d4 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
@@ -3142,49 +3142,8 @@
                  SwapInvertedCondBranchOpTargets>(context);
 }
 
-namespace {
-
-/// Removes vm.call ops to functions that are marked as having no side-effects
-/// if the results are unused.
-template <typename T>
-struct EraseUnusedCallOp : public OpRewritePattern<T> {
-  using OpRewritePattern<T>::OpRewritePattern;
-  LogicalResult matchAndRewrite(T op,
-                                PatternRewriter &rewriter) const override {
-    // First check if the call is unused - this ensures we only do the symbol
-    // lookup if we are actually going to use it.
-    for (auto result : op.getResults()) {
-      if (!result.use_empty()) {
-        return failure();
-      }
-    }
-
-    auto *calleeOp = SymbolTable::lookupSymbolIn(
-        op->template getParentOfType<ModuleOp>(), op.getCallee());
-
-    bool hasNoSideEffects = false;
-    if (calleeOp->getAttr("nosideeffects")) {
-      hasNoSideEffects = true;
-    } else if (auto import = dyn_cast<ImportInterface>(calleeOp)) {
-      hasNoSideEffects = !import.hasSideEffects();
-    }
-    if (!hasNoSideEffects) {
-      // Op has side-effects (or may have them); can't remove.
-      return failure();
-    }
-
-    // Erase op as it is unused.
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
-} // namespace
-
 void CallOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                         MLIRContext *context) {
-  results.insert<EraseUnusedCallOp<CallOp>>(context);
-}
+                                         MLIRContext *context) {}
 
 namespace {
 
@@ -3210,8 +3169,7 @@
 
 void CallVariadicOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.insert<EraseUnusedCallOp<CallVariadicOp>, ConvertNonVariadicToCallOp>(
-      context);
+  results.insert<ConvertNonVariadicToCallOp>(context);
 }
 
 namespace {
diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir
index 17f2a89..82d16d3 100644
--- a/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir
+++ b/compiler/src/iree/compiler/Dialect/VM/IR/test/control_flow_folding.mlir
@@ -55,17 +55,12 @@
   // ^bb2(%1 : i32):
   //   vm.return %1 : i32
   // }
+}
 
-  // CHECK-LABEL: @erase_unused_pure_call
-  vm.func @erase_unused_pure_call(%arg0 : i32) {
-    %0 = vm.call @nonvariadic_pure_func(%arg0) : (i32) -> i32
-    %1 = vm.call.variadic @variadic_pure_func([%arg0]) : (i32 ...) -> i32
-    // CHECK-NEXT: vm.return
-    vm.return
-  }
-  vm.import private @nonvariadic_pure_func(%arg0 : i32) -> i32 attributes {nosideeffects}
-  vm.import private @variadic_pure_func(%arg0 : i32 ...) -> i32 attributes {nosideeffects}
+// -----
 
+// CHECK-LABEL: @call_folds
+vm.module @call_folds {
   // CHECK-LABEL: @convert_nonvariadic_to_call
   vm.func @convert_nonvariadic_to_call(%arg0 : i32) -> (i32, i32) {
     // CHECK-NEXT: vm.call @nonvariadic_func(%arg0) : (i32) -> i32
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel
index bf11310..381e452 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/BUILD.bazel
@@ -18,6 +18,7 @@
         "Conversion.cpp",
         "DeduplicateRodata.cpp",
         "DropEmptyModuleInitializers.cpp",
+        "DropUnusedCalls.cpp",
         "GlobalInitialization.cpp",
         "HoistInlinedRodata.cpp",
         "OrdinalAllocation.cpp",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 71c5aa4..c3c8968 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@
     "Conversion.cpp"
     "DeduplicateRodata.cpp"
     "DropEmptyModuleInitializers.cpp"
+    "DropUnusedCalls.cpp"
     "GlobalInitialization.cpp"
     "HoistInlinedRodata.cpp"
     "OrdinalAllocation.cpp"
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
new file mode 100644
index 0000000..9690bca
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/DropUnusedCalls.cpp
@@ -0,0 +1,104 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir::iree_compiler::IREE::VM {
+
+namespace {
+
+/// Removes vm.call ops to functions that are marked as having no side-effects
+/// if the results are unused.
+template <typename T>
+struct EraseUnusedCallOp : public OpRewritePattern<T> {
+  DenseSet<StringRef> &noSideEffectsSymbols;
+  EraseUnusedCallOp(MLIRContext *context,
+                    DenseSet<StringRef> &noSideEffectsSymbols,
+                    PatternBenefit benefit = 1)
+      : OpRewritePattern<T>(context, benefit),
+        noSideEffectsSymbols(noSideEffectsSymbols) {}
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    // First check if the call is unused - this ensures we only do the symbol
+    // lookup if we are actually going to use it.
+    for (auto result : op.getResults()) {
+      if (!result.use_empty()) {
+        return failure();
+      }
+    }
+
+    // Check that
+    bool hasNoSideEffects = noSideEffectsSymbols.contains(op.getCallee());
+    if (!hasNoSideEffects) {
+      // Op has side-effects (or may have them); can't remove.
+      return failure();
+    }
+
+    // Erase op as it is unused.
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+} // namespace
+
+class DropUnusedCallsPass
+    : public PassWrapper<DropUnusedCallsPass,
+                         OperationPass<IREE::VM::ModuleOp>> {
+public:
+  StringRef getArgument() const override { return "iree-vm-drop-unused-calls"; }
+
+  StringRef getDescription() const override {
+    return "Drops vm.call ops that have no side effects and are unused.";
+  }
+
+  void runOnOperation() override {
+    auto moduleOp = getOperation();
+    SymbolTable symbolTable(moduleOp);
+
+    // Find all top-level symbols that have no side effects.
+    DenseSet<StringRef> noSideEffectsSymbols;
+    for (auto symbolOp : moduleOp.getOps<SymbolOpInterface>()) {
+      if (symbolOp->getAttr("nosideeffects")) {
+        noSideEffectsSymbols.insert(symbolOp.getName());
+      } else if (auto importOp =
+                     dyn_cast<ImportInterface>(symbolOp.getOperation())) {
+        if (!importOp.hasSideEffects()) {
+          noSideEffectsSymbols.insert(symbolOp.getName());
+        }
+      }
+    }
+
+    // Remove all unused calls.
+    // Note that we want to remove entire chains of unused calls and run this
+    // as a pattern application.
+    RewritePatternSet patterns(&getContext());
+    // patterns
+    patterns.insert<EraseUnusedCallOp<IREE::VM::CallOp>,
+                    EraseUnusedCallOp<IREE::VM::CallVariadicOp>>(
+        &getContext(), noSideEffectsSymbols);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createDropUnusedCallsPass() {
+  return std::make_unique<DropUnusedCallsPass>();
+}
+
+static PassRegistration<DropUnusedCallsPass> pass;
+
+} // namespace mlir::iree_compiler::IREE::VM
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index 05852aa..665d4f0 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -38,6 +38,11 @@
   passManager.addPass(mlir::createCanonicalizerPass());
   passManager.addPass(mlir::createCSEPass());
 
+  // Aggressive MLIR cleanup.
+  passManager.addNestedPass<IREE::VM::ModuleOp>(
+      IREE::VM::createDropUnusedCallsPass());
+  passManager.addPass(mlir::createSymbolDCEPass());
+
   // Simplify util.global accesses; this can help with data flow tracking as
   // redundant store-loads are removed.
   FunctionLikeNest(passManager)
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
index 9b569e6..88c0ade 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/Passes.h
@@ -86,6 +86,9 @@
 std::unique_ptr<OperationPass<IREE::VM::ModuleOp>>
 createDropEmptyModuleInitializersPass();
 
+// Drops unused calls to functions marked as having no side effects.
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createDropUnusedCallsPass();
+
 // Sinks defining ops with few uses to their use-sites to reduce the total
 // number of live registers at the cost of additional storage requirements.
 std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createSinkDefiningOpsPass();
@@ -101,6 +104,7 @@
   createHoistInlinedRodataPass();
   createDeduplicateRodataPass();
   createDropEmptyModuleInitializersPass();
+  createDropUnusedCallsPass();
   createGlobalInitializationPass();
   createOrdinalAllocationPass();
   createResolveRodataLoadsPass();
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel
index 6f0bdd9..937957d 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/BUILD.bazel
@@ -18,6 +18,7 @@
         [
             "deduplicate_rodata.mlir",
             "drop_empty_module_initializers.mlir",
+            "drop_unused_calls.mlir",
             "global_initialization.mlir",
             "hoist_inlined_rodata.mlir",
             "ordinal_allocation.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
index 7e2336f..e854e95 100644
--- a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/CMakeLists.txt
@@ -16,6 +16,7 @@
   SRCS
     "deduplicate_rodata.mlir"
     "drop_empty_module_initializers.mlir"
+    "drop_unused_calls.mlir"
     "global_initialization.mlir"
     "hoist_inlined_rodata.mlir"
     "ordinal_allocation.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir
new file mode 100644
index 0000000..bbd8b0b
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/VM/Transforms/test/drop_unused_calls.mlir
@@ -0,0 +1,36 @@
+// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(vm.module(iree-vm-drop-unused-calls))" %s | FileCheck %s
+
+// CHECK-LABEL: @drop_calls
+vm.module public @drop_calls {
+  // CHECK: vm.func @fn
+  vm.func @fn(%arg0 : i32) {
+    // CHECK-NOT: vm.call @nonvariadic_pure_func
+    %0 = vm.call @nonvariadic_pure_func(%arg0) : (i32) -> i32
+    // CHECK-NOT: vm.call.variadic @variadic_pure_func
+    %1 = vm.call.variadic @variadic_pure_func([%arg0]) : (i32 ...) -> i32
+    // CHECK-NEXT: vm.return
+    vm.return
+  }
+  vm.import private @nonvariadic_pure_func(%arg0 : i32) -> i32 attributes {nosideeffects}
+  vm.import private @variadic_pure_func(%arg0 : i32 ...) -> i32 attributes {nosideeffects}
+}
+
+// -----
+
+// CHECK-LABEL: @drop_call_trees
+vm.module public @drop_call_trees {
+  // CHECK: vm.func @fn
+  vm.func @fn(%arg0 : i32) {
+    // CHECK: vm.call @impure_func
+    %0 = vm.call @impure_func(%arg0) : (i32) -> i32
+    // CHECK-NOT: vm.call @pure_func_a
+    %1 = vm.call @pure_func_a(%0) : (i32) -> i32
+    // CHECK-NOT: vm.call @pure_func_b
+    %2 = vm.call @pure_func_b(%1) : (i32) -> i32
+    // CHECK-NEXT: vm.return
+    vm.return
+  }
+  vm.import private @impure_func(%arg0 : i32) -> i32
+  vm.import private @pure_func_a(%arg0 : i32) -> i32 attributes {nosideeffects}
+  vm.import private @pure_func_b(%arg0 : i32) -> i32 attributes {nosideeffects}
+}