compiler/plugins/input/TOSA: fix: TOSA arith lowering must handle apply scale introduced by linalg lowering (#24121)
TOSA to linalg lowering may (re-)introduce some additional TOSA
operations as not all TOSA operations can be lowering to linalg. These
need a following run of TOSA to arith lowering for element wise
operations. The current pass pipeline is missing the option to enable
lowering of tosa.apply_scale, which may be introduced during the
lowering to linalg. This causes errors in the later stage of the
compilation flow such as vectorization.
Signed-off-by: Florian Walbroel <walbroel@roofline.ai>
diff --git a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
index b0fd806..9208c5b 100644
--- a/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
+++ b/compiler/plugins/input/TOSA/InputConversion/Passes.cpp
@@ -63,7 +63,10 @@
iree_compiler::createConverti48Toi64Pass());
// Sometimes we generate more TOSA operations during the lowering to linalg.
- passManager.addNestedPass<func::FuncOp>(createTosaToArithPass());
+ TosaToArithPassOptions tosaToArithPassOptions;
+ tosaToArithPassOptions.includeApplyRescale = true;
+ passManager.addNestedPass<func::FuncOp>(
+ createTosaToArithPass(tosaToArithPassOptions));
passManager.addNestedPass<func::FuncOp>(createTosaToTensorPass());
passManager.addNestedPass<func::FuncOp>(
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
index 9c9e6a6..ce2ef38 100644
--- a/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
+++ b/compiler/plugins/input/TOSA/InputConversion/test/BUILD.bazel
@@ -18,6 +18,7 @@
# keep sorted
[
"apply_pdl_patterns_tosa.mlir",
+ "apply_scale_lowering.mlir",
"auto_input_conversion.mlir",
"convert_i48_to_i64.mlir",
"strip_signedness.mlir",
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
index 6c58e41..799d324 100644
--- a/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
+++ b/compiler/plugins/input/TOSA/InputConversion/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"apply_pdl_patterns_tosa.mlir"
+ "apply_scale_lowering.mlir"
"auto_input_conversion.mlir"
"convert_i48_to_i64.mlir"
"strip_signedness.mlir"
diff --git a/compiler/plugins/input/TOSA/InputConversion/test/apply_scale_lowering.mlir b/compiler/plugins/input/TOSA/InputConversion/test/apply_scale_lowering.mlir
new file mode 100644
index 0000000..6e996a5
--- /dev/null
+++ b/compiler/plugins/input/TOSA/InputConversion/test/apply_scale_lowering.mlir
@@ -0,0 +1,29 @@
+// RUN: iree-compile --compile-to=input --split-input-file %s | FileCheck %s
+
+// Make sure that tosa.apply_scale operations generated during TOSA to
+// linalg lowering, are properly lowered during TOSA to arith lowering
+
+// tosa.mul with a non-zero integer shift lowers through tosa.apply_scale
+//
+// CHECK-LABEL: util.func public @shifted_mul
+// CHECK-NOT: tosa.apply_scale
+// CHECK: return
+func.func @shifted_mul(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
+ %shift = "tosa.const"() {values = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ return %0 : tensor<4xi32>
+}
+
+// -----
+
+// Quantized tosa.avg_pool2d also lowers through tosa.apply_scale
+//
+// CHECK-LABEL: util.func public @quantized_avg_pool
+// CHECK-NOT: tosa.apply_scale
+// CHECK: return
+func.func @quantized_avg_pool(%arg0: tensor<1x2x4x1xi8>) -> tensor<1x1x3x1xi8> {
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, kernel = array<i64: 2, 2>, stride = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>} : (tensor<1x2x4x1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x3x1xi8>
+ return %0 : tensor<1x1x3x1xi8>
+}