blob: d9097e321c610baeb116af6e1ac9877d1b5543f3 [file] [log] [blame]
rsuderman5c447842021-09-27 15:57:30 -07001// Copyright 2021 The IREE Authors
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
Scott Toddc344e262024-03-06 16:06:51 -08007#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
8#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
rsuderman5c447842021-09-27 15:57:30 -07009#include "mlir/IR/PatternMatch.h"
Diego Caballero02a355e2023-09-13 20:32:52 -070010#include "mlir/Interfaces/FunctionInterfaces.h"
Scott Todd961acb92023-11-13 08:13:25 -080011#include "mlir/Transforms/DialectConversion.h"
rsuderman5c447842021-09-27 15:57:30 -070012
Ben Vanike5f95c22023-12-06 19:48:43 -080013namespace mlir::iree_compiler {
rsuderman5c447842021-09-27 15:57:30 -070014
15namespace {
16
17class StripSignednessPass : public StripSignednessBase<StripSignednessPass> {
Jakub Kuderski3b652d42023-06-23 20:51:35 -040018public:
rsuderman5c447842021-09-27 15:57:30 -070019 explicit StripSignednessPass() {}
20 void runOnOperation() override;
21};
22
23class IntegerTypeConverter : public TypeConverter {
Jakub Kuderski3b652d42023-06-23 20:51:35 -040024public:
rsuderman5c447842021-09-27 15:57:30 -070025 static Type convertType(Type type) {
Jacques Pienaar53faf3a2023-05-25 08:12:06 -070026 if (auto iType = llvm::dyn_cast<IntegerType>(type)) {
rsuderman5c447842021-09-27 15:57:30 -070027 if (!iType.isSignless()) {
28 return IntegerType::get(type.getContext(),
29 iType.getIntOrFloatBitWidth());
30 }
31 }
32 return type;
33 }
34 static Type convertTensor(RankedTensorType type) {
35 auto newType = RankedTensorType::get(type.getShape(),
36 convertType(type.getElementType()));
37 return newType;
38 }
39 explicit IntegerTypeConverter() {
40 addConversion([](Type type) { return convertType(type); });
41 addConversion(convertTensor);
42 }
43};
44
45// Handles the type conversion component of the TypeConversion. This updates
46// conversion patterns that used the original Quant types to be updated to
47// the non-quant variants.
48class GenericTypeConvert : public ConversionPattern {
Jakub Kuderski3b652d42023-06-23 20:51:35 -040049public:
50 GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
rsuderman5c447842021-09-27 15:57:30 -070051 : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
Jakub Kuderski3b652d42023-06-23 20:51:35 -040052 LogicalResult
53 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
54 ConversionPatternRewriter &rewriter) const override {
Han-Chung Wang27448a82023-06-13 13:45:19 -070055 llvm::SmallVector<Type> newResults;
Stella Laurenzo2f0a48c2022-03-18 16:24:59 -070056 if (isa<FunctionOpInterface>(op)) {
rsuderman5c447842021-09-27 15:57:30 -070057 return failure();
58 }
59
60 (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
61 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
62 newResults, op->getAttrs(), op->getSuccessors());
Jakub Kuderski3b652d42023-06-23 20:51:35 -040063 for (Region &r : op->getRegions()) {
64 Region *newRegion = state.addRegion();
rsuderman5c447842021-09-27 15:57:30 -070065 rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
66 TypeConverter::SignatureConversion result(newRegion->getNumArguments());
67 (void)getTypeConverter()->convertSignatureArgs(
68 newRegion->getArgumentTypes(), result);
Lei Zhangac418d12024-06-20 19:24:08 -070069 rewriter.applySignatureConversion(&newRegion->front(), result);
rsuderman5c447842021-09-27 15:57:30 -070070 }
Jakub Kuderski3b652d42023-06-23 20:51:35 -040071 Operation *newOp = rewriter.create(state);
rsuderman5c447842021-09-27 15:57:30 -070072 rewriter.replaceOp(op, newOp->getResults());
73 return success();
74 }
75};
76
77static bool isIllegalType(Type type) {
Jacques Pienaar53faf3a2023-05-25 08:12:06 -070078 if (IntegerType ity = llvm::dyn_cast<IntegerType>(type))
79 return !ity.isSignless();
80 if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
rsuderman5c447842021-09-27 15:57:30 -070081 return isIllegalType(shapedType.getElementType());
82 }
83 return false;
84}
85
86void StripSignednessPass::runOnOperation() {
87 IntegerTypeConverter converter;
88 ConversionTarget target(getContext());
89
90 // Operations are legal if they don't contain any illegal type.
Jakub Kuderski3b652d42023-06-23 20:51:35 -040091 target.markUnknownOpDynamicallyLegal([](Operation *op) {
Stella Laurenzo2f0a48c2022-03-18 16:24:59 -070092 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
93 for (Type type : funcOp.getArgumentTypes()) {
Jakub Kuderski3b652d42023-06-23 20:51:35 -040094 if (isIllegalType(type))
95 return false;
rsuderman5c447842021-09-27 15:57:30 -070096 }
Stella Laurenzo2f0a48c2022-03-18 16:24:59 -070097 for (Type type : funcOp.getResultTypes()) {
Jakub Kuderski3b652d42023-06-23 20:51:35 -040098 if (isIllegalType(type))
99 return false;
rsuderman5c447842021-09-27 15:57:30 -0700100 }
101 }
102 for (Type type : op->getResultTypes()) {
Jakub Kuderski3b652d42023-06-23 20:51:35 -0400103 if (type && isIllegalType(type))
104 return false;
rsuderman5c447842021-09-27 15:57:30 -0700105 }
106 for (Type type : op->getOperandTypes()) {
Jakub Kuderski3b652d42023-06-23 20:51:35 -0400107 if (type && isIllegalType(type))
108 return false;
rsuderman5c447842021-09-27 15:57:30 -0700109 }
110 return true;
111 });
112
Jakub Kuderski3b652d42023-06-23 20:51:35 -0400113 auto *ctx = &getContext();
rsuderman5c447842021-09-27 15:57:30 -0700114
115 RewritePatternSet patterns(&getContext());
116 patterns.insert<GenericTypeConvert>(ctx, converter);
Stella Laurenzo2f0a48c2022-03-18 16:24:59 -0700117 populateFunctionOpInterfaceTypeConversionPattern(
118 getOperation()->getName().getStringRef(), patterns, converter);
rsuderman5c447842021-09-27 15:57:30 -0700119
Stella Laurenzo2f0a48c2022-03-18 16:24:59 -0700120 if (failed(
121 applyFullConversion(getOperation(), target, std::move(patterns)))) {
rsuderman5c447842021-09-27 15:57:30 -0700122 signalPassFailure();
123 }
124}
125
Jakub Kuderski3b652d42023-06-23 20:51:35 -0400126} // namespace
rsuderman5c447842021-09-27 15:57:30 -0700127
Stella Laurenzo2f0a48c2022-03-18 16:24:59 -0700128std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
129createStripSignednessPass() {
rsuderman5c447842021-09-27 15:57:30 -0700130 return std::make_unique<StripSignednessPass>();
131}
132
Ben Vanike5f95c22023-12-06 19:48:43 -0800133} // namespace mlir::iree_compiler