Add linalg_ext.scatter operation. (#6397)
Also some minor changes
- move some checks to tablegen
- add some extra verifications to the Interface.
- move the ops to use declerative assembly format.
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 07257b6..2ce4f8d 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -2391,7 +2391,7 @@
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @io::@rw[%c0] : !flow.dispatch.tensor<readwrite:128xi32>
%1 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:128xi32> -> tensor<128xi32>
- %2 = linalg_ext.sort {dimension = 0 : i64} outs(%1 : tensor<128xi32>) {
+ %2 = linalg_ext.sort dimension(0) outs(%1 : tensor<128xi32>) {
^bb0(%arg0: i32, %arg1: i32): // no predecessors
%3 = cmpi sgt, %arg0, %arg1 : i32
linalg_ext.yield %3 : i1
@@ -2402,5 +2402,5 @@
// CHECK-LABEL: func @linalg_ext_sort_1d()
// CHECK-DAG: %[[INOUT:.+]] = hal.interface.binding.subspan @io::@rw
// CHECK: linalg_ext.sort
-// CHECK-SAME: dimension = 0 : i64
+// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[INOUT]] : memref<128xi32>)
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
index 8d25cf6..a75066d 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -9,6 +9,10 @@
include "mlir/IR/OpBase.td"
+//===----------------------------------------------------------------------===//
+// Dialect definition
+//===----------------------------------------------------------------------===//
+
def LinalgExt_Dialect : Dialect {
let name = "linalg_ext";
let cppNamespace = "::mlir::iree_compiler::linalg_ext";
@@ -18,4 +22,15 @@
}];
}
+//===----------------------------------------------------------------------===//
+// Type definitions
+//===----------------------------------------------------------------------===//
+
+class RankedTensorOrMemRefOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes,
+ Or<[IsMemRefTypePred, And<[IsTensorTypePred, HasRankPred]>]>,
+ "ranked tensor or memref", "::mlir::ShapedType">;
+
+def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>;
+
#endif // IREE_DIALECT_LINALGEXT_BASE
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
index 98ed791..941dea1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
@@ -19,7 +19,47 @@
}
namespace detail {
-LogicalResult verifyLinalgExtOpInterface(Operation *op) { return success(); }
+LogicalResult verifyLinalgExtOpInterface(Operation *op) {
+ LinalgExtOp linalgExtOp = cast<LinalgExtOp>(op);
+ if (op->getNumResults()) {
+ for (auto en : llvm::enumerate(linalgExtOp.inputs())) {
+ if (!en.value().getType().isa<RankedTensorType>()) {
+ return linalgExtOp.emitOpError("expected `ins` operand #")
+ << en.index() << " to be of RankedTensorType";
+ }
+ }
+ if (op->getNumResults() != linalgExtOp.outputs().size()) {
+ return linalgExtOp.emitOpError(
+ "expected number of outputs to be same as the number of results");
+ }
+ for (auto en : llvm::enumerate(op->getResultTypes())) {
+ if (!en.value().isa<RankedTensorType>()) {
+ return linalgExtOp.emitOpError("expected result #")
+ << en.index() << " to be of RankedTensorType";
+ }
+ Type outputType = linalgExtOp.outputs()[en.index()].getType();
+ if (en.value() != outputType) {
+ return linalgExtOp.emitOpError("expected type of `outs` operand #")
+ << en.index() << " " << outputType
+ << " to be same as result type " << en.value();
+ }
+ }
+ } else {
+ for (auto en : llvm::enumerate(linalgExtOp.inputs())) {
+ if (!en.value().getType().isa<MemRefType>()) {
+ return linalgExtOp.emitOpError("expected `ins` operand #")
+ << en.index() << " to be of MemRefType";
+ }
+ }
+ for (auto en : llvm::enumerate(linalgExtOp.outputs())) {
+ if (!en.value().getType().isa<MemRefType>()) {
+ return linalgExtOp.emitOpError("expected `outs` operand #")
+ << en.index() << " to be of MemRefType";
+ }
+ }
+ }
+ return success();
+}
} // namespace detail
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp.inc" // IWYU pragma: export
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 8c1d62e..ce8a4e0 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -53,66 +53,140 @@
// Common methods from Linalg dialect.
//===----------------------------------------------------------------------===//
-static ParseResult parseCommonStructuredOpParts(
- OpAsmParser &parser, OperationState &result,
- SmallVectorImpl<Type> &inputTypes, SmallVectorImpl<Type> &outputTypes) {
- llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
- SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
-
- parser.parseOptionalAttrDict(result.attributes);
-
- if (succeeded(parser.parseOptionalKeyword("ins"))) {
- if (parser.parseLParen()) return failure();
-
- inputsOperandsLoc = parser.getCurrentLocation();
- if (parser.parseOperandList(inputsOperands) ||
- parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+static ParseResult parseLinalgExtOperandList(
+ OpAsmParser &parser, StringRef keyword,
+ SmallVectorImpl<OpAsmParser::OperandType> &values,
+ SmallVectorImpl<Type> &types) {
+ StringRef parsedKeyword;
+ if (succeeded(parser.parseOptionalKeyword(&parsedKeyword, {keyword}))) {
+ if (parser.parseLParen() || parser.parseOperandList(values) ||
+ parser.parseColonTypeList(types) || parser.parseRParen()) {
return failure();
+ }
}
-
- if (succeeded(parser.parseOptionalKeyword("outs"))) {
- outputsOperandsLoc = parser.getCurrentLocation();
- if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
- parser.parseColonTypeList(outputTypes) || parser.parseRParen())
- return failure();
- }
-
- if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
- result.operands) ||
- parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
- result.operands))
- return failure();
-
- result.addAttribute("operand_segment_sizes",
- parser.getBuilder().getI32VectorAttr(
- {static_cast<int32_t>(inputsOperands.size()),
- static_cast<int32_t>(outputsOperands.size())}));
return success();
}
-static ParseResult parseNamedStructuredOpResults(
- OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes) {
- if (parser.parseOptionalArrowTypeList(resultTypes)) return failure();
+static ParseResult parseLinalgExtInsList(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+ SmallVectorImpl<Type> &types) {
+ return parseLinalgExtOperandList(parser, "ins", values, types);
+}
+
+static ParseResult parseLinalgExtOutsList(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+ SmallVectorImpl<Type> &types) {
+ return parseLinalgExtOperandList(parser, "outs", values, types);
+}
+
+static void printLinalgExtOperandList(OpAsmPrinter &printer, Operation *op,
+ StringRef keyword, OperandRange values,
+ TypeRange types) {
+ if (!values.empty()) {
+ printer << keyword << "(";
+ printer.printOperands(values);
+ printer << " : " << types << ")";
+ }
+}
+
+static void printLinalgExtInsList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values, TypeRange types) {
+ return printLinalgExtOperandList(printer, op, "ins", values, types);
+}
+
+static void printLinalgExtOutsList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values, TypeRange types) {
+ return printLinalgExtOperandList(printer, op, "outs", values, types);
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+void ScatterOp::getEffects(
+ SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+ &effects) {
+ SmallVector<Value> inputBuffers = getInputBufferOperands();
+ SmallVector<Value> outputBuffers = getOutputBufferOperands();
+ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+ outputBuffers);
+}
+
+static LogicalResult verifyScatterOp(ScatterOp op) {
+ if (op.inputs().size() != 2) {
+ return op.emitOpError("expected two input operands");
+ }
+ if (op.outputs().size() != 1) {
+ return op.emitOpError("expected one output operand");
+ }
+ auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
+ return t1.getShape()[dim] == t2.getShape()[dim];
+ };
+ auto indicesType = op.inputs()[1].getType().cast<ShapedType>();
+ if (indicesType.getRank() != 1 ||
+ !indicesType.getElementType().isInteger(32)) {
+ return op.emitOpError(
+ "expected indices to be of rank 1 of i32 element type");
+ }
+ // The first dimension of the indices should match the first dimension of the
+ // output.
+ auto updateType = op.inputs()[0].getType().cast<ShapedType>();
+ if (updateType.getRank() < 1) {
+ return op.emitOpError("expected update value to be at least rank 1");
+ }
+ if (!checkDimensionsMatch(indicesType, updateType, 0)) {
+ return op.emitOpError(
+ "mismatch in shape of indices and update value at dim#0");
+ }
+ auto originalType = op.outputs()[0].getType().cast<ShapedType>();
+ if (originalType.getRank() != updateType.getRank()) {
+ return op.emitOpError(
+ "mismatch in rank of update value and original value");
+ }
+ for (auto dim : llvm::seq<unsigned>(1, originalType.getRank())) {
+ if (!checkDimensionsMatch(updateType, originalType, dim)) {
+ return op.emitOpError(
+ "mismatch in shape of update value and original value at dim#")
+ << dim;
+ }
+ }
+ Region ®ion = op.region();
+ Block *body = ®ion.front();
+ if (body->getNumArguments() != 2) {
+ return op.emitOpError("expected region to have two arguments");
+ }
+ Type arg0Type = body->getArgument(0).getType();
+ Type arg1Type = body->getArgument(1).getType();
+ if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
+ return op.emitOpError(
+ "expected region to have scalar argument of integer or float types");
+ }
+ if (arg0Type != updateType.getElementType()) {
+ return op.emitOpError("mismatch in argument 0 of region ")
+ << arg0Type << " and element type of update value "
+ << updateType.getElementType();
+ }
+ if (arg1Type != originalType.getElementType()) {
+ return op.emitOpError("mismatch in argument 1 of region ")
+ << arg1Type << " and element type of original value "
+ << originalType.getElementType();
+ }
+ if (arg0Type != arg1Type) {
+ return op.emitOpError("mismatch in region argument types ")
+ << arg0Type << " and " << arg1Type;
+ }
+ auto yieldOp = cast<linalg_ext::YieldOp>(body->getTerminator());
+ if (yieldOp->getNumOperands() != 1) {
+ return yieldOp.emitOpError("expected region to yield a single value");
+ }
+ auto yieldedType = yieldOp->getOperand(0).getType();
+ if (yieldedType != arg0Type) {
+ return yieldOp.emitOpError("mismatch in type of yielded value ")
+ << yieldedType << " and argument of the region " << arg0Type;
+ }
return success();
}
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
- NamedStructuredOpType op) {
- if (!op.inputs().empty()) {
- p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
- }
- if (!op.outputs().empty()) {
- p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
- }
-}
-
-static void printNamedStructuredOpResults(OpAsmPrinter &p,
- TypeRange resultTypes) {
- if (resultTypes.empty()) return;
- p.printOptionalArrowTypeList(resultTypes);
-}
-
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
@@ -126,44 +200,6 @@
outputBuffers);
}
-static ParseResult parseSortOp(OpAsmParser &parser, OperationState &result) {
- DictionaryAttr dictAttr;
- parser.parseOptionalAttribute(dictAttr, "_", result.attributes);
- if (dictAttr) {
- result.attributes.assign(dictAttr.getValue().begin(),
- dictAttr.getValue().end());
- }
-
- // Parsing is shared with named ops, except for the region.
- SmallVector<Type> inputTypes, outputTypes;
- if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
- return failure();
-
- SmallVector<OpAsmParser::OperandType> regionOperands;
- std::unique_ptr<Region> region = std::make_unique<Region>();
- SmallVector<Type> operandTypes, regionTypes;
- if (parser.parseRegion(*region, regionOperands, regionTypes))
- return failure();
- result.addRegion(std::move(region));
-
- SmallVector<Type> outputTensorsTypes;
- if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
- return failure();
- result.addTypes(outputTensorsTypes);
- return success();
-}
-
-static void printSortOp(OpAsmPrinter &p, SortOp op) {
- p << op.getOperationName();
- p.printOptionalAttrDict(op->getAttrs(),
- /*elidedAttrs=*/{"operand_segment_sizes"});
- printCommonStructuredOpParts(p, op);
- if (!op.region().empty()) {
- p.printRegion(op.region());
- }
- printNamedStructuredOpResults(p, op.result_tensors().getTypes());
-}
-
static LogicalResult verifySortOp(SortOp op) {
if (op.getNumInputs()) {
return op.emitOpError("does not expect to take any inputs");
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index be65b29..7d8e64b 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -21,7 +21,10 @@
}
class LinalgExt_Op<string mnemonic, list<OpTrait> traits = []> :
- LinalgExt_PureOp<mnemonic, !listconcat(traits, [LinalgExtInterface])> {
+ LinalgExt_PureOp<mnemonic, !listconcat(traits,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ LinalgExtInterface, SingleBlockImplicitTerminator<"YieldOp">])> {
let verifier = [{ return verify$cppClass(*this); }];
let printer = [{ return print$cppClass(p, *this); }];
let parser = [{ return parse$cppClass(parser, result); }];
@@ -31,10 +34,33 @@
// Non-structured ops
//===----------------------------------------------------------------------===//
-def LinalgExt_SortOp : LinalgExt_Op<"sort", [
- AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- SingleBlockImplicitTerminator<"YieldOp">]> {
+def LinalgExt_ScatterOp : LinalgExt_Op<"scatter"> {
+ let summary = "Scatter operator";
+ let description = [{
+ Based on XLA operation semantics, takes two `inputs` (`update` and
+ `indices`) and `outputs` value (`original`). The operation updates
+ the value at the slices specified by `indices` by combining the
+ current value with the value in `updates` using the computation
+ specified in `region`. The `region` specifies a binary operation
+ of signature (T, T) -> T, where `T` is the element-type of
+ `updates` (and `original`). The first argument correspond the
+ value to be updated (i.e. from `updates`), and the second the
+ current value (i.e. value from `original`).
+ }];
+ let arguments = (ins
+ Variadic<AnyRankedTensorOrMemRefType>:$inputs,
+ Variadic<AnyRankedTensorOrMemRefType>:$outputs
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = [{
+ attr-dict custom<LinalgExtInsList>($inputs, type($inputs))
+ custom<LinalgExtOutsList>($outputs, type($outputs))
+ $region (`->` type($results)^)?
+ }];
+}
+
+def LinalgExt_SortOp : LinalgExt_Op<"sort"> {
let summary = "Sort operator";
let description = [{
Based on XLA operation semantics, sorts the given `operands` at the given
@@ -51,8 +77,14 @@
Variadic<AnyShaped>:$outputs,
OptionalAttr<I64Attr>:$dimension
);
- let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+ let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
+ let assemblyFormat = [{
+ (`dimension` `(` $dimension^ `)`)?
+ attr-dict custom<LinalgExtInsList>($inputs, type($inputs))
+ custom<LinalgExtOutsList>($outputs, type($outputs))
+ $region (`->` type($results)^)?
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 1d6d6b1..70c59020 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -2,7 +2,7 @@
func @sort_invalid_dimension(%arg0: tensor<128xi32>) -> tensor<128xi32> {
// expected-error @+1 {{dimension must be within (0, 1]}}
- %0 = linalg_ext.sort {dimension = 1 : i64}
+ %0 = linalg_ext.sort dimension(1)
outs(%arg0 : tensor<128xi32>) {
^bb0(%arg1: i32, %arg2: i32): // no predecessors
%1 = cmpi sgt, %arg1, %arg2 : i32
@@ -23,3 +23,298 @@
} -> tensor<3x4xi32>
return %0 : tensor<3x4xi32>
}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : memref<?x?xf32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{expected `ins` operand #0 to be of RankedTensorType}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : tensor<?x?xf32>, %indices : memref<?xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{expected `ins` operand #1 to be of RankedTensorType}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, memref<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_extra_outputs(
+ %update : tensor<?x?xf32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ // expected-error @+1 {{expected number of outputs to be same as the number of results}}
+ %0, %1 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>, tensor<?x?xf32>
+ return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : tensor<?x?xf32>, %indices : tensor<?xi32>,
+ %original : memref<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{expected type of `outs` operand #0 'memref<?x?xf32>' to be same as result type 'tensor<?x?xf32>'}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+ outs(%original : memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : tensor<?x?xf32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xf32>) -> memref<?x?xf32> {
+ // expected-error @+1 {{expected result #0 to be of RankedTensorType}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> memref<?x?xf32>
+ return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : memref<?x?xf32>, %indices : tensor<?xi32>,
+ %original : memref<?x?xf32>) {
+ // expected-error @+1 {{expected `ins` operand #1 to be of MemRefType}}
+ linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, tensor<?xi32>)
+ outs(%original : memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+ %update : memref<?x?xf32>, %indices : memref<?xi32>,
+ %original : tensor<?x?xf32>) {
+ // expected-error @+1 {{expected `outs` operand #0 to be of MemRefType}}
+ linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, memref<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ }
+ return
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<?x?xf32>, %indices : tensor<48xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<48xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<64x?xf32>, %indices : tensor<48xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<64x?xf32>, tensor<48xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<?x?x?xf32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in rank of update value and original value}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?x?xf32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+ %update : tensor<?x4xf32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // expected-error @+1 {{mismatch in shape of update value and original value at dim#1}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x4xf32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ // expected-error @+1 {{expected region to have scalar argument of integer or float types}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi32>) {
+ ^bb0(%arg1: index, %arg2: index):
+ %1 = addi %arg1, %arg2 : index
+ %2 = index_cast %1 : index to i32
+ linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi32>) {
+ ^bb0(%arg1: i64, %arg2: i32):
+ %1 = trunci %arg1 : i64 to i32
+ %2 = addi %1, %arg2 : i32
+ linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi32>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %1 = trunci %arg2 : i64 to i32
+ %2 = addi %1, %arg1 : i32
+ linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %1 = sexti %arg1 : i32 to i64
+ %2 = addi %1, %arg2 : i64
+ linalg_ext.yield %2 : i64
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ // expected-error @+1 {{expected region to have two arguments}}
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
+ %1 = addi %arg1, %arg2 : i64
+ linalg_ext.yield %1 : i64
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+
+// -----
+
+func @scatter_yield_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64):
+ %1 = addi %arg1, %arg2 : i64
+ %2 = trunci %1 : i64 to i32
+ // expected-error @+1 {{mismatch in type of yielded value 'i32' and argument of the region 'i64'}}
+ linalg_ext.yield %2 : i32
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_yield_mismatch(
+ %update : tensor<?x?xi64>, %indices : tensor<?xi32>,
+ %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xi64>, tensor<?xi32>)
+ outs(%original : tensor<?x?xi64>) {
+ ^bb0(%arg1: i64, %arg2: i64):
+ %1 = addi %arg1, %arg2 : i64
+ %2 = trunci %1 : i64 to i32
+ // expected-error @+1 {{expected region to yield a single value}}
+ linalg_ext.yield %1, %2 : i64, i32
+ } -> tensor<?x?xi64>
+ return %0 : tensor<?x?xi64>
+}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 70f22a7..457ba78 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -14,12 +14,14 @@
return %0 : tensor<128xi32>
}
+// -----
+
// CHECK-LABEL: func @sort_memref
// CHECK: linalg_ext.sort
// CHECK-SAME: outs({{.*}})
// CHECK: linalg_ext.yield
func @sort_memref(%arg0: memref<128xi32>) {
- linalg_ext.sort {dimension = 0 : i64}
+ linalg_ext.sort dimension(0)
outs(%arg0 : memref<128xi32>) {
^bb0(%arg1: i32, %arg2: i32): // no predecessors
%0 = cmpi sgt, %arg1, %arg2 : i32
@@ -27,3 +29,99 @@
}
return
}
+
+// -----
+
+func @scatter_tensor_dynamic(
+ %original: tensor<?x?xf32>, %indices: tensor<?xi32>,
+ %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+ outs(%original: tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_dynamic(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_tensor_static(
+ %original: tensor<128x3xf32>, %indices: tensor<48xi32>,
+ %update: tensor<48x3xf32>) -> tensor<128x3xf32> {
+ %0 = linalg_ext.scatter
+ ins(%update, %indices : tensor<48x3xf32>, tensor<48xi32>)
+ outs(%original: tensor<128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<128x3xf32>
+ return %0 : tensor<128x3xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_static(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32>
+// CHECK: %[[RESULT:.+]] = linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: linalg_ext.yield %{{.+}} : f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_memref_dynamic(
+ %original: memref<?x?xf32>, %indices: memref<?xi32>,
+ %update: memref<?x?xf32>) {
+ linalg_ext.scatter
+ ins(%update, %indices : memref<?x?xf32>, memref<?xi32>)
+ outs(%original: memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_memref_dynamic(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<?xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: linalg_ext.yield %{{.+}} : f32
+// CHECK: return
+
+// -----
+
+func @scatter_memref_static(
+ %original: memref<128x3xf32>, %indices: memref<48xi32>,
+ %update: memref<48x3xf32>) {
+ linalg_ext.scatter
+ ins(%update, %indices : memref<48x3xf32>, memref<48xi32>)
+ outs(%original: memref<128x3xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK-LABEL: func @scatter_memref_static(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<128x3xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<48xi32>
+// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]]: memref<48x3xf32>
+// CHECK: linalg_ext.scatter
+// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: linalg_ext.yield %{{.+}} : f32
+// CHECK: return
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 7af12e8..319f50c 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -1,7 +1,7 @@
// RUN: iree-opt -split-input-file -iree-linalg-ext-to-loops %s | IreeFileCheck %s
func @sort_1d(%arg0: memref<128xi32>) {
- linalg_ext.sort {dimension = 0 : i64}
+ linalg_ext.sort dimension(0)
outs(%arg0 : memref<128xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%0 = cmpi sgt, %arg2, %arg3 : i32
@@ -30,7 +30,7 @@
// -----
func @sort_2d(%arg0: memref<16x32xi32>) {
- linalg_ext.sort {dimension = 0 : i64}
+ linalg_ext.sort dimension(0)
outs(%arg0 : memref<16x32xi32>) {
^bb0(%arg2: i32, %arg3: i32): // no predecessors
%0 = cmpi sgt, %arg2, %arg3 : i32
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
index 7902e19..8464723 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
@@ -16,7 +16,7 @@
// CHECK: %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:128xi32>
// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]]
// CHECK: %[[SORT:.+]] = linalg_ext.sort
-// CHECK-SAME: dimension = 0 : i64
+// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[IN]] : tensor<128xi32>)
// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32)
// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
@@ -42,7 +42,7 @@
// CHECK: %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:16x32xi32>
// CHECK: %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]]
// CHECK: %[[SORT:.+]] = linalg_ext.sort
-// CHECK-SAME: dimension = 0 : i64
+// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[IN]] : tensor<16x32xi32>)
// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32)
// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
@@ -72,7 +72,7 @@
// CHECK: %[[IN1:.+]] = flow.dispatch.tensor.load %[[ARG2]]
// CHECK: %[[IN2:.+]] = flow.dispatch.tensor.load %[[ARG3]]
// CHECK: %[[SORT:.+]]:2 = linalg_ext.sort
-// CHECK-SAME: dimension = 0 : i64
+// CHECK-SAME: dimension(0)
// CHECK-SAME: outs(%[[IN1]], %[[IN2]] : tensor<128xi32>, tensor<128xi32>)
// CHECK: ^bb0(%[[ARG4:.+]]: i32, %[[ARG5:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32)
// CHECK: %[[CMP:.+]] = cmpi sgt, %[[ARG4]], %[[ARG5]]