[torch] Materialize all derivable bounds and divisor information in the IR. (#18646)

* Adds new util ops: util.assume.divisible, util.assume.narrow,
util.assume.range
* Adds new pass torch-iree-bind-symbolic-shapes which will lower
torch.bind_symbolic_shape ops if present in the IR (these are currently
suppressed in the frontend with a flag, so adding this pass
unconditionally is a no-op)
* Canonicalizes all dynamics dims so that equal-dims are represented
program wide with the same SSA value and related-dims are derived from
the same root SSA values.
* Followon steps will clone the assume annotations into dispatches so
that codegen can make decisions based on the knowledge

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
Co-authored-by: Ben Vanik <ben.vanik@gmail.com>
diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
new file mode 100644
index 0000000..b37ab37
--- /dev/null
+++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp
@@ -0,0 +1,472 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
+#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Pass/Pass.h"
+#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
+#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
+#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
+#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
+
+#include <limits>
+
+namespace Torch = mlir::torch::Torch;
+namespace TorchConversion = mlir::torch::TorchConversion;
+
+namespace mlir::iree_compiler::TorchInput {
+
+#define GEN_PASS_DEF_BINDSYMBOLICSHAPESPASS
+#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"
+
+namespace {
+
+Type getNarrowestType(Builder &builder,
+                      std::optional<std::pair<int64_t, int64_t>> minMaxBounds) {
+  if (!minMaxBounds)
+    return {};
+
+  auto maxBound = minMaxBounds->second;
+  if (maxBound <= std::numeric_limits<int32_t>::max())
+    return builder.getIntegerType(32);
+  else
+    return builder.getIntegerType(64);
+}
+
+// Torch "binds" symbolic shape information to all tensors in the program
+// which are not static. It does this by emitting side-effecting
+// torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops
+// which match 1:1 to terminal symbols in the Torch program.
+//
+// This is a somewhat different representation than we need in order to be
+// usable within IREE:
+//
+//   1. We only want shape information and assertion at the boundaries where
+//      they can come from runtime values of unknown lineage.
+//   2. IREE operates in terms of index values and "binding" them to tensors
+//      so that later dim lookups are memoized.
+//   3. IREE's value analyses operate on real index SSA values, not "symbolic"
+//      values that only exist in the ether.
+//
+// These constraints can only be met if we assume that all Torch symbols are
+// "backed" by a dimension or argument, so just a free-floating relational
+// symbol. Such "backed" symbols are the most dominant form of Torch programs,
+// but it is possible to create them such that symbols do not relate to any
+// one dimension (although this typically does not happen naturally at
+// program boundaries). In this pass we assume that any such relational
+// symbols are not actionable by us, and we therefore drop them. It is possible
+// for the frontend or user to fix this situation, and we therefore assume
+// that anyone who cares will have done so. These cases are emitted as warnings
+// in this pass because they signal potential missed optimization opportunties
+// that we would like to know about.
+//
+// The approach we use from here will roughly map a torch.bind_symbolic_shape
+// op to a flow.tensor.tie_shape op, preserving only the needed dynamic
+// dimensions. Dimensions will be derived from util ops which annotate
+// constraints and relationships.
+//
+// All other bind_symbolic_shape ops will be dropped.
+class BindSymbolicShapesPass final
+    : public impl::BindSymbolicShapesPassBase<BindSymbolicShapesPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect>();
+    registry.insert<tensor::TensorDialect>();
+    registry.insert<IREE::Flow::FlowDialect>();
+    registry.insert<IREE::Util::UtilDialect>();
+    registry.insert<torch::Torch::TorchDialect>();
+    registry.insert<torch::TorchConversion::TorchConversionDialect>();
+  }
+
+  bool isEligibleBinding(Torch::BindSymbolicShapeOp bindOp) {
+    auto operand = bindOp.getOperand();
+    // Torch programs are single block and use structured control flow, so
+    // presume this is an entrypoint.
+    if (llvm::isa<BlockArgument>(operand))
+      return true;
+
+    // Mutable tensors can exist at the boundary and must be "copied" to a
+    // vtensor prior to use. Therefore, we anchor on the point of copy.
+    if (operand.getDefiningOp<Torch::CopyToValueTensorOp>())
+      return true;
+
+    return false;
+  }
+
+  struct SymbolInfo {
+    SymbolInfo(Torch::SymbolicIntOp symbolDefOp) : symbolDefOp(symbolDefOp) {
+      auto minVal = symbolDefOp.getMinValAttr();
+      auto maxVal = symbolDefOp.getMaxValAttr();
+      if (minVal && maxVal) {
+        uint64_t minValInt = minVal.getValue().getZExtValue();
+        uint64_t maxValInt = maxVal.getValue().getZExtValue();
+        // Note that torch represents open ranges in strange ways with various
+        // magic numbers in the high range of the uint64_t type. We somewhat
+        // arbitrarily say that anything over a fourth of the uint64_t
+        // range (which is half of the positive int64_t range, should these have
+        // originated as signed quantities), is a ridiculously large number not
+        // suitable as a shape dimension, and we drop the hint.
+        if (maxValInt >= minValInt &&
+            maxValInt < std::numeric_limits<uint64_t>::max() / 4) {
+          // Note that in Torch, min values are "weird" because they encode
+          // some special cases about broadcast behavior. Here we just discard
+          // them, but in the future, there may be more to derive here.
+          minMaxBounds = std::make_pair(1, maxValInt);
+        }
+      }
+    }
+
+    // Gets the canonical dim for this symbol, returning {} if there
+    // is no canonical dim.
+    Value getCanonicalDimValue(OpBuilder &builder) {
+      if (canonicalDimValue)
+        return canonicalDimValue;
+      if (equalityDimInfos.empty())
+        return {};
+      canonicalDimValue = getEqualityDimValue(builder, 0);
+      return canonicalDimValue;
+    }
+
+    // Gets the dim value for one of the entries in equalityDimInfos,
+    // materializing an op if needed.
+    Value getEqualityDimValue(OpBuilder &builder, unsigned index) {
+      auto [producer, position] = equalityDimInfos[index];
+      // Scrunch all dim ops up as far as they will go so that they can be
+      // shared among any legal consumers.
+      OpBuilder::InsertionGuard guard(builder);
+      builder.setInsertionPointAfterValue(producer);
+      Value dimValue =
+          builder.create<tensor::DimOp>(producer.getLoc(), producer, position);
+      return dimValue;
+    }
+
+    Operation *symbolDefOp;
+
+    // If the symbol carries min/max bounds, note them here.
+    std::optional<std::pair<int64_t, int64_t>> minMaxBounds;
+
+    // All dimensions that should be considered equal by {producer_tensor,
+    // position}. When materializing shape expressions, we always use the
+    // first from this list so that simple SSA equality can be used across
+    // the graph.
+    SmallVector<std::pair<Value, unsigned>> equalityDimInfos;
+
+    Value canonicalDimValue;
+  };
+
+  struct TensorBinding {
+    Operation *bindOp;
+
+    // Symbol ops that that bind to symbols of the affine map.
+    llvm::SmallVector<Value> symbols;
+
+    // The value (tensor) this binding annotates.
+    Value annotatesValue;
+
+    // Torch type of the annotated tensor.
+    Torch::ValueTensorType torchType;
+
+    // Corresponding builtin tensor type.
+    RankedTensorType builtinTensorType;
+
+    // The affine map representing the dimensions.
+    AffineMap shapeMap;
+
+    // When prepared, we convert from the torch type to builtin and back. This
+    // is the back value. Our work gets done feeding into this.
+    TorchConversion::FromBuiltinTensorOp rewrittenTorchOp;
+
+    // Anchor op for building IR on native types.
+    Operation *anchorOp = nullptr;
+
+    // All dim materializations we were able to make. If all are defined once
+    // processing is complete, then we can tie the shape. This will be fully
+    // populated after the associateEqualityDims phase, and subsequent
+    // materializations should take the first value so that all related shapes
+    // anchor the same.
+    llvm::SmallVector<Value> materializedDims;
+
+    // Perform IR preparation for any bindings we may want to preserve.
+    void prepare() {
+      OpBuilder builder(bindOp);
+      TorchConversion::ToBuiltinTensorOp builtinConversion;
+      {
+        // Scrunch all ToBuiltinTensor ops as high up as they can go. We'll
+        // hang tensor.dim ops off of these across all dependent bindings so
+        // we need to make sure that it is always topologically legal. The
+        // easiest way to do this is to put common dependencies like this
+        // as far up as they will go, which means that each binding op (which
+        // is already guaranteed to be topologically legal) stays so.
+        OpBuilder::InsertionGuard guard(builder);
+        builder.setInsertionPointAfterValue(annotatesValue);
+        builtinConversion = builder.create<TorchConversion::ToBuiltinTensorOp>(
+            bindOp->getLoc(), builtinTensorType, annotatesValue);
+      }
+      rewrittenTorchOp = builder.create<TorchConversion::FromBuiltinTensorOp>(
+          bindOp->getLoc(), torchType, builtinConversion.getResult());
+      annotatesValue.replaceAllUsesExcept(rewrittenTorchOp.getResult(),
+                                          builtinConversion);
+      annotatesValue = builtinConversion.getResult();
+      anchorOp = rewrittenTorchOp;
+
+      materializedDims.resize(builtinTensorType.getRank());
+    }
+
+    std::optional<std::pair<int64_t, int64_t>>
+    evaluateExprBounds(AffineExpr expr,
+                       llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+      if (!expr.isSymbolicOrConstant())
+        return {};
+      llvm::SmallVector<std::optional<int64_t>> lowerBounds;
+      llvm::SmallVector<std::optional<int64_t>> upperBounds;
+      lowerBounds.reserve(symbols.size());
+      upperBounds.reserve(symbols.size());
+      for (auto [pos, symbolValue] : llvm::enumerate(symbols)) {
+        const SymbolInfo &symbolInfo = symbolInfos.at(symbolValue);
+        if (!symbolInfo.minMaxBounds) {
+          lowerBounds.push_back({});
+          upperBounds.push_back({});
+        } else {
+          lowerBounds.push_back(symbolInfo.minMaxBounds->first);
+          upperBounds.push_back(symbolInfo.minMaxBounds->second);
+        }
+      }
+
+      auto upperBound = getBoundForAffineExpr(
+          expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds,
+          upperBounds, /*isUpper=*/true);
+      if (!upperBound)
+        return {};
+
+      auto lowerBound = getBoundForAffineExpr(
+          expr, /*numDims=*/0, /*numSymbols=*/symbols.size(), lowerBounds,
+          upperBounds, /*isUpper=*/false);
+      if (!lowerBound)
+        return {};
+
+      return std::make_pair(*lowerBound, *upperBound);
+    }
+
+    // For any dims in the shapeMap that are terminal, set up the root
+    // bindings.
+    void associateEqualityDims(llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+      OpBuilder builder(anchorOp);
+      for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) {
+        if (expr.getKind() != AffineExprKind::SymbolId)
+          continue;
+        auto symbolPos = llvm::cast<AffineSymbolExpr>(expr).getPosition();
+        Value symbol = symbols[symbolPos];
+        auto symbolInfoIt = symbolInfos.find(symbol);
+        assert(symbolInfoIt != symbolInfos.end() &&
+               "No symbol info for symbol");
+        auto &symbolInfo = symbolInfoIt->second;
+        symbolInfo.equalityDimInfos.emplace_back(annotatesValue, index);
+      }
+    }
+
+    Value materializeDimExpr(Location loc, OpBuilder &builder,
+                             AffineExpr genericExpr,
+                             llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+      if (auto binaryExpr = llvm::dyn_cast<AffineBinaryOpExpr>(genericExpr)) {
+        auto lhs =
+            materializeDimExpr(loc, builder, binaryExpr.getLHS(), symbolInfos);
+        if (!lhs)
+          return {};
+        auto rhs =
+            materializeDimExpr(loc, builder, binaryExpr.getRHS(), symbolInfos);
+        if (!rhs)
+          return {};
+
+        switch (binaryExpr.getKind()) {
+        case AffineExprKind::Add:
+          return builder.create<arith::AddIOp>(loc, lhs, rhs);
+        case AffineExprKind::Mul:
+          return builder.create<arith::MulIOp>(loc, lhs, rhs);
+        case AffineExprKind::Mod:
+          return builder.create<arith::RemUIOp>(loc, lhs, rhs);
+        case AffineExprKind::FloorDiv:
+          return builder.create<arith::DivUIOp>(loc, lhs, rhs);
+        case AffineExprKind::CeilDiv:
+          return builder.create<arith::CeilDivUIOp>(loc, lhs, rhs);
+        default:
+          break;
+        }
+      }
+
+      switch (genericExpr.getKind()) {
+      case AffineExprKind::Constant:
+        return builder.create<arith::ConstantOp>(
+            loc, builder.getIndexAttr(
+                     llvm::cast<AffineConstantExpr>(genericExpr).getValue()));
+      case AffineExprKind::DimId:
+        // Unsupported.
+        break;
+      case AffineExprKind::SymbolId: {
+        auto symExpr = llvm::cast<AffineSymbolExpr>(genericExpr);
+        auto pos = symExpr.getPosition();
+        if (pos >= symbols.size())
+          break;
+        Value symbolValue = symbols[pos];
+        auto foundIt = symbolInfos.find(symbolValue);
+        if (foundIt == symbolInfos.end())
+          break;
+        SymbolInfo &info = foundIt->second;
+        return info.getCanonicalDimValue(builder); // May legally return {}
+      }
+      default:
+        break;
+      }
+
+      std::string s;
+      llvm::raw_string_ostream os(s);
+      genericExpr.print(os);
+      emitWarning(loc) << "Symbolic shape expression not supported: " << s
+                       << " (falling back to runtime symbol resolution)";
+      return {};
+    }
+
+    void materializeDims(llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+      OpBuilder builder(anchorOp);
+      for (auto [index, expr] : llvm::enumerate(shapeMap.getResults())) {
+        if (!builtinTensorType.isDynamicDim(index))
+          continue;
+
+        Value dimValue =
+            materializeDimExpr(anchorOp->getLoc(), builder, expr, symbolInfos);
+        if (!dimValue) {
+          // Certain classes of symbolic expressions may not terminate on
+          // distinct dimensions (i.e. `s0 * 4` with no symbol that corresponds)
+          // to `s0`. In this case, we just do runtime resolution of the symbol.
+          dimValue = builder.create<tensor::DimOp>(bindOp->getLoc(),
+                                                   annotatesValue, index);
+        }
+
+        // Add optimization assumptions if the divisor or bounds are known.
+        int64_t divisor = expr.getLargestKnownDivisor();
+        auto bounds = evaluateExprBounds(expr, symbolInfos);
+        if (divisor != 1 || bounds) {
+          Type narrowType = getNarrowestType(builder, bounds);
+          if (narrowType) {
+            dimValue = builder.create<IREE::Util::AssumeNarrowOp>(
+                bindOp->getLoc(), dimValue, TypeAttr::get(narrowType));
+          }
+          if (bounds) {
+            dimValue = builder.create<IREE::Util::AssumeRangeOp>(
+                bindOp->getLoc(), dimValue, bounds->first, bounds->second);
+          }
+          if (divisor != 1) {
+            dimValue = builder.create<IREE::Util::AssumeDivisibleOp>(
+                bindOp->getLoc(), dimValue, divisor);
+          }
+        }
+
+        materializedDims[index] = dimValue;
+      }
+    }
+
+    void tieShape(llvm::DenseMap<Value, SymbolInfo> &symbolInfos) {
+      llvm::SmallVector<Value> dynamicDims;
+      dynamicDims.reserve(materializedDims.size());
+      for (size_t pos = 0; pos < materializedDims.size(); ++pos) {
+        if (builtinTensorType.isDynamicDim(pos)) {
+          Value dimValue = materializedDims[pos];
+          if (!dimValue) {
+            emitWarning(bindOp->getLoc())
+                << "Discarding symbolic shape information from PyTorch: Not "
+                << "all symbols resolved to a known dim value (first missing "
+                << "at position " << pos << ")";
+            return;
+          }
+
+          dynamicDims.push_back(dimValue);
+        }
+      }
+
+      OpBuilder builder(anchorOp);
+      Value tieShape = builder.create<IREE::Flow::TensorTieShapeOp>(
+          bindOp->getLoc(), builtinTensorType, annotatesValue, dynamicDims);
+      rewrittenTorchOp.setOperand(tieShape);
+    }
+  };
+
+  void runOnOperation() override {
+    ConversionTarget target(getContext());
+    TypeConverter typeConverter;
+    TorchConversion::setupBackendTypeConversion(target, typeConverter);
+
+    llvm::SmallVector<Operation *> cleanupOpList;
+    llvm::SmallVector<TensorBinding> bindings;
+    // Mapping of SSA value for a torch.symbolic_int (or related op) to its
+    // info.
+    llvm::DenseMap<Value, SymbolInfo> symbolInfos;
+
+    // Walk the ops we care about and stash for analysis.
+    getOperation()->walk([&](Operation *childOp) {
+      if (auto symbolOp = llvm::dyn_cast<Torch::SymbolicIntOp>(childOp)) {
+        cleanupOpList.push_back(symbolOp);
+        symbolInfos.insert_or_assign(symbolOp.getResult(),
+                                     SymbolInfo(symbolOp));
+      } else if (auto bindOp =
+                     llvm::dyn_cast<Torch::BindSymbolicShapeOp>(childOp)) {
+        cleanupOpList.push_back(bindOp);
+        if (!isEligibleBinding(bindOp))
+          return;
+        auto torchType =
+            llvm::cast<Torch::ValueTensorType>(bindOp.getOperand().getType());
+        auto builtinType = llvm::dyn_cast_or_null<RankedTensorType>(
+            typeConverter.convertType(torchType));
+        if (!builtinType) {
+          emitError(childOp->getLoc())
+              << "cannot convert torch type to builtin: " << torchType;
+          return signalPassFailure();
+        }
+        bindings.push_back(TensorBinding{
+            /*bindOp=*/childOp,
+            /*symbols=*/bindOp.getShapeSymbols(),
+            /*annotatesValue=*/bindOp.getOperand(),
+            /*torchType=*/torchType,
+            /*builtinType=*/builtinType,
+            /*shapeMap=*/bindOp.getShapeExpressions().getAffineMap()});
+      }
+    });
+
+    // For every tensor value of interest, convert to a builtin tensor type and
+    // back, RAUW'ing the result. This will meet the eventual final conversion
+    // with additional graph forking.
+    for (auto &binding : bindings) {
+      binding.prepare();
+    }
+
+    // Find all associations to a single symbol and set up the roots.
+    for (auto &binding : bindings) {
+      binding.associateEqualityDims(symbolInfos);
+    }
+
+    // Materialize all dimension expressions and constraints.
+    for (auto &binding : bindings) {
+      binding.materializeDims(symbolInfos);
+    }
+
+    // Now that all is known, insert tie shape.
+    for (auto &binding : bindings) {
+      binding.tieShape(symbolInfos);
+    }
+
+    // Erase all found ops.
+    for (auto *op : llvm::reverse(cleanupOpList)) {
+      op->erase();
+    }
+  }
+};
+
+} // namespace
+
+} // namespace mlir::iree_compiler::TorchInput
diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
index 1db4085..4e48784 100644
--- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
@@ -34,6 +34,7 @@
   HDRS
     "Passes.h"
   SRCS
+    "BindSymbolicShapes.cpp"
     "BitCastQuantTensor.cpp"
     "ConvertTMTensorToLinalgExt.cpp"
     "FuncConversion.cpp"
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
index 00ab1a4..a0682b5 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp
@@ -36,6 +36,10 @@
   // model) and those constants get somewhat obscured by TorchToArith.
   llvm::ArrayRef<std::string> emptyArrayRef;
 
+  // Dynamic shape bindings add a lot of structure to the IR which we prefer to
+  // leverage and eliminate prior to any other activity, so do this first.
+  pm.addNestedPass<func::FuncOp>(createBindSymbolicShapesPass());
+
   if (options.strictSymbolicShapes) {
     pm.addNestedPass<func::FuncOp>(createSetStrictSymbolicShapesPass());
     // Run canonicalization in case any previously non-strict dynamic code can
diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td
index 91b7792..251cbb2 100644
--- a/compiler/plugins/input/Torch/InputConversion/Passes.td
+++ b/compiler/plugins/input/Torch/InputConversion/Passes.td
@@ -9,6 +9,11 @@
 
 include "mlir/Pass/PassBase.td"
 
+def BindSymbolicShapesPass :
+    InterfacePass<"torch-iree-bind-symbolic-shapes", "mlir::FunctionOpInterface"> {
+  let summary = "Process torch dynamic shape bindings into IREE analyzable forms";
+}
+
 def BitCastQuantTensorPass :
     InterfacePass<"torch-iree-bitcast-quant-tensor", "mlir::FunctionOpInterface"> {
   let summary = "Bitcasts i8 packed tensors of sub-byte types to the actual bit width";
diff --git a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt
index 6f86276..cabc6b2 100644
--- a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt
@@ -6,6 +6,7 @@
     "assume_strict_symbols.mlir"
     "auto_input_conversion.mlir"
     "attention.mlir"
+    "bind_symbolic_shapes.mlir"
     "bitcast_quant_tensor.mlir"
     "func_conversion.mlir"
     "func_conversion_invalid.mlir"
diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
new file mode 100644
index 0000000..e3d6061
--- /dev/null
+++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir
@@ -0,0 +1,178 @@
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(torch-iree-bind-symbolic-shapes))" --split-input-file --verify-diagnostics %s | FileCheck %s
+
+// This example was captured from a program which has a dynamic batch size and
+// tiled inner dim on one of the arguments, causing a symbolic relationship on
+// the second dimension.
+// CHECK-LABEL: @basic_example
+module @basic_example {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
+    // CHECK-DAG: %[[ARG1_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
+    // CHECK-DAG: %[[ARG0_ANCHOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
+    // CHECK-DAG: %[[POS0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[POS1:.*]] = arith.constant 1 : index
+    // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %1, %[[POS0]] :
+    // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %1, %[[POS1]] :
+    // CHECK: %[[ARG0_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
+    // CHECK: %[[ARG0_DIM0_RANGE:.*]] = util.assume.range %[[ARG0_DIM0_NARROW]] in [1, 1024] : index
+    // CHECK: %[[ARG0_DIM1_NARROW:.*]] = util.assume.narrow %[[DIM1]] : index to i32
+    // CHECK: %[[ARG0_DIM1_RANGE:.*]] = util.assume.range %[[ARG0_DIM1_NARROW]] in [1, 1024] : index
+    // CHECK: %[[ARG0_TIE:.*]] = flow.tensor.tie_shape %[[ARG0_ANCHOR]] : tensor<?x?xf32>{%[[ARG0_DIM0_RANGE]], %[[ARG0_DIM1_RANGE]]}
+    // CHECK: %[[ARG0_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG0_TIE]]
+    // CHECK: %[[ARG1_DIM0_NARROW:.*]] = util.assume.narrow %[[DIM0]] : index to i32
+    // CHECK: %[[ARG1_DIM0_RANGE:.*]] = util.assume.range %[[ARG1_DIM0_NARROW]] in [1, 1024]
+    // CHECK: %[[MULTIPLIER0:.*]] = arith.constant 2 : index
+    // CHECK: %[[ARG1_DIM1:.*]] = arith.muli %[[DIM1]], %[[MULTIPLIER0]]
+    // CHECK: %[[ARG1_DIM1_NARROW:.*]] = util.assume.narrow %[[ARG1_DIM1]] : index to i32
+    // CHECK: %[[ARG1_DIM1_RANGE:.*]] = util.assume.range %[[ARG1_DIM1_NARROW]] in [2, 2048] : index
+    // CHECK: %[[ARG1_DIM1_DIV:.*]] = util.assume.divisible %[[ARG1_DIM1_RANGE]] by 2
+    // CHECK: %[[ARG1_TIE:.*]] = flow.tensor.tie_shape %[[ARG1_ANCHOR]] : tensor<?x?xf32>{%[[ARG1_DIM0_RANGE]], %[[ARG1_DIM1_DIV]]}
+    // CHECK: %[[ARG1_EXPORT:.*]] = torch_c.from_builtin_tensor %[[ARG1_TIE]]
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+    %2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    %int1 = torch.constant.int 1
+    %int2 = torch.constant.int 2
+    %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
+    %4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    %int1_0 = torch.constant.int 1
+    %5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    return %5 : !torch.vtensor<[?,?],f32>
+  }
+}
+
+// -----
+// This example was captured from a torch program that used a symbol that did
+// not correspond to any dimension (being used in an expression as part of
+// distinct dimensions). This exercises a special case in the pass for deferring
+// to runtime resolution of the dim.
+// We just verify that the vital information has been captured.
+// CHECK-LABEL: @unbacked_symbol
+module @unbacked_symbol {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
+    // CHECK: util.assume.narrow
+    // CHECK: util.assume.range{{.*}} [1, 1024]
+    // CHECK: util.assume.narrow
+    // CHECK: util.assume.range{{.*}} [2, 2048]
+    // CHECK: util.assume.divisible{{.*}} by 2
+    // CHECK: tie_shape
+    // CHECK: util.assume.narrow
+    // CHECK: util.assume.range{{.*}} [1, 1024]
+    // CHECK: util.assume.narrow
+    // CHECK: util.assume.range{{.*}} [4, 4096]
+    // CHECK: util.assume.divisible{{.*}} by 4
+    // CHECK: tie_shape
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "2*s4" {min_val = 0, max_val = 2048} : !torch.int
+    %2 = torch.symbolic_int "4*s4" {min_val = 0, max_val = 4096} : !torch.int
+    %3 = torch.symbolic_int "s4" {min_val = 2, max_val = 1024} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
+    %int1 = torch.constant.int 1
+    %int2 = torch.constant.int 2
+    %4 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
+    %5 = torch.aten.repeat %arg0, %4 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %5, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
+    %int1_0 = torch.constant.int 1
+    %6 = torch.aten.add.Tensor %5, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %6, [%0, %3], affine_map<()[s0, s1] -> (s0, s1 * 4)> : !torch.vtensor<[?,?],f32>
+    return %6 : !torch.vtensor<[?,?],f32>
+  }
+}
+
+// -----
+// CHECK-LABEL: @all_bindings_dropped
+module @all_bindings_dropped {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
+    // CHECK-NOT: torch.symbolic_int
+    // CHECK-NOT: torch.bind_symbolic_shape
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+    %2 = torch.symbolic_int "2*s1" {min_val = 0, max_val = 2048} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    %int1 = torch.constant.int 1
+    %int2 = torch.constant.int 2
+    %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
+    %4 = torch.aten.repeat %arg0, %3 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    %int1_0 = torch.constant.int 1
+    %5 = torch.aten.add.Tensor %4, %arg1, %int1_0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %5, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 2)> : !torch.vtensor<[?,?],f32>
+    return %5 : !torch.vtensor<[?,?],f32>
+  }
+}
+
+// -----
+// CHECK-LABEL: @add_expr
+module @add_expr {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+    // CHECK: addi
+    // CHECK-NOT: divisible
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 + 2)> : !torch.vtensor<[?,?],f32>
+    return
+  }
+}
+
+// -----
+// CHECK-LABEL: @mod_expr
+module @mod_expr {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+    // CHECK: remui
+    // CHECK-NOT: divisible
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 mod 2)> : !torch.vtensor<[?,?],f32>
+    return
+  }
+}
+
+// -----
+// CHECK-LABEL: @floordiv_expr
+module @floordiv_expr {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+    // CHECK: divui
+    // CHECK-NOT: divisible
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 floordiv 2)> : !torch.vtensor<[?,?],f32>
+    return
+  }
+}
+
+// -----
+// Verifies that unsupported dim expressions warn (and do not assert).
+// CHECK-LABEL: @unsupported_non_symbolic
+module @unsupported_non_symbolic {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 1024} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 1024} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    // expected-warning@+1 {{Symbolic shape expression not supported: d0}}
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<(d0)[s0, s1] -> (s0, s1 + d0)> : !torch.vtensor<[?,?],f32>
+    return
+  }
+}
+
+// -----
+// Torch uses high values to signal unbounded ranges. Ensure they are
+// suppressed.
+// CHECK-LABEL: @torch_unbounded_max_range
+module @torch_unbounded_max_range {
+  func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) {
+    // CHECK-NOT: util.assume.range
+    %0 = torch.symbolic_int "s0" {min_val = 0, max_val = 4611686018427387903} : !torch.int
+    %1 = torch.symbolic_int "s1" {min_val = 0, max_val = 9223372036854775806} : !torch.int
+    torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32>
+    torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32>
+    return
+  }
+}
diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
index b6466d6..881d8d6 100644
--- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
+++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td
@@ -458,6 +458,80 @@
 
 let opDocGroup = OpGroupCompilerHintOps in {
 
+def Util_AssumeDivisibleOp :
+    Util_PureOp<"assume.divisible", [SameOperandsAndResultType]> {
+  let summary = "Memorializes knowledge that an index/integer value is divisible by some constant.";
+
+  let arguments = (ins
+    Util_Range:$operand,
+    Util_IndexAttr:$divisor
+  );
+  let results = (outs
+    Util_Range:$result
+  );
+  let assemblyFormat = [{
+    $operand `by` $divisor attr-dict `:` type($operand)
+  }];
+  let builders = [
+    OpBuilder<(ins
+      "Value":$operand,
+      "uint64_t":$divisor
+    ),
+    [{
+      IntegerAttr divisorAttr = $_builder.getIntegerAttr(
+          $_builder.getIndexType(), divisor);
+      build($_builder, $_state, operand.getType(), operand, divisorAttr);
+    }]>,
+  ];
+}
+
+def Util_AssumeNarrowOp :
+    Util_PureOp<"assume.narrow", [SameOperandsAndResultType]> {
+  let summary = "Memorializes knowledge that an index/integer value can be narrowed to a type.";
+
+  let arguments = (ins
+    Util_Range:$operand,
+    TypeAttr:$narrow_type
+  );
+  let results = (outs
+    Util_Range:$result
+  );
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` $narrow_type
+  }];
+}
+
+def Util_AssumeRangeOp :
+    Util_PureOp<"assume.range", [SameOperandsAndResultType]> {
+  let summary = "Memorializes knowledge that an index/integer value is always within some range.";
+
+  let arguments = (ins
+    Util_Range:$operand,
+    Util_IndexAttr:$min_value,
+    Util_IndexAttr:$max_value
+  );
+  let results = (outs
+    Util_Range:$result
+  );
+  let assemblyFormat = [{
+    $operand `in` ` ` `[` $min_value `,` $max_value `]` `:` type($operand) attr-dict
+  }];
+  let builders = [
+    OpBuilder<(ins
+      "Value":$operand,
+      "uint64_t":$minValue,
+      "uint64_t":$maxValue
+    ),
+    [{
+      IntegerAttr minAttr = $_builder.getIntegerAttr(
+          $_builder.getIndexType(), minValue);
+      IntegerAttr maxAttr = $_builder.getIntegerAttr(
+          $_builder.getIndexType(), maxValue);
+      build($_builder, $_state, operand.getType(), operand, minAttr, maxAttr);
+    }]>,
+  ];
+}
+
 def Util_OptimizationBarrierOp : Util_Op<"optimization_barrier", [
   SameOperandsAndResultType,
 ]> {
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
index ff7cefd..a6f072c 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp
@@ -19,9 +19,20 @@
   void runOnOperation() override {
     // We can't use patterns and applyPatternsAndFoldGreedily because that
     // automatically does canonicalization.
-    getOperation()->walk([&](IREE::Util::OptimizationBarrierOp op) {
-      op.replaceAllUsesWith(op.getOperands());
-      op.erase();
+    getOperation()->walk([&](Operation *genericOp) {
+      if (auto op = dyn_cast<IREE::Util::OptimizationBarrierOp>(genericOp)) {
+        op.replaceAllUsesWith(op.getOperands());
+        op.erase();
+      } else if (auto op = dyn_cast<IREE::Util::AssumeDivisibleOp>(genericOp)) {
+        op.replaceAllUsesWith({op.getOperand()});
+        op.erase();
+      } else if (auto op = dyn_cast<IREE::Util::AssumeRangeOp>(genericOp)) {
+        op.replaceAllUsesWith({op.getOperand()});
+        op.erase();
+      } else if (auto op = dyn_cast<IREE::Util::AssumeNarrowOp>(genericOp)) {
+        op.replaceAllUsesWith({op.getOperand()});
+        op.erase();
+      }
     });
   }
 };
diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
index 717d2bf..c0db60a 100644
--- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
+++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/drop_compiler_hints.mlir
@@ -73,3 +73,33 @@
     }
   }
 }
+
+// -----
+
+// CHECK-LABEL: @assume.divisible
+util.func @assume.divisible() -> i32 {
+  // CHECK-NOT: util.assume.divisible
+  %c1 = arith.constant 12 : i32
+  %0 = util.assume.divisible %c1 by 2 : i32
+  util.return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @assume.narrow
+util.func @assume.narrow() -> i32 {
+  // CHECK-NOT: util.assume.narrow
+  %c1 = arith.constant 12 : i32
+  %0 = util.assume.narrow %c1 : i32 to i8
+  util.return %0 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @assume.range
+util.func @assume.range() -> i32 {
+  // CHECK-NOT: util.assume.range
+  %c1 = arith.constant 12 : i32
+  %0 = util.assume.range %c1 in [2, 20] : i32
+  util.return %0 : i32
+}