Fixes a range inference overflow with util.align. (#18808)
* In the present state we were folding a non-analyzable util.align lhs
to a constant zero because the next power of two of the maximal range is
zero.
* Detects overflow and will not infer a range.
* Fixes some issues with a RHS of zero that were discovered when writing
tests for this case (which isn't really valid but was asserting the
compiler).
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index 4e13877..7153080 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -505,7 +505,7 @@
APInt staticAlignment;
bool hasStaticAlignment =
matchPattern(alignment, m_ConstantInt(&staticAlignment));
- if (hasStaticValue && hasStaticAlignment) {
+ if (hasStaticValue && hasStaticAlignment && !staticAlignment.isZero()) {
// If this value is itself a multiple of the alignment then we can fold.
if (staticValue.urem(staticAlignment).isZero()) {
return true; // value % alignment == 0
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 1a009f1..01d0e42 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -1110,7 +1110,7 @@
auto constantAlignment = argRanges[1].getConstantValue();
// Note that for non constant alignment, there may still be something we
// want to infer, but this is left for the future.
- if (constantAlignment) {
+ if (constantAlignment && !constantAlignment->isZero()) {
// We can align the range directly.
// (value + (alignment - 1)) & ~(alignment - 1)
// https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
@@ -1119,11 +1119,19 @@
APInt one(constantAlignment->getBitWidth(), 1);
APInt alignmentM1 = *constantAlignment - one;
APInt alignmentM1Inv = ~alignmentM1;
- auto align = [&](APInt value) -> APInt {
- return (value + alignmentM1) & alignmentM1Inv;
+ auto align = [&](APInt value, bool &invalid) -> APInt {
+ APInt aligned = (value + alignmentM1) & alignmentM1Inv;
+ // Detect overflow, which commonly happens at max range.
+ if (aligned.ult(value))
+ invalid = true;
+ return aligned;
};
- setResultRange(getResult(),
- ConstantIntRanges::fromUnsigned(align(umin), align(umax)));
+ bool invalid = false;
+ auto alignedUmin = align(umin, invalid);
+ auto alignedUmax = align(umax, invalid);
+ if (!invalid)
+ setResultRange(getResult(),
+ ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax));
}
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
index 022beaa..1049f39 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
@@ -216,12 +216,16 @@
return failure();
uint64_t rhsValue = rhsConstant.getZExtValue();
- if (rhsValue > 0 && lhsDiv.udiv() > 0 && lhsDiv.udiv() % rhsValue != 0)
- return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");
+ if (rhsValue > 0 && lhsDiv.udiv() > 0) {
+ if (lhsDiv.udiv() % rhsValue != 0)
+ return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- op, rewriter.getZeroAttr(op.getResult().getType()));
- return success();
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, rewriter.getZeroAttr(op.getResult().getType()));
+ return success();
+ }
+
+ return failure();
}
DataFlowSolver &solver;
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
index 41b304a..1924f42 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
@@ -462,3 +462,34 @@
// CHECK: util.return %[[ALIGN]], %[[ZERO]], %[[REM128]], %[[TRUE]], %[[FALSE]]
util.return %2, %rem64, %rem128, %in_bounds, %out_bounds : index, index, index, i1, i1
}
+
+// -----
+// Unbounded lhs of util.align technically has a range that extends to the max
+// value of the bit width. Attempting to align this overflows (to zero). If not
+// caught, this will most likely lead the optimizer to conclude that the
+// aligned result is a constant zero. This code is verified by checking for
+// overflow generally and should handle this case.
+// CHECK-LABEL: @util_align_overflow
+util.func @util_align_overflow(%arg0 : i64) -> i64 {
+ %c64 = arith.constant 64 : i64
+ // CHECK: util.align
+ %0 = util.align %arg0, %c64 : i64
+ util.return %0 : i64
+}
+
+// -----
+// Aligning to an alignment of zero doesn't make a lot of sense but it isn't
+// numerically an error. We don't fold or optimize this case and we verify
+// it as such (and that other division by zero errors don't come up).
+// CHECK-LABEL: @util_align_zero
+util.func @util_align_zero(%arg0 : i64) -> i64 {
+ %c0 = arith.constant 0 : i64
+ %c16 = arith.constant 16 : i64
+ %assume = util.assume.int %arg0<umin=0, umax=15> : i64
+ %c128 = arith.constant 128 : i64
+ // CHECK: util.align
+ // CHECK: arith.remui
+ %0 = util.align %assume, %c0 : i64
+ %rem16 = arith.remui %0, %c16 : i64
+ util.return %rem16 : i64
+}