[LinalgExt] Update scatter to allow dropping unit dims (#19704)

Changes `scatter` to enforce that `updates` is only batch dims +
contiguous slice. This allows us to drop fold the unit dimension from
`indices` when index_depth is 1.

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
index dc8b002..93db396 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
@@ -10,6 +10,8 @@
 #include "compiler/plugins/input/Torch/InputConversion/Passes.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -61,10 +63,33 @@
     for (int i = 0; i < numIndices; i++)
       dimMap[i] = i;
 
-    auto scatterOp = rewriter.create<IREE::LinalgExt::ScatterOp>(
-        op.getLoc(), op->getResultTypes(), op.getInputs(), op.getOutputs(),
-        dimMap, op.getUniqueIndices());
+    auto updatesTy = op.getUpdateType();
 
+    // Create a reassociation that drops all unit dims from the indexed portion
+    // slice.
+    Value updateVal = op.updates();
+    SmallVector<int64_t> collapsedShape;
+    collapsedShape.push_back(updatesTy.getShape().front());
+    if (op.getUpdateSliceRank() > 0) {
+      llvm::append_range(collapsedShape,
+                         updatesTy.getShape().take_back(
+                             op.getUpdateSliceRank() - op.getIndexDepth()));
+    }
+    if (collapsedShape != updatesTy.getShape()) {
+      auto reassocIndices = getReassociationIndicesForCollapse(
+          updatesTy.getShape(), collapsedShape);
+      if (!reassocIndices.has_value()) {
+        return rewriter.notifyMatchFailure(
+            op, "failed to compute reassociation indices");
+      }
+      updateVal = rewriter.create<tensor::CollapseShapeOp>(
+          op.getLoc(), updateVal, reassocIndices.value());
+    }
+
+    Value indicesVal = op.indices();
+    auto scatterOp = rewriter.create<IREE::LinalgExt::ScatterOp>(
+        op.getLoc(), op->getResultTypes(), ValueRange{updateVal, indicesVal},
+        op.getOutputs(), dimMap, op.getUniqueIndices());
     rewriter.inlineRegionBefore(op.getRegion(), scatterOp.getRegion(),
                                 scatterOp.getRegion().begin());
     rewriter.replaceOp(op, scatterOp->getResults());
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c1f452e..eed2a45 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -147,93 +147,77 @@
   }
 
   auto indicesType = getIndicesType();
-  if (indicesType.getRank() < 2 ||
+  if (indicesType.getRank() < 1 ||
       !isa<IntegerType>(indicesType.getElementType())) {
-    return op->emitOpError("expected indices to be of rank 2 or greater and of "
+    return op->emitOpError("expected indices to be of rank 1 or greater and of "
                            "integer element type");
   }
-  auto indexDepth = getIndexDepth();
-  if (ShapedType::isDynamic(indexDepth)) {
-    return op->emitOpError("expected index depth is static");
-  }
 
   ArrayRef<int64_t> dimMap = getDimensionMap();
-  if (dimMap.size() != indexDepth) {
-    return op->emitOpError("invalid number of dimension map entries");
-  }
-
-  auto originalType = getOriginalType();
   if (failed(isPermSequence(
           [&]() { return this->emitOpError("dimension map is invalid."); },
           dimMap))) {
     return failure();
   }
 
-  if (indexDepth > originalType.getShape().size()) {
-    return op->emitOpError(
-        "index depth is greater than the rank of the original value");
+  if (dimMap.size() == 0) {
+    return op->emitOpError("dimension map must have at least one element");
   }
 
+  const size_t indexDepth = getIndexDepth();
+  auto originalType = getOriginalType();
   auto updateType = getUpdateType();
-  auto batchRank = indicesType.getRank() - 1;
-  if (updateType.getRank() < batchRank) {
-    return op->emitOpError("expected update value to be of rank greater than "
-                           "or equal to rank(indices) - 1")
-           << batchRank;
-  }
-
-  // Validate the shape of indices and update value match for the first
-  // `batchRank` dims.
-  auto [indicesIt, updateIt] =
-      llvm::mismatch(indicesType.getShape().take_front(batchRank),
-                     updateType.getShape().take_front(batchRank));
-  if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
+  const auto originalSliceRank = originalType.getRank() - indexDepth;
+  if (originalSliceRank < 0) {
     return op->emitOpError(
-               "mismatch in shape of indices and update value at dim#")
-           << (indicesIt - indicesType.getShape().begin());
+        "expected original rank to be greater or equal to index depth");
   }
-
-  if (updateType.getRank() - batchRank > originalType.getRank()) {
-    return op->emitOpError("update operand's slice rank (")
-           << updateType.getRank() - batchRank
-           << " = rank(updates) - batch rank) exceeds the rank of the original "
-              "value ("
-           << originalType.getRank() << ")";
-  }
-
-  // TODO: make it illegal for `numImplicitDims` to be non-zero.
-  auto numImplicitDims = originalType.getRank() - getUpdateSliceRank();
-  if (numImplicitDims > indexDepth) {
+  if (updateType.getRank() < originalSliceRank) {
     return op->emitOpError(
-        "update and index depth does not fully index original");
+        "expected update to be at least the rank of non indexed original dims");
+  }
+  const size_t batchRank = updateType.getRank() - originalSliceRank;
+
+  if (updateType.getRank() - batchRank != originalSliceRank) {
+    return op->emitOpError("expected rank of update value - batch rank to be "
+                           "equal to rank of original value - index depth");
+  }
+
+  if ((indicesType.getRank() != batchRank || indexDepth != 1) &&
+      indicesType.getRank() != batchRank + 1) {
+    return op->emitOpError("expected indices to be equal to batch rank "
+                           "or batch rank + 1");
+  }
+
+  {
+    // Validate the shape of indices and update value match for the first
+    // `batchRank` dims.
+    auto [indicesIt, updateIt] =
+        llvm::mismatch(indicesType.getShape().take_front(batchRank),
+                       updateType.getShape().take_front(batchRank));
+    if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
+      return op->emitOpError(
+                 "mismatch in shape of indices and update value at dim#")
+             << (indicesIt - indicesType.getShape().begin());
+    }
+  }
+  if (batchRank + 1 < indicesType.getShape().size() &&
+      dimMap.size() != indicesType.getShape().back()) {
+    return op->emitOpError(
+        "size of dimension map must match the last dimension of indices");
   }
 
   // updateSlice[0..indexDepth] <= original[0..indexDepth]
   // updateSlice[indexDepth..] == original[indexDepth..]
-  auto updateSliceShape = getUpdateSliceShape();
-  for (uint64_t fullSliceIdx :
-       llvm::seq<uint64_t>(numImplicitDims, indexDepth)) {
-    int64_t originalDim = fullSliceIdx;
-    int64_t updateSliceDim = fullSliceIdx - numImplicitDims;
-    if (!originalType.isDynamicDim(originalDim) &&
-        updateSliceShape[updateSliceDim] >
-            originalType.getDimSize(originalDim)) {
-      return op->emitOpError("shape of update value dim#")
-             << updateSliceDim + batchRank << " exceeds original value at dim#"
-             << originalDim;
-    }
-  }
 
-  for (auto fullSliceIdx :
-       llvm::seq<int64_t>(indexDepth, originalType.getRank())) {
-    int64_t originalDim = fullSliceIdx;
-    int64_t updateSliceDim = fullSliceIdx - numImplicitDims;
-    if (!originalType.isDynamicDim(originalDim) &&
-        updateSliceShape[updateSliceDim] !=
-            originalType.getDimSize(originalDim)) {
+  {
+    auto [updateIt, originalIt] = llvm::mismatch(
+        getUpdateSliceShape(), originalType.getShape().drop_front(indexDepth));
+    if (updateIt != getUpdateSliceShape().end()) {
       return op->emitOpError("shape of update value dim#")
-             << updateSliceDim + batchRank
-             << " must match original value at dim#" << originalDim;
+             << (updateIt - updateType.getShape().begin())
+             << " must match original value at dim#"
+             << (originalIt - originalType.getShape().begin());
     }
   }
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 7d43c09..b6c5791 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -182,9 +182,7 @@
     static constexpr unsigned kOriginalOpNum = 2;
 
     int64_t getIndexDepth() {
-      return cast<ShapedType>(getDpsInputOperand(1)->get().getType())
-          .getShape()
-          .back();
+      return getDimensionMap().size();
     }
 
     Value getUpdates() {
@@ -214,7 +212,7 @@
     /// Utility to get the rank of the portion of `indices` that
     /// represents the batch dimensions
     int64_t getBatchRank() {
-      return getIndicesType().getRank() - 1;
+      return getUpdateType().getRank() - getUpdateSliceRank();
     }
 
     /// Utility to get the shape of the portion of `indices` that
@@ -226,7 +224,7 @@
     /// Utility to get the rank of the portion of `updates` that
     /// is scattered into `original`.
     int64_t getUpdateSliceRank() {
-      return getUpdateType().getRank() - getBatchRank();
+      return getOriginalType().getRank() - getIndexDepth();
     }
 
     /// Utility to get the shape of the portion of `updates` that
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
index ee7d053..21160d5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
@@ -117,9 +117,11 @@
   // Slice of indices.
   auto indicesRank = getIndicesType().getRank();
   SmallVector<OpFoldResult> indicesOffsets(offsets.take_front(getBatchRank()));
-  indicesOffsets.push_back(zeroAttr);
   SmallVector<OpFoldResult> indicesSizes(sizes.take_front(getBatchRank()));
-  indicesSizes.push_back(builder.getIndexAttr(getIndexDepth()));
+  if (getBatchRank() != getIndicesType().getRank()) {
+    indicesOffsets.push_back(zeroAttr);
+    indicesSizes.push_back(builder.getIndexAttr(getIndexDepth()));
+  }
   SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
 
   Operation *indicesSlice = getSlice(builder, loc, getIndices(), indicesOffsets,
@@ -228,7 +230,6 @@
   SmallVector<Value> starts;
   SmallVector<Value> loadIndices;
   append_range(loadIndices, ivs.take_front(getBatchRank()));
-  loadIndices.push_back(Value());
 
   // Populate with empty values.
   auto originalTy = getOriginalType();
@@ -242,8 +243,13 @@
 
   ArrayRef<int64_t> dimMap = getDimensionMap();
 
+  if (getIndicesType().getRank() > getBatchRank()) {
+    loadIndices.push_back(Value());
+  }
   for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
-    loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
+    if (getIndicesType().getRank() > getBatchRank()) {
+      loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
+    }
     Value idx = b.create<memref::LoadOp>(loc, getIndices(), loadIndices);
     Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 93581a1..8df9b8d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -57,12 +57,12 @@
 
 // -----
 
-func.func @scatter_mistmatch_dim_map_entries(
-    %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
+func.func @scatter_empty_dim_map(
+    %update : tensor<?x?xf32>, %indices : tensor<?x2xi32>,
     %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // expected-error @+1 {{invalid number of dimension map entries}}
-  %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
-      ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+  // expected-error @below {{'iree_linalg_ext.scatter' op dimension map must have at least one element}}
+  %0 = iree_linalg_ext.scatter dimension_map = [] unique_indices(true)
+      ins(%update, %indices : tensor<?x?xf32>, tensor<?x2xi32>)
       outs(%original : tensor<?x?xf32>) {
       ^bb0(%arg1: f32, %arg2: f32):
         %1 = arith.addf %arg1, %arg2 : f32
@@ -190,16 +190,16 @@
 
 func.func @scatter_dim_mismatch(
     %update : tensor<48x?x2x11xf32>, %indices : tensor<48x?x1xi32>,
-    %original : tensor<?x10xf32>) -> tensor<?x10xf32> {
-  // expected-error @below {{'iree_linalg_ext.scatter' op shape of update value dim#3 must match original value at dim#1}}
+    %original : tensor<2x?x10xf32>) -> tensor<2x?x10xf32> {
+  // expected-error @below {{'iree_linalg_ext.scatter' op shape of update value dim#2 must match original value at dim#1}}
   %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
     ins(%update, %indices : tensor<48x?x2x11xf32>, tensor<48x?x1xi32>)
-    outs(%original : tensor<?x10xf32>) {
+    outs(%original : tensor<2x?x10xf32>) {
     ^bb0(%arg1: f32, %arg2: f32):
       %1 = arith.addf %arg1, %arg2 : f32
       iree_linalg_ext.yield %1 : f32
-    } -> tensor<?x10xf32>
-  return %0 : tensor<?x10xf32>
+    } -> tensor<2x?x10xf32>
+  return %0 : tensor<2x?x10xf32>
 }
 
 // -----
@@ -207,7 +207,7 @@
 func.func @scatter_rank_mismatch(
     %update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
     %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}}
+  // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
   %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
     ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
     outs(%original : tensor<?x?xf32>) {
@@ -223,7 +223,7 @@
 func.func @scatter_rank_mismatch(
     %update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
     %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}}
+  // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
   %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
     ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
     outs(%original : tensor<?x?xf32>) {
@@ -239,7 +239,7 @@
 func.func @scatter_rank_mismatch(
     %update : tensor<?x?x?x?x?xf32>, %indices : tensor<?x?x1xi32>,
     %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}}
+  // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
   %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
     ins(%update, %indices : tensor<?x?x?x?x?xf32>, tensor<?x?x1xi32>)
     outs(%original : tensor<?x?xf32>) {
@@ -387,27 +387,10 @@
 
 // -----
 
-func.func @scatter_index_depth_dynamic(
-    %update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
-    %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
-  // expected-error @+1 {{expected index depth is static}}
-  %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
-    ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
-    outs(%original : tensor<?x?xi64>) {
-    ^bb0(%arg1: i64, %arg2: i64):
-      %1 = arith.addi %arg1, %arg2 : i64
-      %2 = arith.trunci %1 : i64 to i32
-      iree_linalg_ext.yield %1, %2 : i64, i32
-    } -> tensor<?x?xi64>
-  return %0 : tensor<?x?xi64>
-}
-
-// -----
-
 func.func @scatter_index_depth_too_large(
     %original: tensor<?x?xf32>, %indices: tensor<?x3xi32>,
     %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // expected-error @below {{'iree_linalg_ext.scatter' op index depth is greater than the rank of the original value}}
+  // expected-error @below {{'iree_linalg_ext.scatter' op expected update to be at least the rank of non indexed original dims}}
   %0 = iree_linalg_ext.scatter
     dimension_map = [0, 1, 2]
     unique_indices(true)
@@ -425,7 +408,7 @@
 func.func @scatter_index_depth_too_small(
     %update : tensor<?x1xf32>, %indices : tensor<?x1xi32>,
     %original : tensor<?x?x1xf32>) -> tensor<?x?xf32> {
-  // expected-error @below {{'iree_linalg_ext.scatter' op update and index depth does not fully index original}}
+  // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
   %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
     ins(%update, %indices : tensor<?x1xf32>, tensor<?x1xi32>)
     outs(%original : tensor<?x?x1xf32>) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 1bc505f..c9c9b4b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -101,6 +101,34 @@
 
 // -----
 
+func.func @scatter_tensor_dynamic_implicit_indices(
+    %original: tensor<?x?xf32>, %indices: tensor<?xi32>,
+    %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = iree_linalg_ext.scatter
+    dimension_map = [0]
+    unique_indices(true)
+    ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+    outs(%original: tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = arith.addf %arg1, %arg2 : f32
+      iree_linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @scatter_tensor_dynamic_implicit_indices(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?xi32>
+//  CHECK-SAME:   %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:     dimension_map = [0]
+//  CHECK-SAME:     unique_indices(true)
+//  CHECK-SAME:     ins(%[[UPDATE]], %[[INDICES]]
+//  CHECK-SAME:     outs(%[[ORIGINAL]]
+//       CHECK:     iree_linalg_ext.yield %{{.+}} : f32
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @scatter_repeated_tensor_dynamic(
     %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
     %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -157,6 +185,34 @@
 
 // -----
 
+func.func @scatter_tensor_static_implicit_indices(
+    %original: tensor<128x3xf32>, %indices: tensor<48xi32>,
+    %update: tensor<48x3xf32>) -> tensor<128x3xf32> {
+  %0 = iree_linalg_ext.scatter
+    dimension_map = [0]
+    unique_indices(true)
+    ins(%update, %indices : tensor<48x3xf32>, tensor<48xi32>)
+    outs(%original: tensor<128x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = arith.addf %arg1, %arg2 : f32
+      iree_linalg_ext.yield %1 : f32
+    } -> tensor<128x3xf32>
+  return %0 : tensor<128x3xf32>
+}
+// CHECK-LABEL: func.func @scatter_tensor_static_implicit_indices(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<128x3xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48xi32>
+//  CHECK-SAME:   %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32>
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.scatter
+//       CHECK:     dimension_map = [0]
+//  CHECK-SAME:     unique_indices(true)
+//  CHECK-SAME:     ins(%[[UPDATE]], %[[INDICES]]
+//  CHECK-SAME:     outs(%[[ORIGINAL]]
+//       CHECK:     iree_linalg_ext.yield %{{.+}} : f32
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @scatter_tensor_multi_index_depth(
     %original: tensor<1x128x3xf32>, %indices: tensor<48x2xi32>,
     %update: tensor<48x3xf32>) -> tensor<1x128x3xf32> {
@@ -270,12 +326,12 @@
 // -----
 
 func.func @scatter_update_scalar_1D(
-    %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
+    %original: tensor<8xi32>, %indices: tensor<3xi32>,
     %updates: tensor<3xi32>) -> tensor<8xi32> {
   %0 = iree_linalg_ext.scatter
     dimension_map = [0]
     unique_indices(true)
-    ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
+    ins(%updates, %indices : tensor<3xi32>, tensor<3xi32>)
     outs(%original : tensor<8xi32>)  {
     ^bb0(%arg0: i32, %arg1: i32):  // no predecessors
       iree_linalg_ext.yield %arg0 : i32
@@ -483,11 +539,11 @@
 
 func.func @scatter_update_slice_2D(
     %original: tensor<4x?xi32>, %indices: tensor<1x1xi32>,
-    %updates: tensor<1x3xi32>) -> tensor<4x?xi32> {
+    %updates: tensor<1x?xi32>) -> tensor<4x?xi32> {
   %0 = iree_linalg_ext.scatter
     dimension_map = [0]
     unique_indices(true)
-    ins(%updates, %indices : tensor<1x3xi32>, tensor<1x1xi32>)
+    ins(%updates, %indices : tensor<1x?xi32>, tensor<1x1xi32>)
     outs(%original : tensor<4x?xi32>)  {
     ^bb0(%arg0: i32, %arg1: i32):  // no predecessors
       iree_linalg_ext.yield %arg0 : i32
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 6ad28e5..b49f1e2 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -8,7 +8,6 @@
 // The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and
 // modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`.
 
-#include <optional>
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
 #include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
@@ -19,6 +18,9 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
 
+#include <cstdint>
+#include <optional>
+
 namespace mlir::iree_compiler::IREE::LinalgExt {
 
 static bool
@@ -315,13 +317,7 @@
     ReshapeOperandInfo updateInfo;
     updateInfo.originalShape = scatterOp.getUpdateType().getShape();
     llvm::append_range(updateInfo.operandToIterationSpace,
-                       llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
-    updateInfo.operandToIterationSpace.append(
-        updateRank - (rankOfContiguousSlice + scatterOp.getBatchRank()),
-        ReshapeOperandInfo::kNoMapping);
-    llvm::append_range(
-        updateInfo.operandToIterationSpace,
-        llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
+                       llvm::seq<int64_t>(0, updateRank));
     infos.push_back(std::move(updateInfo));
   }
 
@@ -331,8 +327,9 @@
     indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
     llvm::append_range(indicesInfo.operandToIterationSpace,
                        llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
-    indicesInfo.operandToIterationSpace.push_back(
-        ReshapeOperandInfo::kNoMapping);
+    if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank())
+      indicesInfo.operandToIterationSpace.push_back(
+          ReshapeOperandInfo::kNoMapping);
     infos.push_back(std::move(indicesInfo));
   }
 
@@ -629,71 +626,30 @@
   linalg::ControlFusionFn controlFoldingReshapes;
 };
 
-/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
-/// The dims in `update` between the batch dims and the continuous slice
-/// represent the indexed dimensions. Remove the leading unit dims from the
-/// indexed dims.
-struct FoldScatterNonIterationUnitDims final
-    : public OpRewritePattern<ScatterOp> {
-  FoldScatterNonIterationUnitDims(MLIRContext *context,
-                                  linalg::ControlDropUnitDims options,
-                                  PatternBenefit benefit = 1)
-      : OpRewritePattern<ScatterOp>(context, benefit),
-        options(std::move(options)) {}
-
+struct DropScatterUnitIndexDepth final : public OpRewritePattern<ScatterOp> {
+  using OpRewritePattern<ScatterOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp scatterOp,
                                 PatternRewriter &rewriter) const override {
-    if (options.rankReductionStrategy !=
-        linalg::ControlDropUnitDims::RankReductionStrategy::
-            ReassociativeReshape) {
-      return rewriter.notifyMatchFailure(
-          scatterOp, "Only reassociative reshape strategy supported");
-    }
-    llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
-    const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();
-
-    // Find the number of leading unit dimensions
-    int64_t rankOfContiguousSlice =
-        scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
-    ArrayRef<int64_t> indexedDims =
-        scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice);
-    int64_t numDimsToDrop =
-        llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) -
-        scatterOp.getUpdateSliceShape().begin() - 1;
-
+    llvm::ArrayRef<int64_t> indicesShape =
+        scatterOp.getIndicesType().getShape();
     int64_t batchRank = scatterOp.getBatchRank();
-    llvm::erase_if(canDrop, [&](unsigned dimPos) {
-      return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
-    });
-    if (canDrop.empty()) {
+    if (indicesShape.size() == batchRank || indicesShape.back() != 1) {
       return failure();
     }
-
-    SmallVector<int64_t> droppedUpdateShape;
-    droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
-    for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
-      if (!llvm::is_contained(canDrop, idx)) {
-        droppedUpdateShape.push_back(dimLen);
-      }
+    SmallVector<ReassociationIndices> reassoc;
+    reassoc.reserve(indicesShape.size());
+    for (auto i : llvm::seq<int64_t>(0, batchRank - 1)) {
+      reassoc.emplace_back(1, i);
     }
-
-    auto reassoc =
-        getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
-    assert(reassoc.has_value() && "expected reassociation to be valid");
+    reassoc.push_back(ReassociationIndices{batchRank - 1, batchRank});
     auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
-        scatterOp.getLoc(),
-        RankedTensorType::get(droppedUpdateShape,
-                              scatterOp.getUpdateType().getElementType()),
-        scatterOp.getUpdates(), reassoc.value());
+        scatterOp.getLoc(), scatterOp.getIndices(), reassoc);
 
     rewriter.modifyOpInPlace(scatterOp, [&]() {
-      scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult());
+      scatterOp.setOperand(ScatterOp::kIndicesOpNum, collapseOp.getResult());
     });
     return success();
   }
-
-private:
-  linalg::ControlDropUnitDims options;
 };
 
 struct FoldScatterWithProducerReshapeByExpansion final
@@ -997,7 +953,7 @@
 
 void populateFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
-  patterns.add<FoldScatterNonIterationUnitDims>(patterns.getContext(), options);
+  patterns.add<DropScatterUnitIndexDepth>(patterns.getContext());
 }
 
 } // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index c5b624b..0cfd0b7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -356,42 +356,6 @@
 
 // -----
 
-func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
-  iree_linalg_ext.scatter
-    dimension_map = [0, 1, 2]
-    unique_indices(true)
-    ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
-    outs(%arg0 : memref<2x64x12xf32>) {
-  ^bb0(%arg3: f32, %arg4: f32):
-    iree_linalg_ext.yield %arg4 : f32
-  }
-  return
-}
-
-// CHECK-LABEL: @scatter_partial_slices
-// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[C12:.+]] = arith.constant 12 : index
-// CHECK:     scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
-// CHECK-NEXT:   scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK-NEXT:     scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] {
-// CHECK-NEXT:       %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32>
-// CHECK-NEXT:       %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index
-// CHECK-NEXT:       %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32>
-// CHECK-NEXT:       %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index
-// CHECK-NEXT:       %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index
-// CHECK-NEXT:       %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32>
-// CHECK-NEXT:       %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index
-// CHECK-NEXT:       %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index
-// CHECK-NEXT:       %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
-// CHECK-NEXT:       memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
-
-// -----
-
 func.func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) {
   %stage = arith.constant 1 : index
   iree_linalg_ext.fft
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 403f5e2..53cb5e4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -198,9 +198,9 @@
 
 func.func @scatter_batch_2D(
     %original: memref<?xi32>, %indices: memref<?x?x1xi32>,
-    %updates: memref<?x?x?xi32>) {
+    %updates: memref<?x?xi32>) {
   iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
-    ins(%updates, %indices : memref<?x?x?xi32>, memref<?x?x1xi32>)
+    ins(%updates, %indices : memref<?x?xi32>, memref<?x?x1xi32>)
     outs(%original : memref<?xi32>)  {
   ^bb0(%arg0: i32, %arg1: i32):  // no predecessors
     iree_linalg_ext.yield %arg0 : i32
@@ -221,27 +221,23 @@
 // CHECK-SAME:    %[[UPDATES:[a-zA-Z0-9]+]]
 // CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:     %[[C2:.+]] = arith.constant 2 : index
 // CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 // CHECK-DAG:     %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]]
 // CHECK-DAG:     %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]]
-// CHECK-DAG:     %[[D2:.+]] = memref.dim %[[UPDATES]], %[[C2]]
 // CHECK:         scf.for %[[I:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
 // CHECK:           %[[SZ:.+]] = affine.min #[[MAP]](%[[I]])[%[[D1]]]
 // CHECK:           %[[UPDATES_TILE:.+]] = memref.subview
-// CHECK-SAME:             %[[UPDATES]][0, %[[I]], 0]
-// CHECK-SAME:             [%[[D0]], %[[SZ]], %[[D2]]]
+// CHECK-SAME:             %[[UPDATES]][0, %[[I]]]
+// CHECK-SAME:             [%[[D0]], %[[SZ]]]
 // CHECK:           %[[INDICES_TILE:.+]] = memref.subview
 // CHECK-SAME:             %[[INDICES]][0, %[[I]], 0]
 // CHECK-SAME:             [%[[D0]], %[[SZ]], 1]
 // CHECK:           %[[ORIGINAL_TILE:.+]] = memref.subview
 // CHECK-SAME:             %[[ORIGINAL]][0]
-// CHECK-SAME:             [%[[D2]]]
-// CHECK:           %[[ORIG_CAST:.+]] = memref.cast %[[ORIGINAL_TILE]]
 // CHECK:           iree_linalg_ext.scatter
 // CHECK-SAME:             unique_indices(true)
 // CHECK-SAME:             ins(%[[UPDATES_TILE]], %[[INDICES_TILE]]
-// CHECK-SAME:             outs(%[[ORIG_CAST]]
+// CHECK-SAME:             outs(%[[ORIGINAL_TILE]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 244b02e..247fc7e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -507,8 +507,8 @@
 // -----
 
 util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> {
-  %collapsed = tensor.collapse_shape %arg0[[0, 1], [2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x4x16x4x128xf16>
-  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x4x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) {
+  %collapsed = tensor.collapse_shape %arg0[[0, 1, 2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x16x4x128xf16>
+  %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) {
   ^bb0(%arg7: f16, %arg8: f16):
     iree_linalg_ext.yield %arg7 : f16
   } -> tensor<10x16x4x128xf16>
@@ -519,10 +519,9 @@
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]:
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]:
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]:
-//   CHECK-DAG:   %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x1xi32>
-//   CHECK-DAG:   %[[UPDATES:.+]] = tensor.collapse_shape %[[ARG0]] {{.*}} tensor<4x?x2x2x16x4x128xf16> into tensor<4x?x4x16x4x128xf16>
+//   CHECK-DAG:   %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x2x2x1xi32>
 //       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
-//  CHECK-SAME:       ins(%[[UPDATES]], %[[INDICES]]
+//  CHECK-SAME:       ins(%[[ARG0]], %[[INDICES]]
 //  CHECK-SAME:       outs(%[[ARG2]]
 //       CHECK:   util.return %[[SCATTER]]
 
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
index 62dbba7..e92ed2e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
@@ -109,44 +109,21 @@
 
 // -----
 
-util.func public @scatter0(%arg0: tensor<?x1x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
-  %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
-  ^bb0(%arg3: f16, %arg4: f16):
-    iree_linalg_ext.yield %arg3 : f16
-  } -> tensor<?x2x16x4x128xf16>
-  util.return %0 : tensor<?x2x16x4x128xf16>
+util.func public @scatter(%arg0 : tensor<4xi64>, %arg1 : tensor<4x1xi32>, %arg2 : tensor<4xi64>) -> tensor<4xi64> {
+  %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%arg0, %arg1: tensor<4xi64>, tensor<4x1xi32>) outs(%arg2 : tensor<4xi64>) {
+  ^bb0(%arg3: i64, %arg4: i64):
+    %16 = arith.addi %arg4, %arg3 : i64
+    iree_linalg_ext.yield %16 : i64
+  } -> tensor<4xi64>
+  util.return %0 : tensor<4xi64>
 }
-// CHECK-LABEL: func public @scatter0
-//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
-//  CHECK-SAME:     to tensor<?x2x16x4x128xf16>
+// CHECK-LABEL: func public @scatter
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]]
+//  CHECK-SAME:     tensor<4x1xi32> into tensor<4xi32>
 //       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
-//  CHECK-SAME:     ins(%[[COLLAPSE]]
-
-// -----
-
-util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
-  %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x2xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
-  ^bb0(%arg3: f16, %arg4: f16):
-    iree_linalg_ext.yield %arg3 : f16
-  } -> tensor<?x2x16x4x128xf16>
-  util.return %0 : tensor<?x2x16x4x128xf16>
-}
-// CHECK-LABEL: func public @scatter1
-//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
-//  CHECK-SAME:     to tensor<?x16x4x128xf16>
-//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
-//  CHECK-SAME:     ins(%[[COLLAPSE]]
-
-// -----
-
-// TODO: remove other unit dims.
-util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor<?x2x1x4x128xf16>) -> tensor<?x2x1x4x128xf16> {
-  %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor<?x2x1x4x128xf16>) {
-  ^bb0(%arg3: f16, %arg4: f16):
-    iree_linalg_ext.yield %arg3 : f16
-  } -> tensor<?x2x1x4x128xf16>
-  util.return %0 : tensor<?x2x1x4x128xf16>
-}
-// CHECK-LABEL: func public @scatter_noop
-//   CHECK-NOT:   tensor.collapse_shape
-//       CHECK:   %[[SCATTER:.+]] = iree_linalg_ext.scatter
+//  CHECK-SAME:     ins(%[[ARG0]], %[[COLLAPSED]]
+//  CHECK-SAME:     outs(%[[ARG2]]
+//       CHECK:   util.return %[[SCATTER]]
diff --git a/tests/e2e/linalg_ext_ops/scatter.mlir b/tests/e2e/linalg_ext_ops/scatter.mlir
index 808475a..8980f20 100644
--- a/tests/e2e/linalg_ext_ops/scatter.mlir
+++ b/tests/e2e/linalg_ext_ops/scatter.mlir
@@ -15,39 +15,6 @@
   return
 }
 
-func.func @scatter_2d_origin_slice_horizontal() {
-  %original = util.unfoldable_constant dense<0> : tensor<2x2xi32>
-  %update = util.unfoldable_constant dense<1> : tensor<1x2xi32>
-  %indices = util.unfoldable_constant dense<0> : tensor<1x2xi32>
-  %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
-                          ins(%update, %indices : tensor<1x2xi32>, tensor<1x2xi32>)
-                          outs(%original : tensor<2x2xi32>) {
-                    ^bb0(%arg0: i32, %arg1: i32):
-                      iree_linalg_ext.yield %arg0 : i32
-  } -> tensor<2x2xi32>
-
-  check.expect_eq_const(%result, dense<[[1, 1], [0, 0]]> : tensor<2x2xi32>) : tensor<2x2xi32>
-
-  return
-}
-
-
-func.func @scatter_2d_origin_slice_vertical() {
-  %original = util.unfoldable_constant dense<0> : tensor<2x2xi32>
-  %update = util.unfoldable_constant dense<1> : tensor<1x2x1xi32>
-  %indices = util.unfoldable_constant dense<0> : tensor<1x2xi32>
-  %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
-                          ins(%update, %indices : tensor<1x2x1xi32>, tensor<1x2xi32>)
-                          outs(%original : tensor<2x2xi32>) {
-                    ^bb0(%arg0: i32, %arg1: i32):
-                      iree_linalg_ext.yield %arg0 : i32
-  } -> tensor<2x2xi32>
-
-  check.expect_eq_const(%result, dense<[[1, 0], [1, 0]]> : tensor<2x2xi32>) : tensor<2x2xi32>
-
-  return
-}
-
 func.func @scatter_2d_offset() {
   %original = util.unfoldable_constant dense<0> : tensor<2x2xi32>
   %update = util.unfoldable_constant dense<1> : tensor<1xi32>
@@ -127,35 +94,3 @@
 
   return
 }
-
-func.func @scatter_2d_multiple_slice() {
-  %original = util.unfoldable_constant dense<0> : tensor<3x3xi32>
-  %update = util.unfoldable_constant dense<1> : tensor<2x2xi32>
-  %indices = util.unfoldable_constant dense<[[0, 1], [1, 0]]> : tensor<2x2xi32>
-  %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
-                          ins(%update, %indices : tensor<2x2xi32>, tensor<2x2xi32>)
-                          outs(%original : tensor<3x3xi32>) {
-                    ^bb0(%arg0: i32, %arg1: i32):
-                      iree_linalg_ext.yield %arg0 : i32
-  } -> tensor<3x3xi32>
-
-  check.expect_eq_const(%result, dense<[[0, 1, 1], [1, 1, 0], [0, 0, 0]]> : tensor<3x3xi32>) : tensor<3x3xi32>
-
-  return
-}
-
-func.func @scatter_2d_multiple_slice_transpose() {
-  %original = util.unfoldable_constant dense<0> : tensor<3x4xi32>
-  %update = util.unfoldable_constant dense<1> : tensor<2x2xi32>
-  %indices = util.unfoldable_constant dense<[[0, 1], [2, 0]]> : tensor<2x2xi32>
-  %result = iree_linalg_ext.scatter dimension_map = [1, 0] unique_indices(true)
-                          ins(%update, %indices : tensor<2x2xi32>, tensor<2x2xi32>)
-                          outs(%original : tensor<3x4xi32>) {
-                    ^bb0(%arg0: i32, %arg1: i32):
-                      iree_linalg_ext.yield %arg0 : i32
-  } -> tensor<3x4xi32>
-
-  check.expect_eq_const(%result, dense<[[0, 0, 1, 1], [1, 1, 0, 0], [0, 0, 0, 0]]> : tensor<3x4xi32>) : tensor<3x4xi32>
-
-  return
-}