blob: 0180abbd0749916bf897e1c1dd49a25842b9d389 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implements IREE-specific logic for lowering StableHLO dialect to
// IREE dialects: Linalg, Arith, Math, Tensor, Util, ML Program, etc.
#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/PassDetail.h"
#include "compiler/plugins/input/StableHLO/Conversion/Passes.h"
#include "compiler/plugins/input/StableHLO/Conversion/Preprocessing/Rewriters.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "compiler/plugins/input/StableHLO/Conversion/TypeConversion.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Utils/ConversionUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
namespace mlir::iree_compiler::stablehlo {
#define GEN_PASS_DEF_CONVERTSTABLEHLOTOIREEINPUTDIALECTS
#include "compiler/plugins/input/StableHLO/Conversion/Passes.h.inc"
namespace {
/// Converts stablehlo.concatenate operation to extract_slice ops + insert_slice
/// ops.
struct ConcatenateOpConversion final
: OpConversionPattern<mlir::stablehlo::ConcatenateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::ConcatenateOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType =
getTypeConverter()->convertType<RankedTensorType>(op.getType());
if (!resultType || !resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected static shape for output");
}
Location loc = op.getLoc();
uint64_t dim = op.getDimension();
int64_t rank = resultType.getRank();
SmallVector<Value, 3> offsets;
SmallVector<Value, 3> sizes;
SmallVector<Value, 3> strides;
for (int64_t i = 0; i < rank; ++i) {
offsets.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
loc, adaptor.getOperands()[0], i));
strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 1));
}
Value resultDimSize = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (Value arg : adaptor.getOperands()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, dim);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
Value result = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType());
auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
Value accBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (Value arg : adaptor.getOperands()) {
offsets[dim] = accBound;
sizes[dim] = rewriter.createOrFold<tensor::DimOp>(loc, arg, dim);
result = rewriter.create<tensor::InsertSliceOp>(
loc, arg, result, llvm::map_to_vector(offsets, toOpFoldResult),
llvm::map_to_vector(sizes, toOpFoldResult),
llvm::map_to_vector(strides, toOpFoldResult));
accBound = rewriter.create<arith::AddIOp>(loc, accBound, sizes[dim]);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Creates coefficients based on DFT definition, see
/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform.
Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType,
bool isRealPart) {
// scale = 2 * pi / N
double scale = 2 * M_PI / matrixType.getDimSize(0);
SmallVector<Attribute> values;
assert(matrixType.getRank() == 2 && "expected 2D matrix");
for (auto i : llvm::seq<unsigned>(0, matrixType.getDimSize(0))) {
for (auto j : llvm::seq<unsigned>(0, matrixType.getDimSize(1))) {
double v = scale * i * j;
if (isRealPart) {
v = cos(v);
} else {
v = -sin(v);
}
values.push_back(b.getF32FloatAttr(v));
}
}
return b.create<arith::ConstantOp>(
loc, matrixType, DenseFPElementsAttr::get(matrixType, values));
}
Value createLinalgMatmulOnTensors(OpBuilder b, Location loc,
RankedTensorType resultType, Value lhs,
Value rhs) {
Value zero = b.create<arith::ConstantOp>(
loc, b.getZeroAttr(resultType.getElementType()));
Value emptyTensor = b.create<mlir::tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType(),
/*dyn_size=*/ValueRange{});
Value zeroTensor =
b.create<linalg::FillOp>(loc, zero, emptyTensor).getResult(0);
switch (llvm::cast<RankedTensorType>(lhs.getType()).getRank()) {
case 1:
return b
.create<linalg::VecmatOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs}, ValueRange{zeroTensor})
.getResult(0);
case 2:
return b
.create<linalg::MatmulOp>(loc, TypeRange{resultType},
ValueRange{lhs, rhs}, ValueRange{zeroTensor})
.getResult(0);
default:
assert(false && "unhandled matmul type");
return Value();
}
}
/// Converts stablehlo.fft operation to Linalg ops.
struct FftOpConversion final : OpConversionPattern<mlir::stablehlo::FftOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::FftOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getFftType() != mlir::stablehlo::FftType::RFFT) {
return rewriter.notifyMatchFailure(op,
"non RFFT types are supported yet");
}
auto inputType = dyn_cast<RankedTensorType>(adaptor.getOperand().getType());
if (!inputType || !inputType.hasStaticShape() || inputType.getRank() > 2) {
return rewriter.notifyMatchFailure(op, "only static 1D or 2D dft ops");
}
int64_t rank = inputType.getRank();
int64_t n = inputType.getDimSize(rank - 1);
int64_t fftLength = op.getFftLength().front() / 2 + 1;
Location loc = op.getLoc();
auto matrixType =
RankedTensorType::get({n, fftLength}, inputType.getElementType());
auto resultType = RankedTensorType::get(
llvm::cast<RankedTensorType>(op.getType()).getShape(),
inputType.getElementType());
Value realMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true);
Value real = createLinalgMatmulOnTensors(rewriter, loc, resultType,
adaptor.getOperand(), realMatrix);
Value imagMatrix =
getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false);
Value imag = createLinalgMatmulOnTensors(rewriter, loc, resultType,
adaptor.getOperand(), imagMatrix);
// Pack the results back to mlir::stablehlo::ComplexOp.
rewriter.replaceOpWithNewOp<mlir::stablehlo::ComplexOp>(op, op.getType(),
real, imag);
return success();
}
};
struct OptimizationBarrierOpConversion final
: OpConversionPattern<mlir::stablehlo::OptimizationBarrierOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::stablehlo::OptimizationBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> outputs;
for (Value operand : adaptor.getOperands()) {
outputs.push_back(
rewriter
.create<IREE::Util::OptimizationBarrierOp>(op.getLoc(), operand)
.getResult(0));
}
rewriter.replaceOp(op, outputs);
return success();
}
};
// Returns true if all attributes in the given dictionary are valid for IREE
// input dialects.
static bool isValidFuncAttr(DictionaryAttr attrs) {
// TODO: switch to using a dialect-based exclusion list or some other way that
// is not a big string table.
for (auto attr : attrs) {
if (attr.getName() == "tf.aliasing_output")
return false;
}
return true;
}
// Adds iree.abi.encoding attributes for arguments and results when they have
// had their type changed during conversion.
static void setFuncEncodings(func::FuncOp funcOp, FunctionType oldFuncType,
FunctionType newFuncType) {
auto encodingName = StringAttr::get(funcOp.getContext(), "iree.abi.encoding");
for (auto [i, oldType, newType] :
llvm::enumerate(oldFuncType.getInputs(), newFuncType.getInputs())) {
if (oldType != newType)
funcOp.setArgAttr(i, encodingName, TypeAttr::get(oldType));
}
for (auto [i, oldType, newType] :
llvm::enumerate(oldFuncType.getResults(), newFuncType.getResults())) {
if (oldType != newType)
funcOp.setResultAttr(i, encodingName, TypeAttr::get(oldType));
}
}
// Rewrites attributes on the function from ones coming from HLO-based frontends
// to the IREE supported versions.
static void rewriteFuncAttrs(func::FuncOp funcOp) {
auto *context = funcOp.getContext();
auto indexType = IndexType::get(context);
auto abiOutputName = StringAttr::get(context, "iree.abi.output");
auto aliasingOutputName = StringAttr::get(context, "tf.aliasing_output");
auto rewriteAttrs = [&](DictionaryAttr &allAttrs) {
SmallVector<NamedAttribute> newAttrs;
newAttrs.reserve(allAttrs.size());
for (auto attr : allAttrs) {
if (attr.getName() == aliasingOutputName) {
newAttrs.push_back({
abiOutputName,
IntegerAttr::get(indexType,
llvm::cast<IntegerAttr>(attr.getValue()).getInt()),
});
} else {
newAttrs.push_back(attr);
}
}
allAttrs = DictionaryAttr::get(context, newAttrs);
};
SmallVector<DictionaryAttr> argAttrs;
funcOp.getAllArgAttrs(argAttrs);
llvm::for_each(argAttrs, rewriteAttrs);
funcOp.setAllArgAttrs(argAttrs);
SmallVector<DictionaryAttr> resultAttrs;
funcOp.getAllResultAttrs(resultAttrs);
llvm::for_each(resultAttrs, rewriteAttrs);
funcOp.setAllResultAttrs(resultAttrs);
}
// We need to convert func ops in order to convert types.
struct BuiltinFuncOpPattern final : OpConversionPattern<func::FuncOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::FuncOp srcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FunctionType srcFuncType = srcOp.getFunctionType();
TypeConverter::SignatureConversion signatureConversion(
srcOp.getNumArguments());
// Convert function arguments.
for (auto [idx, inputTy] : llvm::enumerate(srcFuncType.getInputs())) {
if (failed(getTypeConverter()->convertSignatureArg(
idx, inputTy, signatureConversion))) {
return rewriter.notifyMatchFailure(srcOp, "argument failed to convert");
}
}
// Convert function results.
SmallVector<Type> convertedResultTypes;
if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(),
convertedResultTypes))) {
return rewriter.notifyMatchFailure(srcOp, "results failed to convert");
}
// Create new function with converted argument and result types.
auto oldFuncType = srcOp.getFunctionType();
auto newFuncType = mlir::FunctionType::get(
srcOp.getContext(), signatureConversion.getConvertedTypes(),
convertedResultTypes);
// Update the function in place.
rewriter.startOpModification(srcOp);
srcOp.setType(newFuncType);
rewriteFuncAttrs(srcOp);
setFuncEncodings(srcOp, oldFuncType, newFuncType);
// Tell the rewriter to convert the region signature.
const TypeConverter &typeConverter = *getTypeConverter();
if (failed(rewriter.convertRegionTypes(&srcOp.getBody(), typeConverter,
&signatureConversion))) {
return failure();
}
rewriter.finalizeOpModification(srcOp);
return success();
}
};
struct TensorEmptyPattern final : OpConversionPattern<tensor::EmptyOp> {
using OpConversionPattern<tensor::EmptyOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldType = cast<ShapedType>(op.getType());
auto newType = getTypeConverter()->convertType(oldType);
if (newType == oldType)
return failure();
if (!newType)
return rewriter.notifyMatchFailure(op, "result type conversion failed");
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
op, oldType.getShape(),
getTypeConverter()->convertType(oldType.getElementType()),
op.getDynamicSizes());
return success();
}
};
struct GlobalOpPattern final : OpConversionPattern<ml_program::GlobalOp> {
using OpConversionPattern<ml_program::GlobalOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ml_program::GlobalOp globalOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type oldType = globalOp.getType();
Type newType = getTypeConverter()->convertType(oldType);
if (newType == oldType)
return failure();
if (!newType) {
return rewriter.notifyMatchFailure(globalOp,
"result type conversion failed");
}
rewriter.modifyOpInPlace(globalOp, [&]() {
globalOp.setType(newType);
if (Attribute oldValue = globalOp.getValueAttr()) {
globalOp.setValueAttr(
convertAttribute(globalOp.getLoc(), oldValue, *getTypeConverter()));
}
});
return success();
}
};
template <typename T>
struct GenericTypeConvert final : ConversionPattern {
GenericTypeConvert(StringRef rootName, TypeConverter &converter,
MLIRContext *context, PatternBenefit benefit = 0)
: ConversionPattern(converter, rootName, benefit, context) {}
GenericTypeConvert(TypeConverter &converter, MLIRContext *context,
PatternBenefit benefit = 0)
: ConversionPattern(converter, T::getOperationName(), benefit, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<NamedAttribute> newAttr;
llvm::append_range(newAttr, op->getAttrs());
llvm::SmallVector<Type> newResults;
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
newResults))) {
return rewriter.notifyMatchFailure(op, "result type conversion failed");
}
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttr, op->getSuccessors());
for (Region &r : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
if (failed(getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result))) {
return rewriter.notifyMatchFailure(op,
"argument type conversion failed");
}
rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
std::optional<Value> scalarToTensor(OpBuilder &builder, Type /*type*/,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
if (isa<ShapedType>(inputs.front().getType())) {
return std::nullopt;
}
return builder
.create<tensor::FromElementsOp>(
loc, RankedTensorType::get({}, inputs.front().getType()),
inputs.front())
.getResult();
}
// Strips attributes from common StableHLO frontends (JAX, TF, etc) that are not
// used after conversion into the IREE input dialects. Leaving these attributes
// is confusing as they can become inconsistent during subsequent conversions or
// leak frontend details lower into the pipeline than should be allowed.
static void stripFrontendAttrs(mlir::ModuleOp moduleOp) {
auto isAttrFiltered = [](NamedAttribute attr) {
auto fullName = attr.getName().getValue();
return fullName.starts_with("mhlo.") || fullName.starts_with("jax.") ||
fullName.starts_with("tf.");
};
auto filterOpAttrs = [&](Operation *op) {
SmallVector<NamedAttribute> newAttrs;
for (auto attr : op->getDialectAttrs()) {
if (!isAttrFiltered(attr))
newAttrs.push_back(attr);
}
op->setDialectAttrs(newAttrs);
};
auto filterAttrDicts = [&](ArrayAttr allOldAttrs,
SmallVectorImpl<DictionaryAttr> &newAttrs) {
if (!allOldAttrs)
return false;
for (auto oldAttrs : allOldAttrs.getAsRange<DictionaryAttr>()) {
SmallVector<NamedAttribute> preservedAttrs;
preservedAttrs.reserve(oldAttrs.size());
for (auto attr : oldAttrs) {
if (!isAttrFiltered(attr))
preservedAttrs.push_back(attr);
}
newAttrs.push_back(
DictionaryAttr::get(allOldAttrs.getContext(), preservedAttrs));
}
return true;
};
filterOpAttrs(moduleOp);
for (auto callableOp : moduleOp.getOps<mlir::CallableOpInterface>()) {
filterOpAttrs(callableOp);
if (auto funcOp = dyn_cast<func::FuncOp>(callableOp.getOperation())) {
SmallVector<DictionaryAttr> newArgAttrs;
if (filterAttrDicts(funcOp.getAllArgAttrs(), newArgAttrs)) {
funcOp.setAllArgAttrs(newArgAttrs);
}
SmallVector<DictionaryAttr> newResultAttrs;
if (filterAttrDicts(funcOp.getAllResultAttrs(), newResultAttrs)) {
funcOp.setAllResultAttrs(newResultAttrs);
}
}
}
}
struct ConvertStableHloToIreeInputDialects final
: impl::ConvertStableHloToIreeInputDialectsBase<
ConvertStableHloToIreeInputDialects> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
IREE::Flow::FlowDialect, IREE::Util::UtilDialect, linalg::LinalgDialect,
arith::ArithDialect, tensor::TensorDialect, shape::ShapeDialect,
math::MathDialect, memref::MemRefDialect, complex::ComplexDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
std::unique_ptr<TypeConverter> typeConverter =
createStableHloToLinalgTypeConverter();
typeConverter->addArgumentMaterialization(scalarToTensor);
// Run stablehlo canonicalization patterns with a high benefit to avoid some
// expensive expansions.
populateCanonicalizationPatterns(context, &patterns, /*benefit=*/1024);
populateStableHloToLinalgOnTensorsConversionPatterns(
context, *typeConverter, &patterns);
populateStableHloCollectivesConversionPatterns(context, *typeConverter,
&patterns);
// TODO(#12678): Handle remaining complex ops.
// TODO(*): expose patterns that do this much better from
// iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp
// Structural patterns (functions, cfg, terminators).
patterns.add<BuiltinFuncOpPattern>(*typeConverter, context);
patterns.add<GlobalOpPattern, TensorEmptyPattern>(*typeConverter, context);
patterns.add<
GenericTypeConvert<cf::CondBranchOp>, GenericTypeConvert<cf::BranchOp>,
GenericTypeConvert<func::ReturnOp>, GenericTypeConvert<func::ReturnOp>,
GenericTypeConvert<func::CallOp>,
GenericTypeConvert<ml_program::GlobalLoadOp>,
GenericTypeConvert<ml_program::GlobalLoadConstOp>,
GenericTypeConvert<ml_program::GlobalStoreOp>,
GenericTypeConvert<scf::ForOp>, GenericTypeConvert<scf::IfOp>,
GenericTypeConvert<scf::YieldOp>, GenericTypeConvert<scf::ConditionOp>,
GenericTypeConvert<scf::WhileOp>,
GenericTypeConvert<tensor::FromElementsOp>,
GenericTypeConvert<tensor::CollapseShapeOp>,
GenericTypeConvert<tensor::ExpandShapeOp>,
GenericTypeConvert<arith::IndexCastUIOp>,
GenericTypeConvert<arith::SelectOp>>(*typeConverter, context);
ConversionTarget target(*context);
auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); };
auto isLegallyTypedOp = [&](Operation *op) -> bool {
for (Type type : op->getResultTypes()) {
if (isIllegalType(type))
return false;
}
for (Type type : op->getOperandTypes()) {
if (isIllegalType(type))
return false;
}
return true;
};
target.addIllegalDialect<mlir::chlo::ChloDialect>();
target.addIllegalDialect<mlir::stablehlo::StablehloDialect>();
// Functions must have legal types.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
if (auto attrs = funcOp.getAllArgAttrs()) {
if (!llvm::all_of(attrs.getAsRange<DictionaryAttr>(),
isValidFuncAttr)) {
return false;
}
}
if (auto attrs = funcOp.getAllResultAttrs()) {
if (!llvm::all_of(attrs.getAsRange<DictionaryAttr>(),
isValidFuncAttr)) {
return false;
}
}
for (Type type : funcOp.getFunctionType().getInputs()) {
if (isIllegalType(type))
return false;
}
for (Type type : funcOp.getFunctionType().getResults()) {
if (isIllegalType(type))
return false;
}
for (Block &block : funcOp.getFunctionBody()) {
for (Type type : block.getArgumentTypes()) {
if (isIllegalType(type))
return false;
}
}
return true;
});
target.addDynamicallyLegalOp<ml_program::GlobalOp>(
[&](ml_program::GlobalOp op) {
return typeConverter->isLegal(op.getType());
});
target.addDynamicallyLegalOp<tensor::EmptyOp>([&](tensor::EmptyOp op) {
return typeConverter->isLegal(op.getType());
});
// Let the rest fall through.
target.addLegalDialect<BuiltinDialect>();
target.addLegalDialect<IREE::LinalgExt::IREELinalgExtDialect>();
target.markUnknownOpDynamicallyLegal(isLegallyTypedOp);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
{
// Apply the patterns to remove unused operands and results.
RewritePatternSet removeUnusedOperandsResultsPatterns(context);
linalg::populateEraseUnusedOperandsAndResultsPatterns(
removeUnusedOperandsResultsPatterns);
if (failed(applyPatternsAndFoldGreedily(
getOperation(),
std::move(removeUnusedOperandsResultsPatterns)))) {
return signalPassFailure();
}
}
// Drop module/function attributes now that they are no longer required.
stripFrontendAttrs(getOperation());
}
};
} // namespace
void populateStableHloToLinalgOnTensorsConversionPatterns(
MLIRContext *context, TypeConverter &typeConverter,
RewritePatternSet *patterns) {
// TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version
// then remove the PatternBenefit here
patterns->add<ConcatenateOpConversion, FftOpConversion,
OptimizationBarrierOpConversion>(typeConverter, context,
PatternBenefit{1000});
populateStableHloToLinalgConversionPatterns(context, typeConverter, patterns,
/*enablePrimitiveOps=*/false);
}
} // namespace mlir::iree_compiler::stablehlo