[LinalgExt] Add mask operand to linalg_ext gather/scatter (#24126)
Currently gather/scatter operations have no way of representing masked
read/writes. For gather, this can be bypassed by reading from a sentinel
value like 0, but for scatter, this is impossible. Scatters generally
need masking to be correct.
This is generally required in LLM models where the input tokens may be
padded and some of them may need to be masked out accordingly before
writing to the KVCache (masked tokens shouldn't contribute to anything).
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
index b6c9c8c..448ca99 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
@@ -283,8 +283,8 @@
}
auto scatterOp = IREE::LinalgExt::ScatterOp::create(
- rewriter, op.getLoc(), originalType, updates, indices, original,
- scatterDimMap, op.getUniqueIndices());
+ rewriter, op.getLoc(), originalType, updates, indices,
+ /*mask=*/Value(), original, scatterDimMap, op.getUniqueIndices());
rewriter.inlineRegionBefore(op.getUpdateComputation(),
scatterOp.getRegion(),
diff --git a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp
index 4b706bc..dbb5a00 100644
--- a/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp
@@ -137,6 +137,7 @@
// Create the LinalgExt scatter operation.
auto scatter = IREE::LinalgExt::ScatterOp::create(
builder, TypeRange{values.getType()}, updates, indices,
+ /*mask=*/Value(),
/*original=*/values, builder.getDenseI64ArrayAttr({0, 1}),
builder.getBoolAttr(true));
diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
index 8ef8d72..85121b5 100644
--- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
+++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp
@@ -92,7 +92,8 @@
auto scatterOp = IREE::LinalgExt::ScatterOp::create(
rewriter, op.getLoc(), op->getResultTypes(),
/*updates=*/updateVal, /*indices=*/indicesVal,
- /*original=*/op.getOutputs()[0], dimMap, op.getUniqueIndices());
+ /*mask=*/Value(), /*original=*/op.getOutputs()[0], dimMap,
+ op.getUniqueIndices());
rewriter.inlineRegionBefore(op.getRegion(), scatterOp.getRegion(),
scatterOp.getRegion().begin());
rewriter.replaceOp(op, scatterOp->getResults());
diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp
index 0777484..1148f25 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUConvertToCoalescedDMA.cpp
@@ -681,6 +681,10 @@
LogicalResult matchAndRewrite(IREE::LinalgExt::GatherOp gatherOp,
PatternRewriter &rewriter) const override {
+ // TODO: Add support for masked gather.
+ if (gatherOp.getMask()) {
+ return failure();
+ }
auto forallOp = gatherOp->getParentOfType<scf::ForallOp>();
if (!hasWarpMapping(forallOp)) {
return failure();
diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagation.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagation.cpp
index ce040b9..60c7bcd 100644
--- a/compiler/src/iree/compiler/Codegen/Common/TypePropagation.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagation.cpp
@@ -372,19 +372,22 @@
LogicalResult
matchAndRewrite(IREE::LinalgExt::ScatterOp scatterOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
- auto opOperands = scatterOp->getOpOperands();
- Type inputType = opOperands[0].get().getType();
+ Type inputType = scatterOp.getUpdates().getType();
Type legalizedInputType = this->getTypeConverter()->convertType(inputType);
+ Type maskType =
+ scatterOp.getMask() ? scatterOp.getMask().getType() : Type();
+ Type legalizedMaskType =
+ maskType ? this->getTypeConverter()->convertType(maskType) : Type();
+ Type resultType = scatterOp.getOriginal().getType();
+ Type legalizedResultType =
+ this->getTypeConverter()->convertType(resultType);
- if (inputType == legalizedInputType) {
+ if (inputType == legalizedInputType && maskType == legalizedMaskType &&
+ resultType == legalizedResultType) {
return scatterOp.emitOpError(
"unexpected all types legal within conversion pattern");
}
- Type resultType = opOperands[2].get().getType();
- Type legalizedResultType =
- this->getTypeConverter()->convertType(resultType);
-
// Create a clone of the operation without cloning its regions.
auto modifiedOp =
cast<IREE::LinalgExt::ScatterOp>(mlir::cloneWithoutRegions(
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
index 77c01e8..64662cc 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/type_propagation.mlir
@@ -451,6 +451,83 @@
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @scatter_mask_only_illegal() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8xi32>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8x1xi32>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8xi8>>
+ %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<3xi32>>
+ %4 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0], sizes = [8], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8xi32>> -> tensor<8xi32>
+ %5 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [8, 1], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8x1xi32>> -> tensor<8x1xi32>
+ %6 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0], sizes = [8], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8xi8>> -> tensor<8xi8>
+ %7 = arith.trunci %6 : tensor<8xi8> to tensor<8xi1>
+ %8 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0], sizes = [3], strides = [1] : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<3xi32>> -> tensor<3xi32>
+ %9 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%4, %5, %7 : tensor<8xi32>, tensor<8x1xi32>, tensor<8xi1>)
+ outs(%8 : tensor<3xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32):
+ iree_linalg_ext.yield %arg0 : i32
+ } -> tensor<3xi32>
+ iree_tensor_ext.dispatch.tensor.store %9, %3, offsets = [0], sizes = [3], strides = [1] : tensor<3xi32> -> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<3xi32>>
+ return
+}
+// CHECK-LABEL: func.func @scatter_mask_only_illegal()
+// CHECK-DAG: %[[UPDATES:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<8xi32>
+// CHECK-DAG: %[[INDICES:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<8x1xi32>
+// CHECK-DAG: %[[MASK:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<8xi8>
+// CHECK-DAG: %[[OUT:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<3xi32>
+// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]], %[[MASK]] : tensor<8xi32>, tensor<8x1xi32>, tensor<8xi8>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<3xi32>)
+// CHECK-NEXT: ^bb0(%[[ARG0:[a-zA-Z0-9]+]]: i32, %[[ARG1:[a-zA-Z0-9]+]]: i32)
+// CHECK: iree_linalg_ext.yield %[[ARG0]] : i32
+// CHECK: iree_tensor_ext.dispatch.tensor.store %[[SCATTER]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>,
+ #hal.pipeline.binding<storage_buffer>
+]>
+func.func @gather_mask_only_illegal() {
+ %c0 = arith.constant 0 : index
+ %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x4xi32>>
+ %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8x1xi32>>
+ %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8xi8>>
+ %3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<8x4xi32>>
+ %4 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [16, 4], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<16x4xi32>> -> tensor<16x4xi32>
+ %5 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [8, 1], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8x1xi32>> -> tensor<8x1xi32>
+ %6 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0], sizes = [8], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<8xi8>> -> tensor<8xi8>
+ %7 = arith.trunci %6 : tensor<8xi8> to tensor<8xi1>
+ %8 = iree_tensor_ext.dispatch.tensor.load %3, offsets = [0, 0], sizes = [8, 4], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readwrite:tensor<8x4xi32>> -> tensor<8x4xi32>
+ %9 = iree_linalg_ext.gather
+ dimension_map = [0]
+ ins(%4, %5, %7 : tensor<16x4xi32>, tensor<8x1xi32>, tensor<8xi1>)
+ outs(%8 : tensor<8x4xi32>) -> tensor<8x4xi32>
+ iree_tensor_ext.dispatch.tensor.store %9, %3, offsets = [0, 0], sizes = [8, 4], strides = [1, 1] : tensor<8x4xi32> -> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<8x4xi32>>
+ return
+}
+// CHECK-LABEL: func.func @gather_mask_only_illegal()
+// CHECK-DAG: %[[SOURCE:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<16x4xi32>
+// CHECK-DAG: %[[INDICES:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<8x1xi32>
+// CHECK-DAG: %[[MASK:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<8xi8>
+// CHECK-DAG: %[[OUT:.+]] = iree_tensor_ext.dispatch.tensor.load %{{.+}} : {{.+}} -> tensor<8x4xi32>
+// CHECK: %[[GATHER:.+]] = iree_linalg_ext.gather
+// CHECK-SAME: dimension_map = [0]
+// CHECK-SAME: ins(%[[SOURCE]], %[[INDICES]], %[[MASK]] : tensor<16x4xi32>, tensor<8x1xi32>, tensor<8xi8>)
+// CHECK-SAME: outs(%[[OUT]] : tensor<8x4xi32>) -> tensor<8x4xi32>
+// CHECK: iree_tensor_ext.dispatch.tensor.store %[[GATHER]]
+
+// -----
+
+#pipeline_layout = #hal.pipeline.layout<bindings = [
+ #hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
func.func @sort() {
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir
index 7b5d2aa..a782891 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir
+++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir
@@ -173,6 +173,28 @@
// -----
+func.func @linalg_ext_gather_no_vectorize_operand_mask(%source : tensor<1024x128xi32>,
+ %indices : tensor<10xi32>,
+ %mask : tensor<10xi1>,
+ %seed : tensor<10x128xi32>)
+ -> tensor<10x128xi32> {
+ %result = iree_linalg_ext.gather dimension_map = [0]
+ ins(%source, %indices, %mask : tensor<1024x128xi32>, tensor<10xi32>, tensor<10xi1>)
+ outs(%seed : tensor<10x128xi32>) -> tensor<10x128xi32>
+ return %result : tensor<10x128xi32>
+}
+// CHECK-LABEL: @linalg_ext_gather_no_vectorize_operand_mask
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK: %[[GATHER:.+]] = iree_linalg_ext.gather
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]]
+// CHECK-SAME: outs(%[[ARG3]]
+// CHECK: return %[[GATHER]]
+
+// -----
+
func.func @linalg_ext_gather_unit_dim(%source : tensor<1024x128xi32>, %indices : tensor<10x1xi32>) -> (tensor<10x128xi32>) {
%empty = tensor.empty() : tensor<10x128xi32>
%result = iree_linalg_ext.gather dimension_map = [0]
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp
index 10e8571..5d05601 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/VectorizableOpInterface.cpp
@@ -58,6 +58,10 @@
ArrayRef<bool> scalableDims,
DictionaryAttr options) const {
auto gatherOp = cast<IREE::LinalgExt::GatherOp>(op);
+ // TODO: Support operand masks by plumbing them through transfer_gather.
+ if (gatherOp.getMask()) {
+ return false;
+ }
// TODO: Support indexDepth > 1 by splitting the innermost dim of
// `indices` into `indexDepth` vectors so that each independent index can
// be passed to the transfer_gather op.
@@ -72,6 +76,10 @@
ArrayRef<bool> scalableDims,
DictionaryAttr options) const {
auto gatherOp = cast<IREE::LinalgExt::GatherOp>(op);
+ // TODO: Support operand masks by plumbing them through transfer_gather.
+ if (gatherOp.getMask()) {
+ return failure();
+ }
int64_t batchRank = gatherOp.getBatchRank();
Location loc = gatherOp.getLoc();
RewriterBase::InsertionGuard g(rewriter);
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index a10cfbc..316338e 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -143,6 +143,21 @@
});
}
+static AffineMap getLeadingDimsProjectionMap(MLIRContext *ctx, int64_t dimCount,
+ int64_t projectedDimCount) {
+ SmallVector<AffineExpr> exprs;
+ exprs.reserve(projectedDimCount);
+ for (int64_t i = 0; i < projectedDimCount; ++i) {
+ exprs.push_back(getAffineDimExpr(i, ctx));
+ }
+ return AffineMap::get(dimCount, /*symbolCount=*/0, exprs, ctx);
+}
+
+static bool isSupportedMaskElementType(Type type) {
+ auto intType = dyn_cast<IntegerType>(type);
+ return intType && (intType.getWidth() == 1 || intType.getWidth() == 8);
+}
+
/// Helper function to verify both `scatter` and `gather`. Since both ops share
/// the same semantics, we can use the same function to verify them. Note: this
/// is written from the perspective of `scatter` op. For gather, `updateType`
@@ -219,6 +234,26 @@
"size of dimension map must match the last dimension of indices");
}
+ if (std::optional<ShapedType> maybeMaskType = op.getMaskType()) {
+ auto maskType = *maybeMaskType;
+ if (!isSupportedMaskElementType(maskType.getElementType())) {
+ return op->emitOpError(
+ "expected mask to have i1 or storage-legalized i8 element type");
+ }
+ if (maskType.getRank() != static_cast<int64_t>(batchRank)) {
+ return op->emitOpError("expected mask rank to match batch rank");
+ }
+ for (auto dim : llvm::seq<int64_t>(0, static_cast<int64_t>(batchRank))) {
+ if (maskType.isDynamicDim(dim) || updateType.isDynamicDim(dim)) {
+ continue;
+ }
+ if (maskType.getDimSize(dim) != updateType.getDimSize(dim)) {
+ return op->emitOpError("mask shape must match batch dimensions at dim#")
+ << dim;
+ }
+ }
+ }
+
{
for (auto idx : llvm::seq<int64_t>(0, sliceRank)) {
int64_t updateDim = idx + batchRank;
@@ -453,9 +488,15 @@
SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
Builder builder(getContext());
- return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
- builder.getMultiDimIdentityMap(getIndicesType().getRank()),
- /*output=*/AffineMap(nullptr)};
+ SmallVector<AffineMap> maps = {
+ builder.getMultiDimIdentityMap(getUpdateType().getRank()),
+ builder.getMultiDimIdentityMap(getIndicesType().getRank())};
+ if (getMask()) {
+ maps.push_back(getLeadingDimsProjectionMap(
+ getContext(), getUpdateType().getRank(), getBatchRank()));
+ }
+ maps.push_back(/*output=*/AffineMap(nullptr));
+ return maps;
}
SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
@@ -484,10 +525,15 @@
SmallVector<AffineMap> GatherOp::getIndexingMapsForOperands() {
Builder builder(getContext());
- return SmallVector<AffineMap>{
+ SmallVector<AffineMap> maps = {
AffineMap(nullptr),
- builder.getMultiDimIdentityMap(getIndicesType().getRank()),
- builder.getMultiDimIdentityMap(getOutputType().getRank())};
+ builder.getMultiDimIdentityMap(getIndicesType().getRank())};
+ if (getMask()) {
+ maps.push_back(getLeadingDimsProjectionMap(
+ getContext(), getOutputType().getRank(), getBatchRank()));
+ }
+ maps.push_back(builder.getMultiDimIdentityMap(getOutputType().getRank()));
+ return maps;
}
SmallVector<AffineMap> GatherOp::getIndexingMapsForResults() {
@@ -502,7 +548,7 @@
LogicalResult matchAndRewrite(IREE::LinalgExt::GatherOp gatherOp,
PatternRewriter &rewriter) const override {
// TODO: support memref case.
- if (!gatherOp.hasPureTensorSemantics()) {
+ if (!gatherOp.hasPureTensorSemantics() || gatherOp.getMask()) {
return failure();
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 01f198d..980c0bc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -59,7 +59,8 @@
"getTiledImplementationFromOperandTiles"]>]> {
let summary = [{Scatters an input in slices based on a tensor of indices.}];
let description = [{
- Takes two `inputs` (`update` and `indices`) and `outputs` value (`original`).
+ Takes two or three `inputs` (`update`, `indices`, and optional `mask`) and
+ one `outputs` value (`original`).
The operation updates the value at the slices specified by `indices` by
combining the current value with the value in `updates` using the computation
specified in `region`. The `region` specifies a binary operation
@@ -88,6 +89,11 @@
`update` gets scattered to.
Where `rank(original) = rank(update_slice) + index_depth`
+ If the optional operand `mask` is present, it must have batch shape and a
+ boolean element type (`i1`, or storage-legalized `i8` in backend
+ pipelines). A `false` mask value suppresses the corresponding scatter
+ update, leaving `original` unchanged for that batch element.
+
The unique_indices attribute carries the information whether all the
indices are unique. If `unique_indices` is `true` and two or more updates
scatter to the same location in `original` the final value in `original` is
@@ -100,6 +106,7 @@
let arguments = (ins
AnyRankedTensorOrMemRef:$updates,
AnyRankedTensorOrMemRef:$indices,
+ Optional<AnyRankedTensorOrMemRef>:$mask,
AnyRankedTensorOrMemRef:$original,
DenseI64ArrayAttr:$dimension_map,
DefaultValuedAttr<BoolAttr, "true">:$unique_indices
@@ -109,7 +116,7 @@
let assemblyFormat = [{
attr-dict `dimension_map` `=` $dimension_map
`unique_indices` `(` $unique_indices `)`
- `ins` `(` $updates `,` $indices `:` type($updates) `,` type($indices) `)`
+ `ins` `(` $updates `,` $indices (`,` $mask^)? `:` type($updates) `,` type($indices) (`,` type($mask)^)? `)`
`outs` `(` $original `:` type($original) `)`
$region (`->` type($results)^)?
}];
@@ -126,6 +133,13 @@
return cast<ShapedType>(getIndices().getType());
}
+ std::optional<ShapedType> getMaskType() {
+ if (Value mask = getMask()) {
+ return cast<ShapedType>(mask.getType());
+ }
+ return std::nullopt;
+ }
+
ShapedType getOriginalType() {
return cast<ShapedType>(getOriginal().getType());
}
@@ -174,8 +188,9 @@
"generateResultTileValue"]>]> {
let summary = [{Gathers slices from a source based on a tensor of indices.}];
let description = [{
- Takes two inputs (`source` and `indices`) and outputs value (`output`).
- The operation returns the value at the slices specified by `indices`.
+ Takes two or three inputs (`source`, `indices`, and optional `mask`) and
+ one output value (`output`). The operation returns the value at the slices
+ specified by `indices`.
The size of the `dimension_map` attribute is used to determine how many
indices are used to index into `source`, i.e. `index_depth`. The
@@ -187,17 +202,23 @@
`source` into `output` using the indices in `indices`. See the documentation
on `iree_linalg_ext.scatter` for more details regarding the indexing/shape
semantics.
+
+ If the optional operand `mask` is present, it must have batch shape and a
+ boolean element type (`i1`, or storage-legalized `i8` in backend
+ pipelines). A `false` mask value suppresses the corresponding gather
+ update, leaving `output` unchanged for that batch element.
}];
let arguments = (ins
AnyRankedTensorOrMemRef:$source,
AnyRankedTensorOrMemRef:$indices,
+ Optional<AnyRankedTensorOrMemRef>:$mask,
AnyRankedTensorOrMemRef:$output,
DenseI64ArrayAttr:$dimension_map
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let assemblyFormat = [{
attr-dict `dimension_map` `=` $dimension_map
- `ins` `(` $source `,` $indices `:` type($source) `,` type($indices) `)`
+ `ins` `(` $source `,` $indices (`,` $mask^)? `:` type($source) `,` type($indices) (`,` type($mask)^)? `)`
`outs` `(` $output `:` type($output) `)`
(`->` type($results)^)?
}];
@@ -227,6 +248,13 @@
return cast<ShapedType>(getIndices().getType());
}
+ std::optional<ShapedType> getMaskType() {
+ if (Value mask = getMask()) {
+ return cast<ShapedType>(mask.getType());
+ }
+ return std::nullopt;
+ }
+
ShapedType getOutputType(){
return cast<ShapedType>(getOutput().getType());
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
index 35b0ea0..5ef01ea 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp
@@ -77,6 +77,15 @@
return 0;
}
+static Value normalizeMaskValue(OpBuilder &builder, Location loc, Value mask) {
+ auto intType = dyn_cast<IntegerType>(mask.getType());
+ assert(intType && "expected integer mask type");
+ if (intType.getWidth() == 1) {
+ return mask;
+ }
+ return arith::TruncIOp::create(builder, loc, builder.getI1Type(), mask);
+}
+
/// Method similar to `LinalgOp`s that concatenates shapes of all operands.
static SmallVector<OpFoldResult>
createFlatListOfOperandDims(OpBuilder &b, Location loc, Operation *op) {
@@ -166,6 +175,22 @@
Value tiledIndices = indicesSlice->getResult(0);
slices.push_back(indicesSlice);
+ Value tiledMask;
+ if (Value mask = getMask()) {
+ std::optional<ShapedType> maskType = getMaskType();
+ if (maskType->getRank() == 0) {
+ tiledMask = mask;
+ } else {
+ SmallVector<OpFoldResult> maskOffsets(offsets.take_front(getBatchRank()));
+ SmallVector<OpFoldResult> maskSizes(sizes.take_front(getBatchRank()));
+ SmallVector<OpFoldResult> maskStrides(maskType->getRank(), oneAttr);
+ Operation *maskSlice =
+ getSlice(builder, loc, mask, maskOffsets, maskSizes, maskStrides);
+ tiledMask = maskSlice->getResult(0);
+ slices.push_back(maskSlice);
+ }
+ }
+
// Slice of the original.
SmallVector<OpFoldResult> originalOffsets, originalSizes;
if (failed(getResultTilePosition(builder, 0, offsets, sizes, originalOffsets,
@@ -184,9 +209,13 @@
if (getNumResults()) {
resultTypes.push_back(tiledOriginal.getType());
}
+ SmallVector<Value> tiledOperands = {tiledUpdate, tiledIndices};
+ if (tiledMask) {
+ tiledOperands.push_back(tiledMask);
+ }
+ tiledOperands.push_back(tiledOriginal);
Operation *tiledScatterOp =
- mlir::clone(builder, getOperation(), resultTypes,
- ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledScatterOp},
SmallVector<Value>(tiledScatterOp->getResults()),
slices};
@@ -298,20 +327,37 @@
starts[dim] = ret;
}
- Value init = memref::LoadOp::create(b, loc, getOriginal(), starts);
+ auto emitScatterUpdate = [&](OpBuilder &builder, Location nestedLoc) {
+ Value init =
+ memref::LoadOp::create(builder, nestedLoc, getOriginal(), starts);
- IRMapping bvm;
- Block &block = getRegion().front();
- bvm.map(block.getArgument(0), update);
- bvm.map(block.getArgument(1), init);
- for (auto &blockOp : block.without_terminator()) {
- b.clone(blockOp, bvm);
+ IRMapping bvm;
+ Block &block = getRegion().front();
+ bvm.map(block.getArgument(0), update);
+ bvm.map(block.getArgument(1), init);
+ for (auto &blockOp : block.without_terminator()) {
+ builder.clone(blockOp, bvm);
+ }
+ // The last op is linalg_ext.yield op. Store the operand to destination.
+ memref::StoreOp::create(
+ builder, nestedLoc,
+ bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
+ getOriginal(), starts);
+ };
+
+ if (Value mask = getMask()) {
+ SmallVector<Value> maskIndices(ivs.take_front(getBatchRank()));
+ Value maskValue = memref::LoadOp::create(b, loc, mask, maskIndices);
+ maskValue = normalizeMaskValue(b, loc, maskValue);
+ scf::IfOp::create(b, loc, maskValue,
+ [&](OpBuilder &thenBuilder, Location thenLoc) {
+ emitScatterUpdate(thenBuilder, thenLoc);
+ scf::YieldOp::create(thenBuilder, thenLoc);
+ });
+ return success();
}
- // The last op is linalg_ext.yield op. Store the operand to
- // destination.
- memref::StoreOp::create(
- b, loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
- getOriginal(), starts);
+
+ emitScatterUpdate(b, loc);
return success();
}
@@ -367,6 +413,22 @@
indicesSizes, indicesStrides);
Value tiledIndices = indicesSlice->getResult(0);
+ Value tiledMask;
+ if (Value mask = getMask()) {
+ std::optional<ShapedType> maskType = getMaskType();
+ if (maskType->getRank() == 0) {
+ tiledMask = mask;
+ } else {
+ SmallVector<OpFoldResult> maskOffsets(offsets.take_front(getBatchRank()));
+ SmallVector<OpFoldResult> maskSizes(sizes.take_front(getBatchRank()));
+ SmallVector<OpFoldResult> maskStrides(maskType->getRank(), oneAttr);
+ Operation *maskSlice =
+ getSlice(builder, loc, mask, maskOffsets, maskSizes, maskStrides);
+ tiledMask = maskSlice->getResult(0);
+ slices.push_back(maskSlice);
+ }
+ }
+
// Slice of the source.
auto sourceRank = getSourceType().getRank();
auto indexDepth = getIndexDepth();
@@ -394,9 +456,13 @@
if (getNumResults()) {
resultTypes.push_back(tiledResult.getType());
}
+ SmallVector<Value> tiledOperands = {tiledSource, tiledIndices};
+ if (tiledMask) {
+ tiledOperands.push_back(tiledMask);
+ }
+ tiledOperands.push_back(tiledResult);
Operation *tiledGatherOp =
- mlir::clone(builder, getOperation(), resultTypes,
- ValueRange{tiledSource, tiledIndices, tiledResult});
+ mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledGatherOp}, SmallVector<Value>(tiledGatherOp->getResults()), slices};
}
@@ -449,11 +515,26 @@
starts[dim] = ret;
}
- Value init = memref::LoadOp::create(b, loc, getSource(), starts);
+ auto emitGatherStore = [&](OpBuilder &builder, Location nestedLoc) {
+ Value init =
+ memref::LoadOp::create(builder, nestedLoc, getSource(), starts);
+ // The last op is linalg_ext.yield op. Store the operand to destination.
+ memref::StoreOp::create(builder, nestedLoc, init, getOutput(), ivs);
+ };
- // The last op is linalg_ext.yield op. Store the operand to
- // destination.
- memref::StoreOp::create(b, loc, init, getOutput(), ivs);
+ if (Value mask = getMask()) {
+ SmallVector<Value> maskIndices(ivs.take_front(getBatchRank()));
+ Value maskValue = memref::LoadOp::create(b, loc, mask, maskIndices);
+ maskValue = normalizeMaskValue(b, loc, maskValue);
+ scf::IfOp::create(b, loc, maskValue,
+ [&](OpBuilder &thenBuilder, Location thenLoc) {
+ emitGatherStore(thenBuilder, thenLoc);
+ scf::YieldOp::create(thenBuilder, thenLoc);
+ });
+ return success();
+ }
+
+ emitGatherStore(b, loc);
return success();
}
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 5529e94..c329464 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -188,6 +188,36 @@
// -----
+func.func @scatter_mask_wrong_element_type(
+ %update : tensor<4xf32>, %indices : tensor<4x1xi32>,
+ %mask : tensor<4xi32>, %original : tensor<8xf32>) -> tensor<8xf32> {
+ // expected-error @below {{'iree_linalg_ext.scatter' op expected mask to have i1 or storage-legalized i8 element type}}
+ %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%update, %indices, %mask : tensor<4xf32>, tensor<4x1xi32>, tensor<4xi32>)
+ outs(%original : tensor<8xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ iree_linalg_ext.yield %arg1 : f32
+ } -> tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
+
+// -----
+
+func.func @scatter_mask_wrong_shape(
+ %update : tensor<4xf32>, %indices : tensor<4x1xi32>,
+ %mask : tensor<5xi1>, %original : tensor<8xf32>) -> tensor<8xf32> {
+ // expected-error @below {{'iree_linalg_ext.scatter' op mask shape must match batch dimensions at dim#0}}
+ %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%update, %indices, %mask : tensor<4xf32>, tensor<4x1xi32>, tensor<5xi1>)
+ outs(%original : tensor<8xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ iree_linalg_ext.yield %arg1 : f32
+ } -> tensor<8xf32>
+ return %0 : tensor<8xf32>
+}
+
+// -----
+
func.func @scatter_dim_mismatch(
%update : tensor<48x?x2x11xf32>, %indices : tensor<48x?x1xi32>,
%original : tensor<2x?x10xf32>) -> tensor<2x?x10xf32> {
@@ -473,6 +503,19 @@
// -----
+func.func @gather_mask_wrong_rank(
+ %source : tensor<10x10xf32>, %idx : tensor<3x1xi32>,
+ %mask : tensor<3x1xi1>, %output : tensor<3x10xf32>) -> tensor<3x10xf32> {
+ // expected-error @below {{'iree_linalg_ext.gather' op expected mask rank to match batch rank}}
+ %0 = iree_linalg_ext.gather
+ dimension_map = [0]
+ ins(%source, %idx, %mask : tensor<10x10xf32>, tensor<3x1xi32>, tensor<3x1xi1>)
+ outs(%output : tensor<3x10xf32>) -> tensor<3x10xf32>
+ return %0 : tensor<3x10xf32>
+}
+
+// -----
+
func.func @map_store_mixed_element_types(
%input: memref<4xf16>, %output: memref<4xf32>
) {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 691f9df..ba9e945 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -101,6 +101,35 @@
// -----
+func.func @scatter_tensor_masked(
+ %original: tensor<8xi32>, %indices: tensor<3x1xi32>,
+ %mask: tensor<3xi1>, %update: tensor<3xi32>) -> tensor<8xi32> {
+ %0 = iree_linalg_ext.scatter
+ dimension_map = [0]
+ unique_indices(true)
+ ins(%update, %indices, %mask : tensor<3xi32>, tensor<3x1xi32>, tensor<3xi1>)
+ outs(%original: tensor<8xi32>) {
+ ^bb0(%arg1: i32, %arg2: i32):
+ %1 = arith.addi %arg1, %arg2 : i32
+ iree_linalg_ext.yield %1 : i32
+ } -> tensor<8xi32>
+ return %0 : tensor<8xi32>
+}
+// CHECK-LABEL: func.func @scatter_tensor_masked(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<8xi32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<3x1xi32>
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9_]+]]: tensor<3xi1>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<3xi32>
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter
+// CHECK-SAME: dimension_map = [0]
+// CHECK-SAME: unique_indices(true)
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]], %[[MASK]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: iree_linalg_ext.yield %{{.+}} : i32
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @scatter_tensor_partial_dynamic(
%original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
%update: tensor<?x10xf32>) -> tensor<?x?xf32> {
@@ -613,6 +642,28 @@
// -----
+func.func @gather_static_masked(
+ %source : tensor<10xf32>, %idx : tensor<1xi32>,
+ %mask : tensor<1xi1>, %result : tensor<1xf32>) -> tensor<1xf32> {
+ %0 = iree_linalg_ext.gather
+ dimension_map = [0]
+ ins(%source, %idx, %mask : tensor<10xf32>, tensor<1xi32>, tensor<1xi1>)
+ outs(%result : tensor<1xf32>) -> tensor<1xf32>
+ return %0 : tensor<1xf32>
+}
+// CHECK-LABEL: func.func @gather_static_masked(
+// CHECK-SAME: %[[SOURCE:[a-zA-Z0-9_]+]]: tensor<10xf32>
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: tensor<1xi32>
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9_]+]]: tensor<1xi1>
+// CHECK-SAME: %[[RESULT:[a-zA-Z0-9_]+]]: tensor<1xf32>
+// CHECK: %[[VAL:.+]] = iree_linalg_ext.gather
+// CHECK-SAME: dimension_map = [0]
+// CHECK-SAME: ins(%[[SOURCE]], %[[IDX]], %[[MASK]]
+// CHECK-SAME: outs(%[[RESULT]]
+// CHECK: return %[[VAL]]
+
+// -----
+
func.func @gather_static_2D_batch(
%source : tensor<4x3xf32>, %idx : tensor<1x1xi32>,
%result : tensor<1xf32>) -> tensor<1xf32> {
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
index 6aff8ab..0a540ab 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp
@@ -394,13 +394,21 @@
ReshapeOperandInfo indicesInfo;
indicesInfo.originalShape = getDimSizes(scatterOp.getIndices());
llvm::append_range(indicesInfo.operandToIterationSpace,
- llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
+ llvm::seq<int64_t>(scatterOp.getBatchRank()));
if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank()) {
indicesInfo.operandToIterationSpace.push_back(
ReshapeOperandInfo::kNoMapping);
}
infos.push_back(std::move(indicesInfo));
+ if (Value mask = scatterOp.getMask()) {
+ ReshapeOperandInfo maskInfo;
+ maskInfo.originalShape = getDimSizes(mask);
+ llvm::append_range(maskInfo.operandToIterationSpace,
+ llvm::seq<int64_t>(scatterOp.getBatchRank()));
+ infos.push_back(std::move(maskInfo));
+ }
+
ReshapeOperandInfo originalInfo;
originalInfo.originalShape = getDimSizes(scatterOp.getOriginal());
originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(),
@@ -428,13 +436,21 @@
ReshapeOperandInfo indicesInfo;
indicesInfo.originalShape = getDimSizes(gatherOp.getIndices());
llvm::append_range(indicesInfo.operandToIterationSpace,
- llvm::seq<int64_t>(0, gatherOp.getBatchRank()));
+ llvm::seq<int64_t>(gatherOp.getBatchRank()));
if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank()) {
indicesInfo.operandToIterationSpace.push_back(
ReshapeOperandInfo::kNoMapping);
}
infos.push_back(std::move(indicesInfo));
+ if (Value mask = gatherOp.getMask()) {
+ ReshapeOperandInfo maskInfo;
+ maskInfo.originalShape = getDimSizes(mask);
+ llvm::append_range(maskInfo.operandToIterationSpace,
+ llvm::seq<int64_t>(gatherOp.getBatchRank()));
+ infos.push_back(std::move(maskInfo));
+ }
+
ReshapeOperandInfo outputInfo;
outputInfo.originalShape = getDimSizes(gatherOp.getOutput());
llvm::append_range(outputInfo.operandToIterationSpace,
@@ -716,6 +732,7 @@
// Drop batch dimensions.
Value reducedSource = gatherOp.getSource();
Value reducedIndices = gatherOp.getIndices();
+ Value reducedMask = gatherOp.getMask();
Value reducedOutput = gatherOp.getOutput();
if (gatherOp.getBatchRank() > 1) {
// The only reaason we have to do these rank reductions separate is
@@ -728,9 +745,20 @@
FailureOr<Value> newOutput = rankReduceOperand(
rewriter, loc, /*startDim=*/0, /*numDims=*/gatherOp.getBatchRank(),
gatherOp.getOutput(), gatherOp.getOutputType(), options);
- if (succeeded(newIndices) && succeeded(newOutput)) {
+ FailureOr<Value> newMask = failure();
+ if (reducedMask) {
+ newMask =
+ rankReduceOperand(rewriter, loc, /*startDim=*/0,
+ /*numDims=*/gatherOp.getBatchRank(), reducedMask,
+ cast<ShapedType>(reducedMask.getType()), options);
+ }
+ if (succeeded(newIndices) && succeeded(newOutput) &&
+ (!reducedMask || succeeded(newMask))) {
reducedIndices = newIndices.value();
reducedOutput = newOutput.value();
+ if (reducedMask) {
+ reducedMask = newMask.value();
+ }
changed = true;
}
}
@@ -758,7 +786,8 @@
auto newGather = GatherOp::create(
rewriter, gatherOp.getLoc(), TypeRange{reducedOutput.getType()},
/*source=*/reducedSource, /*indices=*/reducedIndices,
- /*output=*/reducedOutput, gatherOp.getDimensionMap());
+ /*mask=*/reducedMask, /*output=*/reducedOutput,
+ gatherOp.getDimensionMap());
rewriter.replaceOp(gatherOp,
rankExpandValue(rewriter, loc, gatherOp.getOutput(),
newGather.getResult(0), options));
@@ -787,6 +816,7 @@
// Drop batch dimensions.
Value original = scatterOp.getOriginal();
Value indices = scatterOp.getIndices();
+ Value mask = scatterOp.getMask();
Value updates = scatterOp.getUpdates();
if (scatterOp.getBatchRank() > 1) {
FailureOr<Value> newIndices = rankReduceOperand(
@@ -795,9 +825,19 @@
FailureOr<Value> newOutput = rankReduceOperand(
rewriter, loc, /*startDim=*/0, /*numDims=*/scatterOp.getBatchRank(),
updates, cast<ShapedType>(updates.getType()), options);
- if (succeeded(newIndices) && succeeded(newOutput)) {
+ FailureOr<Value> newMask = failure();
+ if (mask) {
+ newMask = rankReduceOperand(rewriter, loc, /*startDim=*/0,
+ /*numDims=*/scatterOp.getBatchRank(), mask,
+ cast<ShapedType>(mask.getType()), options);
+ }
+ if (succeeded(newIndices) && succeeded(newOutput) &&
+ (!mask || succeeded(newMask))) {
indices = newIndices.value();
updates = newOutput.value();
+ if (mask) {
+ mask = newMask.value();
+ }
changed = true;
}
}
@@ -824,7 +864,7 @@
auto newScatter = ScatterOp::create(
rewriter, scatterOp.getLoc(), TypeRange{original.getType()}, updates,
- indices, original, scatterOp.getDimensionMap(),
+ indices, mask, original, scatterOp.getDimensionMap(),
scatterOp.getUniqueIndices());
rewriter.inlineRegionBefore(scatterOp.getRegion(), newScatter.getRegion(),
newScatter.getRegion().begin());
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 11e8f6a..7af867d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -124,6 +124,54 @@
// -----
+func.func @scatter_update_scalar_1D_masked(
+ %original: memref<8xi32>, %indices: memref<3x1xi32>,
+ %mask: memref<3xi1>, %updates: memref<3xi32>) {
+ iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%updates, %indices, %mask : memref<3xi32>, memref<3x1xi32>, memref<3xi1>)
+ outs(%original : memref<8xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32):
+ iree_linalg_ext.yield %arg0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func.func @scatter_update_scalar_1D_masked
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[MASK:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C3]] step %[[C1]] {
+// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I]]] : memref<3xi32>
+// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I]], %[[C0]]] : memref<3x1xi32>
+// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index
+// CHECK: %[[MASK_VAL:.+]] = memref.load %[[MASK]][%[[I]]] : memref<3xi1>
+// CHECK: scf.if %[[MASK_VAL]] {
+// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]]
+
+// -----
+
+func.func @scatter_update_scalar_1D_masked_i8(
+ %original: memref<8xi32>, %indices: memref<3x1xi32>,
+ %mask: memref<3xi8>, %updates: memref<3xi32>) {
+ iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%updates, %indices, %mask : memref<3xi32>, memref<3x1xi32>, memref<3xi8>)
+ outs(%original : memref<8xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32):
+ iree_linalg_ext.yield %arg0 : i32
+ }
+ return
+}
+// CHECK-LABEL: func.func @scatter_update_scalar_1D_masked_i8
+// CHECK: scf.for
+// CHECK: %[[MASK_VAL:.+]] = memref.load %{{.+}}[%{{.+}}] : memref<3xi8>
+// CHECK: %[[MASK_I1:.+]] = arith.trunci %[[MASK_VAL]] : i8 to i1
+// CHECK: scf.if %[[MASK_I1]] {
+
+// -----
+
func.func @scatter_batch_2D(
%original: memref<8xi32>, %indices: memref<1x3x1xi32>,
%updates: memref<1x3xi32>) {
@@ -1003,6 +1051,48 @@
// -----
+func.func @gather_1d_indices_masked(%arg0 : memref<10x10xi32>, %arg1 : memref<1xi32>, %arg2 : memref<1xi1>, %arg3 : memref<1x10xi32>) {
+ iree_linalg_ext.gather
+ dimension_map = [0]
+ ins(%arg0, %arg1, %arg2 : memref<10x10xi32>, memref<1xi32>, memref<1xi1>)
+ outs(%arg3: memref<1x10xi32>)
+ return
+}
+// CHECK-LABEL: func @gather_1d_indices_masked
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[C1]] step %[[C1]] {
+// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[C10]] step %[[C1]] {
+// CHECK: %[[IDX:.+]] = memref.load %[[ARG1]][%[[I]]] : memref<1xi32>
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX]] : i32 to index
+// CHECK: %[[MASK_VAL:.+]] = memref.load %[[ARG2]][%[[I]]] : memref<1xi1>
+// CHECK: scf.if %[[MASK_VAL]] {
+// CHECK: %[[LOAD:.+]] = memref.load %[[ARG0]][%[[CAST]], %[[J]]] : memref<10x10xi32>
+// CHECK: memref.store %[[LOAD]], %[[ARG3]][%[[I]], %[[J]]] : memref<1x10xi32>
+
+// -----
+
+func.func @gather_1d_indices_masked_i8(%arg0 : memref<10x10xi32>, %arg1 : memref<1xi32>, %arg2 : memref<1xi8>, %arg3 : memref<1x10xi32>) {
+ iree_linalg_ext.gather
+ dimension_map = [0]
+ ins(%arg0, %arg1, %arg2 : memref<10x10xi32>, memref<1xi32>, memref<1xi8>)
+ outs(%arg3: memref<1x10xi32>)
+ return
+}
+// CHECK-LABEL: func @gather_1d_indices_masked_i8
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[MASK_VAL:.+]] = memref.load %{{.+}}[%{{.+}}] : memref<1xi8>
+// CHECK: %[[MASK_I1:.+]] = arith.trunci %[[MASK_VAL]] : i8 to i1
+// CHECK: scf.if %[[MASK_I1]] {
+
+// -----
+
func.func @gather_2d_indices(%arg0 : memref<2x2xi32>, %arg1 : memref<2x2xi32>, %arg2 : memref<2xi32>) {
iree_linalg_ext.gather
dimension_map = [0, 1]
diff --git a/tests/e2e/linalg_ext_ops/gather.mlir b/tests/e2e/linalg_ext_ops/gather.mlir
index bc79e83..f266038 100644
--- a/tests/e2e/linalg_ext_ops/gather.mlir
+++ b/tests/e2e/linalg_ext_ops/gather.mlir
@@ -71,3 +71,17 @@
check.expect_eq_const(%generic, dense<[4, 6]> : tensor<2xi32>) : tensor<2xi32>
return
}
+
+func.func @gather_operand_mask_preserves_output() {
+ %source = util.unfoldable_constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
+ %output = util.unfoldable_constant dense<[[99, 98], [97, 96]]> : tensor<2x2xi32>
+ %indices = util.unfoldable_constant dense<[1, 0]> : tensor<2xi32>
+ %mask = util.unfoldable_constant dense<[true, false]> : tensor<2xi1>
+ %result = iree_linalg_ext.gather dimension_map = [0]
+ ins(%source, %indices, %mask : tensor<2x2xi32>, tensor<2xi32>, tensor<2xi1>)
+ outs(%output: tensor<2x2xi32>) -> tensor<2x2xi32>
+
+ check.expect_eq_const(%result, dense<[[2, 3], [97, 96]]> : tensor<2x2xi32>)
+ : tensor<2x2xi32>
+ return
+}
diff --git a/tests/e2e/linalg_ext_ops/scatter.mlir b/tests/e2e/linalg_ext_ops/scatter.mlir
index 8980f20..099dbed 100644
--- a/tests/e2e/linalg_ext_ops/scatter.mlir
+++ b/tests/e2e/linalg_ext_ops/scatter.mlir
@@ -94,3 +94,20 @@
return
}
+
+func.func @scatter_operand_mask_preserves_original() {
+ %original = util.unfoldable_constant dense<[10, 20, 30, 40]> : tensor<4xi32>
+ %update = util.unfoldable_constant dense<[1, 2]> : tensor<2xi32>
+ %indices = util.unfoldable_constant dense<[[1], [3]]> : tensor<2x1xi32>
+ %mask = util.unfoldable_constant dense<[true, false]> : tensor<2xi1>
+ %result = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true)
+ ins(%update, %indices, %mask : tensor<2xi32>, tensor<2x1xi32>, tensor<2xi1>)
+ outs(%original : tensor<4xi32>) {
+ ^bb0(%arg0: i32, %arg1: i32):
+ iree_linalg_ext.yield %arg0 : i32
+ } -> tensor<4xi32>
+
+ check.expect_eq_const(%result, dense<[10, 1, 30, 40]> : tensor<4xi32>) : tensor<4xi32>
+
+ return
+}