blob: 6f2985abd50b1fa1c82b9f2c8da044d76489e90a [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements logic for lowering StableHLO pointwise ops to Linalg dialect.
// These patterns are separated out to their own file to save on the compilation
// times, given that we instantiate a large number of class templates here.
#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/MapStableHLOToScalarOp.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
namespace {
int64_t getRank(Value v) { return cast<ShapedType>(v.getType()).getRank(); }
int64_t getMaxRank(ValueRange operands) {
int64_t maxRank = 0;
for (Value operand : operands) {
maxRank = std::max(maxRank, getRank(operand));
}
return maxRank;
}
bool isScalar(Value v) { return getRank(v) == 0; }
/// Inserts block arguments in places where scalar inputs have a nullptr.
SmallVector<Value> interleaveScalarAndBlockArgs(ValueRange scalarInputs,
ValueRange blockArgs) {
SmallVector<Value> result;
auto argsIter = blockArgs.begin();
for (Value scalarInput : scalarInputs) {
if (scalarInput) {
result.push_back(scalarInput);
} else {
result.push_back(*argsIter);
++argsIter;
}
}
return result;
}
struct PointwiseConversionInfo {
int64_t maxOperandRank = 0;
ShapedType resultType;
};
/// Checks the preconditions for conversion of pointwise HLO ops to linalg.
/// Returns the max operand rank and the result type on success.
FailureOr<PointwiseConversionInfo>
checkOperandsAndResults(Operation *op, ValueRange operands,
const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
int64_t maxRank = getMaxRank(operands);
// Apply only if all operands are scalar or have the same rank. Some ops,
// like `stablehlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(operands, [&](Value v) {
int64_t r = getRank(v);
return r == 0 || r == maxRank;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be of same rank or scalar.");
}
// Find result type, if on tensors.
auto resultTy = dyn_cast_or_null<ShapedType>(
typeConverter.convertType(op->getResultTypes().front()));
// Check result type compatibility.
if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank ||
!(resultTy.getElementType().isSignlessIntOrFloat() ||
isa<ComplexType>(resultTy.getElementType()))) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}
// All-scalar pointwise ops inside of linalg ops are processes by
// ScalarHloToArithmeticPattern.
if (maxRank == 0 && isInBodyOfLinalgOps(op))
return failure();
return PointwiseConversionInfo{maxRank, resultTy};
}
/// Converts a HLO operation to a linalg.map op that contains the corresponding
/// scalar operations.
template <typename OpTy>
struct PointwiseToLinalgMapConverter final : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto conversionInfo = checkOperandsAndResults(
op, adaptor.getOperands(), *this->typeConverter, rewriter);
if (failed(conversionInfo)) {
return failure();
}
int64_t maxRank = conversionInfo->maxOperandRank;
ShapedType resultTy = conversionInfo->resultType;
Location loc = op.getLoc();
// Find input/output values and types.
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands());
// Mapped inputs are cast to the same shape as the init tensor.
// Values from scalar inputs are extracted and used directly in the block.
SmallVector<Value> mappedInputs;
SmallVector<Value> scalarInputs;
for (Value input : adaptor.getOperands()) {
if (getRank(input) == maxRank) {
mappedInputs.push_back(coerceTensorShape(
rewriter, loc, cast<TypedValue<ShapedType>>(input),
cast<ShapedType>(emptyTensor.getType())));
scalarInputs.push_back(nullptr);
} else {
scalarInputs.push_back(rewriter.create<tensor::ExtractOp>(loc, input));
}
}
auto mapOp = rewriter.create<linalg::MapOp>(
loc, mappedInputs, emptyTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, getElementTypeOrSelf(emptyTensor),
interleaveScalarAndBlockArgs(scalarInputs, args), &b);
b.create<linalg::YieldOp>(loc, innerResult);
},
linalg::getPrunedAttributeList(op));
rewriter.replaceOp(op, mapOp->getResults());
return success();
}
};
/// Converts a HLO operation to a linalg.generic op that contains the
/// corresponding scalar operations.
template <typename OpTy>
struct PointwiseToLinalgConverter final : OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto conversionInfo = checkOperandsAndResults(
op, adaptor.getOperands(), *this->typeConverter, rewriter);
if (failed(conversionInfo)) {
return failure();
}
int64_t maxRank = conversionInfo->maxOperandRank;
ShapedType resultTy = conversionInfo->resultType;
Location loc = op.getLoc();
// Find input/output values and types.
ValueRange inputs = adaptor.getOperands();
Value output =
getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands());
// Create indexing maps.
AffineMap scalarMap = AffineMap::get(maxRank, 0, rewriter.getContext());
AffineMap idMap = rewriter.getMultiDimIdentityMap(maxRank);
SmallVector<AffineMap> maps;
for (Value v : inputs)
maps.push_back(isScalar(v) ? scalarMap : idMap);
maps.push_back(idMap);
// Build `linalg.generic` op.
bool failed = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultTy ? resultTy : TypeRange{}, inputs, output, maps,
getNParallelLoopsAttrs(maxRank),
[&](OpBuilder &nestedBuilder, Location /*nested_loc*/,
ValueRange args) {
Type innerResultTy = getElementTypeOrSelf(output);
auto argvec = llvm::to_vector<2>(args.take_front(inputs.size()));
Value semiring = preSparsify(op, argvec, innerResultTy, &rewriter);
Value innerResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, innerResultTy, argvec, &rewriter);
if (!innerResult) {
failed = true;
} else {
innerResult = postSparsify(op, semiring, innerResult, &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, innerResult);
}
},
linalg::getPrunedAttributeList(op));
if (failed)
return failure();
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
};
} // namespace
namespace detail {
void populatePointwiseStableHloToLinalgConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns, bool enablePrimitiveOps) {
if (enablePrimitiveOps) {
patterns->add<
PointwiseToLinalgMapConverter<mlir::stablehlo::AbsOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::AddOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::AndOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::Atan2Op>,
PointwiseToLinalgMapConverter<mlir::stablehlo::BitcastConvertOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::CbrtOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::CeilOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ClampOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ClzOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::CompareOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ComplexOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ConvertOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::CosineOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::DivOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ExpOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::Expm1Op>,
PointwiseToLinalgMapConverter<mlir::stablehlo::FloorOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ImagOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::IsFiniteOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::Log1pOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::LogOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::LogisticOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::MaxOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::MinOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::MulOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::NegOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::NotOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::OrOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::PopulationCountOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::PowOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::RealOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ReducePrecisionOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::RemOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::RoundNearestEvenOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::RoundOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::RsqrtOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::SelectOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ShiftLeftOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ShiftRightArithmeticOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::ShiftRightLogicalOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::SignOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::SineOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::SqrtOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::SubtractOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::TanhOp>,
PointwiseToLinalgMapConverter<mlir::stablehlo::XorOp>>(typeConverter,
context);
return;
}
patterns
->add<PointwiseToLinalgConverter<mlir::stablehlo::AbsOp>,
PointwiseToLinalgConverter<mlir::stablehlo::AddOp>,
PointwiseToLinalgConverter<mlir::stablehlo::AndOp>,
PointwiseToLinalgConverter<mlir::stablehlo::Atan2Op>,
PointwiseToLinalgConverter<mlir::stablehlo::BitcastConvertOp>,
PointwiseToLinalgConverter<mlir::stablehlo::CbrtOp>,
PointwiseToLinalgConverter<mlir::stablehlo::CeilOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ClampOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ClzOp>,
PointwiseToLinalgConverter<mlir::stablehlo::CompareOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ComplexOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ConvertOp>,
PointwiseToLinalgConverter<mlir::stablehlo::CosineOp>,
PointwiseToLinalgConverter<mlir::stablehlo::DivOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ExpOp>,
PointwiseToLinalgConverter<mlir::stablehlo::Expm1Op>,
PointwiseToLinalgConverter<mlir::stablehlo::FloorOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ImagOp>,
PointwiseToLinalgConverter<mlir::stablehlo::IsFiniteOp>,
PointwiseToLinalgConverter<mlir::stablehlo::Log1pOp>,
PointwiseToLinalgConverter<mlir::stablehlo::LogOp>,
PointwiseToLinalgConverter<mlir::stablehlo::LogisticOp>,
PointwiseToLinalgConverter<mlir::stablehlo::MaxOp>,
PointwiseToLinalgConverter<mlir::stablehlo::MinOp>,
PointwiseToLinalgConverter<mlir::stablehlo::MulOp>,
PointwiseToLinalgConverter<mlir::stablehlo::NegOp>,
PointwiseToLinalgConverter<mlir::stablehlo::NotOp>,
PointwiseToLinalgConverter<mlir::stablehlo::OrOp>,
PointwiseToLinalgConverter<mlir::stablehlo::PopulationCountOp>,
PointwiseToLinalgConverter<mlir::stablehlo::PowOp>,
PointwiseToLinalgConverter<mlir::stablehlo::RealOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ReducePrecisionOp>,
PointwiseToLinalgConverter<mlir::stablehlo::RemOp>,
PointwiseToLinalgConverter<mlir::stablehlo::RoundNearestEvenOp>,
PointwiseToLinalgConverter<mlir::stablehlo::RoundOp>,
PointwiseToLinalgConverter<mlir::stablehlo::RsqrtOp>,
PointwiseToLinalgConverter<mlir::stablehlo::SelectOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ShiftLeftOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ShiftRightArithmeticOp>,
PointwiseToLinalgConverter<mlir::stablehlo::ShiftRightLogicalOp>,
PointwiseToLinalgConverter<mlir::stablehlo::SignOp>,
PointwiseToLinalgConverter<mlir::stablehlo::SineOp>,
PointwiseToLinalgConverter<mlir::stablehlo::SqrtOp>,
PointwiseToLinalgConverter<mlir::stablehlo::SubtractOp>,
PointwiseToLinalgConverter<mlir::stablehlo::TanhOp>,
PointwiseToLinalgConverter<mlir::stablehlo::XorOp>>(typeConverter,
context);
}
} // namespace detail
} // namespace mlir::iree_compiler::stablehlo