blob: 0db4fafec6bcc0bb59925aa53a72c858f28d3683 [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements logic for lowering StableHLO dot product ops to Linalg dialect.
// These patterns are separated out to their own file to save on the compilation
// times, given that we instantiate a large number of class templates here.
#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
namespace {
enum class DotOperationType {
kVectorDot = 0,
kMatrixVector,
kVectorMatrix,
kMatrixMatrix,
kUnsupported
};
DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) {
ArrayRef<int64_t> lhsShape =
cast<ShapedType>(dotOp.getLhs().getType()).getShape();
ArrayRef<int64_t> rhsShape =
cast<ShapedType>(dotOp.getRhs().getType()).getShape();
auto shapeMatches = [](int64_t a, int64_t b) {
return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b;
};
if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
shapeMatches(lhsShape[0], rhsShape[0])) {
return DotOperationType::kVectorDot;
}
if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
shapeMatches(lhsShape[1], rhsShape[0])) {
return DotOperationType::kMatrixVector;
}
if (lhsShape.size() == 1 && rhsShape.size() == 2 &&
shapeMatches(lhsShape[0], rhsShape[0])) {
return DotOperationType::kVectorMatrix;
}
if (lhsShape.size() == 2 && rhsShape.size() == 2 &&
shapeMatches(lhsShape[1], rhsShape[0])) {
return DotOperationType::kMatrixMatrix;
}
return DotOperationType::kUnsupported;
}
SmallVector<Value, 2> getDotOpEmptyTensorDynSizes(OpBuilder &b, Location loc,
Value lhs, Value rhs,
DotOperationType type) {
SmallVector<Value, 2> dynShape;
switch (type) {
case DotOperationType::kMatrixMatrix: {
if (llvm::cast<ShapedType>(lhs.getType()).isDynamicDim(0))
dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
if (llvm::cast<ShapedType>(rhs.getType()).isDynamicDim(1))
dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kMatrixVector: {
if (llvm::cast<ShapedType>(lhs.getType()).isDynamicDim(0))
dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
break;
}
case DotOperationType::kVectorMatrix: {
if (llvm::cast<ShapedType>(rhs.getType()).isDynamicDim(1))
dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kVectorDot:
case DotOperationType::kUnsupported:
break;
}
return dynShape;
}
template <DotOperationType op_type, typename LinalgOp>
struct DotOpConversion final : OpConversionPattern<mlir::stablehlo::DotOp> {
using OpConversionPattern<mlir::stablehlo::DotOp>::OpConversionPattern;
using OpAdaptor = mlir::stablehlo::DotOp::Adaptor;
LogicalResult
matchAndRewrite(mlir::stablehlo::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
return failure();
}
if (getDotOperationType(op) != op_type)
return failure();
Location loc = op.getLoc();
// Convert unsigned to signed. This works because signed and unsigned
// integer matmul is the same operation in two's complement.
auto outputType =
cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
SmallVector<Value, 2> dynShape = getDotOpEmptyTensorDynSizes(
rewriter, loc, adaptor.getLhs(), adaptor.getRhs(), op_type);
Value emptyTensor =
!sparse_tensor::getSparseTensorEncoding(outputType)
? getEmptyTensor(rewriter, loc, outputType, dynShape)
: getEmptySparseTensor(rewriter, loc, outputType, dynShape);
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
rewriter.replaceOpWithNewOp<LinalgOp>(
op, TypeRange{outputType},
ValueRange{adaptor.getLhs(), adaptor.getRhs()}, ValueRange{zeroTensor},
linalg::getPrunedAttributeList(op));
return success();
}
};
struct DotGeneralBatchMatMulOpConversion final
: OpConversionPattern<mlir::stablehlo::DotGeneralOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
return failure();
}
if (llvm::cast<RankedTensorType>(op.getType()).getRank() != 3) {
return rewriter.notifyMatchFailure(op, "expected a batch matmul");
}
mlir::stablehlo::DotDimensionNumbersAttr dimNumbers =
op.getDotDimensionNumbers();
ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
ArrayRef<int64_t> lhsContractingDims =
dimNumbers.getLhsContractingDimensions();
ArrayRef<int64_t> rhsContractingDims =
dimNumbers.getRhsContractingDimensions();
if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) {
return rewriter.notifyMatchFailure(
op, "expected lhs batching dimensions exactly {0}");
}
if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) {
return rewriter.notifyMatchFailure(
op, "expected rhs batching dimensions exactly {0}");
}
if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) {
return rewriter.notifyMatchFailure(
op, "expected lhs contracting dimensions exactly {2}");
}
if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) {
return rewriter.notifyMatchFailure(
op, "expected rhs contracting dimensions exactly {1}");
}
Location loc = op.getLoc();
// Convert unsigned to signed. This works because signed and unsigned
// integer matmul is the same operation in two's complement.
auto outputType =
cast<ShapedType>(typeConverter->convertType(op.getType()));
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
Operation *linalgOp = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{outputType},
/*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()},
/*outputBuffers=*/ValueRange{zeroTensor},
linalg::getPrunedAttributeList(op));
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
struct DotGeneralOpConversion final
: OpConversionPattern<mlir::stablehlo::DotGeneralOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::DotGeneralOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
return failure();
}
// Get various dimension iterator information
mlir::stablehlo::DotDimensionNumbersAttr dimNumbers =
op.getDotDimensionNumbers();
ArrayRef<int64_t> lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
ArrayRef<int64_t> rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
ArrayRef<int64_t> lhsContractingDims =
dimNumbers.getLhsContractingDimensions();
ArrayRef<int64_t> rhsContractingDims =
dimNumbers.getRhsContractingDimensions();
// Get shape information and initialize output
assert(lhsContractingDims.size() == rhsContractingDims.size() &&
"number of contracting dims must be equal");
size_t numContracting = lhsContractingDims.size();
// Convert unsigned to signed. This works because signed and unsigned
// integer matmul is the same operation in two's complement.
auto outputType =
cast<ShapedType>(typeConverter->convertType(op.getType()));
size_t targetRank = outputType.getRank();
size_t totalLoopCount = numContracting + targetRank;
int64_t lhsRank =
llvm::cast<ShapedType>(adaptor.getLhs().getType()).getRank();
size_t lhsExtraDims =
lhsRank - lhsBatchingDims.size() - lhsContractingDims.size();
int64_t rhsRank =
llvm::cast<ShapedType>(adaptor.getRhs().getType()).getRank();
Location loc = op.getLoc();
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor);
SmallVector<AffineMap, 3> indexingMaps;
auto getMap = [&](int64_t rank, ArrayRef<int64_t> batchingDims,
ArrayRef<int64_t> contractingDims, size_t extraDims) {
llvm::SmallVector<AffineExpr> indices(rank);
for (const auto &i : llvm::enumerate(batchingDims)) {
indices[i.value()] = rewriter.getAffineDimExpr(i.index());
}
for (const auto &i : llvm::enumerate(contractingDims)) {
indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank);
}
for (int i = 0; i < rank; ++i) {
if (!indices[i]) {
indices[i] = rewriter.getAffineDimExpr(extraDims++);
}
}
indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
/*symbolCount=*/0, indices,
op->getContext()));
};
getMap(lhsRank, lhsBatchingDims, lhsContractingDims,
lhsBatchingDims.size());
getMap(rhsRank, rhsBatchingDims, rhsContractingDims,
rhsBatchingDims.size() + lhsExtraDims);
{
SmallVector<AffineExpr> dimExprs;
dimExprs.reserve(targetRank);
for (unsigned i = 0; i < targetRank; ++i)
dimExprs.push_back(rewriter.getAffineDimExpr(i));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
/*symbolCount=*/0, dimExprs,
op.getContext()));
}
Operation *linalgOp = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/TypeRange{outputType},
/*inputs=*/ValueRange{adaptor.getLhs(), adaptor.getRhs()},
/*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps,
getParallelAndReductionIterators(
/*nLoops=*/totalLoopCount,
/*nReduction=*/numContracting),
[](OpBuilder &b, Location loc, ValueRange) {
ImplicitLocOpBuilder builder(loc, b);
linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {});
},
linalg::getPrunedAttributeList(op));
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
} // namespace
namespace detail {
void populateStableHloDotProdToLinalgConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns) {
// Ensure specialized patterns are higher priority than their generic
// versions.
patterns
->add<DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>,
DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>,
DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>,
DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralBatchMatMulOpConversion>(typeConverter, context,
PatternBenefit(2));
patterns->add<DotGeneralOpConversion>(typeConverter, context,
PatternBenefit(1));
}
} // namespace detail
} // namespace mlir::iree_compiler::stablehlo