blob: 9cf30f2e4993dc4f2fa334060be9fee2d13cbd32 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//===- IndexComputation.h ---------------------------------------*- C++ -*-===//
//
// For an IREE dispatch function, compute the map from workitem ID to index of
// tensor computed within that workitem.
//
//===----------------------------------------------------------------------===//
#ifndef IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_H
#define IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace iree_compiler {
/// For each tensor Value* within the dispatch function, store the map
/// representing the index of that tensor needed for a workitem. For each index
/// also store the index of the operands needed to not recompute it later on.
/// TODO(ravishankarm): This sort of assumes each operation has a single result
/// value. Might need to be changed later on.
using IndexComputationCache =
DenseMap<Value *, llvm::MapVector<AffineMap, SmallVector<AffineMap, 4>>>;
/// Base class used to construct a map from opname to the
/// computation that propagates the index map from result to
/// operands.
class IndexPropagation {
public:
virtual ~IndexPropagation() = default;
virtual StringRef getOpName() const = 0;
/// Propagates the index map from result to operands, i.e. given the index map
/// that represents the element(s) of result(s) used within a thread, compute
/// the indices of the operands needed. If overridden by inherited class the
/// implementation should evaluate the indices of the operands needed for all
/// indices of the result value needed. The default implementation only
/// handles operations with a zero or single-return value.
virtual LogicalResult propagateIndexMap(
Operation *operation, IndexComputationCache &indexMap) const;
/// Propagates the index map from result to operands for a given index of the
/// result operand.
virtual LogicalResult propagateIndexMap(
Operation *operation, AffineMap resultIndex,
SmallVectorImpl<AffineMap> &operandIndices) const {
// By default do-nothing.
return success();
}
};
/// Base class that is templated on operation type to common
/// method to get the operation name.
template <typename OpTy>
class IndexPropagationOp : public IndexPropagation {
public:
using IndexPropagation::IndexPropagation;
virtual ~IndexPropagationOp() = default;
virtual StringRef getOpName() const { return OpTy::getOperationName(); }
};
// ===-------------------------------------------------------------------- ===//
// NoBroadCastPwOp
// ===-------------------------------------------------------------------- ===//
/// Propagates the index map from result to operands for
/// operations that are point-wise and the operands are not
/// implicitly broadcasted. Just copies the index maps of the
/// result to the operands.
template <typename OpTy>
class NoBroadcastPwOpIndexPropagation : public IndexPropagationOp<OpTy> {
public:
using IndexPropagationOp<OpTy>::IndexPropagationOp;
LogicalResult propagateIndexMap(
Operation *operation, AffineMap resultIndex,
SmallVectorImpl<AffineMap> &operandIndices) const override {
// All operands must have the same type.
auto argRefType =
operation->getOperand(0)->getType().dyn_cast<RankedTensorType>();
if (!argRefType) {
return operation->emitError("expected operands to be of tensortype");
}
for (auto arg : operation->getOperands()) {
auto argType = arg->getType().dyn_cast<RankedTensorType>();
if (!argType || argType.getShape() != argRefType.getShape()) {
return operation->emitError("expected operands to have same shape");
}
operandIndices.push_back(resultIndex);
}
return success();
}
};
// ===-------------------------------------------------------------------- ===//
// ReshapeOp
// ===-------------------------------------------------------------------- ===//
/// Computes the index map representing the element of the operand accessed
/// given the index map of the result for a reshape operation.
LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap,
ArrayRef<int64_t> resultShapeRef,
ArrayRef<int64_t> operandShapeRef,
AffineMap &operandIndexMap);
/// Propagates the index map from result to operands of reshape type operations.
template <typename OpTy>
class ReshapeOpIndexPropagation final : public IndexPropagationOp<OpTy> {
public:
using IndexPropagationOp<OpTy>::IndexPropagationOp;
LogicalResult propagateIndexMap(
Operation *op, AffineMap resultIndex,
SmallVectorImpl<AffineMap> &operandIndices) const override {
Builder builder(op->getContext());
auto resultType =
op->getResult(0)->getType().template dyn_cast<ShapedType>();
auto operandType =
op->getOperand(0)->getType().template dyn_cast<ShapedType>();
if (!resultType || !operandType) {
return op->emitError("expected result and operand to be shaped types");
}
if (!resultType.hasStaticShape() || !operandType.hasStaticShape()) {
return op->emitError(
"unhandled non static shape of result/operands of reshape op");
}
if (resultType.getNumElements() != operandType.getNumElements()) {
return op->emitError(
"invalid reshape operation, mismatch in number of elements in "
"operand and result");
}
// Check if the reshape is a adding or removing a dimension of size 1 (and
// build the index for the operand accordingly).
AffineMap operandIndexMap;
if (failed(getReshapeOperandMap(builder, resultIndex, resultType.getShape(),
operandType.getShape(), operandIndexMap))) {
return op->emitError("unhandled reshape index propagation");
}
operandIndices.push_back(operandIndexMap);
return success();
}
};
// ===-------------------------------------------------------------------- ===//
// ReverseOp
// ===-------------------------------------------------------------------- ===//
/// Propagates the index map from result to operands of reverse type operation.
/// See https://www.tensorflow.org/xla/operation_semantics#rev_reverse for more
/// details.
template <typename OpTy>
class ReverseOpIndexPropagation : public IndexPropagationOp<OpTy> {
public:
using IndexPropagationOp<OpTy>::IndexPropagationOp;
virtual ~ReverseOpIndexPropagation() = default;
protected:
LogicalResult propagateIndexMapImpl(
Operation *op, DenseSet<unsigned> dimensions, AffineMap resultIndex,
SmallVectorImpl<AffineMap> &operandIndices) const {
auto shaped_type = op->getOperand(0)->getType().cast<ShapedType>();
Builder builder(op->getContext());
SmallVector<AffineExpr, 4> dimensionsExprs;
for (unsigned index = 0; index < shaped_type.getRank(); ++index) {
if (dimensions.count(index)) {
auto reversed_index =
shaped_type.getDimSize(index) - builder.getAffineDimExpr(index) - 1;
dimensionsExprs.push_back(reversed_index);
} else {
dimensionsExprs.push_back(builder.getAffineDimExpr(index));
}
}
auto dimensionsAffineMap = AffineMap::get(
dimensionsExprs.size(), /*symbolCount=*/0, dimensionsExprs);
// Compose the index map of the result with the index map for the
// dimensions.
auto operandMap = dimensionsAffineMap.compose(resultIndex);
operandIndices.push_back(operandMap);
return success();
}
};
/// Index map of the operand of a transpose op is obtained by composing the
/// affine map of the result with the affine map that represents the inverse of
/// the transpose permutation vector. The permutation vector must be supplied by
/// derived classes in the definition of the method `propagateIndexMap`.
template <typename OpTy>
class TransposeOpIndexPropagation : public IndexPropagationOp<OpTy> {
public:
using IndexPropagationOp<OpTy>::IndexPropagationOp;
virtual ~TransposeOpIndexPropagation() = default;
protected:
LogicalResult propagateIndexMapImpl(
Operation *op, ArrayRef<unsigned> permutation, AffineMap resultIndex,
SmallVectorImpl<AffineMap> &operandIndices) const {
Builder builder(op->getContext());
SmallVector<AffineExpr, 4> permutationExprs;
for (auto index : permutation) {
permutationExprs.push_back(builder.getAffineDimExpr(index));
}
auto permutationAffineMap =
AffineMap::get(permutationExprs.size(), 0, permutationExprs);
// Compute the inverse of the permutation map.
auto invPermutationMap = inversePermutation(permutationAffineMap);
// Compose the index map of the result with the index map
// for the inverse of the permutation.
auto operandMap = invPermutationMap.compose(resultIndex);
operandIndices.push_back(operandMap);
return success();
}
};
/// Maintains list of IndexPropagation objects that propagate the index
/// information from result to the operands of instructions.
template <typename... Ts>
class IndexPropagationList {
using IndexPropagationListT =
llvm::StringMap<std::unique_ptr<IndexPropagation>>;
public:
explicit IndexPropagationList() { insert(); }
/// Performs the propagation.
LogicalResult propagate(Region &region,
IndexComputationCache &indexMap) const {
if (region.getBlocks().size() != 1) {
return emitError(
region.getLoc(),
"unimplemented handling multiple blocks within a region");
}
// TODO(ravishankarm) : Need to process blocks in reverse topological order.
for (auto it = region.rbegin(), ie = region.rend(); it != ie; ++it) {
for (auto jt = it->rbegin(), je = it->rend(); jt != je; ++jt) {
auto &op = *jt;
auto opName = op.getName().getStringRef();
if (!indexPropagationList.count(opName)) {
return op.emitError("unhandled index propagation");
}
for (auto rt = op.result_begin(), re = op.result_end(); rt != re;
++rt) {
auto resultValue = *rt;
auto type = resultValue->getType().dyn_cast<RankedTensorType>();
if (!type) {
return op.emitError("expected return value to be a tensor");
}
if (!indexMap.count(*rt)) {
return op.emitError("missing index map of result");
}
}
auto propagate = indexPropagationList.find(opName);
if (failed(propagate->getValue()->propagateIndexMap(&op, indexMap))) {
return failure();
}
}
}
return success();
}
private:
/// Builds the list by unpacking the parameter pack.
void insert() {
std::vector<std::unique_ptr<IndexPropagation>> objs;
// TODO(ravishankarm) : This uses the fold logic from
// mlir/IR/PatternMatch.h. There might be a simpler/more efficient way of
// implementing this.
using dummy = int[];
(void)dummy{0, (objs.emplace_back(std::make_unique<Ts>()), 0)...};
for (auto &elem : objs) {
StringRef opName = elem->getOpName();
indexPropagationList.try_emplace(opName, std::move(elem));
}
}
/// List of methods for propagation indexed using opname.
IndexPropagationListT indexPropagationList;
};
/// Debug method to just dump the indexMap to llvm::errs.
void dumpIndexCache(IndexComputationCache &indexMap);
} // namespace iree_compiler
} // namespace mlir
#endif // IREE_COMPILER_TRANSLATION_SPIRV_INDEXCOMPUTATION_H