blob: ba58d34397ad1aa16f841f7f1fa1ca1d7c53541b [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/Common/LinalgOpInfo.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
using namespace mlir::linalg;
namespace mlir {
namespace iree_compiler {
/// Returns true if `map` is a tranpose. A transpose map is a projected
/// permutation with or without zeros in results where there exist at least two
/// dimensions di and dj such that di < dj and result_pos(di) > result_pos(dj).
/// Examples:
///
/// (d0, d1, d2) -> (d0, d2) is not a transpose map.
/// (d0, d1, d2) -> (d2, d0) is a transpose map.
/// (d0, d1, d2) -> (d1, d2) is not a transpose map.
/// (d0, d1, d2) -> (d0, 0, d1) is not a transpose map.
/// (d0, d1, d2) -> (d2, 0, d1) is a transpose map.
/// (d0, d1, d2) -> (d1, 0) is not a transpose map.
///
// TODO(dcaballe): Discern between "memcopy" transposes and "shuffle"
// transposes.
// TODO(dcaballe): Move to Affine utils?
static bool isTransposeMap(AffineMap map) {
// A transpose map must be a projected permutation with or without
// broadcasted/reduction dimensions.
if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) {
return false;
}
// Check that the projected permutation has at least two result dimensions
// that are actually transposed by comparing its input position.
unsigned prevDim = 0;
for (AffineExpr expr : map.getResults()) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
// Constant zero expression, guaranteed by 'allowZeroInResults' above.
continue;
} else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
if (prevDim > dimExpr.getPosition()) {
return true;
}
prevDim = dimExpr.getPosition();
} else {
return false;
}
}
return false;
}
/// The default filter passes all op operands.
static bool defaultTransposeMapFilter(AffineMap map) { return true; }
LinalgOpInfo::LinalgOpInfo(linalg::LinalgOp linalgOp)
: transposeMapFilter(defaultTransposeMapFilter) {
computeInfo(linalgOp);
}
LinalgOpInfo::LinalgOpInfo(linalg::LinalgOp linalgOp,
TransposeMapFilter transposeMapFilter)
: transposeMapFilter(transposeMapFilter) {
computeInfo(linalgOp);
}
/// Returns true if a LinalgOp implements a transpose.
// TODO(dcaballe):
// * Consider transpose + reductions.
// * Consider input and output transposes.
static SmallVector<OpOperand *> computeTransposeInfo(
LinalgOp linalgOp, TransposeMapFilter transposeMapFilter) {
SmallVector<OpOperand *> transposeOperands;
// Reductions are not supported.
if (linalgOp.getNumReductionLoops() > 0) {
return transposeOperands;
}
// Inverse map to use transfer op permutation logic.
AffineMap outputInversedMap = inversePermutation(
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)));
SmallVector<AffineMap> inputInversedMaps;
for (OpOperand *linalgOperand : linalgOp.getDpsInputOperands()) {
auto map = linalgOp.getMatchingIndexingMap(linalgOperand);
if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) {
return transposeOperands;
}
AffineMap inverseMap = inverseAndBroadcastProjectedPermutation(map);
if (isTransposeMap(inverseMap) && transposeMapFilter(inverseMap)) {
transposeOperands.push_back(linalgOperand);
}
}
// Multiple outputs are not supported yet.
if (linalgOp.getNumDpsInits() != 1) {
return transposeOperands;
}
if (isTransposeMap(outputInversedMap) &&
transposeMapFilter(outputInversedMap)) {
transposeOperands.push_back(linalgOp.getDpsInitOperand(0));
}
return transposeOperands;
}
static bool computeReductionInfo(LinalgOp linalgOp) {
return linalgOp.getNumReductionLoops() > 1;
}
static bool computeDynamicInfo(LinalgOp linalgOp) {
return linalgOp.hasDynamicShape();
}
void LinalgOpInfo::computeInfo(LinalgOp linalgOp) {
transposeOperands = computeTransposeInfo(linalgOp, transposeMapFilter);
reductionTrait = computeReductionInfo(linalgOp);
dynamicTrait = computeDynamicInfo(linalgOp);
}
} // namespace iree_compiler
} // namespace mlir