Allow dynamic dimensions during folding of `tensor.expand_shape/collapse_shape` into `flow.dispatch.tensor.load/store`. (#18873)

This also cleans up the implementation of these patterns to avoid using
templated code that is hard to read/maintain.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
index 50ca569..6f1cc19 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/canonicalize_interface_load_store.mlir
@@ -71,26 +71,36 @@
 // -----
 
 #pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
+  #hal.pipeline.binding<storage_buffer>,
   #hal.pipeline.binding<storage_buffer>
 ]>
-// CHECK-LABEL: func.func @dont_fold_dynamic_reshape()
-func.func @dont_fold_dynamic_reshape() {
+func.func @fold_dynamic_reshape() {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %dim0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
   %dim1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
   %dim2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
   %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<readonly:tensor<?x?x96xf32>>{%dim0, %dim1}
-  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
+  %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) : !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
   %3 = flow.dispatch.tensor.load %1, offsets=[0, 0, 0], sizes =[%dim0, %dim1, 96], strides=[1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<?x?x96xf32>>{%dim0, %dim1} -> tensor<?x?x96xf32>
-  // CHECK: tensor.collapse_shape
-  // CHECK: tensor.expand_shape
   %4 = tensor.collapse_shape %3 [[0, 1], [2]] : tensor<?x?x96xf32> into tensor<?x96xf32>
   %dyn = tensor.dim %4, %c0 : tensor<?x96xf32>
   %5 = tensor.expand_shape %4 [[0], [1, 2]] output_shape [%dyn, 12, 8] : tensor<?x96xf32> into tensor<?x12x8xf32>
-  flow.dispatch.tensor.store %5, %2, offsets = [%c0, %c0, %c0], sizes = [%c1, 12, 8], strides = [%c1, %c1, %c1] : tensor<?x12x8xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
+  flow.dispatch.tensor.store %5, %2, offsets = [0, 0, 0], sizes = [%dim2, 12, 8], strides = [1, 1, 1] : tensor<?x12x8xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x12x8xf32>>{%dim2}
   return
 }
+//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
+//       CHECK: func.func @fold_dynamic_reshape()
+//   CHECK-DAG:   %[[CST0:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(0)
+//   CHECK-DAG:   %[[CST1:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(1)
+//   CHECK-DAG:   %[[CST2:.+]] = hal.interface.constant.load layout(#{{.+}}) ordinal(2)
+//       CHECK:   %[[COLLAPSED:.+]] = affine.apply #[[MAP]]()[%[[CST0]], %[[CST1]]]
+//       CHECK:   %[[IN_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(0) : !flow.dispatch.tensor<readonly:tensor<?x96xf32>>{%[[COLLAPSED]]}
+//       CHECK:   %[[OUT_BINDING:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       binding(1) : !flow.dispatch.tensor<writeonly:tensor<?x96xf32>>{%[[CST2]]}
+//       CHECK:   %[[IN:.+]] = flow.dispatch.tensor.load %[[IN_BINDING]]
+//       CHECK:   flow.dispatch.tensor.store %[[IN]], %[[OUT_BINDING]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
index 7dd745e..fc9e85e 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s | FileCheck %s
+// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-propagate-reshapes-by-expansion))" --split-input-file %s --mlir-print-local-scope | FileCheck %s
 
 func.func @reshape_and_lowering_config(%src: tensor<3x4xf16>, %dest: tensor<12xf16>, %dest2: tensor<12xf16>) -> tensor<12xf16> {
   %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<3x4xf16> into tensor<12xf16>
@@ -14,3 +14,75 @@
 //       CHECK:   linalg.copy
 //  CHECK-SAME:     lowering_config = #iree_gpu.derived_thread_config
 //  CHECK-SAME:     ins(%[[COLLAPSE]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
+    #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>
+func.func @fold_collapse_into_loads_dynamic() -> tensor<?x32xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+      flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0}
+  %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0} -> tensor<2x?x32xf32>
+  %3 = tensor.collapse_shape %2 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor<?x32xf32>
+  return %3 : tensor<?x32xf32>
+}
+// CHECK-LABEL: func @fold_collapse_into_loads_dynamic()
+//       CHECK:   %[[CONST:.+]] = hal.interface.constant.load
+//       CHECK:   %[[SHAPE:.+]] = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%[[CONST]]]
+//       CHECK:   %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x32xf32>>{%[[SHAPE]]}
+//       CHECK:   %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]]
+//  CHECK-SAME:       offsets = [0, 0], sizes = [%[[SHAPE]], 32], strides = [1, 1]
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<?x32xf32>>{%[[SHAPE]]}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 2, bindings = [
+    #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">], flags = Indirect>
+func.func @fold_expand_into_loads_dynamic() -> tensor<2x?x16x32xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+      flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0}
+  %2 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [2, %0, 32], strides = [1, 1, 1]
+      : !flow.dispatch.tensor<readonly:tensor<2x?x32xf32>>{%0} -> tensor<2x?x32xf32>
+  %3 = affine.apply affine_map<()[s0] -> (s0 floordiv 2)>()[%0]
+  %4 = tensor.expand_shape %2 [[0], [1, 2], [3]] output_shape [2, %3, 16, 32] : tensor<2x?x32xf32> into tensor<2x?x16x32xf32>
+  return %4 : tensor<2x?x16x32xf32>
+}
+// CHECK-LABEL: func @fold_expand_into_loads_dynamic()
+//   CHECK-DAG:   %[[C16:.+]] = arith.constant 16 : index
+//   CHECK-DAG:   %[[CONST:.+]] = hal.interface.constant.load
+//       CHECK:   %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C16]]
+//       CHECK:   %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<2x?x16x32xf32>>{%[[SHAPE]]}
+//       CHECK:   %[[LOAD:.+]] = flow.dispatch.tensor.load %[[SUBSPAN]]
+//  CHECK-SAME:       offsets = [0, 0, 0, 0], sizes = [2, %[[SHAPE]], 16, 32], strides = [1, 1, 1, 1]
+//  CHECK-SAME:       !flow.dispatch.tensor<readonly:tensor<2x?x16x32xf32>>{%[[SHAPE]]}
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<constants = 1, bindings = [
+    #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>
+func.func @fold_collapse_into_stores_dynamic(%arg0 : tensor<2x?x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
+  %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0)
+      flags("ReadOnly|Indirect") : !flow.dispatch.tensor<writeonly:tensor<?x32xf32>>{%0}
+  %2 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x?x32xf32> into tensor<?x32xf32>
+  flow.dispatch.tensor.store %2, %1, offsets = [0, 0], sizes = [%0, 32], strides = [1, 1]
+      : tensor<?x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x32xf32>>{%0}
+  return
+}
+// CHECK-LABEL: func @fold_collapse_into_stores_dynamic(
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//       CHECK:   %[[CONST:.+]] = hal.interface.constant.load
+//       CHECK:   %[[SHAPE:.+]] = arith.divui %[[CONST]], %[[C2]]
+//       CHECK:   %[[SUBSPAN:.+]] = hal.interface.binding.subspan
+//  CHECK-SAME:       !flow.dispatch.tensor<writeonly:tensor<2x?x32xf32>>{%[[SHAPE]]}
+//       CHECK:   flow.dispatch.tensor.store %{{.+}}, %[[SUBSPAN]]
+//  CHECK-SAME:       offsets = [0, 0, 0], sizes = [2, %[[SHAPE]], 32], strides = [1, 1, 1]
+//  CHECK-SAME:       !flow.dispatch.tensor<writeonly:tensor<2x?x32xf32>>{%[[SHAPE]]}
diff --git a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
index 0b8c49c..6d0d052 100644
--- a/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
+++ b/compiler/src/iree/compiler/Codegen/Transforms/Transforms.cpp
@@ -36,20 +36,29 @@
 
 namespace mlir::iree_compiler {
 
-static bool isAllConstantValue(SmallVector<OpFoldResult> ofrs, int64_t v) {
+static bool isAllConstantValue(ArrayRef<OpFoldResult> ofrs, int64_t v) {
   return llvm::all_of(
       ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, v); });
 }
 
-static bool isFullSlice(SmallVector<OpFoldResult> mixedOffsets,
-                        SmallVector<OpFoldResult> mixedSizes,
-                        SmallVector<OpFoldResult> mixedStrides,
-                        IREE::Flow::DispatchTensorType tensorType) {
-  std::optional<SmallVector<int64_t>> constSizes =
-      getConstantIntValues(mixedSizes);
+static bool isFullSlice(ArrayRef<OpFoldResult> mixedOffsets,
+                        ArrayRef<OpFoldResult> mixedSizes,
+                        ArrayRef<OpFoldResult> mixedStrides,
+                        IREE::Flow::DispatchTensorType tensorType,
+                        ValueRange dynamicDims) {
+  OpBuilder builder(tensorType.getContext());
+  SmallVector<int64_t> tensorShape = llvm::to_vector(tensorType.getShape());
+  SmallVector<OpFoldResult> mixedTensorShape =
+      mlir::getMixedValues(tensorShape, dynamicDims, builder);
   return isAllConstantValue(mixedOffsets, 0) &&
-         isAllConstantValue(mixedStrides, 1) && constSizes &&
-         llvm::equal(tensorType.getShape(), *constSizes);
+         isAllConstantValue(mixedStrides, 1) && mixedTensorShape == mixedSizes;
+}
+static bool isFullSlice(OffsetSizeAndStrideOpInterface sliceLoadStoreOp,
+                        IREE::Flow::DispatchTensorType tensorType,
+                        ValueRange dynamicDims) {
+  return isFullSlice(
+      sliceLoadStoreOp.getMixedOffsets(), sliceLoadStoreOp.getMixedSizes(),
+      sliceLoadStoreOp.getMixedStrides(), tensorType, dynamicDims);
 }
 
 static bool sliceFilter(Operation *op, ValueRange nonIndexComputationOperands,
@@ -546,14 +555,29 @@
 
 namespace {
 
-// TODO(antigainst): enable dynamic shape support once they are needed.
-template <typename TensorReshapeOp>
-static std::optional<Value> getStaticReshapeOpSrc(TensorReshapeOp reshapeOp) {
-  auto reshapeSrcType = llvm::cast<ShapedType>(reshapeOp.getSrc().getType());
-  auto reshapeDstType = llvm::cast<ShapedType>(reshapeOp.getType());
-  if (!reshapeSrcType.hasStaticShape() || !reshapeDstType.hasStaticShape())
-    return std::nullopt;
-  return reshapeOp.getSrc();
+static SmallVector<OpFoldResult>
+inferCollapsedShape(RewriterBase &rewriter, Location loc,
+                    RankedTensorType expandedType,
+                    ArrayRef<ReassociationIndices> reassociations,
+                    ValueRange expandedDynamicDims) {
+  ArrayRef<int64_t> expandedStaticShape = expandedType.getShape();
+  SmallVector<OpFoldResult> expandedMixedShape =
+      mlir::getMixedValues(expandedStaticShape, expandedDynamicDims, rewriter);
+  SmallVector<OpFoldResult> collapsedShape;
+  unsigned expandedShapeDim = 0;
+  for (auto reassociation : reassociations) {
+    AffineExpr mulExpr = rewriter.getAffineSymbolExpr(0);
+    for (auto i : llvm::seq<unsigned>(1, reassociation.size())) {
+      mulExpr = mulExpr * rewriter.getAffineSymbolExpr(i);
+    }
+    auto collapsedDim = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, mulExpr,
+        ArrayRef(expandedMixedShape)
+            .slice(expandedShapeDim, reassociation.size()));
+    collapsedShape.push_back(collapsedDim);
+    expandedShapeDim += reassociation.size();
+  }
+  return collapsedShape;
 }
 
 /// Folds tensor.expand/collapse_shape into the source
@@ -576,35 +600,38 @@
 ///       !flow.dispatch.tensor<readonly:tensor<864xf32>>
 ///   %0 = flow.dispatch.tensor.load %subspan :
 ///       !flow.dispatch.tensor<readonly:tensor<864xf32>> -> tensor<864xf32>
-template <typename TensorReshapeOp>
-struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> {
-  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+struct FoldCollapseShapeIntoInterfaceTensorLoad
+    : OpRewritePattern<tensor::CollapseShapeOp> {
+  using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
-    std::optional<Value> reshapeSrc =
-        getStaticReshapeOpSrc<TensorReshapeOp>(reshapeOp);
-    if (!reshapeSrc)
-      return failure();
-
-    auto loadOp =
-        reshapeSrc->template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+    Value reshapeSrc = reshapeOp.getSrc();
+    auto reshapeSrcType = cast<RankedTensorType>(reshapeSrc.getType());
+    auto loadOp = reshapeSrc.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
     if (!loadOp)
       return failure();
 
     // Make sure we are loading the full incoming subspan. Otherwise we cannot
     // simply adjust the subspan's resultant type later.
-    if (!isFullSlice(loadOp.getMixedOffsets(), loadOp.getMixedSizes(),
-                     loadOp.getMixedStrides(), loadOp.getSourceType())) {
+    if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) {
       return failure();
     }
 
-    auto subspanOp =
-        loadOp.getSource()
-            .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+    auto subspanOp = loadOp.getSource()
+                         .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
     if (!subspanOp)
       return failure();
-    assert(subspanOp.getDynamicDims().empty());
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(subspanOp);
+    SmallVector<OpFoldResult> collapsedShape = inferCollapsedShape(
+        rewriter, subspanOp.getLoc(), reshapeSrcType,
+        reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims());
+    SmallVector<int64_t> collapsedStaticShape;
+    SmallVector<Value> collapsedDynamicShape;
+    dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape,
+                               collapsedStaticShape);
 
     auto tensorAccess =
         llvm::cast<IREE::Flow::DispatchTensorType>(subspanOp.getType())
@@ -615,12 +642,111 @@
     Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
         subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
         subspanOp.getBinding(), subspanOp.getByteOffset(),
-        subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
+        collapsedDynamicShape, subspanOp.getAlignmentAttr(),
         subspanOp.getDescriptorFlagsAttr());
 
+    rewriter.setInsertionPoint(reshapeOp);
     rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
         reshapeOp, reshapeOp.getResultType(), newSubspanOp,
-        loadOp.getSourceDims());
+        collapsedDynamicShape);
+
+    return success();
+  }
+};
+
+/// Folds tensor.expand_shape into the source
+/// hal.interface.binding.subspan.
+///
+/// For example, this matches the following pattern:
+///
+///   %subspan = hal.interface.binding.subspan ... :
+///       !flow.dispatch.tensor<readonly:tensor<3x3x1x96xf32>>
+///   %tensor = flow.dispatch.tensor.load %subspan :
+///       !flow.dispatch.tensor<readonly:tensor<3x3x1x96xf32>> ->
+///       tensor<3x3x1x96xf32>
+///   %0 = linalg.expand_reshape %tensor [
+///         affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+///       ] : tensor<3x3x1x96xf32> into tensor<864xf32>
+///
+/// And turns it into:
+///
+///   %subspan = hal.interface.binding.subspan ... :
+///       !flow.dispatch.tensor<readonly:tensor<864xf32>>
+///   %0 = flow.dispatch.tensor.load %subspan :
+///       !flow.dispatch.tensor<readonly:tensor<864xf32>> -> tensor<864xf32>
+struct FoldExpandShapeIntoInterfaceTensorLoad
+    : OpRewritePattern<tensor::ExpandShapeOp> {
+  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    Value reshapeSrc = reshapeOp.getSrc();
+    auto loadOp = reshapeSrc.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
+    if (!loadOp) {
+      return failure();
+    }
+
+    // Make sure we are loading the full incoming subspan. Otherwise we cannot
+    // simply adjust the subspan's resultant type later.
+    if (!isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) {
+      return failure();
+    }
+
+    // In the corner case where the expand_shape is the source of a store, dont
+    // fold with the load. Instead fold with the store to reduce the
+    // dimensionality
+    if (reshapeOp->hasOneUse()) {
+      if (auto storeOp = dyn_cast<IREE::Flow::DispatchTensorStoreOp>(
+              *reshapeOp->getUsers().begin())) {
+        if (isFullSlice(storeOp, storeOp.getTargetType(),
+                        storeOp.getTargetDims())) {
+          return rewriter.notifyMatchFailure(reshapeOp,
+                                             "fold with store instead");
+        }
+      }
+    }
+
+    auto subspanOp = loadOp.getSource()
+                         .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+    if (!subspanOp)
+      return failure();
+
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(subspanOp);
+
+    auto currDynamicDims = subspanOp.getDynamicDims();
+    auto currStaticDims = loadOp.getType().getShape();
+    auto currOfrDynamicDims =
+        mlir::getMixedValues(currStaticDims, currDynamicDims, rewriter);
+    std::optional<SmallVector<OpFoldResult>> expandedDims =
+        mlir::inferExpandShapeOutputShape(
+            rewriter, subspanOp.getLoc(), reshapeOp.getType(),
+            reshapeOp.getReassociationIndices(), currOfrDynamicDims);
+    if (!expandedDims) {
+      return reshapeOp.emitOpError("failure in expanded shape");
+    }
+
+    auto tensorAccess =
+        llvm::cast<IREE::Flow::DispatchTensorType>(subspanOp.getType())
+            .getAccess();
+    auto newSubspanType = IREE::Flow::DispatchTensorType::get(
+        tensorAccess, reshapeOp.getResultType());
+
+    SmallVector<Value> expandedDynamicDims;
+    SmallVector<int64_t> expandedStaticDims;
+    dispatchIndexOpFoldResults(expandedDims.value(), expandedDynamicDims,
+                               expandedStaticDims);
+
+    Value newSubspanOp;
+    newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+        subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicDims,
+        subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
+
+    rewriter.setInsertionPoint(reshapeOp);
+    rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
+        reshapeOp, reshapeOp.getResultType(), newSubspanOp,
+        expandedDynamicDims);
 
     return success();
   }
@@ -652,8 +778,8 @@
                                 PatternRewriter &rewriter) const override {
     // Make sure we are storing the full incoming subspan. Otherwise we cannot
     // simply adjust the subspan's resultant type later.
-    if (!isFullSlice(storeOp.getMixedOffsets(), storeOp.getMixedSizes(),
-                     storeOp.getMixedStrides(), storeOp.getTargetType())) {
+    if (!isFullSlice(storeOp, storeOp.getTargetType(),
+                     storeOp.getTargetDims())) {
       return failure();
     }
 
@@ -662,38 +788,136 @@
       return failure();
     }
 
-    // Dynamic shapes are currently unsupported.
-    std::optional<Value> reshapeSrc =
-        getStaticReshapeOpSrc<tensor::ExpandShapeOp>(reshapeOp);
-    if (!reshapeSrc)
-      return failure();
+    Value reshapeSrc = reshapeOp.getSrc();
+    // If the source is a `flow.dispatch.tensor.load`, fold with the load
+    // instead to reduce dimensionality of the problem
+    if (auto loadOp =
+            reshapeSrc.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>()) {
+      if (isFullSlice(loadOp, loadOp.getSourceType(), loadOp.getSourceDims())) {
+        return rewriter.notifyMatchFailure(
+            storeOp, "fold expand_shape with load instead");
+      }
+    }
 
-    auto subspanOp =
-        storeOp.getTarget()
-            .template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+    auto subspanOp = storeOp.getTarget()
+                         .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
     if (!subspanOp)
       return failure();
-    assert(subspanOp.getDynamicDims().empty());
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(subspanOp);
+    SmallVector<OpFoldResult> collapsedShape = inferCollapsedShape(
+        rewriter, subspanOp.getLoc(), reshapeOp.getResultType(),
+        reshapeOp.getReassociationIndices(), subspanOp.getDynamicDims());
+    SmallVector<int64_t> collapsedStaticShape;
+    SmallVector<Value> collapsedDynamicShape;
+    dispatchIndexOpFoldResults(collapsedShape, collapsedDynamicShape,
+                               collapsedStaticShape);
 
     auto tensorAccess =
         llvm::cast<IREE::Flow::DispatchTensorType>(subspanOp.getType())
             .getAccess();
-    auto newSubspanType = IREE::Flow::DispatchTensorType::get(
-        tensorAccess, reshapeSrc->getType());
+    auto newSubspanType =
+        IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrc.getType());
 
-    Value newSubspanOp;
-    {
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointAfter(subspanOp);
-      newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
-          subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
-          subspanOp.getBinding(), subspanOp.getByteOffset(),
-          subspanOp.getDynamicDims(), subspanOp.getAlignmentAttr(),
-          subspanOp.getDescriptorFlagsAttr());
+    Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+        subspanOp.getBinding(), subspanOp.getByteOffset(),
+        collapsedDynamicShape, subspanOp.getAlignmentAttr(),
+        subspanOp.getDescriptorFlagsAttr());
+
+    rewriter.setInsertionPoint(storeOp);
+    rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
+        storeOp, reshapeSrc, newSubspanOp, collapsedDynamicShape);
+
+    return success();
+  }
+};
+
+/// Folds tensor.collapse_shape into the source hal.interface.binding.subspan.
+///
+/// For example, this matches the following pattern:
+///
+///   %subspan = hal.interface.binding.subspan ... :
+///       !flow.dispatch.tensor<writeonly:tensor<3x3x1x96xf32>>
+///   %0 = tensor.collapse_shape %tensor [[0, 1, 2, 3]]
+///       : tensor<3x?x?x96xf32> into tensor<?xf32>
+///   %tensor = flow.dispatch.tensor.store %0, %subspan :
+///       tensor<?xf32> -> !flow.dispatch.tensor<writeonly:tensor<?xf32>>{%dim}
+///
+/// And turns it into:
+///
+///   %subspan = hal.interface.binding.subspan ... :
+///       !flow.dispatch.tensor<writeonly:tensor<3x?x?x96xf32>>
+///   %0 = flow.dispatch.tensor.store %tensor, %subspan :
+///       tensor<3x?x?x96xf32> ->
+///       !flow.dispatch.tensor<writeonly:tensor<3x?x?x96xf32>>{%d0, %d1}
+///
+/// TODO: This handles full slices. The pattern below
+/// (`FoldCollapseShapeIntoTensorInsertSlice`) handles cases where the slic is
+/// not a full slice, but requires the shapes to be static. This pattern handles
+/// dynamic shapes as well. Combine the two (if possible, it isnt clear that it
+/// is possible)
+struct FoldCollapseShapeIntoInterfaceTensorStoreFullSlice
+    : OpRewritePattern<IREE::Flow::DispatchTensorStoreOp> {
+  using OpRewritePattern<IREE::Flow::DispatchTensorStoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IREE::Flow::DispatchTensorStoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    // Make sure we are storing the full incoming subspan. Otherwise we cannot
+    // simply adjust the subspan's resultant type later.
+    if (!isFullSlice(storeOp, storeOp.getTargetType(),
+                     storeOp.getTargetDims())) {
+      return failure();
     }
 
+    auto reshapeOp =
+        storeOp.getValue().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!reshapeOp) {
+      return failure();
+    }
+    auto subspanOp = storeOp.getTarget()
+                         .getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
+    if (!subspanOp)
+      return failure();
+
+    Value reshapeSrc = reshapeOp.getSrc();
+    auto reshapeSrcType = cast<RankedTensorType>(reshapeSrc.getType());
+
+    // Compute the type and dynamic dims of the interface binding.
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(subspanOp);
+    auto dynamicDims = subspanOp.getDynamicDims();
+    ArrayRef<int64_t> staticShape = reshapeOp.getType().getShape();
+    SmallVector<OpFoldResult> mixedShape =
+        mlir::getMixedValues(staticShape, dynamicDims, rewriter);
+    std::optional<SmallVector<OpFoldResult>> expandedShape =
+        mlir::inferExpandShapeOutputShape(
+            rewriter, subspanOp.getLoc(),
+            cast<ShapedType>(reshapeSrc.getType()),
+            reshapeOp.getReassociationIndices(), mixedShape);
+    if (!expandedShape) {
+      return rewriter.notifyMatchFailure(
+          storeOp, "failed to compute expand shape for interface binding");
+    }
+    SmallVector<int64_t> expandedStaticShape;
+    SmallVector<Value> expandedDynamicShape;
+    dispatchIndexOpFoldResults(*expandedShape, expandedDynamicShape,
+                               expandedStaticShape);
+
+    auto tensorAccess =
+        cast<IREE::Flow::DispatchTensorType>(subspanOp.getType()).getAccess();
+    auto newSubspanType =
+        IREE::Flow::DispatchTensorType::get(tensorAccess, reshapeSrcType);
+
+    auto newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
+        subspanOp.getLoc(), newSubspanType, subspanOp.getLayout(),
+        subspanOp.getBinding(), subspanOp.getByteOffset(), expandedDynamicShape,
+        subspanOp.getAlignmentAttr(), subspanOp.getDescriptorFlagsAttr());
+
+    rewriter.setInsertionPoint(storeOp);
     rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorStoreOp>(
-        storeOp, *reshapeSrc, newSubspanOp, storeOp.getTargetDims());
+        storeOp, reshapeSrc, newSubspanOp, expandedDynamicShape);
 
     return success();
   }
@@ -840,12 +1064,11 @@
 } // namespace
 
 void populateReshapeToInterfaceTensorPatterns(RewritePatternSet &patterns) {
-  patterns.insert<FoldReshapeIntoInterfaceTensorLoad<tensor::CollapseShapeOp>,
-                  FoldReshapeIntoInterfaceTensorLoad<tensor::ExpandShapeOp>>(
-      patterns.getContext());
-  patterns.insert<FoldExpandShapeIntoInterfaceTensorStore>(
-      patterns.getContext());
-  patterns.insert<FoldCollapseShapeIntoInterfaceTensorStore>(
+  patterns.insert<FoldCollapseShapeIntoInterfaceTensorLoad,
+                  FoldCollapseShapeIntoInterfaceTensorStore,
+                  FoldCollapseShapeIntoInterfaceTensorStoreFullSlice,
+                  FoldExpandShapeIntoInterfaceTensorLoad,
+                  FoldExpandShapeIntoInterfaceTensorStore>(
       patterns.getContext());
 }