Materialize inserted indices for mhlo.scatter (#10463)
mhlo.scatter supports insertedd dimensions at any index length-1 or
indexed into. Linalg's scatter requires that indices be explicitly
defined (excluding optionally inserting the indexed dims). Adds an
mhlo-to-mhlo rewriter that materializes the cases where the inserted
indices occur after index dims.
Includes fixing the broken `mhlo.scatter` tests in `mhlo-to-mhlo-preprocessing`.
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
index f1cb2e6..0401f65 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp
@@ -471,7 +471,7 @@
indices, reassociationMap);
auto newScatter = rewriter.create<mhlo::ScatterOp>(
- op.getLoc(), op.getResultTypes(), op.getOperands(), indices,
+ op.getLoc(), op.getResultTypes(), op.operands(), indices,
op.getUpdates(), dimNumbers, op.getIndicesAreSorted(),
op.getUniqueIndices());
Region ®ion = newScatter.getUpdateComputation();
@@ -543,7 +543,7 @@
dimNumbers.getIndexVectorDim() + 1);
auto newScatter = rewriter.create<mhlo::ScatterOp>(
- op.getLoc(), op.getResultTypes(), op.getOperands(), indices, updates,
+ op.getLoc(), op.getResultTypes(), op.operands(), indices, updates,
newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices());
Region ®ion = newScatter.getUpdateComputation();
rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end());
@@ -640,7 +640,7 @@
/*indexVectorDim=*/1);
auto newScatter = rewriter.create<mhlo::ScatterOp>(
- op.getLoc(), op.getResultTypes(), op.getOperands(), indices, updates,
+ op.getLoc(), op.getResultTypes(), op.operands(), indices, updates,
newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices());
Region ®ion = newScatter.getUpdateComputation();
rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end());
@@ -649,6 +649,141 @@
}
};
+// mhlo.scatter can materialize a unit dimension at both indexed dimensions or
+// at unary dimensions in the destination matrix. linalg_ext.scatter only
+// allows unit dimensions at indexed dimensions. This pattern inserts all
+// unary dimensions that are not index dimensions to be compatible with
+// linalg_ext.scatter.
+//
+// If converts an mhlo.scatter as below:
+// %result = "mhlo.scatter"(...) ({
+// indices_are_sorted = true,
+// scatter_dimension_numbers = #mhlo.scatter<
+// update_window_dims = [1],
+// inserted_window_dims = [0, 2],
+// scatter_dims_to_operand_dims = [0],
+// index_vector_dim = 1>,
+// unique_indices = true} :
+// (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4xi32>)
+//
+// To:
+// %result = "mhlo.scatter"(...) ({
+// indices_are_sorted = true,
+// scatter_dimension_numbers = #mhlo.scatter<
+// update_window_dims = [1, 2],
+// inserted_window_dims = [0],
+// scatter_dims_to_operand_dims = [0],
+// index_vector_dim = 1>,
+// unique_indices = true} :
+// (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4x1xi32>)
+// return %0 : tensor<5x4x1xi32>
+struct ScatterMaterializeInsertedDim
+ : public OpRewritePattern<mhlo::ScatterOp> {
+ using OpRewritePattern<mhlo::ScatterOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mhlo::ScatterOp op,
+ PatternRewriter &rewriter) const final {
+ auto indices = op.getScatterIndices();
+ auto operand = op.operands().front();
+ auto indicesTy = indices.getType().cast<ShapedType>();
+ auto operandTy = operand.getType().cast<ShapedType>();
+ if (!operandTy.hasRank() || !indicesTy.hasRank()) {
+ return rewriter.notifyMatchFailure(op, "operand/indices have no rank");
+ }
+
+ auto dimNumbers = op.getScatterDimensionNumbers();
+ auto updateDims = dimNumbers.getUpdateWindowDims();
+
+ if (indicesTy.getRank() != 2 || dimNumbers.getIndexVectorDim() != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "indices is not of shape [batch, indices]");
+ }
+
+ if (!updateDims.empty() && updateDims.front() == 0) {
+ return rewriter.notifyMatchFailure(
+ op, "updates is not of shape [batch, ...]");
+ }
+
+ auto scatterDimsToOperandDims = dimNumbers.getScatterDimsToOperandDims();
+ llvm::SmallVector<bool> isIndexDim(operandTy.getRank(), false);
+ for (auto val : scatterDimsToOperandDims) {
+ isIndexDim[val] = true;
+ }
+
+ int64_t firstNonIndex = 0;
+ for (int64_t s = scatterDimsToOperandDims.size(); firstNonIndex < s;
+ ++firstNonIndex) {
+ if (!isIndexDim[firstNonIndex]) break;
+ }
+
+ llvm::SmallVector<bool> isInsertDims(operandTy.getRank(), false);
+ for (auto val : dimNumbers.getInsertedWindowDims()) {
+ isInsertDims[val] = true;
+ }
+
+ int64_t frontInsertedDims = 0;
+ for (; frontInsertedDims < firstNonIndex; ++frontInsertedDims) {
+ if (!isInsertDims[frontInsertedDims]) {
+ break;
+ }
+ }
+
+ llvm::ArrayRef<bool> toInsertDims =
+ llvm::ArrayRef<bool>(isInsertDims).drop_front(frontInsertedDims);
+ if (!llvm::any_of(toInsertDims, [](auto d) { return d; })) {
+ return rewriter.notifyMatchFailure(op, "no dimensions to insert");
+ }
+
+ // Create a reassociation map that starts with the batch dims.
+ SmallVector<ReassociationExprs, 4> reassociationMap;
+ reassociationMap.push_back({rewriter.getAffineDimExpr(0)});
+
+ for (auto it : llvm::enumerate(llvm::ArrayRef<bool>(toInsertDims))) {
+ if (!it.value()) reassociationMap.push_back({});
+ reassociationMap.back().push_back(
+ rewriter.getAffineDimExpr(it.index() + 1));
+ }
+
+ llvm::SmallVector<Value> expandedUpdates;
+ for (auto update : op.getUpdates()) {
+ auto updatesTy = update.getType().cast<ShapedType>();
+
+ llvm::SmallVector<int64_t> newShape;
+ for (int i = 0, s = reassociationMap.size(); i < s; ++i) {
+ newShape.push_back(updatesTy.getDimSize(i));
+ for (int j = 1, s = reassociationMap[i].size(); j < s; ++j) {
+ newShape.push_back(1);
+ }
+ }
+
+ Value expandUpdate = rewriter.create<tensor::ExpandShapeOp>(
+ op.getLoc(),
+ RankedTensorType::get(newShape, updatesTy.getElementType()), update,
+ reassociationMap);
+ expandedUpdates.push_back(expandUpdate);
+ }
+
+ llvm::SmallVector<int64_t> newUpdatedWindowDims(toInsertDims.size());
+ llvm::SmallVector<int64_t> newInsertedWindowDims(frontInsertedDims);
+ std::iota(newUpdatedWindowDims.begin(), newUpdatedWindowDims.end(), 1);
+ std::iota(newInsertedWindowDims.begin(), newInsertedWindowDims.end(), 0);
+
+ auto newDimNumbers = mhlo::ScatterDimensionNumbersAttr::get(
+ op.getContext(), newUpdatedWindowDims, newInsertedWindowDims,
+ dimNumbers.getScatterDimsToOperandDims(),
+ /*indexVectorDim=*/1);
+
+ auto newScatter = rewriter.create<mhlo::ScatterOp>(
+ op.getLoc(), op.getResultTypes(), op.operands(), op.getScatterIndices(),
+ expandedUpdates, newDimNumbers, op.getIndicesAreSorted(),
+ op.getUniqueIndices());
+ Region ®ion = newScatter.getUpdateComputation();
+ rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end());
+ rewriter.replaceOp(op, newScatter.getResults());
+ return success();
+ }
+};
+
// Traverse upward past common operations to see if the value came from a
// boolean tensor.
bool isFromBool(Value val) {
@@ -1162,7 +1297,8 @@
// scatter canonicalization patterns
patterns.insert<ScatterOpImplicitIndex, ScatterOpImplicitBatch,
- ScatterOpCollapseBatch>(context);
+ ScatterMaterializeInsertedDim, ScatterOpCollapseBatch>(
+ context);
// dot_general canoncalization patterns.
mhlo::populateGeneralDotOpLoweringPatterns(&patterns, context);
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
index b083052..99e5b4f 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD
@@ -28,7 +28,7 @@
"mhlo_to_linalg.mlir",
"mhlo_to_mhlo_preprocessing.mlir",
"mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir",
- "mhlo_to_mhlo_preprocessing_disabled.mlir",
+ "mhlo_to_mhlo_scatter.mlir",
"missing_legalizations.mlir",
"transformation_pipeline.mlir",
"verify_compiler_mhlo_input_legality.mlir",
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
index 341b4de..8404a20 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt
@@ -24,7 +24,7 @@
"mhlo_to_linalg.mlir"
"mhlo_to_mhlo_preprocessing.mlir"
"mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir"
- "mhlo_to_mhlo_preprocessing_disabled.mlir"
+ "mhlo_to_mhlo_scatter.mlir"
"missing_legalizations.mlir"
"transformation_pipeline.mlir"
"verify_compiler_mhlo_input_legality.mlir"
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
index 1e2da37..f0efca3 100644
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir
@@ -243,7 +243,7 @@
// CHECK: %[[RES:.+]] = mhlo.reshape %[[SLICE]] : (tensor<15xf32>) -> tensor<3x5xf32>
// CHECK: return %[[RES]]
-//-----
+// -----
func.func @mul_float_bool_cast(%arg0 : tensor<?xi1>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
%0 = "mhlo.convert"(%arg0) : (tensor<?xi1>) -> tensor<?xf32>
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_disabled.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_disabled.mlir
deleted file mode 100644
index bf8af5d..0000000
--- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_disabled.mlir
+++ /dev/null
@@ -1,87 +0,0 @@
-
-// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s
-
-// XFAIL: *
-// FIXME(#10774): Fix and re-enable MHLO scatter tests.
-
-func.func @scatter_implicit_batch(%arg0: tensor<5x5xi32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<5x5xi32> {
- %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
- "mhlo.return"(%arg4) : (tensor<i32>) -> ()
- }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1]>, unique_indices = true} : (tensor<5x5xi32>, tensor<2xi32>, tensor<i32>) -> tensor<5x5xi32>
- return %0 : tensor<5x5xi32>
-}
-
-// CHECK-LABEL: func.func @scatter_implicit_batch
-// CHECK-DAG: %[[RE_I:.+]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1]] : tensor<2xi32> into tensor<1x2xi32>
-// CHECK-DAG: %[[RE_U:.+]] = tensor.expand_shape %{{.*}} [] : tensor<i32> into tensor<1xi32>
-// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%{{.*}}, %[[RE_I]], %[[RE_U]])
-// CHECK: mhlo.return %{{.*}}
-
-// -----
-
-func.func @scatter_implicit_indices(%arg0: tensor<17x11xf32>,
- %arg1: tensor<7xi32>, %arg2: tensor<7x11xf32>) -> tensor<17x11xf32> {
- %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
- ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
- %1 = mhlo.add %arg3, %arg4 : tensor<f32>
- "mhlo.return"(%1) : (tensor<f32>) -> ()
- }) {indices_are_sorted = false,
- scatter_dimension_numbers = #mhlo.scatter<
- update_window_dims = [1],
- inserted_window_dims = [0],
- scatter_dims_to_operand_dims = [0],
- index_vector_dim = 1>,
- unique_indices = false
- } : (tensor<17x11xf32>, tensor<7xi32>, tensor<7x11xf32>) -> tensor<17x11xf32>
- return %0 : tensor<17x11xf32>
-}
-
-// CHECK-LABEL: func.func @scatter_implicit_indices
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg1 {{\[\[}}0, 1]] : tensor<7xi32> into tensor<7x1xi32>
-// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[EXPAND]], %arg2) ({
-// CHECK-NEXT: ^bb0(%[[A0:.+]]: tensor<f32>, %[[A1:.+]]: tensor<f32>):
-// CHECK-NEXT: %[[ADD:.+]] = mhlo.add %[[A0]], %[[A1]] : tensor<f32>
-// CHECK-NEXT: mhlo.return %[[ADD]]
-// CHECK-NEXT: })
-// CHECK-SAME: indices_are_sorted = false,
-// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<
-// CHECK-SAME: update_window_dims = [1],
-// CHECK-SAME: inserted_window_dims = [0],
-// CHECK-SAME: scatter_dims_to_operand_dims = [0],
-// CHECK-SAME: index_vector_dim = 1>,
-// CHECK-SAME: unique_indices = false
-
-// -----
-
-func.func @scatter_collapse_batch(%arg0: tensor<1x24x512xi32>,
- %arg1: tensor<2x3x2xi32>, %arg2: tensor<2x3x512xi32>) -> tensor<1x24x512xi32> {
- %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
- ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
- "mhlo.return"(%arg4) : (tensor<i32>) -> ()
- }) {indices_are_sorted = false,
- scatter_dimension_numbers = #mhlo.scatter<
- update_window_dims = [2],
- inserted_window_dims = [0, 1],
- scatter_dims_to_operand_dims = [0, 1],
- index_vector_dim = 2,
- >,
- unique_indices = true
- } : (tensor<1x24x512xi32>, tensor<2x3x2xi32>, tensor<2x3x512xi32>) -> tensor<1x24x512xi32>
- return %0 : tensor<1x24x512xi32>
-}
-
-// CHECK-LABEL: func.func @scatter_collapse_batch
-// CHECK: %[[COLLAPSE0:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<2x3x2xi32> into tensor<6x2xi32>
-// CHECK: %[[COLLAPSE1:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0, 1], [2]] : tensor<2x3x512xi32> into tensor<6x512xi32>
-// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[COLLAPSE0]], %[[COLLAPSE1]])
-// CHECK: ^bb0(%[[ARG0:.+]]: tensor<i32>, %[[ARG1:.+]]: tensor<i32>):
-// CHECK: mhlo.return %[[ARG1]]
-// CHECK: }) {
-// CHECK: indices_are_sorted = false,
-// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1]
-// CHECK-SAME: inserted_window_dims = [0, 1]
-// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
-// CHECK-SAME: index_vector_dim = 1>
-// CHECK-SAME: unique_indices = true
-// CHECK: return %[[SCATTER]]
diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir
new file mode 100644
index 0000000..a71312c
--- /dev/null
+++ b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir
@@ -0,0 +1,208 @@
+// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s
+
+func.func @scatter_implicit_batch(%arg0: tensor<5x5xi32>, %arg1: tensor<2xi32>, %arg2: tensor<i32>) -> tensor<5x5xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1]>, unique_indices = true} : (tensor<5x5xi32>, tensor<2xi32>, tensor<i32>) -> tensor<5x5xi32>
+ return %0 : tensor<5x5xi32>
+}
+
+// CHECK-LABEL: func.func @scatter_implicit_batch
+// CHECK-DAG: %[[RE_I:.+]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1]] : tensor<2xi32> into tensor<1x2xi32>
+// CHECK-DAG: %[[RE_U:.+]] = tensor.expand_shape %{{.*}} [] : tensor<i32> into tensor<1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%{{.*}}, %[[RE_I]], %[[RE_U]])
+// CHECK: mhlo.return %{{.*}}
+
+// -----
+
+func.func @scatter_implicit_indices(%arg0: tensor<17x11xf32>,
+ %arg1: tensor<7xi32>, %arg2: tensor<7x11xf32>) -> tensor<17x11xf32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = mhlo.add %arg3, %arg4 : tensor<f32>
+ "mhlo.return"(%1) : (tensor<f32>) -> ()
+ }) {indices_are_sorted = false,
+ scatter_dimension_numbers = #mhlo.scatter<
+ update_window_dims = [1],
+ inserted_window_dims = [0],
+ scatter_dims_to_operand_dims = [0],
+ index_vector_dim = 1>,
+ unique_indices = false
+ } : (tensor<17x11xf32>, tensor<7xi32>, tensor<7x11xf32>) -> tensor<17x11xf32>
+ return %0 : tensor<17x11xf32>
+}
+
+// CHECK-LABEL: func.func @scatter_implicit_indices
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg1 {{\[\[}}0, 1]] : tensor<7xi32> into tensor<7x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[EXPAND]], %arg2) ({
+// CHECK-NEXT: ^bb0(%[[A0:.+]]: tensor<f32>, %[[A1:.+]]: tensor<f32>):
+// CHECK-NEXT: %[[ADD:.+]] = mhlo.add %[[A0]], %[[A1]] : tensor<f32>
+// CHECK-NEXT: mhlo.return %[[ADD]]
+// CHECK-NEXT: })
+// CHECK-SAME: indices_are_sorted = false,
+// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1],
+// CHECK-SAME: inserted_window_dims = [0],
+// CHECK-SAME: scatter_dims_to_operand_dims = [0],
+// CHECK-SAME: index_vector_dim = 1>,
+// CHECK-SAME: unique_indices = false
+
+// -----
+
+func.func @scatter_collapse_batch(%arg0: tensor<1x24x512xi32>,
+ %arg1: tensor<2x3x2xi32>, %arg2: tensor<2x3x512xi32>) -> tensor<1x24x512xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {indices_are_sorted = false,
+ scatter_dimension_numbers = #mhlo.scatter<
+ update_window_dims = [2],
+ inserted_window_dims = [0, 1],
+ scatter_dims_to_operand_dims = [0, 1],
+ index_vector_dim = 2,
+ >,
+ unique_indices = true
+ } : (tensor<1x24x512xi32>, tensor<2x3x2xi32>, tensor<2x3x512xi32>) -> tensor<1x24x512xi32>
+ return %0 : tensor<1x24x512xi32>
+}
+
+// CHECK-LABEL: func.func @scatter_collapse_batch
+// CHECK: %[[COLLAPSE0:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<2x3x2xi32> into tensor<6x2xi32>
+// CHECK: %[[COLLAPSE1:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0, 1], [2]] : tensor<2x3x512xi32> into tensor<6x512xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[COLLAPSE0]], %[[COLLAPSE1]])
+// CHECK: ^bb0(%[[ARG0:.+]]: tensor<i32>, %[[ARG1:.+]]: tensor<i32>):
+// CHECK: mhlo.return %[[ARG1]]
+// CHECK: }) {
+// CHECK: indices_are_sorted = false,
+// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1]
+// CHECK-SAME: inserted_window_dims = [0, 1]
+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
+// CHECK-SAME: index_vector_dim = 1>
+// CHECK-SAME: unique_indices = true
+// CHECK: return %[[SCATTER]]
+
+// -----
+
+func.func @scatter_materialize_index_update(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x4xi32>) -> tensor<5x1x1xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {
+ indices_are_sorted = true,
+ scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1],
+ inserted_window_dims = [1, 2],
+ scatter_dims_to_operand_dims = [0, 1],
+ index_vector_dim = 1>,
+ unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x2xi32>, tensor<1x4xi32>) -> tensor<5x1x1xi32>
+ return %0 : tensor<5x1x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_index_update
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0], [1, 2, 3]] : tensor<1x4xi32> into tensor<1x4x1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1, 2, 3]
+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
+// CHECK-SAME: index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_materialize_one_dim(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1xi32>) -> tensor<5x1x1xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {
+ indices_are_sorted = true,
+ scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [],
+ inserted_window_dims = [0, 1, 2],
+ scatter_dims_to_operand_dims = [0, 1],
+ index_vector_dim = 1>,
+ unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x2xi32>, tensor<1xi32>) -> tensor<5x1x1xi32>
+ return %0 : tensor<5x1x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_one_dim
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1]] : tensor<1xi32> into tensor<1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1]
+// CHECK-SAME: inserted_window_dims = [0, 1]
+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
+// CHECK-SAME: index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_materialize_two_dims(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<5x1x1xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {
+ indices_are_sorted = true,
+ scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [],
+ inserted_window_dims = [0, 1, 2],
+ scatter_dims_to_operand_dims = [0],
+ index_vector_dim = 1>,
+ unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<5x1x1xi32>
+ return %0 : tensor<5x1x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_two_dims
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1, 2]] : tensor<1xi32> into tensor<1x1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1, 2]
+// CHECK-SAME: inserted_window_dims = [0]
+// CHECK-SAME: scatter_dims_to_operand_dims = [0]
+// CHECK-SAME: index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_materialize_comprehensive(%arg0: tensor<5x4x1xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1x4xi32>) -> tensor<5x4x1xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {
+ indices_are_sorted = true,
+ scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1],
+ inserted_window_dims = [0, 2],
+ scatter_dims_to_operand_dims = [0],
+ index_vector_dim = 1>,
+ unique_indices = true} : (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4xi32>) -> tensor<5x4x1xi32>
+ return %0 : tensor<5x4x1xi32>
+}
+
+// CHECK-LABEL: @scatter_materialize_comprehensive
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0], [1, 2]] : tensor<1x4xi32> into tensor<1x4x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1, 2]
+// CHECK-SAME: inserted_window_dims = [0]
+// CHECK-SAME: scatter_dims_to_operand_dims = [0]
+// CHECK-SAME: index_vector_dim = 1>, unique_indices = true
+
+// -----
+
+func.func @scatter_operand_map(%arg0: tensor<5x4x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1xi32>) -> tensor<5x4x1xi32> {
+ %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
+ "mhlo.return"(%arg4) : (tensor<i32>) -> ()
+ }) {
+ indices_are_sorted = true,
+ scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [],
+ inserted_window_dims = [0, 1, 2],
+ scatter_dims_to_operand_dims = [0, 2],
+ index_vector_dim = 1>,
+ unique_indices = true} : (tensor<5x4x1xi32>, tensor<1x2xi32>, tensor<1xi32>) -> tensor<5x4x1xi32>
+ return %0 : tensor<5x4x1xi32>
+}
+
+// CHECK-LABEL: @scatter_operand_map
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1, 2]] : tensor<1xi32> into tensor<1x1x1xi32>
+// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]])
+// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<
+// CHECK-SAME: update_window_dims = [1, 2],
+// CHECK-SAME: inserted_window_dims = [0],
+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2],
+// CHECK-SAME: index_vector_dim = 1>, unique_indices = true