blob: 1e8ce5413708db32d87721aa723806ec8b45196d [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements logic for lowering StableHLO einsum op to dot_general ops.
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h"
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
#define GEN_PASS_DEF_EINSUMTODOTGENERAL
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Passes.h.inc"
namespace {
struct EinsumToDotGeneralPattern final
: OpRewritePattern<mlir::stablehlo::EinsumOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::stablehlo::EinsumOp einsum,
PatternRewriter &rewriter) const override {
StringRef equation = einsum.getEinsumConfig();
SmallVector<char> lhsTokens, rhsTokens;
SmallVector<char> resultTokens;
size_t index = 0;
enum EquationVariable { kIsLhs, kIsRhs, kIsResult };
EquationVariable currentVariable = kIsLhs;
while (index < equation.size()) {
if (std::isalpha(equation[index])) {
if (currentVariable == kIsLhs) {
lhsTokens.push_back(equation[index]);
} else if (currentVariable == kIsRhs) {
rhsTokens.push_back(equation[index]);
} else {
resultTokens.push_back(equation[index]);
}
} else if (equation.substr(index, 1).contains(",")) {
currentVariable = kIsRhs;
} else if ((index < (equation.size() - 1)) &&
(equation.substr(index, 2).contains("->"))) {
currentVariable = kIsResult;
++index;
} else {
return einsum.emitError("unexpected character ")
<< equation.substr(index, 1) << " encountered";
}
++index;
}
auto lhsType = cast<RankedTensorType>(einsum.getLhs().getType());
auto rhsType = cast<RankedTensorType>(einsum.getRhs().getType());
assert(static_cast<int64_t>(lhsTokens.size()) == lhsType.getRank());
assert(static_cast<int64_t>(rhsTokens.size()) == rhsType.getRank());
auto collectOperandDims =
[resultTokens](
RankedTensorType operandType, SmallVector<char> operandTokens,
SmallVector<char> others, SmallVectorImpl<int64_t> &contractingDims,
SmallVectorImpl<int64_t> &batchingDims,
SmallVector<char> &dotResultTokens,
SmallVector<int64_t> &dotResultShape) {
llvm::SmallDenseSet<char> othersSet(others.begin(), others.end());
llvm::SmallDenseSet<char> resultTokensSet(resultTokens.begin(),
resultTokens.end());
for (auto [idx, token] : llvm::enumerate(operandTokens)) {
bool isResultToken = resultTokensSet.contains(token);
bool isOtherToken = othersSet.contains(token);
if (!isResultToken) {
contractingDims.push_back(idx);
} else if (isOtherToken) {
batchingDims.push_back(idx);
} else {
dotResultTokens.push_back(token);
dotResultShape.push_back(operandType.getShape()[idx]);
}
}
};
// Indices of batch and contracting dims, relative to each operand's
// dimensions.
SmallVector<int64_t> lhsContractingDims, lhsBatchingDims,
rhsContractingDims, rhsBatchingDims;
// Tokens representing the natural order of the dot_general op (i.e.
// the lhs non-contracting followed by rhs non-contracting tokens).
SmallVector<char> dotResultTokens;
SmallVector<int64_t> dotResultShape;
collectOperandDims(lhsType, lhsTokens, rhsTokens, lhsContractingDims,
lhsBatchingDims, dotResultTokens, dotResultShape);
collectOperandDims(rhsType, rhsTokens, lhsTokens, rhsContractingDims,
rhsBatchingDims, dotResultTokens, dotResultShape);
// Prepend batch tokens.
for (auto [idx, dim] : llvm::enumerate(lhsBatchingDims)) {
char batchingToken = lhsTokens[dim];
int64_t batchingShapeDim = lhsType.getShape()[dim];
dotResultTokens.insert(dotResultTokens.begin() + idx, batchingToken);
dotResultShape.insert(dotResultShape.begin() + idx, batchingShapeDim);
}
// Lowering to dot_general does not support a mismatch between the number
// of result dims and the number of non-contracting dims.
if (dotResultTokens.size() != resultTokens.size()) {
return rewriter.notifyMatchFailure(einsum,
"rank reducing einsum not supported");
}
// Generate a permutation sequence based on result tokens.
SmallVector<int64_t> resultPerms;
bool isNaturalOrder = true;
for (char resultToken : resultTokens) {
auto foundIt = std::find(dotResultTokens.begin(), dotResultTokens.end(),
resultToken);
if (foundIt == dotResultTokens.end()) {
return rewriter.notifyMatchFailure(
einsum, "result token not found in operands");
}
auto resultIndex = std::distance(dotResultTokens.begin(), foundIt);
if (resultPerms.empty()) {
if (resultIndex != 0) {
isNaturalOrder = false;
}
} else if (resultIndex != (resultPerms.back() + 1)) {
isNaturalOrder = false;
}
resultPerms.push_back(resultIndex);
}
// Emit the dot_general, using its native result ordering.
auto dotGeneralResultType = RankedTensorType::get(
ArrayRef<int64_t>(dotResultShape), lhsType.getElementType());
auto dimNumbers = mlir::stablehlo::DotDimensionNumbersAttr::get(
rewriter.getContext(), lhsBatchingDims, rhsBatchingDims,
lhsContractingDims, rhsContractingDims);
auto dotGeneralOp = mlir::stablehlo::DotGeneralOp::create(
rewriter, einsum.getLoc(), dotGeneralResultType, einsum.getLhs(),
einsum.getRhs(), dimNumbers,
/*precision_config=*/ArrayAttr{}, mlir::stablehlo::DotAlgorithmAttr{});
if (isNaturalOrder) {
// The dot_general is already in an appropriate result order.
rewriter.replaceOp(einsum, ValueRange{dotGeneralOp});
} else {
// Generate a transpose.
rewriter.replaceOpWithNewOp<mlir::stablehlo::TransposeOp>(
einsum, dotGeneralOp, rewriter.getDenseI64ArrayAttr(resultPerms));
}
return success();
}
};
struct EinsumToDotGeneral final
: impl::EinsumToDotGeneralBase<EinsumToDotGeneral> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePreprocessingEinsumToDotGeneralPatterns(&getContext(), &patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
void populatePreprocessingEinsumToDotGeneralPatterns(
mlir::MLIRContext *context, RewritePatternSet *patterns) {
patterns->add<EinsumToDotGeneralPattern>(context);
}
} // namespace mlir::iree_compiler::stablehlo