blob: d598af4ec2f8ae40febb25d67fa45406eaa7945c [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_DIALECT_LINALGEXT_OPS
#define IREE_DIALECT_LINALGEXT_OPS
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
//===----------------------------------------------------------------------===//
// Base class.
//===----------------------------------------------------------------------===//
class LinalgExt_PureOp<string mnemonic, list<OpTrait> traits = []> :
Op<LinalgExt_Dialect, mnemonic, traits> {
}
class LinalgExt_Op<string mnemonic, list<OpTrait> traits = []> :
LinalgExt_PureOp<mnemonic, !listconcat(traits,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
LinalgExtInterface, SingleBlockImplicitTerminator<"YieldOp">])> {
let verifier = [{ return verify$cppClass(*this); }];
let printer = [{ return print$cppClass(p, *this); }];
let parser = [{ return parse$cppClass(parser, result); }];
code extraLinalgExtOpClassDeclaration = [{
SmallVector<Value> getDestinationOperands() {
SmallVector<Value> dest(outputs().begin(), outputs().end());
return dest;
}
}];
}
//===----------------------------------------------------------------------===//
// Non-structured ops
//===----------------------------------------------------------------------===//
def LinalgExt_ScatterOp : LinalgExt_Op<"scatter",
[DeclareOpInterfaceMethods<TiledOpInterface,
["getTiledImplementation", "generateScalarImplementation"]>]> {
let summary = "Scatter operator";
let description = [{
Based on XLA operation semantics, takes two `inputs` (`update` and
`indices`) and `outputs` value (`original`). The operation updates
the value at the slices specified by `indices` by combining the
current value with the value in `updates` using the computation
specified in `region`. The `region` specifies a binary operation
of signature (T, T) -> T, where `T` is the element-type of
`updates` (and `original`). The first argument correspond the
value to be updated (i.e. from `updates`), and the second the
current value (i.e. value from `original`).
The `indices` is a 2D tensor/memref type. The first dim is the number of
updates, and the second dim is index depth. The index depth should always be
static.
The first dim of `updates` and `indices` is identical, since they represent
the number of updates.
The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`.
The first `index_depth` indices are derived from `indices` and the shape of
update value must match the rest shape of `original`.
The shapes definition follows tensorflow operations execept that it force
batch dims to be 1D. See more information in
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
}];
let arguments = (ins
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
Variadic<AnyRankedTensorOrMemRefType>:$outputs
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
int64_t getIndexDepth() {
return getInputOperand(1)
->get()
.getType()
.cast<ShapedType>()
.getShape()
.back();
}
Value updates() {
return getInputOperand(0)->get();
}
ShapedType getUpdateType() {
return updates().getType().cast<ShapedType>();
}
Value indices() {
return getInputOperand(1)->get();
}
ShapedType getIndicesType() {
return indices().getType().cast<ShapedType>();
}
Value original() {
return getOutputOperand(0)->get();
}
ShapedType getOriginalType() {
return original().getType().cast<ShapedType>();
}
int64_t getUpdateSliceRank() {
return updates().getType().cast<ShapedType>().getRank() - 1;
}
bool isScalarUpdate() {
return getUpdateSliceRank() == 0;
}
}];
}
def LinalgExt_SortOp : LinalgExt_Op<"sort",
[DeclareOpInterfaceMethods<TiledOpInterface,
["getPartitionableLoops", "generateScalarImplementation",
"getTiledImplementation"]>]> {
let summary = "Sort operator";
let description = [{
Based on XLA operation semantics, sorts the given `operands` at the given
`dimension` with the given `comparator`.
See https://www.tensorflow.org/xla/operation_semantics#sort.
}];
// Define arguments and results like linalg.generic op. The attribute has the
// same definition as mhlo.sort::dimension. If the rank is greater than 1,
// the attribute must be set. If the rank is exacatly 1, the dimension is
// optional.
let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
OptionalAttr<I64Attr>:$dimension
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
(`dimension` `(` $dimension^ `)`)?
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value operand(int index) {
return outputs()[index];
}
ShapedType getOperandType(int index) {
return operand(index).getType().cast<ShapedType>();
}
int64_t getOperandRank() {
return getOperandType(0).getRank();
}
ArrayRef<int64_t> getOperandShape() {
return getOperandType(0).getShape();
}
uint64_t getSortedDimension() {
uint64_t sortedDim = 0;
if (auto setSortedDim = dimension()) {
sortedDim = *setSortedDim;
}
return sortedDim;
}
}];
}
//===----------------------------------------------------------------------===//
// Pure ops
//===----------------------------------------------------------------------===//
def LinalgExt_YieldOp : LinalgExt_PureOp<"yield", [NoSideEffect, ReturnLike, Terminator]> {
let summary = "LinalgExt yield op";
let description = [{
`linalg_ext.yield` is a special terminator operation for blocks inside
regions in `linalg_ext` ops.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [
OpBuilder<(ins), [{ /* nothing to do */ }]>,
];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
#endif // IREE_DIALECT_LINALGEXT_OPS