[Util] Implement InferIntDivisibilityOpInterface for affine ops (#22723)
This PR implements the InferIntDivisibilityOpInterface for affine.apply,
affine.min, and affine.max operations. Affine apply gets the
divisibility of its result expression, and affine.min/max gets the GCD
of all result expression divisibilities. The implementation supports the
following divisibilities and any compositions of them:
- Multiplication: product of operand divisibilities
- Addition/Subtraction: GCD (union) of operand divisibilities
- Division (floor/ceil): quotient when evenly divisible, else 1
- Modulo: falls back to minimum divisibility (1,1)
This PR also adds the TestIntegerDivisibilityAnalysis pass to more
directly test divisibility analysis without relying on IR optimizations.
The pass probes values consumed by
`"iree_unregistered.test_int_divisibility"` ops and annotates them with
computed divisibility attributes.
There is a small change to the arith.divui divisibility implementation
to fallback to minimum divisibility when there is a remainder division,
because we can't infer the divisibility when there is a remainder.
---------
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
index f8bedfc..997dfe0 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel
@@ -43,6 +43,7 @@
"StripDebugOps.cpp",
"TestConversion.cpp",
"TestFloatRangeAnalysis.cpp",
+ "TestIntegerDivisibilityAnalysis.cpp",
"VerifyInitializationOrder.cpp",
"VerifyStructuredControlFlow.cpp",
],
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
index f73fa2a..e64d47b 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt
@@ -41,6 +41,7 @@
"StripDebugOps.cpp"
"TestConversion.cpp"
"TestFloatRangeAnalysis.cpp"
+ "TestIntegerDivisibilityAnalysis.cpp"
"VerifyInitializationOrder.cpp"
"VerifyStructuredControlFlow.cpp"
DEPS
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
index 5445337..188c545 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.h
@@ -73,6 +73,7 @@
#define GEN_PASS_DECL_STRIPDEBUGOPSPASS
#define GEN_PASS_DECL_TESTCONVERSIONPASS
#define GEN_PASS_DECL_TESTFLOATRANGEANALYSISPASS
+#define GEN_PASS_DECL_TESTINTEGERDIVISIBILITYANALYSISPASS
#define GEN_PASS_DECL_VERIFYINITIALIZATIONORDERPASS
#define GEN_PASS_DECL_VERIFYSTRUCTUREDCONTROLFLOWPASS
#include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc" // IWYU pragma: keep
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
index 7093bed..b3f46f7 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td
@@ -346,4 +346,14 @@
}];
}
+def TestIntegerDivisibilityAnalysisPass :
+ Pass<"iree-util-test-integer-divisibility-analysis", ""> {
+ let summary = "Tests integer divisibility analysis.";
+ let description = [{
+ Tests integer divisibility analysis by evaluating any
+ 'iree_unregistered.test_int_divisibility' op and setting the results on an
+ attribute.
+ }];
+}
+
#endif // IREE_DIALECT_UTIL_PASSES
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
new file mode 100644
index 0000000..21954d4
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/TestIntegerDivisibilityAnalysis.cpp
@@ -0,0 +1,68 @@
+// Copyright 2025 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Util/Analysis/IntegerDivisibilityAnalysis.h"
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+
+namespace mlir::iree_compiler::IREE::Util {
+
+#define GEN_PASS_DEF_TESTINTEGERDIVISIBILITYANALYSISPASS
+#include "iree/compiler/Dialect/Util/Transforms/Passes.h.inc"
+
+namespace {
+
+class TestIntegerDivisibilityAnalysisPass
+ : public impl::TestIntegerDivisibilityAnalysisPassBase<
+ TestIntegerDivisibilityAnalysisPass> {
+public:
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+ MLIRContext *context = &getContext();
+
+ // The pass is rooted on `iree_unregistered.test_int_divisibility` ops,
+ // which are expected to have a single operand for which to annotate
+ // divisibility information.
+ SmallVector<std::pair<Operation *, Value>> queryOps;
+ rootOp->walk([&](Operation *op) {
+ if (op->getName().getStringRef() ==
+ "iree_unregistered.test_int_divisibility" &&
+ op->getNumOperands() == 1) {
+ queryOps.emplace_back(op, op->getOperand(0));
+ }
+ });
+
+ DataFlowSolver solver;
+ // DeadCodeAnalysis is the base analysis that allows the solver to traverse
+ // control flow. We include it to make the divisibility analysis more
+ // powerful.
+ solver.load<dataflow::DeadCodeAnalysis>();
+ solver.load<IntegerDivisibilityAnalysis>();
+ if (failed(solver.initializeAndRun(rootOp))) {
+ return signalPassFailure();
+ }
+
+ for (auto &[op, value] : queryOps) {
+ auto *lattice = solver.lookupState<IntegerDivisibilityLattice>(value);
+ if (!lattice || lattice->getValue().isUninitialized()) {
+ op->setAttr("divisibility", StringAttr::get(context, "uninitialized"));
+ continue;
+ }
+
+ // Format for the divisibility information is "udiv = X, sdiv = Y".
+ const auto &div = lattice->getValue().getValue();
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ os << "udiv = " << div.udiv() << ", sdiv = " << div.sdiv();
+ op->setAttr("divisibility", StringAttr::get(context, os.str()));
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::IREE::Util
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
index d3fe868..64bb4a4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/BUILD.bazel
@@ -41,6 +41,7 @@
"strip_debug_ops.mlir",
"test_float_range_analysis.mlir",
"test_float_range_analysis_linalg.mlir",
+ "test_integer_divisibility_analysis.mlir",
"verify_initialization_order.mlir",
"verify_structured_control_flow.mlir",
],
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
index abca385..658f9a9 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/CMakeLists.txt
@@ -39,6 +39,7 @@
"strip_debug_ops.mlir"
"test_float_range_analysis.mlir"
"test_float_range_analysis_linalg.mlir"
+ "test_integer_divisibility_analysis.mlir"
"verify_initialization_order.mlir"
"verify_structured_control_flow.mlir"
TOOLS
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
new file mode 100644
index 0000000..998b6f9
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/test_integer_divisibility_analysis.mlir
@@ -0,0 +1,188 @@
+// RUN: iree-opt --split-input-file --iree-util-test-integer-divisibility-analysis --allow-unregistered-dialect %s | FileCheck %s
+
+// CHECK-LABEL: @affine_apply_mul_divisibility
+util.func @affine_apply_mul_divisibility(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 8> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 * 4)>(%0)
+ // CHECK: divisibility = "udiv = 32, sdiv = 32"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_mul_negative
+util.func @affine_apply_mul_negative(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 8> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 * -4)>(%0)
+ // CHECK: divisibility = "udiv = 32, sdiv = 32"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_add_gcd
+util.func @affine_apply_add_gcd(%arg0 : index, %arg1 : index) -> index {
+ %0:2 = util.assume.int %arg0<udiv = 16>,
+ %arg1<udiv = 24> : index, index
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%0#0, %0#1)
+ // CHECK: divisibility = "udiv = 8, sdiv = 8"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_floordiv_exact
+util.func @affine_apply_floordiv_exact(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 64> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 floordiv 4)>(%0)
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_ceildiv_exact
+util.func @affine_apply_ceildiv_exact(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 64> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 ceildiv 4)>(%0)
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_floordiv_non_exact
+util.func @affine_apply_floordiv_non_exact(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 20> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 floordiv 3)>(%0)
+ // CHECK: divisibility = "udiv = 1, sdiv = 1"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_mod
+util.func @affine_apply_mod(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 16> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 mod 16)>(%0)
+ // CHECK: divisibility = "udiv = 1, sdiv = 1"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_composition
+util.func @affine_apply_composition(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 8> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 * 4 + 16)>(%0)
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_with_symbol
+util.func @affine_apply_with_symbol(%arg0 : index, %arg1 : index) -> index {
+ %0:2 = util.assume.int %arg0<udiv = 16>,
+ %arg1<udiv = 16> : index, index
+ %1 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%0#0)[%0#1]
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_min_uniform_divisibility
+util.func @affine_min_uniform_divisibility(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 16> : index
+ %1 = affine.min affine_map<(d0) -> (d0, d0 + 64)>(%0)
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_min_different_divisibilities
+util.func @affine_min_different_divisibilities(%arg0 : index, %arg1 : index) -> index {
+ %0:2 = util.assume.int %arg0<udiv = 16>,
+ %arg1<udiv = 24> : index, index
+ %1 = affine.min affine_map<(d0, d1) -> (d0, d1)>(%0#0, %0#1)
+ // CHECK: divisibility = "udiv = 8, sdiv = 8"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_max_uniform_divisibility
+util.func @affine_max_uniform_divisibility(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 32> : index
+ %1 = affine.max affine_map<(d0) -> (d0, d0 - 64)>(%0)
+ // CHECK: divisibility = "udiv = 32, sdiv = 32"
+ %2 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ util.return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_max_different_divisibilities
+util.func @affine_max_different_divisibilities(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
+ %0:3 = util.assume.int %arg0<udiv = 12>,
+ %arg1<udiv = 24>,
+ %arg2<udiv = 18> : index, index, index
+ %3 = affine.max affine_map<(d0, d1, d2) -> (d0, d1, d2)>(%0#0, %0#1, %0#2)
+ // CHECK: divisibility = "udiv = 6, sdiv = 6"
+ %4 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index
+ util.return %4 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_constant
+util.func @affine_apply_constant() -> index {
+ %0 = affine.apply affine_map<() -> (64)>()
+ // CHECK: divisibility = "udiv = 64, sdiv = 64"
+ %1 = "iree_unregistered.test_int_divisibility"(%0) : (index) -> index
+ util.return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @affine_apply_chained_operations
+util.func @affine_apply_chained_operations(%arg0 : index) -> index {
+ %0 = util.assume.int %arg0<udiv = 4> : index
+ %1 = affine.apply affine_map<(d0) -> (d0 * 8)>(%0)
+ %2 = affine.apply affine_map<(d0) -> (d0 + 16)>(%1)
+ // CHECK: divisibility = "udiv = 16, sdiv = 16"
+ %3 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index
+ util.return %3 : index
+}
+
+// -----
+
+// CHECK-LABEL: @complex_chained_affine_ops
+util.func @complex_chained_affine_ops(%arg0 : index, %arg1 : index, %arg2 : index) -> index {
+ %0:3 = util.assume.int %arg0<udiv = 210>,
+ %arg1<udiv = 7>,
+ %arg2<udiv = 15> : index, index, index
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 + 2 * d1)>(%0#0, %0#1)
+ // CHECK: divisibility = "udiv = 14, sdiv = 14"
+ %div_1 = "iree_unregistered.test_int_divisibility"(%1) : (index) -> index
+ %2 = affine.max affine_map<(d0, d1) -> (d0 floordiv 6, d1 * 3)>(%0#0, %0#2)
+ // CHECK: divisibility = "udiv = 5, sdiv = 5"
+ %div_2 = "iree_unregistered.test_int_divisibility"(%2) : (index) -> index
+ %3 = affine.min affine_map<(d0)[s0] -> (2 * (s0 * d0 - 14) ceildiv 7, d0 floordiv 3 * 2)>(%2)[%1]
+ // CHECK: divisibility = "udiv = 2, sdiv = 2"
+ %div_3 = "iree_unregistered.test_int_divisibility"(%3) : (index) -> index
+ util.return %div_3 : index
+}
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
index 4bb1330..61eafdf 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
+++ b/compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel
@@ -42,6 +42,7 @@
"//compiler/src/iree/compiler/Dialect/TensorExt/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationInterfaces",
"@llvm-project//mlir:ControlFlowInterfaces",
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
index a183c55..efb4db2 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
+++ b/compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt
@@ -31,6 +31,7 @@
"UtilExternalModels.cpp"
DEPS
LLVMSupport
+ MLIRAffineDialect
MLIRArithDialect
MLIRControlFlowInterfaces
MLIRGPUDialect
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
index 75d442b..78e79aa 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
@@ -16,11 +16,13 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -49,6 +51,232 @@
return IREE::Util::ConstantIntDivisibility(1, 1);
}
+/// Visits affine expressions and recursively calculates the divisibilities of
+/// each subexpression. The final divisibilities of the expression and its
+/// subexpressions will be stored in the map for which a reference is provided
+/// to the AffineExprDivisibilityFinder (i.e., `divisibilityMap`).
+class AffineExprDivisibilityFinder
+ : public AffineExprVisitor<AffineExprDivisibilityFinder,
+ IREE::Util::ConstantIntDivisibility> {
+public:
+ using ExprDivisibilityMap =
+ llvm::DenseMap<AffineExpr, IREE::Util::ConstantIntDivisibility>;
+ AffineExprDivisibilityFinder(ExprDivisibilityMap &divisibilityMap)
+ : divisibilityMap(divisibilityMap) {}
+
+ IREE::Util::ConstantIntDivisibility
+ visitConstantExpr(AffineConstantExpr expr) {
+ // Constant expressions are trivial, since they are always static.
+ uint64_t constValue = std::abs(expr.getValue());
+ return IREE::Util::ConstantIntDivisibility(constValue, constValue);
+ }
+
+ IREE::Util::ConstantIntDivisibility visitDimExpr(AffineDimExpr expr) {
+ // Dim expressions cannot be analyzed further, so return the divisibility
+ // in `divisibilityMap` if it has been populated by the caller, or fallback
+ // to the minimum divisibility.
+ if (divisibilityMap.contains(expr)) {
+ return divisibilityMap[expr];
+ }
+ return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue();
+ }
+
+ IREE::Util::ConstantIntDivisibility visitSymbolExpr(AffineSymbolExpr expr) {
+ // Symbol expressions cannot be analyzed further, so return the divisibility
+ // in `divisibilityMap` if it has been populated by the caller, or fallback
+ // to the minimum divisibility.
+ if (divisibilityMap.contains(expr)) {
+ return divisibilityMap[expr];
+ }
+ return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue();
+ }
+
+ /// Infer the divisibility of an addition or subtraction expression by
+ /// recursively visiting the LHS and RHS, and then unioning the results.
+ IREE::Util::ConstantIntDivisibility visitAddExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr)) {
+ return divisibilityMap[expr];
+ }
+ // The divisibility of an addition is the GCD of its constituents'
+ // divisibilities.
+ IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS());
+ return lhsDiv.getUnion(rhsDiv);
+ }
+
+ /// Infer the divisibility of a multiplication expression by recursively
+ /// visiting the LHS and RHS, and then multiplying the results.
+ IREE::Util::ConstantIntDivisibility visitMulExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr)) {
+ return divisibilityMap[expr];
+ }
+ // The divisibility of a multiplication is the product of its constituents'
+ // divisibilities.
+ IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ IREE::Util::ConstantIntDivisibility rhsDiv = visit(expr.getRHS());
+ return IREE::Util::ConstantIntDivisibility(lhsDiv.udiv() * rhsDiv.udiv(),
+ lhsDiv.sdiv() * rhsDiv.sdiv());
+ }
+
+ IREE::Util::ConstantIntDivisibility
+ visitFloorDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr);
+ }
+
+ IREE::Util::ConstantIntDivisibility
+ visitCeilDivExpr(AffineBinaryOpExpr expr) {
+ return visitDivExpr(expr);
+ }
+
+ /// Mod expressions could be inferred to be zero in some cases, but for now
+ /// just return the minimum divisibility.
+ /// TODO(Max191): Handle evenly divisible cases, and ensure that the zero
+ /// divisibility propagates properly through parent expressions.
+ IREE::Util::ConstantIntDivisibility visitModExpr(AffineBinaryOpExpr expr) {
+ return visitInvalidExpr(expr);
+ }
+
+private:
+ IREE::Util::ConstantIntDivisibility
+ visitInvalidExpr(AffineBinaryOpExpr expr) {
+ return IREE::Util::IntegerDivisibility::getMinDivisibility().getValue();
+ }
+
+ /// Helper shared by ceildiv and floordiv implementations. Returns the minimum
+ /// divisibility as a fallback if the divisor is not a constant, because the
+ /// divisibility cannot be inferred in this case. If the divisor is a
+ /// constant, then this function recursively visits the dividend, and returns
+ /// the quotient of the dividend's divisibility with the divisor.
+ IREE::Util::ConstantIntDivisibility visitDivExpr(AffineBinaryOpExpr expr) {
+ if (divisibilityMap.contains(expr)) {
+ return divisibilityMap[expr];
+ }
+ auto constRhs = dyn_cast<AffineConstantExpr>(expr.getRHS());
+ // Division by zero is undefined, so return the minimum divisibility.
+ if (!constRhs || constRhs.getValue() == 0) {
+ return IREE::Util::ConstantIntDivisibility(1, 1);
+ }
+ auto constValue = static_cast<uint64_t>(std::abs(constRhs.getValue()));
+ IREE::Util::ConstantIntDivisibility lhsDiv = visit(expr.getLHS());
+ uint64_t divUDiv =
+ lhsDiv.udiv() % constValue == 0 ? lhsDiv.udiv() / constValue : 1;
+ uint64_t divSDiv =
+ lhsDiv.sdiv() % constValue == 0 ? lhsDiv.sdiv() / constValue : 1;
+ return IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv);
+ }
+
+ ExprDivisibilityMap &divisibilityMap;
+};
+
+/// Returns the divisibilities of each AffineMap result based on the
+/// divisibilities of its dims and symbols. The `dimAndSymbolDivisibilities`
+/// should contain the divisibilities of the dims, followed by the
+/// divisibilities of the symbols in ascending order by their positions.
+static SmallVector<IREE::Util::ConstantIntDivisibility> getResultDivisibilities(
+ AffineMap map,
+ ArrayRef<IREE::Util::ConstantIntDivisibility> dimAndSymbolDivisibilities) {
+ // Seed the AffineExprDivisibilityFinder with the dimAndSymbolDivisibilities.
+ llvm::DenseMap<AffineExpr, IREE::Util::ConstantIntDivisibility>
+ exprDivisibilityMap;
+ SmallVector<AffineExpr> inputExprs;
+ inputExprs.append(llvm::map_to_vector(
+ llvm::seq<int64_t>(map.getNumDims()),
+ [&](int64_t dim) { return getAffineDimExpr(dim, map.getContext()); }));
+ inputExprs.append(llvm::map_to_vector(
+ llvm::seq<int64_t>(map.getNumSymbols()),
+ [&](int64_t sym) { return getAffineSymbolExpr(sym, map.getContext()); }));
+ for (auto [expr, divisibility] :
+ llvm::zip_equal(inputExprs, dimAndSymbolDivisibilities)) {
+ exprDivisibilityMap[expr] = divisibility;
+ }
+ AffineExprDivisibilityFinder divisibilityFinder(exprDivisibilityMap);
+
+ // Walk each result expression and compute their divisibilities.
+ SmallVector<IREE::Util::ConstantIntDivisibility> resultDivisibilities;
+ for (AffineExpr resultExpr : map.getResults()) {
+ resultDivisibilities.push_back(divisibilityFinder.visit(resultExpr));
+ }
+ return resultDivisibilities;
+}
+
+struct AffineApplyInferIntDivisibilityOpInterface
+ : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel<
+ AffineApplyInferIntDivisibilityOpInterface, affine::AffineApplyOp> {
+
+ void inferResultDivisibility(
+ Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+ IREE::Util::SetIntDivisibilityFn setResultDivs) const {
+ auto affineApplyOp = cast<affine::AffineApplyOp>(op);
+ SmallVector<IREE::Util::ConstantIntDivisibility> operandDivisibilities;
+ for (auto [operand, divisibility] :
+ llvm::zip(affineApplyOp.getOperands(), argDivs)) {
+ operandDivisibilities.push_back(
+ getDivisibilityOfOperand(operand, divisibility));
+ }
+
+ SmallVector<IREE::Util::ConstantIntDivisibility> resultDivisibilities =
+ getResultDivisibilities(affineApplyOp.getMap(), operandDivisibilities);
+ for (auto [result, divisibility] :
+ llvm::zip_equal(affineApplyOp->getResults(), resultDivisibilities)) {
+ setResultDivs(result, divisibility);
+ }
+ }
+};
+
+/// Infer the result divisibility of an affine.min or affine.max operation
+/// based on its operand divisibilities. The result divisibility is the GCD
+/// of the divisibilities of each of the affine map results, because the result
+/// of the affine.min/max op could be any of these results.
+template <typename MinOrMaxTy>
+static void inferAffineMinOrMaxResultDivisibility(
+ MinOrMaxTy minOrMaxOp, ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+ IREE::Util::SetIntDivisibilityFn setResultDivs) {
+ static_assert(
+ llvm::is_one_of<MinOrMaxTy, affine::AffineMinOp,
+ affine::AffineMaxOp>::value,
+ "MinOrMaxTy must be affine::AffineMinOp or affine::AffineMaxOp");
+ SmallVector<IREE::Util::ConstantIntDivisibility> operandDivisibilities;
+ for (auto [operand, divisibility] :
+ llvm::zip(minOrMaxOp.getOperands(), argDivs)) {
+ operandDivisibilities.push_back(
+ getDivisibilityOfOperand(operand, divisibility));
+ }
+
+ SmallVector<IREE::Util::ConstantIntDivisibility> resultDivisibilities =
+ getResultDivisibilities(minOrMaxOp.getMap(), operandDivisibilities);
+
+ IREE::Util::ConstantIntDivisibility resultDivisibility =
+ resultDivisibilities.pop_back_val();
+ for (auto divisibility : resultDivisibilities) {
+ resultDivisibility = resultDivisibility.getUnion(divisibility);
+ }
+ setResultDivs(minOrMaxOp.getResult(), resultDivisibility);
+}
+
+struct AffineMinInferIntDivisibilityOpInterface
+ : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel<
+ AffineMinInferIntDivisibilityOpInterface, affine::AffineMinOp> {
+
+ void inferResultDivisibility(
+ Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+ IREE::Util::SetIntDivisibilityFn setResultDivs) const {
+ auto affineMinOp = cast<affine::AffineMinOp>(op);
+ inferAffineMinOrMaxResultDivisibility(affineMinOp, argDivs, setResultDivs);
+ }
+};
+
+struct AffineMaxInferIntDivisibilityOpInterface
+ : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel<
+ AffineMaxInferIntDivisibilityOpInterface, affine::AffineMaxOp> {
+
+ void inferResultDivisibility(
+ Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+ IREE::Util::SetIntDivisibilityFn setResultDivs) const {
+ auto affineMaxOp = cast<affine::AffineMaxOp>(op);
+ inferAffineMinOrMaxResultDivisibility(affineMaxOp, argDivs, setResultDivs);
+ }
+};
+
struct ArithConstantInferIntDivisibilityOpInterface
: public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel<
ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> {
@@ -104,8 +332,13 @@
auto lhsDivisibility = getDivisibilityOfOperand(divOp.getLhs(), argDivs[0]);
- uint64_t divUDiv = lhsDivisibility.udiv() / intVal.getZExtValue();
- uint64_t divSDiv = lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue());
+ uint64_t divUDiv = lhsDivisibility.udiv() % intVal.getZExtValue() == 0
+ ? lhsDivisibility.udiv() / intVal.getZExtValue()
+ : 1;
+ uint64_t divSDiv =
+ lhsDivisibility.sdiv() % std::abs(intVal.getSExtValue()) == 0
+ ? lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue())
+ : 1;
setResultDivs(divOp, IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv));
}
@@ -906,6 +1139,7 @@
void registerUtilExternalModels(DialectRegistry ®istry) {
// Must ensure that any dependent dialects are registered.
+ registry.insert<affine::AffineDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<linalg::LinalgDialect>();
registry.insert<ml_program::MLProgramDialect>();
@@ -933,6 +1167,16 @@
});
registry.addExtension(
+ +[](MLIRContext *context, affine::AffineDialect *dialect) {
+ affine::AffineApplyOp::attachInterface<
+ AffineApplyInferIntDivisibilityOpInterface>(*context);
+ affine::AffineMinOp::attachInterface<
+ AffineMinInferIntDivisibilityOpInterface>(*context);
+ affine::AffineMaxOp::attachInterface<
+ AffineMaxInferIntDivisibilityOpInterface>(*context);
+ });
+
+ registry.addExtension(
+[](MLIRContext *context, tensor::TensorDialect *dialect) {
tensor::InsertSliceOp::attachInterface<InsertSliceOpTiedOpInterface>(
*context);