Add integer range inference to hal.buffer_view.dim and rank ops. (#18943)
This matches that default range behavior of runtime dimensions we get
from frontends.
---------
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
index f1a820f..81f8da8 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp
@@ -22,6 +22,27 @@
namespace mlir::iree_compiler::IREE::HAL {
+namespace {
+
+// We aribtrarily say that unbounded dimensions in a torch program cannot
+// exceed 53bits, making the maximum safe dimension 9007199254740991. The
+// astute reader will note that this is also the maximum safe value in
+// JavaScript, which also "happens" to be the largest mantissa value in a
+// 64bit double. We need a maximum and in the absence of a better choice,
+// with this one we are at least in good company. This limit is also used
+// in the frontends.
+static constexpr uint64_t MAX_DIM_VALUE = (static_cast<uint64_t>(1) << 53) - 1;
+
+// Similarly we use a very conservative maximum rank value for specifying
+// ranges of runtime rank resolution functions. Various frameworks have hard
+// and practical limits ranging from 32 (numpy) to hundreds. At the time of
+// writing, PyTorch throws weird errors if trying to print a tensor with a rank
+// greater than 992. We really just want a smallish integer value to bound
+// arithmetic, so we use an arbitrary maximum.
+static constexpr uint64_t MAX_RANK_VALUE = 4096;
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// custom<DescriptorType>($descriptor_type)
//===----------------------------------------------------------------------===//
@@ -1025,6 +1046,30 @@
}
//===----------------------------------------------------------------------===//
+// hal.buffer_view.dim
+//===----------------------------------------------------------------------===//
+
+void BufferViewDimOp::inferResultRangesFromOptional(
+ ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
+ const unsigned indexTypeNumBits = 64;
+ setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
+ APInt::getZero(indexTypeNumBits),
+ APInt(indexTypeNumBits, MAX_DIM_VALUE))));
+}
+
+//===----------------------------------------------------------------------===//
+// hal.buffer_view.dim
+//===----------------------------------------------------------------------===//
+
+void BufferViewRankOp::inferResultRangesFromOptional(
+ ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
+ const unsigned indexTypeNumBits = 64;
+ setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
+ APInt::getZero(indexTypeNumBits),
+ APInt(indexTypeNumBits, MAX_RANK_VALUE))));
+}
+
+//===----------------------------------------------------------------------===//
// hal.channel.create
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
index ae58127..16dd46b 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.h
@@ -20,6 +20,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
index fdd43b7..9e370a1 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
+++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
@@ -18,6 +18,7 @@
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -1010,7 +1011,10 @@
}];
}
-def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank"> {
+def HAL_BufferViewRankOp : HAL_PureOp<"buffer_view.rank", [
+ DeclareOpInterfaceMethods<InferIntRangeInterface,
+ ["inferResultRangesFromOptional"]>,
+]> {
let summary = [{buffer view rank query}];
let description = [{
Returns the rank of the buffer view.
@@ -1030,7 +1034,10 @@
}];
}
-def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim"> {
+def HAL_BufferViewDimOp : HAL_PureOp<"buffer_view.dim", [
+ DeclareOpInterfaceMethods<InferIntRangeInterface,
+ ["inferResultRangesFromOptional"]>,
+]> {
let summary = [{buffer view dimension value query}];
let description = [{
Returns the value of the given dimension.
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
index 1924f42..f78817c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir
@@ -493,3 +493,33 @@
%rem16 = arith.remui %0, %c16 : i64
util.return %rem16 : i64
}
+
+// -----
+
+util.func @hal_buffer_view_dim_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) {
+ %zero = arith.constant 0 : index
+ %max = arith.constant 9007199254740991 : index
+ %0 = hal.buffer_view.dim<%bv : !hal.buffer_view>[0] : index
+ %1 = arith.cmpi slt, %0, %zero : index
+ %2 = arith.cmpi uge, %0, %zero : index
+ %3 = arith.cmpi ugt, %0, %max : index
+ // CHECK-DAG: %[[FALSE:.*]] = arith.constant false
+ // CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+ // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]]
+ util.return %1, %2, %3 : i1, i1, i1
+}
+
+// -----
+
+util.func @hal_buffer_view_rank_min_max(%bv : !hal.buffer_view) -> (i1, i1, i1) {
+ %zero = arith.constant 0 : index
+ %max = arith.constant 4096 : index
+ %0 = hal.buffer_view.rank<%bv : !hal.buffer_view> : index
+ %1 = arith.cmpi slt, %0, %zero : index
+ %2 = arith.cmpi uge, %0, %zero : index
+ %3 = arith.cmpi ugt, %0, %max : index
+ // CHECK-DAG: %[[FALSE:.*]] = arith.constant false
+ // CHECK-DAG: %[[TRUE:.*]] = arith.constant true
+ // CHECK: util.return %[[FALSE]], %[[TRUE]], %[[FALSE]]
+ util.return %1, %2, %3 : i1, i1, i1
+}