// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//===- IndexComputation.cpp ------------------------------------*- C++//-*-===//
//
// For an IREE dispatch function, compute the map from workitem ID to index of
// tensor computed within that workitem.
//
//===----------------------------------------------------------------------===//
#include "compiler/Translation/SPIRV/IndexComputation.h"

#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"

static llvm::cl::opt<bool> doAffineExprSimplify(
    "simplify-spirv-affine-exprs",
    llvm::cl::desc("Simplify affine expressions during code-generation."),
    llvm::cl::init(true));

namespace mlir {
namespace iree_compiler {

//===----------------------------------------------------------------------===//
// Reshape Utility Functions
//===----------------------------------------------------------------------===//

namespace {
/// Handles shapes for scalars. Shape of scalars are represented as empty vetor,
/// i.e. {}. Its easier to do index propogation to handle the scalar as vector
/// of size 1.
inline SmallVector<int64_t, 4> handleIfScalar(ArrayRef<int64_t> shape) {
  SmallVector<int64_t, 4> resultShape;
  if (shape.empty()) {
    return {1};
  }
  return SmallVector<int64_t, 4>(shape.begin(), shape.end());
}

/// Reshapes are often used to either add a dimension of size 1 or remove a
/// dimension of size 1. Recognizing such cases can make the code-generation
/// easier. The AffineMap needs to either add a constant 0 in the range for such
/// added dimensions or drop those dimensions.
inline LogicalResult getAffineExprForAddOrRemoveDimension(
    Builder &builder, ArrayRef<AffineExpr> resultExprs,
    ArrayRef<int64_t> resultShape, ArrayRef<int64_t> operandShape,
    SmallVectorImpl<AffineExpr> &operandExprs) {
  auto resultIndex = resultShape.size();
  auto operandIndex = operandShape.size();
  operandExprs.resize(operandShape.size());
  // Try to match up the dimensions of the operand and result by ignoring any
  // dimensions of size of 1 that are introduced.
  while (resultIndex > 0 && operandIndex > 0) {
    if (resultShape[resultIndex - 1] == -1 ||
        operandShape[operandIndex - 1] == -1) {
      return failure();
    }
    if (resultShape[resultIndex - 1] == operandShape[operandIndex - 1]) {
      operandExprs[operandIndex - 1] = resultExprs[resultIndex - 1];
      resultIndex--;
      operandIndex--;
      continue;
    }
    if (resultShape[resultIndex - 1] == 1) {
      // This is a dimension that is added on the operand. This affine
      // expression corresponding to this dimension is dropped.
      resultIndex--;
      continue;
    }
    if (operandShape[operandIndex - 1] == 1) {
      // This is a dimension of size 1 of the operand that is dropped. Add a
      // constant expr 0.
      operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
      operandIndex--;
      continue;
    }
    return failure();
  }
  // Any remaining dimensions should be 1.
  while (resultIndex > 0) {
    if (resultShape[resultIndex - 1] != 1) {
      return failure();
    }
    resultIndex--;
  }
  while (operandIndex > 0) {
    if (operandShape[operandIndex - 1] != 1) {
      return failure();
    }
    // This is a dimension of size 1 that is dropped. Add a constant expression
    // 0.
    operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
    operandIndex--;
  }
  return success();
}

/// Constructs the strides of an array assuming a row-major packed layout.
// TODO(ravishankarm): This assumes the shape are static. When using dynamic
// shapes, parameters of each dimension can be used to construct AffineExpr for
// strides along each dimension. Note that multiplying two symbolic constants is
// technically not affine, but you could use another symbol to represent the
// product, so it should be still representable as affine exprs.
inline LogicalResult getRowMajorPackedStrides(
    Builder &builder, ArrayRef<int64_t> shape,
    SmallVectorImpl<AffineExpr> &strides) {
  strides.resize(shape.size());
  int64_t stride = 1;
  for (auto dim : enumerate(reverse(shape))) {
    if (dim.value() < 0) {
      // TODO(ravishankarm) : Better error message.
      return failure();
    }
    strides[shape.size() - 1 - dim.index()] =
        builder.getAffineConstantExpr(stride);
    stride *= dim.value();
  }
  return success();
}

/// Linearizes the index of the result position accessed using the shape of the
/// result tensor and delinearizes it to get the position of the operand.
inline LogicalResult getAffineExprForReshape(
    Builder &builder, unsigned numDims, unsigned numSymbols,
    ArrayRef<AffineExpr> resultExprs, ArrayRef<int64_t> resultShape,
    ArrayRef<int64_t> operandShape, SmallVectorImpl<AffineExpr> &operandExprs) {
  // To linearize the index, assume that the memory is laid out in
  // packed-row-major layout based on the shape.
  // TODO(ravishankarm) : When there is stride information, use that to map from
  // index to memory location.
  SmallVector<AffineExpr, 4> resultStrides;
  if (failed(getRowMajorPackedStrides(builder, resultShape, resultStrides))) {
    return failure();
  }
  AffineExpr linearizedExpr;
  for (auto index : enumerate(resultExprs)) {
    auto val = getAffineBinaryOpExpr(AffineExprKind::Mul, index.value(),
                                     resultStrides[index.index()]);
    if (doAffineExprSimplify) {
      val = simplifyAffineExpr(val, numDims, numSymbols);
    }
    linearizedExpr = (index.index() ? getAffineBinaryOpExpr(AffineExprKind::Add,
                                                            linearizedExpr, val)
                                    : val);
    if (doAffineExprSimplify) {
      linearizedExpr = simplifyAffineExpr(val, numDims, numSymbols);
    }
  }

  // Unlinearize the index, assuming row-major-packed layout.
  // TODO(ravishankarm) : When there is stride information, use that to map from
  // memory location to index.
  SmallVector<AffineExpr, 4> operandStrides;
  if (failed(getRowMajorPackedStrides(builder, operandShape, operandStrides))) {
    return failure();
  }
  operandExprs.resize(operandStrides.size());
  for (auto stride : enumerate(operandStrides)) {
    if (stride.index() == operandStrides.size() - 1) {
      operandExprs[stride.index()] = linearizedExpr;
      break;
    }
    auto expr = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, linearizedExpr,
                                      stride.value());
    operandExprs[stride.index()] =
        (doAffineExprSimplify ? simplifyAffineExpr(expr, numDims, numSymbols)
                              : expr);

    linearizedExpr = getAffineBinaryOpExpr(AffineExprKind::Mod, linearizedExpr,
                                           stride.value());
    if (doAffineExprSimplify) {
      linearizedExpr = simplifyAffineExpr(linearizedExpr, numDims, numSymbols);
    }
  }
  return success();
}
}  // namespace

LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap,
                                   ArrayRef<int64_t> resultShapeRef,
                                   ArrayRef<int64_t> operandShapeRef,
                                   AffineMap &operandIndexMap) {
  auto resultShape = handleIfScalar(resultShapeRef);
  auto operandShape = handleIfScalar(operandShapeRef);
  auto resultExprs = resultIndexMap.getResults();
  assert(resultShape.size() == resultExprs.size() &&
         "Ranks of the Domain of index map and result must be the same");
  SmallVector<AffineExpr, 4> operandExprs;
  if (failed(getAffineExprForAddOrRemoveDimension(
          builder, resultExprs, resultShape, operandShape, operandExprs)) &&
      failed(getAffineExprForReshape(
          builder, resultIndexMap.getNumDims(), resultIndexMap.getNumSymbols(),
          resultExprs, resultShape, operandShape, operandExprs))) {
    return failure();
  }
  assert(operandExprs.size() == operandShape.size() &&
         "expected as many exprs for the operand as the rank of the operand");
  operandIndexMap =
      AffineMap::get(resultIndexMap.getNumDims(),
                     resultIndexMap.getNumSymbols(), operandExprs);

  return success();
}

LogicalResult IndexPropagation::propagateIndexMap(
    Operation *op, IndexComputationCache &indexMap) const {
  if (op->getNumResults() == 0) {
    // Nothing to do for this op.
    return success();
  }
  if (op->getNumResults() != 1) {
    return op->emitError(
        "default index propagation handles case with a single-return value");
  }
  // Initialize the storage for all the operands.
  for (auto arg : op->getOperands()) {
    indexMap[arg];
  }
  for (auto &resultIndexMap : indexMap[op->getResult(0)]) {
    SmallVector<AffineMap, 4> operandIndices;
    if (failed(this->propagateIndexMap(op, resultIndexMap.first,
                                       operandIndices))) {
      return failure();
    }
    assert(operandIndices.size() == op->getNumOperands() &&
           "Expected as many indices as operands");
    for (auto arg : enumerate(op->getOperands())) {
      indexMap[arg.value()][operandIndices[arg.index()]];
      resultIndexMap.second.push_back(operandIndices[arg.index()]);
    }
  }
  return success();
}

void dumpIndexCache(IndexComputationCache &indexMap) {
  for (auto &el : indexMap) {
    // llvm::errs() << "Value : " << *(el.first);
    // llvm::errs().flush();
    if (isa<OpResult>(el.first)) {
      llvm::errs() << "Operation : " << el.first->getDefiningOp()->getName();
    } else if (isa<BlockArgument>(el.first)) {
      llvm::errs() << "BlockArgument";
    }
    for (auto &used : el.second) {
      llvm::errs() << "\n\t" << used.first << " : [";
      std::string sep = "";
      for (auto &operand : used.second) {
        llvm::errs() << sep << operand;
        sep = ", ";
      }
      llvm::errs() << "]";
    }
    llvm::errs() << "\n";
  }
  llvm::errs() << "\n";
}

}  // namespace iree_compiler
}  // namespace mlir
