Basic shape calculation hoisting pass.

This should be enough to unblock basic dispatch region formation.

PiperOrigin-RevId: 298924505
diff --git a/iree/compiler/Dialect/Shape/Transforms/BUILD b/iree/compiler/Dialect/Shape/Transforms/BUILD
index 70e2747..a32cad3 100644
--- a/iree/compiler/Dialect/Shape/Transforms/BUILD
+++ b/iree/compiler/Dialect/Shape/Transforms/BUILD
@@ -23,6 +23,7 @@
         "CleanupPlaceholders.cpp",
         "ConvertHLOToShapeDialect.cpp",
         "ExpandFunctionDynamicDims.cpp",
+        "HoistShapeCalculations.cpp",
         "MaterializeShapeCalculations.cpp",
         "Passes.cpp",
         "TieDynamicShapes.cpp",
@@ -33,6 +34,7 @@
     deps = [
         "//iree/compiler/Dialect/Shape/IR",
         "//iree/compiler/Dialect/Shape/Plugins/XLA:XlaHloShapeBuilder",
+        "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:StandardOps",
diff --git a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
index e2bd871..accadbd 100644
--- a/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/Shape/Transforms/CMakeLists.txt
@@ -23,10 +23,12 @@
     "CleanupPlaceholders.cpp"
     "ConvertHLOToShapeDialect.cpp"
     "ExpandFunctionDynamicDims.cpp"
+    "HoistShapeCalculations.cpp"
     "MaterializeShapeCalculations.cpp"
     "Passes.cpp"
     "TieDynamicShapes.cpp"
   DEPS
+    MLIRAnalysis
     MLIRIR
     MLIRPass
     MLIRStandardOps
diff --git a/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculations.cpp b/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculations.cpp
new file mode 100644
index 0000000..2f78ee2
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculations.cpp
@@ -0,0 +1,159 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+#include <iterator>
+
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeTypes.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace {
+
+bool isSimpleShapeCalculationOp(Operation *op) {
+  // The op must have no side effects.
+  if (!op->hasNoSideEffect()) {
+    return false;
+  }
+  // The op should operate on types that are likely shape calculations.
+  // The exact predicate used here isn't too important. The main important thing
+  // is that we don't want to include ops on tensors.
+  for (Type type : op->getOperandTypes()) {
+    if (type.isa<IntegerType>() || type.isa<Shape::RankedShapeType>()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+// Return an operation in `block` that defines `v`, if one exists.
+Operation *getDefiningOpInBlock(Value v, Block &block) {
+  if (OpResult opResult = v.dyn_cast<OpResult>()) {
+    if (opResult.getOwner()->getBlock() == &block) {
+      return opResult.getOwner();
+    }
+  }
+  return nullptr;
+}
+
+DenseSet<Operation *> calculateOpsToHoist(Block &block) {
+  // Strategy:
+  // Backward DFS from shapex.tie_shape shape operands (second
+  // operands), staying within the block and not incorporating ops that don't
+  // satisfy `isSimpleShapeCalculationOp`.
+
+  SmallVector<Operation *, 16> worklist;
+  // The return value, and also used as a "visited" set for our DFS below.
+  DenseSet<Operation *> opsToHoistSet;
+
+  // The worklist is initially populated with shape-producing ops defined in
+  // this block.
+  //
+  // We are only hoisting ops within the block, so block arguments and any
+  // values defined outside the block (which already dominate the entire block)
+  // don't matter.
+  for (Operation &op : block) {
+    if (auto tieShape = dyn_cast<Shape::TieShapeOp>(op)) {
+      if (Operation *op = getDefiningOpInBlock(tieShape.shape(), block)) {
+        worklist.push_back(op);
+      }
+    }
+  }
+  while (!worklist.empty()) {
+    Operation *op = worklist.pop_back_val();
+    if (!isSimpleShapeCalculationOp(op)) {
+      continue;
+    }
+    if (opsToHoistSet.insert(op).second) {
+      for (Value v : op->getOperands()) {
+        if (Operation *op = getDefiningOpInBlock(v, block)) {
+          worklist.push_back(op);
+        }
+      }
+    }
+  }
+  return opsToHoistSet;
+}
+
+void hoistOps(DenseSet<Operation *> opsToHoistSet, Block &block,
+              DominanceInfo &domInfo) {
+  auto opsToHoist = llvm::to_vector<16>(opsToHoistSet);
+  llvm::sort(opsToHoist, [&](Operation *lhs, Operation *rhs) {
+    return domInfo.properlyDominates(lhs, rhs);
+  });
+
+  for (Operation *op : opsToHoist) {
+    Operation *insertAfter = nullptr;
+    for (Value operand : op->getOperands()) {
+      if (Operation *definingOp = getDefiningOpInBlock(operand, block)) {
+        if (insertAfter == nullptr ||
+            domInfo.properlyDominates(definingOp, insertAfter)) {
+          insertAfter = definingOp;
+        }
+      }
+    }
+    if (insertAfter != nullptr) {
+      op->moveBefore(&*std::next(insertAfter->getIterator()));
+    } else {
+      op->moveBefore(&block, block.begin());
+    }
+  }
+}
+
+// Best-effort pass for hoisting shape calculations earlier in the program.
+// We currently don't provide any hard guarantees about exactly what invariants
+// are established by this pass.
+//
+// The goal of this pass is to unblock further progress on dynamic shape
+// support. One pragmatic thing we observe is that for IREE, dispatch region
+// formation requires that when there is a `shapex.tie_shape %tensor, %shape`
+// op, to even properly form the dispatch region, IREE needs `%shape` to
+// dominate `%tensor` since the dispatch region's "workload" is derived from the
+// shape.
+//
+// This pass doesn't have a cost model, so it shouldn't be considered a generic
+// "hoist stuff to make things faster" type of pass. It's strictly a
+// best-effort pass to make certain lowerings work, albeit on somewhat shaky
+// ground. Longer-term, IREE's dispatch region formation will use a more
+// sophisticated algorithm and the analysis/hoisting done here will be a
+// byproduct of the dispatch region formation legality analysis/preparation.
+class HoistShapeCalculations : public FunctionPass<HoistShapeCalculations> {
+ public:
+  void runOnFunction() override {
+    auto func = getFunction();
+    DominanceInfo domInfo(func);
+    for (Block &block : func) {
+      DenseSet<Operation *> opsToHoist = calculateOpsToHoist(block);
+      hoistOps(opsToHoist, block, domInfo);
+    }
+  }
+};
+}  // namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> createHoistShapeCalculationsPass() {
+  return std::make_unique<HoistShapeCalculations>();  // NOLINT
+}
+
+static PassRegistration<HoistShapeCalculations> pass(
+    "iree-shape-hoist-shape-calculations",
+    "Best-effort shape calculation hoisting.");
+
+}  // namespace iree_compiler
+}  // namespace mlir
diff --git a/iree/compiler/Dialect/Shape/Transforms/Passes.h b/iree/compiler/Dialect/Shape/Transforms/Passes.h
index 1ed4997..32fe813 100644
--- a/iree/compiler/Dialect/Shape/Transforms/Passes.h
+++ b/iree/compiler/Dialect/Shape/Transforms/Passes.h
@@ -49,6 +49,10 @@
 // dialect.
 std::unique_ptr<OpPassBase<FuncOp>> createConvertHLOToShapePass();
 
+// Best-effort hoisting of shape calculations to attempt to establish the
+// invariant that shape.tie_shape second operand dominates the first operand.
+std::unique_ptr<OpPassBase<FuncOp>> createHoistShapeCalculationsPass();
+
 }  // namespace iree_compiler
 }  // namespace mlir
 
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir b/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir
new file mode 100644
index 0000000..8a2cdad
--- /dev/null
+++ b/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir
@@ -0,0 +1,55 @@
+// RUN: iree-opt -split-input-file -verify-diagnostics -iree-shape-hoist-shape-calculations %s | IreeFileCheck %s
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: index) {
+  // CHECK: shapex.make_ranked_shape
+  // CHECK: addf
+  %t = addf %arg0, %arg0 : tensor<?xf32>
+  %shape = shapex.make_ranked_shape %arg1 -> !shapex.ranked_shape<[?]>
+  shapex.tie_shape %t, %shape : tensor<?xf32>, !shapex.ranked_shape<[?]>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>) {
+  // CHECK: addf
+  // CHECK: dim
+  // CHECK: shapex.make_ranked_shape
+  %t = addf %arg0, %arg0 : tensor<?xf32>
+  %dim = dim %t, 0 : tensor<?xf32>
+  %shape = shapex.make_ranked_shape %dim -> !shapex.ranked_shape<[?]>
+  shapex.tie_shape %t, %shape : tensor<?xf32>, !shapex.ranked_shape<[?]>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: index) {
+  // CHECK: addi
+  // CHECK: muli
+  // CHECK: shapex.make_ranked_shape
+  // CHECK: some_dialect.some_op
+  "some_dialect.some_op"() : () -> ()
+  %addi = addi %arg1, %arg1 : index
+  %dim = muli %addi, %addi : index
+  %shape = shapex.make_ranked_shape %dim -> !shapex.ranked_shape<[?]>
+  shapex.tie_shape %arg0, %shape : tensor<?xf32>, !shapex.ranked_shape<[?]>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<?xf32>, %arg1: index) {
+  // CHECK: some_dialect.some_op
+  // CHECK: some_dialect.side_effecting_muli
+  // CHECK: shapex.make_ranked_shape
+  "some_dialect.some_op"() : () -> ()
+  %dim = "some_dialect.side_effecting_muli"(%arg1, %arg1) : (index, index) -> index
+  %shape = shapex.make_ranked_shape %dim -> !shapex.ranked_shape<[?]>
+  shapex.tie_shape %arg0, %shape : tensor<?xf32>, !shapex.ranked_shape<[?]>
+  return
+}