Add integer range inference to hal.buffer_view.dim and rank ops. (#18943)

This matches that default range behavior of runtime dimensions we get
from frontends.

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index f1a820f..81f8da8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -22,6 +22,27 @@
 
 namespace mlir::iree_compiler::IREE::HAL {
 
+namespace {
+
+// We aribtrarily say that unbounded dimensions in a torch program cannot
+// exceed 53bits, making the maximum safe dimension 9007199254740991. The
+// astute reader will note that this is also the maximum safe value in
+// JavaScript, which also "happens" to be the largest mantissa value in a
+// 64bit double. We need a maximum and in the absence of a better choice,
+// with this one we are at least in good company. This limit is also used
+// in the frontends.
+static constexpr uint64_t MAX_DIM_VALUE = (static_cast<uint64_t>(1) << 53) - 1;
+
+// Similarly we use a very conservative maximum rank value for specifying
+// ranges of runtime rank resolution functions. Various frameworks have hard
+// and practical limits ranging from 32 (numpy) to hundreds. At the time of
+// writing, PyTorch throws weird errors if trying to print a tensor with a rank
+// greater than 992. We really just want a smallish integer value to bound
+// arithmetic, so we use an arbitrary maximum.
+static constexpr uint64_t MAX_RANK_VALUE = 4096;
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // custom<DescriptorType>($descriptor_type)
 //===----------------------------------------------------------------------===//
@@ -1025,6 +1046,30 @@
 }
 
 //===----------------------------------------------------------------------===//
+// hal.buffer_view.dim
+//===----------------------------------------------------------------------===//
+
+void BufferViewDimOp::inferResultRangesFromOptional(
+    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
+  const unsigned indexTypeNumBits = 64;
+  setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
+                                  APInt::getZero(indexTypeNumBits),
+                                  APInt(indexTypeNumBits, MAX_DIM_VALUE))));
+}
+
+//===----------------------------------------------------------------------===//
+// hal.buffer_view.dim
+//===----------------------------------------------------------------------===//
+
+void BufferViewRankOp::inferResultRangesFromOptional(
+    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
+  const unsigned indexTypeNumBits = 64;
+  setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
+                                  APInt::getZero(indexTypeNumBits),
+                                  APInt(indexTypeNumBits, MAX_RANK_VALUE))));
+}
+
+//===----------------------------------------------------------------------===//
 // hal.channel.create
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
index ae58127..16dd46b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index fdd43b7..9e370a1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -18,6 +18,7 @@
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 
@@ -1010,7 +1011,10 @@
   }];
 }
 
-def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> {
+def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank", [
+    DeclareOpInterfaceMethods<InferIntRangeInterface,
+        ["inferResultRangesFromOptional"]>,
+]> {
   let summary = [{buffer view rank query}];
   let description = [{
     Returns the rank of the buffer view.
@@ -1030,7 +1034,10 @@
   }];
 }
 
-def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> {
+def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim", [
+    DeclareOpInterfaceMethods<InferIntRangeInterface,
+        ["inferResultRangesFromOptional"]>,
+]> {
   let summary = [{buffer view dimension value query}];
   let description = [{
     Returns the value of the given dimension.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
index 1924f42..f78817c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
@@ -493,3 +493,33 @@
   %rem16 = arith.remui %0, %c16 : i64
   util.return %rem16 : i64
 }
+
+// -----
+
+util.func @hal_buffer_view_dim_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) {
+  %zero = arith.constant 0 : index
+  %max = arith.constant 9007199254740991 : index
+  %0 = hal.buffer_view.dim<%bv : !hal.buffer_view>[0] : index
+  %1 = arith.cmpi slt, %0, %zero : index
+  %2 = arith.cmpi uge, %0, %zero : index
+  %3 = arith.cmpi ugt, %0, %max : index
+  // CHECK-DAG: %[[FALSE:.*]] = arith.constant false
+  // CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+  // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]]
+  util.return %1, %2, %3 : i1, i1, i1
+}
+
+// -----
+
+util.func @hal_buffer_view_rank_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) {
+  %zero = arith.constant 0 : index
+  %max = arith.constant 4096 : index
+  %0 = hal.buffer_view.rank<%bv : !hal.buffer_view> : index
+  %1 = arith.cmpi slt, %0, %zero : index
+  %2 = arith.cmpi uge, %0, %zero : index
+  %3 = arith.cmpi ugt, %0, %max : index
+  // CHECK-DAG: %[[FALSE:.*]] = arith.constant false
+  // CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+  // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]]
+  util.return %1, %2, %3 : i1, i1, i1
+}