Fixes a range inference overflow with util.align. (#18808)

* In the present state we were folding a non-analyzable util.align lhs
to a constant zero because the next power of two of the maximal range is
zero.
* Detects overflow and will not infer a range.
* Fixes some issues with a RHS of zero that were discovered when writing
tests for this case (which isn't really valid but was asserting the
compiler).

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index 4e13877..7153080 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -505,7 +505,7 @@
   APInt staticAlignment;
   bool hasStaticAlignment =
       matchPattern(alignment, m_ConstantInt(&staticAlignment));
-  if (hasStaticValue && hasStaticAlignment) {
+  if (hasStaticValue && hasStaticAlignment && !staticAlignment.isZero()) {
     // If this value is itself a multiple of the alignment then we can fold.
     if (staticValue.urem(staticAlignment).isZero()) {
       return true; // value % alignment == 0
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 1a009f1..01d0e42 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -1110,7 +1110,7 @@
   auto constantAlignment = argRanges[1].getConstantValue();
   // Note that for non constant alignment, there may still be something we
   // want to infer, but this is left for the future.
-  if (constantAlignment) {
+  if (constantAlignment && !constantAlignment->isZero()) {
     // We can align the range directly.
     // (value + (alignment - 1)) & ~(alignment - 1)
     // https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
@@ -1119,11 +1119,19 @@
     APInt one(constantAlignment->getBitWidth(), 1);
     APInt alignmentM1 = *constantAlignment - one;
     APInt alignmentM1Inv = ~alignmentM1;
-    auto align = [&](APInt value) -> APInt {
-      return (value + alignmentM1) & alignmentM1Inv;
+    auto align = [&](APInt value, bool &invalid) -> APInt {
+      APInt aligned = (value + alignmentM1) & alignmentM1Inv;
+      // Detect overflow, which commonly happens at max range.
+      if (aligned.ult(value))
+        invalid = true;
+      return aligned;
     };
-    setResultRange(getResult(),
-                   ConstantIntRanges::fromUnsigned(align(umin), align(umax)));
+    bool invalid = false;
+    auto alignedUmin = align(umin, invalid);
+    auto alignedUmax = align(umax, invalid);
+    if (!invalid)
+      setResultRange(getResult(),
+                     ConstantIntRanges::fromUnsigned(alignedUmin, alignedUmax));
   }
 }
 
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
index 022beaa..1049f39 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp
@@ -216,12 +216,16 @@
       return failure();
 
     uint64_t rhsValue = rhsConstant.getZExtValue();
-    if (rhsValue > 0 && lhsDiv.udiv() > 0 && lhsDiv.udiv() % rhsValue != 0)
-      return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");
+    if (rhsValue > 0 && lhsDiv.udiv() > 0) {
+      if (lhsDiv.udiv() % rhsValue != 0)
+        return rewriter.notifyMatchFailure(op, "rhs does not divide lhs");
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-        op, rewriter.getZeroAttr(op.getResult().getType()));
-    return success();
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, rewriter.getZeroAttr(op.getResult().getType()));
+      return success();
+    }
+
+    return failure();
   }
 
   DataFlowSolver &solver;
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
index 41b304a..1924f42 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
@@ -462,3 +462,34 @@
   // CHECK: util.return %[[ALIGN]], %[[ZERO]], %[[REM128]], %[[TRUE]], %[[FALSE]]
   util.return %2, %rem64, %rem128, %in_bounds, %out_bounds : index, index, index, i1, i1
 }
+
+// -----
+// Unbounded lhs of util.align technically has a range that extends to the max
+// value of the bit width. Attempting to align this overflows (to zero). If not
+// caught, this will most likely lead the optimizer to conclude that the
+// aligned result is a constant zero. This code is verified by checking for
+// overflow generally and should handle this case.
+// CHECK-LABEL: @util_align_overflow
+util.func @util_align_overflow(%arg0 : i64) -> i64 {
+  %c64 = arith.constant 64 : i64
+  // CHECK: util.align
+  %0 = util.align %arg0, %c64 : i64
+  util.return %0 : i64
+}
+
+// -----
+// Aligning to an alignment of zero doesn't make a lot of sense but it isn't
+// numerically an error. We don't fold or optimize this case and we verify
+// it as such (and that other division by zero errors don't come up).
+// CHECK-LABEL: @util_align_zero
+util.func @util_align_zero(%arg0 : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %c16 = arith.constant 16 : i64
+  %assume = util.assume.int %arg0<umin=0, umax=15> : i64
+  %c128 = arith.constant 128 : i64
+  // CHECK: util.align
+  // CHECK: arith.remui
+  %0 = util.align %assume, %c0 : i64
+  %rem16 = arith.remui %0, %c16 : i64
+  util.return %rem16 : i64
+}