Add linalg_ext.scatter operation. (#6397)

Also some minor changes

- move some checks to tablegen
- add some extra verifications to the Interface.
- move the ops to use declerative assembly format.
diff --git a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
index 07257b6..2ce4f8d 100644
--- a/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
+++ b/iree/compiler/Codegen/Common/test/linalg_bufferize.mlir
@@ -2391,7 +2391,7 @@
   %c0 = constant 0 : index
   %0 = hal.interface.binding.subspan @io::@rw[%c0] : !flow.dispatch.tensor<readwrite:128xi32>
   %1 = flow.dispatch.tensor.load %0, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readwrite:128xi32> -> tensor<128xi32>
-  %2 = linalg_ext.sort {dimension = 0 : i64} outs(%1 : tensor<128xi32>) {
+  %2 = linalg_ext.sort dimension(0) outs(%1 : tensor<128xi32>) {
   ^bb0(%arg0: i32, %arg1: i32):  // no predecessors
     %3 = cmpi sgt, %arg0, %arg1 : i32
     linalg_ext.yield %3 : i1
@@ -2402,5 +2402,5 @@
 // CHECK-LABEL: func @linalg_ext_sort_1d()
 //   CHECK-DAG:   %[[INOUT:.+]] = hal.interface.binding.subspan @io::@rw
 //       CHECK:   linalg_ext.sort
-//  CHECK-SAME:     dimension = 0 : i64
+//  CHECK-SAME:     dimension(0)
 //  CHECK-SAME:     outs(%[[INOUT]] : memref<128xi32>)
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
index 8d25cf6..a75066d 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td
@@ -9,6 +9,10 @@
 
 include "mlir/IR/OpBase.td"
 
+//===----------------------------------------------------------------------===//
+// Dialect definition
+//===----------------------------------------------------------------------===//
+
 def LinalgExt_Dialect : Dialect {
   let name = "linalg_ext";
   let cppNamespace = "::mlir::iree_compiler::linalg_ext";
@@ -18,4 +22,15 @@
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Type definitions
+//===----------------------------------------------------------------------===//
+
+class RankedTensorOrMemRefOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes,
+      Or<[IsMemRefTypePred, And<[IsTensorTypePred, HasRankPred]>]>,
+  "ranked tensor or memref", "::mlir::ShapedType">;
+
+def AnyRankedTensorOrMemRefType : RankedTensorOrMemRefOf<[AnyType]>;
+
 #endif // IREE_DIALECT_LINALGEXT_BASE
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
index 98ed791..941dea1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp
@@ -19,7 +19,47 @@
 }
 
 namespace detail {
-LogicalResult verifyLinalgExtOpInterface(Operation *op) { return success(); }
+LogicalResult verifyLinalgExtOpInterface(Operation *op) {
+  LinalgExtOp linalgExtOp = cast<LinalgExtOp>(op);
+  if (op->getNumResults()) {
+    for (auto en : llvm::enumerate(linalgExtOp.inputs())) {
+      if (!en.value().getType().isa<RankedTensorType>()) {
+        return linalgExtOp.emitOpError("expected `ins` operand #")
+               << en.index() << " to be of RankedTensorType";
+      }
+    }
+    if (op->getNumResults() != linalgExtOp.outputs().size()) {
+      return linalgExtOp.emitOpError(
+          "expected number of outputs to be same as the number of results");
+    }
+    for (auto en : llvm::enumerate(op->getResultTypes())) {
+      if (!en.value().isa<RankedTensorType>()) {
+        return linalgExtOp.emitOpError("expected result #")
+               << en.index() << " to be of RankedTensorType";
+      }
+      Type outputType = linalgExtOp.outputs()[en.index()].getType();
+      if (en.value() != outputType) {
+        return linalgExtOp.emitOpError("expected type of `outs` operand #")
+               << en.index() << " " << outputType
+               << " to be same as result type " << en.value();
+      }
+    }
+  } else {
+    for (auto en : llvm::enumerate(linalgExtOp.inputs())) {
+      if (!en.value().getType().isa<MemRefType>()) {
+        return linalgExtOp.emitOpError("expected `ins` operand #")
+               << en.index() << " to be of MemRefType";
+      }
+    }
+    for (auto en : llvm::enumerate(linalgExtOp.outputs())) {
+      if (!en.value().getType().isa<MemRefType>()) {
+        return linalgExtOp.emitOpError("expected `outs` operand #")
+               << en.index() << " to be of MemRefType";
+      }
+    }
+  }
+  return success();
+}
 }  // namespace detail
 
 #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.cpp.inc"  // IWYU pragma: export
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 8c1d62e..ce8a4e0 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -53,66 +53,140 @@
 // Common methods from Linalg dialect.
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseCommonStructuredOpParts(
-    OpAsmParser &parser, OperationState &result,
-    SmallVectorImpl<Type> &inputTypes, SmallVectorImpl<Type> &outputTypes) {
-  llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
-  SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
-
-  parser.parseOptionalAttrDict(result.attributes);
-
-  if (succeeded(parser.parseOptionalKeyword("ins"))) {
-    if (parser.parseLParen()) return failure();
-
-    inputsOperandsLoc = parser.getCurrentLocation();
-    if (parser.parseOperandList(inputsOperands) ||
-        parser.parseColonTypeList(inputTypes) || parser.parseRParen())
+static ParseResult parseLinalgExtOperandList(
+    OpAsmParser &parser, StringRef keyword,
+    SmallVectorImpl<OpAsmParser::OperandType> &values,
+    SmallVectorImpl<Type> &types) {
+  StringRef parsedKeyword;
+  if (succeeded(parser.parseOptionalKeyword(&parsedKeyword, {keyword}))) {
+    if (parser.parseLParen() || parser.parseOperandList(values) ||
+        parser.parseColonTypeList(types) || parser.parseRParen()) {
       return failure();
+    }
   }
-
-  if (succeeded(parser.parseOptionalKeyword("outs"))) {
-    outputsOperandsLoc = parser.getCurrentLocation();
-    if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
-        parser.parseColonTypeList(outputTypes) || parser.parseRParen())
-      return failure();
-  }
-
-  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
-                             result.operands) ||
-      parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
-                             result.operands))
-    return failure();
-
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getI32VectorAttr(
-                          {static_cast<int32_t>(inputsOperands.size()),
-                           static_cast<int32_t>(outputsOperands.size())}));
   return success();
 }
 
-static ParseResult parseNamedStructuredOpResults(
-    OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes) {
-  if (parser.parseOptionalArrowTypeList(resultTypes)) return failure();
+static ParseResult parseLinalgExtInsList(
+    OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+    SmallVectorImpl<Type> &types) {
+  return parseLinalgExtOperandList(parser, "ins", values, types);
+}
+
+static ParseResult parseLinalgExtOutsList(
+    OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+    SmallVectorImpl<Type> &types) {
+  return parseLinalgExtOperandList(parser, "outs", values, types);
+}
+
+static void printLinalgExtOperandList(OpAsmPrinter &printer, Operation *op,
+                                      StringRef keyword, OperandRange values,
+                                      TypeRange types) {
+  if (!values.empty()) {
+    printer << keyword << "(";
+    printer.printOperands(values);
+    printer << " : " << types << ")";
+  }
+}
+
+static void printLinalgExtInsList(OpAsmPrinter &printer, Operation *op,
+                                  OperandRange values, TypeRange types) {
+  return printLinalgExtOperandList(printer, op, "ins", values, types);
+}
+
+static void printLinalgExtOutsList(OpAsmPrinter &printer, Operation *op,
+                                   OperandRange values, TypeRange types) {
+  return printLinalgExtOperandList(printer, op, "outs", values, types);
+}
+
+//===----------------------------------------------------------------------===//
+// ScatterOp
+//===----------------------------------------------------------------------===//
+
+void ScatterOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  SmallVector<Value> inputBuffers = getInputBufferOperands();
+  SmallVector<Value> outputBuffers = getOutputBufferOperands();
+  getEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+                 outputBuffers);
+}
+
+static LogicalResult verifyScatterOp(ScatterOp op) {
+  if (op.inputs().size() != 2) {
+    return op.emitOpError("expected two input operands");
+  }
+  if (op.outputs().size() != 1) {
+    return op.emitOpError("expected one output operand");
+  }
+  auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
+    return t1.getShape()[dim] == t2.getShape()[dim];
+  };
+  auto indicesType = op.inputs()[1].getType().cast<ShapedType>();
+  if (indicesType.getRank() != 1 ||
+      !indicesType.getElementType().isInteger(32)) {
+    return op.emitOpError(
+        "expected indices to be of rank 1 of i32 element type");
+  }
+  // The first dimension of the indices should match the first dimension of the
+  // output.
+  auto updateType = op.inputs()[0].getType().cast<ShapedType>();
+  if (updateType.getRank() < 1) {
+    return op.emitOpError("expected update value to be at least rank 1");
+  }
+  if (!checkDimensionsMatch(indicesType, updateType, 0)) {
+    return op.emitOpError(
+        "mismatch in shape of indices and update value at dim#0");
+  }
+  auto originalType = op.outputs()[0].getType().cast<ShapedType>();
+  if (originalType.getRank() != updateType.getRank()) {
+    return op.emitOpError(
+        "mismatch in rank of update value and original value");
+  }
+  for (auto dim : llvm::seq<unsigned>(1, originalType.getRank())) {
+    if (!checkDimensionsMatch(updateType, originalType, dim)) {
+      return op.emitOpError(
+                 "mismatch in shape of update value and original value at dim#")
+             << dim;
+    }
+  }
+  Region &region = op.region();
+  Block *body = &region.front();
+  if (body->getNumArguments() != 2) {
+    return op.emitOpError("expected region to have two arguments");
+  }
+  Type arg0Type = body->getArgument(0).getType();
+  Type arg1Type = body->getArgument(1).getType();
+  if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
+    return op.emitOpError(
+        "expected region to have scalar argument of integer or float types");
+  }
+  if (arg0Type != updateType.getElementType()) {
+    return op.emitOpError("mismatch in argument 0 of region ")
+           << arg0Type << " and element type of update value "
+           << updateType.getElementType();
+  }
+  if (arg1Type != originalType.getElementType()) {
+    return op.emitOpError("mismatch in argument 1 of region ")
+           << arg1Type << " and element type of original value "
+           << originalType.getElementType();
+  }
+  if (arg0Type != arg1Type) {
+    return op.emitOpError("mismatch in region argument types ")
+           << arg0Type << " and " << arg1Type;
+  }
+  auto yieldOp = cast<linalg_ext::YieldOp>(body->getTerminator());
+  if (yieldOp->getNumOperands() != 1) {
+    return yieldOp.emitOpError("expected region to yield a single value");
+  }
+  auto yieldedType = yieldOp->getOperand(0).getType();
+  if (yieldedType != arg0Type) {
+    return yieldOp.emitOpError("mismatch in type of yielded value ")
+           << yieldedType << " and argument of the region " << arg0Type;
+  }
   return success();
 }
 
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
-                                         NamedStructuredOpType op) {
-  if (!op.inputs().empty()) {
-    p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
-  }
-  if (!op.outputs().empty()) {
-    p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
-  }
-}
-
-static void printNamedStructuredOpResults(OpAsmPrinter &p,
-                                          TypeRange resultTypes) {
-  if (resultTypes.empty()) return;
-  p.printOptionalArrowTypeList(resultTypes);
-}
-
 //===----------------------------------------------------------------------===//
 // SortOp
 //===----------------------------------------------------------------------===//
@@ -126,44 +200,6 @@
                  outputBuffers);
 }
 
-static ParseResult parseSortOp(OpAsmParser &parser, OperationState &result) {
-  DictionaryAttr dictAttr;
-  parser.parseOptionalAttribute(dictAttr, "_", result.attributes);
-  if (dictAttr) {
-    result.attributes.assign(dictAttr.getValue().begin(),
-                             dictAttr.getValue().end());
-  }
-
-  // Parsing is shared with named ops, except for the region.
-  SmallVector<Type> inputTypes, outputTypes;
-  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
-    return failure();
-
-  SmallVector<OpAsmParser::OperandType> regionOperands;
-  std::unique_ptr<Region> region = std::make_unique<Region>();
-  SmallVector<Type> operandTypes, regionTypes;
-  if (parser.parseRegion(*region, regionOperands, regionTypes))
-    return failure();
-  result.addRegion(std::move(region));
-
-  SmallVector<Type> outputTensorsTypes;
-  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
-    return failure();
-  result.addTypes(outputTensorsTypes);
-  return success();
-}
-
-static void printSortOp(OpAsmPrinter &p, SortOp op) {
-  p << op.getOperationName();
-  p.printOptionalAttrDict(op->getAttrs(),
-                          /*elidedAttrs=*/{"operand_segment_sizes"});
-  printCommonStructuredOpParts(p, op);
-  if (!op.region().empty()) {
-    p.printRegion(op.region());
-  }
-  printNamedStructuredOpResults(p, op.result_tensors().getTypes());
-}
-
 static LogicalResult verifySortOp(SortOp op) {
   if (op.getNumInputs()) {
     return op.emitOpError("does not expect to take any inputs");
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index be65b29..7d8e64b 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -21,7 +21,10 @@
 }
 
 class LinalgExt_Op<string mnemonic, list<OpTrait> traits = []> :
-    LinalgExt_PureOp<mnemonic, !listconcat(traits, [LinalgExtInterface])> {
+    LinalgExt_PureOp<mnemonic, !listconcat(traits,
+        [AttrSizedOperandSegments,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+         LinalgExtInterface, SingleBlockImplicitTerminator<"YieldOp">])> {
   let verifier = [{ return verify$cppClass(*this); }];
   let printer = [{ return print$cppClass(p, *this); }];
   let parser = [{ return parse$cppClass(parser, result); }];
@@ -31,10 +34,33 @@
 // Non-structured ops
 //===----------------------------------------------------------------------===//
 
-def LinalgExt_SortOp : LinalgExt_Op<"sort", [
-    AttrSizedOperandSegments,
-    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-    SingleBlockImplicitTerminator<"YieldOp">]> {
+def LinalgExt_ScatterOp : LinalgExt_Op<"scatter"> {
+  let summary = "Scatter operator";
+  let description = [{
+    Based on XLA operation semantics, takes two `inputs` (`update` and
+    `indices`) and `outputs` value (`original`). The operation updates
+    the value at the slices specified by `indices` by combining the
+    current value with the value in `updates` using the computation
+    specified in `region`. The `region` specifies a binary operation
+    of signature (T, T) -> T, where `T` is the element-type of
+    `updates` (and `original`). The first argument correspond the
+    value to be updated (i.e. from `updates`), and the second the
+    current value (i.e. value from `original`).
+  }];
+  let arguments = (ins
+      Variadic<AnyRankedTensorOrMemRefType>:$inputs,
+      Variadic<AnyRankedTensorOrMemRefType>:$outputs
+  );
+  let results = (outs Variadic<AnyRankedTensor>:$results);
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = [{
+    attr-dict custom<LinalgExtInsList>($inputs, type($inputs))
+    custom<LinalgExtOutsList>($outputs, type($outputs))
+    $region (`->` type($results)^)?
+  }];
+}
+
+def LinalgExt_SortOp : LinalgExt_Op<"sort"> {
   let summary = "Sort operator";
   let description = [{
     Based on XLA operation semantics, sorts the given `operands` at the given
@@ -51,8 +77,14 @@
                        Variadic<AnyShaped>:$outputs,
                        OptionalAttr<I64Attr>:$dimension
   );
-  let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
+  let results = (outs Variadic<AnyRankedTensor>:$results);
   let regions = (region AnyRegion:$region);
+  let assemblyFormat = [{
+    (`dimension` `(` $dimension^ `)`)?
+    attr-dict custom<LinalgExtInsList>($inputs, type($inputs))
+    custom<LinalgExtOutsList>($outputs, type($outputs))
+    $region (`->` type($results)^)?
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 1d6d6b1..70c59020 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -2,7 +2,7 @@
 
 func @sort_invalid_dimension(%arg0: tensor<128xi32>) -> tensor<128xi32> {
   // expected-error @+1 {{dimension must be within (0, 1]}}
-  %0 = linalg_ext.sort {dimension = 1 : i64}
+  %0 = linalg_ext.sort dimension(1)
     outs(%arg0 : tensor<128xi32>) {
   ^bb0(%arg1: i32, %arg2: i32):  // no predecessors
     %1 = cmpi sgt, %arg1, %arg2 : i32
@@ -23,3 +23,298 @@
   } -> tensor<3x4xi32>
   return %0 : tensor<3x4xi32>
 }
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+    %update : memref<?x?xf32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{expected `ins` operand #0 to be of RankedTensorType}}
+  %0 = linalg_ext.scatter
+      ins(%update, %indices : memref<?x?xf32>, tensor<?xi32>)
+      outs(%original : tensor<?x?xf32>) {
+      ^bb0(%arg1: f32, %arg2: f32):
+        %1 = addf %arg1, %arg2 : f32
+        linalg_ext.yield %1 : f32
+      } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+    %update : tensor<?x?xf32>, %indices : memref<?xi32>,
+    %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{expected `ins` operand #1 to be of RankedTensorType}}
+  %0 = linalg_ext.scatter
+      ins(%update, %indices : tensor<?x?xf32>, memref<?xi32>)
+      outs(%original : tensor<?x?xf32>) {
+      ^bb0(%arg1: f32, %arg2: f32):
+        %1 = addf %arg1, %arg2 : f32
+        linalg_ext.yield %1 : f32
+      } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_extra_outputs(
+    %update : tensor<?x?xf32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  // expected-error @+1 {{expected number of outputs to be same as the number of results}}
+  %0, %1 = linalg_ext.scatter
+      ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+      outs(%original : tensor<?x?xf32>) {
+      ^bb0(%arg1: f32, %arg2: f32):
+        %1 = addf %arg1, %arg2 : f32
+        linalg_ext.yield %1 : f32
+      } -> tensor<?x?xf32>, tensor<?x?xf32>
+  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+    %update : tensor<?x?xf32>, %indices : tensor<?xi32>,
+    %original : memref<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{expected type of `outs` operand #0 'memref<?x?xf32>' to be same as result type 'tensor<?x?xf32>'}}
+  %0 = linalg_ext.scatter
+      ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+      outs(%original : memref<?x?xf32>) {
+      ^bb0(%arg1: f32, %arg2: f32):
+        %1 = addf %arg1, %arg2 : f32
+        linalg_ext.yield %1 : f32
+      } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+    %update : tensor<?x?xf32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xf32>) -> memref<?x?xf32> {
+  // expected-error @+1 {{expected result #0 to be of RankedTensorType}}
+  %0 = linalg_ext.scatter
+      ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+      outs(%original : tensor<?x?xf32>) {
+      ^bb0(%arg1: f32, %arg2: f32):
+        %1 = addf %arg1, %arg2 : f32
+        linalg_ext.yield %1 : f32
+      } -> memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+    %update : memref<?x?xf32>, %indices : tensor<?xi32>,
+    %original : memref<?x?xf32>) {
+  // expected-error @+1 {{expected `ins` operand #1 to be of MemRefType}}
+  linalg_ext.scatter
+    ins(%update, %indices : memref<?x?xf32>, tensor<?xi32>)
+    outs(%original : memref<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    }
+  return
+}
+
+// -----
+
+func @scatter_mixed_tensor_memref(
+    %update : memref<?x?xf32>, %indices : memref<?xi32>,
+    %original : tensor<?x?xf32>) {
+  // expected-error @+1 {{expected `outs` operand #0 to be of MemRefType}}
+  linalg_ext.scatter
+    ins(%update, %indices : memref<?x?xf32>, memref<?xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    }
+  return
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+    %update : tensor<?x?xf32>, %indices : tensor<48xi32>,
+    %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xf32>, tensor<48xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+    %update : tensor<64x?xf32>, %indices : tensor<48xi32>,
+    %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<64x?xf32>, tensor<48xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+    %update : tensor<?x?x?xf32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{mismatch in rank of update value and original value}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?x?xf32>, tensor<?xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_dim_mismatch(
+    %update : tensor<?x4xf32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // expected-error @+1 {{mismatch in shape of update value and original value at dim#1}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x4xf32>, tensor<?xi32>)
+    outs(%original : tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+    %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  // expected-error @+1 {{expected region to have scalar argument of integer or float types}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi32>) {
+    ^bb0(%arg1: index, %arg2: index):
+      %1 = addi %arg1, %arg2 : index
+      %2 = index_cast %1 : index to i32
+      linalg_ext.yield %2 : i32
+    } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+    %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi32>) {
+    ^bb0(%arg1: i64, %arg2: i32):
+      %1 = trunci %arg1 : i64 to i32
+      %2 = addi %1, %arg2 : i32
+      linalg_ext.yield %2 : i32
+    } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+    %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi32>) {
+    ^bb0(%arg1: i32, %arg2: i64):
+      %1 = trunci %arg2 : i64 to i32
+      %2 = addi %1, %arg1 : i32
+      linalg_ext.yield %2 : i32
+    } -> tensor<?x?xi32>
+  return %0 : tensor<?x?xi32>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+    %update : tensor<?x?xi32>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+  // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi32>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi64>) {
+    ^bb0(%arg1: i32, %arg2: i64):
+      %1 = sexti %arg1 : i32 to i64
+      %2 = addi %1, %arg2 : i64
+      linalg_ext.yield %2 : i64
+    } -> tensor<?x?xi64>
+  return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_region_type_mismatch(
+    %update : tensor<?x?xi64>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+  // expected-error @+1 {{expected region to have two arguments}}
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi64>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi64>) {
+    ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64):
+      %1 = addi %arg1, %arg2 : i64
+      linalg_ext.yield %1 : i64
+    } -> tensor<?x?xi64>
+  return %0 : tensor<?x?xi64>
+}
+
+
+// -----
+
+func @scatter_yield_mismatch(
+    %update : tensor<?x?xi64>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi64>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi64>) {
+    ^bb0(%arg1: i64, %arg2: i64):
+      %1 = addi %arg1, %arg2 : i64
+      %2 = trunci %1 : i64 to i32
+      // expected-error @+1 {{mismatch in type of yielded value 'i32' and argument of the region 'i64'}}
+      linalg_ext.yield %2 : i32
+    } -> tensor<?x?xi64>
+  return %0 : tensor<?x?xi64>
+}
+
+// -----
+
+func @scatter_yield_mismatch(
+    %update : tensor<?x?xi64>, %indices : tensor<?xi32>,
+    %original : tensor<?x?xi64>) -> tensor<?x?xi64> {
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xi64>, tensor<?xi32>)
+    outs(%original : tensor<?x?xi64>) {
+    ^bb0(%arg1: i64, %arg2: i64):
+      %1 = addi %arg1, %arg2 : i64
+      %2 = trunci %1 : i64 to i32
+      // expected-error @+1 {{expected region to yield a single value}}
+      linalg_ext.yield %1, %2 : i64, i32
+    } -> tensor<?x?xi64>
+  return %0 : tensor<?x?xi64>
+}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 70f22a7..457ba78 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -14,12 +14,14 @@
   return %0 : tensor<128xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @sort_memref
 // CHECK:         linalg_ext.sort
 // CHECK-SAME:      outs({{.*}})
 // CHECK:           linalg_ext.yield
 func @sort_memref(%arg0: memref<128xi32>) {
-  linalg_ext.sort {dimension = 0 : i64}
+  linalg_ext.sort dimension(0)
     outs(%arg0 : memref<128xi32>) {
   ^bb0(%arg1: i32, %arg2: i32):  // no predecessors
     %0 = cmpi sgt, %arg1, %arg2 : i32
@@ -27,3 +29,99 @@
   }
   return
 }
+
+// -----
+
+func @scatter_tensor_dynamic(
+    %original: tensor<?x?xf32>, %indices: tensor<?xi32>,
+    %update: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<?x?xf32>, tensor<?xi32>)
+    outs(%original: tensor<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_dynamic(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?xi32>
+//  CHECK-SAME:   %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//       CHECK:   %[[RESULT:.+]] = linalg_ext.scatter
+//  CHECK-SAME:     ins(%[[UPDATE]], %[[INDICES]]
+//  CHECK-SAME:     outs(%[[ORIGINAL]]
+//       CHECK:     linalg_ext.yield %{{.+}} : f32
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scatter_tensor_static(
+    %original: tensor<128x3xf32>, %indices: tensor<48xi32>,
+    %update: tensor<48x3xf32>) -> tensor<128x3xf32> {
+  %0 = linalg_ext.scatter
+    ins(%update, %indices : tensor<48x3xf32>, tensor<48xi32>)
+    outs(%original: tensor<128x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    } -> tensor<128x3xf32>
+  return %0 : tensor<128x3xf32>
+}
+// CHECK-LABEL: func @scatter_tensor_static(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<128x3xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: tensor<48xi32>
+//  CHECK-SAME:   %[[UPDATE:[a-zA-Z0-9_]+]]: tensor<48x3xf32>
+//       CHECK:   %[[RESULT:.+]] = linalg_ext.scatter
+//  CHECK-SAME:     ins(%[[UPDATE]], %[[INDICES]]
+//  CHECK-SAME:     outs(%[[ORIGINAL]]
+//       CHECK:     linalg_ext.yield %{{.+}} : f32
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func @scatter_memref_dynamic(
+    %original: memref<?x?xf32>, %indices: memref<?xi32>,
+    %update: memref<?x?xf32>) {
+  linalg_ext.scatter
+    ins(%update, %indices : memref<?x?xf32>, memref<?xi32>)
+    outs(%original: memref<?x?xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    }
+  return
+}
+// CHECK-LABEL: func @scatter_memref_dynamic(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: memref<?xi32>
+//  CHECK-SAME:   %[[UPDATE:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//       CHECK:   linalg_ext.scatter
+//  CHECK-SAME:     ins(%[[UPDATE]], %[[INDICES]]
+//  CHECK-SAME:     outs(%[[ORIGINAL]]
+//       CHECK:     linalg_ext.yield %{{.+}} : f32
+//       CHECK:   return
+
+// -----
+
+func @scatter_memref_static(
+    %original: memref<128x3xf32>, %indices: memref<48xi32>,
+    %update: memref<48x3xf32>) {
+  linalg_ext.scatter
+    ins(%update, %indices : memref<48x3xf32>, memref<48xi32>)
+    outs(%original: memref<128x3xf32>) {
+    ^bb0(%arg1: f32, %arg2: f32):
+      %1 = addf %arg1, %arg2 : f32
+      linalg_ext.yield %1 : f32
+    }
+  return
+}
+// CHECK-LABEL: func @scatter_memref_static(
+//  CHECK-SAME:   %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<128x3xf32>
+//  CHECK-SAME:   %[[INDICES:[a-zA-Z0-9_]+]]: memref<48xi32>
+//  CHECK-SAME:   %[[UPDATE:[a-zA-Z0-9_]+]]: memref<48x3xf32>
+//       CHECK:   linalg_ext.scatter
+//  CHECK-SAME:     ins(%[[UPDATE]], %[[INDICES]]
+//  CHECK-SAME:     outs(%[[ORIGINAL]]
+//       CHECK:     linalg_ext.yield %{{.+}} : f32
+//       CHECK:   return
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 7af12e8..319f50c 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -1,7 +1,7 @@
 // RUN: iree-opt -split-input-file -iree-linalg-ext-to-loops %s | IreeFileCheck %s
 
 func @sort_1d(%arg0: memref<128xi32>) {
-  linalg_ext.sort {dimension = 0 : i64}
+  linalg_ext.sort dimension(0)
     outs(%arg0 : memref<128xi32>) {
   ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
     %0 = cmpi sgt, %arg2, %arg3 : i32
@@ -30,7 +30,7 @@
 // -----
 
 func @sort_2d(%arg0: memref<16x32xi32>) {
-  linalg_ext.sort {dimension = 0 : i64}
+  linalg_ext.sort dimension(0)
     outs(%arg0 : memref<16x32xi32>) {
   ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
     %0 = cmpi sgt, %arg2, %arg3 : i32
diff --git a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir b/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
index 7902e19..8464723 100644
--- a/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
+++ b/iree/compiler/InputConversion/MHLO/test/convert_and_distribute_mhlo_to_linalg_ext.mlir
@@ -16,7 +16,7 @@
 // CHECK:           %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:128xi32>
 // CHECK:           %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]]
 // CHECK:           %[[SORT:.+]] = linalg_ext.sort
-// CHECK-SAME:        dimension = 0 : i64
+// CHECK-SAME:        dimension(0)
 // CHECK-SAME:        outs(%[[IN]] : tensor<128xi32>)
 // CHECK:          ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32)
 // CHECK:            %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
@@ -42,7 +42,7 @@
 // CHECK:           %[[ARG1:.+]]: !flow.dispatch.tensor<readwrite:16x32xi32>
 // CHECK:           %[[IN:.+]] = flow.dispatch.tensor.load %[[ARG1]]
 // CHECK:           %[[SORT:.+]] = linalg_ext.sort
-// CHECK-SAME:        dimension = 0 : i64
+// CHECK-SAME:        dimension(0)
 // CHECK-SAME:        outs(%[[IN]] : tensor<16x32xi32>)
 // CHECK:          ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32)
 // CHECK:            %[[CMP:.+]] = cmpi sgt, %[[ARG2]], %[[ARG3]]
@@ -72,7 +72,7 @@
 // CHECK:           %[[IN1:.+]] = flow.dispatch.tensor.load %[[ARG2]]
 // CHECK:           %[[IN2:.+]] = flow.dispatch.tensor.load %[[ARG3]]
 // CHECK:           %[[SORT:.+]]:2 = linalg_ext.sort
-// CHECK-SAME:        dimension = 0 : i64
+// CHECK-SAME:        dimension(0)
 // CHECK-SAME:        outs(%[[IN1]], %[[IN2]] : tensor<128xi32>, tensor<128xi32>)
 // CHECK:          ^bb0(%[[ARG4:.+]]: i32, %[[ARG5:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32)
 // CHECK:            %[[CMP:.+]] = cmpi sgt, %[[ARG4]], %[[ARG5]]