Rework util.assume.* ops to util.assume.int and base on attributes. (#18703)
This arguably should have been done from the get-go but the need wasn't
apparent until implementing propagation across executable boundaries.
The key additional requirement which arises there is in relation to
capturing annotations from multiple specializations in a way that can be
actioned. Having one op which carries all of the information and can
represent both multiple constraints and multiple operands solves for all
needed cases.
Also standardized on umin/umax verbiage to align with the upstream
dataflow terminology, which will be hooked up in the following patch.
---------
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
index b37ab37..b1c90d2 100644
--- a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
@@ -30,18 +30,6 @@
namespace {
-Type getNarrowestType(Builder &builder,
- std::optional<std::pair<int64_t, int64_t>> minMaxBounds) {
- if (!minMaxBounds)
- return {};
-
- auto maxBound = minMaxBounds->second;
- if (maxBound <= std::numeric_limits<int32_t>::max())
- return builder.getIntegerType(32);
- else
- return builder.getIntegerType(64);
-}
-
// Torch "binds" symbolic shape information to all tensors in the program
// which are not static. It does this by emitting side-effecting
// torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops
@@ -352,20 +340,25 @@
// Add optimization assumptions if the divisor or bounds are known.
int64_t divisor = expr.getLargestKnownDivisor();
auto bounds = evaluateExprBounds(expr, symbolInfos);
- if (divisor != 1 || bounds) {
- Type narrowType = getNarrowestType(builder, bounds);
- if (narrowType) {
- dimValue = builder.create<IREE::Util::AssumeNarrowOp>(
- bindOp->getLoc(), dimValue, TypeAttr::get(narrowType));
- }
- if (bounds) {
- dimValue = builder.create<IREE::Util::AssumeRangeOp>(
- bindOp->getLoc(), dimValue, bounds->first, bounds->second);
- }
- if (divisor != 1) {
- dimValue = builder.create<IREE::Util::AssumeDivisibleOp>(
- bindOp->getLoc(), dimValue, divisor);
- }
+ std::optional<uint64_t> optionalUmin;
+ std::optional<uint64_t> optionalUmax;
+ std::optional<int64_t> optionalDivisor;
+ if (bounds) {
+ optionalUmin = bounds->first;
+ optionalUmax = bounds->second;
+ }
+ if (divisor != 1) {
+ optionalDivisor = divisor;
+ }
+ if (optionalUmin || optionalUmax || optionalDivisor) {
+ auto assumption = builder.getAttr<IREE::Util::IntAssumptionAttr>(
+ /*umin=*/optionalUmin,
+ /*umax=*/optionalUmax,
+ /*divisor=*/optionalDivisor);
+ dimValue = builder
+ .create<IREE::Util::AssumeIntOp>(bindOp->getLoc(),
+ dimValue, assumption)
+ .getResult(0);
}
materializedDims[index] = dimValue;
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
index e3d6061..641013f 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
+++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
@@ -12,20 +12,15 @@
// CHECK-DAG: %[[POS1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %1, %[[POS0]] :
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %1, %[[POS1]] :
- // CHECK: %[[ARG0_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
- // CHECK: %[[ARG0_DIM0_RANGE:.*]] = util.assume.range %[[ARG0_DIM0_NARROW]] in [1, 1024] : index
- // CHECK: %[[ARG0_DIM1_NARROW:.*]] = util.assume.narrow %[[DIM1]] : index to i32
- // CHECK: %[[ARG0_DIM1_RANGE:.*]] = util.assume.range %[[ARG0_DIM1_NARROW]] in [1, 1024] : index
+ // CHECK: %[[ARG0_DIM0_RANGE:.*]] = util.assume.int %[[DIM0]]<umin = 1, umax = 1024> : index
+ // CHECK: %[[ARG0_DIM1_RANGE:.*]] = util.assume.int %[[DIM1]]<umin = 1, umax = 1024> : index
// CHECK: %[[ARG0_TIE:.*]] = flow.tensor.tie_shape %[[ARG0_ANCHOR]] : tensor<?x?xf32>{%[[ARG0_DIM0_RANGE]], %[[ARG0_DIM1_RANGE]]}
// CHECK: %[[ARG0_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG0_TIE]]
- // CHECK: %[[ARG1_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
- // CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.range %[[ARG1_DIM0_NARROW]] in [1, 1024]
+ // CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.int %[[DIM0]]<umin = 1, umax = 1024>
// CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index
// CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]]
- // CHECK: %[[ARG1_DIM1_NARROW:.*]] = util.assume.narrow %[[ARG1_DIM1]] : index to i32
- // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.range %[[ARG1_DIM1_NARROW]] in [2, 2048] : index
- // CHECK: %[[ARG1_DIM1_DIV:.*]] = util.assume.divisible %[[ARG1_DIM1_RANGE]] by 2
- // CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor<?x?xf32>{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_DIV]]}
+ // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.int %[[ARG1_DIM1]]<umin = 2, umax = 2048, divisor = 2> : index
+ // CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor<?x?xf32>{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_RANGE]]}
// CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]]
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
@@ -53,17 +48,11 @@
// CHECK-LABEL: @unbacked_symbol
module @unbacked_symbol {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
- // CHECK: util.assume.narrow
- // CHECK: util.assume.range{{.*}} [1, 1024]
- // CHECK: util.assume.narrow
- // CHECK: util.assume.range{{.*}} [2, 2048]
- // CHECK: util.assume.divisible{{.*}} by 2
+ // CHECK: util.assume.int{{.*}}<umin = 1, umax = 1024>
+ // CHECK: util.assume.int{{.*}}<umin = 2, umax = 2048, divisor = 2>
// CHECK: tie_shape
- // CHECK: util.assume.narrow
- // CHECK: util.assume.range{{.*}} [1, 1024]
- // CHECK: util.assume.narrow
- // CHECK: util.assume.range{{.*}} [4, 4096]
- // CHECK: util.assume.divisible{{.*}} by 4
+ // CHECK: util.assume.int{{.*}}<umin = 1, umax = 1024>
+ // CHECK: util.assume.int{{.*}}<umin = 4, umax = 4096, divisor = 4>
// CHECK: tie_shape
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int
@@ -111,7 +100,7 @@
module @add_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: addi
- // CHECK-NOT: divisible
+ // CHECK-NOT: divisor
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
@@ -125,7 +114,7 @@
module @mod_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: remui
- // CHECK-NOT: divisible
+ // CHECK-NOT: divisor
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
@@ -139,7 +128,7 @@
module @floordiv_expr {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
// CHECK: divui
- // CHECK-NOT: divisible
+ // CHECK-NOT: divisor
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
@@ -168,7 +157,7 @@
// CHECK-LABEL: @torch_unbounded_max_range
module @torch_unbounded_max_range {
func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
- // CHECK-NOT: util.assume.range
+ // CHECK-NOT: util.assume.int<umin
%0 = torch.symbolic_int "s0" {min_val = 0, max_val = 4611686018427387903} : !torch.int
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775806} : !torch.int
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
index ecb3fde..310e9f4 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/ConversionPatterns.cpp
@@ -26,35 +26,13 @@
// Utilities
//===----------------------------------------------------------------------===//
-namespace {
-
-struct ConvertAssumeNarrowOp
- : public OpConversionPattern<IREE::Util::AssumeNarrowOp> {
- using OpConversionPattern::OpConversionPattern;
- LogicalResult
- matchAndRewrite(IREE::Util::AssumeNarrowOp assumeOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // The op is a pass-through but we leave the narrow type unconverted since
- // it is a hint and should not itself be subject to type conversion
- // (narrowing, widening, etc).
- rewriter.replaceOpWithNewOp<IREE::Util::AssumeNarrowOp>(
- assumeOp, adaptor.getOperand().getType(), adaptor.getOperand(),
- assumeOp.getNarrowType());
- return success();
- }
-};
-
-} // namespace
-
void populateUtilConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns
- .insert<GenericConvertTypesPattern<IREE::Util::AssumeDivisibleOp>,
- GenericConvertTypesPattern<IREE::Util::AssumeRangeOp>,
+ .insert<GenericConvertTypesPattern<IREE::Util::AssumeIntOp>,
GenericConvertTypesPattern<IREE::Util::OptimizationBarrierOp>>(
typeConverter, context);
- patterns.insert<ConvertAssumeNarrowOp>(typeConverter, context);
typeConverter.addConversion([&](IREE::Util::PtrType type,
SmallVectorImpl<Type> &results) {
@@ -83,11 +61,7 @@
ConversionTarget &conversionTarget,
TypeConverter &typeConverter,
RewritePatternSet &patterns) {
- addGenericLegalOp<IREE::Util::AssumeDivisibleOp>(conversionTarget,
- typeConverter);
- addGenericLegalOp<IREE::Util::AssumeNarrowOp>(conversionTarget,
- typeConverter);
- addGenericLegalOp<IREE::Util::AssumeRangeOp>(conversionTarget, typeConverter);
+ addGenericLegalOp<IREE::Util::AssumeIntOp>(conversionTarget, typeConverter);
addGenericLegalOp<IREE::Util::OptimizationBarrierOp>(conversionTarget,
typeConverter);
addGenericLegalOp<IREE::Util::ListCreateOp>(conversionTarget, typeConverter);
diff --git a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/compiler_hints.mlir b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/compiler_hints.mlir
index 5cbc524..e77ca93 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Conversion/test/compiler_hints.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Conversion/test/compiler_hints.mlir
@@ -1,25 +1,10 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-util-test-conversion{widen-integers})' %s | FileCheck %s
-// CHECK-LABEL: @assumeDivisibleOp
-util.func public @assumeDivisibleOp(%arg0 : i16) -> i16 {
- // CHECK: util.assume.divisible {{.*}} by 4 : i32
- %0 = util.assume.divisible %arg0 by 4 : i16
- util.return %0 : i16
-}
-
// -----
-// CHECK-LABEL: @assumeNarrowOp
-util.func public @assumeNarrowOp(%arg0 : i16) -> i16 {
- // CHECK: util.assume.narrow %arg0 : i32 to i8
- %0 = util.assume.narrow %arg0 : i16 to i8
- util.return %0 : i16
-}
-
-// -----
-// CHECK-LABEL: @assumeRangeOp
-util.func public @assumeRangeOp(%arg0 : i16) -> i16 {
- // CHECK: util.assume.range %arg0 in [4, 12] : i32
- %0 = util.assume.range %arg0 in [4, 12] : i16
+// CHECK-LABEL: @assumeIntOp
+util.func public @assumeIntOp(%arg0 : i16) -> i16 {
+ // CHECK: util.assume.int %arg0<umin = 1> : i32
+ %0 = util.assume.int %arg0<umin = 1> : i16
util.return %0 : i16
}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
index 4473236..e43a2b6 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilAttrs.td
@@ -14,6 +14,37 @@
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
+// #util.int.assumptions
+//===----------------------------------------------------------------------===//
+
+def Util_IntAssumptionAttr : AttrDef<Util_Dialect, "IntAssumption", []
+> {
+ let mnemonic = "int.assumption";
+ let summary = [{specifies assumptions that can be made about an integer's value}];
+ let description = [{
+ This is typically used to memorialize the result of some integer analysis
+ or outside knowledge. All components of the attribute are optional.
+
+ See the op `util.assume.int` for binding assumptions to values.
+ }];
+ let parameters = (ins
+ DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$umin,
+ DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$umax,
+ DefaultValuedParameter<"std::optional<uint64_t>", "std::nullopt">:$divisor
+ );
+ let assemblyFormat = [{
+ `<` struct($umin, $umax, $divisor) `>`
+ }];
+}
+
+def Util_IntAssumptionAttrList : TypedArrayAttrBase<
+ Util_IntAssumptionAttr, "list of int assumption attributes">;
+
+def Util_MultiValueIntAssumptionAttrList : TypedArrayAttrBase<
+ Util_IntAssumptionAttrList,
+ "list of int attribute assumptions for multiple values">;
+
+//===----------------------------------------------------------------------===//
// #util.byte_pattern
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
index 581023a..73230ac 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp
@@ -1099,6 +1099,131 @@
namespace mlir::iree_compiler::IREE::Util {
//===----------------------------------------------------------------------===//
+// util.assume.int
+//===----------------------------------------------------------------------===//
+
+void AssumeIntOp::build(OpBuilder &builder, OperationState &state,
+ Value singleOperand,
+ IntAssumptionAttr singleAssumption) {
+ state.addOperands({singleOperand});
+ state.addTypes({singleOperand.getType()});
+ state.addAttribute("assumptions", builder.getArrayAttr(builder.getArrayAttr(
+ {singleAssumption})));
+}
+
+LogicalResult AssumeIntOp::verify() {
+ ArrayAttr allOperandAssumptions = getAssumptions();
+ // Verify that there is an assumption row per operand.
+ if (getNumOperands() != allOperandAssumptions.size()) {
+ return emitOpError() << "expected " << getNumOperands()
+ << " assumption rows to match number of operands";
+ }
+
+ std::optional<int> rank;
+ for (auto [index, operandAssumptionsAttr] :
+ llvm::enumerate(allOperandAssumptions)) {
+ auto operandAssumptions = cast<ArrayAttr>(operandAssumptionsAttr);
+ if (rank && *rank != operandAssumptions.size())
+ return emitOpError() << "expected operand #" << index << " to have "
+ << *rank << " assumptions but it has "
+ << operandAssumptions.size();
+ rank = operandAssumptions.size();
+ }
+
+ return success();
+}
+
+ParseResult AssumeIntOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<Attribute> allOperandAssumptions;
+ SmallVector<OpAsmParser::UnresolvedOperand> parsedOperands;
+ SmallVector<Type> parsedOperandTypes;
+
+ if (parser.parseCommaSeparatedList([&]() {
+ parsedOperands.emplace_back();
+ OpAsmParser::UnresolvedOperand &parsedOperand = parsedOperands.back();
+ SmallVector<Attribute> operandAssumptions;
+
+ if (parser.parseOperand(parsedOperand))
+ return failure();
+
+ // Parse as a single assumption or a list.
+ if (failed(parser.parseOptionalLSquare())) {
+ // Single assumption.
+ IntAssumptionAttr singleAssumption;
+ if (parser.parseCustomAttributeWithFallback(singleAssumption))
+ return failure();
+ operandAssumptions.push_back(singleAssumption);
+ } else {
+ // Multiple assumptions.
+ if (failed(parser.parseOptionalRSquare())) {
+ if (parser.parseCommaSeparatedList([&]() {
+ IntAssumptionAttr singleAssumption;
+ if (parser.parseCustomAttributeWithFallback(singleAssumption))
+ return failure();
+ operandAssumptions.push_back(singleAssumption);
+ return success();
+ }))
+ return failure();
+ if (parser.parseRSquare())
+ return failure();
+ }
+ }
+
+ // Finalize operand.
+ allOperandAssumptions.push_back(
+ parser.getBuilder().getArrayAttr(operandAssumptions));
+
+ return success();
+ }))
+ return failure();
+
+ // Parse `:` type.
+ if (parser.parseColon() || parser.parseTypeList(parsedOperandTypes))
+ return failure();
+ result.addTypes(parsedOperandTypes);
+
+ if (parser.resolveOperands(parsedOperands, parsedOperandTypes,
+ parser.getNameLoc(), result.operands))
+ return failure();
+
+ result.attributes.append(
+ "assumptions", parser.getBuilder().getArrayAttr(allOperandAssumptions));
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+void AssumeIntOp::print(OpAsmPrinter &p) {
+ p << " ";
+ ArrayAttr allOperandAssumptions = getAssumptions();
+ for (auto [index, operand] : llvm::enumerate(getOperands())) {
+ if (index > 0)
+ p << ", ";
+ ArrayAttr operandAssumptions =
+ cast<ArrayAttr>(allOperandAssumptions[index]);
+ p.printOperand(operand);
+
+ // Print the assumptions, either as a single assumption or list.
+ if (operandAssumptions.size() == 1) {
+ p.printStrippedAttrOrType(cast<IntAssumptionAttr>(operandAssumptions[0]));
+ } else {
+ p << "[";
+ llvm::interleaveComma(
+ operandAssumptions, p.getStream(), [&](Attribute attr) {
+ p.printStrippedAttrOrType(cast<IntAssumptionAttr>(attr));
+ });
+ p << "]";
+ }
+ }
+
+ p << " : ";
+ llvm::interleaveComma(getOperands(), p.getStream(),
+ [&](Value operand) { p.printType(operand.getType()); });
+ p.printOptionalAttrDict((*this)->getAttrs(), {"assumptions"});
+}
+
+//===----------------------------------------------------------------------===//
// util.optimization_barrier
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index 881d8d6..2b62cd2 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -458,78 +458,40 @@
let opDocGroup = OpGroupCompilerHintOps in {
-def Util_AssumeDivisibleOp :
- Util_PureOp<"assume.divisible", [SameOperandsAndResultType]> {
- let summary = "Memorializes knowledge that an index/integer value is divisible by some constant.";
+def Util_AssumeIntOp : Util_PureOp<"assume.int", []> {
+ let summary = "memorializes assumptions about index/integer values.";
+ let description = [{
+ This op is used to memorialize the result of some integer analysis or
+ outside knowledge across a boundary beyond which such information can
+ not be easily recovered. Assumptions are made per op/result pair.
+
+ Assumptions are tied to operands as rows of permutations of an
+ `#util.assume.int` per operand. The number of permutations is the rank.
+ Typically multiple permutations record a specific subset of assumptions
+ broken down per call-site in some way that is meaningful to the receiver.
+ Implementations can use this information to specialize on each
+ permutation if it is meaninful to do so (i.e. vs unioning across them).
+ In such cases, there will typically be one such op at the top of a
+ function or scope which passes all covered operands through it.
+ }];
let arguments = (ins
- Util_Range:$operand,
- Util_IndexAttr:$divisor
+ Variadic<AnySignlessIntegerOrIndex>:$operands,
+ Util_MultiValueIntAssumptionAttrList:$assumptions
);
let results = (outs
- Util_Range:$result
+ Variadic<AnySignlessIntegerOrIndex>:$results
);
- let assemblyFormat = [{
- $operand `by` $divisor attr-dict `:` type($operand)
- }];
let builders = [
+ // Helper for building simple single operand/single assumption ops.
OpBuilder<(ins
- "Value":$operand,
- "uint64_t":$divisor
- ),
- [{
- IntegerAttr divisorAttr = $_builder.getIntegerAttr(
- $_builder.getIndexType(), divisor);
- build($_builder, $_state, operand.getType(), operand, divisorAttr);
- }]>,
+ "Value":$singleOperand,
+ "IntAssumptionAttr":$singleAssumption
+ )>,
];
-}
-def Util_AssumeNarrowOp :
- Util_PureOp<"assume.narrow", [SameOperandsAndResultType]> {
- let summary = "Memorializes knowledge that an index/integer value can be narrowed to a type.";
-
- let arguments = (ins
- Util_Range:$operand,
- TypeAttr:$narrow_type
- );
- let results = (outs
- Util_Range:$result
- );
- let assemblyFormat = [{
- $operand attr-dict `:` type($operand) `to` $narrow_type
- }];
-}
-
-def Util_AssumeRangeOp :
- Util_PureOp<"assume.range", [SameOperandsAndResultType]> {
- let summary = "Memorializes knowledge that an index/integer value is always within some range.";
-
- let arguments = (ins
- Util_Range:$operand,
- Util_IndexAttr:$min_value,
- Util_IndexAttr:$max_value
- );
- let results = (outs
- Util_Range:$result
- );
- let assemblyFormat = [{
- $operand `in` ` ` `[` $min_value `,` $max_value `]` `:` type($operand) attr-dict
- }];
- let builders = [
- OpBuilder<(ins
- "Value":$operand,
- "uint64_t":$minValue,
- "uint64_t":$maxValue
- ),
- [{
- IntegerAttr minAttr = $_builder.getIntegerAttr(
- $_builder.getIndexType(), minValue);
- IntegerAttr maxAttr = $_builder.getIntegerAttr(
- $_builder.getIndexType(), maxValue);
- build($_builder, $_state, operand.getType(), operand, minAttr, maxAttr);
- }]>,
- ];
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
}
def Util_OptimizationBarrierOp : Util_Op<"optimization_barrier", [
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
index a26b891..2df2bfd 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/BUILD.bazel
@@ -20,6 +20,7 @@
"alignment_ops.mlir",
"assignment_folding.mlir",
"assignment_ops.mlir",
+ "assume_ops.mlir",
"attributes.mlir",
"buffer_folding.mlir",
"buffer_ops.mlir",
@@ -31,6 +32,7 @@
"hint_ops.mlir",
"list_ops.mlir",
"numeric_ops.mlir",
+ "op_verification.mlir",
"range_folding.mlir",
"range_ops.mlir",
"structural_folding.mlir",
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
index 062e79a..b6ac5d8 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/CMakeLists.txt
@@ -18,6 +18,7 @@
"alignment_ops.mlir"
"assignment_folding.mlir"
"assignment_ops.mlir"
+ "assume_ops.mlir"
"attributes.mlir"
"buffer_folding.mlir"
"buffer_ops.mlir"
@@ -29,6 +30,7 @@
"hint_ops.mlir"
"list_ops.mlir"
"numeric_ops.mlir"
+ "op_verification.mlir"
"range_folding.mlir"
"range_ops.mlir"
"structural_folding.mlir"
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir
new file mode 100644
index 0000000..0886d38
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/assume_ops.mlir
@@ -0,0 +1,24 @@
+// RUN: iree-opt --split-input-file %s | iree-opt --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @assume.int.single_assumption
+util.func public @assume.int.single_assumption(%arg0 : index) -> index {
+ // CHECK: util.assume.int %arg0<umin = 0> : index
+ %0 = util.assume.int %arg0<umin=0> : index
+ util.return %0 : index
+}
+
+// -----
+// CHECK-LABEL: @assume.int.multi_assumption
+util.func public @assume.int.multi_assumption(%arg0 : index) -> index {
+ // CHECK: util.assume.int %arg0[<umin = 0>, <divisor = 5>] : index
+ %0 = util.assume.int %arg0[<umin=0>, <divisor=5>] : index
+ util.return %0 : index
+}
+
+// -----
+// CHECK-LABEL: @assume.int.multi_operand
+util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
+ // CHECK: util.assume.int %arg0[<umin = 0>, <divisor = 5>], %arg1[<umax = 10>, <divisor = 6>] : index, i64
+ %0:2 = util.assume.int %arg0[<umin=0>, <divisor=5>], %arg1[<umax=10>, <divisor=6>] : index, i64
+ util.return %0#0, %0#1 : index, i64
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir
index 58d1def..5220794 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/attributes.mlir
@@ -1,5 +1,19 @@
// RUN: iree-opt --split-input-file --mlir-print-local-scope %s | iree-opt --split-input-file --mlir-print-local-scope | FileCheck %s
+// CHECK-LABEL: @assume_int
+builtin.module @assume_int attributes {
+ // CHECK: util.all = #util.int.assumption<umin = 1, umax = 2, divisor = 16>
+ // CHECK-SAME: util.divisor = #util.int.assumption<divisor = 32>
+ // CHECK-SAME: util.umax = #util.int.assumption<umax = 10>
+ // CHECK-SAME: util.umin = #util.int.assumption<umin = 5>
+ util.all = #util.int.assumption<umin = 1, umax = 2, divisor = 16>,
+ util.divisor = #util.int.assumption<divisor = 32>,
+ util.umax = #util.int.assumption<umax = 10>,
+ util.umin = #util.int.assumption<umin = 5>
+} {}
+
+// -----
+
// CHECK-LABEL: @byte_pattern
builtin.module @byte_pattern attributes {
// CHECK: r0 = #util.byte_pattern<0> : i8
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
new file mode 100644
index 0000000..e4a6f6e
--- /dev/null
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/test/op_verification.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics %s
+
+util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
+ // expected-error @+1 {{expected operand #1 to have 1 assumptions but it has 2}}
+ %0:2 = util.assume.int %arg0[<umin=0>], %arg1[<umax=10>, <divisor=6>] : index, i64
+ util.return %0#0, %0#1 : index, i64
+}
+
+// -----
+
+util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
+ // expected-error @+1 {{expected 2 assumption rows to match number of operands}}
+ %0:2 = "util.assume.int"(%arg0, %arg1) {
+ assumptions = []
+ } : (index, i64) -> (index, i64)
+ util.return %0#0, %0#1 : index, i64
+}
+
+// -----
+
+util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
+ // expected-error @+1 {{failed to satisfy constraint}}
+ %0:2 = "util.assume.int"(%arg0, %arg1) {
+ assumptions = [[32], [32]]
+ } : (index, i64) -> (index, i64)
+ util.return %0#0, %0#1 : index, i64
+}
+
+// -----
+
+util.func public @assume.int.multi_operand(%arg0 : index, %arg1 : i64) -> index, i64 {
+ // expected-error @+1 {{failed to satisfy constraint}}
+ %0:2 = "util.assume.int"(%arg0, %arg1) {
+ assumptions = [32, [32]]
+ } : (index, i64) -> (index, i64)
+ util.return %0#0, %0#1 : index, i64
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
index a6f072c..15a8775 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
@@ -23,14 +23,8 @@
if (auto op = dyn_cast<IREE::Util::OptimizationBarrierOp>(genericOp)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
- } else if (auto op = dyn_cast<IREE::Util::AssumeDivisibleOp>(genericOp)) {
- op.replaceAllUsesWith({op.getOperand()});
- op.erase();
- } else if (auto op = dyn_cast<IREE::Util::AssumeRangeOp>(genericOp)) {
- op.replaceAllUsesWith({op.getOperand()});
- op.erase();
- } else if (auto op = dyn_cast<IREE::Util::AssumeNarrowOp>(genericOp)) {
- op.replaceAllUsesWith({op.getOperand()});
+ } else if (auto op = dyn_cast<IREE::Util::AssumeIntOp>(genericOp)) {
+ op.replaceAllUsesWith(op.getOperands());
op.erase();
}
});
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
index c0db60a..89a8f28 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
@@ -76,30 +76,10 @@
// -----
-// CHECK-LABEL: @assume.divisible
-util.func @assume.divisible() -> i32 {
- // CHECK-NOT: util.assume.divisible
+// CHECK-LABEL: @assume.int
+util.func @assume.int() -> i32 {
+ // CHECK-NOT: util.assume.int
%c1 = arith.constant 12 : i32
- %0 = util.assume.divisible %c1 by 2 : i32
- util.return %0 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @assume.narrow
-util.func @assume.narrow() -> i32 {
- // CHECK-NOT: util.assume.narrow
- %c1 = arith.constant 12 : i32
- %0 = util.assume.narrow %c1 : i32 to i8
- util.return %0 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @assume.range
-util.func @assume.range() -> i32 {
- // CHECK-NOT: util.assume.range
- %c1 = arith.constant 12 : i32
- %0 = util.assume.range %c1 in [2, 20] : i32
+ %0 = util.assume.int %c1<umin = 1> : i32
util.return %0 : i32
}