// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef IREE_DIALECT_VMLA_OPS
#define IREE_DIALECT_VMLA_OPS

include "iree/compiler/Dialect/VMLA/IR/VMLABase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/SideEffects.td"

class VMLA_PureOp<string mnemonic, list<OpTrait> traits = []> :
    VMLA_Op<mnemonic, !listconcat(traits, [NoSideEffect])>;

//===----------------------------------------------------------------------===//
// VMLA Ops: pseudo ops
//===----------------------------------------------------------------------===//

def VMLA_ConstantOp : VMLA_PureOp<"constant"> {
  let summary = [{constant buffer declaration}];
  let description = [{
    A pseudo-op used to represent a buffer with constant contents. This is later
    expanded into VM ops and the vmla.buffer.const op.
  }];

  let arguments = (ins
    ElementsAttr:$value
  );
  let results = (outs
    VMLA_Buffer:$result
  );

  let builders = [
    OpBuilder<"OpBuilder &builder, OperationState &result, ElementsAttr value",
    [{
      build(builder, result, IREE::VMLA::BufferType::get(builder.getContext()),
            value);
    }]>,
  ];
}

//===----------------------------------------------------------------------===//
// VMLA Ops: buffer manipulation
//===----------------------------------------------------------------------===//

def VMLA_BufferConstOp : VMLA_PureOp<"buffer.const"> {
  let arguments = (ins
    VMLA_HostBuffer:$value
  );
  let results = (outs
    VMLA_Buffer:$result
  );

  let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)";
}

def VMLA_BufferAllocOp : VMLA_Op<"buffer.alloc"> {
  let arguments = (ins
    VMLA_DeviceSize:$byte_length
  );
  let results = (outs
    VMLA_Buffer:$result
  );

  let assemblyFormat = [{
    `byte_length` `=` $byte_length attr-dict `:` type($result)
  }];
}

def VMLA_BufferCloneOp : VMLA_Op<"buffer.clone"> {
  let arguments = (ins
    VMLA_Buffer:$src
  );
  let results = (outs
    VMLA_Buffer:$result
  );

  let assemblyFormat = "$src attr-dict `:` type($result)";
}

def VMLA_BufferByteLengthOp : VMLA_PureOp<"buffer.byte_length"> {
  let arguments = (ins
    VMLA_Buffer:$value
  );
  let results = (outs
    VMLA_DeviceSize:$result
  );

  let assemblyFormat = "$value attr-dict `:` type($result)";
}

def VMLA_BufferViewOp : VMLA_PureOp<"buffer.view"> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_DeviceSize:$byte_offset,
    VMLA_DeviceSize:$byte_length
  );
  let results = (outs
    VMLA_Buffer:$result
  );

  let assemblyFormat = [{
    $src`[`$byte_offset`]``,` `byte_length` `=` $byte_length
    attr-dict `:` type($result)
  }];
}

def VMLA_BufferCopyOp : VMLA_Op<"buffer.copy"> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_DeviceSize:$src_byte_offset,
    VMLA_Buffer:$dst,
    VMLA_DeviceSize:$dst_byte_offset,
    VMLA_DeviceSize:$byte_length
  );

  let assemblyFormat = [{
    $src`[`$src_byte_offset`]``,`
    `out` $dst`[`$dst_byte_offset`]``,` `byte_length` `=` $byte_length
    attr-dict
  }];
}

def VMLA_BufferFillOp : VMLA_Op<"buffer.fill"> {
  let arguments = (ins
    VMLA_Buffer:$value,
    VMLA_Buffer:$dst
  );

  let assemblyFormat = "$value`,` `out` $dst attr-dict";
}

def VMLA_BufferLoadI32Op : VMLA_PureOp<"buffer.load.i32"> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_DeviceSize:$byte_offset
  );
  let results = (outs
    I32:$result
  );

  let assemblyFormat = "$src`[`$byte_offset`]` attr-dict `:` type($result)";
}

//===----------------------------------------------------------------------===//
// VMLA Ops: comparison
//===----------------------------------------------------------------------===//

def VMLA_CmpOp : VMLA_ElementTypeOp<"cmp"> {
  let arguments = (ins
    VMLA_CmpPredicateAttr:$predicate,
    VMLA_Buffer:$lhs,
    VMLA_Buffer:$rhs,
    VMLA_Buffer:$dst,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_SelectOp : VMLA_ElementTypeOp<"select"> {
  let arguments = (ins
    VMLA_Buffer:$cond,
    VMLA_Buffer:$lhs,
    VMLA_Buffer:$rhs,
    VMLA_Buffer:$dst,
    VMLA_AnyTypeAttr:$element_type
  );
}

//===----------------------------------------------------------------------===//
// VMLA Ops: shape/structure
//===----------------------------------------------------------------------===//

def VMLA_CopyOp : VMLA_ElementTypeOp<"copy", [
    VMLA_IncludeShapes,
    SameVariadicOperandSize,
  ]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    Variadic<VMLA_Index>:$src_indices,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    Variadic<VMLA_Index>:$dst_indices,
    Variadic<VMLA_Index>:$lengths,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_TransposeOp : VMLA_ElementTypeOp<"transpose", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    ElementsAttr:$permutation,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_AnyTypeAttr:$element_type
  );

  let assemblyFormat = [{
    $src`(`$src_shape `:` type($src_shape)`)``,`
    `out` $dst`(`$dst_shape `:` type($dst_shape)`)` attr-dict `:` $element_type
  }];
}

def VMLA_ReverseOp : VMLA_ElementTypeOp<"reverse", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    ElementsAttr:$dimensions,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_PadOp : VMLA_ElementTypeOp<"pad", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    VMLA_Buffer:$value,
    VMLA_Shape:$value_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    ElementsAttr:$edge_padding_low,
    ElementsAttr:$edge_padding_high,
    ElementsAttr:$interior_padding,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_BroadcastOp : VMLA_ElementTypeOp<"broadcast", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_TileOp : VMLA_ElementTypeOp<"tile", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_GatherOp : VMLA_ElementTypeOp<"gather", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    VMLA_Buffer:$indices,
    VMLA_Shape:$indices_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    I64Attr:$dim,
    I64Attr:$batch_dims,
    VMLA_AnyTypeAttr:$element_type
  );
}

//===----------------------------------------------------------------------===//
// VMLA Ops: bit manipulation
//===----------------------------------------------------------------------===//

def VMLA_NotOp : VMLA_UnaryOp<"not", VMLA_AnyTypeAttr>;
def VMLA_AndOp : VMLA_BinaryOp<"and", VMLA_AnyTypeAttr>;
def VMLA_OrOp : VMLA_BinaryOp<"or", VMLA_AnyTypeAttr>;
def VMLA_XorOp : VMLA_BinaryOp<"xor", VMLA_AnyTypeAttr>;
def VMLA_ShlOp : VMLA_BinaryOp<"shl", VMLA_AnyTypeAttr>;
def VMLA_ShrOp : VMLA_BinaryOp<"shr", VMLA_AnyTypeAttr>;

//===----------------------------------------------------------------------===//
// VMLA Ops: arithmetic
//===----------------------------------------------------------------------===//

def VMLA_AddOp : VMLA_BinaryOp<"add", VMLA_AnyTypeAttr>;
def VMLA_SubOp : VMLA_BinaryOp<"sub", VMLA_AnyTypeAttr>;
def VMLA_AbsOp : VMLA_UnaryOp<"abs", VMLA_AnyTypeAttr>;
def VMLA_NegOp : VMLA_UnaryOp<"neg", VMLA_AnyTypeAttr>;
def VMLA_MulOp : VMLA_BinaryOp<"mul", VMLA_AnyTypeAttr>;
def VMLA_DivOp : VMLA_BinaryOp<"div", VMLA_AnyTypeAttr>;
def VMLA_RemOp : VMLA_BinaryOp<"rem", VMLA_AnyTypeAttr>;
def VMLA_PowOp : VMLA_BinaryOp<"pow", VMLA_FloatTypeAttr>;
def VMLA_ExpOp : VMLA_UnaryOp<"exp", VMLA_FloatTypeAttr>;
def VMLA_LogOp : VMLA_UnaryOp<"log", VMLA_FloatTypeAttr>;
def VMLA_RsqrtOp : VMLA_UnaryOp<"rsqrt", VMLA_FloatTypeAttr>;
def VMLA_SqrtOp : VMLA_UnaryOp<"sqrt", VMLA_FloatTypeAttr>;
def VMLA_CosOp : VMLA_UnaryOp<"cos", VMLA_FloatTypeAttr>;
def VMLA_SinOp : VMLA_UnaryOp<"sin", VMLA_FloatTypeAttr>;
def VMLA_TanhOp : VMLA_UnaryOp<"tanh", VMLA_FloatTypeAttr>;
def VMLA_Atan2Op : VMLA_BinaryOp<"atan2", VMLA_FloatTypeAttr>;

def VMLA_MinOp : VMLA_BinaryOp<"min", VMLA_AnyTypeAttr>;
def VMLA_MaxOp : VMLA_BinaryOp<"max", VMLA_AnyTypeAttr>;
def VMLA_ClampOp : VMLA_TernaryOp<"clamp", VMLA_AnyTypeAttr>;
def VMLA_FloorOp : VMLA_UnaryOp<"floor", VMLA_FloatTypeAttr>;
def VMLA_CeilOp : VMLA_UnaryOp<"ceil", VMLA_FloatTypeAttr>;

//===----------------------------------------------------------------------===//
// VMLA Ops: conversion
//===----------------------------------------------------------------------===//

def VMLA_ConvertOp : VMLA_Op<"convert", [VMLA_OpInterface]> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Buffer:$dst,
    VMLA_AnyTypeAttr:$src_type,
    VMLA_AnyTypeAttr:$dst_type
  );

  let extraClassDeclaration = [{
    static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
      state.addAttribute("src_type", TypeAttr::get(operandTypes[0]));
      state.addAttribute("dst_type", TypeAttr::get(resultTypes[0]));
    }
  }];

  let assemblyFormat = [{
    $src`,` `out` $dst attr-dict `:` $src_type `->` $dst_type
  }];
}

//===----------------------------------------------------------------------===//
// VMLA Ops: Convultion
//===----------------------------------------------------------------------===//

def VLMA_ConvOp : VMLA_Op<"conv", [VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$input,
    VMLA_Shape:$input_shape,
    VMLA_Buffer:$filter,
    VMLA_Shape:$filter_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    I32ElementsAttr:$window_strides,
    I32ElementsAttr:$padding,
    I32ElementsAttr:$lhs_dilation,
    I32ElementsAttr:$rhs_dilation,
    I32Attr:$feature_group_count,
    I32Attr:$batch_group_count,
    VMLA_FloatTypeAttr:$input_type,
    VMLA_FloatTypeAttr:$filter_type,
    VMLA_FloatTypeAttr:$dst_type
  );

  let extraClassDeclaration = [{
    static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
      state.addAttribute("input_type", TypeAttr::get(operandTypes[0]));
      state.addAttribute("filter_type", TypeAttr::get(operandTypes[1]));
      state.addAttribute("dst_type", TypeAttr::get(resultTypes[0]));
    }
  }];
}

//===----------------------------------------------------------------------===//
// VMLA Ops: GEMM/GEMV
//===----------------------------------------------------------------------===//

def VMLA_BatchMatMulPseudoOp : VMLA_Op<"batch.matmul.pseudo"> {
  let summary = "Tensor-level pseudo-op of VMLA::BatchMatMulOp.";
  let description = [{
    This is a tensor-level version of VMLA::BatchMatMulOp, to facilitate
    the lowering process.

    All operands are rank-3 with the following dimension structure:
    - lhs = [B, FLHS, C]
    - rhs = [B, FRHS, C]
    - dst = [B, FRHS, FLHS]
    Where:
    - B = batch dimension
    - C = contracting dimension
    - FLHS and FRHS are the free dimensions of each operand

    To put this in terms closer to the mathematics of matrix multiplication,
    if we ignore the leading B dimension and focus on what is mathematically an
    MxKxN matmul, then this corresponds to:
    - lhs = [M, K] = [LHSROWS, K]
    - rhs = [N, K] = [RHSCOLS, K]
    - dst = [N, M] = [RHSCOLS, LHSROWS]
    Note that dst is transposed from what one would expect.
    This is due to an implementation detail of this op in the runtime.
    This op is backed by an invocation of the Ruy matrix multiplication library,
    which prefers its matrices in this layout (in matrix terminology:
    lhs = row-major, rhs = column-major, dst = column-major).
    We insert the relevant transposes as needed in the compiler.
  }];
  let arguments = (ins
    AnyTensor:$lhs,
    AnyTensor:$rhs
  );
  let results = (outs
    AnyTensor:$dst
  );

  let assemblyFormat = [{
    $lhs`,` $rhs attr-dict `:`
    `(`type($lhs)`,` type($rhs)`)` `->` type($dst)
  }];
}

def VMLA_BatchMatMulOp : VMLA_Op<"batch.matmul", [VMLA_OpInterface, VMLA_IncludeShapes]> {
  let arguments = (ins
    VMLA_Buffer:$lhs,
    VMLA_Shape:$lhs_shape,
    VMLA_Buffer:$rhs,
    VMLA_Shape:$rhs_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_FloatTypeAttr:$lhs_type,
    VMLA_FloatTypeAttr:$rhs_type,
    VMLA_FloatTypeAttr:$dst_type
  );

  let extraClassDeclaration = [{
    static void extractTypeAttributes(OperationState &state, ArrayRef<Type> operandTypes, ArrayRef<Type> resultTypes) {
      state.addAttribute("lhs_type", TypeAttr::get(operandTypes[0]));
      state.addAttribute("rhs_type", TypeAttr::get(operandTypes[1]));
      state.addAttribute("dst_type", TypeAttr::get(resultTypes[0]));
    }
  }];

  let assemblyFormat = [{
    $lhs`(`$lhs_shape `:` type($lhs_shape)`)` `:` $lhs_type`,`
    $rhs`(`$rhs_shape `:` type($rhs_shape)`)` `:` $rhs_type`,`
    `out` $dst`(`$dst_shape `:` type($dst_shape)`)` `:` $dst_type attr-dict
  }];
}

//===----------------------------------------------------------------------===//
// VMLA Ops: reduction
//===----------------------------------------------------------------------===//

class VMLA_ReduceOp<string mnemonic, list<OpTrait> traits = []> :
    VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    VMLA_Buffer:$init,
    VMLA_Shape:$init_shape,
    I32Attr:$dimension,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_AnyTypeAttr:$element_type
  );
}

def VMLA_ReduceSumOp : VMLA_ReduceOp<"reduce.sum">;
def VMLA_ReduceMinOp : VMLA_ReduceOp<"reduce.min">;
def VMLA_ReduceMaxOp : VMLA_ReduceOp<"reduce.max">;

class VMLA_PoolingOp<string mnemonic, list<OpTrait> traits = []> :
    VMLA_ElementTypeOp<mnemonic, !listconcat(traits, [VMLA_IncludeShapes])> {
  let arguments = (ins
    VMLA_Buffer:$src,
    VMLA_Shape:$src_shape,
    VMLA_Buffer:$init,
    VMLA_Shape:$init_shape,
    VMLA_Buffer:$dst,
    VMLA_Shape:$dst_shape,
    VMLA_AnyTypeAttr:$element_type,
    I32ElementsAttr:$window_dimensions,
    I32ElementsAttr:$window_strides,
    I32ElementsAttr:$padding
  );
}

def VMLA_PoolingSumOp : VMLA_PoolingOp<"pooling.sum">;
def VMLA_PoolingMinOp : VMLA_PoolingOp<"pooling.min">;
def VMLA_PoolingMaxOp : VMLA_PoolingOp<"pooling.max">;

//===----------------------------------------------------------------------===//
// VMLA Ops: ABI
//===----------------------------------------------------------------------===//

def VMLA_InterfaceConstOp :
    VMLA_PureOp<"interface.const", [VMLA_OpInterface]> {
  let arguments = (ins
    VMLA_Interface:$interface,
    IREE_IndexAttr:$offset
  );
  let results = (outs
    AnyTypeOf<[I32, VMLA_Index]>:$result
  );
}

def VMLA_InterfaceBindingOp :
    VMLA_PureOp<"interface.binding", [VMLA_OpInterface]> {
  let arguments = (ins
    VMLA_Interface:$interface,
    I32Attr:$set,
    I32Attr:$binding
  );
  let results = (outs
    VMLA_Buffer:$result
  );
}

#endif  // IREE_DIALECT_VMLA_OPS
