Reworking the DeduplicateExecutables to properly check the IR. (#4000)
diff --git a/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp b/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp index de0643e..b48b193 100644 --- a/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/DeduplicateExecutables.cpp
@@ -14,8 +14,12 @@ #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/RegionGraphTraits.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -25,6 +29,203 @@ namespace { +template <typename Range, typename Pred> +bool compare_ranges(Range &&lhs, Range &&rhs, Pred pred) { + auto lhsIt = lhs.begin(); + auto rhsIt = rhs.begin(); + while (lhsIt != lhs.end() && rhsIt != rhs.end()) { + if (!pred(*lhsIt++, *rhsIt++)) return false; + } + if ((lhsIt == lhs.end()) != (rhsIt == rhs.end())) { + // Block count mismatch. We do this here so that we avoid the O(n) scan + // that would have been required to calculate the size above. + return false; + } + return true; +} + +static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs, + BlockAndValueMapping &parentMapping); +static bool isStructurallyEquivalentTo(Operation &lhs, Operation &rhs, + BlockAndValueMapping &parentMapping); + +// Recursively compares two regions for structural equivalence. +// Structural equivalence ensures that operations on both the |lhs| and |rhs| +// have the same attributes and same use-def structure. +// +// Example: +// func @lhs(%arg0 : index) -> index { +// %c1 = constant 1 : index +// %0 = add %arg0, %c1 : index +// return %0 : index +// } +// func @rhs(%arg0 : index) -> index { +// %c1 = constant 1 : index +// %0 = add %arg0, %c1 : index +// return %0 : index +// } +// +// assert(isStructurallyEquivalentTo(lhs.getBody(), rhs.getBody())); +// +// TODO(#3996): upstream into mlir::OperationEquivalence if this works. +// TODO(#3996): add symbol ref comparison (add to BlockAndValueMapping). +static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs) { + BlockAndValueMapping mapping; + return isStructurallyEquivalentTo(lhs, rhs, mapping); +} + +static bool isStructurallyEquivalentTo(Region &lhs, Region &rhs, + BlockAndValueMapping &mapping) { + // Use compare_ranges to walk the block list in parallel and get a boolean in + // the case of size mismatch without an O(N) linked-list size query. + if (!compare_ranges( + lhs.getBlocks(), rhs.getBlocks(), + [&](Block &lhsBlock, Block &rhsBlock) { + if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) { + return false; + } + for (auto argPair : + llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) { + auto &lhsArg = std::get<0>(argPair); + auto &rhsArg = std::get<1>(argPair); + if (lhsArg.getType() != rhsArg.getType()) return false; + mapping.map(lhsArg, rhsArg); + } + mapping.map(&lhsBlock, &rhsBlock); + return true; + })) { + return false; // block mismatch + } + + // Walk the blocks again now that we have a populated mapping. + // We do this in topological order so that we have all values required by a + // block mapped by the time we reach it observing transitive block dominance. + llvm::SetVector<Block *> lhsBlocks; + for (Block &b : lhs.getBlocks()) { + llvm::ReversePostOrderTraversal<Block *> traversal(&b); + lhsBlocks.insert(traversal.begin(), traversal.end()); + } + llvm::SetVector<Block *> rhsBlocks; + for (Block &b : rhs.getBlocks()) { + llvm::ReversePostOrderTraversal<Block *> traversal(&b); + rhsBlocks.insert(traversal.begin(), traversal.end()); + } + for (auto blockPair : llvm::zip(lhsBlocks, rhsBlocks)) { + auto &lhsBlock = std::get<0>(blockPair); + auto &rhsBlock = std::get<1>(blockPair); + for (auto opPair : + llvm::zip(lhsBlock->getOperations(), rhsBlock->getOperations())) { + auto &lhsOp = std::get<0>(opPair); + auto &rhsOp = std::get<1>(opPair); + if (!isStructurallyEquivalentTo(lhsOp, rhsOp, mapping)) { + return false; + } + } + } + + // Equivalent! + return true; +} +static bool isStructurallyEquivalentTo(Operation &lhs, Operation &rhs, + BlockAndValueMapping &parentMapping) { + // Check operation metadata for early-exit opportunities. + if (lhs.getName() != rhs.getName()) return false; + if (lhs.getNumOperands() != rhs.getNumOperands()) return false; + if (lhs.getNumResults() != rhs.getNumResults()) return false; + if (lhs.getNumRegions() != rhs.getNumRegions()) return false; + if (lhs.getNumSuccessors() != rhs.getNumSuccessors()) return false; + + // TODO(#3996): symbol mapping; for now allow them to differ unconditionally. + if (!compare_ranges( + lhs.getAttrs(), rhs.getAttrs(), + [&](const NamedAttribute &lhs, const NamedAttribute &rhs) { + if (lhs.first == "function_ref" || + lhs.first == SymbolTable::getSymbolAttrName()) { + return true; + } + return lhs == rhs; + })) { + return false; + } + + // If the op references blocks (such as a branch) then we expect to have them + // in the mapping already from the parent region to do the lhs->rhs mapping. + for (auto successorPair : + llvm::zip(lhs.getSuccessors(), rhs.getSuccessors())) { + auto *lhsSuccessor = std::get<0>(successorPair); + auto *rhsSuccessor = std::get<1>(successorPair); + if (rhsSuccessor != parentMapping.lookup(lhsSuccessor)) return false; + } + + // Ensure result types match first and add to the block and value mapping. + // For many ops if the result types don't match it's a good (cheap) indicator + // that the operands won't match either so this still allows a somewhat-early + // exit prior to the full traversal. + for (auto resultPair : llvm::zip(lhs.getResults(), rhs.getResults())) { + auto &lhsValue = std::get<0>(resultPair); + auto &rhsValue = std::get<1>(resultPair); + if (lhsValue.getType() != rhsValue.getType()) return false; + parentMapping.map(lhsValue, rhsValue); + } + + // Check operands using the lhs->rhs mapping; since this op is only consuming + // these values they should already be defined in the mapping. + for (auto operandPair : llvm::zip(lhs.getOperands(), rhs.getOperands())) { + auto &lhsValue = std::get<0>(operandPair); + auto &rhsValue = std::get<1>(operandPair); + if (lhsValue.getType() != rhsValue.getType()) return false; + if (rhsValue != parentMapping.lookup(lhsValue)) return false; + } + + // Recurse into regions. + for (auto regionPair : llvm::zip(lhs.getRegions(), rhs.getRegions())) { + auto &lhsRegion = std::get<0>(regionPair); + auto &rhsRegion = std::get<1>(regionPair); + + // If the region is isolated we don't want to reuse any parent mapping or + // pollute it with our mappings. + BlockAndValueMapping scopedRegionMapping; + BlockAndValueMapping regionMapping = + lhs.isKnownIsolatedFromAbove() ? scopedRegionMapping : parentMapping; + + if (!isStructurallyEquivalentTo(lhsRegion, rhsRegion, regionMapping)) { + return false; + } + } + + // Equivalent! + return true; +} + +bool areExecutablesEquivalent(ExecutableOp lhs, ExecutableOp rhs) { + auto lhsModule = lhs.getInnerModule(); + auto rhsModule = rhs.getInnerModule(); + + // Must have the same number of entry point ops, with the same attributes. + // Entry point op symbol names are expected to differ, that won't affect + // equivalence. + if (!compare_ranges(lhsModule.getOps<DispatchEntryOp>(), + rhsModule.getOps<DispatchEntryOp>(), + [](DispatchEntryOp lhs, DispatchEntryOp rhs) { + return lhs.getAttrs() == rhs.getAttrs(); + })) { + return false; // dispatch entry mismatch + } + + // Walk all functions and ensure equivalent. + if (!compare_ranges(lhsModule.getOps<FuncOp>(), rhsModule.getOps<FuncOp>(), + [](FuncOp lhs, FuncOp rhs) { + if (lhs.getType() != rhs.getType()) return false; + if (lhs.getAttrs() != rhs.getAttrs()) return false; + return isStructurallyEquivalentTo(lhs.getRegion(), + rhs.getRegion()); + })) { + return false; // dispatch entry mismatch + } + + return true; +} + // Replaces each usage of an entry point with its original symbol name with a // new symbol name. void replaceEntryPointUses( @@ -40,96 +241,6 @@ } } -bool areRegionsEquivalent(Region *lhs, Region *rhs) { - if (lhs->getBlocks().size() != rhs->getBlocks().size()) { - return false; - } - - for (auto blockPair : llvm::zip(lhs->getBlocks(), rhs->getBlocks())) { - auto &lhsBlock = std::get<0>(blockPair); - auto &rhsBlock = std::get<1>(blockPair); - // Warning: .size() is linear time. - // We could instead iterate through both lists of operations explicitly, - // stopping when operations are not equivalent, OR either list runs out of - // operations early. - if (lhsBlock.getOperations().size() != rhsBlock.getOperations().size()) { - return false; - } - - for (auto opPair : - llvm::zip(lhsBlock.getOperations(), rhsBlock.getOperations())) { - auto &lhsOp = std::get<0>(opPair); - auto &rhsOp = std::get<1>(opPair); - if (!OperationEquivalence::isEquivalentTo( - &lhsOp, &rhsOp, OperationEquivalence::IgnoreOperands)) { - return false; - } - - // We want to check the operand _types_, but don't care if the actual - // operand references differ (as they live in separate modules anyway). - if (!std::equal(lhsOp.operand_type_begin(), lhsOp.operand_type_end(), - rhsOp.operand_type_begin())) { - return false; - } - - // If the operations have regions, recurse into them (depth-first). - if (lhsOp.getNumRegions() != rhsOp.getNumRegions()) { - return false; - } - auto lhsRegions = lhsOp.getRegions(); - auto rhsRegions = rhsOp.getRegions(); - for (int i = 0; i < lhsRegions.size(); ++i) { - if (!areRegionsEquivalent(&lhsRegions[i], &rhsRegions[i])) { - return false; - } - } - } - } - - return true; -} - -bool areExecutablesEquivalent(ExecutableOp lhs, ExecutableOp rhs) { - auto lhsModule = lhs.getInnerModule(); - auto rhsModule = rhs.getInnerModule(); - - // TODO(scotttodd): Generalize: replace special cases with just calling - // areRegionsEquivalent() on module.getBodyRegion(). We want to ignore - // operation names and sym_name attrs, which - // OperationEquivalence::isEquivalentTo() does not support [yet]. - - // Must have the same number of entry point ops, with the same attributes. - // Entry point op symbol names are expected to differ, that won't affect - // equivalence. - auto lhsEntryOps = llvm::to_vector<1>(lhsModule.getOps<DispatchEntryOp>()); - auto rhsEntryOps = llvm::to_vector<1>(rhsModule.getOps<DispatchEntryOp>()); - if (lhsEntryOps.size() != rhsEntryOps.size()) { - return false; - } - for (int i = 0; i < lhsEntryOps.size(); ++i) { - if (lhsEntryOps[i].getAttrs() != rhsEntryOps[i].getAttrs()) { - return false; - } - } - - // Must have the same number of functions, with each listed in the same order - // and with equivalent regions inside. - auto lhsFuncOps = llvm::to_vector<1>(lhsModule.getOps<FuncOp>()); - auto rhsFuncOps = llvm::to_vector<1>(rhsModule.getOps<FuncOp>()); - if (lhsFuncOps.size() != rhsFuncOps.size()) { - return false; - } - for (int i = 0; i < lhsFuncOps.size(); ++i) { - auto lhsRegion = lhsFuncOps[i].getCallableRegion(); - auto rhsRegion = rhsFuncOps[i].getCallableRegion(); - if (!areRegionsEquivalent(lhsRegion, rhsRegion)) { - return false; - } - } - - return true; -} - } // namespace class DeduplicateExecutablesPass @@ -152,9 +263,9 @@ for (int j = 0; j < i; ++j) { auto referenceExecutableOp = executableOps[j]; - - if (!areExecutablesEquivalent(duplicateExecutableOp, - referenceExecutableOp)) { + if (!isStructurallyEquivalentTo( + duplicateExecutableOp.getBodyRegion(), + referenceExecutableOp.getBodyRegion())) { continue; }
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir index 946cd10..aa0d239 100644 --- a/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir +++ b/iree/compiler/Dialect/Flow/Transforms/test/deduplicate_executables.mlir
@@ -64,6 +64,38 @@ // ----- +// CHECK: flow.executable @same_ops_diff_operands_ex_0 +flow.executable @same_ops_diff_operands_ex_0 { + flow.dispatch.entry @entry_0 + module { + func @entry_0(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.multiply %arg0, %arg1 : tensor<2xi32> + return %0 : tensor<2xi32> + } + } +} +// CHECK: flow.executable @same_ops_diff_operands_ex_1 +flow.executable @same_ops_diff_operands_ex_1 { + flow.dispatch.entry @entry_1 + module { + func @entry_1(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> + return %0 : tensor<2xi32> + } + } +} +// CHECK-LABEL: func @same_ops_diff_operands +func @same_ops_diff_operands(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi32> { + %c4 = constant 4 : index + // CHECK: %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = flow.dispatch @same_ops_diff_operands_ex_0::@entry_0[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + // CHECK: %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = flow.dispatch @same_ops_diff_operands_ex_1::@entry_1[%c4 : index](%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + // CHECK-LABEL: flow.executable @multiple_entry_points_ex_0 flow.executable @multiple_entry_points_ex_0 { flow.dispatch.entry @multiple_entry_points_0_entry_0 @@ -142,7 +174,6 @@ // ----- - // CHECK-LABEL: flow.executable @nested_ops_ex_0 flow.executable @nested_ops_ex_0 { flow.dispatch.entry @nested_ops_entry_0 @@ -199,3 +230,89 @@ %2 = flow.dispatch @nested_ops_ex_2::@nested_ops_entry_2[%c4 : index](%arg0) : (tensor<1x4xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } + +// ----- + +// CHECK-LABEL: flow.executable @attributes_ex_0 +flow.executable @attributes_ex_0 { + flow.dispatch.entry @attributes_entry_0 + module { + func @attributes_entry_0(%input: tensor<1x4xi32>) -> tensor<1xi32> { + %0 = constant dense<0> : tensor<i32> + %1 = "mhlo.reduce"(%input, %0) ( { + ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors + %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> + "mhlo.return"(%3) : (tensor<i32>) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32> + return %1 : tensor<1xi32> + } + } +} + +// CHECK-LABEL: flow.executable @attributes_ex_1 +flow.executable @attributes_ex_1 { + flow.dispatch.entry @attributes_entry_1 + module { + func @attributes_entry_1(%input: tensor<1x4xi32>) -> tensor<1xi32> { + %0 = constant dense<0> : tensor<i32> + %1 = "mhlo.reduce"(%input, %0) ( { + ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors + %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> + "mhlo.return"(%3) : (tensor<i32>) -> () + // @attributes_ex_0 but with a different attribute. + }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32> + return %1 : tensor<1xi32> + } + } +} +// Duplicate of @attributes_ex_0 +// CHECK-NOT: flow.executable @attributes_ex_2 +flow.executable @attributes_ex_2 { + flow.dispatch.entry @attributes_entry_2 + module { + func @attributes_entry_2(%input: tensor<1x4xi32>) -> tensor<1xi32> { + %0 = constant dense<0> : tensor<i32> + %1 = "mhlo.reduce"(%input, %0) ( { + ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>): // no predecessors + %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> + "mhlo.return"(%3) : (tensor<i32>) -> () + }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<i32>) -> tensor<1xi32> + return %1 : tensor<1xi32> + } + } +} + +// ----- + +// CHECK-LABEL: flow.executable @block_successors_ex_0 +flow.executable @block_successors_ex_0 { + flow.dispatch.entry @entry_0 + module { + func @entry_0(%arg0: i32, %arg1: i32) -> i32 { + %c0 = constant 0 : i32 + %c1 = constant 1 : i32 + %eqz = cmpi "eq", %arg0, %arg1 : i32 + cond_br %eqz, ^bb_a(%c0 : i32), ^bb_b(%c1 : i32) + ^bb_a(%bb_a_arg0 : i32): + return %bb_a_arg0 : i32 + ^bb_b(%bb_b_arg0 : i32): + return %bb_b_arg0 : i32 + } + } +} +// CHECK-LABEL: flow.executable @block_successors_ex_with_swapped_cond_br +flow.executable @block_successors_ex_with_swapped_cond_br { + flow.dispatch.entry @entry_1 + module { + func @entry_0(%arg0: i32, %arg1: i32) -> i32 { + %c0 = constant 0 : i32 + %c1 = constant 1 : i32 + %eqz = cmpi "eq", %arg0, %arg1 : i32 + cond_br %eqz, ^bb_b(%c0 : i32), ^bb_b(%c1 : i32) + ^bb_a(%bb_a_arg0 : i32): + return %bb_a_arg0 : i32 + ^bb_b(%bb_b_arg0 : i32): + return %bb_b_arg0 : i32 + } + } +}