The dominance check was literally backward in the hoisting code. Need to work on my testing skills. Also, since DominanceInfo can get out of date, use Operation::isBeforeInBlock since it stays up to date and we only need local dominance checks in this code. That removes the need for DominanceInfo altogether. PiperOrigin-RevId: 305975281
diff --git a/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculationsPass.cpp b/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculationsPass.cpp index 98fa0eb..f075b80 100644 --- a/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculationsPass.cpp +++ b/iree/compiler/Dialect/Shape/Transforms/HoistShapeCalculationsPass.cpp
@@ -93,11 +93,10 @@ return opsToHoistSet; } -void hoistOps(DenseSet<Operation *> opsToHoistSet, Block &block, - DominanceInfo &domInfo) { +void hoistOps(DenseSet<Operation *> opsToHoistSet, Block &block) { auto opsToHoist = llvm::to_vector<16>(opsToHoistSet); - llvm::sort(opsToHoist, [&](Operation *lhs, Operation *rhs) { - return domInfo.properlyDominates(lhs, rhs); + llvm::stable_sort(opsToHoist, [&](Operation *lhs, Operation *rhs) { + return lhs->isBeforeInBlock(rhs); }); for (Operation *op : opsToHoist) { @@ -105,7 +104,7 @@ for (Value operand : op->getOperands()) { if (Operation *definingOp = getDefiningOpInBlock(operand, block)) { if (insertAfter == nullptr || - domInfo.properlyDominates(definingOp, insertAfter)) { + insertAfter->isBeforeInBlock(definingOp)) { insertAfter = definingOp; } } @@ -140,10 +139,9 @@ public: void runOnFunction() override { auto func = getFunction(); - DominanceInfo domInfo(func); for (Block &block : func) { DenseSet<Operation *> opsToHoist = calculateOpsToHoist(block); - hoistOps(opsToHoist, block, domInfo); + hoistOps(opsToHoist, block); } } };
diff --git a/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir b/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir index 3121f70..1d503e8 100644 --- a/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir +++ b/iree/compiler/Dialect/Shape/Transforms/test/hoist_shape_calculations.mlir
@@ -53,3 +53,17 @@ shapex.tie_shape %arg0, %shape : tensor<?xf32>, !shapex.ranked_shape<[?]> return } + +// ----- + +// CHECK-LABEL: func @f +func @f(%arg0: tensor<?x?xf32>) { + // CHECK: constant + // CHECK: constant + // CHECK: shapex.make_ranked_shape + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = shapex.make_ranked_shape %c0, %c1 : (index, index) -> !shapex.ranked_shape<[?,?]> + %1 = shapex.tie_shape %arg0, %0 : tensor<?x?xf32>, !shapex.ranked_shape<[?,?]> + return +}