[Util] Support loop IVs in divisibility analysis (#22729)

Adds support for analyzing the induction variables of loop-like ops in
the `IntegerDivisibilityAnalysis`. This just uses the lower bound and
step divisibilities to compute the IV divisibilities based on the simple
`lb + i * step` expression.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
index 502818b..10901f8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
@@ -31,7 +31,9 @@
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:Analysis",
         "@llvm-project//mlir:ControlFlowInterfaces",
+        "@llvm-project//mlir:DialectUtils",
         "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LoopLikeInterface",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:SCFDialect",
         "@llvm-project//mlir:Support",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt
index 04c4c80..87f0b6b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt
@@ -28,6 +28,7 @@
     MLIRAnalysis
     MLIRControlFlowInterfaces
     MLIRIR
+    MLIRLoopLikeInterface
     MLIRPass
     MLIRSCFDialect
     MLIRSupport
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp
index 2b17acd..96ca759 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp
@@ -8,6 +8,8 @@
 
 #include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
 #include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 
 #define DEBUG_TYPE "iree-util-int-divisibility-analysis"
 
@@ -67,4 +69,60 @@
   return success();
 }
 
+void IntegerDivisibilityAnalysis::visitNonControlFlowArguments(
+    Operation *op, const RegionSuccessor &successor,
+    ArrayRef<IntegerDivisibilityLattice *> argLattices, unsigned firstIndex) {
+  // Get the constant divisibility, or query the lattice for Values.
+  auto getDivFromOfr = [&](std::optional<OpFoldResult> ofr, Block *block,
+                           bool isUnsigned) -> uint64_t {
+    if (ofr.has_value()) {
+      if (auto constBound = getConstantIntValue(*ofr)) {
+        return constBound.value();
+      }
+      auto value = cast<Value>(ofr.value());
+      const IntegerDivisibilityLattice *lattice =
+          getLatticeElementFor(getProgramPointBefore(block), value);
+      if (lattice != nullptr && !lattice->getValue().isUninitialized()) {
+        return isUnsigned ? lattice->getValue().getValue().udiv()
+                          : lattice->getValue().getValue().sdiv();
+      }
+    }
+    return isUnsigned
+               ? IntegerDivisibility::getMinDivisibility().getValue().udiv()
+               : IntegerDivisibility::getMinDivisibility().getValue().sdiv();
+  };
+
+  // Infer bounds for loop arguments that have static bounds
+  if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+    std::optional<SmallVector<Value>> ivs = loop.getLoopInductionVars();
+    std::optional<SmallVector<OpFoldResult>> lbs = loop.getLoopLowerBounds();
+    std::optional<SmallVector<OpFoldResult>> steps = loop.getLoopSteps();
+    if (!ivs || !lbs || !steps) {
+      return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+          op, successor, argLattices, firstIndex);
+    }
+    for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) {
+      IntegerDivisibilityLattice *ivEntry = getLatticeElement(iv);
+      Block *block = iv.getParentBlock();
+      uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true);
+      uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false);
+      uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true);
+      uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false);
+      ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv);
+      ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv);
+
+      // Loop induction variables are computed as `lb + i * step`. The
+      // divisibility for `i * step` is just the divisibility of `step`, so
+      // the total divisibility is obtained by unioning the step divisibility
+      // with the lower bound divisibility, which takes the GCD of the two.
+      ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv);
+      propagateIfChanged(ivEntry, ivEntry->join(ivDiv));
+    }
+    return;
+  }
+
+  return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+      op, successor, argLattices, firstIndex);
+}
+
 } // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h
index 2a550f6..72d3292 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h
@@ -35,6 +35,15 @@
   visitOperation(Operation *op,
                  ArrayRef<const IntegerDivisibilityLattice *> operands,
                  ArrayRef<IntegerDivisibilityLattice *> results) override;
+
+  /// Visit block arguments or operation results of an operation with region
+  /// control-flow for which values are not defined by region control-flow. This
+  /// function tries to infer the divisibility of loop induction variables based
+  /// on known loop bounds and steps.
+  void visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &successor,
+      ArrayRef<IntegerDivisibilityLattice *> argLattices,
+      unsigned firstIndex) override;
 };
 
 } // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
index 21954d4..7c8714d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
@@ -6,6 +6,7 @@
 
 #include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"
 #include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 
@@ -41,6 +42,10 @@
     // control flow. We include it to make the divisibility analysis more
     // powerful.
     solver.load<dataflow::DeadCodeAnalysis>();
+    // SparseConstantPropagation is needed because DeadCodeAnalysis is too
+    // conservative. It allows the analysis to call visitNonControlFlowArguments
+    // and analyze arguments like loop induction variables.
+    solver.load<dataflow::SparseConstantPropagation>();
     solver.load<IntegerDivisibilityAnalysis>();
     if (failed(solver.initializeAndRun(rootOp))) {
       return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
index 998b6f9..068cb30 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
@@ -186,3 +186,93 @@
   %div_3 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index
   util.return %div_3 : index
 }
+
+// -----
+
+// CHECK-LABEL: @scf_for_constant_step_from_zero
+util.func @scf_for_constant_step_from_zero() {
+  %c0 = arith.constant 0 : index
+  %c100 = arith.constant 100 : index
+  %c8 = arith.constant 8 : index
+  scf.for %iv = %c0 to %c100 step %c8 {
+    // CHECK: divisibility = "udiv = 8, sdiv = 8"
+    %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+  }
+  util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_nontrivial_gcd
+util.func @scf_for_nontrivial_gcd() {
+  %c12 = arith.constant 12 : index
+  %c100 = arith.constant 100 : index
+  %c18 = arith.constant 18 : index
+  scf.for %iv = %c12 to %c100 step %c18 {
+    // CHECK: divisibility = "udiv = 6, sdiv = 6"
+    %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+  }
+  util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_dynamic_bounds_and_step
+util.func @scf_for_dynamic_bounds_and_step(%arg0 : index, %arg1 : index) {
+  %lb = util.assume.int %arg0<udiv = 16> : index
+  %step = util.assume.int %arg1<udiv = 24> : index
+  %c100 = arith.constant 100 : index
+  scf.for %iv = %lb to %c100 step %step {
+    // CHECK: divisibility = "udiv = 8, sdiv = 8"
+    %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+  }
+  util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_coprime_bounds_and_step
+util.func @scf_for_coprime_bounds_and_step() {
+  %c15 = arith.constant 15 : index
+  %c100 = arith.constant 100 : index
+  %c8 = arith.constant 8 : index
+  scf.for %iv = %c15 to %c100 step %c8 {
+    // CHECK: divisibility = "udiv = 1, sdiv = 1"
+    %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+  }
+  util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_forall_dynamic_bounds
+util.func @scf_forall_dynamic_bounds(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %lb = util.assume.int %arg0<udiv = 20> : index
+  %ub = util.assume.int %arg1<udiv = 1> : index
+  %step = util.assume.int %arg2<udiv = 15> : index
+  scf.forall (%iv) = (%lb) to (%ub) step (%step) {
+    // CHECK: divisibility = "udiv = 5, sdiv = 5"
+    %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+  }
+  util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_forall_multiple_ivs
+util.func @scf_forall_multiple_ivs(%arg0 : index, %arg1 : index, %arg2 : index,
+                                   %arg3 : index, %arg4 : index, %arg5 : index) {
+  %lbs:2 = util.assume.int %arg0<udiv = 20>,
+                           %arg3<udiv = 10> : index, index
+  %ubs:2 = util.assume.int %arg1<udiv = 1>,
+                           %arg4<udiv = 10> : index, index
+  %steps:2 = util.assume.int %arg2<udiv = 15>,
+                             %arg5<udiv = 3> : index, index
+  scf.forall (%iv0, %iv1) = (%lbs#0, %lbs#1) to (%ubs#0, %ubs#1) step (%steps#0, %steps#1) {
+    // CHECK: divisibility = "udiv = 5, sdiv = 5"
+    %0 = "iree_unregistered.test_int_divisibility"(%iv0) : (index) -> index
+    // CHECK: divisibility = "udiv = 1, sdiv = 1"
+    %1 = "iree_unregistered.test_int_divisibility"(%iv1) : (index) -> index
+  }
+  util.return
+}