Move TOSA StripSignednessPass from Flow to tosa-iree plugin. (#15541)
Follow-up to https://github.com/openxla/iree/pull/15495, part of
https://github.com/openxla/iree/issues/15468
This pass is only used in TOSA import and has no dependencies on the
Flow dialect (https://github.com/openxla/iree/pull/7192 probably
predated a bunch of the organization we later applied to some dialects)
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
index 0feea75..32560ba 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
@@ -55,6 +55,7 @@
srcs = [
"Converti48Toi64.cpp",
"Passes.cpp",
+ "StripSignedness.cpp",
"TosaToLinalgExt.cpp",
"VerifyCompilerTOSAInputLegality.cpp",
],
@@ -68,11 +69,11 @@
deps = [
":PassHeaders",
":PassesIncGen",
- "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
"//compiler/src/iree/compiler/InputConversion/Common",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Pass",
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
index 68e36ce..e6e614c 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
@@ -52,6 +52,7 @@
SRCS
"Converti48Toi64.cpp"
"Passes.cpp"
+ "StripSignedness.cpp"
"TosaToLinalgExt.cpp"
"VerifyCompilerTOSAInputLegality.cpp"
DEPS
@@ -60,6 +61,7 @@
IREELinalgExtDialect
MLIRArithDialect
MLIRFuncDialect
+ MLIRFunctionInterfaces
MLIRIR
MLIRLinalgDialect
MLIRPass
@@ -71,7 +73,6 @@
MLIRTosaToSCF
MLIRTosaToTensor
MLIRTransforms
- iree::compiler::Dialect::Flow::Transforms
iree::compiler::InputConversion::Common
PUBLIC
)
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
index 532cb9d..e974685 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
@@ -6,7 +6,6 @@
#include "tosa-iree/InputConversion/Passes.h"
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/InputConversion/Common/Passes.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
@@ -14,6 +13,7 @@
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
@@ -60,10 +60,8 @@
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
- // TODO(scotttodd): move IREE::Flow::createStripSignednessPass into plugin
- // (should in-tree plugins even depend on other in-tree code?)
passManager.addNestedPass<func::FuncOp>(
- IREE::Flow::createStripSignednessPass());
+ iree_compiler::createStripSignednessPass());
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
passManager.addNestedPass<func::FuncOp>(
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
index 4eb693a..06eb634 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
@@ -24,22 +24,26 @@
void registerTOSAConversionPassPipeline();
//------------------------------------------------------------------------------
-// Conversions into Linalg
+// Conversions from TOSA into Linalg and other core IREE dialects
//------------------------------------------------------------------------------
-// Verifies a module being input to the core compiler pipeline only contains
-// IR structures that are supported at that level.
-std::unique_ptr<OperationPass<ModuleOp>>
-createVerifyCompilerTOSAInputLegality();
-
// Set of patterns for materializing TOSA operations to linalg_ext.
void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns);
-// Creates a pass that converts TOSA operations to linalg_ext.
+// Converts i48 to i64.
+std::unique_ptr<OperationPass<func::FuncOp>> createConverti48Toi64();
+
+// Strips the signed/unsigned portion off of tensors.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass();
+
+// Converts TOSA operations to linalg_ext.
std::unique_ptr<OperationPass<func::FuncOp>> createTosaToLinalgExt();
-// Creates a pass that converts i48 to i64.
-std::unique_ptr<OperationPass<func::FuncOp>> createConverti48Toi64();
+// Verifies that a module only contains IR structures that are supported by the
+// core compiler.
+std::unique_ptr<OperationPass<ModuleOp>>
+createVerifyCompilerTOSAInputLegality();
//===----------------------------------------------------------------------===//
// Register all Passes
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
index 3a6bbf9..ac2d8e8 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
@@ -9,6 +9,18 @@
include "mlir/Pass/PassBase.td"
+def Converti48Toi64 :
+ Pass<"iree-tosa-convert-i48-to-i64", "mlir::func::FuncOp"> {
+ let summary = "Converts all i48s to i64s";
+ let constructor = "mlir::iree_compiler::createConverti48Toi64()";
+}
+
+def StripSignedness :
+ InterfacePass<"iree-tosa-strip-signedness", "mlir::FunctionOpInterface"> {
+ let summary = "Legalizes ui tensors constants to uis";
+ let constructor = "mlir::iree_compiler::createStripSignednessPass()";
+}
+
def TosaToLinalgExt :
Pass<"iree-tosa-to-linalg-ext", "mlir::func::FuncOp"> {
let summary = "Convert TOSA operations to their equivalent linalg-ext operations.";
@@ -27,10 +39,4 @@
let constructor = "mlir::iree_compiler::createVerifyCompilerTOSAInputLegality()";
}
-def Converti48Toi64 :
- Pass<"iree-convert-i48-to-i64", "mlir::func::FuncOp"> {
- let summary = "Converts all i48s to i64s";
- let constructor = "mlir::iree_compiler::createConverti48Toi64()";
-}
-
#endif // TOSA_IREE_INPUTCONVERSION_PASSES
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/StripSignedness.cpp
new file mode 100644
index 0000000..787626d
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/StripSignedness.cpp
@@ -0,0 +1,135 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tosa-iree/InputConversion/PassDetail.h"
+#include "tosa-iree/InputConversion/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+class StripSignednessPass : public StripSignednessBase<StripSignednessPass> {
+public:
+ explicit StripSignednessPass() {}
+ void runOnOperation() override;
+};
+
+class IntegerTypeConverter : public TypeConverter {
+public:
+ static Type convertType(Type type) {
+ if (auto iType = llvm::dyn_cast<IntegerType>(type)) {
+ if (!iType.isSignless()) {
+ return IntegerType::get(type.getContext(),
+ iType.getIntOrFloatBitWidth());
+ }
+ }
+ return type;
+ }
+ static Type convertTensor(RankedTensorType type) {
+ auto newType = RankedTensorType::get(type.getShape(),
+ convertType(type.getElementType()));
+ return newType;
+ }
+ explicit IntegerTypeConverter() {
+ addConversion([](Type type) { return convertType(type); });
+ addConversion(convertTensor);
+ }
+};
+
+// Handles the type conversion component of the TypeConversion. This updates
+// conversion patterns that used the original Quant types to be updated to
+// the non-quant variants.
+class GenericTypeConvert : public ConversionPattern {
+public:
+ GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+ : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<Type> newResults;
+ if (isa<FunctionOpInterface>(op)) {
+ return failure();
+ }
+
+ (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, op->getAttrs(), op->getSuccessors());
+ for (Region &r : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ (void)getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result);
+ rewriter.applySignatureConversion(newRegion, result);
+ }
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+static bool isIllegalType(Type type) {
+ if (IntegerType ity = llvm::dyn_cast<IntegerType>(type))
+ return !ity.isSignless();
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
+ return isIllegalType(shapedType.getElementType());
+ }
+ return false;
+}
+
+void StripSignednessPass::runOnOperation() {
+ IntegerTypeConverter converter;
+ ConversionTarget target(getContext());
+
+ // Operations are legal if they don't contain any illegal type.
+ target.markUnknownOpDynamicallyLegal([](Operation *op) {
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ for (Type type : funcOp.getArgumentTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ for (Type type : funcOp.getResultTypes()) {
+ if (isIllegalType(type))
+ return false;
+ }
+ }
+ for (Type type : op->getResultTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ for (Type type : op->getOperandTypes()) {
+ if (type && isIllegalType(type))
+ return false;
+ }
+ return true;
+ });
+
+ auto *ctx = &getContext();
+
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<GenericTypeConvert>(ctx, converter);
+ populateFunctionOpInterfaceTypeConversionPattern(
+ getOperation()->getName().getStringRef(), patterns, converter);
+
+ if (failed(
+ applyFullConversion(getOperation(), target, std::move(patterns)))) {
+ signalPassFailure();
+ }
+}
+
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass() {
+ return std::make_unique<StripSignednessPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
index 234b517..b1a0a24 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
@@ -17,6 +17,7 @@
srcs = enforce_glob(
[
"convert_i48_to_i64.mlir",
+ "strip_signedness.mlir",
"tosa_to_linalg_ext.mlir",
"verify_compiler_tosa_input_legality.mlir",
],
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
index c226bc4..06ddb1c 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"convert_i48_to_i64.mlir"
+ "strip_signedness.mlir"
"tosa_to_linalg_ext.mlir"
"verify_compiler_tosa_input_legality.mlir"
TOOLS
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
index 48c1df2..e0bd07c 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-convert-i48-to-i64 --verify-diagnostics %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-tosa-convert-i48-to-i64 --verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: @test_all_i48_converted
func.func @test_all_i48_converted(%arg0: tensor<2x2xi48>) -> tensor<2x2xi48> {
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/strip_signedness.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/strip_signedness.mlir
new file mode 100644
index 0000000..cf30e57
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/strip_signedness.mlir
@@ -0,0 +1,22 @@
+
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-tosa-strip-signedness))' %s | FileCheck %s
+
+// CHECK-LABEL: @strip_signedness_arg
+// CHECK-SAME: tensor<4xi8>
+func.func @strip_signedness_arg(%arg0 : tensor<4xui8>) -> (tensor<4xui8>) {
+ // CHECK: return
+ // CHECK-SAME: tensor<4xi8>
+ return %arg0 : tensor<4xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_signedness_const
+func.func @strip_signedness_const() -> (tensor<4xi8>) {
+ // CHECK: constant
+ // CHECK-SAME: tensor<4xi8>
+ %0 = arith.constant dense<[0, 2, 3, 7]> : tensor<4xi8>
+ // CHECK: return
+ // CHECK-SAME: tensor<4xi8>
+ return %0 : tensor<4xi8>
+}