[LinalgExt] Scatter fusion by expansion 3/3 (#19588)

Implements fusion with reshapes by expansion for `LinalgExt::ScatterOp`.

See main issue https://github.com/iree-org/iree/issues/19091

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 901288b..6ad28e5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -14,14 +14,50 @@
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
 
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
+static bool
+isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
+  for (auto &index : indices) {
+    if (index.size() != 1)
+      return false;
+  }
+  return true;
+};
+
+static SmallVector<ReassociationIndices>
+computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
+  SmallVector<ReassociationIndices> reassoc;
+  int64_t dimCount = 0;
+  for (auto &shape : shapeMap) {
+    reassoc.emplace_back(
+        llvm::to_vector(llvm::seq<int64_t>(dimCount, dimCount + shape.size())));
+    dimCount += shape.size();
+  }
+  return reassoc;
+}
+
 namespace {
 
+/// Helper class that supports fusing reshapes with operands when not all of the
+/// shape dims map to the iteration space.
+struct ReshapeOperandInfo {
+  static constexpr int64_t kNoMapping = -1;
+
+  // Original shape of this operand.
+  ArrayRef<int64_t> originalShape;
+
+  // Similar to the results of the operand's `AffineMap` except `kNoMapping` if
+  // that dim doesn't map to the iteration space. For example, the indexed
+  // dimensions in a LinalgExt::ScatterOp.
+  SmallVector<int64_t> operandToIterationSpace;
+};
+
 /// Information needed to expand an operation to fold the reshape with
 /// it.
 class ExpansionInfo {
@@ -30,32 +66,78 @@
   // of the expanded op given the `indexingMap` of the fused operand/result of
   // the op, the `reassocationMaps` of the reshape op and the shape of
   // the expanded op.
-  template <typename OpTy>
-  LogicalResult compute(OpTy op, OpOperand *fusableOpOperand,
-                        ArrayRef<AffineMap> reassociationMaps,
-                        ArrayRef<int64_t> expandedShape,
-                        ArrayRef<int64_t> collapsedShape,
-                        PatternRewriter &rewriter);
-  unsigned getOrigOpNumDims() const { return reassociation.size(); }
-  unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
-  ReassociationIndicesRef getExpandedDims(unsigned i) const {
-    return reassociation[i];
+  LogicalResult compute(SmallVector<ReshapeOperandInfo> infos,
+                        SmallVector<int64_t> loopRanges,
+                        OpOperand *fusableOpOperand,
+                        ArrayRef<ReassociationIndices> operandReassoc,
+                        ArrayRef<int64_t> expandedShape);
+
+  std::optional<Value> getOrCreateExpanded(Location loc, OpOperand *operand,
+                                           RewriterBase &rewriter) {
+    auto shapeMap = this->getShapeMap(operand);
+    auto reassoc = computeReassocFromShapeMap(shapeMap);
+    if (isIdentityReassoc(reassoc)) {
+      return operand->get();
+    }
+    SmallVector<int64_t> flattenedArray;
+    for (auto &shape : shapeMap) {
+      flattenedArray.append(shape.begin(), shape.end());
+    }
+    auto oldType = cast<ShapedType>(operand->get().getType());
+    auto newType =
+        RankedTensorType::get(flattenedArray, oldType.getElementType());
+    if (failed(reshapeLikeShapesAreCompatible(
+            [&](const Twine &msg) {
+              return rewriter.notifyMatchFailure(loc, msg);
+            },
+            oldType.getShape(), newType.getShape(), reassoc,
+            /*isExpandingReshape=*/true))) {
+      return {};
+    }
+    return rewriter.create<tensor::ExpandShapeOp>(loc, newType, operand->get(),
+                                                  reassoc);
+  };
+
+  /// Get the shape map for the operand.
+  SmallVector<SmallVector<int64_t>> getShapeMap(OpOperand *operand) const {
+    auto info = reshapeInfos[operand->getOperandNumber()];
+    SmallVector<SmallVector<int64_t>> shapeMap;
+    for (auto [operandIdx, loopIdx] :
+         llvm::enumerate(info.operandToIterationSpace)) {
+      if (loopIdx == ReshapeOperandInfo::kNoMapping) {
+        shapeMap.push_back(
+            SmallVector<int64_t>{info.originalShape[operandIdx]});
+      } else {
+        shapeMap.push_back(loopShapeMap[loopIdx]);
+      }
+    }
+    return shapeMap;
   }
-  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
-    return expandedShapeMap[i];
+
+  SmallVector<ReassociationIndices> getReassoc(OpOperand *operand) const {
+    auto shapeMap = this->getShapeMap(operand);
+    return computeReassocFromShapeMap(shapeMap);
   }
-  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+
+  unsigned getOrigNumLoops() const { return loopReassoc.size(); }
+  unsigned getExpandedNumLoops() const { return expandedOpNumDims; }
+  ReassociationIndicesRef getExpandedLoops(unsigned i) const {
+    return loopReassoc[i];
+  }
+  ArrayRef<int64_t> getExpandedShapeOfLoop(unsigned i) const {
+    return loopShapeMap[i];
+  }
 
 private:
-  /// Reassociation from the dimensions in the original operation to the
-  /// dimension of the expanded operation.
-  SmallVector<ReassociationIndices> reassociation;
+  /// Extent of the iteration space in the original operation.
+  SmallVector<int64_t> loopRanges;
+  SmallVector<ReassociationIndices> loopReassoc;
   /// Mapping from extent of loops in the original operation, to the extent of
   /// loops in the expanded operation.
-  SmallVector<SmallVector<int64_t>> expandedShapeMap;
-  /// Extent of the loop in the original operation.
-  SmallVector<int64_t> originalLoopExtent;
+  SmallVector<SmallVector<int64_t>> loopShapeMap;
   unsigned expandedOpNumDims;
+  /// Info about the reassociation and original shape for each operand.
+  SmallVector<ReshapeOperandInfo> reshapeInfos;
 };
 
 class CollapsingInfo {
@@ -109,50 +191,46 @@
 
 } // namespace
 
-template <typename OpTy>
-LogicalResult ExpansionInfo::compute(OpTy op, OpOperand *fusableOpOperand,
-                                     ArrayRef<AffineMap> reassociationMaps,
-                                     ArrayRef<int64_t> expandedShape,
-                                     ArrayRef<int64_t> collapsedShape,
-                                     PatternRewriter &rewriter) {
-  if (reassociationMaps.empty())
+LogicalResult ExpansionInfo::compute(
+    SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t> loopRanges,
+    OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
+    ArrayRef<int64_t> expandedShape) {
+  if (operandReassoc.empty())
     return failure();
-  AffineMap fusedIndexMap = op.getMatchingIndexingMap(fusableOpOperand);
-  FailureOr<SmallVector<int64_t>> originalLoopRange = op.getStaticLoopRanges();
-  if (failed(originalLoopRange)) {
+
+  int64_t operandNum = fusableOpOperand->getOperandNumber();
+  ReshapeOperandInfo &fusionOperandInfo = infos[operandNum];
+  this->loopShapeMap.clear();
+  this->loopShapeMap.resize(loopRanges.size());
+  for (auto [operandIdx, loopIdx] :
+       llvm::enumerate(fusionOperandInfo.operandToIterationSpace)) {
+    if (loopIdx == ReshapeOperandInfo::kNoMapping) {
+      continue;
+    }
+
+    // Compute the shape map at element `loopIdx`
+    ReassociationIndicesRef indices = operandReassoc[operandIdx];
+    for (auto [dimIdx, shapeIdx] : llvm::enumerate(indices)) {
+      this->loopShapeMap[loopIdx].push_back(expandedShape[shapeIdx]);
+    }
+  }
+
+  // Fill in the remaining elements with `loopRanges`
+  this->expandedOpNumDims = 0;
+  for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) {
+    if (shapeMap.empty()) {
+      this->loopShapeMap[loopIdx] = SmallVector<int64_t>{loopRanges[loopIdx]};
+    }
+    this->expandedOpNumDims += shapeMap.size();
+  }
+
+  if (llvm::all_of(this->loopShapeMap,
+                   [&](auto vec) { return vec.size() == 1; })) {
     return failure();
   }
-  originalLoopExtent.assign(originalLoopRange->begin(),
-                            originalLoopRange->end());
-
-  reassociation.clear();
-  expandedShapeMap.clear();
-  // Compute the number of dimension in the expanded op that correspond to each
-  // dimension of the original op.
-  SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
-  expandedShapeMap.resize(fusedIndexMap.getNumDims());
-  for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
-    unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
-    AffineMap foldedDims = reassociationMaps[resultExpr.index()];
-    numExpandedDims[pos] = foldedDims.getNumResults();
-    ArrayRef<int64_t> shape =
-        expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
-    expandedShapeMap[pos].assign(shape.begin(), shape.end());
-  }
-  // The remaining dimensions remain the same.
-  for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
-    if (expandedShapeMap[i].empty())
-      expandedShapeMap[i] = {originalLoopExtent[i]};
-
-  // Compute reassociation map from the original op to the expanded op.
-  unsigned sum = 0;
-  reassociation.reserve(fusedIndexMap.getNumDims());
-  for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
-    auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
-    reassociation.emplace_back(seq.begin(), seq.end());
-    sum += numFoldedDim.value();
-  }
-  expandedOpNumDims = sum;
+  this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap);
+  this->reshapeInfos = std::move(infos);
+  this->loopRanges = std::move(loopRanges);
   return success();
 }
 
@@ -201,6 +279,77 @@
   return success();
 }
 
+static SmallVector<ReshapeOperandInfo>
+getAttentionReshapeInfo(LinalgExt::AttentionOp attentionOp) {
+  return llvm::map_to_vector(
+      attentionOp->getOpOperands(), [&](OpOperand &opOperand) {
+        ReshapeOperandInfo operandInfo;
+        auto operandType = dyn_cast<ShapedType>(opOperand.get().getType());
+        if (!operandType) {
+          assert(
+              attentionOp.getMatchingIndexingMap(&opOperand).getNumResults() ==
+                  0 &&
+              "expected non-shaped type to have no results in indexing map");
+          return operandInfo;
+        }
+
+        operandInfo.originalShape = operandType.getShape();
+        for (auto result :
+             attentionOp.getMatchingIndexingMap(&opOperand).getResults()) {
+          operandInfo.operandToIterationSpace.push_back(
+              cast<AffineDimExpr>(result).getPosition());
+        }
+        return operandInfo;
+      });
+}
+
+static SmallVector<ReshapeOperandInfo>
+getScatterReshapeInfo(LinalgExt::ScatterOp scatterOp) {
+  SmallVector<ReshapeOperandInfo> infos;
+  auto rankOfContiguousSlice =
+      scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
+  auto updateRank = scatterOp.getUpdateType().getRank();
+
+  // Operand #0 Updates
+  {
+    ReshapeOperandInfo updateInfo;
+    updateInfo.originalShape = scatterOp.getUpdateType().getShape();
+    llvm::append_range(updateInfo.operandToIterationSpace,
+                       llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
+    updateInfo.operandToIterationSpace.append(
+        updateRank - (rankOfContiguousSlice + scatterOp.getBatchRank()),
+        ReshapeOperandInfo::kNoMapping);
+    llvm::append_range(
+        updateInfo.operandToIterationSpace,
+        llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
+    infos.push_back(std::move(updateInfo));
+  }
+
+  // Operand#1 Indices
+  {
+    ReshapeOperandInfo indicesInfo;
+    indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
+    llvm::append_range(indicesInfo.operandToIterationSpace,
+                       llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
+    indicesInfo.operandToIterationSpace.push_back(
+        ReshapeOperandInfo::kNoMapping);
+    infos.push_back(std::move(indicesInfo));
+  }
+
+  // Operand #2 Original
+  {
+    ReshapeOperandInfo originalInfo;
+    originalInfo.originalShape = scatterOp.getOriginalType().getShape();
+    originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(),
+                                                ReshapeOperandInfo::kNoMapping);
+    llvm::append_range(
+        originalInfo.operandToIterationSpace,
+        llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
+    infos.push_back(std::move(originalInfo));
+  }
+  return infos;
+};
+
 static AffineMap
 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
                            const ExpansionInfo &expansionInfo) {
@@ -208,46 +357,17 @@
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned pos = cast<AffineDimExpr>(expr).getPosition();
     auto expandedExprs = llvm::to_vector_of<AffineExpr, 6>(
-        llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
+        llvm::map_range(expansionInfo.getExpandedLoops(pos), [&](int64_t v) {
           return builder.getAffineDimExpr(static_cast<unsigned>(v));
         }));
     newExprs.append(expandedExprs.begin(), expandedExprs.end());
   }
-  return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
+  return AffineMap::get(expansionInfo.getExpandedNumLoops(),
                         indexingMap.getNumSymbols(), newExprs,
                         builder.getContext());
 }
 
-static RankedTensorType getExpandedType(RankedTensorType originalType,
-                                        AffineMap indexingMap,
-                                        const ExpansionInfo &expansionInfo) {
-  SmallVector<int64_t> expandedShape;
-  for (AffineExpr expr : indexingMap.getResults()) {
-    unsigned dim = cast<AffineDimExpr>(expr).getPosition();
-    auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
-    expandedShape.append(dimExpansion.begin(), dimExpansion.end());
-  }
-  return RankedTensorType::get(expandedShape, originalType.getElementType());
-}
-
-static SmallVector<ReassociationIndices>
-getReassociationForExpansion(AffineMap indexingMap,
-                             const ExpansionInfo &expansionInfo) {
-  SmallVector<ReassociationIndices> reassociation;
-  unsigned numReshapeDims = 0;
-  for (AffineExpr expr : indexingMap.getResults()) {
-    unsigned dim = cast<AffineDimExpr>(expr).getPosition();
-    auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
-    SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
-        llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
-    reassociation.emplace_back(std::move(indices));
-    numReshapeDims += numExpandedDims;
-  }
-  return reassociation;
-}
-
-template <typename OpTy>
-static bool isFusableWithReshapeByDimExpansion(OpTy op,
+static bool isFusableWithReshapeByDimExpansion(AttentionOp op,
                                                OpOperand *fusableOpOperand) {
   // Is fusable only if:
   // - All the indexing maps for operands and results are projected
@@ -256,10 +376,11 @@
   // - All the loops for the reshaped operand are parallel loops.
   SmallVector<utils::IteratorType> iteratorTypes = op.getLoopIteratorTypes();
   AffineMap operandMap = op.getMatchingIndexingMap(fusableOpOperand);
-  return op.hasPureTensorSemantics() &&
-         llvm::all_of(
-             op.getIndexingMapsArray(),
-             [](AffineMap map) { return map.isProjectedPermutation(); }) &&
+  return operandMap && op.hasPureTensorSemantics() &&
+         llvm::all_of(op.getIndexingMapsArray(),
+                      [](AffineMap map) {
+                        return map && map.isProjectedPermutation();
+                      }) &&
          operandMap.getNumResults() > 0;
 }
 
@@ -277,16 +398,13 @@
   RankedTensorType expandedType = isExpanding
                                       ? expandingReshapeOp.getResultType()
                                       : collapsingReshapeOp.getSrcType();
-  RankedTensorType collapsedType = isExpanding
-                                       ? expandingReshapeOp.getSrcType()
-                                       : collapsingReshapeOp.getResultType();
-
   ExpansionInfo expansionInfo;
   if (failed(expansionInfo.compute(
-          attentionOp, fusableOpOperand,
-          isExpanding ? expandingReshapeOp.getReassociationMaps()
-                      : collapsingReshapeOp.getReassociationMaps(),
-          expandedType.getShape(), collapsedType.getShape(), rewriter)))
+          getAttentionReshapeInfo(attentionOp),
+          attentionOp.getStaticLoopRanges().value(), fusableOpOperand,
+          isExpanding ? expandingReshapeOp.getReassociationIndices()
+                      : collapsingReshapeOp.getReassociationIndices(),
+          expandedType.getShape())))
     return std::nullopt;
   auto expandedOpIndexingMaps = llvm::to_vector_of<AffineMap, 6>(
       llvm::map_range(attentionOp.getIndexingMapsArray(), [&](AffineMap m) {
@@ -305,54 +423,23 @@
                                                : collapsingReshapeOp.getSrc());
       continue;
     }
-    if (auto opOperandType =
-            dyn_cast<RankedTensorType>(opOperand->get().getType())) {
-      AffineMap indexingMap = attentionOp.getMatchingIndexingMap(opOperand);
-      RankedTensorType expandedOperandType =
-          getExpandedType(opOperandType, indexingMap, expansionInfo);
-      if (expandedOperandType != opOperand->get().getType()) {
-        // Reshape the operand to get the right type.
-        SmallVector<ReassociationIndices> reassociation =
-            getReassociationForExpansion(indexingMap, expansionInfo);
-        if (failed(reshapeLikeShapesAreCompatible(
-                [&](const Twine &msg) {
-                  return rewriter.notifyMatchFailure(attentionOp, msg);
-                },
-                opOperandType.getShape(), expandedOperandType.getShape(),
-                reassociation,
-                /*isExpandingReshape=*/true)))
-          return std::nullopt;
-        expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
-            loc, expandedOperandType, opOperand->get(), reassociation));
-        continue;
-      }
-    }
-    expandedOpOperands.push_back(opOperand->get());
+    // Reshape the operand to get the right type.
+    std::optional<Value> expanded =
+        expansionInfo.getOrCreateExpanded(loc, opOperand, rewriter);
+    if (!expanded)
+      return std::nullopt;
+    expandedOpOperands.push_back(*expanded);
+    continue;
   }
 
   Value output;
   OpOperand &outOperand = attentionOp.getOutputMutable();
 
-  AffineMap indexingMap = attentionOp.getMatchingIndexingMap(&outOperand);
-  auto opOperandType = cast<RankedTensorType>(outOperand.get().getType());
-  RankedTensorType expandedOutputType =
-      getExpandedType(opOperandType, indexingMap, expansionInfo);
-  if (expandedOutputType != outOperand.get().getType()) {
-    SmallVector<ReassociationIndices> reassociation =
-        getReassociationForExpansion(indexingMap, expansionInfo);
-    if (failed(reshapeLikeShapesAreCompatible(
-            [&](const Twine &msg) {
-              return rewriter.notifyMatchFailure(attentionOp, msg);
-            },
-            opOperandType.getShape(), expandedOutputType.getShape(),
-            reassociation,
-            /*isExpandingReshape=*/true)))
-      return std::nullopt;
-    output = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandedOutputType, outOperand.get(), reassociation);
-  } else {
-    output = outOperand.get();
-  }
+  std::optional<Value> maybeOutput =
+      expansionInfo.getOrCreateExpanded(loc, &outOperand, rewriter);
+  if (!maybeOutput)
+    return std::nullopt;
+  output = *maybeOutput;
 
   Value maskOperand;
   if (expandedOpOperands.size() > 4) {
@@ -377,9 +464,7 @@
     int64_t resultNumber = opResult.getResultNumber();
     if (resultTypes[resultNumber] != opResult.getType()) {
       SmallVector<ReassociationIndices> reassociation =
-          getReassociationForExpansion(
-              attentionOp.getIndexingMapsForResults()[resultNumber],
-              expansionInfo);
+          expansionInfo.getReassoc(attentionOp.getTiedOpOperand(opResult));
       resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
           attentionOp.getLoc(), opResult.getType(),
           fusedOp->getResult(resultNumber), reassociation));
@@ -391,6 +476,64 @@
   return resultVals;
 }
 
+static std::optional<Value>
+fuseScatterWithReshapeByExpansion(ScatterOp scatterOp, Operation *reshapeOp,
+                                  OpOperand *fusableOpOperand,
+                                  PatternRewriter &rewriter) {
+  Location loc = scatterOp.getLoc();
+  // Check if reshape is expanding or collapsing.
+  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
+  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
+  bool isExpanding = (expandingReshapeOp != nullptr);
+  RankedTensorType expandedType = isExpanding
+                                      ? expandingReshapeOp.getResultType()
+                                      : collapsingReshapeOp.getSrcType();
+  ExpansionInfo info;
+  if (failed(info.compute(
+          getScatterReshapeInfo(scatterOp),
+          scatterOp.getStaticLoopRanges().value(), fusableOpOperand,
+          isExpanding ? expandingReshapeOp.getReassociationIndices()
+                      : collapsingReshapeOp.getReassociationIndices(),
+          expandedType.getShape()))) {
+    return std::nullopt;
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(scatterOp);
+
+  OpOperand *update = scatterOp.getDpsInputOperand(0);
+  OpOperand *indices = scatterOp.getDpsInputOperand(1);
+  OpOperand *original = scatterOp.getDpsInitOperand(0);
+  auto newUpdates = info.getOrCreateExpanded(loc, update, rewriter).value();
+  auto newIndices = info.getOrCreateExpanded(loc, indices, rewriter).value();
+  auto newOriginal = info.getOrCreateExpanded(loc, original, rewriter).value();
+
+  auto newScatter = rewriter.create<ScatterOp>(
+      loc, newOriginal.getType(), ValueRange{newUpdates, newIndices},
+      ValueRange{newOriginal}, scatterOp.getDimensionMap(),
+      scatterOp.getUniqueIndices());
+  rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(),
+                              newScatter.getRegion().begin());
+
+  auto originalShapeMap = info.getShapeMap(original);
+  SmallVector<ReassociationIndices> originalReassoc =
+      computeReassocFromShapeMap(originalShapeMap);
+
+  // Collapse back to original shape.
+  if (isIdentityReassoc(originalReassoc)) {
+    return {newScatter.getResult(0)};
+  }
+  auto newCollapse = rewriter.create<tensor::CollapseShapeOp>(
+      loc, scatterOp.getOriginalType(), newScatter.getResult(0),
+      originalReassoc);
+
+  return {newCollapse};
+}
+
+//===----------------------------------------------------------------------===//
+// Fuse By Expansion Patterns
+//===----------------------------------------------------------------------===//
+
 namespace {
 
 // Fold attention with its consumer expand_shape op.
@@ -553,6 +696,74 @@
   linalg::ControlDropUnitDims options;
 };
 
+struct FoldScatterWithProducerReshapeByExpansion final
+    : public OpRewritePattern<ScatterOp> {
+  FoldScatterWithProducerReshapeByExpansion(
+      MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<ScatterOp>(context, benefit),
+        controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
+
+  LogicalResult matchAndRewrite(ScatterOp scatterOp,
+                                PatternRewriter &rewriter) const override {
+    for (OpOperand &opOperand : scatterOp->getOpOperands()) {
+      tensor::CollapseShapeOp reshapeOp =
+          opOperand.get().getDefiningOp<tensor::CollapseShapeOp>();
+      if (!reshapeOp)
+        continue;
+      if (!controlFoldingReshapes(&opOperand))
+        continue;
+
+      std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
+          scatterOp, reshapeOp, &opOperand, rewriter);
+      if (!replacementValue)
+        return failure();
+      rewriter.replaceOp(scatterOp, *replacementValue);
+      return success();
+    }
+    return failure();
+  }
+
+  linalg::ControlFusionFn controlFoldingReshapes;
+};
+
+struct FoldScatterWithConsumerReshapeByExpansion final
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+  FoldScatterWithConsumerReshapeByExpansion(
+      MLIRContext *context, linalg::ControlFusionFn controlFoldingReshapes,
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(controlFoldingReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto producerResult = dyn_cast<OpResult>(expandOp.getSrc());
+    if (!producerResult) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "source not produced by an operation");
+    }
+
+    auto scatterOp = producerResult.getDefiningOp<LinalgExt::ScatterOp>();
+    if (!scatterOp) {
+      return failure();
+    }
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return failure();
+    }
+
+    std::optional<Value> replacementValue = fuseScatterWithReshapeByExpansion(
+        scatterOp, expandOp, scatterOp.getTiedOpOperand(producerResult),
+        rewriter);
+    if (!replacementValue)
+      return failure();
+    rewriter.replaceOp(scatterOp, *replacementValue);
+    return success();
+  }
+
+  linalg::ControlFusionFn controlFoldingReshapes;
+};
+
 } // namespace
 
 /// Return the `reassociation` indices to use to collapse the operand when the
@@ -773,6 +984,10 @@
       patterns.getContext(), controlFoldingReshapes);
   patterns.add<FoldAttentionWithProducerReshapeByExpansion>(
       patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldScatterWithProducerReshapeByExpansion>(
+      patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldScatterWithConsumerReshapeByExpansion>(
+      patterns.getContext(), controlFoldingReshapes);
 }
 
 SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 25e1553..b3fdd90 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -481,3 +481,204 @@
 //  CHECK-SAME:       affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
 //  CHECK-SAME:       ins(%[[ARG2]], %[[ARG1]], %[[COLLAPSED]], %[[ARG3]] :
 //       CHECK:   util.return %[[ATTENTION]]
+
+
+// -----
+
+util.func @scatter_collapse_updates(%arg0: tensor<4x?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0[[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor<?x2x16x4x128xf16>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+  ^bb0(%arg7: f16, %arg8: f16):
+    iree_linalg_ext.yield %arg7 : f16
+  } -> tensor<?x2x16x4x128xf16>
+  util.return %1 : tensor<?x2x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_updates
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//       CHECK:   %[[INDICES:.+]] = tensor.expand_shape
+//  CHECK-SAME:     tensor<?x1xi32> into tensor<4x?x1xi32>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[ARG0]], %[[INDICES]]
+//  CHECK-SAME:       outs(%[[ARG2]]
+//       CHECK:   util.return %[[SCATTER]]
+
+// -----
+
+util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0[[0, 1], [2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x4x16x4x128xf16>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x4x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) {
+  ^bb0(%arg7: f16, %arg8: f16):
+    iree_linalg_ext.yield %arg7 : f16
+  } -> tensor<10x16x4x128xf16>
+  util.return %1 : tensor<10x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_updates_partial
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//   CHECK-DAG:   %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x1xi32>
+//   CHECK-DAG:   %[[UPDATES:.+]] = tensor.collapse_shape %[[ARG0]] {{.*}} tensor<4x?x2x2x16x4x128xf16> into tensor<4x?x4x16x4x128xf16>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[INDICES]]
+//  CHECK-SAME:       outs(%[[ARG2]]
+//       CHECK:   util.return %[[SCATTER]]
+
+// -----
+
+util.func @scatter_collapse_original(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> {
+    %collapsed = tensor.collapse_shape %arg2 [[0], [1, 2], [3, 4], [5, 6]] : tensor<?x2x16x4x2x64x2xf16> into tensor<?x32x8x128xf16>
+    %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) {
+    ^bb0(%arg6: f16, %arg7: f16):
+      iree_linalg_ext.yield %arg6 : f16
+    } -> tensor<?x32x8x128xf16>
+  util.return %1 : tensor<?x32x8x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_original
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//       CHECK:   %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]]
+//  CHECK-SAME:     tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[ARG1]]
+//  CHECK-SAME:       outs(%[[ARG2]]
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]]
+//       CHECK:   util.return %[[COLLAPSE]]
+
+// -----
+
+util.func @scatter_original_noop(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x80x2x32x8x128xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> {
+    %collapsed = tensor.collapse_shape %arg2 [[0, 1, 2], [3], [4], [5]] : tensor<?x80x2x32x8x128xf16> into tensor<?x32x8x128xf16>
+    %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) {
+    ^bb0(%arg6: f16, %arg7: f16):
+      iree_linalg_ext.yield %arg6 : f16
+    } -> tensor<?x32x8x128xf16>
+  util.return %1 : tensor<?x32x8x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_original_noop
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG2]]
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:       outs(%[[COLLAPSE]]
+//       CHECK:   util.return %[[SCATTER]]
+
+
+// -----
+
+util.func @scatter_collapse_original_partial(%arg0: tensor<?x1x32x8x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<5x?x2x16x4x2x64x2xf16>, %arg3 : index) -> tensor<?x32x8x128xf16> {
+    %collapsed = tensor.collapse_shape %arg2 [[0, 1], [2, 3], [4, 5], [6, 7]] : tensor<5x?x2x16x4x2x64x2xf16> into tensor<?x32x8x128xf16>
+    %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>) outs(%collapsed : tensor<?x32x8x128xf16>) {
+    ^bb0(%arg6: f16, %arg7: f16):
+      iree_linalg_ext.yield %arg6 : f16
+    } -> tensor<?x32x8x128xf16>
+  util.return %1 : tensor<?x32x8x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_original_partial
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//   CHECK-DAG:   %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x1x32x8x128xf16> into tensor<?x1x2x16x4x2x64x2xf16>
+// TODO(IanWood1): fix this so the collapse folds with the expand
+//   CHECK-DAG:   %[[ORIGINAL:.+]] = tensor.expand_shape {{.*}} tensor<?x32x8x128xf16> into tensor<?x2x16x4x2x64x2xf16>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[ARG1]]
+//  CHECK-SAME:       outs(%[[ORIGINAL]]
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[SCATTER]]
+//       CHECK:   util.return %[[COLLAPSE]]
+
+// -----
+
+util.func @scatter_collapse_indices(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<4x?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
+  %collapsed = tensor.collapse_shape %arg1[[0, 1], [2]] : tensor<4x?x1xi32> into tensor<?x1xi32>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+  ^bb0(%arg7: f16, %arg8: f16):
+    iree_linalg_ext.yield %arg7 : f16
+  } -> tensor<?x2x16x4x128xf16>
+  util.return %1 : tensor<?x2x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_indices
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//   CHECK-DAG:   %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<4x?x2x16x4x128xf16>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[ARG1]]
+//  CHECK-SAME:       outs(%[[ARG2]]
+//       CHECK:   util.return %[[SCATTER]]
+
+// -----
+
+util.func @scatter_collapse_indices_partial(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<4x?x1x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
+  %collapsed = tensor.collapse_shape %arg1[[0, 1], [2, 3]] : tensor<4x?x1x1xi32> into tensor<?x1xi32>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %collapsed: tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+  ^bb0(%arg7: f16, %arg8: f16):
+    iree_linalg_ext.yield %arg7 : f16
+  } -> tensor<?x2x16x4x128xf16>
+  util.return %1 : tensor<?x2x16x4x128xf16>
+}
+
+// CHECK-LABEL: util.func public @scatter_collapse_indices_partial
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//   CHECK-DAG:   %[[UPDATES:.+]] = tensor.expand_shape {{.*}} tensor<?x2x16x4x128xf16> into tensor<4x?x2x16x4x128xf16>
+//   CHECK-DAG:   %[[ORIGINAL:.+]] = tensor.collapse_shape {{.*}} tensor<4x?x1x1xi32> into tensor<4x?x1xi32>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[ORIGINAL]]
+//  CHECK-SAME:       outs(%[[ARG2]]
+//       CHECK:   util.return %[[SCATTER]]
+
+// -----
+
+util.func public @scatter_collapse(%arg0: tensor<?x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x4x32xf16> {
+  %c0 = arith.constant 0 : index
+  %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
+  ^bb0(%arg3: f16, %arg4: f16):
+    iree_linalg_ext.yield %arg3 : f16
+  } -> tensor<?x2x16x4x128xf16>
+  %dim = tensor.dim %arg0, %c0 : tensor<?x2x16x4x128xf16>
+  %expanded = tensor.expand_shape %0 [[0], [1], [2], [3], [4, 5]] output_shape [%dim, 2, 16, 4, 4, 32] : tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16>
+  util.return %expanded : tensor<?x2x16x4x4x32xf16>
+}
+// CHECK-LABEL: util.func public @scatter_collapse
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//   CHECK-DAG:   %[[UPDATES:.+]] = tensor.expand_shape %[[ARG0]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16>
+//   CHECK-DAG:   %[[ORIGINAL:.+]] = tensor.expand_shape %[[ARG2]] {{.*}} tensor<?x2x16x4x128xf16> into tensor<?x2x16x4x4x32xf16>
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[UPDATES]], %[[ARG1]]
+//  CHECK-SAME:       outs(%[[ORIGINAL]]
+//       CHECK:   util.return %[[SCATTER]]
+
+// -----
+
+util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x1xi32>, %arg2: tensor<128xf16>) -> tensor<4x4x4x2xf16> {
+  %c0 = arith.constant 0 : index
+  %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<10xf16>, tensor<10x1xi32>) outs(%arg2 : tensor<128xf16>) {
+  ^bb0(%arg3: f16, %arg4: f16):
+    iree_linalg_ext.yield %arg3 : f16
+  } -> tensor<128xf16>
+  %expanded = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape[4, 4, 4, 2] : tensor<128xf16> into tensor<4x4x4x2xf16>
+  util.return %expanded : tensor<4x4x4x2xf16>
+}
+// CHECK-LABEL: util.func public @scatter_collapse_noop
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
+//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:       outs(%[[ARG2]]
+//       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[SCATTER]]
+//       CHECK:   util.return %[[EXPANDED]]