[torch] Materialize all derivable bounds and divisor information in the IR. (#18646)
* Adds new util ops: util.assume.divisible, util.assume.narrow,
util.assume.range
* Adds new pass torch-iree-bind-symbolic-shapes which will lower
torch.bind_symbolic_shape ops if present in the IR (these are currently
suppressed in the frontend with a flag, so adding this pass
unconditionally is a no-op)
* Canonicalizes all dynamics dims so that equal-dims are represented
program wide with the same SSA value and related-dims are derived from
the same root SSA values.
* Followon steps will clone the assume annotations into dispatches so
that codegen can make decisions based on the knowledge
---------
Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
Co-authored-by: Ben Vanik <ben.vanik@gmail.com>
diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
new file mode 100644
index 0000000..b37ab37
--- /dev/null
+++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
@@ -0,0 +1,472 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Pass/Pass.h"
+#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
+#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
+#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
+#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
+
+#include <limits>
+
+namespace Torch = mlir::torch::Torch;
+namespace TorchConversion = mlir::torch::TorchConversion;
+
+namespace mlir::iree_compiler::TorchInput {
+
+#define GEN_PASS_DEF_BINDSYMBOLICSHAPESPASS
+#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"
+
+namespace {
+
+Type getNarrowestType(Builder &builder,
+ std::optional<std::pair<int64_t, int64_t>> minMaxBounds) {
+ if (!minMaxBounds)
+ return {};
+
+ auto maxBound = minMaxBounds->second;
+ if (maxBound <= std::numeric_limits<int32_t>::max())
+ return builder.getIntegerType(32);
+ else
+ return builder.getIntegerType(64);
+}
+
+// Torch "binds" symbolic shape information to all tensors in the program
+// which are not static. It does this by emitting side-effecting
+// torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops
+// which match 1:1 to terminal symbols in the Torch program.
+//
+// This is a somewhat different representation than we need in order to be
+// usable within IREE:
+//
+// 1. We only want shape information and assertion at the boundaries where
+// they can come from runtime values of unknown lineage.
+// 2. IREE operates in terms of index values and "binding" them to tensors
+// so that later dim lookups are memoized.
+// 3. IREE's value analyses operate on real index SSA values, not "symbolic"
+// values that only exist in the ether.
+//
+// These constraints can only be met if we assume that all Torch symbols are
+// "backed" by a dimension or argument, so just a free-floating relational
+// symbol. Such "backed" symbols are the most dominant form of Torch programs,
+// but it is possible to create them such that symbols do not relate to any
+// one dimension (although this typically does not happen naturally at
+// program boundaries). In this pass we assume that any such relational
+// symbols are not actionable by us, and we therefore drop them. It is possible
+// for the frontend or user to fix this situation, and we therefore assume
+// that anyone who cares will have done so. These cases are emitted as warnings
+// in this pass because they signal potential missed optimization opportunties
+// that we would like to know about.
+//
+// The approach we use from here will roughly map a torch.bind_symbolic_shape
+// op to a flow.tensor.tie_shape op, preserving only the needed dynamic
+// dimensions. Dimensions will be derived from util ops which annotate
+// constraints and relationships.
+//
+// All other bind_symbolic_shape ops will be dropped.
+class BindSymbolicShapesPass final
+ : public impl::BindSymbolicShapesPassBase<BindSymbolicShapesPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect>();
+ registry.insert<tensor::TensorDialect>();
+ registry.insert<IREE::Flow::FlowDialect>();
+ registry.insert<IREE::Util::UtilDialect>();
+ registry.insert<torch::Torch::TorchDialect>();
+ registry.insert<torch::TorchConversion::TorchConversionDialect>();
+ }
+
+ bool isEligibleBinding(Torch::BindSymbolicShapeOp bindOp) {
+ auto operand = bindOp.getOperand();
+ // Torch programs are single block and use structured control flow, so
+ // presume this is an entrypoint.
+ if (llvm::isa<BlockArgument>(operand))
+ return true;
+
+ // Mutable tensors can exist at the boundary and must be "copied" to a
+ // vtensor prior to use. Therefore, we anchor on the point of copy.
+ if (operand.getDefiningOp<Torch::CopyToValueTensorOp>())
+ return true;
+
+ return false;
+ }
+
+ struct SymbolInfo {
+ SymbolInfo(Torch::SymbolicIntOp symbolDefOp) : symbolDefOp(symbolDefOp) {
+ auto minVal = symbolDefOp.getMinValAttr();
+ auto maxVal = symbolDefOp.getMaxValAttr();
+ if (minVal && maxVal) {
+ uint64_t minValInt = minVal.getValue().getZExtValue();
+ uint64_t maxValInt = maxVal.getValue().getZExtValue();
+ // Note that torch represents open ranges in strange ways with various
+ // magic numbers in the high range of the uint64_t type. We somewhat
+ // arbitrarily say that anything over a fourth of the uint64_t
+ // range (which is half of the positive int64_t range, should these have
+ // originated as signed quantities), is a ridiculously large number not
+ // suitable as a shape dimension, and we drop the hint.
+ if (maxValInt >= minValInt &&
+ maxValInt < std::numeric_limits<uint64_t>::max() / 4) {
+ // Note that in Torch, min values are "weird" because they encode
+ // some special cases about broadcast behavior. Here we just discard
+ // them, but in the future, there may be more to derive here.
+ minMaxBounds = std::make_pair(1, maxValInt);
+ }
+ }
+ }
+
+ // Gets the canonical dim for this symbol, returning {} if there
+ // is no canonical dim.
+ Value getCanonicalDimValue(OpBuilder &builder) {
+ if (canonicalDimValue)
+ return canonicalDimValue;
+ if (equalityDimInfos.empty())
+ return {};
+ canonicalDimValue = getEqualityDimValue(builder, 0);
+ return canonicalDimValue;
+ }
+
+ // Gets the dim value for one of the entries in equalityDimInfos,
+ // materializing an op if needed.
+ Value getEqualityDimValue(OpBuilder &builder, unsigned index) {
+ auto [producer, position] = equalityDimInfos[index];
+ // Scrunch all dim ops up as far as they will go so that they can be
+ // shared among any legal consumers.
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfterValue(producer);
+ Value dimValue =
+ builder.create<tensor::DimOp>(producer.getLoc(), producer, position);
+ return dimValue;
+ }
+
+ Operation *symbolDefOp;
+
+ // If the symbol carries min/max bounds, note them here.
+ std::optional<std::pair<int64_t, int64_t>> minMaxBounds;
+
+ // All dimensions that should be considered equal by {producer_tensor,
+ // position}. When materializing shape expressions, we always use the
+ // first from this list so that simple SSA equality can be used across
+ // the graph.
+ SmallVector<std::pair<Value, unsigned>> equalityDimInfos;
+
+ Value canonicalDimValue;
+ };
+
+ struct TensorBinding {
+ Operation *bindOp;
+
+ // Symbol ops that that bind to symbols of the affine map.
+ llvm::SmallVector<Value> symbols;
+
+ // The value (tensor) this binding annotates.
+ Value annotatesValue;
+
+ // Torch type of the annotated tensor.
+ Torch::ValueTensorType torchType;
+
+ // Corresponding builtin tensor type.
+ RankedTensorType builtinTensorType;
+
+ // The affine map representing the dimensions.
+ AffineMap shapeMap;
+
+ // When prepared, we convert from the torch type to builtin and back. This
+ // is the back value. Our work gets done feeding into this.
+ TorchConversion::FromBuiltinTensorOp rewrittenTorchOp;
+
+ // Anchor op for building IR on native types.
+ Operation *anchorOp = nullptr;
+
+ // All dim materializations we were able to make. If all are defined once
+ // processing is complete, then we can tie the shape. This will be fully
+ // populated after the associateEqualityDims phase, and subsequent
+ // materializations should take the first value so that all related shapes
+ // anchor the same.
+ llvm::SmallVector<Value> materializedDims;
+
+ // Perform IR preparation for any bindings we may want to preserve.
+ void prepare() {
+ OpBuilder builder(bindOp);
+ TorchConversion::ToBuiltinTensorOp builtinConversion;
+ {
+ // Scrunch all ToBuiltinTensor ops as high up as they can go. We'll
+ // hang tensor.dim ops off of these across all dependent bindings so
+ // we need to make sure that it is always topologically legal. The
+ // easiest way to do this is to put common dependencies like this
+ // as far up as they will go, which means that each binding op (which
+ // is already guaranteed to be topologically legal) stays so.
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfterValue(annotatesValue);
+ builtinConversion = builder.create<TorchConversion::ToBuiltinTensorOp>(
+ bindOp->getLoc(), builtinTensorType, annotatesValue);
+ }
+ rewrittenTorchOp = builder.create<TorchConversion::FromBuiltinTensorOp>(
+ bindOp->getLoc(), torchType, builtinConversion.getResult());
+ annotatesValue.replaceAllUsesExcept(rewrittenTorchOp.getResult(),
+ builtinConversion);
+ annotatesValue = builtinConversion.getResult();
+ anchorOp = rewrittenTorchOp;
+
+ materializedDims.resize(builtinTensorType.getRank());
+ }
+
+ std::optional<std::pair<int64_t, int64_t>>
+ evaluateExprBounds(AffineExpr expr,
+ llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+ if (!expr.isSymbolicOrConstant())
+ return {};
+ llvm::SmallVector<std::optional<int64_t>> lowerBounds;
+ llvm::SmallVector<std::optional<int64_t>> upperBounds;
+ lowerBounds.reserve(symbols.size());
+ upperBounds.reserve(symbols.size());
+ for (auto [pos, symbolValue] : llvm::enumerate(symbols)) {
+ const SymbolInfo &symbolInfo = symbolInfos.at(symbolValue);
+ if (!symbolInfo.minMaxBounds) {
+ lowerBounds.push_back({});
+ upperBounds.push_back({});
+ } else {
+ lowerBounds.push_back(symbolInfo.minMaxBounds->first);
+ upperBounds.push_back(symbolInfo.minMaxBounds->second);
+ }
+ }
+
+ auto upperBound = getBoundForAffineExpr(
+ expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds,
+ upperBounds, /*isUpper=*/true);
+ if (!upperBound)
+ return {};
+
+ auto lowerBound = getBoundForAffineExpr(
+ expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds,
+ upperBounds, /*isUpper=*/false);
+ if (!lowerBound)
+ return {};
+
+ return std::make_pair(*lowerBound, *upperBound);
+ }
+
+ // For any dims in the shapeMap that are terminal, set up the root
+ // bindings.
+ void associateEqualityDims(llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+ OpBuilder builder(anchorOp);
+ for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) {
+ if (expr.getKind() != AffineExprKind::SymbolId)
+ continue;
+ auto symbolPos = llvm::cast<AffineSymbolExpr>(expr).getPosition();
+ Value symbol = symbols[symbolPos];
+ auto symbolInfoIt = symbolInfos.find(symbol);
+ assert(symbolInfoIt != symbolInfos.end() &&
+ "No symbol info for symbol");
+ auto &symbolInfo = symbolInfoIt->second;
+ symbolInfo.equalityDimInfos.emplace_back(annotatesValue, index);
+ }
+ }
+
+ Value materializeDimExpr(Location loc, OpBuilder &builder,
+ AffineExpr genericExpr,
+ llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+ if (auto binaryExpr = llvm::dyn_cast<AffineBinaryOpExpr>(genericExpr)) {
+ auto lhs =
+ materializeDimExpr(loc, builder, binaryExpr.getLHS(), symbolInfos);
+ if (!lhs)
+ return {};
+ auto rhs =
+ materializeDimExpr(loc, builder, binaryExpr.getRHS(), symbolInfos);
+ if (!rhs)
+ return {};
+
+ switch (binaryExpr.getKind()) {
+ case AffineExprKind::Add:
+ return builder.create<arith::AddIOp>(loc, lhs, rhs);
+ case AffineExprKind::Mul:
+ return builder.create<arith::MulIOp>(loc, lhs, rhs);
+ case AffineExprKind::Mod:
+ return builder.create<arith::RemUIOp>(loc, lhs, rhs);
+ case AffineExprKind::FloorDiv:
+ return builder.create<arith::DivUIOp>(loc, lhs, rhs);
+ case AffineExprKind::CeilDiv:
+ return builder.create<arith::CeilDivUIOp>(loc, lhs, rhs);
+ default:
+ break;
+ }
+ }
+
+ switch (genericExpr.getKind()) {
+ case AffineExprKind::Constant:
+ return builder.create<arith::ConstantOp>(
+ loc, builder.getIndexAttr(
+ llvm::cast<AffineConstantExpr>(genericExpr).getValue()));
+ case AffineExprKind::DimId:
+ // Unsupported.
+ break;
+ case AffineExprKind::SymbolId: {
+ auto symExpr = llvm::cast<AffineSymbolExpr>(genericExpr);
+ auto pos = symExpr.getPosition();
+ if (pos >= symbols.size())
+ break;
+ Value symbolValue = symbols[pos];
+ auto foundIt = symbolInfos.find(symbolValue);
+ if (foundIt == symbolInfos.end())
+ break;
+ SymbolInfo &info = foundIt->second;
+ return info.getCanonicalDimValue(builder); // May legally return {}
+ }
+ default:
+ break;
+ }
+
+ std::string s;
+ llvm::raw_string_ostream os(s);
+ genericExpr.print(os);
+ emitWarning(loc) << "Symbolic shape expression not supported: " << s
+ << " (falling back to runtime symbol resolution)";
+ return {};
+ }
+
+ void materializeDims(llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+ OpBuilder builder(anchorOp);
+ for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) {
+ if (!builtinTensorType.isDynamicDim(index))
+ continue;
+
+ Value dimValue =
+ materializeDimExpr(anchorOp->getLoc(), builder, expr, symbolInfos);
+ if (!dimValue) {
+ // Certain classes of symbolic expressions may not terminate on
+ // distinct dimensions (i.e. `s0 * 4` with no symbol that corresponds)
+ // to `s0`. In this case, we just do runtime resolution of the symbol.
+ dimValue = builder.create<tensor::DimOp>(bindOp->getLoc(),
+ annotatesValue, index);
+ }
+
+ // Add optimization assumptions if the divisor or bounds are known.
+ int64_t divisor = expr.getLargestKnownDivisor();
+ auto bounds = evaluateExprBounds(expr, symbolInfos);
+ if (divisor != 1 || bounds) {
+ Type narrowType = getNarrowestType(builder, bounds);
+ if (narrowType) {
+ dimValue = builder.create<IREE::Util::AssumeNarrowOp>(
+ bindOp->getLoc(), dimValue, TypeAttr::get(narrowType));
+ }
+ if (bounds) {
+ dimValue = builder.create<IREE::Util::AssumeRangeOp>(
+ bindOp->getLoc(), dimValue, bounds->first, bounds->second);
+ }
+ if (divisor != 1) {
+ dimValue = builder.create<IREE::Util::AssumeDivisibleOp>(
+ bindOp->getLoc(), dimValue, divisor);
+ }
+ }
+
+ materializedDims[index] = dimValue;
+ }
+ }
+
+ void tieShape(llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+ llvm::SmallVector<Value> dynamicDims;
+ dynamicDims.reserve(materializedDims.size());
+ for (size_t pos = 0; pos < materializedDims.size(); ++pos) {
+ if (builtinTensorType.isDynamicDim(pos)) {
+ Value dimValue = materializedDims[pos];
+ if (!dimValue) {
+ emitWarning(bindOp->getLoc())
+ << "Discarding symbolic shape information from PyTorch: Not "
+ << "all symbols resolved to a known dim value (first missing "
+ << "at position " << pos << ")";
+ return;
+ }
+
+ dynamicDims.push_back(dimValue);
+ }
+ }
+
+ OpBuilder builder(anchorOp);
+ Value tieShape = builder.create<IREE::Flow::TensorTieShapeOp>(
+ bindOp->getLoc(), builtinTensorType, annotatesValue, dynamicDims);
+ rewrittenTorchOp.setOperand(tieShape);
+ }
+ };
+
+ void runOnOperation() override {
+ ConversionTarget target(getContext());
+ TypeConverter typeConverter;
+ TorchConversion::setupBackendTypeConversion(target, typeConverter);
+
+ llvm::SmallVector<Operation *> cleanupOpList;
+ llvm::SmallVector<TensorBinding> bindings;
+ // Mapping of SSA value for a torch.symbolic_int (or related op) to its
+ // info.
+ llvm::DenseMap<Value, SymbolInfo> symbolInfos;
+
+ // Walk the ops we care about and stash for analysis.
+ getOperation()->walk([&](Operation *childOp) {
+ if (auto symbolOp = llvm::dyn_cast<Torch::SymbolicIntOp>(childOp)) {
+ cleanupOpList.push_back(symbolOp);
+ symbolInfos.insert_or_assign(symbolOp.getResult(),
+ SymbolInfo(symbolOp));
+ } else if (auto bindOp =
+ llvm::dyn_cast<Torch::BindSymbolicShapeOp>(childOp)) {
+ cleanupOpList.push_back(bindOp);
+ if (!isEligibleBinding(bindOp))
+ return;
+ auto torchType =
+ llvm::cast<Torch::ValueTensorType>(bindOp.getOperand().getType());
+ auto builtinType = llvm::dyn_cast_or_null<RankedTensorType>(
+ typeConverter.convertType(torchType));
+ if (!builtinType) {
+ emitError(childOp->getLoc())
+ << "cannot convert torch type to builtin: " << torchType;
+ return signalPassFailure();
+ }
+ bindings.push_back(TensorBinding{
+ /*bindOp=*/childOp,
+ /*symbols=*/bindOp.getShapeSymbols(),
+ /*annotatesValue=*/bindOp.getOperand(),
+ /*torchType=*/torchType,
+ /*builtinType=*/builtinType,
+ /*shapeMap=*/bindOp.getShapeExpressions().getAffineMap()});
+ }
+ });
+
+ // For every tensor value of interest, convert to a builtin tensor type and
+ // back, RAUW'ing the result. This will meet the eventual final conversion
+ // with additional graph forking.
+ for (auto &binding : bindings) {
+ binding.prepare();
+ }
+
+ // Find all associations to a single symbol and set up the roots.
+ for (auto &binding : bindings) {
+ binding.associateEqualityDims(symbolInfos);
+ }
+
+ // Materialize all dimension expressions and constraints.
+ for (auto &binding : bindings) {
+ binding.materializeDims(symbolInfos);
+ }
+
+ // Now that all is known, insert tie shape.
+ for (auto &binding : bindings) {
+ binding.tieShape(symbolInfos);
+ }
+
+ // Erase all found ops.
+ for (auto *op : llvm::reverse(cleanupOpList)) {
+ op->erase();
+ }
+ }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::TorchInput
diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
index 1db4085..4e48784 100644
--- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
@@ -34,6 +34,7 @@
HDRS
"Passes.h"
SRCS
+ "BindSymbolicShapes.cpp"
"BitCastQuantTensor.cpp"
"ConvertTMTensorToLinalgExt.cpp"
"FuncConversion.cpp"
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
index 00ab1a4..a0682b5 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
@@ -36,6 +36,10 @@
// model) and those constants get somewhat obscured by TorchToArith.
llvm::ArrayRef<std::string> emptyArrayRef;
+ // Dynamic shape bindings add a lot of structure to the IR which we prefer to
+ // leverage and eliminate prior to any other activity, so do this first.
+ pm.addNestedPass<func::FuncOp>(createBindSymbolicShapesPass());
+
if (options.strictSymbolicShapes) {
pm.addNestedPass<func::FuncOp>(createSetStrictSymbolicShapesPass());
// Run canonicalization in case any previously non-strict dynamic code can
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td
index 91b7792..251cbb2 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.td
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.td
@@ -9,6 +9,11 @@
include "mlir/Pass/PassBase.td"
+def BindSymbolicShapesPass :
+ InterfacePass<"torch-iree-bind-symbolic-shapes", "mlir::FunctionOpInterface"> {
+ let summary = "Process torch dynamic shape bindings into IREE analyzable forms";
+}
+
def BitCastQuantTensorPass :
InterfacePass<"torch-iree-bitcast-quant-tensor", "mlir::FunctionOpInterface"> {
let summary = "Bitcasts i8 packed tensors of sub-byte types to the actual bit width";
diff --git a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt
index 6f86276..cabc6b2 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt
@@ -6,6 +6,7 @@
"assume_strict_symbols.mlir"
"auto_input_conversion.mlir"
"attention.mlir"
+ "bind_symbolic_shapes.mlir"
"bitcast_quant_tensor.mlir"
"func_conversion.mlir"
"func_conversion_invalid.mlir"
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
new file mode 100644
index 0000000..e3d6061
--- /dev/null
+++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
@@ -0,0 +1,178 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(torch-iree-bind-symbolic-shapes))" --split-input-file --verify-diagnostics %s | FileCheck %s
+
+// This example was captured from a program which has a dynamic batch size and
+// tiled inner dim on one of the arguments, causing a symbolic relationship on
+// the second dimension.
+// CHECK-LABEL: @basic_example
+module @basic_example {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
+ // CHECK-DAG: %[[ARG1_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
+ // CHECK-DAG: %[[ARG0_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
+ // CHECK-DAG: %[[POS0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[POS1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %1, %[[POS0]] :
+ // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %1, %[[POS1]] :
+ // CHECK: %[[ARG0_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
+ // CHECK: %[[ARG0_DIM0_RANGE:.*]] = util.assume.range %[[ARG0_DIM0_NARROW]] in [1, 1024] : index
+ // CHECK: %[[ARG0_DIM1_NARROW:.*]] = util.assume.narrow %[[DIM1]] : index to i32
+ // CHECK: %[[ARG0_DIM1_RANGE:.*]] = util.assume.range %[[ARG0_DIM1_NARROW]] in [1, 1024] : index
+ // CHECK: %[[ARG0_TIE:.*]] = flow.tensor.tie_shape %[[ARG0_ANCHOR]] : tensor<?x?xf32>{%[[ARG0_DIM0_RANGE]], %[[ARG0_DIM1_RANGE]]}
+ // CHECK: %[[ARG0_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG0_TIE]]
+ // CHECK: %[[ARG1_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
+ // CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.range %[[ARG1_DIM0_NARROW]] in [1, 1024]
+ // CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index
+ // CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]]
+ // CHECK: %[[ARG1_DIM1_NARROW:.*]] = util.assume.narrow %[[ARG1_DIM1]] : index to i32
+ // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.range %[[ARG1_DIM1_NARROW]] in [2, 2048] : index
+ // CHECK: %[[ARG1_DIM1_DIV:.*]] = util.assume.divisible %[[ARG1_DIM1_RANGE]] by 2
+ // CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor<?x?xf32>{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_DIV]]}
+ // CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]]
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+ %2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ %int1 = torch.constant.int 1
+ %int2 = torch.constant.int 2
+ %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
+ %4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ %int1_0 = torch.constant.int 1
+ %5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ return %5 : !torch.vtensor<[?,?],f32>
+ }
+}
+
+// -----
+// This example was captured from a torch program that used a symbol that did
+// not correspond to any dimension (being used in an expression as part of
+// distinct dimensions). This exercises a special case in the pass for deferring
+// to runtime resolution of the dim.
+// We just verify that the vital information has been captured.
+// CHECK-LABEL: @unbacked_symbol
+module @unbacked_symbol {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
+ // CHECK: util.assume.narrow
+ // CHECK: util.assume.range{{.*}} [1, 1024]
+ // CHECK: util.assume.narrow
+ // CHECK: util.assume.range{{.*}} [2, 2048]
+ // CHECK: util.assume.divisible{{.*}} by 2
+ // CHECK: tie_shape
+ // CHECK: util.assume.narrow
+ // CHECK: util.assume.range{{.*}} [1, 1024]
+ // CHECK: util.assume.narrow
+ // CHECK: util.assume.range{{.*}} [4, 4096]
+ // CHECK: util.assume.divisible{{.*}} by 4
+ // CHECK: tie_shape
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int
+ %2 = torch.symbolic_int "4*s4" {min_val = 0, max_val = 4096} : !torch.int
+ %3 = torch.symbolic_int "s4" {min_val = 2, max_val = 1024} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
+ %int1 = torch.constant.int 1
+ %int2 = torch.constant.int 2
+ %4 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
+ %5 = torch.aten.repeat %arg0, %4 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %5, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
+ %int1_0 = torch.constant.int 1
+ %6 = torch.aten.add.Tensor %5, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %6, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
+ return %6 : !torch.vtensor<[?,?],f32>
+ }
+}
+
+// -----
+// CHECK-LABEL: @all_bindings_dropped
+module @all_bindings_dropped {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
+ // CHECK-NOT: torch.symbolic_int
+ // CHECK-NOT: torch.bind_symbolic_shape
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+ %2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ %int1 = torch.constant.int 1
+ %int2 = torch.constant.int 2
+ %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
+ %4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ %int1_0 = torch.constant.int 1
+ %5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+ return %5 : !torch.vtensor<[?,?],f32>
+ }
+}
+
+// -----
+// CHECK-LABEL: @add_expr
+module @add_expr {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+ // CHECK: addi
+ // CHECK-NOT: divisible
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 + 2)> : !torch.vtensor<[?,?],f32>
+ return
+ }
+}
+
+// -----
+// CHECK-LABEL: @mod_expr
+module @mod_expr {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+ // CHECK: remui
+ // CHECK-NOT: divisible
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 mod 2)> : !torch.vtensor<[?,?],f32>
+ return
+ }
+}
+
+// -----
+// CHECK-LABEL: @floordiv_expr
+module @floordiv_expr {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+ // CHECK: divui
+ // CHECK-NOT: divisible
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 floordiv 2)> : !torch.vtensor<[?,?],f32>
+ return
+ }
+}
+
+// -----
+// Verifies that unsupported dim expressions warn (and do not assert).
+// CHECK-LABEL: @unsupported_non_symbolic
+module @unsupported_non_symbolic {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ // expected-warning@+1 {{Symbolic shape expression not supported: d0}}
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<(d0)[s0, s1] -> (s0, s1 + d0)> : !torch.vtensor<[?,?],f32>
+ return
+ }
+}
+
+// -----
+// Torch uses high values to signal unbounded ranges. Ensure they are
+// suppressed.
+// CHECK-LABEL: @torch_unbounded_max_range
+module @torch_unbounded_max_range {
+ func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+ // CHECK-NOT: util.assume.range
+ %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 4611686018427387903} : !torch.int
+ %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775806} : !torch.int
+ torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+ torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32>
+ return
+ }
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index b6466d6..881d8d6 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -458,6 +458,80 @@
let opDocGroup = OpGroupCompilerHintOps in {
+def Util_AssumeDivisibleOp :
+ Util_PureOp<"assume.divisible", [SameOperandsAndResultType]> {
+ let summary = "Memorializes knowledge that an index/integer value is divisible by some constant.";
+
+ let arguments = (ins
+ Util_Range:$operand,
+ Util_IndexAttr:$divisor
+ );
+ let results = (outs
+ Util_Range:$result
+ );
+ let assemblyFormat = [{
+ $operand `by` $divisor attr-dict `:` type($operand)
+ }];
+ let builders = [
+ OpBuilder<(ins
+ "Value":$operand,
+ "uint64_t":$divisor
+ ),
+ [{
+ IntegerAttr divisorAttr = $_builder.getIntegerAttr(
+ $_builder.getIndexType(), divisor);
+ build($_builder, $_state, operand.getType(), operand, divisorAttr);
+ }]>,
+ ];
+}
+
+def Util_AssumeNarrowOp :
+ Util_PureOp<"assume.narrow", [SameOperandsAndResultType]> {
+ let summary = "Memorializes knowledge that an index/integer value can be narrowed to a type.";
+
+ let arguments = (ins
+ Util_Range:$operand,
+ TypeAttr:$narrow_type
+ );
+ let results = (outs
+ Util_Range:$result
+ );
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` $narrow_type
+ }];
+}
+
+def Util_AssumeRangeOp :
+ Util_PureOp<"assume.range", [SameOperandsAndResultType]> {
+ let summary = "Memorializes knowledge that an index/integer value is always within some range.";
+
+ let arguments = (ins
+ Util_Range:$operand,
+ Util_IndexAttr:$min_value,
+ Util_IndexAttr:$max_value
+ );
+ let results = (outs
+ Util_Range:$result
+ );
+ let assemblyFormat = [{
+ $operand `in` ` ` `[` $min_value `,` $max_value `]` `:` type($operand) attr-dict
+ }];
+ let builders = [
+ OpBuilder<(ins
+ "Value":$operand,
+ "uint64_t":$minValue,
+ "uint64_t":$maxValue
+ ),
+ [{
+ IntegerAttr minAttr = $_builder.getIntegerAttr(
+ $_builder.getIndexType(), minValue);
+ IntegerAttr maxAttr = $_builder.getIntegerAttr(
+ $_builder.getIndexType(), maxValue);
+ build($_builder, $_state, operand.getType(), operand, minAttr, maxAttr);
+ }]>,
+ ];
+}
+
def Util_OptimizationBarrierOp : Util_Op<"optimization_barrier", [
SameOperandsAndResultType,
]> {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
index ff7cefd..a6f072c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
@@ -19,9 +19,20 @@
void runOnOperation() override {
// We can't use patterns and applyPatternsAndFoldGreedily because that
// automatically does canonicalization.
- getOperation()->walk([&](IREE::Util::OptimizationBarrierOp op) {
- op.replaceAllUsesWith(op.getOperands());
- op.erase();
+ getOperation()->walk([&](Operation *genericOp) {
+ if (auto op = dyn_cast<IREE::Util::OptimizationBarrierOp>(genericOp)) {
+ op.replaceAllUsesWith(op.getOperands());
+ op.erase();
+ } else if (auto op = dyn_cast<IREE::Util::AssumeDivisibleOp>(genericOp)) {
+ op.replaceAllUsesWith({op.getOperand()});
+ op.erase();
+ } else if (auto op = dyn_cast<IREE::Util::AssumeRangeOp>(genericOp)) {
+ op.replaceAllUsesWith({op.getOperand()});
+ op.erase();
+ } else if (auto op = dyn_cast<IREE::Util::AssumeNarrowOp>(genericOp)) {
+ op.replaceAllUsesWith({op.getOperand()});
+ op.erase();
+ }
});
}
};
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
index 717d2bf..c0db60a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
@@ -73,3 +73,33 @@
}
}
}
+
+// -----
+
+// CHECK-LABEL: @assume.divisible
+util.func @assume.divisible() -> i32 {
+ // CHECK-NOT: util.assume.divisible
+ %c1 = arith.constant 12 : i32
+ %0 = util.assume.divisible %c1 by 2 : i32
+ util.return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @assume.narrow
+util.func @assume.narrow() -> i32 {
+ // CHECK-NOT: util.assume.narrow
+ %c1 = arith.constant 12 : i32
+ %0 = util.assume.narrow %c1 : i32 to i8
+ util.return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @assume.range
+util.func @assume.range() -> i32 {
+ // CHECK-NOT: util.assume.range
+ %c1 = arith.constant 12 : i32
+ %0 = util.assume.range %c1 in [2, 20] : i32
+ util.return %0 : i32
+}