blob: 1249651a2e653c7bf61b5f27e1048678efddf07e [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace mlir::iree_compiler {
#define GEN_PASS_DEF_CONVERTI48TOI64PASS
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc"
namespace {
class Converti48Toi64Pass final
: public impl::Converti48Toi64PassBase<Converti48Toi64Pass> {
public:
explicit Converti48Toi64Pass() = default;
void runOnOperation() override;
};
struct i48Toi64Converter : public TypeConverter {
public:
static Type convertType(Type type) {
if (type.isInteger(48)) {
return IntegerType::get(type.getContext(), /*width=*/64);
}
return type;
}
static Type convertTensor(RankedTensorType type) {
auto newType = RankedTensorType::get(type.getShape(),
convertType(type.getElementType()));
return newType;
}
explicit i48Toi64Converter() {
addConversion([](Type type) { return convertType(type); });
addConversion(convertTensor);
}
};
// Handles the type conversion component of the TypeConversion. This updates
// conversion patterns that used the original i48 tensor types to be
// updated to the i64 variants.
class GenericTypeConvert : public ConversionPattern {
public:
GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<Type, 4> newResults;
if (isa<mlir::FunctionOpInterface>(op)) {
return rewriter.notifyMatchFailure(op, "is a func op");
}
llvm::SmallVector<Type, 4> oldAttrTypes;
llvm::SmallVector<unsigned, 4> typedIndices;
// Extract the typed attributes for conversion.
for (auto [index, attr] : llvm::enumerate(op->getAttrs())) {
if (auto typedAttr = dyn_cast<TypedAttr>(attr.getValue())) {
oldAttrTypes.push_back(typedAttr.getType());
typedIndices.push_back(index);
}
}
llvm::SmallVector<Type, 4> newAttrTypes;
(void)getTypeConverter()->convertTypes(oldAttrTypes, newAttrTypes);
llvm::SmallVector<NamedAttribute, 4> newAttrs(op->getAttrs());
for (auto [idx, typedIndex] : llvm::enumerate(typedIndices)) {
auto attrValue = newAttrs[typedIndex].getValue();
auto newAttrType = newAttrTypes[idx];
// For integer attributes, create a new integer of new width.
if (auto intAttr = dyn_cast<IntegerAttr>(attrValue)) {
if (auto intType = dyn_cast<IntegerType>(newAttrType)) {
auto value =
IntegerAttr::get(intType, intAttr.getValue().getZExtValue());
newAttrs[typedIndex] =
NamedAttribute(newAttrs[typedIndex].getName(), value);
continue;
}
}
// For shaped types, map the values to the new types.
if (auto shapedType = dyn_cast<ShapedType>(newAttrType)) {
if (auto denseAttr = dyn_cast<DenseIntElementsAttr>(attrValue)) {
auto eType = dyn_cast<IntegerType>(shapedType.getElementType());
auto cast = [&](APInt value) {
return APInt(eType.getWidth(), value.getZExtValue());
};
auto newDenseAttr = denseAttr.mapValues(eType, cast);
newAttrs[typedIndex] =
NamedAttribute(newAttrs[typedIndex].getName(), newDenseAttr);
continue;
}
}
return rewriter.notifyMatchFailure(op, "Unsupported input type");
}
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());
for (Region &r : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
static bool isIllegalType(Type type) {
if (auto shapedType = dyn_cast<ShapedType>(type)) {
return isIllegalType(shapedType.getElementType());
}
return type.isInteger(48);
}
void Converti48Toi64Pass::runOnOperation() {
i48Toi64Converter converter;
ConversionTarget target(getContext());
// Operations are legal if they don't contain any illegal type.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (auto funcOp = dyn_cast<mlir::FunctionOpInterface>(op)) {
for (Type type : funcOp.getArgumentTypes()) {
if (isIllegalType(type))
return false;
}
for (Type type : funcOp.getResultTypes()) {
if (isIllegalType(type))
return false;
}
}
for (Type type : op->getResultTypes()) {
if (type && isIllegalType(type))
return false;
}
for (Type type : op->getOperandTypes()) {
if (type && isIllegalType(type))
return false;
}
for (auto attr : op->getAttrs()) {
if (auto typedAttr = dyn_cast<TypedAttr>(attr.getValue())) {
if (isIllegalType(typedAttr.getType())) {
return false;
}
}
}
return true;
});
auto *ctx = &getContext();
auto func = getOperation();
RewritePatternSet patterns(&getContext());
patterns.add<GenericTypeConvert>(ctx, converter);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
if (failed(applyFullConversion(func, target, std::move(patterns)))) {
signalPassFailure();
}
}
} // namespace
} // namespace mlir::iree_compiler