Materialize inserted indices for mhlo.scatter (#10463)

mhlo.scatter supports insertedd dimensions at any index length-1 or
indexed into. Linalg's scatter requires that indices be explicitly
defined (excluding optionally inserting the indexed dims). Adds an
mhlo-to-mhlo rewriter that materializes the cases where the inserted
indices occur after index dims.

Includes fixing the broken `mhlo.scatter` tests in `mhlo-to-mhlo-preprocessing`.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index f1cb2e6..0401f65 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -471,7 +471,7 @@
                                                      indices, reassociationMap);
 
     auto newScatter = rewriter.create<mhlo::ScatterOp>(
-        op.getLoc(), op.getResultTypes(), op.getOperands(), indices,
+        op.getLoc(), op.getResultTypes(), op.operands(), indices,
         op.getUpdates(), dimNumbers, op.getIndicesAreSorted(),
         op.getUniqueIndices());
     Region &region = newScatter.getUpdateComputation();
@@ -543,7 +543,7 @@
         dimNumbers.getIndexVectorDim() + 1);
 
     auto newScatter = rewriter.create<mhlo::ScatterOp>(
-        op.getLoc(), op.getResultTypes(), op.getOperands(), indices, updates,
+        op.getLoc(), op.getResultTypes(), op.operands(), indices, updates,
         newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices());
     Region &region = newScatter.getUpdateComputation();
     rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end());
@@ -640,7 +640,7 @@
         /*indexVectorDim=*/1);
 
     auto newScatter = rewriter.create<mhlo::ScatterOp>(
-        op.getLoc(), op.getResultTypes(), op.getOperands(), indices, updates,
+        op.getLoc(), op.getResultTypes(), op.operands(), indices, updates,
         newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices());
     Region &region = newScatter.getUpdateComputation();
     rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end());
@@ -649,6 +649,141 @@
   }
 };
 
+// mhlo.scatter can materialize a unit dimension at both indexed dimensions or
+// at unary dimensions in the destination matrix. linalg_ext.scatter only
+// allows unit dimensions at indexed dimensions. This pattern inserts all
+// unary dimensions that are not index dimensions to be compatible with
+// linalg_ext.scatter.
+//
+// If converts an mhlo.scatter as below:
+//  %result = "mhlo.scatter"(...) ({
+//    indices_are_sorted = true,
+//    scatter_dimension_numbers = #mhlo.scatter<
+//            update_window_dims = [1],
+//            inserted_window_dims = [0, 2],
+//            scatter_dims_to_operand_dims = [0],
+//            index_vector_dim = 1>,
+//    unique_indices = true} :
+//        (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4xi32>)
+//
+// To:
+//  %result = "mhlo.scatter"(...) ({
+//    indices_are_sorted = true,
+//    scatter_dimension_numbers = #mhlo.scatter<
+//            update_window_dims = [1, 2],
+//            inserted_window_dims = [0],
+//            scatter_dims_to_operand_dims = [0],
+//            index_vector_dim = 1>,
+//     unique_indices = true} :
+//        (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4x1xi32>)
+//  return %0 : tensor<5x4x1xi32>
+struct ScatterMaterializeInsertedDim
+    : public OpRewritePattern<mhlo::ScatterOp> {
+  using OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mhlo::ScatterOp op,
+                                PatternRewriter &rewriter) const final {
+    auto indices = op.getScatterIndices();
+    auto operand = op.operands().front();
+    auto indicesTy = indices.getType().cast<ShapedType>();
+    auto operandTy = operand.getType().cast<ShapedType>();
+    if (!operandTy.hasRank() || !indicesTy.hasRank()) {
+      return rewriter.notifyMatchFailure(op, "operand/indices have no rank");
+    }
+
+    auto dimNumbers = op.getScatterDimensionNumbers();
+    auto updateDims = dimNumbers.getUpdateWindowDims();
+
+    if (indicesTy.getRank() != 2 || dimNumbers.getIndexVectorDim() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "indices is not of shape [batch, indices]");
+    }
+
+    if (!updateDims.empty() && updateDims.front() == 0) {
+      return rewriter.notifyMatchFailure(
+          op, "updates is not of shape [batch, ...]");
+    }
+
+    auto scatterDimsToOperandDims = dimNumbers.getScatterDimsToOperandDims();
+    llvm::SmallVector<bool> isIndexDim(operandTy.getRank(), false);
+    for (auto val : scatterDimsToOperandDims) {
+      isIndexDim[val] = true;
+    }
+
+    int64_t firstNonIndex = 0;
+    for (int64_t s = scatterDimsToOperandDims.size(); firstNonIndex < s;
+         ++firstNonIndex) {
+      if (!isIndexDim[firstNonIndex]) break;
+    }
+
+    llvm::SmallVector<bool> isInsertDims(operandTy.getRank(), false);
+    for (auto val : dimNumbers.getInsertedWindowDims()) {
+      isInsertDims[val] = true;
+    }
+
+    int64_t frontInsertedDims = 0;
+    for (; frontInsertedDims < firstNonIndex; ++frontInsertedDims) {
+      if (!isInsertDims[frontInsertedDims]) {
+        break;
+      }
+    }
+
+    llvm::ArrayRef<bool> toInsertDims =
+        llvm::ArrayRef<bool>(isInsertDims).drop_front(frontInsertedDims);
+    if (!llvm::any_of(toInsertDims, [](auto d) { return d; })) {
+      return rewriter.notifyMatchFailure(op, "no dimensions to insert");
+    }
+
+    // Create a reassociation map that starts with the batch dims.
+    SmallVector<ReassociationExprs, 4> reassociationMap;
+    reassociationMap.push_back({rewriter.getAffineDimExpr(0)});
+
+    for (auto it : llvm::enumerate(llvm::ArrayRef<bool>(toInsertDims))) {
+      if (!it.value()) reassociationMap.push_back({});
+      reassociationMap.back().push_back(
+          rewriter.getAffineDimExpr(it.index() + 1));
+    }
+
+    llvm::SmallVector<Value> expandedUpdates;
+    for (auto update : op.getUpdates()) {
+      auto updatesTy = update.getType().cast<ShapedType>();
+
+      llvm::SmallVector<int64_t> newShape;
+      for (int i = 0, s = reassociationMap.size(); i < s; ++i) {
+        newShape.push_back(updatesTy.getDimSize(i));
+        for (int j = 1, s = reassociationMap[i].size(); j < s; ++j) {
+          newShape.push_back(1);
+        }
+      }
+
+      Value expandUpdate = rewriter.create<tensor::ExpandShapeOp>(
+          op.getLoc(),
+          RankedTensorType::get(newShape, updatesTy.getElementType()), update,
+          reassociationMap);
+      expandedUpdates.push_back(expandUpdate);
+    }
+
+    llvm::SmallVector<int64_t> newUpdatedWindowDims(toInsertDims.size());
+    llvm::SmallVector<int64_t> newInsertedWindowDims(frontInsertedDims);
+    std::iota(newUpdatedWindowDims.begin(), newUpdatedWindowDims.end(), 1);
+    std::iota(newInsertedWindowDims.begin(), newInsertedWindowDims.end(), 0);
+
+    auto newDimNumbers = mhlo::ScatterDimensionNumbersAttr::get(
+        op.getContext(), newUpdatedWindowDims, newInsertedWindowDims,
+        dimNumbers.getScatterDimsToOperandDims(),
+        /*indexVectorDim=*/1);
+
+    auto newScatter = rewriter.create<mhlo::ScatterOp>(
+        op.getLoc(), op.getResultTypes(), op.operands(), op.getScatterIndices(),
+        expandedUpdates, newDimNumbers, op.getIndicesAreSorted(),
+        op.getUniqueIndices());
+    Region &region = newScatter.getUpdateComputation();
+    rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end());
+    rewriter.replaceOp(op, newScatter.getResults());
+    return success();
+  }
+};
+
 // Traverse upward past common operations to see if the value came from a
 // boolean tensor.
 bool isFromBool(Value val) {
@@ -1162,7 +1297,8 @@
 
     // scatter canonicalization patterns
     patterns.insert<ScatterOpImplicitIndex, ScatterOpImplicitBatch,
-                    ScatterOpCollapseBatch>(context);
+                    ScatterMaterializeInsertedDim, ScatterOpCollapseBatch>(
+        context);
 
     // dot_general canoncalization patterns.
     mhlo::populateGeneralDotOpLoweringPatterns(&patterns, context);
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
index b083052..99e5b4f 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
@@ -28,7 +28,7 @@
             "mhlo_to_linalg.mlir",
             "mhlo_to_mhlo_preprocessing.mlir",
             "mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir",
-            "mhlo_to_mhlo_preprocessing_disabled.mlir",
+            "mhlo_to_mhlo_scatter.mlir",
             "missing_legalizations.mlir",
             "transformation_pipeline.mlir",
             "verify_compiler_mhlo_input_legality.mlir",
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
index 341b4de..8404a20 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
@@ -24,7 +24,7 @@
     "mhlo_to_linalg.mlir"
     "mhlo_to_mhlo_preprocessing.mlir"
     "mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir"
-    "mhlo_to_mhlo_preprocessing_disabled.mlir"
+    "mhlo_to_mhlo_scatter.mlir"
     "missing_legalizations.mlir"
     "transformation_pipeline.mlir"
     "verify_compiler_mhlo_input_legality.mlir"
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
index 1e2da37..f0efca3 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
@@ -243,7 +243,7 @@
 // CHECK:         %[[RES:.+]] = mhlo.reshape %[[SLICE]] : (tensor<15xf32>) -> tensor<3x5xf32>
 // CHECK:         return %[[RES]]
 
-//-----
+// -----
 
 func.func @mul_float_bool_cast(%arg0 : tensor<?xi1>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
   %0 = "mhlo.convert"(%arg0) : (tensor<?xi1>) -> tensor<?xf32>
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_disabled.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_disabled.mlir
deleted file mode 100644
index bf8af5d..0000000
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_disabled.mlir
+++ /dev/null
@@ -1,87 +0,0 @@
-
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s
-
-// XFAIL: *
-// FIXME(#10774): Fix and re-enable MHLO scatter tests.
-
-func.func @scatter_implicit_batch(%arg0: tensor<5x5xi32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<5x5xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1]>, unique_indices = true} : (tensor<5x5xi32>, tensor<2xi32>, tensor<i32>) -> tensor<5x5xi32>
-  return %0 : tensor<5x5xi32>
-}
-
-// CHECK-LABEL: func.func @scatter_implicit_batch
-// CHECK-DAG: %[[RE_I:.+]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1]] : tensor<2xi32> into tensor<1x2xi32>
-// CHECK-DAG: %[[RE_U:.+]] = tensor.expand_shape %{{.*}} [] : tensor<i32> into tensor<1xi32>
-// CHECK:     %[[SCATTER:.+]] = "mhlo.scatter"(%{{.*}}, %[[RE_I]], %[[RE_U]])
-// CHECK:       mhlo.return %{{.*}}
-
-// -----
-
-func.func @scatter_implicit_indices(%arg0: tensor<17x11xf32>,
-  %arg1: tensor<7xi32>, %arg2: tensor<7x11xf32>) -> tensor<17x11xf32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
-  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
-    %1 = mhlo.add %arg3, %arg4 : tensor<f32>
-    "mhlo.return"(%1) : (tensor<f32>) -> ()
-  }) {indices_are_sorted = false,
-      scatter_dimension_numbers = #mhlo.scatter<
-      update_window_dims = [1],
-      inserted_window_dims = [0],
-      scatter_dims_to_operand_dims = [0],
-      index_vector_dim = 1>, 
-      unique_indices = false
-      } : (tensor<17x11xf32>, tensor<7xi32>, tensor<7x11xf32>) -> tensor<17x11xf32>
-  return %0 : tensor<17x11xf32>
-}
-
-// CHECK-LABEL: func.func @scatter_implicit_indices
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg1 {{\[\[}}0, 1]] : tensor<7xi32> into tensor<7x1xi32>
-// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[EXPAND]], %arg2) ({
-// CHECK-NEXT: ^bb0(%[[A0:.+]]: tensor<f32>, %[[A1:.+]]: tensor<f32>):
-// CHECK-NEXT:   %[[ADD:.+]] = mhlo.add %[[A0]], %[[A1]] : tensor<f32>
-// CHECK-NEXT:   mhlo.return %[[ADD]]
-// CHECK-NEXT: })
-// CHECK-SAME: indices_are_sorted = false,
-// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<
-// CHECK-SAME:   update_window_dims = [1],
-// CHECK-SAME:   inserted_window_dims = [0],
-// CHECK-SAME:   scatter_dims_to_operand_dims = [0],
-// CHECK-SAME:   index_vector_dim = 1>,
-// CHECK-SAME:   unique_indices = false
-
-// -----
-
-func.func @scatter_collapse_batch(%arg0: tensor<1x24x512xi32>,
-    %arg1: tensor<2x3x2xi32>, %arg2: tensor<2x3x512xi32>) -> tensor<1x24x512xi32> {
-  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
-  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
-    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
-  }) {indices_are_sorted = false,
-      scatter_dimension_numbers = #mhlo.scatter<
-        update_window_dims = [2],
-        inserted_window_dims = [0, 1],
-        scatter_dims_to_operand_dims = [0, 1],
-        index_vector_dim = 2,
-      >,
-      unique_indices = true
-  } : (tensor<1x24x512xi32>, tensor<2x3x2xi32>, tensor<2x3x512xi32>) -> tensor<1x24x512xi32>
-  return %0 : tensor<1x24x512xi32>
-}
-
-// CHECK-LABEL: func.func @scatter_collapse_batch
-// CHECK: %[[COLLAPSE0:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<2x3x2xi32> into tensor<6x2xi32>
-// CHECK: %[[COLLAPSE1:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0, 1], [2]] : tensor<2x3x512xi32> into tensor<6x512xi32>
-// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[COLLAPSE0]], %[[COLLAPSE1]])
-// CHECK: ^bb0(%[[ARG0:.+]]: tensor<i32>, %[[ARG1:.+]]: tensor<i32>):
-// CHECK:   mhlo.return %[[ARG1]]
-// CHECK: }) {
-// CHECK: indices_are_sorted = false,
-// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1]
-// CHECK-SAME: inserted_window_dims = [0, 1]
-// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
-// CHECK-SAME: index_vector_dim = 1>
-// CHECK-SAME: unique_indices = true
-// CHECK: return %[[SCATTER]]
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir
new file mode 100644
index 0000000..a71312c
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir
@@ -0,0 +1,208 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s
+
+func.func @scatter_implicit_batch(%arg0: tensor<5x5xi32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<5x5xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1]>, unique_indices = true} : (tensor<5x5xi32>, tensor<2xi32>, tensor<i32>) -> tensor<5x5xi32>
+  return %0 : tensor<5x5xi32>
+}
+
+// CHECK-LABEL: func.func @scatter_implicit_batch
+// CHECK-DAG: %[[RE_I:.+]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1]] : tensor<2xi32> into tensor<1x2xi32>
+// CHECK-DAG: %[[RE_U:.+]] = tensor.expand_shape %{{.*}} [] : tensor<i32> into tensor<1xi32>
+// CHECK:     %[[SCATTER:.+]] = "mhlo.scatter"(%{{.*}}, %[[RE_I]], %[[RE_U]])
+// CHECK:       mhlo.return %{{.*}}
+
+// -----
+
+func.func @scatter_implicit_indices(%arg0: tensor<17x11xf32>,
+  %arg1: tensor<7xi32>, %arg2: tensor<7x11xf32>) -> tensor<17x11xf32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = mhlo.add %arg3, %arg4 : tensor<f32>
+    "mhlo.return"(%1) : (tensor<f32>) -> ()
+  }) {indices_are_sorted = false,
+      scatter_dimension_numbers = #mhlo.scatter<
+      update_window_dims = [1],
+      inserted_window_dims = [0],
+      scatter_dims_to_operand_dims = [0],
+      index_vector_dim = 1>, 
+      unique_indices = false
+      } : (tensor<17x11xf32>, tensor<7xi32>, tensor<7x11xf32>) -> tensor<17x11xf32>
+  return %0 : tensor<17x11xf32>
+}
+
+// CHECK-LABEL: func.func @scatter_implicit_indices
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg1 {{\[\[}}0, 1]] : tensor<7xi32> into tensor<7x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[EXPAND]], %arg2) ({
+// CHECK-NEXT: ^bb0(%[[A0:.+]]: tensor<f32>, %[[A1:.+]]: tensor<f32>):
+// CHECK-NEXT:   %[[ADD:.+]] = mhlo.add %[[A0]], %[[A1]] : tensor<f32>
+// CHECK-NEXT:   mhlo.return %[[ADD]]
+// CHECK-NEXT: })
+// CHECK-SAME: indices_are_sorted = false,
+// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME:   update_window_dims = [1],
+// CHECK-SAME:   inserted_window_dims = [0],
+// CHECK-SAME:   scatter_dims_to_operand_dims = [0],
+// CHECK-SAME:   index_vector_dim = 1>,
+// CHECK-SAME:   unique_indices = false
+
+// -----
+
+func.func @scatter_collapse_batch(%arg0: tensor<1x24x512xi32>,
+    %arg1: tensor<2x3x2xi32>, %arg2: tensor<2x3x512xi32>) -> tensor<1x24x512xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {indices_are_sorted = false,
+      scatter_dimension_numbers = #mhlo.scatter<
+        update_window_dims = [2],
+        inserted_window_dims = [0, 1],
+        scatter_dims_to_operand_dims = [0, 1],
+        index_vector_dim = 2,
+      >,
+      unique_indices = true
+  } : (tensor<1x24x512xi32>, tensor<2x3x2xi32>, tensor<2x3x512xi32>) -> tensor<1x24x512xi32>
+  return %0 : tensor<1x24x512xi32>
+}
+
+// CHECK-LABEL: func.func @scatter_collapse_batch
+// CHECK: %[[COLLAPSE0:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<2x3x2xi32> into tensor<6x2xi32>
+// CHECK: %[[COLLAPSE1:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0, 1], [2]] : tensor<2x3x512xi32> into tensor<6x512xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[COLLAPSE0]], %[[COLLAPSE1]])
+// CHECK: ^bb0(%[[ARG0:.+]]: tensor<i32>, %[[ARG1:.+]]: tensor<i32>):
+// CHECK:   mhlo.return %[[ARG1]]
+// CHECK: }) {
+// CHECK: indices_are_sorted = false,
+// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1]
+// CHECK-SAME: inserted_window_dims = [0, 1]
+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
+// CHECK-SAME: index_vector_dim = 1>
+// CHECK-SAME: unique_indices = true
+// CHECK: return %[[SCATTER]]
+
+// -----
+
+func.func @scatter_materialize_index_update(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x4xi32>) -> tensor<5x1x1xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = true,
+    scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1],
+                                              inserted_window_dims = [1, 2],
+                                              scatter_dims_to_operand_dims = [0, 1],
+                                              index_vector_dim = 1>,
+    unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x2xi32>, tensor<1x4xi32>) -> tensor<5x1x1xi32>
+  return %0 : tensor<5x1x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_index_update
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0], [1, 2, 3]] : tensor<1x4xi32> into tensor<1x4x1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK:                   indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME:                update_window_dims = [1, 2, 3]
+// CHECK-SAME:                scatter_dims_to_operand_dims = [0, 1]
+// CHECK-SAME:                index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_materialize_one_dim(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1xi32>) -> tensor<5x1x1xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = true,
+    scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [],
+                                              inserted_window_dims = [0, 1, 2],
+                                              scatter_dims_to_operand_dims = [0, 1],
+                                              index_vector_dim = 1>,
+    unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x2xi32>, tensor<1xi32>) -> tensor<5x1x1xi32>
+  return %0 : tensor<5x1x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_one_dim
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1]] : tensor<1xi32> into tensor<1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK:                   indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME:                 update_window_dims = [1]
+// CHECK-SAME:                 inserted_window_dims = [0, 1]
+// CHECK-SAME:                 scatter_dims_to_operand_dims = [0, 1]
+// CHECK-SAME:                 index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_materialize_two_dims(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<5x1x1xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = true,
+    scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [],
+                                              inserted_window_dims = [0, 1, 2],
+                                              scatter_dims_to_operand_dims = [0],
+                                              index_vector_dim = 1>,
+    unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<5x1x1xi32>
+  return %0 : tensor<5x1x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_two_dims
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1, 2]] : tensor<1xi32> into tensor<1x1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK:                   indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME:                 update_window_dims = [1, 2]
+// CHECK-SAME:                 inserted_window_dims = [0]
+// CHECK-SAME:                 scatter_dims_to_operand_dims = [0]
+// CHECK-SAME:                 index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_materialize_comprehensive(%arg0: tensor<5x4x1xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1x4xi32>) -> tensor<5x4x1xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = true,
+    scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1],
+                                              inserted_window_dims = [0, 2],
+                                              scatter_dims_to_operand_dims = [0],
+                                              index_vector_dim = 1>,
+    unique_indices = true} : (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4xi32>) -> tensor<5x4x1xi32>
+  return %0 : tensor<5x4x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_comprehensive
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0], [1, 2]] : tensor<1x4xi32> into tensor<1x4x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK:                   indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME:                 update_window_dims = [1, 2]
+// CHECK-SAME:                 inserted_window_dims = [0]
+// CHECK-SAME:                 scatter_dims_to_operand_dims = [0]
+// CHECK-SAME:                 index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_operand_map(%arg0: tensor<5x4x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1xi32>) -> tensor<5x4x1xi32> {
+  %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+    "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+  }) {
+    indices_are_sorted = true,
+    scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [],
+                                              inserted_window_dims = [0, 1, 2],
+                                              scatter_dims_to_operand_dims = [0, 2],
+                                              index_vector_dim = 1>,
+    unique_indices = true} : (tensor<5x4x1xi32>, tensor<1x2xi32>, tensor<1xi32>) -> tensor<5x4x1xi32>
+  return %0 : tensor<5x4x1xi32>
+}
+
+// CHECK-LABEL: @scatter_operand_map
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1, 2]] : tensor<1xi32> into tensor<1x1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK:                   indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME:                 update_window_dims = [1, 2],
+// CHECK-SAME:                 inserted_window_dims = [0],
+// CHECK-SAME:                 scatter_dims_to_operand_dims = [0, 2],
+// CHECK-SAME:                 index_vector_dim = 1>, unique_indices = true