Sinks single uses of ops in the entry block to their use. (#3523)
This dramatically reduces the total register pressure on a function after
CSE runs and moves every constant to the entry block.
diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
index b4fb233..771b568 100644
--- a/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
+++ b/iree/compiler/Dialect/HAL/Target/VMLA/test/smoketest.mlir
@@ -108,12 +108,12 @@
// CHECK-NEXT: vm.module @linked_module {
// CHECK-NEXT: vm.rodata @reduction_ex_dispatch_0_const dense<0.000000e+00> : tensor<1xf32>
// CHECK-NEXT: vm.func @reduction_ex_dispatch_0(%arg0: !vm.ref<!vmla.interface>, %arg1: i32, %arg2: i32, %arg3: i32) {
-// CHECK-NEXT: %zero = vm.const.i32.zero : i32
-// CHECK-NEXT: %c128 = vm.const.i32 128 : i32
-// CHECK-NEXT: %c16 = vm.const.i32 16 : i32
-// CHECK-NEXT: %c4 = vm.const.i32 4 : i32
-// CHECK-NEXT: %c8 = vm.const.i32 8 : i32
-// CHECK-NEXT: %c1 = vm.const.i32 1 : i32
+// CHECK-DAG: %zero = vm.const.i32.zero : i32
+// CHECK-DAG: %c128 = vm.const.i32 128 : i32
+// CHECK-DAG: %c16 = vm.const.i32 16 : i32
+// CHECK-DAG: %c4 = vm.const.i32 4 : i32
+// CHECK-DAG: %c8 = vm.const.i32 8 : i32
+// CHECK-DAG: %c1 = vm.const.i32 1 : i32
// CHECK-NEXT: %reduction_ex_dispatch_0_const = vm.const.ref.rodata @reduction_ex_dispatch_0_const : !vm.ref<!iree.byte_buffer>
// CHECK-NEXT: %ref = vm.call @vmla.buffer.const(%reduction_ex_dispatch_0_const) : (!vm.ref<!iree.byte_buffer>) -> !vm.ref<!vmla.buffer>
// CHECK-NEXT: %ref_0 = vm.call @vmla.buffer.alloc(%c4) : (i32) -> !vm.ref<!vmla.buffer>
diff --git a/iree/compiler/Dialect/VM/Transforms/BUILD b/iree/compiler/Dialect/VM/Transforms/BUILD
index db938f0..6b152f9 100644
--- a/iree/compiler/Dialect/VM/Transforms/BUILD
+++ b/iree/compiler/Dialect/VM/Transforms/BUILD
@@ -27,6 +27,7 @@
"MarkPublicSymbolsExported.cpp",
"OrdinalAllocation.cpp",
"Passes.cpp",
+ "SinkDefiningOps.cpp",
],
hdrs = [
"Passes.h",
@@ -42,6 +43,7 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
diff --git a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
index 5704913..0a946fc 100644
--- a/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/VM/Transforms/CMakeLists.txt
@@ -26,10 +26,12 @@
"MarkPublicSymbolsExported.cpp"
"OrdinalAllocation.cpp"
"Passes.cpp"
+ "SinkDefiningOps.cpp"
DEPS
LLVMSupport
MLIRIR
MLIRPass
+ MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
MLIRTransforms
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.cpp b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
index b72ba7d..900666f 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.cpp
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.cpp
@@ -35,6 +35,7 @@
passManager.addPass(createInlinerPass());
passManager.addPass(createCSEPass());
passManager.addPass(createSymbolDCEPass());
+ passManager.addPass(createSinkDefiningOpsPass());
}
void registerVMTransformPassPipeline() {
diff --git a/iree/compiler/Dialect/VM/Transforms/Passes.h b/iree/compiler/Dialect/VM/Transforms/Passes.h
index a54fc05..fb453b6 100644
--- a/iree/compiler/Dialect/VM/Transforms/Passes.h
+++ b/iree/compiler/Dialect/VM/Transforms/Passes.h
@@ -83,6 +83,14 @@
createOrdinalAllocationPass();
//===----------------------------------------------------------------------===//
+// Optimization passes
+//===----------------------------------------------------------------------===//
+
+// Sinks defining ops with few uses to their use-sites to reduce the total
+// number of live registers at the cost of additional storage requirements.
+std::unique_ptr<OperationPass<IREE::VM::ModuleOp>> createSinkDefiningOpsPass();
+
+//===----------------------------------------------------------------------===//
// Test passes
//===----------------------------------------------------------------------===//
@@ -100,6 +108,7 @@
createHoistInlinedRodataPass();
createGlobalInitializationPass();
createOrdinalAllocationPass();
+ createSinkDefiningOpsPass();
}
inline void registerVMTestPasses() {
diff --git a/iree/compiler/Dialect/VM/Transforms/SinkDefiningOps.cpp b/iree/compiler/Dialect/VM/Transforms/SinkDefiningOps.cpp
new file mode 100644
index 0000000..9a8d36f
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Transforms/SinkDefiningOps.cpp
@@ -0,0 +1,101 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/compiler/Dialect/VM/IR/VMOps.h"
+#include "iree/compiler/Dialect/VM/Transforms/Passes.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace IREE {
+namespace VM {
+
+class SinkDefiningOpsPass
+ : public PassWrapper<SinkDefiningOpsPass, OperationPass<ModuleOp>> {
+ public:
+ void runOnOperation() override {
+ for (auto funcOp : getOperation().getOps<FuncOp>()) {
+ DominanceInfo domInfo(funcOp);
+
+ // Consider only those constant ops in the entry block.
+ SmallVector<std::pair<Operation *, Operation *>, 8> sinks;
+ for (auto &op : funcOp.getBlocks().front()) {
+ if (op.getNumResults() != 1 ||
+ !MemoryEffectOpInterface::hasNoEffect(&op)) {
+ // Probably not safe to move.
+ continue;
+ }
+
+ auto users = llvm::to_vector<4>(op.getUsers());
+ if (users.empty()) {
+ // No users (probably leftover needing DCE).
+ continue;
+ } else if (users.size() == 1) {
+ // Only a single user; safe to move.
+ sinks.push_back(std::make_pair(&op, users.front()));
+ continue;
+ }
+
+ // Find the common dominator block across all uses. This may be the
+ // entry block itself.
+ Block *commonDominator = users.front()->getBlock();
+ for (auto user : users) {
+ commonDominator = domInfo.findNearestCommonDominator(
+ commonDominator, user->getBlock());
+ }
+
+ // Find the first use within the dominator block (if any) so that we
+ // can sink down to it.
+ Operation *firstUserInDominator = commonDominator->getTerminator();
+ for (auto user : users) {
+ if (user->getBlock() == commonDominator) {
+ if (user->isBeforeInBlock(firstUserInDominator)) {
+ firstUserInDominator = user;
+ }
+ }
+ }
+
+ sinks.push_back(std::make_pair(&op, firstUserInDominator));
+ }
+
+ // Sink values after iterating.
+ for (auto &sink : sinks) {
+ sink.first->moveBefore(sink.second);
+ }
+ }
+ }
+};
+
+std::unique_ptr<OperationPass<ModuleOp>> createSinkDefiningOpsPass() {
+ return std::make_unique<SinkDefiningOpsPass>();
+}
+
+static PassRegistration<SinkDefiningOpsPass> pass(
+ "iree-vm-sink-defining-ops",
+ "Sinks defining ops with few uses to their use-sites.");
+
+} // namespace VM
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/VM/Transforms/test/sink_defining_ops.mlir b/iree/compiler/Dialect/VM/Transforms/test/sink_defining_ops.mlir
new file mode 100644
index 0000000..5db14fb
--- /dev/null
+++ b/iree/compiler/Dialect/VM/Transforms/test/sink_defining_ops.mlir
@@ -0,0 +1,70 @@
+// RUN: iree-opt -split-input-file -iree-vm-sink-defining-ops %s | IreeFileCheck %s
+
+vm.module @module {
+ // CHECK-LABEL: @single_uses
+ vm.func @single_uses(%arg0 : i32) -> i32 {
+ %c4 = vm.const.i32 4 : i32
+ %c5 = vm.const.i32 5 : i32
+ // CHECK: vm.cond_br
+ vm.cond_br %arg0, ^bb1, ^bb2
+ ^bb1:
+ // CHECK: %c4 = vm.const.i32 4 : i32
+ // CHECK-NEXT: vm.return %c4 : i32
+ vm.return %c4 : i32
+ ^bb2:
+ // CHECK: %c5 = vm.const.i32 5 : i32
+ // CHECK-NEXT: vm.return %c5 : i32
+ vm.return %c5 : i32
+ }
+}
+
+// -----
+
+vm.module @module {
+ // CHECK-LABEL: @multiple_uses
+ vm.func @multiple_uses(%arg0 : i32) -> (i32, i32) {
+ %c4 = vm.const.i32 4 : i32
+ %c5 = vm.const.i32 5 : i32
+ // CHECK: %c5 = vm.const.i32 5 : i32
+ // CHECK: vm.cond_br
+ vm.cond_br %arg0, ^bb1, ^bb2
+ ^bb1:
+ // CHECK: %c4 = vm.const.i32 4 : i32
+ // CHECK-NEXT: vm.return %c4, %c5
+ vm.return %c4, %c5 : i32, i32
+ ^bb2:
+ // CHECK: vm.return %c5 : i32
+ vm.return %c5 : i32
+ }
+}
+
+// -----
+
+vm.module @module {
+ // CHECK-LABEL: @common_dominator
+ vm.func @common_dominator(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
+ %c4 = vm.const.i32 4 : i32
+ %c5 = vm.const.i32 5 : i32
+ %c6 = vm.const.i32 6 : i32
+ // CHECK: %c5 = vm.const.i32 5 : i32
+ // CHECK-NEXT: vm.cond_br %arg0
+ vm.cond_br %arg0, ^bb1, ^bb_end
+ ^bb1:
+ // CHECK: "test.do_something_else"
+ "test.do_something_else"() : () -> ()
+ // CHECK-NEXT: %c4 = vm.const.i32 4 : i32
+ // CHECK-NEXT: "test.do_thing"(%c4) : (i32) -> ()
+ "test.do_thing"(%c4) : (i32) -> ()
+ // CHECK-NEXT: vm.cond_br %arg1
+ vm.cond_br %arg1, ^bb2, ^bb_end
+ ^bb2:
+ "test.do_thing"(%c4) : (i32) -> ()
+ "test.do_thing"(%c5) : (i32) -> ()
+ // CHECK: vm.br
+ vm.br ^bb_end
+ ^bb_end:
+ // CHECK: %c6 = vm.const.i32 6 : i32
+ // CHECK-NEXT: vm.return %c5, %c6
+ vm.return %c5, %c6 : i32, i32
+ }
+}