Update quantized matmul tests to not generate stack allocation (#8562)
This fixes the tests about quantized matmul and matmul comparison.
The function arguments are read-only tensors in IREE, we should not
create test cases that passes zero_acc as function arguments.
diff --git a/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir b/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir
index 58dfe83..4273127 100644
--- a/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir
+++ b/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir
@@ -10,12 +10,15 @@
// Reference: Section 2.3 of https://arxiv.org/abs/1712.05877.
// Equivalent to linalg.quantized_matmul, but not using linalg.quantized_matmul
-func private @quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<3x5xi32>) -> tensor<3x5xi32> {
+func private @quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32) -> tensor<3x5xi32> {
+ %c_0 = arith.constant 0 : i32
+ %init_acc_uninitialized = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ %zero_acc = linalg.fill ins(%c_0 : i32) outs(%init_acc_uninitialized : tensor<3x5xi32>) -> tensor<3x5xi32>
+
// compute the matmul itself, which would be the end result already in the case
// where both zero-point values %lhs_zp and %rhs_zp are zero.
- %matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<3x4xi8>, tensor<4x5xi8>) outs(%acc : tensor<3x5xi32>) -> tensor<3x5xi32>
+ %matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<3x4xi8>, tensor<4x5xi8>) outs(%zero_acc : tensor<3x5xi32>) -> tensor<3x5xi32>
- %c_0 = arith.constant 0 : i32
%k_size = arith.constant 4 : i32 // = dim 1 of %lhs = dim 0 of %rhs
// compute the sums along rows of %lhs.
@@ -49,7 +52,6 @@
} -> tensor<5xi32>
// add all the terms together.
- %init_acc_uninitialized = linalg.init_tensor [3, 5] : tensor<3x5xi32>
%quantized_matmul_from_matmul_result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
@@ -76,11 +78,7 @@
}
// Equivalent to linalg.quantized_matmul, but not using linalg.quantized_matmul
-func private @quantized_matmul_as_matmul_dynamic(%lhs : tensor<?x?xi8>, %rhs : tensor<?x?xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<?x?xi32>) -> tensor<?x?xi32> {
- // compute the matmul itself, which would be the end result already in the case
- // where both zero-point values %lhs_zp and %rhs_zp are zero.
- %matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%acc : tensor<?x?xi32>) -> tensor<?x?xi32>
-
+func private @quantized_matmul_as_matmul_dynamic(%lhs : tensor<?x?xi8>, %rhs : tensor<?x?xi8>, %lhs_zp : i32, %rhs_zp : i32) -> tensor<?x?xi32> {
%c_0_index = arith.constant 0 : index
%c_1_index = arith.constant 1 : index
%m_size = tensor.dim %lhs, %c_0_index : tensor<?x?xi8>
@@ -89,6 +87,12 @@
%k_size_i32 = arith.index_cast %k_size : index to i32
%c_0 = arith.constant 0 : i32
+ %init_acc_uninitialized = linalg.init_tensor [%m_size, %n_size] : tensor<?x?xi32>
+ %zero_acc = linalg.fill ins(%c_0 : i32) outs(%init_acc_uninitialized : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+ // compute the matmul itself, which would be the end result already in the case
+ // where both zero-point values %lhs_zp and %rhs_zp are zero.
+ %matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%zero_acc : tensor<?x?xi32>) -> tensor<?x?xi32>
// compute the sums along rows of %lhs.
%lhs_i32 = arith.extsi %lhs : tensor<?x?xi8> to tensor<?x?xi32>
@@ -121,7 +125,6 @@
} -> tensor<?xi32>
// add all the terms together.
- %init_acc_uninitialized = linalg.init_tensor [%m_size, %n_size] : tensor<?x?xi32>
%quantized_matmul_from_matmul_result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>,
@@ -148,17 +151,29 @@
}
// Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_3x4x5
-func private @check_one_quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<3x5xi32>) {
- %result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) outs(%acc : tensor<3x5xi32>) -> tensor<3x5xi32>
- %result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_3x4x5(%lhs, %rhs, %lhs_zp, %rhs_zp, %acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> tensor<3x5xi32>
+func private @check_one_quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32) {
+ %c_0 = arith.constant 0 : i32
+ %init_acc_uninitialized = linalg.init_tensor [3, 5] : tensor<3x5xi32>
+ %zero_acc = linalg.fill ins(%c_0 : i32) outs(%init_acc_uninitialized : tensor<3x5xi32>) -> tensor<3x5xi32>
+ %result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) outs(%zero_acc : tensor<3x5xi32>) -> tensor<3x5xi32>
+ %result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_3x4x5(%lhs, %rhs, %lhs_zp, %rhs_zp) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> tensor<3x5xi32>
check.expect_eq(%result_of_quantized_matmul, %result_of_quantized_matmul_as_matmul) : tensor<3x5xi32>
return
}
// Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_dynamic
-func private @check_one_quantized_matmul_as_matmul_dynamic(%lhs : tensor<?x?xi8>, %rhs : tensor<?x?xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<?x?xi32>) {
- %result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<?x?xi8>, tensor<?x?xi8>, i32, i32) outs(%acc : tensor<?x?xi32>) -> tensor<?x?xi32>
- %result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_dynamic(%lhs, %rhs, %lhs_zp, %rhs_zp, %acc) : (tensor<?x?xi8>, tensor<?x?xi8>, i32, i32, tensor<?x?xi32>) -> tensor<?x?xi32>
+func private @check_one_quantized_matmul_as_matmul_dynamic(%lhs : tensor<?x?xi8>, %rhs : tensor<?x?xi8>, %lhs_zp : i32, %rhs_zp : i32) {
+ %c_0_index = arith.constant 0 : index
+ %c_1_index = arith.constant 1 : index
+ %m_size = tensor.dim %lhs, %c_0_index : tensor<?x?xi8>
+ %n_size = tensor.dim %rhs, %c_1_index : tensor<?x?xi8>
+
+ %c_0 = arith.constant 0 : i32
+ %init_acc_uninitialized = linalg.init_tensor [%m_size, %n_size] : tensor<?x?xi32>
+ %zero_acc = linalg.fill ins(%c_0 : i32) outs(%init_acc_uninitialized : tensor<?x?xi32>) -> tensor<?x?xi32>
+
+ %result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<?x?xi8>, tensor<?x?xi8>, i32, i32) outs(%zero_acc : tensor<?x?xi32>) -> tensor<?x?xi32>
+ %result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_dynamic(%lhs, %rhs, %lhs_zp, %rhs_zp) : (tensor<?x?xi8>, tensor<?x?xi8>, i32, i32) -> tensor<?x?xi32>
check.expect_eq(%result_of_quantized_matmul, %result_of_quantized_matmul_as_matmul) : tensor<?x?xi32>
return
}
@@ -191,22 +206,19 @@
%c_minus128 = arith.constant -128 : i32
%c_plus127 = arith.constant 127 : i32
- %init_acc_uninitialized = linalg.init_tensor [3, 5] : tensor<3x5xi32>
- %zero_acc = linalg.fill ins(%c_0 : i32) outs(%init_acc_uninitialized : tensor<3x5xi32>) -> tensor<3x5xi32>
// Test special case: both zero points are 0
- call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_0, %c_0, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+ call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_0, %c_0) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> ()
// Test special cases: one of the zero points is 0
- call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_0, %c_plus3, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
- call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_0, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+ call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_0, %c_plus3) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> ()
+ call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_0) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> ()
// Test general cases: both zero points are nonzero
- call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_plus3, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
- call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_plus41, %c_minus57, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
- call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_minus128, %c_plus127, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+ call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_plus3) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> ()
+ call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_plus41, %c_minus57) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> ()
+ call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_minus128, %c_plus127) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) -> ()
%lhs_3x4_dynamic = tensor.cast %lhs_3x4_1 : tensor<3x4xi8> to tensor<?x?xi8>
%rhs_4x5_dynamic = tensor.cast %rhs_4x5_1 : tensor<4x5xi8> to tensor<?x?xi8>
- %zero_acc_dynamic = tensor.cast %zero_acc : tensor<3x5xi32> to tensor<?x?xi32>
- call @check_one_quantized_matmul_as_matmul_dynamic(%lhs_3x4_dynamic, %rhs_4x5_dynamic, %c_minus128, %c_plus127, %zero_acc_dynamic) : (tensor<?x?xi8>, tensor<?x?xi8>, i32, i32, tensor<?x?xi32>) -> ()
+ call @check_one_quantized_matmul_as_matmul_dynamic(%lhs_3x4_dynamic, %rhs_4x5_dynamic, %c_minus128, %c_plus127) : (tensor<?x?xi8>, tensor<?x?xi8>, i32, i32) -> ()
return
}