compiler/plugins/input/TOSA: fix: TOSA arith lowering must handle apply scale introduced by linalg lowering (#24121)

TOSA to linalg lowering may (re-)introduce some additional TOSA
operations as not all TOSA operations can be lowering to linalg. These
need a following run of TOSA to arith lowering for element wise
operations. The current pass pipeline is missing the option to enable
lowering of tosa.apply_scale, which may be introduced during the
lowering to linalg. This causes errors in the later stage of the
compilation flow such as vectorization.

Signed-off-by: Florian Walbroel <walbroel@roofline.ai>
diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
index b0fd806..9208c5b 100644
--- a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
@@ -63,7 +63,10 @@
       iree_compiler::createConverti48Toi64Pass());
 
   // Sometimes we generate more TOSA operations during the lowering to linalg.
-  passManager.addNestedPass<func::FuncOp>(createTosaToArithPass());
+  TosaToArithPassOptions tosaToArithPassOptions;
+  tosaToArithPassOptions.includeApplyRescale = true;
+  passManager.addNestedPass<func::FuncOp>(
+      createTosaToArithPass(tosaToArithPassOptions));
   passManager.addNestedPass<func::FuncOp>(createTosaToTensorPass());
 
   passManager.addNestedPass<func::FuncOp>(
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
index 9c9e6a6..ce2ef38 100644
--- a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
@@ -18,6 +18,7 @@
         # keep sorted
         [
             "apply_pdl_patterns_tosa.mlir",
+            "apply_scale_lowering.mlir",
             "auto_input_conversion.mlir",
             "convert_i48_to_i64.mlir",
             "strip_signedness.mlir",
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
index 6c58e41..799d324 100644
--- a/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
@@ -15,6 +15,7 @@
     lit
   SRCS
     "apply_pdl_patterns_tosa.mlir"
+    "apply_scale_lowering.mlir"
     "auto_input_conversion.mlir"
     "convert_i48_to_i64.mlir"
     "strip_signedness.mlir"
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/apply_scale_lowering.mlir b/compiler/plugins/input/TOSA/InputConversion/test/apply_scale_lowering.mlir
new file mode 100644
index 0000000..6e996a5
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/apply_scale_lowering.mlir
@@ -0,0 +1,29 @@
+// RUN: iree-compile --compile-to=input --split-input-file %s | FileCheck %s
+
+// Make sure that tosa.apply_scale operations generated during TOSA to
+// linalg lowering, are properly lowered during TOSA to arith lowering
+
+// tosa.mul with a non-zero integer shift lowers through tosa.apply_scale
+//
+// CHECK-LABEL: util.func public @shifted_mul
+// CHECK-NOT: tosa.apply_scale
+// CHECK: return
+func.func @shifted_mul(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+  %shift = "tosa.const"() {values = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  return %0 : tensor<4xi32>
+}
+
+// -----
+
+// Quantized tosa.avg_pool2d also lowers through tosa.apply_scale
+//
+// CHECK-LABEL: util.func public @quantized_avg_pool
+// CHECK-NOT: tosa.apply_scale
+// CHECK: return
+func.func @quantized_avg_pool(%arg0: tensor<1x2x4x1xi8>) -> tensor<1x1x3x1xi8> {
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array<i64: 2, 2>, stride = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>} : (tensor<1x2x4x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x3x1xi8>
+  return %0 : tensor<1x1x3x1xi8>
+}