Cleanup compiler plugin directory and include paths. (#16691)
Progress on https://github.com/openxla/iree/issues/15468 (saw these
issues while working on migrating each plugin, finally got around to
forming a clear opinion and fixing)
* Qualify include paths relative to the repository root, not the plugin
root or some artificial prefix like `torch-iree/`
* Collapse directory structures, e.g. `StableHLO/stablehlo-iree/` to ->
`StableHLO/`
* Generate more CMake files from Bazel (still needed a few TableGen
tweaks though)
* Fixup `# ifdef` guards after code moves
* Fixup a few copyright headers that were using LLVM format instead of
IREE
diff --git a/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
new file mode 100644
index 0000000..dc3ba5b
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
@@ -0,0 +1,133 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
+#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::iree_compiler {
+
+namespace {
+
+class StripSignednessPass : public StripSignednessBase<StripSignednessPass> {
+public:
+ explicit StripSignednessPass() {}
+ void runOnOperation() override;
+};
+
+class IntegerTypeConverter : public TypeConverter {
+public:
+ static Type convertType(Type type) {
+ if (auto iType = llvm::dyn_cast<IntegerType>(type)) {
+ if (!iType.isSignless()) {
+ return IntegerType::get(type.getContext(),
+ iType.getIntOrFloatBitWidth());
+ }
+ }
+ return type;
+ }
+ static Type convertTensor(RankedTensorType type) {
+ auto newType = RankedTensorType::get(type.getShape(),
+ convertType(type.getElementType()));
+ return newType;
+ }
+ explicit IntegerTypeConverter() {
+ addConversion([](Type type) { return convertType(type); });
+ addConversion(convertTensor);
+ }
+};
+
+// Handles the type conversion component of the TypeConversion. This updates
+// conversion patterns that used the original Quant types to be updated to
+// the non-quant variants.
+class GenericTypeConvert : public ConversionPattern {
+public:
+ GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<Type> newResults;
+ if (isa<FunctionOpInterface>(op)) {
+ return failure();
+ }
+
+ (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, op->getAttrs(), op->getSuccessors());
+ for (Region &r : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ (void)getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result);
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+static bool isIllegalType(Type type) {
+ if (IntegerType ity = llvm::dyn_cast<IntegerType>(type))
+ return !ity.isSignless();
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
+ return isIllegalType(shapedType.getElementType());
+ }
+ return false;
+}
+
+void StripSignednessPass::runOnOperation() {
+ IntegerTypeConverter converter;
+ ConversionTarget target(getContext());
+
+ // Operations are legal if they don't contain any illegal type.
+ target.markUnknownOpDynamicallyLegal([](Operation *op) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (Type type : funcOp.getArgumentTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ for (Type type : funcOp.getResultTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ }
+ for (Type type : op->getResultTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ for (Type type : op->getOperandTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ return true;
+ });
+
+ auto *ctx = &getContext();
+
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<GenericTypeConvert>(ctx, converter);
+ populateFunctionOpInterfaceTypeConversionPattern(
+ getOperation()->getName().getStringRef(), patterns, converter);
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass() {
+ return std::make_unique<StripSignednessPass>();
+}
+
+} // namespace mlir::iree_compiler