Reworking the DeduplicateExecutables to properly check the IR. (#4000)

diff --git a/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp b/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
index de0643e..b48b193 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
@@ -14,8 +14,12 @@
 
 #include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
 #include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -25,6 +29,203 @@
 
 namespace {
 
+template <typename Range, typename Pred>
+bool compare_ranges(Range &&lhs, Range &&rhs, Pred pred) {
+  auto lhsIt = lhs.begin();
+  auto rhsIt = rhs.begin();
+  while (lhsIt != lhs.end() && rhsIt != rhs.end()) {
+    if (!pred(*lhsIt++, *rhsIt++)) return false;
+  }
+  if ((lhsIt == lhs.end()) != (rhsIt == rhs.end())) {
+    // Block count mismatch. We do this here so that we avoid the O(n) scan
+    // that would have been required to calculate the size above.
+    return false;
+  }
+  return true;
+}
+
+static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs,
+                                       BlockAndValueMapping &parentMapping);
+static bool isStructurallyEquivalentTo(Operation &lhs, Operation &rhs,
+                                       BlockAndValueMapping &parentMapping);
+
+// Recursively compares two regions for structural equivalence.
+// Structural equivalence ensures that operations on both the |lhs| and |rhs|
+// have the same attributes and same use-def structure.
+//
+// Example:
+//   func @lhs(%arg0 : index) -> index {
+//     %c1 = constant 1 : index
+//     %0 = add %arg0, %c1 : index
+//     return %0 : index
+//   }
+//   func @rhs(%arg0 : index) -> index {
+//     %c1 = constant 1 : index
+//     %0 = add %arg0, %c1 : index
+//     return %0 : index
+//   }
+//
+//   assert(isStructurallyEquivalentTo(lhs.getBody(), rhs.getBody()));
+//
+// TODO(#3996): upstream into mlir::OperationEquivalence if this works.
+// TODO(#3996): add symbol ref comparison (add to BlockAndValueMapping).
+static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs) {
+  BlockAndValueMapping mapping;
+  return isStructurallyEquivalentTo(lhs, rhs, mapping);
+}
+
+static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs,
+                                       BlockAndValueMapping &mapping) {
+  // Use compare_ranges to walk the block list in parallel and get a boolean in
+  // the case of size mismatch without an O(N) linked-list size query.
+  if (!compare_ranges(
+          lhs.getBlocks(), rhs.getBlocks(),
+          [&](Block &lhsBlock, Block &rhsBlock) {
+            if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) {
+              return false;
+            }
+            for (auto argPair :
+                 llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
+              auto &lhsArg = std::get<0>(argPair);
+              auto &rhsArg = std::get<1>(argPair);
+              if (lhsArg.getType() != rhsArg.getType()) return false;
+              mapping.map(lhsArg, rhsArg);
+            }
+            mapping.map(&lhsBlock, &rhsBlock);
+            return true;
+          })) {
+    return false;  // block mismatch
+  }
+
+  // Walk the blocks again now that we have a populated mapping.
+  // We do this in topological order so that we have all values required by a
+  // block mapped by the time we reach it observing transitive block dominance.
+  llvm::SetVector<Block *> lhsBlocks;
+  for (Block &b : lhs.getBlocks()) {
+    llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+    lhsBlocks.insert(traversal.begin(), traversal.end());
+  }
+  llvm::SetVector<Block *> rhsBlocks;
+  for (Block &b : rhs.getBlocks()) {
+    llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+    rhsBlocks.insert(traversal.begin(), traversal.end());
+  }
+  for (auto blockPair : llvm::zip(lhsBlocks, rhsBlocks)) {
+    auto &lhsBlock = std::get<0>(blockPair);
+    auto &rhsBlock = std::get<1>(blockPair);
+    for (auto opPair :
+         llvm::zip(lhsBlock->getOperations(), rhsBlock->getOperations())) {
+      auto &lhsOp = std::get<0>(opPair);
+      auto &rhsOp = std::get<1>(opPair);
+      if (!isStructurallyEquivalentTo(lhsOp, rhsOp, mapping)) {
+        return false;
+      }
+    }
+  }
+
+  // Equivalent!
+  return true;
+}
+static bool isStructurallyEquivalentTo(Operation &lhs, Operation &rhs,
+                                       BlockAndValueMapping &parentMapping) {
+  // Check operation metadata for early-exit opportunities.
+  if (lhs.getName() != rhs.getName()) return false;
+  if (lhs.getNumOperands() != rhs.getNumOperands()) return false;
+  if (lhs.getNumResults() != rhs.getNumResults()) return false;
+  if (lhs.getNumRegions() != rhs.getNumRegions()) return false;
+  if (lhs.getNumSuccessors() != rhs.getNumSuccessors()) return false;
+
+  // TODO(#3996): symbol mapping; for now allow them to differ unconditionally.
+  if (!compare_ranges(
+          lhs.getAttrs(), rhs.getAttrs(),
+          [&](const NamedAttribute &lhs, const NamedAttribute &rhs) {
+            if (lhs.first == "function_ref" ||
+                lhs.first == SymbolTable::getSymbolAttrName()) {
+              return true;
+            }
+            return lhs == rhs;
+          })) {
+    return false;
+  }
+
+  // If the op references blocks (such as a branch) then we expect to have them
+  // in the mapping already from the parent region to do the lhs->rhs mapping.
+  for (auto successorPair :
+       llvm::zip(lhs.getSuccessors(), rhs.getSuccessors())) {
+    auto *lhsSuccessor = std::get<0>(successorPair);
+    auto *rhsSuccessor = std::get<1>(successorPair);
+    if (rhsSuccessor != parentMapping.lookup(lhsSuccessor)) return false;
+  }
+
+  // Ensure result types match first and add to the block and value mapping.
+  // For many ops if the result types don't match it's a good (cheap) indicator
+  // that the operands won't match either so this still allows a somewhat-early
+  // exit prior to the full traversal.
+  for (auto resultPair : llvm::zip(lhs.getResults(), rhs.getResults())) {
+    auto &lhsValue = std::get<0>(resultPair);
+    auto &rhsValue = std::get<1>(resultPair);
+    if (lhsValue.getType() != rhsValue.getType()) return false;
+    parentMapping.map(lhsValue, rhsValue);
+  }
+
+  // Check operands using the lhs->rhs mapping; since this op is only consuming
+  // these values they should already be defined in the mapping.
+  for (auto operandPair : llvm::zip(lhs.getOperands(), rhs.getOperands())) {
+    auto &lhsValue = std::get<0>(operandPair);
+    auto &rhsValue = std::get<1>(operandPair);
+    if (lhsValue.getType() != rhsValue.getType()) return false;
+    if (rhsValue != parentMapping.lookup(lhsValue)) return false;
+  }
+
+  // Recurse into regions.
+  for (auto regionPair : llvm::zip(lhs.getRegions(), rhs.getRegions())) {
+    auto &lhsRegion = std::get<0>(regionPair);
+    auto &rhsRegion = std::get<1>(regionPair);
+
+    // If the region is isolated we don't want to reuse any parent mapping or
+    // pollute it with our mappings.
+    BlockAndValueMapping scopedRegionMapping;
+    BlockAndValueMapping regionMapping =
+        lhs.isKnownIsolatedFromAbove() ? scopedRegionMapping : parentMapping;
+
+    if (!isStructurallyEquivalentTo(lhsRegion, rhsRegion, regionMapping)) {
+      return false;
+    }
+  }
+
+  // Equivalent!
+  return true;
+}
+
+bool areExecutablesEquivalent(ExecutableOp lhs, ExecutableOp rhs) {
+  auto lhsModule = lhs.getInnerModule();
+  auto rhsModule = rhs.getInnerModule();
+
+  // Must have the same number of entry point ops, with the same attributes.
+  // Entry point op symbol names are expected to differ, that won't affect
+  // equivalence.
+  if (!compare_ranges(lhsModule.getOps<DispatchEntryOp>(),
+                      rhsModule.getOps<DispatchEntryOp>(),
+                      [](DispatchEntryOp lhs, DispatchEntryOp rhs) {
+                        return lhs.getAttrs() == rhs.getAttrs();
+                      })) {
+    return false;  // dispatch entry mismatch
+  }
+
+  // Walk all functions and ensure equivalent.
+  if (!compare_ranges(lhsModule.getOps<FuncOp>(), rhsModule.getOps<FuncOp>(),
+                      [](FuncOp lhs, FuncOp rhs) {
+                        if (lhs.getType() != rhs.getType()) return false;
+                        if (lhs.getAttrs() != rhs.getAttrs()) return false;
+                        return isStructurallyEquivalentTo(lhs.getRegion(),
+                                                          rhs.getRegion());
+                      })) {
+    return false;  // dispatch entry mismatch
+  }
+
+  return true;
+}
+
 // Replaces each usage of an entry point with its original symbol name with a
 // new symbol name.
 void replaceEntryPointUses(
@@ -40,96 +241,6 @@
   }
 }
 
-bool areRegionsEquivalent(Region *lhs, Region *rhs) {
-  if (lhs->getBlocks().size() != rhs->getBlocks().size()) {
-    return false;
-  }
-
-  for (auto blockPair : llvm::zip(lhs->getBlocks(), rhs->getBlocks())) {
-    auto &lhsBlock = std::get<0>(blockPair);
-    auto &rhsBlock = std::get<1>(blockPair);
-    // Warning: .size() is linear time.
-    // We could instead iterate through both lists of operations explicitly,
-    // stopping when operations are not equivalent, OR either list runs out of
-    // operations early.
-    if (lhsBlock.getOperations().size() != rhsBlock.getOperations().size()) {
-      return false;
-    }
-
-    for (auto opPair :
-         llvm::zip(lhsBlock.getOperations(), rhsBlock.getOperations())) {
-      auto &lhsOp = std::get<0>(opPair);
-      auto &rhsOp = std::get<1>(opPair);
-      if (!OperationEquivalence::isEquivalentTo(
-              &lhsOp, &rhsOp, OperationEquivalence::IgnoreOperands)) {
-        return false;
-      }
-
-      // We want to check the operand _types_, but don't care if the actual
-      // operand references differ (as they live in separate modules anyway).
-      if (!std::equal(lhsOp.operand_type_begin(), lhsOp.operand_type_end(),
-                      rhsOp.operand_type_begin())) {
-        return false;
-      }
-
-      // If the operations have regions, recurse into them (depth-first).
-      if (lhsOp.getNumRegions() != rhsOp.getNumRegions()) {
-        return false;
-      }
-      auto lhsRegions = lhsOp.getRegions();
-      auto rhsRegions = rhsOp.getRegions();
-      for (int i = 0; i < lhsRegions.size(); ++i) {
-        if (!areRegionsEquivalent(&lhsRegions[i], &rhsRegions[i])) {
-          return false;
-        }
-      }
-    }
-  }
-
-  return true;
-}
-
-bool areExecutablesEquivalent(ExecutableOp lhs, ExecutableOp rhs) {
-  auto lhsModule = lhs.getInnerModule();
-  auto rhsModule = rhs.getInnerModule();
-
-  // TODO(scotttodd): Generalize: replace special cases with just calling
-  //   areRegionsEquivalent() on module.getBodyRegion(). We want to ignore
-  //   operation names and sym_name attrs, which
-  //   OperationEquivalence::isEquivalentTo() does not support [yet].
-
-  // Must have the same number of entry point ops, with the same attributes.
-  // Entry point op symbol names are expected to differ, that won't affect
-  // equivalence.
-  auto lhsEntryOps = llvm::to_vector<1>(lhsModule.getOps<DispatchEntryOp>());
-  auto rhsEntryOps = llvm::to_vector<1>(rhsModule.getOps<DispatchEntryOp>());
-  if (lhsEntryOps.size() != rhsEntryOps.size()) {
-    return false;
-  }
-  for (int i = 0; i < lhsEntryOps.size(); ++i) {
-    if (lhsEntryOps[i].getAttrs() != rhsEntryOps[i].getAttrs()) {
-      return false;
-    }
-  }
-
-  // Must have the same number of functions, with each listed in the same order
-  // and with equivalent regions inside.
-  auto lhsFuncOps = llvm::to_vector<1>(lhsModule.getOps<FuncOp>());
-  auto rhsFuncOps = llvm::to_vector<1>(rhsModule.getOps<FuncOp>());
-  if (lhsFuncOps.size() != rhsFuncOps.size()) {
-    return false;
-  }
-  for (int i = 0; i < lhsFuncOps.size(); ++i) {
-    auto lhsRegion = lhsFuncOps[i].getCallableRegion();
-    auto rhsRegion = rhsFuncOps[i].getCallableRegion();
-    if (!areRegionsEquivalent(lhsRegion, rhsRegion)) {
-      return false;
-    }
-  }
-
-  return true;
-}
-
 }  // namespace
 
 class DeduplicateExecutablesPass
@@ -152,9 +263,9 @@
 
       for (int j = 0; j < i; ++j) {
         auto referenceExecutableOp = executableOps[j];
-
-        if (!areExecutablesEquivalent(duplicateExecutableOp,
-                                      referenceExecutableOp)) {
+        if (!isStructurallyEquivalentTo(
+                duplicateExecutableOp.getBodyRegion(),
+                referenceExecutableOp.getBodyRegion())) {
           continue;
         }
 
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
index 946cd10..aa0d239 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
@@ -64,6 +64,38 @@
 
 // -----
 
+// CHECK: flow.executable @same_ops_diff_operands_ex_0
+flow.executable @same_ops_diff_operands_ex_0 {
+  flow.dispatch.entry @entry_0
+  module {
+    func @entry_0(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> {
+      %0 = mhlo.multiply %arg0, %arg1 : tensor<2xi32>
+      return %0 : tensor<2xi32>
+    }
+  }
+}
+// CHECK: flow.executable @same_ops_diff_operands_ex_1
+flow.executable @same_ops_diff_operands_ex_1 {
+  flow.dispatch.entry @entry_1
+  module {
+    func @entry_1(%arg0: tensor<2xi32>) -> tensor<2xi32> {
+      %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
+      return %0 : tensor<2xi32>
+    }
+  }
+}
+// CHECK-LABEL: func @same_ops_diff_operands
+func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> {
+  %c4 = constant 4 : index
+  // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: flow.executable @multiple_entry_points_ex_0
 flow.executable @multiple_entry_points_ex_0 {
   flow.dispatch.entry @multiple_entry_points_0_entry_0
@@ -142,7 +174,6 @@
 
 // -----
 
-
 // CHECK-LABEL: flow.executable @nested_ops_ex_0
 flow.executable @nested_ops_ex_0 {
   flow.dispatch.entry @nested_ops_entry_0
@@ -199,3 +230,89 @@
   %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32>
   return %0 : tensor<1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: flow.executable @attributes_ex_0
+flow.executable @attributes_ex_0 {
+  flow.dispatch.entry @attributes_entry_0
+  module {
+    func @attributes_entry_0(%input: tensor<1x4xi32>) -> tensor<1xi32> {
+      %0 = constant dense<0> : tensor<i32>
+      %1 = "mhlo.reduce"(%input, %0) ( {
+      ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):   // no predecessors
+        %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+        "mhlo.return"(%3) : (tensor<i32>) -> ()
+      }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32>
+      return %1 : tensor<1xi32>
+    }
+  }
+}
+
+// CHECK-LABEL: flow.executable @attributes_ex_1
+flow.executable @attributes_ex_1 {
+  flow.dispatch.entry @attributes_entry_1
+  module {
+    func @attributes_entry_1(%input: tensor<1x4xi32>) -> tensor<1xi32> {
+      %0 = constant dense<0> : tensor<i32>
+      %1 = "mhlo.reduce"(%input, %0) ( {
+      ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):   // no predecessors
+        %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+        "mhlo.return"(%3) : (tensor<i32>) -> ()
+        // @attributes_ex_0 but with a different attribute.
+      }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32>
+      return %1 : tensor<1xi32>
+    }
+  }
+}
+// Duplicate of @attributes_ex_0
+// CHECK-NOT: flow.executable @attributes_ex_2
+flow.executable @attributes_ex_2 {
+  flow.dispatch.entry @attributes_entry_2
+  module {
+    func @attributes_entry_2(%input: tensor<1x4xi32>) -> tensor<1xi32> {
+      %0 = constant dense<0> : tensor<i32>
+      %1 = "mhlo.reduce"(%input, %0) ( {
+      ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):   // no predecessors
+        %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+        "mhlo.return"(%3) : (tensor<i32>) -> ()
+      }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32>
+      return %1 : tensor<1xi32>
+    }
+  }
+}
+
+// -----
+
+// CHECK-LABEL: flow.executable @block_successors_ex_0
+flow.executable @block_successors_ex_0 {
+  flow.dispatch.entry @entry_0
+  module {
+    func @entry_0(%arg0: i32, %arg1: i32) -> i32 {
+      %c0 = constant 0 : i32
+      %c1 = constant 1 : i32
+      %eqz = cmpi "eq", %arg0, %arg1 : i32
+      cond_br %eqz, ^bb_a(%c0 : i32), ^bb_b(%c1 : i32)
+    ^bb_a(%bb_a_arg0 : i32):
+      return %bb_a_arg0 : i32
+    ^bb_b(%bb_b_arg0 : i32):
+      return %bb_b_arg0 : i32
+    }
+  }
+}
+// CHECK-LABEL: flow.executable @block_successors_ex_with_swapped_cond_br
+flow.executable @block_successors_ex_with_swapped_cond_br {
+  flow.dispatch.entry @entry_1
+  module {
+    func @entry_0(%arg0: i32, %arg1: i32) -> i32 {
+      %c0 = constant 0 : i32
+      %c1 = constant 1 : i32
+      %eqz = cmpi "eq", %arg0, %arg1 : i32
+      cond_br %eqz, ^bb_b(%c0 : i32), ^bb_b(%c1 : i32)
+    ^bb_a(%bb_a_arg0 : i32):
+      return %bb_a_arg0 : i32
+    ^bb_b(%bb_b_arg0 : i32):
+      return %bb_b_arg0 : i32
+    }
+  }
+}