| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #ifndef IREE_DIALECT_LINALGEXT_OPS |
| #define IREE_DIALECT_LINALGEXT_OPS |
| |
| include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtBase.td" |
| include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtInterfaces.td" |
| include "mlir/Interfaces/ControlFlowInterfaces.td" |
| include "mlir/Interfaces/DestinationStyleOpInterface.td" |
| include "mlir/Interfaces/InferTypeOpInterface.td" |
| include "mlir/Interfaces/SideEffectInterfaces.td" |
| include "mlir/Interfaces/TilingInterface.td" |
| include "mlir/Interfaces/ViewLikeInterface.td" |
| |
| //===----------------------------------------------------------------------===// |
| // Base class. |
| //===----------------------------------------------------------------------===// |
| |
| class IREELinalgExt_PureOp<string mnemonic, list<Trait> traits = []> : |
| Op<IREELinalgExt_Dialect, mnemonic, traits> { |
| } |
| |
| class IREELinalgExt_Op<string mnemonic, list<Trait> traits = []> : |
| IREELinalgExt_PureOp<mnemonic, !listconcat(traits, |
| [AttrSizedOperandSegments, |
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, |
| DestinationStyleOpInterface, LinalgExtInterface, |
| SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp"> |
| ])> { |
| let hasVerifier = 1; |
| let hasCustomAssemblyFormat = 1; |
| code extraLinalgExtOpClassDeclaration = ""; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Utility ops |
| //===----------------------------------------------------------------------===// |
| |
| def OpGroupUtilityOps : OpDocGroup { |
| let summary = "Utility ops"; |
| let description = ""; |
| } |
| |
| let opDocGroup = OpGroupUtilityOps in { |
| |
| def IREELinalgExt_DoNotDCEOperandsOp : |
| Op<IREELinalgExt_Dialect, "transform.do_not_dce_operands", []> { |
| let summary = "Unfoldable op that just keeps its operands live"; |
| let description = [{ |
| Unfoldable op that just keeps its operands live. This is to use with the |
| transform dialect in case where transforms introduce IR that would be |
| otherwise DCE'd by canonicalizations. |
| |
| This op should be added to the transform dialect in the fullness of time but |
| it can't be registered dynamically on the IREE side as that triggers errors |
| since the op does not implement any transform interface. |
| }]; |
| |
| let arguments = (ins Variadic<AnyType>:$operands); |
| let results = (outs); |
| let assemblyFormat = "attr-dict $operands `:` type($operands)"; |
| } |
| |
| def IREELinalgExt_YieldOp : IREELinalgExt_PureOp<"yield", [Pure, ReturnLike, Terminator]> { |
| let summary = "LinalgExt yield op"; |
| let description = [{ |
| `iree_linalg_ext.yield` is a special terminator operation for blocks inside |
| regions in `iree_linalg_ext` ops. |
| }]; |
| |
| let arguments = (ins Variadic<AnyType>:$operands); |
| |
| let builders = [ |
| OpBuilder<(ins), [{ /* nothing to do */ }]>, |
| ]; |
| |
| let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; |
| } |
| |
| } // OpGroupUtilityOps |
| |
| //===----------------------------------------------------------------------===// |
| // Non-structured ops |
| //===----------------------------------------------------------------------===// |
| |
| def OpGroupNonStructuredOps : OpDocGroup { |
| let summary = "Non-structured ops"; |
| let description = ""; |
| } |
| |
| let opDocGroup = OpGroupNonStructuredOps in { |
| |
| def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter", |
| [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["generateScalarImplementation", |
| "getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Scatter operator"; |
| let description = [{ |
| Based on XLA operation semantics, takes two `inputs` (`update` and |
| `indices`) and `outputs` value (`original`). The operation updates |
| the value at the slices specified by `indices` by combining the |
| current value with the value in `updates` using the computation |
| specified in `region`. The `region` specifies a binary operation |
| of signature (T, T) -> T, where `T` is the element-type of |
| `updates` (and `original`). The first argument correspond the |
| value to be updated (i.e. from `updates`), and the second the |
| current value (i.e. value from `original`). |
| |
| The `indices` is a 2D tensor/memref type. The first dim is the number of |
| updates, and the second dim is index depth. The index depth should always be |
| static. |
| |
| The first dim of `updates` and `indices` is identical, since they represent |
| the number of updates. |
| |
| The rank of the `original`/`result` is at least |
| `index_depth + rank(%updates) - 1`. The first `index_depth` indices are |
| derived from `indices` and the shape of update value has the last |
| rank(%original) - index_depth values match %(originals) last dimensions, |
| with the previous dims extending from the index offsets. |
| |
| The dimension_map attributes describes which index value maps to which |
| dimension in the destionation. It cannot contain duplicate values, must |
| have as many entries as index depth, and values must be within the rank of |
| the destination. |
| |
| The unique_indices attribute carries the information whether all the indices |
| are unique. If there are repeated indices, the first iteration loop will be |
| marked as reduction. |
| |
| The shapes definition follows tensorflow operations execept that it force |
| batch dims to be 1D. See more information in |
| https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update |
| }]; |
| let arguments = (ins |
| Variadic<AnyRankedTensorOrMemRefType>:$inputs, |
| Variadic<AnyRankedTensorOrMemRefType>:$outputs, |
| DenseI64ArrayAttr:$dimension_map, |
| DefaultValuedAttr<BoolAttr, "true">:$unique_indices |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let regions = (region AnyRegion:$region); |
| let assemblyFormat = [{ |
| attr-dict `dimension_map` `=` $dimension_map |
| `unique_indices` `(` $unique_indices `)` |
| (`ins` `(` $inputs^ `:` type($inputs) `)`)? |
| `outs` `(` $outputs `:` type($outputs) `)` |
| $region (`->` type($results)^)? |
| }]; |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| |
| int64_t getIndexDepth() { |
| return getDpsInputOperand(1) |
| ->get() |
| .getType() |
| .cast<ShapedType>() |
| .getShape() |
| .back(); |
| } |
| |
| Value updates() { |
| return getDpsInputOperand(0)->get(); |
| } |
| |
| ShapedType getUpdateType() { |
| return updates().getType().cast<ShapedType>(); |
| } |
| |
| Value indices() { |
| return getDpsInputOperand(1)->get(); |
| } |
| |
| ShapedType getIndicesType() { |
| return indices().getType().cast<ShapedType>(); |
| } |
| |
| Value original() { |
| return getDpsInitOperand(0)->get(); |
| } |
| |
| ShapedType getOriginalType() { |
| return original().getType().cast<ShapedType>(); |
| } |
| |
| int64_t getUpdateSliceRank() { |
| return updates().getType().cast<ShapedType>().getRank() - 1; |
| } |
| |
| bool isScalarUpdate() { |
| return getUpdateSliceRank() == 0; |
| } |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort", |
| [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["generateScalarImplementation", |
| "getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Sort operator"; |
| let description = [{ |
| Based on XLA operation semantics, sorts the given `operands` at the given |
| `dimension` with the given `comparator`. |
| |
| See https://www.tensorflow.org/xla/operation_semantics#sort. |
| }]; |
| |
| let arguments = (ins Variadic<AnyType>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| I64Attr:$dimension |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let regions = (region AnyRegion:$region); |
| let assemblyFormat = [{ |
| attr-dict |
| `dimension` `(` $dimension `)` |
| (`ins` `(` $inputs^ `:` type($inputs) `)`)? |
| `outs` `(` $outputs `:` type($outputs) `)` |
| $region (`->` type($results)^)? |
| }]; |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value operand(int index) { |
| return getOutputs()[index]; |
| } |
| ShapedType getOperandType(int index) { |
| return operand(index).getType().cast<ShapedType>(); |
| } |
| int64_t getOperandRank() { |
| return getOperandType(0).getRank(); |
| } |
| ArrayRef<int64_t> getOperandShape() { |
| return getOperandType(0).getShape(); |
| } |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_FftOp : IREELinalgExt_Op<"fft", [ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["generateScalarImplementation", |
| "getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Fft operator"; |
| let description = [{ |
| Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive. |
| Thus, the bit reversal is assumed applied on the input. The op carries an |
| input -- stage, which indicates the level of reduction loop in the |
| algorithm. It represents the computation body. For more details, see |
| "Data reordering, bit reversal, and in-place algorithms" section in |
| https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm |
| |
| The size of innermost dim is expected to be a power of 2. |
| |
| It is optional to carry coefficient tensors/buffers as inputs. In this |
| context, they will be the second and third inputs. |
| }]; |
| |
| let arguments = (ins Variadic<AnyType>:$inputs, |
| Variadic<AnyShaped>:$outputs |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let assemblyFormat = [{ |
| attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? |
| `outs` `(` $outputs `:` type($outputs) `)` |
| (`:` type($results)^)? |
| }]; |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value getStage() { return getInputs()[0]; } |
| Value getReal() { return getOutputs()[0]; } |
| Value getImag() { return getOutputs()[1]; } |
| bool hasCoeff() { return getNumDpsInputs() > 1; } |
| void generateScalarImplWithoutCoeffBuf( |
| OpBuilder & b, Location loc, ArrayRef<Value> operands, Value wholeSize); |
| void generateScalarImplWithCoeffBuf(OpBuilder & b, Location loc, |
| ArrayRef<Value> operands); |
| Value getRealCoeff() { |
| if (!hasCoeff()) return Value(); |
| return getInputs()[1]; |
| } |
| Value getImagCoeff() { |
| if (!hasCoeff()) return Value(); |
| return getInputs()[2]; |
| } |
| ShapedType getOperandType() { |
| return getReal().getType().cast<ShapedType>(); |
| } |
| int64_t getOperandRank() { |
| return getOperandType().getRank(); |
| } |
| ArrayRef<int64_t> getOperandShape() { |
| return getOperandType().getShape(); |
| } |
| int64_t getFftLength() { |
| return getOperandShape().back(); |
| } |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan", |
| [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["generateScalarImplementation", |
| "getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Scan operator"; |
| let description = [{ |
| Computes the inclusive/exclusive scan along a given dimension. |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| I64Attr:$dimension, |
| BoolAttr:$inclusive |
| ); |
| |
| let builders = [ |
| OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, |
| CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)> |
| ]; |
| |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let regions = (region AnyRegion:$region); |
| let hasFolder = 1; |
| let assemblyFormat = [{ |
| attr-dict |
| `dimension` `(` $dimension `)` |
| `inclusive` `(` $inclusive `)` |
| `ins` `(` $inputs `:` type($inputs) `)` |
| `outs` `(` $outputs `:` type($outputs) `)` |
| $region (`->` type($results)^)? |
| }]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value input() { |
| return getDpsInputOperand(0)->get(); |
| } |
| Value accumulator() { |
| return getDpsInitOperand(1)->get(); |
| } |
| Value output() { |
| return getDpsInitOperand(0)->get(); |
| } |
| ShapedType getOperandType() { |
| return input().getType().cast<ShapedType>(); |
| } |
| int64_t getOperandRank() { |
| return getOperandType().getRank(); |
| } |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods< |
| TilingInterface, |
| ["generateScalarImplementation", |
| "getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>, |
| DeclareOpInterfaceMethods<LinalgExtInterface>]> { |
| let summary = "Reverse operator"; |
| let description = [{ |
| A temporary solution for lowering reverse ops into IREE, allowing IREE to |
| tile and distribute them. |
| } |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| I64ElementsAttr:$dimensions |
| ); |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let assemblyFormat = [{ |
| attr-dict `dimensions` `(` $dimensions `)` |
| (`ins` `(` $inputs^ `:` type($inputs) `)`)? |
| (`outs` `(` $outputs^ `:` type($outputs) `)`)? |
| (`:` type($results)^)? |
| }]; |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value input() { |
| return getDpsInputOperand(0)->get(); |
| } |
| Value output() { |
| return getDpsInitOperand(0)->get(); |
| } |
| ShapedType getOperandType() { |
| return input().getType().cast<ShapedType>(); |
| } |
| int64_t getOperandRank() { |
| return getOperandType().getRank(); |
| } |
| ArrayRef<int64_t> getOprerandShape() { |
| return getOperandType().getShape(); |
| } |
| SmallVector<int64_t> dims() { |
| SmallVector<int64_t> ret; |
| for (const APInt& elem : getDimensions()) { |
| ret.push_back(elem.getLimitedValue()); |
| } |
| return ret; |
| } |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<LinalgExtInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["generateScalarImplementation", |
| "getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]> |
| ]>{ |
| let summary = "Top-K operator"; |
| let description = [{ |
| A Top-K operation for N-D tensors. Reduces the target dimension from the input |
| size N down to K elements based on the supplied binary region. |
| |
| Accepts an N-D tensor input consisting of values and an optioanl N-D tensor |
| for indices of those values (i32 type). If input indices aren't provided, the |
| index mapping is inferred based on the k dim. Both input values/indices |
| tensors and output values/indicies tensors must have the same shape. Top-K is |
| computed along the target dimension (from dimension()). Returns two output |
| tensors of values and the indicies of Top-K results. The output dimensions |
| must match the input save for the dimension that is reduced to K results. |
| |
| Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an |
| i1. If true, the two values are swapped: |
| - For Top-K compoarision: > |
| - For Min-K comparision: < |
| Note: when the two values are equal, the first occurence is always selected. |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| I64Attr:$dimension |
| ); |
| |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let regions = (region AnyRegion:$region); |
| let assemblyFormat = [{ |
| attr-dict |
| `dimension` `(` $dimension `)` |
| `ins` `(` $inputs `:` type($inputs) `)` |
| `outs` `(` $outputs `:` type($outputs) `)` |
| $region (`->` type($results)^)? |
| }]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value values() { |
| return getDpsInputOperand(0)->get(); |
| } |
| std::optional<Value> indices() { |
| if (getNumDpsInputs() < 2) { |
| return {}; |
| } else { |
| return getDpsInputOperand(1)->get(); |
| } |
| } |
| Value outputValues() { |
| return getDpsInitOperand(0)->get(); |
| } |
| Value outputIndices() { |
| return getDpsInitOperand(1)->get(); |
| } |
| ShapedType getInputType() { |
| return values().getType().cast<ShapedType>(); |
| } |
| int64_t getInputRank() { |
| return getInputType().getRank(); |
| } |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Attention |
| //===----------------------------------------------------------------------===// |
| |
| def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention", |
| [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Attention operator"; |
| let description = [{ |
| This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes |
| the attention. For self-attention, all inputs have the same shape BxNxd where B is the |
| of the batch dimension, N is the sequence length and d is head dimension. |
| Typically N >>> d. Mathematically, the attention is defined as |
| matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually, |
| this operator also performs scaling, masking and dropout, but we leave |
| that out of the current implementation. For cross-attention, the query and output |
| have the same shape (BxNxd), while the key and value differ in sequence length |
| (they have shape BxLxd, where L != N). |
| This operator after tiling results in a tiled result as per flash attention and results |
| in the current `max` and `sum` statistics while processing the current tile. |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs |
| ); |
| |
| let builders = [ |
| OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)> |
| ]; |
| |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let hasFolder = 1; |
| let assemblyFormat = [{ |
| attr-dict |
| `ins` `(` $inputs `:` type($inputs) `)` |
| `outs` `(` $outputs `:` type($outputs) `)` |
| (`->` type($results)^)? |
| }]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value getQuery() { |
| return getDpsInputOperand(0)->get(); |
| } |
| Value getKey() { |
| return getDpsInputOperand(1)->get(); |
| } |
| Value getValue() { |
| return getDpsInputOperand(2)->get(); |
| } |
| Value getOutput() { |
| return getDpsInitOperand(0)->get(); |
| } |
| std::optional<Value> getMax() { |
| if (getNumResults() == 1) |
| return std::nullopt; |
| return getDpsInitOperand(1)->get(); |
| } |
| std::optional<Value> getSum() { |
| if (getNumResults() == 1) |
| return std::nullopt; |
| return getDpsInitOperand(2)->get(); |
| } |
| ShapedType getQueryType() { |
| return getQuery().getType().cast<ShapedType>(); |
| } |
| ShapedType getKeyType() { |
| return getKey().getType().cast<ShapedType>(); |
| } |
| ShapedType getValueType() { |
| return getValue().getType().cast<ShapedType>(); |
| } |
| ShapedType getOutputType() { |
| return getOutput().getType().cast<ShapedType>(); |
| } |
| std::optional<ShapedType> getMaxType() { |
| if (!getMax().has_value()) |
| return std::nullopt; |
| return (*getMax()).getType().cast<ShapedType>(); |
| } |
| std::optional<ShapedType> getSumType() { |
| if (!getSum().has_value()) |
| return std::nullopt; |
| return (*getSum()).getType().cast<ShapedType>(); |
| } |
| int64_t getQueryRank() { |
| return getQueryType().getRank(); |
| } |
| int64_t getKeyRank() { |
| return getKeyType().getRank(); |
| } |
| int64_t getValueRank() { |
| return getValueType().getRank(); |
| } |
| int64_t getOutputRank() { |
| return getOutputType().getRank(); |
| } |
| std::optional<int64_t> getMaxRank() { |
| if (!getMax()) |
| return std::nullopt; |
| return (*getMaxType()).getRank(); |
| } |
| std::optional<int64_t> getSumRank() { |
| if (!getSum().has_value()) |
| return std::nullopt; |
| return (*getSumType()).getRank(); |
| } |
| int64_t getIterationDomainRank() { |
| return 2; |
| }; |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| } // OpGroupNonStructuredOps |
| |
| //===----------------------------------------------------------------------===// |
| // Data tiling ops |
| //===----------------------------------------------------------------------===// |
| |
| def OpGroupDataTilingOps : OpDocGroup { |
| let summary = "Data tiling ops"; |
| let description = [{ |
| Operations for working with data layouts, padding, encodings, and other |
| properties useful for tiling computations across iteration space dimensions. |
| }]; |
| } |
| |
| let opDocGroup = OpGroupDataTilingOps in { |
| |
| def IREELinalgExt_PackOp : IREELinalgExt_Op<"pack", [ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<LinalgExtInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["getIterationDomain", |
| "generateScalarImplementation"]>, |
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface> |
| ]>{ |
| let summary = "pack operation"; |
| let description = [{ |
| The pack operation converts an `input` into a tiled and packed layout. The |
| dimensions to be tiled are obtained from `inner_dims_pos` and the size of the |
| tile is obtained from `inner_tiles`. The dimensions listed in `inner_dims_pos` |
| do not need to be contiguous in which case the tile will get transposed. We |
| handle only full tiles if `padding_value` is not set; it is UB if the tile does |
| not perfectly divide the dimension. If `padding_value` is set, it will pad |
| along high dimensions, i.e., it pads at the bottom and on the right if the |
| input has rank 2, and the result type shape, will be dynamic in any dimension |
| if and only if the input shape is. As optional input, the operation takes |
| `outer_dims_perm` that allows to permute the tiled loops. |
| |
| Example KC_to_KCck: |
| |
| ```mlir |
| iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0] |
| inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<16x8x32x8xf32>) |
| ``` |
| |
| Example NC_to_NCnc: |
| |
| ```mlir |
| iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1] |
| inner_tiles = [8, 32] into %arg1 : (memref<128x256xf32> memref<16x8x8x32xf32>) |
| ``` |
| Example KC_to_CKkc |
| |
| ```mlir |
| iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] |
| inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>) |
| ``` |
| |
| In all cases, dimension at position 0 in the input memref (128) is tiled |
| with a factor of 8, while dimension at position 1 (256) is tiled with a factor |
| of 32. In the KC_to_KCck example, the point loops are interchanged, while in the |
| KC_to_CKkc example the tiled loops. |
| |
| Example NC_to_NCnc with padding: |
| |
| ```mlir |
| iree_linalg_ext.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1] |
| inner_tiles = [8, 2] into %arg1 : (memref<13x15xf32> memref<2x8x8x2xf32>) |
| ``` |
| |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm, |
| DenseI64ArrayAttr:$inner_dims_pos, |
| Variadic<Index>:$inner_tiles, |
| DenseI64ArrayAttr:$static_inner_tiles, |
| Optional<AnyType>:$padding_value); |
| |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let assemblyFormat = [{ |
| attr-dict |
| $inputs |
| (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? |
| (`outer_dims_perm` `=` $outer_dims_perm^)? |
| `inner_dims_pos` `=` $inner_dims_pos |
| `inner_tiles` `=` |
| custom<DynamicIndexList>($inner_tiles, $static_inner_tiles) |
| `into` $outputs `:` `(` type($inputs) type($outputs) `)` |
| (`->` type($results)^)? |
| }]; |
| |
| let builders = [ |
| OpBuilder<(ins "Value":$source, "Value":$output, |
| "ArrayRef<int64_t>":$innerDimsPos, |
| "ArrayRef<OpFoldResult>":$innerTiles, |
| CArg<"std::optional<Value>", "std::nullopt">:$paddingValue, |
| CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)> |
| ]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| |
| // Return the output operand. |
| Value getOutput() { |
| return getDpsInitOperand(0)->get(); |
| } |
| |
| // Return the input operand. |
| Value getInput() { |
| return getDpsInputOperand(0)->get(); |
| } |
| |
| // Return the output rank. |
| int64_t getOutputRank() { |
| return getOutputType().getRank(); |
| } |
| |
| // Return the output type. |
| ShapedType getOutputType() { |
| return getOutput().getType().cast<ShapedType>(); |
| } |
| |
| // Return the input type. |
| ShapedType getInputType() { |
| return getInput().getType().cast<ShapedType>(); |
| } |
| |
| // Return the output shape. |
| ArrayRef<int64_t> getOutputShape() { |
| return getOutputType().getShape(); |
| } |
| |
| // Return the input shape. |
| ArrayRef<int64_t> getInputShape() { |
| return getInputType().getShape(); |
| } |
| |
| // Return the element type. |
| Type getElementType() { |
| return getInputType().getElementType(); |
| } |
| |
| // Return the rank of the input operand. |
| int64_t getInputRank() { |
| return getInputType().getRank(); |
| } |
| |
| // Return the tile sizes. |
| SmallVector<OpFoldResult> getMixedTiles(); |
| SmallVector<int64_t> getStaticTiles(); |
| |
| // Return a mapping from positions `dims_pos` to their tile factors. |
| DenseMap<int64_t, OpFoldResult> getDimAndTileMapping(); |
| |
| // Method to get the shape of the result as `SmallVector<OpFoldResult>`. |
| // This is a static method to allow getting the shape of the destination |
| // expected while creating a `pack` op. |
| static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder, |
| Location loc, ArrayRef<OpFoldResult> sourceDims, |
| ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos, |
| ArrayRef<int64_t> outerDimsPerm = {}); |
| // Method to return the shape of the result as `SmallVector<OpFoldResult>`. |
| SmallVector<OpFoldResult> getResultShape(OpBuilder &builder); |
| |
| // Method to get the `ShapedType` of the result. This is a static method |
| // to allow getting the type of the destination while creating the `pack` |
| // op. |
| static ShapedType getPackedType(ShapedType sourceType, |
| ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos, |
| ArrayRef<int64_t> outerDimsPerm = {}); |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_UnPackOp : IREELinalgExt_Op<"unpack", [ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<LinalgExtInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["getIterationDomain", |
| "generateScalarImplementation"]>, |
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface> |
| ]>{ |
| let summary = "unpack operation"; |
| |
| let description = [{ |
| The unpack operation converts a tiled and packed input to an unpacked |
| output. See `pack` for more details on `inner_tiles` and `dims_pos`; it is UB |
| if the tile does not perfectly divide the dimension. Optionally, the operation |
| also supports permuting the tiled loops. |
| |
| Example KCck_to_KC: |
| |
| ```mlir |
| iree_linalg_ext.unpack %arg0 dims_pos = [1, 0] |
| inner_tiles = [32, 8] into %arg1 : (memref<16x8x32x8xf32> memref<128x256xf32>) |
| ``` |
| |
| Example NCnc_to_NC: |
| |
| ```mlir |
| iree_linalg_ext.unpack %arg0 dims_pos = [0, 1] |
| inner_tiles = [8, 32] into %arg1 : (memref<16x8x8x32xf32> memref<128x256xf32>) |
| ``` |
| |
| Example CKkc_to_KC: |
| |
| ```mlir |
| iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] |
| inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>) |
| ``` |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm, |
| DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$inner_dims_pos, |
| Variadic<Index>:$inner_tiles, |
| DenseI64ArrayAttr:$static_inner_tiles); |
| |
| let results = (outs Variadic<AnyRankedTensor>:$results); |
| let assemblyFormat = [{ |
| attr-dict |
| $inputs |
| (`outer_dims_perm` `=` $outer_dims_perm^)? |
| `inner_dims_pos` `=` $inner_dims_pos |
| `inner_tiles` `=` |
| custom<DynamicIndexList>($inner_tiles, $static_inner_tiles) |
| `into` $outputs `:` `(` type($inputs) type($outputs) `)` |
| (`->` type($results)^)? |
| }]; |
| |
| let builders = [ |
| OpBuilder<(ins "Value":$source, "Value":$output, |
| "ArrayRef<int64_t>":$innerDimsPos, |
| "ArrayRef<OpFoldResult>":$innerTiles, |
| CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)> |
| ]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| |
| // Return the output operand. |
| Value getOutput() { |
| return getDpsInitOperand(0)->get(); |
| } |
| |
| // Return the input operand. |
| Value getInput() { |
| return getDpsInputOperand(0)->get(); |
| } |
| |
| // Return the output rank. |
| int64_t getOutputRank() { |
| return getOutputType().getRank(); |
| } |
| |
| // Return the output type. |
| ShapedType getOutputType() { |
| return getOutput().getType().cast<ShapedType>(); |
| } |
| |
| // Return the input type. |
| ShapedType getInputType() { |
| return getInput().getType().cast<ShapedType>(); |
| } |
| |
| // Return the output shape. |
| ArrayRef<int64_t> getOutputShape() { |
| return getOutputType().getShape(); |
| } |
| |
| // Return the input shape. |
| ArrayRef<int64_t> getInputShape() { |
| return getInputType().getShape(); |
| } |
| |
| // Return the rank of the input operand. |
| int64_t getInputRank() { |
| return getInputType().getRank(); |
| } |
| |
| // Return the tile sizes. |
| SmallVector<OpFoldResult> getMixedTiles(); |
| SmallVector<int64_t> getStaticTiles(); |
| |
| // Return a mapping from positions `dims_pos` to their tile factors. |
| DenseMap<int64_t, OpFoldResult> getDimAndTileMapping(); |
| |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_SetEncodingOp : IREELinalgExt_PureOp<"set_encoding",[ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, Pure |
| ]> { |
| let summary = "perform pack and pad operation on source"; |
| let description = [{ |
| Operation to assign an encoding to a tensor. The operation |
| does not change the rank or extent of a tensor. Instead it |
| adds an encoding attribute to the tensor type to represent |
| a change in layout. |
| }]; |
| |
| let arguments = (ins AnyRankedTensor:$source); |
| let results = (outs AnyRankedTensor:$result); |
| |
| let assemblyFormat = [{ |
| attr-dict $source `:` type($source) `->` type($result) |
| }]; |
| |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = [{ |
| RankedTensorType getSourceType() { |
| return getSource().getType().cast<RankedTensorType>(); |
| } |
| RankedTensorType getResultType() { |
| return getResult().getType().cast<RankedTensorType>(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_UpperBoundTileSizeOp : IREELinalgExt_PureOp<"upper_bound_tile_size", |
| [Pure]> { |
| let summary = "returns an upper bound on tile sizes"; |
| let description = [{ |
| This returns the largest tile sizes that might result from materialization |
| of the given encoding. This can be used outside of target-specific code, so |
| there may be multiple targets, and this will return the maximum tile size |
| from iterating over all of them. The evaluation happens in the |
| MaterializeUpperBoundTileSize pass. |
| }]; |
| |
| let arguments = (ins TypeAttrOf<AnyRankedTensor>:$tensorType); |
| let results = (outs Variadic<Index>:$results); |
| |
| let assemblyFormat = [{ |
| attr-dict $tensorType `->` type($results) |
| }]; |
| } |
| |
| def IREELinalgExt_UnsetEncodingOp : IREELinalgExt_PureOp<"unset_encoding", [ |
| DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, Pure |
| ]> { |
| let summary = "perfom unpack and extract operation on source"; |
| let description = [{ |
| Operation to convert an tensor with encoding that represents |
| its data layout into a tensor with default layout (i.e. no encoding). |
| For now in IREE the default layout is row-major. |
| }]; |
| let arguments = (ins AnyRankedTensor:$source); |
| let results = (outs AnyRankedTensor:$result); |
| |
| let assemblyFormat = [{ |
| attr-dict $source `:` type($source) `->` type($result) |
| }]; |
| |
| let hasVerifier = 1; |
| |
| let extraClassDeclaration = [{ |
| RankedTensorType getSourceType() { |
| return getSource().getType().cast<RankedTensorType>(); |
| } |
| RankedTensorType getResultType() { |
| return getResult().getType().cast<RankedTensorType>(); |
| } |
| }]; |
| } |
| |
| } // OpGroupDataTilingOps |
| |
| //===----------------------------------------------------------------------===// |
| // Winograd ops |
| //===----------------------------------------------------------------------===// |
| |
| def OpGroupWinogradOps : OpDocGroup { |
| let summary = "Winograd ops"; |
| let description = ""; |
| } |
| |
| let opDocGroup = OpGroupWinogradOps in { |
| |
| def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_transform", |
| [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Winograd Input Transform operator"; |
| let description = [{ |
| This operator is the first step in converting a convolution to |
| its Winograd equivalent. Given a tile of an input image (I), |
| this operator computes matmul(tranpose(B), matmul(I, B)). |
| The input tile is assumed to be square with each side of size m + r - 1, |
| where the convolutional kernel is m x m and the output tile size is r x r. |
| B is a constant 2-d square matrix of the same shape as the input tile I. |
| The input to the operator is an image of shape (N, H, W, C) or (N, C, H, W) |
| and the output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C) |
| where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result |
| of this operator is first collapsed and then fed to a batch matmul op. |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| I64Attr:$output_tile_size, |
| I64Attr:$kernel_size, |
| DenseI64ArrayAttr:$image_dimensions |
| ); |
| |
| let builders = [ |
| OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, |
| CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size, |
| CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)> |
| ]; |
| |
| let results = (outs Variadic<AnyRankedTensor>:$result); |
| let hasFolder = 1; |
| let assemblyFormat = [{ |
| attr-dict |
| `output_tile_size` `(` $output_tile_size `)` |
| `kernel_size` `(` $kernel_size `)` |
| `image_dimensions` `(` $image_dimensions `)` |
| `ins` `(` $inputs `:` type($inputs) `)` |
| `outs` `(` $outputs `:` type($outputs) `)` |
| (`->` type($result)^)? |
| }]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value input() { |
| return getDpsInputOperand(0)->get(); |
| } |
| Value output() { |
| return getDpsInitOperand(0)->get(); |
| } |
| ShapedType getInputOperandType() { |
| return input().getType().cast<ShapedType>(); |
| } |
| ShapedType getOutputOperandType() { |
| return output().getType().cast<ShapedType>(); |
| } |
| int64_t getInputOperandRank() { |
| return getInputOperandType().getRank(); |
| } |
| int64_t getOutputOperandRank() { |
| return getOutputOperandType().getRank(); |
| } |
| int64_t getInputTileSize() { |
| return getOutputTileSize() + getKernelSize() - 1; |
| } |
| SmallVector<int64_t> imageDimensions() { |
| return llvm::to_vector(getImageDimensions()); |
| } |
| std::array<int64_t, 2> nhwcImageDimensions() { |
| return {1, 2}; |
| } |
| std::array<int64_t, 2> nchwImageDimensions() { |
| return {2, 3}; |
| } |
| bool isNhwc() { |
| std::array<int64_t, 2> nhwcImageDims = nhwcImageDimensions(); |
| SmallVector<int64_t> imageDims = imageDimensions(); |
| return imageDims == ArrayRef<int64_t>(nhwcImageDims); |
| } |
| bool isNchw() { |
| std::array<int64_t, 2> nchwImageDims = nchwImageDimensions(); |
| SmallVector<int64_t> imageDims = imageDimensions(); |
| return imageDims == ArrayRef<int64_t>(nchwImageDims); |
| } |
| int channelDim() { |
| return isNhwc() ? 3 : 1; |
| } |
| int64_t getIterationDomainRank() { |
| SmallVector<int64_t> imageDims = imageDimensions(); |
| return getInputOperandRank() - imageDims.size(); |
| } |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_transform", |
| [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, |
| DeclareOpInterfaceMethods<TilingInterface, |
| ["getIterationDomain", |
| "getLoopIteratorTypes", |
| "getResultTilePosition", |
| "getTiledImplementation"]>]> { |
| let summary = "Winograd Output Transform operator"; |
| let description = [{ |
| This operator is the last transform in converting a convolution to |
| its Winograd equivalent. After convolution in the Winograd domain |
| (which turns into an elementwise product for a single channel and |
| batch matrix multiplication for many channels), this operator converts |
| the output back into the original domain. Given a tile of the |
| output (O) in the Winograd domain, this operator computes |
| matmul(transpose(A), matmul(O, A)). The output tile is square with |
| each side of size m + r - 1, where the convolutional kernel is m x m |
| and the output tile size is r x r. A is a constant 2-d matrix of |
| shape (m + r - 1) x r. The input to the operator is a tensor of |
| shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a |
| tensor of shape (N, H, W, C) or (N, C, H, W) where H = r H' and W = r W'. |
| This operator is followed by a tensor.extract_slice which extracts |
| only the non-padded part of the output. |
| }]; |
| |
| let arguments = (ins Variadic<AnyShaped>:$inputs, |
| Variadic<AnyShaped>:$outputs, |
| I64Attr:$output_tile_size, |
| I64Attr:$kernel_size, |
| DenseI64ArrayAttr:$image_dimensions |
| ); |
| |
| let builders = [ |
| OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs, |
| CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size, |
| CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)> |
| ]; |
| |
| let results = (outs Variadic<AnyRankedTensor>:$result); |
| let hasFolder = 1; |
| let assemblyFormat = [{ |
| attr-dict |
| `output_tile_size` `(` $output_tile_size `)` |
| `kernel_size` `(` $kernel_size `)` |
| `image_dimensions` `(` $image_dimensions `)` |
| `ins` `(` $inputs `:` type($inputs) `)` |
| `outs` `(` $outputs `:` type($outputs) `)` |
| (`->` type($result)^)? |
| }]; |
| |
| let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ |
| Value input() { |
| return getDpsInputOperand(0)->get(); |
| } |
| Value output() { |
| return getDpsInitOperand(0)->get(); |
| } |
| ShapedType getInputOperandType() { |
| return input().getType().cast<ShapedType>(); |
| } |
| ShapedType getOutputOperandType() { |
| return output().getType().cast<ShapedType>(); |
| } |
| SmallVector<int64_t> imageDimensions() { |
| return llvm::to_vector(getImageDimensions()); |
| } |
| std::array<int64_t, 2> nhwcImageDimensions() { |
| return {1, 2}; |
| } |
| std::array<int64_t, 2> nchwImageDimensions() { |
| return {2, 3}; |
| } |
| bool isNhwc() { |
| std::array<int64_t, 2> nhwcImageDims = nhwcImageDimensions(); |
| SmallVector<int64_t> imageDims = imageDimensions(); |
| return imageDims == ArrayRef<int64_t>(nhwcImageDims); |
| } |
| bool isNchw() { |
| std::array<int64_t, 2> nchwImageDims = nchwImageDimensions(); |
| SmallVector<int64_t> imageDims = imageDimensions(); |
| return imageDims == ArrayRef<int64_t>(nchwImageDims); |
| } |
| int channelDim() { |
| return isNhwc() ? 3 : 1; |
| } |
| int64_t getInputOperandRank() { |
| return getInputOperandType().getRank(); |
| } |
| int64_t getOutputOperandRank() { |
| return getOutputOperandType().getRank(); |
| } |
| int64_t getIterationDomainRank() { |
| SmallVector<int64_t> imageDims = imageDimensions(); |
| return getOutputOperandRank() - imageDims.size(); |
| } |
| int64_t getInputTileSize() { |
| return getOutputTileSize() + getKernelSize() - 1; |
| } |
| // Method to implement for specifying output range for |
| // DestinationStyleOpInterface |
| MutableOperandRange getDpsInitsMutable() { |
| return getOutputsMutable(); |
| } |
| }]; |
| } |
| |
| } // OpGroupWinogradOps |
| |
| #endif // IREE_DIALECT_LINALGEXT_OPS |