Add util.assume.int folder. (#18805)

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
index fb466bf..4e13877 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOpFolders.cpp
@@ -27,6 +27,115 @@
 namespace mlir::iree_compiler::IREE::Util {
 
 //===----------------------------------------------------------------------===//
+// util.assume.int
+//===----------------------------------------------------------------------===//
+
+LogicalResult AssumeIntOp::canonicalize(AssumeIntOp op,
+                                        PatternRewriter &rewriter) {
+  bool needsRewrite = false;
+  ArrayAttr assumptions = op.getAssumptions();
+
+  // We do a fast check for the canonical form here, making any in-place updates
+  // we can and signalling needsRewrite=true when the op needs to be updated
+  // to a new canonical form.
+  SmallPtrSet<Value, 4> seenOperands;
+  seenOperands.reserve(op.getNumOperands());
+  for (auto [idx, operand] : llvm::enumerate(op.getOperands())) {
+    // Match constant.
+    if (matchPattern(operand, m_Constant())) {
+      needsRewrite = true;
+      rewriter.replaceAllUsesWith(op.getResult(idx), operand);
+      continue;
+    }
+
+    // Check for a duplicate.
+    auto [foundIt, inserted] = seenOperands.insert(operand);
+    if (!inserted) {
+      // This should be the non-common path: find the original index number
+      // and rewrite.
+      for (auto [seenIdx, seenOperand] : llvm::enumerate(op.getOperands())) {
+        if (seenOperand == operand) {
+          needsRewrite = true;
+          rewriter.replaceAllUsesWith(op.getResult(idx), op.getResult(seenIdx));
+          break;
+        }
+      }
+      continue;
+    }
+
+    // Detect whether assumptions need to be normalized.
+    ArrayAttr assumptionRow = llvm::cast<ArrayAttr>(assumptions[idx]);
+    if (assumptionRow.size() > 1) {
+      bool allAssumptionsSame = true;
+      for (unsigned i = 1; i < assumptionRow.size(); ++i) {
+        if (assumptionRow[i] != assumptionRow[0]) {
+          allAssumptionsSame = false;
+          break;
+        }
+      }
+      if (allAssumptionsSame) {
+        needsRewrite = true;
+      }
+    }
+  }
+  if (!needsRewrite)
+    return failure();
+
+  // Need to rewrite the assumption.
+  auto normalizeAssumptions = [](Attribute row, bool &madeChange) {
+    auto rowArray = llvm::cast<ArrayAttr>(row);
+    if (rowArray.size() <= 1)
+      return rowArray;
+
+    bool allSame = true;
+    for (unsigned i = 1; i < rowArray.size(); ++i) {
+      if (rowArray[0] != rowArray[i]) {
+        allSame = false;
+        break;
+      }
+    }
+
+    if (!allSame)
+      return rowArray;
+
+    // All entries are the same: compress down to a single column.
+    madeChange = true;
+    return ArrayAttr::get(row.getContext(), {rowArray[0]});
+  };
+  SmallVector<ArrayAttr> newAssumptions;
+  SmallVector<Value> newOperands;
+  SmallVector<Value> retainedResults;
+  bool madeChange = false;
+  for (auto [idx, operand] : llvm::enumerate(op.getOperands())) {
+    // If the result has no uses, do not retain it.
+    if (op.getResult(idx).use_empty()) {
+      madeChange = true;
+      continue;
+    }
+
+    newAssumptions.push_back(
+        normalizeAssumptions(assumptions[idx], madeChange));
+    newOperands.push_back(operand);
+    retainedResults.push_back(op.getResult(idx));
+  }
+
+  // It is important to avoid canonicalizer looping that if we determined at
+  // the top that a rewrite was needed, that we actually made a change.
+  (void)madeChange;
+  assert(madeChange && "util.assume.int canonicalizer signaled a rewrite was "
+                       "needed but it produced the same op");
+
+  if (!newOperands.empty()) {
+    auto newOp =
+        rewriter.create<AssumeIntOp>(op.getLoc(), newOperands, newAssumptions);
+    rewriter.replaceAllUsesWith(retainedResults, newOp.getResults());
+  }
+
+  rewriter.eraseOp(op);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // util.null
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 3de051b..1a009f1 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -1264,6 +1264,9 @@
   for (auto [index, operandAssumptionsAttr] :
        llvm::enumerate(allOperandAssumptions)) {
     auto operandAssumptions = cast<ArrayAttr>(operandAssumptionsAttr);
+    // We always allow a single row to broadcast to any requested size.
+    if (operandAssumptions.size() == 1)
+      continue;
     if (rank && *rank != operandAssumptions.size())
       return emitOpError() << "expected operand #" << index << " to have "
                            << *rank << " assumptions but it has "
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index aaa10da..b1c17bd 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -518,6 +518,7 @@
     std::optional<uint64_t> getUnionedUnsignedDivisor(unsigned operandIndex);
   }];
 
+  let hasCanonicalizeMethod  = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
index 2df2bfd..a1c6040 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
@@ -20,6 +20,7 @@
             "alignment_ops.mlir",
             "assignment_folding.mlir",
             "assignment_ops.mlir",
+            "assume_folding.mlir",
             "assume_ops.mlir",
             "attributes.mlir",
             "buffer_folding.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
index b6ac5d8..2dad4d1 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
@@ -18,6 +18,7 @@
     "alignment_ops.mlir"
     "assignment_folding.mlir"
     "assignment_ops.mlir"
+    "assume_folding.mlir"
     "assume_ops.mlir"
     "attributes.mlir"
     "buffer_folding.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir
new file mode 100644
index 0000000..ffc2aae
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_folding.mlir
@@ -0,0 +1,50 @@
+// RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @already_canonical
+util.func public @already_canonical(%arg0 : index) -> index  {
+  // CHECK: util.assume.int
+  %0 = util.assume.int %arg0<umin=0> : index
+  util.return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: @elide_constant_assumption
+util.func public @elide_constant_assumption() -> index  {
+  %cst = arith.constant 1 : index
+  %0 = util.assume.int %cst<umin=0> : index
+  // CHECK: %[[CST:.*]] = arith.constant 1 : index
+  // CHECK: util.return %[[CST]]
+  util.return %0 : index
+}
+
+// -----
+// CHECK-LABEL: @elide_multi_constant_assumption
+util.func public @elide_multi_constant_assumption(%arg0 : index, %arg1 : index) -> index, index, index {
+  %cst = arith.constant 1 : index
+  // CHECK: %[[CST:.*]] = arith.constant 1 : index
+  // CHECK: %[[ASSUME:.*]]:2 = util.assume.int
+  // CHECK-NEXT: %arg0<udiv = 2>,
+  // CHECK-NEXT: %arg1<udiv = 4>
+  // CHECK-NEXT: : index, index
+  %0:3 = util.assume.int %arg0<udiv=2>, %cst<umin=0>, %arg1<udiv=4> : index, index, index
+  // CHECK: util.return %[[ASSUME]]#0, %[[CST]], %[[ASSUME]]#1
+  util.return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+// CHECK-LABEL: @broadcast_duplicate_assumptions
+util.func public @broadcast_duplicate_assumptions(%arg0 : index) -> index  {
+  // CHECK: util.assume.int %arg0<umin = 0>
+  %0 = util.assume.int %arg0[<umin=0>, <umin=0>] : index
+  util.return %0 : index
+}
+
+// -----
+// CHECK-LABEL: @dedup_duplicate_operands
+util.func public @dedup_duplicate_operands(%arg0 : index) -> index, index {
+  // CHECK: %[[ASSUME:.*]] = util.assume.int %arg0<umax = 2> : index
+  %0:2 = util.assume.int %arg0[<umax=2>, <umax=2>], %arg0<umin=0> : index, index
+  // CHECK: util.return %[[ASSUME]], %[[ASSUME]]
+  util.return %0#0, %0#1 : index, index
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
index 2be8dc5..e9d1708 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
@@ -1,7 +1,15 @@
 // RUN: iree-opt --split-input-file --verify-diagnostics %s
 
 util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64  {
-  // expected-error @+1 {{expected operand #1 to have 1 assumptions but it has 2}}
+  // expected-error @+1 {{expected operand #1 to have 3 assumptions but it has 2}}
+  %0:2 = util.assume.int %arg0[<umin=0>, <umax=2>, <udiv=16>], %arg1[<umax=10>, <udiv=6>] : index, i64
+  util.return %0#0, %0#1 : index, i64
+}
+
+// -----
+
+util.func public @assume.int.multi_operand_broadcast(%arg0 : index, %arg1 : i64) -> index, i64  {
+  // It is legal to have a mismatched arity if 1.
   %0:2 = util.assume.int %arg0[<umin=0>], %arg1[<umax=10>, <udiv=6>] : index, i64
   util.return %0#0, %0#1 : index, i64
 }