rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 1 | // Copyright 2021 The IREE Authors |
| 2 | // |
| 3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | |
Scott Todd | c344e26 | 2024-03-06 16:06:51 -0800 | [diff] [blame] | 7 | #include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h" |
| 8 | #include "compiler/plugins/input/TOSA/InputConversion/Passes.h" |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 9 | #include "mlir/IR/PatternMatch.h" |
Diego Caballero | 02a355e | 2023-09-13 20:32:52 -0700 | [diff] [blame] | 10 | #include "mlir/Interfaces/FunctionInterfaces.h" |
Scott Todd | 961acb9 | 2023-11-13 08:13:25 -0800 | [diff] [blame] | 11 | #include "mlir/Transforms/DialectConversion.h" |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 12 | |
Ben Vanik | e5f95c2 | 2023-12-06 19:48:43 -0800 | [diff] [blame] | 13 | namespace mlir::iree_compiler { |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 14 | |
| 15 | namespace { |
| 16 | |
| 17 | class StripSignednessPass : public StripSignednessBase<StripSignednessPass> { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 18 | public: |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 19 | explicit StripSignednessPass() {} |
| 20 | void runOnOperation() override; |
| 21 | }; |
| 22 | |
| 23 | class IntegerTypeConverter : public TypeConverter { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 24 | public: |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 25 | static Type convertType(Type type) { |
Jacques Pienaar | 53faf3a | 2023-05-25 08:12:06 -0700 | [diff] [blame] | 26 | if (auto iType = llvm::dyn_cast<IntegerType>(type)) { |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 27 | if (!iType.isSignless()) { |
| 28 | return IntegerType::get(type.getContext(), |
| 29 | iType.getIntOrFloatBitWidth()); |
| 30 | } |
| 31 | } |
| 32 | return type; |
| 33 | } |
| 34 | static Type convertTensor(RankedTensorType type) { |
| 35 | auto newType = RankedTensorType::get(type.getShape(), |
| 36 | convertType(type.getElementType())); |
| 37 | return newType; |
| 38 | } |
| 39 | explicit IntegerTypeConverter() { |
| 40 | addConversion([](Type type) { return convertType(type); }); |
| 41 | addConversion(convertTensor); |
| 42 | } |
| 43 | }; |
| 44 | |
| 45 | // Handles the type conversion component of the TypeConversion. This updates |
| 46 | // conversion patterns that used the original Quant types to be updated to |
| 47 | // the non-quant variants. |
| 48 | class GenericTypeConvert : public ConversionPattern { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 49 | public: |
| 50 | GenericTypeConvert(MLIRContext *context, TypeConverter &converter) |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 51 | : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {} |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 52 | LogicalResult |
| 53 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 54 | ConversionPatternRewriter &rewriter) const override { |
Han-Chung Wang | 27448a8 | 2023-06-13 13:45:19 -0700 | [diff] [blame] | 55 | llvm::SmallVector<Type> newResults; |
Stella Laurenzo | 2f0a48c | 2022-03-18 16:24:59 -0700 | [diff] [blame] | 56 | if (isa<FunctionOpInterface>(op)) { |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 57 | return failure(); |
| 58 | } |
| 59 | |
| 60 | (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults); |
| 61 | OperationState state(op->getLoc(), op->getName().getStringRef(), operands, |
| 62 | newResults, op->getAttrs(), op->getSuccessors()); |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 63 | for (Region &r : op->getRegions()) { |
| 64 | Region *newRegion = state.addRegion(); |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 65 | rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin()); |
| 66 | TypeConverter::SignatureConversion result(newRegion->getNumArguments()); |
| 67 | (void)getTypeConverter()->convertSignatureArgs( |
| 68 | newRegion->getArgumentTypes(), result); |
Lei Zhang | ac418d1 | 2024-06-20 19:24:08 -0700 | [diff] [blame^] | 69 | rewriter.applySignatureConversion(&newRegion->front(), result); |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 70 | } |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 71 | Operation *newOp = rewriter.create(state); |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 72 | rewriter.replaceOp(op, newOp->getResults()); |
| 73 | return success(); |
| 74 | } |
| 75 | }; |
| 76 | |
| 77 | static bool isIllegalType(Type type) { |
Jacques Pienaar | 53faf3a | 2023-05-25 08:12:06 -0700 | [diff] [blame] | 78 | if (IntegerType ity = llvm::dyn_cast<IntegerType>(type)) |
| 79 | return !ity.isSignless(); |
| 80 | if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) { |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 81 | return isIllegalType(shapedType.getElementType()); |
| 82 | } |
| 83 | return false; |
| 84 | } |
| 85 | |
| 86 | void StripSignednessPass::runOnOperation() { |
| 87 | IntegerTypeConverter converter; |
| 88 | ConversionTarget target(getContext()); |
| 89 | |
| 90 | // Operations are legal if they don't contain any illegal type. |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 91 | target.markUnknownOpDynamicallyLegal([](Operation *op) { |
Stella Laurenzo | 2f0a48c | 2022-03-18 16:24:59 -0700 | [diff] [blame] | 92 | if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { |
| 93 | for (Type type : funcOp.getArgumentTypes()) { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 94 | if (isIllegalType(type)) |
| 95 | return false; |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 96 | } |
Stella Laurenzo | 2f0a48c | 2022-03-18 16:24:59 -0700 | [diff] [blame] | 97 | for (Type type : funcOp.getResultTypes()) { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 98 | if (isIllegalType(type)) |
| 99 | return false; |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 100 | } |
| 101 | } |
| 102 | for (Type type : op->getResultTypes()) { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 103 | if (type && isIllegalType(type)) |
| 104 | return false; |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 105 | } |
| 106 | for (Type type : op->getOperandTypes()) { |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 107 | if (type && isIllegalType(type)) |
| 108 | return false; |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 109 | } |
| 110 | return true; |
| 111 | }); |
| 112 | |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 113 | auto *ctx = &getContext(); |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 114 | |
| 115 | RewritePatternSet patterns(&getContext()); |
| 116 | patterns.insert<GenericTypeConvert>(ctx, converter); |
Stella Laurenzo | 2f0a48c | 2022-03-18 16:24:59 -0700 | [diff] [blame] | 117 | populateFunctionOpInterfaceTypeConversionPattern( |
| 118 | getOperation()->getName().getStringRef(), patterns, converter); |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 119 | |
Stella Laurenzo | 2f0a48c | 2022-03-18 16:24:59 -0700 | [diff] [blame] | 120 | if (failed( |
| 121 | applyFullConversion(getOperation(), target, std::move(patterns)))) { |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 122 | signalPassFailure(); |
| 123 | } |
| 124 | } |
| 125 | |
Jakub Kuderski | 3b652d4 | 2023-06-23 20:51:35 -0400 | [diff] [blame] | 126 | } // namespace |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 127 | |
Stella Laurenzo | 2f0a48c | 2022-03-18 16:24:59 -0700 | [diff] [blame] | 128 | std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> |
| 129 | createStripSignednessPass() { |
rsuderman | 5c44784 | 2021-09-27 15:57:30 -0700 | [diff] [blame] | 130 | return std::make_unique<StripSignednessPass>(); |
| 131 | } |
| 132 | |
Ben Vanik | e5f95c2 | 2023-12-06 19:48:43 -0800 | [diff] [blame] | 133 | } // namespace mlir::iree_compiler |