Add `InferIntDivisibilityInterface` for `arith.muli`. (#18994)

Fixes #18973

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
index 7a45116..930c9f5 100644
--- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp
@@ -60,6 +60,7 @@
 /// any un-propagated `tensor.expand_shape/tensor.collapse_shape` patterns.
 struct BlockDynamicDimensionsPass final
     : impl::BlockDynamicDimensionsPassBase<BlockDynamicDimensionsPass> {
+  using Base::Base;
   void runOnOperation() override;
 };
 } // namespace
@@ -222,6 +223,33 @@
                                 prunedOperandsList);
 }
 
+/// Generic method to block dynamic dimensions for all tensor operands.
+/// Only used for testing for now
+static LogicalResult
+blockDynamicDimensions(RewriterBase &rewriter,
+                       const TensorDynamicDimAnalysis &dynamicDimAnalysis,
+                       Operation *operation, bool test) {
+  return TypeSwitch<Operation *, LogicalResult>(operation)
+      .Case<IREE::LinalgExt::AttentionOp>([&](auto attentionOp) {
+        return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
+                                      attentionOp);
+      })
+      .Default([&](Operation *op) {
+        if (!test) {
+          return success();
+        }
+        // The default path here is for now only for testing.
+        llvm::SmallDenseSet<int64_t> tensorOperandsList;
+        for (OpOperand &opOperand : operation->getOpOperands()) {
+          if (isa<RankedTensorType>(opOperand.get().getType())) {
+            tensorOperandsList.insert(opOperand.getOperandNumber());
+          }
+        }
+        return blockDynamicDimensions(rewriter, dynamicDimAnalysis, operation,
+                                      tensorOperandsList);
+      });
+}
+
 void BlockDynamicDimensionsPass::runOnOperation() {
   Operation *operation = getOperation();
   MLIRContext *context = &getContext();
@@ -231,12 +259,10 @@
   }
 
   IRRewriter rewriter(context);
-  auto walkResult = operation->walk(
-      [&](IREE::LinalgExt::AttentionOp attentionOp) -> WalkResult {
-        rewriter.setInsertionPoint(attentionOp);
-        return blockDynamicDimensions(rewriter, dynamicDimAnalysis,
-                                      attentionOp);
-      });
+  auto walkResult = operation->walk([&](Operation *op) -> WalkResult {
+    rewriter.setInsertionPoint(op);
+    return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op, test);
+  });
   if (walkResult.wasInterrupted()) {
     return signalPassFailure();
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td
index 5aa3ef4..bb0676f 100644
--- a/compiler/src/iree/compiler/Codegen/Common/Passes.td
+++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td
@@ -23,6 +23,9 @@
     : Pass<"iree-codegen-block-dynamic-dimensions"> {
   let summary = "Expand dynamic dimensions that are known to be multiples of "
                 "statically known values.";
+  let options = [
+    Option<"test", "test", "bool", /*default=*/"false", "Enable test mode">
+  ];
 }
 
 def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> {
diff --git a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
index b0e7667..fa8f774 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TensorDynamicDimAnalysis.cpp
@@ -59,7 +59,8 @@
   // Update solver info
   auto *divisibilityState =
       solver.lookupState<IREE::Util::IntegerDivisibilityLattice>(dynamicDim);
-  if (divisibilityState && !divisibilityState->getValue().isUninitialized()) {
+  if (divisibilityState && !divisibilityState->getValue().isUninitialized() &&
+      divisibilityState->getValue().getValue().sdiv() != 1) {
     updateDivisibilityInfo(divisibilityInfo, tensorValue, dimIndex,
                            divisibilityState->getValue().getValue());
   }
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
index 819c412..4dab261 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions{test}, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s
 
 #pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
     #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
@@ -99,3 +99,36 @@
 //       CHECK:       ins(%[[Q]], %[[K]], %[[V]], %{{.+}}, %[[MASK]] :
 //       CHECK:   %[[GENERIC:.+]] = linalg.generic
 //       CHECK:   flow.dispatch.tensor.store %[[GENERIC]], %[[OUTPUT_BINDING]]
+
+// -----
+
+func.func @basic_blocking_test(%arg0 : index) -> tensor<?xf32> {
+  %0 = util.assume.int %arg0<umin = 0, umax = 1024, udiv = 16> : index
+  %1 = tensor.empty(%0) : tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
+// CHECK-LABEL: func @basic_blocking_test(
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?x16xf32>
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EMPTY]]
+//       CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_blocking(%arg0 : index) -> tensor<?xf32> {
+  %1 = tensor.empty(%arg0) : tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_blocking(
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
+//       CHECK:   return %[[EMPTY]]
+
+// -----
+
+func.func @no_unit_blocking(%arg0 : index) -> tensor<?xf32> {
+  %0 = util.assume.int %arg0<umin = 0, umax = 1024, udiv = 1> : index
+  %1 = tensor.empty(%0) : tensor<?xf32>
+  return %1 : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_unit_blocking(
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
+//       CHECK:   return %[[EMPTY]]
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir
index 3e2235b..0c80b48 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir
@@ -45,3 +45,15 @@
   %1 = arith.remui %0, %cst : index
   util.return %1 : index
 }
+
+// -----
+
+util.func @muli_divisibility(%arg0 : index) -> index {
+  %cst = arith.constant 16 : index
+  %0 = arith.muli %arg0, %cst : index
+  %1 = arith.remui %0, %cst : index
+  util.return %1 : index
+}
+// CHECK-LABEL: @muli_divisibility
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   return %[[C0]]
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
index 2097cbf..1ce8e18 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
 
 namespace mlir::iree_compiler {
 
@@ -45,6 +46,28 @@
   }
 };
 
+struct ArithMulIInferIntDivisibilityOpInterface
+    : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel<
+          ArithMulIInferIntDivisibilityOpInterface, arith::MulIOp> {
+
+  void inferResultDivisibility(
+      Operation *op, ArrayRef<IREE::Util::IntegerDivisibility> argDivs,
+      IREE::Util::SetIntDivisibilityFn setResultDivs) const {
+    auto mulOp = cast<arith::MulIOp>(op);
+    APInt intVal;
+    if (!matchPattern(mulOp.getLhs(), m_ConstantInt(&intVal))) {
+      if (!matchPattern(mulOp.getRhs(), m_ConstantInt(&intVal))) {
+        return;
+      }
+    }
+
+    uint64_t udiv = intVal.getZExtValue();
+    uint64_t sdiv = std::abs(intVal.getSExtValue());
+    setResultDivs(mulOp.getResult(),
+                  IREE::Util::ConstantIntDivisibility(udiv, sdiv));
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // GlobalOpInterface
 //===----------------------------------------------------------------------===//
@@ -328,6 +351,8 @@
         arith::TruncIOp, arith::SIToFPOp, arith::UIToFPOp>(context);
     arith::ConstantOp::attachInterface<
         ArithConstantInferIntDivisibilityOpInterface>(*context);
+    arith::MulIOp::attachInterface<ArithMulIInferIntDivisibilityOpInterface>(
+        *context);
   });
 
   registry.addExtension(