blob: 877fab0bd7f9679b4114040f8a1cdd5046b43823 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- IndexComputation.cpp ------------------------------------*- C++//-*-===//
//
// For an IREE dispatch function, compute the map from workitem ID to index of
// tensor computed within that workitem.
//
//===----------------------------------------------------------------------===//
#include "compiler/Translation/SPIRV/IndexComputation.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
static llvm::cl::opt<bool> doAffineExprSimplify(
"simplify-spirv-affine-exprs",
llvm::cl::desc("Simplify affine expressions during code-generation."),
llvm::cl::init(true));
namespace mlir {
namespace iree_compiler {
//===----------------------------------------------------------------------===//
// Reshape Utility Functions
//===----------------------------------------------------------------------===//
namespace {
/// Handles shapes for scalars. Shape of scalars are represented as empty vetor,
/// i.e. {}. Its easier to do index propogation to handle the scalar as vector
/// of size 1.
inline SmallVector<int64_t, 4> handleIfScalar(ArrayRef<int64_t> shape) {
SmallVector<int64_t, 4> resultShape;
if (shape.empty()) {
return {1};
}
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
}
/// Reshapes are often used to either add a dimension of size 1 or remove a
/// dimension of size 1. Recognizing such cases can make the code-generation
/// easier. The AffineMap needs to either add a constant 0 in the range for such
/// added dimensions or drop those dimensions.
inline LogicalResult getAffineExprForAddOrRemoveDimension(
Builder &builder, ArrayRef<AffineExpr> resultExprs,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> operandShape,
SmallVectorImpl<AffineExpr> &operandExprs) {
auto resultIndex = resultShape.size();
auto operandIndex = operandShape.size();
operandExprs.resize(operandShape.size());
// Try to match up the dimensions of the operand and result by ignoring any
// dimensions of size of 1 that are introduced.
while (resultIndex > 0 && operandIndex > 0) {
if (resultShape[resultIndex - 1] == -1 ||
operandShape[operandIndex - 1] == -1) {
return failure();
}
if (resultShape[resultIndex - 1] == operandShape[operandIndex - 1]) {
operandExprs[operandIndex - 1] = resultExprs[resultIndex - 1];
resultIndex--;
operandIndex--;
continue;
}
if (resultShape[resultIndex - 1] == 1) {
// This is a dimension that is added on the operand. This affine
// expression corresponding to this dimension is dropped.
resultIndex--;
continue;
}
if (operandShape[operandIndex - 1] == 1) {
// This is a dimension of size 1 of the operand that is dropped. Add a
// constant expr 0.
operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
operandIndex--;
continue;
}
return failure();
}
// Any remaining dimensions should be 1.
while (resultIndex > 0) {
if (resultShape[resultIndex - 1] != 1) {
return failure();
}
resultIndex--;
}
while (operandIndex > 0) {
if (operandShape[operandIndex - 1] != 1) {
return failure();
}
// This is a dimension of size 1 that is dropped. Add a constant expression
// 0.
operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0);
operandIndex--;
}
return success();
}
/// Constructs the strides of an array assuming a row-major packed layout.
// TODO(ravishankarm): This assumes the shape are static. When using dynamic
// shapes, parameters of each dimension can be used to construct AffineExpr for
// strides along each dimension. Note that multiplying two symbolic constants is
// technically not affine, but you could use another symbol to represent the
// product, so it should be still representable as affine exprs.
inline LogicalResult getRowMajorPackedStrides(
Builder &builder, ArrayRef<int64_t> shape,
SmallVectorImpl<AffineExpr> &strides) {
strides.resize(shape.size());
int64_t stride = 1;
for (auto dim : enumerate(reverse(shape))) {
if (dim.value() < 0) {
// TODO(ravishankarm) : Better error message.
return failure();
}
strides[shape.size() - 1 - dim.index()] =
builder.getAffineConstantExpr(stride);
stride *= dim.value();
}
return success();
}
/// Linearizes the index of the result position accessed using the shape of the
/// result tensor and delinearizes it to get the position of the operand.
inline LogicalResult getAffineExprForReshape(
Builder &builder, unsigned numDims, unsigned numSymbols,
ArrayRef<AffineExpr> resultExprs, ArrayRef<int64_t> resultShape,
ArrayRef<int64_t> operandShape, SmallVectorImpl<AffineExpr> &operandExprs) {
// To linearize the index, assume that the memory is laid out in
// packed-row-major layout based on the shape.
// TODO(ravishankarm) : When there is stride information, use that to map from
// index to memory location.
SmallVector<AffineExpr, 4> resultStrides;
if (failed(getRowMajorPackedStrides(builder, resultShape, resultStrides))) {
return failure();
}
AffineExpr linearizedExpr;
for (auto index : enumerate(resultExprs)) {
auto val = getAffineBinaryOpExpr(AffineExprKind::Mul, index.value(),
resultStrides[index.index()]);
if (doAffineExprSimplify) {
val = simplifyAffineExpr(val, numDims, numSymbols);
}
linearizedExpr = (index.index() ? getAffineBinaryOpExpr(AffineExprKind::Add,
linearizedExpr, val)
: val);
if (doAffineExprSimplify) {
linearizedExpr = simplifyAffineExpr(val, numDims, numSymbols);
}
}
// Unlinearize the index, assuming row-major-packed layout.
// TODO(ravishankarm) : When there is stride information, use that to map from
// memory location to index.
SmallVector<AffineExpr, 4> operandStrides;
if (failed(getRowMajorPackedStrides(builder, operandShape, operandStrides))) {
return failure();
}
operandExprs.resize(operandStrides.size());
for (auto stride : enumerate(operandStrides)) {
if (stride.index() == operandStrides.size() - 1) {
operandExprs[stride.index()] = linearizedExpr;
break;
}
auto expr = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, linearizedExpr,
stride.value());
operandExprs[stride.index()] =
(doAffineExprSimplify ? simplifyAffineExpr(expr, numDims, numSymbols)
: expr);
linearizedExpr = getAffineBinaryOpExpr(AffineExprKind::Mod, linearizedExpr,
stride.value());
if (doAffineExprSimplify) {
linearizedExpr = simplifyAffineExpr(linearizedExpr, numDims, numSymbols);
}
}
return success();
}
} // namespace
LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap,
ArrayRef<int64_t> resultShapeRef,
ArrayRef<int64_t> operandShapeRef,
AffineMap &operandIndexMap) {
auto resultShape = handleIfScalar(resultShapeRef);
auto operandShape = handleIfScalar(operandShapeRef);
auto resultExprs = resultIndexMap.getResults();
assert(resultShape.size() == resultExprs.size() &&
"Ranks of the Domain of index map and result must be the same");
SmallVector<AffineExpr, 4> operandExprs;
if (failed(getAffineExprForAddOrRemoveDimension(
builder, resultExprs, resultShape, operandShape, operandExprs)) &&
failed(getAffineExprForReshape(
builder, resultIndexMap.getNumDims(), resultIndexMap.getNumSymbols(),
resultExprs, resultShape, operandShape, operandExprs))) {
return failure();
}
assert(operandExprs.size() == operandShape.size() &&
"expected as many exprs for the operand as the rank of the operand");
operandIndexMap =
AffineMap::get(resultIndexMap.getNumDims(),
resultIndexMap.getNumSymbols(), operandExprs);
return success();
}
LogicalResult IndexPropagation::propagateIndexMap(
Operation *op, IndexComputationCache &indexMap) const {
if (op->getNumResults() == 0) {
// Nothing to do for this op.
return success();
}
if (op->getNumResults() != 1) {
return op->emitError(
"default index propagation handles case with a single-return value");
}
// Initialize the storage for all the operands.
for (auto arg : op->getOperands()) {
indexMap[arg];
}
for (auto &resultIndexMap : indexMap[op->getResult(0)]) {
SmallVector<AffineMap, 4> operandIndices;
if (failed(this->propagateIndexMap(op, resultIndexMap.first,
operandIndices))) {
return failure();
}
assert(operandIndices.size() == op->getNumOperands() &&
"Expected as many indices as operands");
for (auto arg : enumerate(op->getOperands())) {
indexMap[arg.value()][operandIndices[arg.index()]];
resultIndexMap.second.push_back(operandIndices[arg.index()]);
}
}
return success();
}
void dumpIndexCache(IndexComputationCache &indexMap) {
for (auto &el : indexMap) {
// llvm::errs() << "Value : " << *(el.first);
// llvm::errs().flush();
if (isa<OpResult>(el.first)) {
llvm::errs() << "Operation : " << el.first->getDefiningOp()->getName();
} else if (isa<BlockArgument>(el.first)) {
llvm::errs() << "BlockArgument";
}
for (auto &used : el.second) {
llvm::errs() << "\n\t" << used.first << " : [";
std::string sep = "";
for (auto &operand : used.second) {
llvm::errs() << sep << operand;
sep = ", ";
}
llvm::errs() << "]";
}
llvm::errs() << "\n";
}
llvm::errs() << "\n";
}
} // namespace iree_compiler
} // namespace mlir