[LinalgExt] Update scatter to allow dropping unit dims (#19704)
Changes `scatter` to enforce that `updates` is only batch dims +
contiguous slice. This allows us to drop fold the unit dimension from
`indices` when index_depth is 1.
---------
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
index dc8b002..93db396 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
@@ -10,6 +10,8 @@
#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -61,10 +63,33 @@
for (int i = 0; i < numIndices; i++)
dimMap[i] = i;
- auto scatterOp = rewriter.create<IREE::LinalgExt::ScatterOp>(
- op.getLoc(), op->getResultTypes(), op.getInputs(), op.getOutputs(),
- dimMap, op.getUniqueIndices());
+ auto updatesTy = op.getUpdateType();
+ // Create a reassociation that drops all unit dims from the indexed portion
+ // slice.
+ Value updateVal = op.updates();
+ SmallVector<int64_t> collapsedShape;
+ collapsedShape.push_back(updatesTy.getShape().front());
+ if (op.getUpdateSliceRank() > 0) {
+ llvm::append_range(collapsedShape,
+ updatesTy.getShape().take_back(
+ op.getUpdateSliceRank() - op.getIndexDepth()));
+ }
+ if (collapsedShape != updatesTy.getShape()) {
+ auto reassocIndices = getReassociationIndicesForCollapse(
+ updatesTy.getShape(), collapsedShape);
+ if (!reassocIndices.has_value()) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to compute reassociation indices");
+ }
+ updateVal = rewriter.create<tensor::CollapseShapeOp>(
+ op.getLoc(), updateVal, reassocIndices.value());
+ }
+
+ Value indicesVal = op.indices();
+ auto scatterOp = rewriter.create<IREE::LinalgExt::ScatterOp>(
+ op.getLoc(), op->getResultTypes(), ValueRange{updateVal, indicesVal},
+ op.getOutputs(), dimMap, op.getUniqueIndices());
rewriter.inlineRegionBefore(op.getRegion(), scatterOp.getRegion(),
scatterOp.getRegion().begin());
rewriter.replaceOp(op, scatterOp->getResults());
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c1f452e..eed2a45 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -147,93 +147,77 @@
}
auto indicesType = getIndicesType();
- if (indicesType.getRank() < 2 ||
+ if (indicesType.getRank() < 1 ||
!isa<IntegerType>(indicesType.getElementType())) {
- return op->emitOpError("expected indices to be of rank 2 or greater and of "
+ return op->emitOpError("expected indices to be of rank 1 or greater and of "
"integer element type");
}
- auto indexDepth = getIndexDepth();
- if (ShapedType::isDynamic(indexDepth)) {
- return op->emitOpError("expected index depth is static");
- }
ArrayRef<int64_t> dimMap = getDimensionMap();
- if (dimMap.size() != indexDepth) {
- return op->emitOpError("invalid number of dimension map entries");
- }
-
- auto originalType = getOriginalType();
if (failed(isPermSequence(
[&]() { return this->emitOpError("dimension map is invalid."); },
dimMap))) {
return failure();
}
- if (indexDepth > originalType.getShape().size()) {
- return op->emitOpError(
- "index depth is greater than the rank of the original value");
+ if (dimMap.size() == 0) {
+ return op->emitOpError("dimension map must have at least one element");
}
+ const size_t indexDepth = getIndexDepth();
+ auto originalType = getOriginalType();
auto updateType = getUpdateType();
- auto batchRank = indicesType.getRank() - 1;
- if (updateType.getRank() < batchRank) {
- return op->emitOpError("expected update value to be of rank greater than "
- "or equal to rank(indices) - 1")
- << batchRank;
- }
-
- // Validate the shape of indices and update value match for the first
- // `batchRank` dims.
- auto [indicesIt, updateIt] =
- llvm::mismatch(indicesType.getShape().take_front(batchRank),
- updateType.getShape().take_front(batchRank));
- if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
+ const auto originalSliceRank = originalType.getRank() - indexDepth;
+ if (originalSliceRank < 0) {
return op->emitOpError(
- "mismatch in shape of indices and update value at dim#")
- << (indicesIt - indicesType.getShape().begin());
+ "expected original rank to be greater or equal to index depth");
}
-
- if (updateType.getRank() - batchRank > originalType.getRank()) {
- return op->emitOpError("update operand's slice rank (")
- << updateType.getRank() - batchRank
- << " = rank(updates) - batch rank) exceeds the rank of the original "
- "value ("
- << originalType.getRank() << ")";
- }
-
- // TODO: make it illegal for `numImplicitDims` to be non-zero.
- auto numImplicitDims = originalType.getRank() - getUpdateSliceRank();
- if (numImplicitDims > indexDepth) {
+ if (updateType.getRank() < originalSliceRank) {
return op->emitOpError(
- "update and index depth does not fully index original");
+ "expected update to be at least the rank of non indexed original dims");
+ }
+ const size_t batchRank = updateType.getRank() - originalSliceRank;
+
+ if (updateType.getRank() - batchRank != originalSliceRank) {
+ return op->emitOpError("expected rank of update value - batch rank to be "
+ "equal to rank of original value - index depth");
+ }
+
+ if ((indicesType.getRank() != batchRank || indexDepth != 1) &&
+ indicesType.getRank() != batchRank + 1) {
+ return op->emitOpError("expected indices to be equal to batch rank "
+ "or batch rank + 1");
+ }
+
+ {
+ // Validate the shape of indices and update value match for the first
+ // `batchRank` dims.
+ auto [indicesIt, updateIt] =
+ llvm::mismatch(indicesType.getShape().take_front(batchRank),
+ updateType.getShape().take_front(batchRank));
+ if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
+ return op->emitOpError(
+ "mismatch in shape of indices and update value at dim#")
+ << (indicesIt - indicesType.getShape().begin());
+ }
+ }
+ if (batchRank + 1 < indicesType.getShape().size() &&
+ dimMap.size() != indicesType.getShape().back()) {
+ return op->emitOpError(
+ "size of dimension map must match the last dimension of indices");
}
// updateSlice[0..indexDepth] <= original[0..indexDepth]
// updateSlice[indexDepth..] == original[indexDepth..]
- auto updateSliceShape = getUpdateSliceShape();
- for (uint64_t fullSliceIdx :
- llvm::seq<uint64_t>(numImplicitDims, indexDepth)) {
- int64_t originalDim = fullSliceIdx;
- int64_t updateSliceDim = fullSliceIdx - numImplicitDims;
- if (!originalType.isDynamicDim(originalDim) &&
- updateSliceShape[updateSliceDim] >
- originalType.getDimSize(originalDim)) {
- return op->emitOpError("shape of update value dim#")
- << updateSliceDim + batchRank << " exceeds original value at dim#"
- << originalDim;
- }
- }
- for (auto fullSliceIdx :
- llvm::seq<int64_t>(indexDepth, originalType.getRank())) {
- int64_t originalDim = fullSliceIdx;
- int64_t updateSliceDim = fullSliceIdx - numImplicitDims;
- if (!originalType.isDynamicDim(originalDim) &&
- updateSliceShape[updateSliceDim] !=
- originalType.getDimSize(originalDim)) {
+ {
+ auto [updateIt, originalIt] = llvm::mismatch(
+ getUpdateSliceShape(), originalType.getShape().drop_front(indexDepth));
+ if (updateIt != getUpdateSliceShape().end()) {
return op->emitOpError("shape of update value dim#")
- << updateSliceDim + batchRank
- << " must match original value at dim#" << originalDim;
+ << (updateIt - updateType.getShape().begin())
+ << " must match original value at dim#"
+ << (originalIt - originalType.getShape().begin());
}
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 7d43c09..b6c5791 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -182,9 +182,7 @@
static constexpr unsigned kOriginalOpNum = 2;
int64_t getIndexDepth() {
- return cast<ShapedType>(getDpsInputOperand(1)->get().getType())
- .getShape()
- .back();
+ return getDimensionMap().size();
}
Value getUpdates() {
@@ -214,7 +212,7 @@
/// Utility to get the rank of the portion of `indices` that
/// represents the batch dimensions
int64_t getBatchRank() {
- return getIndicesType().getRank() - 1;
+ return getUpdateType().getRank() - getUpdateSliceRank();
}
/// Utility to get the shape of the portion of `indices` that
@@ -226,7 +224,7 @@
/// Utility to get the rank of the portion of `updates` that
/// is scattered into `original`.
int64_t getUpdateSliceRank() {
- return getUpdateType().getRank() - getBatchRank();
+ return getOriginalType().getRank() - getIndexDepth();
}
/// Utility to get the shape of the portion of `updates` that
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
index ee7d053..21160d5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
@@ -117,9 +117,11 @@
// Slice of indices.
auto indicesRank = getIndicesType().getRank();
SmallVector<OpFoldResult> indicesOffsets(offsets.take_front(getBatchRank()));
- indicesOffsets.push_back(zeroAttr);
SmallVector<OpFoldResult> indicesSizes(sizes.take_front(getBatchRank()));
- indicesSizes.push_back(builder.getIndexAttr(getIndexDepth()));
+ if (getBatchRank() != getIndicesType().getRank()) {
+ indicesOffsets.push_back(zeroAttr);
+ indicesSizes.push_back(builder.getIndexAttr(getIndexDepth()));
+ }
SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
Operation *indicesSlice = getSlice(builder, loc, getIndices(), indicesOffsets,
@@ -228,7 +230,6 @@
SmallVector<Value> starts;
SmallVector<Value> loadIndices;
append_range(loadIndices, ivs.take_front(getBatchRank()));
- loadIndices.push_back(Value());
// Populate with empty values.
auto originalTy = getOriginalType();
@@ -242,8 +243,13 @@
ArrayRef<int64_t> dimMap = getDimensionMap();
+ if (getIndicesType().getRank() > getBatchRank()) {
+ loadIndices.push_back(Value());
+ }
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
- loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
+ if (getIndicesType().getRank() > getBatchRank()) {
+ loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
+ }
Value idx = b.create<memref::LoadOp>(loc, getIndices(), loadIndices);
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 93581a1..8df9b8d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -57,12 +57,12 @@
// -----
-func.func @scatter_mistmatch_dim_map_entries(
- %update : tensor<?x?xf32>, %indices : tensor<?x1xi32>,
+func.func @scatter_empty_dim_map(
+ %update : tensor<?x?xf32>, %indices : tensor<?x2xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
- // expected-error @+1 {{invalid number of dimension map entries}}
- %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
- ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ // expected-error @below {{'iree_linalg_ext.scatter' op dimension map must have at least one element}}
+ %0 = iree_linalg_ext.scatter dimension_map = [] unique_indices(true)
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x2xi32>)
outs(%original : tensor<?x?xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
@@ -190,16 +190,16 @@
func.func @scatter_dim_mismatch(
%update : tensor<48x?x2x11xf32>, %indices : tensor<48x?x1xi32>,
- %original : tensor<?x10xf32>) -> tensor<?x10xf32> {
- // expected-error @below {{'iree_linalg_ext.scatter' op shape of update value dim#3 must match original value at dim#1}}
+ %original : tensor<2x?x10xf32>) -> tensor<2x?x10xf32> {
+ // expected-error @below {{'iree_linalg_ext.scatter' op shape of update value dim#2 must match original value at dim#1}}
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%update, %indices : tensor<48x?x2x11xf32>, tensor<48x?x1xi32>)
- outs(%original : tensor<?x10xf32>) {
+ outs(%original : tensor<2x?x10xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1 = arith.addf %arg1, %arg2 : f32
iree_linalg_ext.yield %1 : f32
- } -> tensor<?x10xf32>
- return %0 : tensor<?x10xf32>
+ } -> tensor<2x?x10xf32>
+ return %0 : tensor<2x?x10xf32>
}
// -----
@@ -207,7 +207,7 @@
func.func @scatter_rank_mismatch(
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
- // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}}
+ // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
@@ -223,7 +223,7 @@
func.func @scatter_rank_mismatch(
%update : tensor<?x?x?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
- // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}}
+ // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%update, %indices : tensor<?x?x?x?xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?xf32>) {
@@ -239,7 +239,7 @@
func.func @scatter_rank_mismatch(
%update : tensor<?x?x?x?x?xf32>, %indices : tensor<?x?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
- // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}}
+ // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%update, %indices : tensor<?x?x?x?x?xf32>, tensor<?x?x1xi32>)
outs(%original : tensor<?x?xf32>) {
@@ -387,27 +387,10 @@
// -----
-func.func @scatter_index_depth_dynamic(
- %update : tensor<?x?xi64>, %indices : tensor<?x?xi32>,
- %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
- // expected-error @+1 {{expected index depth is static}}
- %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
- ins(%update, %indices : tensor<?x?xi64>, tensor<?x?xi32>)
- outs(%original : tensor<?x?xi64>) {
- ^bb0(%arg1: i64, %arg2: i64):
- %1 = arith.addi %arg1, %arg2 : i64
- %2 = arith.trunci %1 : i64 to i32
- iree_linalg_ext.yield %1, %2 : i64, i32
- } -> tensor<?x?xi64>
- return %0 : tensor<?x?xi64>
-}
-
-// -----
-
func.func @scatter_index_depth_too_large(
%original: tensor<?x?xf32>, %indices: tensor<?x3xi32>,
%update: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // expected-error @below {{'iree_linalg_ext.scatter' op index depth is greater than the rank of the original value}}
+ // expected-error @below {{'iree_linalg_ext.scatter' op expected update to be at least the rank of non indexed original dims}}
%0 = iree_linalg_ext.scatter
dimension_map = [0, 1, 2]
unique_indices(true)
@@ -425,7 +408,7 @@
func.func @scatter_index_depth_too_small(
%update : tensor<?x1xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?x1xf32>) -> tensor<?x?xf32> {
- // expected-error @below {{'iree_linalg_ext.scatter' op update and index depth does not fully index original}}
+ // expected-error @below {{'iree_linalg_ext.scatter' op expected indices to be equal to batch rank or batch rank + 1}}
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
ins(%update, %indices : tensor<?x1xf32>, tensor<?x1xi32>)
outs(%original : tensor<?x?x1xf32>) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 1bc505f..c9c9b4b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -101,6 +101,34 @@
// -----
+func.func @scatter_tensor_dynamic_implicit_indices(
+ %original: tensor<?x?xf32>, %indices: tensor<?xi32>,
+ %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = iree_linalg_ext.scatter
+ dimension_map = [0]
+ unique_indices(true)
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+ outs(%original: tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @scatter_tensor_dynamic_implicit_indices(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: dimension_map = [0]
+// CHECK-SAME: unique_indices(true)
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @scatter_repeated_tensor_dynamic(
%original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
%update: tensor<?x?xf32>) -> tensor<?x?xf32> {
@@ -157,6 +185,34 @@
// -----
+func.func @scatter_tensor_static_implicit_indices(
+ %original: tensor<128x3xf32>, %indices: tensor<48xi32>,
+ %update: tensor<48x3xf32>) -> tensor<128x3xf32> {
+ %0 = iree_linalg_ext.scatter
+ dimension_map = [0]
+ unique_indices(true)
+ ins(%update, %indices : tensor<48x3xf32>, tensor<48xi32>)
+ outs(%original: tensor<128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = arith.addf %arg1, %arg2 : f32
+ iree_linalg_ext.yield %1 : f32
+ } -> tensor<128x3xf32>
+ return %0 : tensor<128x3xf32>
+}
+// CHECK-LABEL: func.func @scatter_tensor_static_implicit_indices(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK: dimension_map = [0]
+// CHECK-SAME: unique_indices(true)
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @scatter_tensor_multi_index_depth(
%original: tensor<1x128x3xf32>, %indices: tensor<48x2xi32>,
%update: tensor<48x3xf32>) -> tensor<1x128x3xf32> {
@@ -270,12 +326,12 @@
// -----
func.func @scatter_update_scalar_1D(
- %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
+ %original: tensor<8xi32>, %indices: tensor<3xi32>,
%updates: tensor<3xi32>) -> tensor<8xi32> {
%0 = iree_linalg_ext.scatter
dimension_map = [0]
unique_indices(true)
- ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>)
+ ins(%updates, %indices : tensor<3xi32>, tensor<3xi32>)
outs(%original : tensor<8xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
@@ -483,11 +539,11 @@
func.func @scatter_update_slice_2D(
%original: tensor<4x?xi32>, %indices: tensor<1x1xi32>,
- %updates: tensor<1x3xi32>) -> tensor<4x?xi32> {
+ %updates: tensor<1x?xi32>) -> tensor<4x?xi32> {
%0 = iree_linalg_ext.scatter
dimension_map = [0]
unique_indices(true)
- ins(%updates, %indices : tensor<1x3xi32>, tensor<1x1xi32>)
+ ins(%updates, %indices : tensor<1x?xi32>, tensor<1x1xi32>)
outs(%original : tensor<4x?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 6ad28e5..b49f1e2 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -8,7 +8,6 @@
// The content of this file is adapted from linalg's ElemenwiseOpFusion.cpp and
// modified to work with LinalgExt ops, specifically `LinalgExt::AttentionOp`.
-#include <optional>
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
@@ -19,6 +18,9 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
+#include <cstdint>
+#include <optional>
+
namespace mlir::iree_compiler::IREE::LinalgExt {
static bool
@@ -315,13 +317,7 @@
ReshapeOperandInfo updateInfo;
updateInfo.originalShape = scatterOp.getUpdateType().getShape();
llvm::append_range(updateInfo.operandToIterationSpace,
- llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
- updateInfo.operandToIterationSpace.append(
- updateRank - (rankOfContiguousSlice + scatterOp.getBatchRank()),
- ReshapeOperandInfo::kNoMapping);
- llvm::append_range(
- updateInfo.operandToIterationSpace,
- llvm::seq(updateRank - rankOfContiguousSlice, updateRank));
+ llvm::seq<int64_t>(0, updateRank));
infos.push_back(std::move(updateInfo));
}
@@ -331,8 +327,9 @@
indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
llvm::append_range(indicesInfo.operandToIterationSpace,
llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
- indicesInfo.operandToIterationSpace.push_back(
- ReshapeOperandInfo::kNoMapping);
+ if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank())
+ indicesInfo.operandToIterationSpace.push_back(
+ ReshapeOperandInfo::kNoMapping);
infos.push_back(std::move(indicesInfo));
}
@@ -629,71 +626,30 @@
linalg::ControlFusionFn controlFoldingReshapes;
};
-/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
-/// The dims in `update` between the batch dims and the continuous slice
-/// represent the indexed dimensions. Remove the leading unit dims from the
-/// indexed dims.
-struct FoldScatterNonIterationUnitDims final
- : public OpRewritePattern<ScatterOp> {
- FoldScatterNonIterationUnitDims(MLIRContext *context,
- linalg::ControlDropUnitDims options,
- PatternBenefit benefit = 1)
- : OpRewritePattern<ScatterOp>(context, benefit),
- options(std::move(options)) {}
-
+struct DropScatterUnitIndexDepth final : public OpRewritePattern<ScatterOp> {
+ using OpRewritePattern<ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
- if (options.rankReductionStrategy !=
- linalg::ControlDropUnitDims::RankReductionStrategy::
- ReassociativeReshape) {
- return rewriter.notifyMatchFailure(
- scatterOp, "Only reassociative reshape strategy supported");
- }
- llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
- const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();
-
- // Find the number of leading unit dimensions
- int64_t rankOfContiguousSlice =
- scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
- ArrayRef<int64_t> indexedDims =
- scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice);
- int64_t numDimsToDrop =
- llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) -
- scatterOp.getUpdateSliceShape().begin() - 1;
-
+ llvm::ArrayRef<int64_t> indicesShape =
+ scatterOp.getIndicesType().getShape();
int64_t batchRank = scatterOp.getBatchRank();
- llvm::erase_if(canDrop, [&](unsigned dimPos) {
- return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
- });
- if (canDrop.empty()) {
+ if (indicesShape.size() == batchRank || indicesShape.back() != 1) {
return failure();
}
-
- SmallVector<int64_t> droppedUpdateShape;
- droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
- for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
- if (!llvm::is_contained(canDrop, idx)) {
- droppedUpdateShape.push_back(dimLen);
- }
+ SmallVector<ReassociationIndices> reassoc;
+ reassoc.reserve(indicesShape.size());
+ for (auto i : llvm::seq<int64_t>(0, batchRank - 1)) {
+ reassoc.emplace_back(1, i);
}
-
- auto reassoc =
- getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
- assert(reassoc.has_value() && "expected reassociation to be valid");
+ reassoc.push_back(ReassociationIndices{batchRank - 1, batchRank});
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
- scatterOp.getLoc(),
- RankedTensorType::get(droppedUpdateShape,
- scatterOp.getUpdateType().getElementType()),
- scatterOp.getUpdates(), reassoc.value());
+ scatterOp.getLoc(), scatterOp.getIndices(), reassoc);
rewriter.modifyOpInPlace(scatterOp, [&]() {
- scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult());
+ scatterOp.setOperand(ScatterOp::kIndicesOpNum, collapseOp.getResult());
});
return success();
}
-
-private:
- linalg::ControlDropUnitDims options;
};
struct FoldScatterWithProducerReshapeByExpansion final
@@ -997,7 +953,7 @@
void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
- patterns.add<FoldScatterNonIterationUnitDims>(patterns.getContext(), options);
+ patterns.add<DropScatterUnitIndexDepth>(patterns.getContext());
}
} // namespace mlir::iree_compiler::IREE::LinalgExt
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index c5b624b..0cfd0b7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -356,42 +356,6 @@
// -----
-func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) {
- iree_linalg_ext.scatter
- dimension_map = [0, 1, 2]
- unique_indices(true)
- ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>)
- outs(%arg0 : memref<2x64x12xf32>) {
- ^bb0(%arg3: f32, %arg4: f32):
- iree_linalg_ext.yield %arg4 : f32
- }
- return
-}
-
-// CHECK-LABEL: @scatter_partial_slices
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] {
-// CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] {
-// CHECK-NEXT: %[[LOAD0:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C0]]] : memref<2x3xi32>
-// CHECK-NEXT: %[[CAST0:.+]] = arith.index_cast %[[LOAD0]] : i32 to index
-// CHECK-NEXT: %[[LOAD1:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C1]]] : memref<2x3xi32>
-// CHECK-NEXT: %[[CAST1:.+]] = arith.index_cast %[[LOAD1]] : i32 to index
-// CHECK-NEXT: %[[ADD1:.+]] = arith.addi %[[CAST1]], %[[ARG4]] : index
-// CHECK-NEXT: %[[LOAD2:.+]] = memref.load %[[ARG1]][%[[ARG3]], %[[C2]]] : memref<2x3xi32>
-// CHECK-NEXT: %[[CAST2:.+]] = arith.index_cast %[[LOAD2]] : i32 to index
-// CHECK-NEXT: %[[ADD2:.+]] = arith.addi %[[CAST2]], %[[ARG5]] : index
-// CHECK-NEXT: %[[LOAD3:.+]] = memref.load %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
-// CHECK-NEXT: memref.store %[[LOAD3]], %[[ARG0]][%[[CAST0]], %[[ADD1]], %[[ADD2]]] : memref<2x64x12xf32>
-
-// -----
-
func.func @fft_1D(%real: memref<16xf32>, %imag: memref<16xf32>) {
%stage = arith.constant 1 : index
iree_linalg_ext.fft
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 403f5e2..53cb5e4 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -198,9 +198,9 @@
func.func @scatter_batch_2D(
%original: memref<?xi32>, %indices: memref<?x?x1xi32>,
- %updates: memref<?x?x?xi32>) {
+ %updates: memref<?x?xi32>) {
iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
- ins(%updates, %indices : memref<?x?x?xi32>, memref<?x?x1xi32>)
+ ins(%updates, %indices : memref<?x?xi32>, memref<?x?x1xi32>)
outs(%original : memref<?xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
iree_linalg_ext.yield %arg0 : i32
@@ -221,27 +221,23 @@
// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]]
-// CHECK-DAG: %[[D2:.+]] = memref.dim %[[UPDATES]], %[[C2]]
// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
// CHECK: %[[SZ:.+]] = affine.min #[[MAP]](%[[I]])[%[[D1]]]
// CHECK: %[[UPDATES_TILE:.+]] = memref.subview
-// CHECK-SAME: %[[UPDATES]][0, %[[I]], 0]
-// CHECK-SAME: [%[[D0]], %[[SZ]], %[[D2]]]
+// CHECK-SAME: %[[UPDATES]][0, %[[I]]]
+// CHECK-SAME: [%[[D0]], %[[SZ]]]
// CHECK: %[[INDICES_TILE:.+]] = memref.subview
// CHECK-SAME: %[[INDICES]][0, %[[I]], 0]
// CHECK-SAME: [%[[D0]], %[[SZ]], 1]
// CHECK: %[[ORIGINAL_TILE:.+]] = memref.subview
// CHECK-SAME: %[[ORIGINAL]][0]
-// CHECK-SAME: [%[[D2]]]
-// CHECK: %[[ORIG_CAST:.+]] = memref.cast %[[ORIGINAL_TILE]]
// CHECK: iree_linalg_ext.scatter
// CHECK-SAME: unique_indices(true)
// CHECK-SAME: ins(%[[UPDATES_TILE]], %[[INDICES_TILE]]
-// CHECK-SAME: outs(%[[ORIG_CAST]]
+// CHECK-SAME: outs(%[[ORIGINAL_TILE]]
// -----
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
index 244b02e..247fc7e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/attention_fuse_by_expansion.mlir
@@ -507,8 +507,8 @@
// -----
util.func @scatter_collapse_updates_partial(%arg0: tensor<4x?x2x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<10x16x4x128xf16>) -> tensor<10x16x4x128xf16> {
- %collapsed = tensor.collapse_shape %arg0[[0, 1], [2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x4x16x4x128xf16>
- %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x4x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) {
+ %collapsed = tensor.collapse_shape %arg0[[0, 1, 2, 3], [4], [5], [6]] : tensor<4x?x2x2x16x4x128xf16> into tensor<?x16x4x128xf16>
+ %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<10x16x4x128xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<10x16x4x128xf16>
@@ -519,10 +519,9 @@
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
-// CHECK-DAG: %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x1xi32>
-// CHECK-DAG: %[[UPDATES:.+]] = tensor.collapse_shape %[[ARG0]] {{.*}} tensor<4x?x2x2x16x4x128xf16> into tensor<4x?x4x16x4x128xf16>
+// CHECK-DAG: %[[INDICES:.+]] = tensor.expand_shape %[[ARG1]] {{.*}} tensor<?x1xi32> into tensor<4x?x2x2x1xi32>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
-// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]]
+// CHECK-SAME: ins(%[[ARG0]], %[[INDICES]]
// CHECK-SAME: outs(%[[ARG2]]
// CHECK: util.return %[[SCATTER]]
diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
index 62dbba7..e92ed2e 100644
--- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
+++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir
@@ -109,44 +109,21 @@
// -----
-util.func public @scatter0(%arg0: tensor<?x1x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
- %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
- ^bb0(%arg3: f16, %arg4: f16):
- iree_linalg_ext.yield %arg3 : f16
- } -> tensor<?x2x16x4x128xf16>
- util.return %0 : tensor<?x2x16x4x128xf16>
+util.func public @scatter(%arg0 : tensor<4xi64>, %arg1 : tensor<4x1xi32>, %arg2 : tensor<4xi64>) -> tensor<4xi64> {
+ %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%arg0, %arg1: tensor<4xi64>, tensor<4x1xi32>) outs(%arg2 : tensor<4xi64>) {
+ ^bb0(%arg3: i64, %arg4: i64):
+ %16 = arith.addi %arg4, %arg3 : i64
+ iree_linalg_ext.yield %16 : i64
+ } -> tensor<4xi64>
+ util.return %0 : tensor<4xi64>
}
-// CHECK-LABEL: func public @scatter0
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
-// CHECK-SAME: to tensor<?x2x16x4x128xf16>
+// CHECK-LABEL: func public @scatter
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG1]]
+// CHECK-SAME: tensor<4x1xi32> into tensor<4xi32>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
-// CHECK-SAME: ins(%[[COLLAPSE]]
-
-// -----
-
-util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
- %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x2xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
- ^bb0(%arg3: f16, %arg4: f16):
- iree_linalg_ext.yield %arg3 : f16
- } -> tensor<?x2x16x4x128xf16>
- util.return %0 : tensor<?x2x16x4x128xf16>
-}
-// CHECK-LABEL: func public @scatter1
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
-// CHECK-SAME: to tensor<?x16x4x128xf16>
-// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
-// CHECK-SAME: ins(%[[COLLAPSE]]
-
-// -----
-
-// TODO: remove other unit dims.
-util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor<?x2x1x4x128xf16>) -> tensor<?x2x1x4x128xf16> {
- %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor<?x2x1x4x128xf16>) {
- ^bb0(%arg3: f16, %arg4: f16):
- iree_linalg_ext.yield %arg3 : f16
- } -> tensor<?x2x1x4x128xf16>
- util.return %0 : tensor<?x2x1x4x128xf16>
-}
-// CHECK-LABEL: func public @scatter_noop
-// CHECK-NOT: tensor.collapse_shape
-// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: ins(%[[ARG0]], %[[COLLAPSED]]
+// CHECK-SAME: outs(%[[ARG2]]
+// CHECK: util.return %[[SCATTER]]
diff --git a/tests/e2e/linalg_ext_ops/scatter.mlir b/tests/e2e/linalg_ext_ops/scatter.mlir
index 808475a..8980f20 100644
--- a/tests/e2e/linalg_ext_ops/scatter.mlir
+++ b/tests/e2e/linalg_ext_ops/scatter.mlir
@@ -15,39 +15,6 @@
return
}
-func.func @scatter_2d_origin_slice_horizontal() {
- %original = util.unfoldable_constant dense<0> : tensor<2x2xi32>
- %update = util.unfoldable_constant dense<1> : tensor<1x2xi32>
- %indices = util.unfoldable_constant dense<0> : tensor<1x2xi32>
- %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
- ins(%update, %indices : tensor<1x2xi32>, tensor<1x2xi32>)
- outs(%original : tensor<2x2xi32>) {
- ^bb0(%arg0: i32, %arg1: i32):
- iree_linalg_ext.yield %arg0 : i32
- } -> tensor<2x2xi32>
-
- check.expect_eq_const(%result, dense<[[1, 1], [0, 0]]> : tensor<2x2xi32>) : tensor<2x2xi32>
-
- return
-}
-
-
-func.func @scatter_2d_origin_slice_vertical() {
- %original = util.unfoldable_constant dense<0> : tensor<2x2xi32>
- %update = util.unfoldable_constant dense<1> : tensor<1x2x1xi32>
- %indices = util.unfoldable_constant dense<0> : tensor<1x2xi32>
- %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
- ins(%update, %indices : tensor<1x2x1xi32>, tensor<1x2xi32>)
- outs(%original : tensor<2x2xi32>) {
- ^bb0(%arg0: i32, %arg1: i32):
- iree_linalg_ext.yield %arg0 : i32
- } -> tensor<2x2xi32>
-
- check.expect_eq_const(%result, dense<[[1, 0], [1, 0]]> : tensor<2x2xi32>) : tensor<2x2xi32>
-
- return
-}
-
func.func @scatter_2d_offset() {
%original = util.unfoldable_constant dense<0> : tensor<2x2xi32>
%update = util.unfoldable_constant dense<1> : tensor<1xi32>
@@ -127,35 +94,3 @@
return
}
-
-func.func @scatter_2d_multiple_slice() {
- %original = util.unfoldable_constant dense<0> : tensor<3x3xi32>
- %update = util.unfoldable_constant dense<1> : tensor<2x2xi32>
- %indices = util.unfoldable_constant dense<[[0, 1], [1, 0]]> : tensor<2x2xi32>
- %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true)
- ins(%update, %indices : tensor<2x2xi32>, tensor<2x2xi32>)
- outs(%original : tensor<3x3xi32>) {
- ^bb0(%arg0: i32, %arg1: i32):
- iree_linalg_ext.yield %arg0 : i32
- } -> tensor<3x3xi32>
-
- check.expect_eq_const(%result, dense<[[0, 1, 1], [1, 1, 0], [0, 0, 0]]> : tensor<3x3xi32>) : tensor<3x3xi32>
-
- return
-}
-
-func.func @scatter_2d_multiple_slice_transpose() {
- %original = util.unfoldable_constant dense<0> : tensor<3x4xi32>
- %update = util.unfoldable_constant dense<1> : tensor<2x2xi32>
- %indices = util.unfoldable_constant dense<[[0, 1], [2, 0]]> : tensor<2x2xi32>
- %result = iree_linalg_ext.scatter dimension_map = [1, 0] unique_indices(true)
- ins(%update, %indices : tensor<2x2xi32>, tensor<2x2xi32>)
- outs(%original : tensor<3x4xi32>) {
- ^bb0(%arg0: i32, %arg1: i32):
- iree_linalg_ext.yield %arg0 : i32
- } -> tensor<3x4xi32>
-
- check.expect_eq_const(%result, dense<[[0, 0, 1, 1], [1, 1, 0, 0], [0, 0, 0, 0]]> : tensor<3x4xi32>) : tensor<3x4xi32>
-
- return
-}