blob: 7a62ec2aca80b2e1227f1f97b88adddb9858cea9 [file] [log] [blame]
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements logic for lowering scalar StableHLO ops to arith dialect.
#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
namespace {
template <typename OpTy>
struct ScalarHloToFuncPatterns final : OpConversionPattern<OpTy> {
ScalarHloToFuncPatterns(TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern<OpTy>(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<mlir::func::FuncOp>(op->getParentOp())) {
return rewriter.notifyMatchFailure(op,
"Return must be inside a function");
}
mlir::Operation::operand_range operands = op.getOperands();
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, operands);
return success();
}
};
template <typename OpTy>
struct ScalarHloToArithmeticPattern final : OpConversionPattern<OpTy> {
ScalarHloToArithmeticPattern(
TypeConverter &typeConverter, MLIRContext *context,
llvm::function_ref<bool(Operation *)> filterFn = nullptr,
PatternBenefit benefit = 1)
: OpConversionPattern<OpTy>(typeConverter, context, benefit),
filterFn(filterFn) {}
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (filterFn && !filterFn(op))
return failure();
auto isScalar = [](Value v) {
return cast<ShapedType>(v.getType()).getRank() == 0;
};
if (!llvm::all_of(adaptor.getOperands(), isScalar))
return rewriter.notifyMatchFailure(op, "All operands must be scalar.");
Location loc = op.getLoc();
auto resultTy = dyn_cast_or_null<ShapedType>(
this->getTypeConverter()->convertType(op->getResultTypes().front()));
if (!resultTy)
return failure();
SmallVector<Value> operands;
for (Value operand : adaptor.getOperands()) {
operands.push_back(
rewriter.create<tensor::ExtractOp>(loc, operand, ValueRange()));
}
Value scalarResult = mlir::stablehlo::StableHloOpToStdScalarOp::mapOp(
op, resultTy.getElementType(), operands, &rewriter);
if (!scalarResult)
return failure();
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, resultTy,
scalarResult);
return success();
}
private:
llvm::function_ref<bool(Operation *)> filterFn;
};
} // namespace
namespace detail {
void populateScalarHloToArithConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns,
llvm::function_ref<bool(Operation *)> filterFn) {
// TODO(#12678): Handle the XLA rng op.
patterns->add<
ScalarHloToArithmeticPattern<mlir::stablehlo::AbsOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::AddOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::AndOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::Atan2Op>,
ScalarHloToArithmeticPattern<mlir::stablehlo::BitcastConvertOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::CbrtOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::CeilOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ClampOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ClzOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::CompareOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ComplexOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ConvertOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::CosineOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::DivOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ExpOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::Expm1Op>,
ScalarHloToArithmeticPattern<mlir::stablehlo::FloorOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ImagOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::IsFiniteOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::Log1pOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::LogOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::LogisticOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::MaxOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::MinOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::MulOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::NegOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::NotOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::OrOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::PopulationCountOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::PowOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::RealOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ReducePrecisionOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::RemOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::RoundNearestEvenOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::RoundOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::RsqrtOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::SelectOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ShiftLeftOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ShiftRightArithmeticOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::ShiftRightLogicalOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::SignOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::SineOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::SqrtOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::SubtractOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::TanhOp>,
ScalarHloToArithmeticPattern<mlir::stablehlo::XorOp>>(typeConverter,
context, filterFn);
patterns->add<ScalarHloToFuncPatterns<mlir::stablehlo::ReturnOp>>(
typeConverter, context);
}
} // namespace detail
} // namespace mlir::iree_compiler::stablehlo