Add util.assume.int folder. (#18805)
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index fb466bf..4e13877 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -27,6 +27,115 @@
namespace mlir::iree_compiler::IREE::Util {
//===----------------------------------------------------------------------===//
+// util.assume.int
+//===----------------------------------------------------------------------===//
+
+LogicalResult AssumeIntOp::canonicalize(AssumeIntOp op,
+ PatternRewriter &rewriter) {
+ bool needsRewrite = false;
+ ArrayAttr assumptions = op.getAssumptions();
+
+ // We do a fast check for the canonical form here, making any in-place updates
+ // we can and signalling needsRewrite=true when the op needs to be updated
+ // to a new canonical form.
+ SmallPtrSet<Value, 4> seenOperands;
+ seenOperands.reserve(op.getNumOperands());
+ for (auto [idx, operand] : llvm::enumerate(op.getOperands())) {
+ // Match constant.
+ if (matchPattern(operand, m_Constant())) {
+ needsRewrite = true;
+ rewriter.replaceAllUsesWith(op.getResult(idx), operand);
+ continue;
+ }
+
+ // Check for a duplicate.
+ auto [foundIt, inserted] = seenOperands.insert(operand);
+ if (!inserted) {
+ // This should be the non-common path: find the original index number
+ // and rewrite.
+ for (auto [seenIdx, seenOperand] : llvm::enumerate(op.getOperands())) {
+ if (seenOperand == operand) {
+ needsRewrite = true;
+ rewriter.replaceAllUsesWith(op.getResult(idx), op.getResult(seenIdx));
+ break;
+ }
+ }
+ continue;
+ }
+
+ // Detect whether assumptions need to be normalized.
+ ArrayAttr assumptionRow = llvm::cast<ArrayAttr>(assumptions[idx]);
+ if (assumptionRow.size() > 1) {
+ bool allAssumptionsSame = true;
+ for (unsigned i = 1; i < assumptionRow.size(); ++i) {
+ if (assumptionRow[i] != assumptionRow[0]) {
+ allAssumptionsSame = false;
+ break;
+ }
+ }
+ if (allAssumptionsSame) {
+ needsRewrite = true;
+ }
+ }
+ }
+ if (!needsRewrite)
+ return failure();
+
+ // Need to rewrite the assumption.
+ auto normalizeAssumptions = [](Attribute row, bool &madeChange) {
+ auto rowArray = llvm::cast<ArrayAttr>(row);
+ if (rowArray.size() <= 1)
+ return rowArray;
+
+ bool allSame = true;
+ for (unsigned i = 1; i < rowArray.size(); ++i) {
+ if (rowArray[0] != rowArray[i]) {
+ allSame = false;
+ break;
+ }
+ }
+
+ if (!allSame)
+ return rowArray;
+
+ // All entries are the same: compress down to a single column.
+ madeChange = true;
+ return ArrayAttr::get(row.getContext(), {rowArray[0]});
+ };
+ SmallVector<ArrayAttr> newAssumptions;
+ SmallVector<Value> newOperands;
+ SmallVector<Value> retainedResults;
+ bool madeChange = false;
+ for (auto [idx, operand] : llvm::enumerate(op.getOperands())) {
+ // If the result has no uses, do not retain it.
+ if (op.getResult(idx).use_empty()) {
+ madeChange = true;
+ continue;
+ }
+
+ newAssumptions.push_back(
+ normalizeAssumptions(assumptions[idx], madeChange));
+ newOperands.push_back(operand);
+ retainedResults.push_back(op.getResult(idx));
+ }
+
+ // It is important to avoid canonicalizer looping that if we determined at
+ // the top that a rewrite was needed, that we actually made a change.
+ (void)madeChange;
+ assert(madeChange && "util.assume.int canonicalizer signaled a rewrite was "
+ "needed but it produced the same op");
+
+ if (!newOperands.empty()) {
+ auto newOp =
+ rewriter.create<AssumeIntOp>(op.getLoc(), newOperands, newAssumptions);
+ rewriter.replaceAllUsesWith(retainedResults, newOp.getResults());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// util.null
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 3de051b..1a009f1 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -1264,6 +1264,9 @@
for (auto [index, operandAssumptionsAttr] :
llvm::enumerate(allOperandAssumptions)) {
auto operandAssumptions = cast<ArrayAttr>(operandAssumptionsAttr);
+ // We always allow a single row to broadcast to any requested size.
+ if (operandAssumptions.size() == 1)
+ continue;
if (rank && *rank != operandAssumptions.size())
return emitOpError() << "expected operand #" << index << " to have "
<< *rank << " assumptions but it has "
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index aaa10da..b1c17bd 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -518,6 +518,7 @@
std::optional<uint64_t> getUnionedUnsignedDivisor(unsigned operandIndex);
}];
+ let hasCanonicalizeMethod = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
index 2df2bfd..a1c6040 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
@@ -20,6 +20,7 @@
"alignment_ops.mlir",
"assignment_folding.mlir",
"assignment_ops.mlir",
+ "assume_folding.mlir",
"assume_ops.mlir",
"attributes.mlir",
"buffer_folding.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
index b6ac5d8..2dad4d1 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
@@ -18,6 +18,7 @@
"alignment_ops.mlir"
"assignment_folding.mlir"
"assignment_ops.mlir"
+ "assume_folding.mlir"
"assume_ops.mlir"
"attributes.mlir"
"buffer_folding.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir
new file mode 100644
index 0000000..ffc2aae
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir
@@ -0,0 +1,50 @@
+// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @already_canonical
+util.func public @already_canonical(%arg0 : index) -> index {
+ // CHECK: util.assume.int
+ %0 = util.assume.int %arg0<umin=0> : index
+ util.return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @elide_constant_assumption
+util.func public @elide_constant_assumption() -> index {
+ %cst = arith.constant 1 : index
+ %0 = util.assume.int %cst<umin=0> : index
+ // CHECK: %[[CST:.*]] = arith.constant 1 : index
+ // CHECK: util.return %[[CST]]
+ util.return %0 : index
+}
+
+// -----
+// CHECK-LABEL: @elide_multi_constant_assumption
+util.func public @elide_multi_constant_assumption(%arg0 : index, %arg1 : index) -> index, index, index {
+ %cst = arith.constant 1 : index
+ // CHECK: %[[CST:.*]] = arith.constant 1 : index
+ // CHECK: %[[ASSUME:.*]]:2 = util.assume.int
+ // CHECK-NEXT: %arg0<udiv = 2>,
+ // CHECK-NEXT: %arg1<udiv = 4>
+ // CHECK-NEXT: : index, index
+ %0:3 = util.assume.int %arg0<udiv=2>, %cst<umin=0>, %arg1<udiv=4> : index, index, index
+ // CHECK: util.return %[[ASSUME]]#0, %[[CST]], %[[ASSUME]]#1
+ util.return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+// CHECK-LABEL: @broadcast_duplicate_assumptions
+util.func public @broadcast_duplicate_assumptions(%arg0 : index) -> index {
+ // CHECK: util.assume.int %arg0<umin = 0>
+ %0 = util.assume.int %arg0[<umin=0>, <umin=0>] : index
+ util.return %0 : index
+}
+
+// -----
+// CHECK-LABEL: @dedup_duplicate_operands
+util.func public @dedup_duplicate_operands(%arg0 : index) -> index, index {
+ // CHECK: %[[ASSUME:.*]] = util.assume.int %arg0<umax = 2> : index
+ %0:2 = util.assume.int %arg0[<umax=2>, <umax=2>], %arg0<umin=0> : index, index
+ // CHECK: util.return %[[ASSUME]], %[[ASSUME]]
+ util.return %0#0, %0#1 : index, index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
index 2be8dc5..e9d1708 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
@@ -1,7 +1,15 @@
// RUN: iree-opt --split-input-file --verify-diagnostics %s
util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
- // expected-error @+1 {{expected operand #1 to have 1 assumptions but it has 2}}
+ // expected-error @+1 {{expected operand #1 to have 3 assumptions but it has 2}}
+ %0:2 = util.assume.int %arg0[<umin=0>, <umax=2>, <udiv=16>], %arg1[<umax=10>, <udiv=6>] : index, i64
+ util.return %0#0, %0#1 : index, i64
+}
+
+// -----
+
+util.func public @assume.int.multi_operand_broadcast(%arg0 : index, %arg1 : i64) -> index, i64 {
+ // It is legal to have a mismatched arity if 1.
%0:2 = util.assume.int %arg0[<umin=0>], %arg1[<umax=10>, <udiv=6>] : index, i64
util.return %0#0, %0#1 : index, i64
}