blob: b6c579179246e6ae153c150fcaba55403404f6ce [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_DIALECT_LINALGEXT_OPS
#define IREE_DIALECT_LINALGEXT_OPS
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/TilingInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
//===----------------------------------------------------------------------===//
class IREELinalgExt_PureOp<string mnemonic, list<Trait> traits = []> :
Op<IREELinalgExt_Dialect, mnemonic, traits> {
}
class IREELinalgExt_Op<string mnemonic, list<Trait> traits = []> :
IREELinalgExt_PureOp<mnemonic, !listconcat(traits,
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DestinationStyleOpInterface, LinalgExtInterface,
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">
])> {
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
code extraLinalgExtOpClassDeclaration = "";
}
//===----------------------------------------------------------------------===//
// Utility ops
//===----------------------------------------------------------------------===//
def OpGroupUtilityOps : OpDocGroup {
let summary = "Utility ops";
let description = "";
}
let opDocGroup = OpGroupUtilityOps in {
def IREELinalgExt_IndexOp : IREELinalgExt_PureOp<"index", [Pure]>,
Arguments<(ins ConfinedAttr<I64Attr, [IntMinValue<0>]>:$dim)>,
Results<(outs Index:$result)> {
let summary = "linalg_ext index operation";
let description = [{
This operation is a mirror of `linalg.index` operation and has the same
semantics, except that `linalg.index` enforces that the parent op is a
`LinalgOp`, and the `iree_linalg_ext.index` operation enforces that the
parent op is a `IREE::LinalgExt::CustomOp`.
}];
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
let hasVerifier = 1;
}
def IREELinalgExt_YieldOp : IREELinalgExt_PureOp<"yield", [Pure, ReturnLike, Terminator]> {
let summary = "LinalgExt yield op";
let description = [{
`iree_linalg_ext.yield` is a special terminator operation for blocks inside
regions in `iree_linalg_ext` ops.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [
OpBuilder<(ins), [{ /* nothing to do */ }]>,
];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
} // OpGroupUtilityOps
//===----------------------------------------------------------------------===//
// Non-structured ops
//===----------------------------------------------------------------------===//
def OpGroupNonStructuredOps : OpDocGroup {
let summary = "Non-structured ops";
let description = "";
}
let opDocGroup = OpGroupNonStructuredOps in {
def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation",
"getIterationDomainTileFromOperandTile",
"getTiledImplementationFromOperandTile"]>]> {
let summary = "Scatter operator";
let description = [{
Based on XLA operation semantics, takes two `inputs` (`update` and
`indices`) and `outputs` value (`original`). The operation updates
the value at the slices specified by `indices` by combining the
current value with the value in `updates` using the computation
specified in `region`. The `region` specifies a binary operation
of signature (T, T) -> T, where `T` is the element-type of
`updates` (and `original`). The first argument is from `updates`,
and the second is from `original`.
The operand `indices` is a N-D tensor/memref type that is composed
of two logical parts:
- The first `N-1` dimensions represent the batch of updates.
- The last dim (at index `N-1`) is the `index_depth`, which should
always be static.
For example, given `indices` of shape `[4, 3, 2]`, the batch dimensions
are `[4, 3]` and the `index_depth` is `2`.
The operand `update` is a M-D tensor/memref type and similarly
consists of two parts:
- The first `N-1` dimensions represent the batch of updates. This
must exactly match to the first `N-1` dimensions in `indices`
(from the example above: `indices` must start with `[4, 3]`)
- Dimensions `N..M-1` represent the slice scattered into `original`.
The first part of this tensor represents the dimensions indexed
by `indices`. This must be no larger than `index_depth` but can be
less if unit dimensions are omitted.
The second part represents a contiguous slice to be inserted into
`original`.
The operand `original` is a DPS init representing the destination that
`update` gets scattered to.
The rank of the `original` is at least `rank(%updates) - batch_rank`.
The first `index_depth` indices are derived from `indices` and the
shape of update value has the last rank(%original) - index_depth values
match %(originals) last dimensions, with the previous dims extending
from the index offsets.
The dimension_map attributes describes which index value maps to which
dimension in the destionation. It's rank must equal `index_depth` as
represents a permutation of the indices before indexing into `original``.
The unique_indices attribute carries the information whether all the
indices are unique. If `unique_indices` is `true` and two or more updates
scatter to the same location in `original` the final value in `original` is
not guaranteed. If `unique_indices` is set to false, the first
`batch_rank` iteration loops will be marked as reduction.
The shapes definition follows tensorflow operations. See more information in
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
}];
let arguments = (ins
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
Variadic<AnyRankedTensorOrMemRefType>:$outputs,
DenseI64ArrayAttr:$dimension_map,
DefaultValuedAttr<BoolAttr, "true">:$unique_indices
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict `dimension_map` `=` $dimension_map
`unique_indices` `(` $unique_indices `)`
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
static constexpr unsigned kUpdatesOpNum = 0;
static constexpr unsigned kIndicesOpNum = 1;
static constexpr unsigned kOriginalOpNum = 2;
int64_t getIndexDepth() {
return getDimensionMap().size();
}
Value getUpdates() {
return getDpsInputOperand(0)->get();
}
ShapedType getUpdateType() {
return cast<ShapedType>(getUpdates().getType());
}
Value getIndices() {
return getDpsInputOperand(1)->get();
}
ShapedType getIndicesType() {
return cast<ShapedType>(getIndices().getType());
}
Value getOriginal() {
return getDpsInitOperand(0)->get();
}
ShapedType getOriginalType() {
return cast<ShapedType>(getOriginal().getType());
}
/// Utility to get the rank of the portion of `indices` that
/// represents the batch dimensions
int64_t getBatchRank() {
return getUpdateType().getRank() - getUpdateSliceRank();
}
/// Utility to get the shape of the portion of `indices` that
/// represents the batch dimensions.
ArrayRef<int64_t> getBatchShape() {
return getIndicesType().getShape().slice(0, getBatchRank());
}
/// Utility to get the rank of the portion of `updates` that
/// is scattered into `original`.
int64_t getUpdateSliceRank() {
return getOriginalType().getRank() - getIndexDepth();
}
/// Utility to get the shape of the portion of `updates` that
/// is scattered into `original`.
ArrayRef<int64_t> getUpdateSliceShape() {
return getUpdateType().getShape().slice(getBatchRank(),
getUpdateSliceRank());
}
/// Utility to get the dimension in `updates` the corresponds
/// to the given dimension in `original`
int64_t convertOriginalDimToUpdatesDim(uint64_t dim) {
assert(dim >= 0 && dim < getOriginalType().getRank() &&
"expected dimension to be within original rank");
int64_t updateDim =
getUpdateType().getRank() - getOriginalType().getRank() + dim;
assert(updateDim >= getBatchRank() &&
"dim doesn't map to a dim in updates");
return updateDim;
}
/// Get the dimension in `original` that corresponds to the given
/// dimension in `original`.
int64_t convertUpdatesDimToOriginalDim(uint64_t dim) {
assert(dim >= getBatchRank() &&
"update batch dim doesn't map to original");
assert(dim < getUpdateType().getRank() &&
"expected dimension to be within updates rank");
return getOriginalType().getRank() - getUpdateType().getRank() + dim;
}
bool isScalarUpdate() {
return getUpdateSliceRank() == 0;
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Sort operator";
let description = [{
Based on XLA operation semantics, sorts the given `operands` at the given
`dimension` with the given `comparator`.
See https://www.tensorflow.org/xla/operation_semantics#sort.
}];
let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict
`dimension` `(` $dimension `)`
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getOperand(int index) {
return getOutputs()[index];
}
ShapedType getOperandType(int index) {
return cast<ShapedType>(getOperand(index).getType());
}
int64_t getOperandRank() {
return getOperandType(0).getRank();
}
ArrayRef<int64_t> getOperandShape() {
return getOperandType(0).getShape();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_FftOp : IREELinalgExt_Op<"fft", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Fft operator";
let description = [{
Apply 1D FFT to innermost dim. This is an iterative FFT, not recurrsive.
Thus, the bit reversal is assumed applied on the input. The op carries an
input -- stage, which indicates the level of reduction loop in the
algorithm. It represents the computation body. For more details, see
"Data reordering, bit reversal, and in-place algorithms" section in
https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm
The size of innermost dim is expected to be a power of 2.
It is optional to carry coefficient tensors/buffers as inputs. In this
context, they will be the second and third inputs.
}];
let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let assemblyFormat = [{
attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
`outs` `(` $outputs `:` type($outputs) `)`
(`:` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getStage() { return getInputs()[0]; }
Value getReal() { return getOutputs()[0]; }
Value getImag() { return getOutputs()[1]; }
bool hasCoeff() { return getNumDpsInputs() > 1; }
void generateScalarImplWithoutCoeffBuf(
OpBuilder & b, Location loc, ArrayRef<Value> operands, Value wholeSize);
void generateScalarImplWithCoeffBuf(OpBuilder & b, Location loc,
ArrayRef<Value> operands);
Value getRealCoeff() {
if (!hasCoeff()) return Value();
return getInputs()[1];
}
Value getImagCoeff() {
if (!hasCoeff()) return Value();
return getInputs()[2];
}
ShapedType getOperandType() {
return cast<ShapedType>(getReal().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
}
ArrayRef<int64_t> getOperandShape() {
return getOperandType().getShape();
}
int64_t getFftLength() {
return getOperandShape().back();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Scan operator";
let description = [{
Computes the inclusive/exclusive scan along a given dimension.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension,
BoolAttr:$inclusive
);
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)>
];
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`dimension` `(` $dimension `)`
`inclusive` `(` $inclusive `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getInput() {
return getDpsInputOperand(0)->get();
}
Value getAccumulator() {
return getDpsInitOperand(1)->get();
}
Value getOutput() {
return getDpsInitOperand(0)->get();
}
ShapedType getOperandType() {
return cast<ShapedType>(getInput().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgExtInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>
]>{
let summary = "Top-K operator";
let description = [{
A Top-K operation for N-D tensors. Reduces the target dimension from the input
size N down to K elements based on the supplied binary region.
Accepts an N-D tensor input consisting of values and an optioanl N-D tensor
for indices of those values (i32 type). If input indices aren't provided, the
index mapping is inferred based on the k dim. Both input values/indices
tensors and output values/indicies tensors must have the same shape. Top-K is
computed along the target dimension (from dimension()). Returns two output
tensors of values and the indicies of Top-K results. The output dimensions
must match the input save for the dimension that is reduced to K results.
Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an
i1. If true, the two values are swapped:
- For Top-K compoarision: >
- For Min-K comparision: <
Note: when the two values are equal, the first occurence is always selected.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension
);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region AnyRegion:$region);
let assemblyFormat = [{
attr-dict
`dimension` `(` $dimension `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getValues() {
return getDpsInputOperand(0)->get();
}
std::optional<Value> getIndices() {
if (getNumDpsInputs() < 2) {
return {};
} else {
return getDpsInputOperand(1)->get();
}
}
Value outputValues() {
return getDpsInitOperand(0)->get();
}
Value outputIndices() {
return getDpsInitOperand(1)->get();
}
ShapedType getInputType() {
return cast<ShapedType>(getValues().getType());
}
int64_t getInputRank() {
return getInputType().getRank();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
//===----------------------------------------------------------------------===//
// Attention
//===----------------------------------------------------------------------===//
def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getIndexingMapsForResults", "getIndexingMapsForOperands",
"getStaticLoopRanges"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation",
"generateResultTileValue"]>]> {
let summary = "Attention operator";
let description = [{
Computes the scaled dot product attention function:
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
Here Q, K, V are given tensors and scale is a scalar value specifying
the scale to use.
If an additional mask argument M is included, the result of the first matmul is modified according to:
Q @ K.T += M
}];
let arguments = (ins AnyShaped:$query,
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
Optional<AnyShaped>:$mask,
AnyShaped:$output,
AffineMapArrayAttr:$indexing_maps,
OptionalAttr<DictionaryAttr>:$decomposition_config
);
let regions = (region SizedRegion<1>:$region);
let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)`
`outs` `(` $output `:` type($output) `)`
$region
(`->` type($results)^)?
}];
let builders = [
OpBuilder<(ins "TypeRange":$results,
"Value":$query,
"Value":$key,
"Value":$value,
"Value":$scale,
"Value":$output,
"ArrayAttr":$indexing_maps,
CArg<"std::optional<Value>", "std::nullopt">:$mask)>
];
let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable();
SmallVector<AffineMap> getIndexingMapsArray();
AffineMap getQueryMap() {
return cast<AffineMap>(getIndexingMapsArray()[0]);
}
AffineMap getKeyMap() {
return cast<AffineMap>(getIndexingMapsArray()[1]);
}
AffineMap getValueMap() {
return cast<AffineMap>(getIndexingMapsArray()[2]);
}
AffineMap getScaleMap() {
return cast<AffineMap>(getIndexingMapsArray()[3]);
}
std::optional<AffineMap> getMaskMap() {
if (getMask()) {
return cast<AffineMap>(getIndexingMapsArray()[4]);
}
return std::nullopt;
}
AffineMap getOutputMap() {
return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]);
}
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
/* Decomposition control attributes */
// Attributes to set on QK and PV matmul after decomposition.
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
}];
}
//===----------------------------------------------------------------------===//
// OnlineAttention
//===----------------------------------------------------------------------===//
def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">,
DestinationStyleOpInterface, LinalgExtInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>,
DeclareOpInterfaceMethods<PartialReductionOpInterface,
["generateInitialTensorForPartialReduction",
"tileToPartialReduction",
"mergeReductions",
"getPartialResultTilePosition"]>]> {
let summary = "Online Attention operator";
let description = [{
Traditional scaled dot product attention computes:
attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V
Online Attention on the other hand, uses an online normalizer instead of
softmax:
online_attention(Q, K, V, scale, running_max, running_sum)
= online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V
If an additional mask argument M is included, the result of the first matmul is modified according to:
Q @ K.T += M
The advantage of this online_normalizer is that it can be tiled along
its reduction dimension, making the online_attention operator:
- Tilable along softmax reduction dimension
- Associative along softmax reduction dimension
- Commutative along softmax associative dimension
Note: The results of online_attention need to be combined after computing
it over the entire softmax reduction dimension by:
x, _, sum : results
x = (1 / sum) * x
}];
let arguments = (ins AnyShaped:$query,
AnyShaped:$key,
AnyShaped:$value,
AnyFloat:$scale,
Optional<AnyShaped>:$mask,
AnyShaped:$output,
AnyShaped:$max,
AnyShaped:$sum,
AffineMapArrayAttr:$indexing_maps,
OptionalAttr<DictionaryAttr>:$decomposition_config
);
let regions = (region SizedRegion<1>:$region);
let results = (outs Variadic<AnyRankedTensor>:$results);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
attr-dict
`ins` `(` $query `,` $key `,` $value `,` $scale (`,` $mask^)? `:` type($query) `,` type($key) `,` type($value) `,` type($scale) (`,` type($mask)^ )?`)`
`outs` `(` $output `,` $max `,` $sum `:` type($output) `,` type($max) `,` type($sum) `)`
$region
(`->` type($results)^)?
}];
let builders = [
OpBuilder<(ins "TypeRange":$results,
"Value":$query,
"Value":$key,
"Value":$value,
"Value":$scale,
"Value":$output,
"Value":$max,
"Value":$sum,
"ArrayAttr":$indexing_maps,
CArg<"std::optional<Value>", "std::nullopt">:$mask)>
];
let extraClassDeclaration = [{
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable();
SmallVector<AffineMap> getIndexingMapsArray();
AffineMap getQueryMap() {
return getIndexingMapsArray()[0];
}
AffineMap getKeyMap() {
return getIndexingMapsArray()[1];
}
AffineMap getValueMap() {
return getIndexingMapsArray()[2];
}
AffineMap getScaleMap() {
return getIndexingMapsArray()[3];
}
std::optional<AffineMap> getMaskMap() {
if (getMask()) {
return cast<AffineMap>(getIndexingMapsArray()[4]);
}
return std::nullopt;
}
AffineMap getOutputMap() {
return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs()]);
}
AffineMap getMaxMap() {
return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 1]);
}
AffineMap getSumMap() {
return cast<AffineMap>(getIndexingMapsArray()[getNumDpsInputs() + 2]);
}
int64_t getIterationDomainRank() {
return getQueryMap().getNumDims();
}
/* Decomposition control attributes */
// Attributes to set on QK and PV matmul after decomposition.
static StringRef getQKAttrStr() { return "qk_attrs"; }
static StringRef getPVAttrStr() { return "pv_attrs"; }
}];
}
//===----------------------------------------------------------------------===//
// Im2col
//===----------------------------------------------------------------------===//
def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation",
"generateResultTileValue"]>]> {
let summary = "Im2col operation for convolutions";
let description = [{
Im2col op for convolutions. The operation performs a transformation on the
input to convert it from a convolution input to an equivalent gemm input.
The op is defined by its input, output, some conv metadata, and some
indexing metadata. The `strides`, `dilations`, and `kernel_size` are taken
from the convolution from which this op is generated, and they define how
the input operand is indexed when the operation is decomposed. The shape of
the output should be `tensor<BxMxK>`, and the `m_pos`, `k_pos`, and
`batch_pos` indicate which input dimensions map to which output dimensions.
The `k_offset` is an offset within the output K dimension from which the
iteration space of the operation begins. This is used for tiling, since the
tiled implementation must leave the output K dimension untiled. Similarly,
`m_offset` is the offset within the output M dimension from which the
iteration space of the operation begins.
The iteration space is the full output shape of the im2col op, so if the
im2col op were tiled to loops with a scalar inner tile, it would look like
the following:
```
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
m_offset = [0] * [1] k_offset = [0] * [1]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in : tensor<2x34x34x640xf32>)
outs(%out : tensor<2x1024x5760xf32>) -> tensor<2x1024x5760xf32>
```
becomes:
```
scf.for %arg0 = %c0 to %c2 step %c1
scf.for %arg1 = %c0 to %c1024 step %c1
scf.for %arg2 = %c0 to %c5760 step %c1
%im2col = iree_linalg_ext.im2col
strides = [1, 1] dilations = [1, 1] kernel_size = [3, 3]
m_offset = [%arg1] * [1] k_offset = [%arg2] * [1]
batch_pos = [0] m_pos = [1, 2] k_pos = [3]
ins(%in_tile : tensor<1x34x34x640xf32>)
outs(%out_tile : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
```
Then, when the tiled op is decomposed, it becomes a loop over the iteration
space of the im2col op, whith an extract_slice from the `%in_tile` followed
by an insert_slice to the `%out_tile`. The indices for the extract slice are
computed using the `m_offset` and `k_offset` as:
(b, m, k) -> (b, M / 32 + K / (640*3), M % 32 + K % (640*3) / 640, K % 640)
Where `(b, m, k)` are the indices of the tiled op's iteration space, and
`M = m + m_offset` and `K = k + K_offset`.
The `m_strides` and `k_strides` fields are used as a basis for linearizing
the `m_offset` and `k_offset`. This is used when there are multiple M or K
output dimensions, and therefore multiple `m_offset` or `k_offset` values.
The strides fields are assembled in the IR as if they are multiplied as an
inner product with `m_offset` and `k_offset, indicating that the total
linear offset along the dimension is equal to this inner product. These
strides fields also determine the strides of the output dimensions along
M and K. For example, an op with `m_strides = [32, 1]`, `k_strides = [4, 1]`,
and output type `tensor<BxM0xM1xK0xK1>` (expanded from `tensor<BxMxK>`),
would have strides along the M dim of 32 for `M0`, meaning as `M0` increases
by 1, the index into the flat `M` increases by 32. Along the K dim, strides
would be 4 for `K0`, and 1 for `K1`, meaning as `K0` increases by 1, the
index into the flat `K` increases by 4. The strides in M from `m_strides`
are orthogonal to the strides in `K` from `k_strides`.
}];
let arguments = (ins AnyShaped:$input, AnyShaped:$output,
DenseI64ArrayAttr:$strides,
DenseI64ArrayAttr:$dilations,
Variadic<Index>:$kernel_size,
DenseI64ArrayAttr:$static_kernel_size,
Variadic<Index>:$m_offset,
DenseI64ArrayAttr:$static_m_offset,
Variadic<Index>:$m_strides,
DenseI64ArrayAttr:$static_m_strides,
Variadic<Index>:$k_offset,
DenseI64ArrayAttr:$static_k_offset,
Variadic<Index>:$k_strides,
DenseI64ArrayAttr:$static_k_strides,
DenseI64ArrayAttr:$batch_pos,
DenseI64ArrayAttr:$m_pos,
DenseI64ArrayAttr:$k_pos);
let results = (outs Variadic<AnyShaped>:$results);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`strides` `=` $strides
`dilations` `=` $dilations
`kernel_size` `=`
custom<DynamicIndexList>($kernel_size, $static_kernel_size)
`m_offset` `=`
custom<DynamicIndexList>($m_offset, $static_m_offset)
`*` custom<DynamicIndexList>($m_strides, $static_m_strides)
`k_offset` `=`
custom<DynamicIndexList>($k_offset, $static_k_offset)
`*` custom<DynamicIndexList>($k_strides, $static_k_strides)
`batch_pos` `=` $batch_pos
`m_pos` `=` $m_pos
`k_pos` `=` $k_pos
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
(`->` type($results)^)?
}];
let builders = [
OpBuilder<(ins "Value":$input, "Value":$output,
"ArrayRef<int64_t>":$strides,
"ArrayRef<int64_t>":$dilations,
"ArrayRef<OpFoldResult>":$kernel_size,
"ArrayRef<OpFoldResult>":$m_offset,
"ArrayRef<OpFoldResult>":$m_strides,
"ArrayRef<OpFoldResult>":$k_offset,
"ArrayRef<OpFoldResult>":$k_strides,
"ArrayRef<int64_t>":$batch_dimensions,
"ArrayRef<int64_t>":$m_dimensions,
"ArrayRef<int64_t>":$k_dimensions)>
];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getInputRank() {
return getInputType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
// Helpers to get output dimensions corresponding to batch, m, and k.
SmallVector<int64_t> getBatchOutputDims();
SmallVector<int64_t> getMOutputDims();
SmallVector<int64_t> getKOutputDims();
// Return op metadata.
SmallVector<OpFoldResult> getMixedKernelSize();
SmallVector<OpFoldResult> getMixedMOffset();
SmallVector<OpFoldResult> getMixedKOffset();
SmallVector<OpFoldResult> getMixedMStrides();
SmallVector<OpFoldResult> getMixedKStrides();
// Set op metadata.
void setMixedKOffset(SmallVector<OpFoldResult> kOffset);
void setMixedMOffset(SmallVector<OpFoldResult> mOffset);
void setMixedKStrides(SmallVector<OpFoldResult> kStrides);
void setMixedMStrides(SmallVector<OpFoldResult> mStrides);
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputMutable();
}
}];
}
} // OpGroupNonStructuredOps
//===----------------------------------------------------------------------===//
// Data tiling ops
//===----------------------------------------------------------------------===//
def OpGroupDataTilingOps : OpDocGroup {
let summary = "Data tiling ops";
let description = [{
Operations for working with data layouts, padding, encodings, and other
properties useful for tiling computations across iteration space dimensions.
}];
}
let opDocGroup = OpGroupDataTilingOps in {
def IREELinalgExt_PackOp : IREELinalgExt_Op<"pack", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgExtInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"generateScalarImplementation"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]>{
let summary = "pack operation";
let description = [{
The pack operation converts an `input` into a tiled and packed layout. The
dimensions to be tiled are obtained from `inner_dims_pos` and the size of the
tile is obtained from `inner_tiles`. The dimensions listed in `inner_dims_pos`
do not need to be contiguous in which case the tile will get transposed. We
handle only full tiles if `padding_value` is not set; it is UB if the tile does
not perfectly divide the dimension. If `padding_value` is set, it will pad
along high dimensions, i.e., it pads at the bottom and on the right if the
input has rank 2, and the result type shape, will be dynamic in any dimension
if and only if the input shape is. As optional input, the operation takes
`outer_dims_perm` that allows to permute the tiled loops.
Example KC_to_KCck:
```mlir
iree_linalg_ext.pack %arg0 inner_dims_pos = [1, 0]
inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<16x8x32x8xf32>)
```
Example NC_to_NCnc:
```mlir
iree_linalg_ext.pack %arg0 inner_dims_pos = [0, 1]
inner_tiles = [8, 32] into %arg1 : (memref<128x256xf32> memref<16x8x8x32xf32>)
```
Example KC_to_CKkc
```mlir
iree_linalg_ext.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [32, 8] into %arg1 : (memref<128x256xf32> memref<32x4x32x8xf32>)
```
In all cases, dimension at position 0 in the input memref (128) is tiled
with a factor of 8, while dimension at position 1 (256) is tiled with a factor
of 32. In the KC_to_KCck example, the point loops are interchanged, while in the
KC_to_CKkc example the tiled loops.
Example NC_to_NCnc with padding:
```mlir
iree_linalg_ext.pack %arg padding_value(%pad : f32) inner_dims_pos = [0, 1]
inner_tiles = [8, 2] into %arg1 : (memref<13x15xf32> memref<2x8x8x2xf32>)
```
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles,
Optional<AnyType>:$padding_value);
let results = (outs Variadic<AnyRankedTensor>:$results);
let assemblyFormat = [{
attr-dict
$inputs
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $outputs `:` `(` type($inputs) type($outputs) `)`
(`->` type($results)^)?
}];
let builders = [
OpBuilder<(ins "Value":$source, "Value":$output,
"ArrayRef<int64_t>":$innerDimsPos,
"ArrayRef<OpFoldResult>":$innerTiles,
CArg<"std::optional<Value>", "std::nullopt">:$paddingValue,
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
// Return the output operand.
Value getOutput() {
return getDpsInitOperand(0)->get();
}
// Return the input operand.
Value getInput() {
return getDpsInputOperand(0)->get();
}
// Return the output rank.
int64_t getOutputRank() {
return getOutputType().getRank();
}
// Return the output type.
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
// Return the input type.
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
// Return the output shape.
ArrayRef<int64_t> getOutputShape() {
return getOutputType().getShape();
}
// Return the input shape.
ArrayRef<int64_t> getInputShape() {
return getInputType().getShape();
}
// Return the element type.
Type getElementType() {
return getInputType().getElementType();
}
// Return the rank of the input operand.
int64_t getInputRank() {
return getInputType().getRank();
}
// Return the tile sizes.
SmallVector<OpFoldResult> getMixedTiles();
SmallVector<int64_t> getStaticTiles();
// Return a mapping from positions `dims_pos` to their tile factors.
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
// Method to get the shape of the result as `SmallVector<OpFoldResult>`.
// This is a static method to allow getting the shape of the destination
// expected while creating a `pack` op.
static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
// Method to return the shape of the result as `SmallVector<OpFoldResult>`.
SmallVector<OpFoldResult> getResultShape(OpBuilder &builder);
// Method to get the `ShapedType` of the result. This is a static method
// to allow getting the type of the destination while creating the `pack`
// op.
static ShapedType getPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_UnPackOp : IREELinalgExt_Op<"unpack", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgExtInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"generateScalarImplementation"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]>{
let summary = "unpack operation";
let description = [{
The unpack operation converts a tiled and packed input to an unpacked
output. See `pack` for more details on `inner_tiles` and `dims_pos`; it is UB
if the tile does not perfectly divide the dimension. Optionally, the operation
also supports permuting the tiled loops.
Example KCck_to_KC:
```mlir
iree_linalg_ext.unpack %arg0 dims_pos = [1, 0]
inner_tiles = [32, 8] into %arg1 : (memref<16x8x32x8xf32> memref<128x256xf32>)
```
Example NCnc_to_NC:
```mlir
iree_linalg_ext.unpack %arg0 dims_pos = [0, 1]
inner_tiles = [8, 32] into %arg1 : (memref<16x8x8x32xf32> memref<128x256xf32>)
```
Example CKkc_to_KC:
```mlir
iree_linalg_ext.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [32, 8] into %arg0 : (memref<32x4x32x8xf32> memref<128x256xf32>)
```
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs Variadic<AnyRankedTensor>:$results);
let assemblyFormat = [{
attr-dict
$inputs
(`outer_dims_perm` `=` $outer_dims_perm^)?
`inner_dims_pos` `=` $inner_dims_pos
`inner_tiles` `=`
custom<DynamicIndexList>($inner_tiles, $static_inner_tiles)
`into` $outputs `:` `(` type($inputs) type($outputs) `)`
(`->` type($results)^)?
}];
let builders = [
OpBuilder<(ins "Value":$source, "Value":$output,
"ArrayRef<int64_t>":$innerDimsPos,
"ArrayRef<OpFoldResult>":$innerTiles,
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
// Return the output operand.
Value getOutput() {
return getDpsInitOperand(0)->get();
}
// Return the input operand.
Value getInput() {
return getDpsInputOperand(0)->get();
}
// Return the output rank.
int64_t getOutputRank() {
return getOutputType().getRank();
}
// Return the output type.
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
// Return the input type.
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
// Return the output shape.
ArrayRef<int64_t> getOutputShape() {
return getOutputType().getShape();
}
// Return the input shape.
ArrayRef<int64_t> getInputShape() {
return getInputType().getShape();
}
// Return the rank of the input operand.
int64_t getInputRank() {
return getInputType().getRank();
}
// Return the tile sizes.
SmallVector<OpFoldResult> getMixedTiles();
SmallVector<int64_t> getStaticTiles();
// Return a mapping from positions `dims_pos` to their tile factors.
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
} // OpGroupDataTilingOps
//===----------------------------------------------------------------------===//
// Winograd ops
//===----------------------------------------------------------------------===//
def OpGroupWinogradOps : OpDocGroup {
let summary = "Winograd ops";
let description = "";
}
let opDocGroup = OpGroupWinogradOps in {
def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_transform",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd Input Transform operator";
let description = [{
This operator is part of the first step in converting a convolution to
its Winograd equivalent. Given a tile of an input image (I),
this operator computes matmul(tranpose(B), matmul(I, B)).
The input tile is assumed to be square with each side of size m + r - 1,
where the convolutional kernel is m x m and the output tile size is r x r.
B is a constant 2-d square matrix of the same shape as the input tile I.
The input to the operator is an image of shape (N, H, W, C) or (N, C, H, W)
and the output is an operator of shape (m + r - 1, m + r - 1, N, H', W', C)
where H' = ceil((H - m + 1)/r) and W' = ceil((W - m + 1)/r). The result
of this operator is first collapsed and then fed to a batch matmul op.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$output_tile_size,
I64Attr:$kernel_size,
DenseI64ArrayAttr:$image_dimensions
);
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
];
let results = (outs Variadic<AnyRankedTensor>:$result);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`image_dimensions` `(` $image_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getInput() {
return getDpsInputOperand(0)->get();
}
Value getOutput() {
return getDpsInitOperand(0)->get();
}
Value getOriginalOperand() {
return getInput();
}
Value getTransformedOperand() {
return getOutput();
}
ShapedType getOriginalOperandType() {
return getInputType();
}
ShapedType getTransformedOperandType() {
return getOutputType();
}
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getInputRank() {
return getInputType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
int64_t getInputTileSize() {
return getOutputTileSize() + getKernelSize() - 1;
}
ArrayRef<int64_t> getHwDimensions() {
return getImageDimensions();
}
std::array<int64_t, 2> nhwcImageDimensions() {
return {1, 2};
}
std::array<int64_t, 2> nchwImageDimensions() {
return {2, 3};
}
bool isNhwc() {
std::array<int64_t, 2> nhwcImageDims = nhwcImageDimensions();
return getImageDimensions() == ArrayRef<int64_t>(nhwcImageDims);
}
bool isNchw() {
std::array<int64_t, 2> nchwImageDims = nchwImageDimensions();
return getImageDimensions() == ArrayRef<int64_t>(nchwImageDims);
}
int getChannelDim() {
return isNhwc() ? 3 : 1;
}
int64_t getIterationDomainRank() {
return getOutputRank() - getImageDimensions().size();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_WinogradFilterTransformOp : IREELinalgExt_Op<"winograd.filter_transform",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd Filter Transform operator";
let description = [{
This operator is part of the first step in converting a convolution to
its Winograd equivalent. Given a tile of a convolution filter (F),
this operator computes matmul(G, matmul(F, transpose(B))).
The filter tile is assumed to be the full m x m convolutional kernel,
and the result of the transformation on this tile is a square with each
side of size m + r - 1, where the output tile size is r x r. G is a constant
2-d matrix of shape (m + r - 1) x m. The input to the operator is a filter
of shape (H, W, C, F) or (F, C, H, W) and the output is an operator of shape
(m + r - 1, m + r - 1, C, F). The result of this operator is first collapsed
and then fed to a batch matmul op.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$output_tile_size,
I64Attr:$kernel_size,
DenseI64ArrayAttr:$kernel_dimensions
);
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
CArg<"ArrayRef<int64_t>", "{0, 1}">:$kernel_dimensions)>
];
let results = (outs Variadic<AnyRankedTensor>:$result);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`kernel_dimensions` `(` $kernel_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getInput() {
return getDpsInputOperand(0)->get();
}
Value getOutput() {
return getDpsInitOperand(0)->get();
}
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
Value getOriginalOperand() {
return getInput();
}
Value getTransformedOperand() {
return getOutput();
}
ShapedType getOriginalOperandType() {
return getInputType();
}
ShapedType getTransformedOperandType() {
return getOutputType();
}
int64_t getInputRank() {
return getInputType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
int64_t getInputTileSize() {
return getOutputTileSize() + getKernelSize() - 1;
}
ArrayRef<int64_t> getHwDimensions() {
return getKernelDimensions();
}
std::array<int64_t, 2> hwcfKernelDimensions() {
return {0, 1};
}
std::array<int64_t, 2> fchwKernelDimensions() {
return {2, 3};
}
bool isHwcf() {
std::array<int64_t, 2> hwcfKernelDims = hwcfKernelDimensions();
ArrayRef<int64_t> kernelDims = getKernelDimensions();
return kernelDims == ArrayRef<int64_t>(hwcfKernelDims);
}
bool isFchw() {
std::array<int64_t, 2> fchwKernelDims = fchwKernelDimensions();
ArrayRef<int64_t> kernelDims = getKernelDimensions();
return kernelDims == ArrayRef<int64_t>(fchwKernelDims);
}
int getChannelDim() {
return isHwcf() ? 2 : 1;
}
int getFilterDim() {
return isHwcf() ? 3 : 0;
}
int64_t getIterationDomainRank() {
return getInputRank() - getKernelDimensions().size();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_transform",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd Output Transform operator";
let description = [{
This operator is the last transform in converting a convolution to
its Winograd equivalent. After convolution in the Winograd domain
(which turns into an elementwise product for a single channel and
batch matrix multiplication for many channels), this operator converts
the output back into the original domain. Given a tile of the
output (O) in the Winograd domain, this operator computes
matmul(transpose(A), matmul(O, A)). The output tile is square with
each side of size m + r - 1, where the convolutional kernel is m x m
and the output tile size is r x r. A is a constant 2-d matrix of
shape (m + r - 1) x r. The input to the operator is a tensor of
shape (m + r - 1, m + r - 1, N, H', W', C) and the output is a
tensor of shape (N, H, W, C) or (N, C, H, W) where H = r H' and W = r W'.
This operator is followed by a tensor.extract_slice which extracts
only the non-padded part of the output.
}];
let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$output_tile_size,
I64Attr:$kernel_size,
DenseI64ArrayAttr:$image_dimensions
);
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"int64_t", "8">:$output_tile_size, CArg<"int64_t", "3">:$kernel_size,
CArg<"ArrayRef<int64_t>", "{1, 2}">:$image_dimensions)>
];
let results = (outs Variadic<AnyRankedTensor>:$result);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`output_tile_size` `(` $output_tile_size `)`
`kernel_size` `(` $kernel_size `)`
`image_dimensions` `(` $image_dimensions `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value getInput() {
return getDpsInputOperand(0)->get();
}
Value getOutput() {
return getDpsInitOperand(0)->get();
}
ShapedType getInputType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
Value getOriginalOperand() {
return getOutput();
}
Value getTransformedOperand() {
return getInput();
}
ShapedType getOriginalOperandType() {
return getOutputType();
}
ShapedType getTransformedOperandType() {
return getInputType();
}
ArrayRef<int64_t> getHwDimensions() {
return getImageDimensions();
}
std::array<int64_t, 2> nhwcImageDimensions() {
return {1, 2};
}
std::array<int64_t, 2> nchwImageDimensions() {
return {2, 3};
}
bool isNhwc() {
std::array<int64_t, 2> nhwcImageDims = nhwcImageDimensions();
return getImageDimensions() == ArrayRef<int64_t>(nhwcImageDims);
}
bool isNchw() {
std::array<int64_t, 2> nchwImageDims = nchwImageDimensions();
return getImageDimensions() == ArrayRef<int64_t>(nchwImageDims);
}
int getChannelDim() {
return isNhwc() ? 3 : 1;
}
int64_t getInputRank() {
return getInputType().getRank();
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
int64_t getIterationDomainRank() {
return getInputRank() - getImageDimensions().size();
}
int64_t getInputTileSize() {
return getOutputTileSize() + getKernelSize() - 1;
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
} // OpGroupWinogradOps
//===---------------------------------------------------------------------===//
// Custom tilable op
//===---------------------------------------------------------------------===//
def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [
DeclareOpInterfaceMethods<AggregatedOpInterface, [
"decomposeOperation"]>,
DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>
]> {
let summary = "Custom operation for compiling with IREE";
let description = [{
This operation is meant to allow computation sequences that are fused at
tile level prescriptively. This is to account for cases where such fusion
cannot/is not yet discovered appropriately.
The operation implements all the interfaces needed to be able to
1. Compile e2e using IREE
2. Still be able to fuse with other operations that the compiler can
figure out automatically.
Similar to how `LinalgOp`s represent a perfectly nested loop computation
with
- `indexing_maps` representing how the `ins`/`outs` are accessed
- `region` representing the scalar computation performed
- `iterator_types` representing the dependence along each iteration space
dimension
this operation represent a tiled computation with perfectly nested
inter-tile loop nest.
- `indexing_maps` represent what slices slices of the `ins`/`outs` are
needed for each iteration of the tiled computation.
- `region` represents the tiled computation performed using these slices
- `iterator_types` represents the dependence between tiles along each
iteration space.
Some modifications required to handle the tile-level semantics are
- Some dimensions of operands might not be accessed by dimensions of the
inter-tile iteration space. This means that along these dimensions the
slice size matches the dimension size. This access pattern of operands
is captured in the respective indexing map using a `symbol` to represent
that the entire dimension needs to be sliced.
- The basic block arguments of the region represent the slice of the
operand. These are either scalar types (if the corresponding operand is a
scalar), or a `tensor` type with dynamic shapes (if the corresponding
operand is a `tensor` type).
For example, one could represent a prescriptively fused matmul computation
as follows
```
%0:2 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0, d1)[s0, s1] -> (d0, s0)>,
affine_map<(d0, d1)[s0, s1] -> (s0, s1)>,
affine_map<(d0, d1)[s0, s1] -> (s1, d1)>,
affine_map<(d0, d1)[s0, s1] -> (d0, s1)>,
affine_map<(d0, d1)[s0, s1] -> (d0, d1)],
iterator_types = ["parallel", "parallel"]}
ins(%lhs1, %rhs1, %rhs2
: tensor<1000000x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
outs(%outs1, %outs2 : tensor<1000000x?xf32>, tensor<1000000x?xf32>) {
^bb0(%t0 : tensor<?x?xf32>, %t1 : tensor<?x?xf32>, %t2 : tensor<?x?xf32>,
%t3 : tensor<?x?xf32>, %t4 : tensor<?x?xf32>) :
%0 = linalg.matmul ins(%t0, %t1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%t3 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.matmul ins(%0, %t2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%t4 : tensor<?x?xf32>) -> tensor<?x?xf32>
iree_linalg_ext.yield %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
} -> tensor<1000000x?xf32>, tensor<x?xf32>
```
}];
let arguments = (ins
Variadic<AnyRankedTensorOrScalarType>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
AffineMapArrayAttr:$indexing_maps,
IREELinalgExt_IteratorTypeArrayAttr:$iterator_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$region);
let hasVerifier = 1;
let assemblyFormat = [{
`{` `indexing_maps` `=` $indexing_maps `,`
`iterator_types` `=` $iterator_types `}`
attr-dict-with-keyword
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
(`outs` `(` $outputs^ `:` type($outputs) `)`)?
$region (`->` type($results)^)?
}];
let extraClassDeclaration =[{
// Helper accessor methods.
unsigned getNumLoops();
int64_t getRank(Value);
// Return the number of non-loop dimensions of the op.
unsigned getNumNonLoopDimensions();
// Return the ranges for the loop dimensions and symbol
// dimensions of the operation.
SmallVector<Range> getIterationDomainForDimensions(OpBuilder &builder,
ArrayRef<unsigned> dims, ArrayRef<unsigned> symbols);
// DestinationStyleOpInterface methods
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}
#endif // IREE_DIALECT_LINALGEXT_OPS