| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #ifndef IREE_DIALECT_LINALGEXT_OPS |
| #define IREE_DIALECT_LINALGEXT_OPS |
| |
| include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td" |
| include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td" |
| include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td" |
| include "mlir/Interfaces/SideEffectInterfaces.td" |
| include "mlir/Interfaces/ControlFlowInterfaces.td" |
| |
| //===----------------------------------------------------------------------===// |
| // Base class. |
| //===----------------------------------------------------------------------===// |
| |
| class LinalgExt_PureOp<string mnemonic, list<OpTrait> traits = []> : |
| Op<LinalgExt_Dialect, mnemonic, traits> { |
| } |
| |
| class LinalgExt_Op<string mnemonic, list<OpTrait> traits = []> : |
| LinalgExt_PureOp<mnemonic, !listconcat(traits, |
| [AttrSizedOperandSegments, |
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, |
| LinalgExtInterface, SingleBlockImplicitTerminator<"YieldOp">])> { |
| let verifier = [{ return verify$cppClass(*this); }]; |
| let printer = [{ return print$cppClass(p, *this); }]; |
| let parser = [{ return parse$cppClass(parser, result); }]; |
| code extraLinalgExtOpClassDeclaration = [{ |
| SmallVector<Value> getDestinationOperands() { |
| SmallVector<Value> dest(outputs().begin(), outputs().end()); |
| return dest; |
| } |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Non-structured ops |
| //===----------------------------------------------------------------------===// |
| |
| def LinalgExt_ScatterOp : LinalgExt_Op<"scatter", |
| [DeclareOpInterfaceMethods<TiledOpInterface, |
| ["getTiledImplementation", "generateScalarImplementation"]>]> { |
| let summary = "Scatter operator"; |
| let description = [{ |
| Based on XLA operation semantics, takes two `inputs` (`update` and |
| `indices`) and `outputs` value (`original`). The operation updates |
| the value at the slices specified by `indices` by combining the |
| current value with the value in `updates` using the computation |
| specified in `region`. The `region` specifies a binary operation |
| of signature (T, T) -> T, where `T` is the element-type of |
| `updates` (and `original`). The first argument correspond the |
| value to be updated (i.e. from `updates`), and the second the |
| current value (i.e. value from `original`). |
| |
| The `indices` is a 2D tensor/memref type. The first dim is the number of |
| updates, and the second dim is index depth. The index depth should always be |
| static. |
| |
| The first dim of `updates` and `indices` is identical, since they represent |
| the number of updates. |
| |
| The rank of the `original`/`result` is `index_depth + rank(%updates) - 1`. |
| The first `index_depth` indices are derived from `indices` and the shape of |
| update value must match the rest shape of `original`. |
| |
| The shapes definition follows tensorflow operations execept that it force |
| batch dims to be 1D. See more information in |
| https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update |
| }]; |
| let arguments = (ins |
| Variadic<AnyRankedTensorOrMemRefType>:$inputs, |
| Variadic<AnyRankedTensorOrMemRefType>:$outputs |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let regions = (region AnyRegion:$region); |
| let assemblyFormat = [{ |
| attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? |
| `outs` `(` $outputs `:` type($outputs) `)` |
| $region (`->` type($results)^)? |
| }]; |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| |
| int64_t getIndexDepth() { |
| return getInputOperand(1) |
| ->get() |
| .getType() |
| .cast<ShapedType>() |
| .getShape() |
| .back(); |
| } |
| |
| Value updates() { |
| return getInputOperand(0)->get(); |
| } |
| |
| ShapedType getUpdateType() { |
| return updates().getType().cast<ShapedType>(); |
| } |
| |
| Value indices() { |
| return getInputOperand(1)->get(); |
| } |
| |
| ShapedType getIndicesType() { |
| return indices().getType().cast<ShapedType>(); |
| } |
| |
| Value original() { |
| return getOutputOperand(0)->get(); |
| } |
| |
| ShapedType getOriginalType() { |
| return original().getType().cast<ShapedType>(); |
| } |
| |
| int64_t getUpdateSliceRank() { |
| return updates().getType().cast<ShapedType>().getRank() - 1; |
| } |
| |
| bool isScalarUpdate() { |
| return getUpdateSliceRank() == 0; |
| } |
| }]; |
| } |
| |
| def LinalgExt_SortOp : LinalgExt_Op<"sort", |
| [DeclareOpInterfaceMethods<TiledOpInterface, |
| ["getPartitionableLoops", "generateScalarImplementation", |
| "getTiledImplementation"]>]> { |
| let summary = "Sort operator"; |
| let description = [{ |
| Based on XLA operation semantics, sorts the given `operands` at the given |
| `dimension` with the given `comparator`. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#sort. |
| }]; |
| |
| // Define arguments and results like linalg.generic op. The attribute has the |
| // same definition as mhlo.sort::dimension. If the rank is greater than 1, |
| // the attribute must be set. If the rank is exacatly 1, the dimension is |
| // optional. |
| let arguments = (ins Variadic<AnyType>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| OptionalAttr<I64Attr>:$dimension |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let regions = (region AnyRegion:$region); |
| let assemblyFormat = [{ |
| (`dimension` `(` $dimension^ `)`)? |
| attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? |
| `outs` `(` $outputs `:` type($outputs) `)` |
| $region (`->` type($results)^)? |
| }]; |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value operand(int index) { |
| return outputs()[index]; |
| } |
| ShapedType getOperandType(int index) { |
| return operand(index).getType().cast<ShapedType>(); |
| } |
| int64_t getOperandRank() { |
| return getOperandType(0).getRank(); |
| } |
| ArrayRef<int64_t> getOperandShape() { |
| return getOperandType(0).getShape(); |
| } |
| uint64_t getSortedDimension() { |
| uint64_t sortedDim = 0; |
| if (auto setSortedDim = dimension()) { |
| sortedDim = *setSortedDim; |
| } |
| return sortedDim; |
| } |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pure ops |
| //===----------------------------------------------------------------------===// |
| |
| def LinalgExt_YieldOp : LinalgExt_PureOp<"yield", [NoSideEffect, ReturnLike, Terminator]> { |
| let summary = "LinalgExt yield op"; |
| let description = [{ |
| `linalg_ext.yield` is a special terminator operation for blocks inside |
| regions in `linalg_ext` ops. |
| }]; |
| |
| let arguments = (ins Variadic<AnyType>:$operands); |
| |
| let builders = [ |
| OpBuilder<(ins), [{ /* nothing to do */ }]>, |
| ]; |
| |
| let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; |
| } |
| |
| #endif // IREE_DIALECT_LINALGEXT_OPS |