[Util] Support loop IVs in divisibility analysis (#22729)
Adds support for analyzing the induction variables of loop-like ops in
the `IntegerDivisibilityAnalysis`. This just uses the lower bound and
step divisibilities to compute the IV divisibilities based on the simple
`lb + i * step` expression.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
index 502818b..10901f8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/BUILD.bazel
@@ -31,7 +31,9 @@
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ControlFlowInterfaces",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt
index 04c4c80..87f0b6b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/CMakeLists.txt
@@ -28,6 +28,7 @@
MLIRAnalysis
MLIRControlFlowInterfaces
MLIRIR
+ MLIRLoopLikeInterface
MLIRPass
MLIRSCFDialect
MLIRSupport
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp
index 2b17acd..96ca759 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.cpp
@@ -8,6 +8,8 @@
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#define DEBUG_TYPE "iree-util-int-divisibility-analysis"
@@ -67,4 +69,60 @@
return success();
}
+void IntegerDivisibilityAnalysis::visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerDivisibilityLattice *> argLattices, unsigned firstIndex) {
+ // Get the constant divisibility, or query the lattice for Values.
+ auto getDivFromOfr = [&](std::optional<OpFoldResult> ofr, Block *block,
+ bool isUnsigned) -> uint64_t {
+ if (ofr.has_value()) {
+ if (auto constBound = getConstantIntValue(*ofr)) {
+ return constBound.value();
+ }
+ auto value = cast<Value>(ofr.value());
+ const IntegerDivisibilityLattice *lattice =
+ getLatticeElementFor(getProgramPointBefore(block), value);
+ if (lattice != nullptr && !lattice->getValue().isUninitialized()) {
+ return isUnsigned ? lattice->getValue().getValue().udiv()
+ : lattice->getValue().getValue().sdiv();
+ }
+ }
+ return isUnsigned
+ ? IntegerDivisibility::getMinDivisibility().getValue().udiv()
+ : IntegerDivisibility::getMinDivisibility().getValue().sdiv();
+ };
+
+ // Infer bounds for loop arguments that have static bounds
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+ std::optional<SmallVector<Value>> ivs = loop.getLoopInductionVars();
+ std::optional<SmallVector<OpFoldResult>> lbs = loop.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loop.getLoopSteps();
+ if (!ivs || !lbs || !steps) {
+ return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, argLattices, firstIndex);
+ }
+ for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) {
+ IntegerDivisibilityLattice *ivEntry = getLatticeElement(iv);
+ Block *block = iv.getParentBlock();
+ uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true);
+ uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false);
+ uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true);
+ uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false);
+ ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv);
+ ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv);
+
+ // Loop induction variables are computed as `lb + i * step`. The
+ // divisibility for `i * step` is just the divisibility of `step`, so
+ // the total divisibility is obtained by unioning the step divisibility
+ // with the lower bound divisibility, which takes the GCD of the two.
+ ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv);
+ propagateIfChanged(ivEntry, ivEntry->join(ivDiv));
+ }
+ return;
+ }
+
+ return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, argLattices, firstIndex);
+}
+
} // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h
index 2a550f6..72d3292 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h
@@ -35,6 +35,15 @@
visitOperation(Operation *op,
ArrayRef<const IntegerDivisibilityLattice *> operands,
ArrayRef<IntegerDivisibilityLattice *> results) override;
+
+ /// Visit block arguments or operation results of an operation with region
+ /// control-flow for which values are not defined by region control-flow. This
+ /// function tries to infer the divisibility of loop induction variables based
+ /// on known loop bounds and steps.
+ void visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerDivisibilityLattice *> argLattices,
+ unsigned firstIndex) override;
};
} // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
index 21954d4..7c8714d 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
@@ -6,6 +6,7 @@
#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
@@ -41,6 +42,10 @@
// control flow. We include it to make the divisibility analysis more
// powerful.
solver.load<dataflow::DeadCodeAnalysis>();
+ // SparseConstantPropagation is needed because DeadCodeAnalysis is too
+ // conservative. It allows the analysis to call visitNonControlFlowArguments
+ // and analyze arguments like loop induction variables.
+ solver.load<dataflow::SparseConstantPropagation>();
solver.load<IntegerDivisibilityAnalysis>();
if (failed(solver.initializeAndRun(rootOp))) {
return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
index 998b6f9..068cb30 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
@@ -186,3 +186,93 @@
%div_3 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index
util.return %div_3 : index
}
+
+// -----
+
+// CHECK-LABEL: @scf_for_constant_step_from_zero
+util.func @scf_for_constant_step_from_zero() {
+ %c0 = arith.constant 0 : index
+ %c100 = arith.constant 100 : index
+ %c8 = arith.constant 8 : index
+ scf.for %iv = %c0 to %c100 step %c8 {
+ // CHECK: divisibility = "udiv = 8, sdiv = 8"
+ %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+ }
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_nontrivial_gcd
+util.func @scf_for_nontrivial_gcd() {
+ %c12 = arith.constant 12 : index
+ %c100 = arith.constant 100 : index
+ %c18 = arith.constant 18 : index
+ scf.for %iv = %c12 to %c100 step %c18 {
+ // CHECK: divisibility = "udiv = 6, sdiv = 6"
+ %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+ }
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_dynamic_bounds_and_step
+util.func @scf_for_dynamic_bounds_and_step(%arg0 : index, %arg1 : index) {
+ %lb = util.assume.int %arg0<udiv = 16> : index
+ %step = util.assume.int %arg1<udiv = 24> : index
+ %c100 = arith.constant 100 : index
+ scf.for %iv = %lb to %c100 step %step {
+ // CHECK: divisibility = "udiv = 8, sdiv = 8"
+ %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+ }
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_for_coprime_bounds_and_step
+util.func @scf_for_coprime_bounds_and_step() {
+ %c15 = arith.constant 15 : index
+ %c100 = arith.constant 100 : index
+ %c8 = arith.constant 8 : index
+ scf.for %iv = %c15 to %c100 step %c8 {
+ // CHECK: divisibility = "udiv = 1, sdiv = 1"
+ %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+ }
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_forall_dynamic_bounds
+util.func @scf_forall_dynamic_bounds(%arg0 : index, %arg1 : index, %arg2 : index) {
+ %lb = util.assume.int %arg0<udiv = 20> : index
+ %ub = util.assume.int %arg1<udiv = 1> : index
+ %step = util.assume.int %arg2<udiv = 15> : index
+ scf.forall (%iv) = (%lb) to (%ub) step (%step) {
+ // CHECK: divisibility = "udiv = 5, sdiv = 5"
+ %0 = "iree_unregistered.test_int_divisibility"(%iv) : (index) -> index
+ }
+ util.return
+}
+
+// -----
+
+// CHECK-LABEL: @scf_forall_multiple_ivs
+util.func @scf_forall_multiple_ivs(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index, %arg5 : index) {
+ %lbs:2 = util.assume.int %arg0<udiv = 20>,
+ %arg3<udiv = 10> : index, index
+ %ubs:2 = util.assume.int %arg1<udiv = 1>,
+ %arg4<udiv = 10> : index, index
+ %steps:2 = util.assume.int %arg2<udiv = 15>,
+ %arg5<udiv = 3> : index, index
+ scf.forall (%iv0, %iv1) = (%lbs#0, %lbs#1) to (%ubs#0, %ubs#1) step (%steps#0, %steps#1) {
+ // CHECK: divisibility = "udiv = 5, sdiv = 5"
+ %0 = "iree_unregistered.test_int_divisibility"(%iv0) : (index) -> index
+ // CHECK: divisibility = "udiv = 1, sdiv = 1"
+ %1 = "iree_unregistered.test_int_divisibility"(%iv1) : (index) -> index
+ }
+ util.return
+}