Sinks single uses of ops in the entry block to their use. (#3523)

This dramatically reduces the total register pressure on a function after
CSE runs and moves every constant to the entry block.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index b4fb233..771b568 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -108,12 +108,12 @@
 //  CHECK-NEXT:       vm.module @linked_module {
 //  CHECK-NEXT:         vm.rodata @reduction_ex_dispatch_0_const dense<0.000000e+00> : tensor<1xf32>
 //  CHECK-NEXT:         vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-//  CHECK-NEXT:           %zero = vm.const.i32.zero : i32
-//  CHECK-NEXT:           %c128 = vm.const.i32 128 : i32
-//  CHECK-NEXT:           %c16 = vm.const.i32 16 : i32
-//  CHECK-NEXT:           %c4 = vm.const.i32 4 : i32
-//  CHECK-NEXT:           %c8 = vm.const.i32 8 : i32
-//  CHECK-NEXT:           %c1 = vm.const.i32 1 : i32
+//   CHECK-DAG:           %zero = vm.const.i32.zero : i32
+//   CHECK-DAG:           %c128 = vm.const.i32 128 : i32
+//   CHECK-DAG:           %c16 = vm.const.i32 16 : i32
+//   CHECK-DAG:           %c4 = vm.const.i32 4 : i32
+//   CHECK-DAG:           %c8 = vm.const.i32 8 : i32
+//   CHECK-DAG:           %c1 = vm.const.i32 1 : i32
 //  CHECK-NEXT:           %reduction_ex_dispatch_0_const = vm.const.ref.rodata @reduction_ex_dispatch_0_const : !vm.ref<!iree.byte_buffer>
 //  CHECK-NEXT:           %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
 //  CHECK-NEXT:           %ref_0 = vm.call @vmla.buffer.alloc(%c4) : (i32) -> !vm.ref<!vmla.buffer>
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD
index db938f0..6b152f9 100644
--- a/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -27,6 +27,7 @@
         "MarkPublicSymbolsExported.cpp",
         "OrdinalAllocation.cpp",
         "Passes.cpp",
+        "SinkDefiningOps.cpp",
     ],
     hdrs = [
         "Passes.h",
@@ -42,6 +43,7 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:SideEffectInterfaces",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:TransformUtils",
         "@llvm-project//mlir:Transforms",
diff --git a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 5704913..0a946fc 100644
--- a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -26,10 +26,12 @@
     "MarkPublicSymbolsExported.cpp"
     "OrdinalAllocation.cpp"
     "Passes.cpp"
+    "SinkDefiningOps.cpp"
   DEPS
     LLVMSupport
     MLIRIR
     MLIRPass
+    MLIRSideEffectInterfaces
     MLIRSupport
     MLIRTransformUtils
     MLIRTransforms
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index b72ba7d..900666f 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -35,6 +35,7 @@
   passManager.addPass(createInlinerPass());
   passManager.addPass(createCSEPass());
   passManager.addPass(createSymbolDCEPass());
+  passManager.addPass(createSinkDefiningOpsPass());
 }
 
 void registerVMTransformPassPipeline() {
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.h b/iree/compiler/Dialect/VM/Transforms/Passes.h
index a54fc05..fb453b6 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.h
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.h
@@ -83,6 +83,14 @@
 createOrdinalAllocationPass();
 
 //===----------------------------------------------------------------------===//
+// Optimization passes
+//===----------------------------------------------------------------------===//
+
+// Sinks defining ops with few uses to their use-sites to reduce the total
+// number of live registers at the cost of additional storage requirements.
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createSinkDefiningOpsPass();
+
+//===----------------------------------------------------------------------===//
 // Test passes
 //===----------------------------------------------------------------------===//
 
@@ -100,6 +108,7 @@
   createHoistInlinedRodataPass();
   createGlobalInitializationPass();
   createOrdinalAllocationPass();
+  createSinkDefiningOpsPass();
 }
 
 inline void registerVMTestPasses() {
diff --git a/iree/compiler/Dialect/VM/Transforms/SinkDefiningOps.cpp b/iree/compiler/Dialect/VM/Transforms/SinkDefiningOps.cpp
new file mode 100644
index 0000000..9a8d36f
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Transforms/SinkDefiningOps.cpp
@@ -0,0 +1,101 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+class SinkDefiningOpsPass
+    : public PassWrapper<SinkDefiningOpsPass, OperationPass<ModuleOp>> {
+ public:
+  void runOnOperation() override {
+    for (auto funcOp : getOperation().getOps<FuncOp>()) {
+      DominanceInfo domInfo(funcOp);
+
+      // Consider only those constant ops in the entry block.
+      SmallVector<std::pair<Operation *, Operation *>, 8> sinks;
+      for (auto &op : funcOp.getBlocks().front()) {
+        if (op.getNumResults() != 1 ||
+            !MemoryEffectOpInterface::hasNoEffect(&op)) {
+          // Probably not safe to move.
+          continue;
+        }
+
+        auto users = llvm::to_vector<4>(op.getUsers());
+        if (users.empty()) {
+          // No users (probably leftover needing DCE).
+          continue;
+        } else if (users.size() == 1) {
+          // Only a single user; safe to move.
+          sinks.push_back(std::make_pair(&op, users.front()));
+          continue;
+        }
+
+        // Find the common dominator block across all uses. This may be the
+        // entry block itself.
+        Block *commonDominator = users.front()->getBlock();
+        for (auto user : users) {
+          commonDominator = domInfo.findNearestCommonDominator(
+              commonDominator, user->getBlock());
+        }
+
+        // Find the first use within the dominator block (if any) so that we
+        // can sink down to it.
+        Operation *firstUserInDominator = commonDominator->getTerminator();
+        for (auto user : users) {
+          if (user->getBlock() == commonDominator) {
+            if (user->isBeforeInBlock(firstUserInDominator)) {
+              firstUserInDominator = user;
+            }
+          }
+        }
+
+        sinks.push_back(std::make_pair(&op, firstUserInDominator));
+      }
+
+      // Sink values after iterating.
+      for (auto &sink : sinks) {
+        sink.first->moveBefore(sink.second);
+      }
+    }
+  }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createSinkDefiningOpsPass() {
+  return std::make_unique<SinkDefiningOpsPass>();
+}
+
+static PassRegistration<SinkDefiningOpsPass> pass(
+    "iree-vm-sink-defining-ops",
+    "Sinks defining ops with few uses to their use-sites.");
+
+}  // namespace VM
+}  // namespace IREE
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/VM/Transforms/test/sink_defining_ops.mlir b/iree/compiler/Dialect/VM/Transforms/test/sink_defining_ops.mlir
new file mode 100644
index 0000000..5db14fb
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Transforms/test/sink_defining_ops.mlir
@@ -0,0 +1,70 @@
+// RUN: iree-opt -split-input-file -iree-vm-sink-defining-ops %s | IreeFileCheck %s
+
+vm.module @module {
+  // CHECK-LABEL: @single_uses
+  vm.func @single_uses(%arg0 : i32) -> i32 {
+    %c4 = vm.const.i32 4 : i32
+    %c5 = vm.const.i32 5 : i32
+    // CHECK: vm.cond_br
+    vm.cond_br %arg0, ^bb1, ^bb2
+  ^bb1:
+    //      CHECK: %c4 = vm.const.i32 4 : i32
+    // CHECK-NEXT: vm.return %c4 : i32
+    vm.return %c4 : i32
+  ^bb2:
+    //      CHECK: %c5 = vm.const.i32 5 : i32
+    // CHECK-NEXT: vm.return %c5 : i32
+    vm.return %c5 : i32
+  }
+}
+
+// -----
+
+vm.module @module {
+  // CHECK-LABEL: @multiple_uses
+  vm.func @multiple_uses(%arg0 : i32) -> (i32, i32) {
+    %c4 = vm.const.i32 4 : i32
+    %c5 = vm.const.i32 5 : i32
+    // CHECK: %c5 = vm.const.i32 5 : i32
+    // CHECK: vm.cond_br
+    vm.cond_br %arg0, ^bb1, ^bb2
+  ^bb1:
+    //      CHECK: %c4 = vm.const.i32 4 : i32
+    // CHECK-NEXT: vm.return %c4, %c5
+    vm.return %c4, %c5 : i32, i32
+  ^bb2:
+    // CHECK: vm.return %c5 : i32
+    vm.return %c5 : i32
+  }
+}
+
+// -----
+
+vm.module @module {
+  // CHECK-LABEL: @common_dominator
+  vm.func @common_dominator(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
+    %c4 = vm.const.i32 4 : i32
+    %c5 = vm.const.i32 5 : i32
+    %c6 = vm.const.i32 6 : i32
+    //      CHECK: %c5 = vm.const.i32 5 : i32
+    // CHECK-NEXT: vm.cond_br %arg0
+    vm.cond_br %arg0, ^bb1, ^bb_end
+  ^bb1:
+    //      CHECK: "test.do_something_else"
+    "test.do_something_else"() : () -> ()
+    // CHECK-NEXT: %c4 = vm.const.i32 4 : i32
+    // CHECK-NEXT: "test.do_thing"(%c4) : (i32) -> ()
+    "test.do_thing"(%c4) : (i32) -> ()
+    // CHECK-NEXT: vm.cond_br %arg1
+    vm.cond_br %arg1, ^bb2, ^bb_end
+  ^bb2:
+    "test.do_thing"(%c4) : (i32) -> ()
+    "test.do_thing"(%c5) : (i32) -> ()
+    // CHECK: vm.br
+    vm.br ^bb_end
+  ^bb_end:
+    //      CHECK: %c6 = vm.const.i32 6 : i32
+    // CHECK-NEXT: vm.return %c5, %c6
+    vm.return %c5, %c6 : i32, i32
+  }
+}