Move TOSA StripSignednessPass from Flow to tosa-iree plugin. (#15541)

Follow-up to https://github.com/openxla/iree/pull/15495, part of
https://github.com/openxla/iree/issues/15468

This pass is only used in TOSA import and has no dependencies on the
Flow dialect (https://github.com/openxla/iree/pull/7192 probably
predated a bunch of the organization we later applied to some dialects)
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
index 0feea75..32560ba 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/BUILD.bazel
@@ -55,6 +55,7 @@
     srcs = [
         "Converti48Toi64.cpp",
         "Passes.cpp",
+        "StripSignedness.cpp",
         "TosaToLinalgExt.cpp",
         "VerifyCompilerTOSAInputLegality.cpp",
     ],
@@ -68,11 +69,11 @@
     deps = [
         ":PassHeaders",
         ":PassesIncGen",
-        "//compiler/src/iree/compiler/Dialect/Flow/Transforms",
         "//compiler/src/iree/compiler/InputConversion/Common",
         "//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
         "@llvm-project//mlir:ArithDialect",
         "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:FunctionInterfaces",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:Pass",
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
index 68e36ce..e6e614c 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/CMakeLists.txt
@@ -52,6 +52,7 @@
   SRCS
     "Converti48Toi64.cpp"
     "Passes.cpp"
+    "StripSignedness.cpp"
     "TosaToLinalgExt.cpp"
     "VerifyCompilerTOSAInputLegality.cpp"
   DEPS
@@ -60,6 +61,7 @@
     IREELinalgExtDialect
     MLIRArithDialect
     MLIRFuncDialect
+    MLIRFunctionInterfaces
     MLIRIR
     MLIRLinalgDialect
     MLIRPass
@@ -71,7 +73,6 @@
     MLIRTosaToSCF
     MLIRTosaToTensor
     MLIRTransforms
-    iree::compiler::Dialect::Flow::Transforms
     iree::compiler::InputConversion::Common
   PUBLIC
 )
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
index 532cb9d..e974685 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.cpp
@@ -6,7 +6,6 @@
 
 #include "tosa-iree/InputConversion/Passes.h"
 
-#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
 #include "iree/compiler/InputConversion/Common/Passes.h"
 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
@@ -14,6 +13,7 @@
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Pass/PassManager.h"
 #include "mlir/Pass/PassOptions.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Transforms/Passes.h"
@@ -60,10 +60,8 @@
   passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
   passManager.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
 
-  // TODO(scotttodd): move IREE::Flow::createStripSignednessPass into plugin
-  //     (should in-tree plugins even depend on other in-tree code?)
   passManager.addNestedPass<func::FuncOp>(
-      IREE::Flow::createStripSignednessPass());
+      iree_compiler::createStripSignednessPass());
   passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
 
   passManager.addNestedPass<func::FuncOp>(
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
index 4eb693a..06eb634 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.h
@@ -24,22 +24,26 @@
 void registerTOSAConversionPassPipeline();
 
 //------------------------------------------------------------------------------
-// Conversions into Linalg
+// Conversions from TOSA into Linalg and other core IREE dialects
 //------------------------------------------------------------------------------
 
-// Verifies a module being input to the core compiler pipeline only contains
-// IR structures that are supported at that level.
-std::unique_ptr<OperationPass<ModuleOp>>
-createVerifyCompilerTOSAInputLegality();
-
 // Set of patterns for materializing TOSA operations to linalg_ext.
 void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns);
 
-// Creates a pass that converts TOSA operations to linalg_ext.
+// Converts i48 to i64.
+std::unique_ptr<OperationPass<func::FuncOp>> createConverti48Toi64();
+
+// Strips the signed/unsigned portion off of tensors.
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass();
+
+// Converts TOSA operations to linalg_ext.
 std::unique_ptr<OperationPass<func::FuncOp>> createTosaToLinalgExt();
 
-// Creates a pass that converts i48 to i64.
-std::unique_ptr<OperationPass<func::FuncOp>> createConverti48Toi64();
+// Verifies that a module only contains IR structures that are supported by the
+// core compiler.
+std::unique_ptr<OperationPass<ModuleOp>>
+createVerifyCompilerTOSAInputLegality();
 
 //===----------------------------------------------------------------------===//
 // Register all Passes
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
index 3a6bbf9..ac2d8e8 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/Passes.td
@@ -9,6 +9,18 @@
 
 include "mlir/Pass/PassBase.td"
 
+def Converti48Toi64 :
+    Pass<"iree-tosa-convert-i48-to-i64", "mlir::func::FuncOp"> {
+  let summary = "Converts all i48s to i64s";
+  let constructor = "mlir::iree_compiler::createConverti48Toi64()";
+}
+
+def StripSignedness :
+    InterfacePass<"iree-tosa-strip-signedness", "mlir::FunctionOpInterface"> {
+  let summary = "Legalizes ui tensors constants to uis";
+  let constructor = "mlir::iree_compiler::createStripSignednessPass()";
+}
+
 def TosaToLinalgExt :
     Pass<"iree-tosa-to-linalg-ext", "mlir::func::FuncOp"> {
   let summary = "Convert TOSA operations to their equivalent linalg-ext operations.";
@@ -27,10 +39,4 @@
   let constructor = "mlir::iree_compiler::createVerifyCompilerTOSAInputLegality()";
 }
 
-def Converti48Toi64 :
-    Pass<"iree-convert-i48-to-i64", "mlir::func::FuncOp"> {
-  let summary = "Converts all i48s to i64s";
-  let constructor = "mlir::iree_compiler::createConverti48Toi64()";
-}
-
 #endif // TOSA_IREE_INPUTCONVERSION_PASSES
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/StripSignedness.cpp b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/StripSignedness.cpp
new file mode 100644
index 0000000..787626d
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/StripSignedness.cpp
@@ -0,0 +1,135 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "tosa-iree/InputConversion/PassDetail.h"
+#include "tosa-iree/InputConversion/Passes.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+class StripSignednessPass : public StripSignednessBase<StripSignednessPass> {
+public:
+  explicit StripSignednessPass() {}
+  void runOnOperation() override;
+};
+
+class IntegerTypeConverter : public TypeConverter {
+public:
+  static Type convertType(Type type) {
+    if (auto iType = llvm::dyn_cast<IntegerType>(type)) {
+      if (!iType.isSignless()) {
+        return IntegerType::get(type.getContext(),
+                                iType.getIntOrFloatBitWidth());
+      }
+    }
+    return type;
+  }
+  static Type convertTensor(RankedTensorType type) {
+    auto newType = RankedTensorType::get(type.getShape(),
+                                         convertType(type.getElementType()));
+    return newType;
+  }
+  explicit IntegerTypeConverter() {
+    addConversion([](Type type) { return convertType(type); });
+    addConversion(convertTensor);
+  }
+};
+
+// Handles the type conversion component of the TypeConversion. This updates
+// conversion patterns that used the original Quant types to be updated to
+// the non-quant variants.
+class GenericTypeConvert : public ConversionPattern {
+public:
+  GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
+      : ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<Type> newResults;
+    if (isa<FunctionOpInterface>(op)) {
+      return failure();
+    }
+
+    (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
+    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+                         newResults, op->getAttrs(), op->getSuccessors());
+    for (Region &r : op->getRegions()) {
+      Region *newRegion = state.addRegion();
+      rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin());
+      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+      (void)getTypeConverter()->convertSignatureArgs(
+          newRegion->getArgumentTypes(), result);
+      rewriter.applySignatureConversion(newRegion, result);
+    }
+    Operation *newOp = rewriter.create(state);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
+static bool isIllegalType(Type type) {
+  if (IntegerType ity = llvm::dyn_cast<IntegerType>(type))
+    return !ity.isSignless();
+  if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
+    return isIllegalType(shapedType.getElementType());
+  }
+  return false;
+}
+
+void StripSignednessPass::runOnOperation() {
+  IntegerTypeConverter converter;
+  ConversionTarget target(getContext());
+
+  // Operations are legal if they don't contain any illegal type.
+  target.markUnknownOpDynamicallyLegal([](Operation *op) {
+    if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+      for (Type type : funcOp.getArgumentTypes()) {
+        if (isIllegalType(type))
+          return false;
+      }
+      for (Type type : funcOp.getResultTypes()) {
+        if (isIllegalType(type))
+          return false;
+      }
+    }
+    for (Type type : op->getResultTypes()) {
+      if (type && isIllegalType(type))
+        return false;
+    }
+    for (Type type : op->getOperandTypes()) {
+      if (type && isIllegalType(type))
+        return false;
+    }
+    return true;
+  });
+
+  auto *ctx = &getContext();
+
+  RewritePatternSet patterns(&getContext());
+  patterns.insert<GenericTypeConvert>(ctx, converter);
+  populateFunctionOpInterfaceTypeConversionPattern(
+      getOperation()->getName().getStringRef(), patterns, converter);
+
+  if (failed(
+          applyFullConversion(getOperation(), target, std::move(patterns)))) {
+    signalPassFailure();
+  }
+}
+
+} // namespace
+
+std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
+createStripSignednessPass() {
+  return std::make_unique<StripSignednessPass>();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
index 234b517..b1a0a24 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/BUILD.bazel
@@ -17,6 +17,7 @@
     srcs = enforce_glob(
         [
             "convert_i48_to_i64.mlir",
+            "strip_signedness.mlir",
             "tosa_to_linalg_ext.mlir",
             "verify_compiler_tosa_input_legality.mlir",
         ],
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
index c226bc4..06ddb1c 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "convert_i48_to_i64.mlir"
+    "strip_signedness.mlir"
     "tosa_to_linalg_ext.mlir"
     "verify_compiler_tosa_input_legality.mlir"
   TOOLS
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
index 48c1df2..e0bd07c 100644
--- a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/convert_i48_to_i64.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt --split-input-file --iree-convert-i48-to-i64 --verify-diagnostics %s | FileCheck %s
+// RUN: iree-opt --split-input-file --iree-tosa-convert-i48-to-i64 --verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL: @test_all_i48_converted
 func.func @test_all_i48_converted(%arg0: tensor<2x2xi48>) -> tensor<2x2xi48> {
diff --git a/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/strip_signedness.mlir b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/strip_signedness.mlir
new file mode 100644
index 0000000..cf30e57
--- /dev/null
+++ b/compiler/plugins/input/TOSA/tosa-iree/InputConversion/test/strip_signedness.mlir
@@ -0,0 +1,22 @@
+
+// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-tosa-strip-signedness))' %s | FileCheck %s
+
+// CHECK-LABEL: @strip_signedness_arg
+// CHECK-SAME: tensor<4xi8>
+func.func @strip_signedness_arg(%arg0 : tensor<4xui8>) -> (tensor<4xui8>) {
+    // CHECK: return
+    // CHECK-SAME: tensor<4xi8>
+    return %arg0 : tensor<4xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @strip_signedness_const
+func.func @strip_signedness_const() -> (tensor<4xi8>) {
+    // CHECK: constant
+    // CHECK-SAME: tensor<4xi8>
+    %0 = arith.constant dense<[0, 2, 3, 7]> : tensor<4xi8>
+    // CHECK: return
+    // CHECK-SAME: tensor<4xi8>
+    return %0 : tensor<4xi8>
+}