e2e test for linalg.quantized_matmul (vs linalg.matmul + generics) (#8370)

This test compares the output of linalg.quantized_matmul to an equivalent
computation where the matmul part itself is done by linalg.matmul and the
zero-points are handled by additional linalg.generic ops.

Reference: Section 2.3 of https://arxiv.org/abs/1712.05877.
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index bc0e05b..9a81eda 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -53,6 +53,7 @@
             "dynamic_linalg_matmul_on_tensors_fuse_0.mlir",
             "dynamic_linalg_matmul_on_tensors_fuse_1.mlir",
             "dynamic_linalg_matmul_on_tensors_fuse_2.mlir",
+            "linalg_quantized_matmul_vs_linalg_matmul.mlir",
             "lowering_config.mlir",
         ] + BACKEND_TESTS,
     ),
@@ -68,6 +69,7 @@
 iree_check_single_backend_test_suite(
     name = "check_regression_dylib-llvm-aot",
     srcs = [
+        "linalg_quantized_matmul_vs_linalg_matmul.mlir",
         "lowering_config.mlir",
     ] + BACKEND_TESTS,
     compiler_flags = ["-iree-input-type=mhlo"],
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index 31cc3df..920dfec 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -43,6 +43,7 @@
     "dynamic_torch_index_select_vector.mlir"
     "linalg_ext_ops.mlir"
     "linalg_ops.mlir"
+    "linalg_quantized_matmul_vs_linalg_matmul.mlir"
     "lowering_config.mlir"
   TARGET_BACKEND
     "dylib-llvm-aot"
diff --git a/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir b/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir
new file mode 100644
index 0000000..52b117a
--- /dev/null
+++ b/iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir
@@ -0,0 +1,121 @@
+// This test compares the output of linalg.quantized_matmul to an equivalent
+// computation where the matmul part itself is done by linalg.matmul and the
+// zero-points are handled by additional linalg.generic ops.
+//
+// Reference: Section 2.3 of https://arxiv.org/abs/1712.05877.
+
+// Equivalent to linalg.quantized_matmul, but not using linalg.quantized_matmul
+func private @quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>,  %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<3x5xi32>) -> tensor<3x5xi32> {
+  // compute the matmul itself, which would be the end result already in the case
+  // where both zero-point values %lhs_zp and %rhs_zp are zero.
+  %matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<3x4xi8>, tensor<4x5xi8>) outs(%acc : tensor<3x5xi32>) -> tensor<3x5xi32>
+
+  %c_0 = arith.constant 0 : i32
+  %k_size = arith.constant 4 : i32  // = dim 1 of %lhs = dim 0 of %rhs
+
+  // compute the sums along rows of %lhs.
+  %lhs_i32 = arith.extsi %lhs : tensor<3x4xi8> to tensor<3x4xi32>
+  %init_lhs_sums_uninitialized = linalg.init_tensor [3] : tensor<3xi32>
+  %zero_lhs_sums = linalg.fill(%c_0, %init_lhs_sums_uninitialized) : i32, tensor<3xi32> -> tensor<3xi32>
+  %lhs_sums = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%lhs_i32 : tensor<3x4xi32>)
+      outs(%zero_lhs_sums : tensor<3xi32>) {
+      ^bb0(%arg0: i32, %arg1: i32) :
+          %1 = arith.addi %arg0, %arg1 : i32
+          linalg.yield %1 : i32
+      } -> tensor<3xi32>
+
+  // compute the sums along columns of %rhs.
+  %rhs_i32 = arith.extsi %rhs : tensor<4x5xi8> to tensor<4x5xi32>
+  %init_rhs_sums_uninitialized = linalg.init_tensor [5] : tensor<5xi32>
+  %zero_rhs_sums = linalg.fill(%c_0, %init_rhs_sums_uninitialized) : i32, tensor<5xi32> -> tensor<5xi32>
+  %rhs_sums = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                       affine_map<(d0, d1) -> (d1)>],
+      iterator_types = ["reduction", "parallel"]}
+      ins(%rhs_i32 : tensor<4x5xi32>)
+      outs(%zero_rhs_sums : tensor<5xi32>) {
+      ^bb0(%arg0: i32, %arg1: i32) :
+          %1 = arith.addi %arg0, %arg1 : i32
+          linalg.yield %1 : i32
+      } -> tensor<5xi32>
+
+  // add all the terms together.
+  %init_acc_uninitialized =  linalg.init_tensor [3, 5] : tensor<3x5xi32>
+  %quantized_matmul_from_matmul_result = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1) -> (d0, d1)>,
+        affine_map<(d0, d1) -> (d0)>,
+        affine_map<(d0, d1) -> (d1)>,
+        affine_map<(d0, d1) -> ()>,
+        affine_map<(d0, d1) -> ()>,
+        affine_map<(d0, d1) -> ()>,
+        affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%matmul_result, %lhs_sums, %rhs_sums, %lhs_zp, %rhs_zp, %k_size : tensor<3x5xi32>, tensor<3xi32>, tensor<5xi32>, i32, i32, i32)
+      outs(%init_acc_uninitialized : tensor<3x5xi32>) {
+      ^bb0(%matmul_result_val : i32, %lhs_sums_val: i32, %rhs_sums_val: i32, %lhs_zp_val: i32, %rhs_zp_val: i32, %k : i32, %acc_val: i32) :
+          %linear_term_in_rhs_zp = arith.muli %lhs_sums_val, %rhs_zp_val : i32
+          %linear_term_in_lhs_zp = arith.muli %rhs_sums_val, %lhs_zp_val : i32
+          %linear_term = arith.addi %linear_term_in_rhs_zp, %linear_term_in_lhs_zp : i32
+          %product_of_zp = arith.muli %lhs_zp_val, %rhs_zp_val : i32
+          %quadratic_term = arith.muli %k, %product_of_zp : i32
+          %corrected_for_linear_term = arith.subi %matmul_result_val, %linear_term : i32
+          %corrected = arith.addi %corrected_for_linear_term, %quadratic_term : i32
+          linalg.yield %corrected : i32
+      } -> tensor<3x5xi32>
+  return %quantized_matmul_from_matmul_result : tensor<3x5xi32>
+}
+
+// Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_3x4x5
+func private @check_one_quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<3x5xi32>) {
+    %result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) outs(%acc : tensor<3x5xi32>) -> tensor<3x5xi32>
+    %result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_3x4x5(%lhs, %rhs, %lhs_zp, %rhs_zp, %acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> tensor<3x5xi32>
+    check.expect_eq(%result_of_quantized_matmul, %result_of_quantized_matmul_as_matmul) : tensor<3x5xi32>
+    return
+}
+
+func @test_quantized_matmul_as_matmul_3x4x5() {
+  %lhs_3x4_1 = util.unfoldable_constant dense<[
+      [1, 2, 3, 4],
+      [5, 6, 7, 8],
+      [9, 10, 11, 12]]> : tensor<3x4xi8>
+  %rhs_4x5_1 = util.unfoldable_constant dense<[
+      [5, 4, 3, 2, 9],
+      [1, 0, -1, -2, 8],
+      [-3, -4, -5, -6, 7],
+      [2, 3, 5, 7, 11]]> : tensor<4x5xi8>
+  // matrices with larger values including the interval bounds -128 and +127.
+  %lhs_3x4_2 = util.unfoldable_constant dense<[
+      [127, -128, 0, 51],
+      [-47, 101, -119, 0],
+      [-128, 89, -63, 127]]> : tensor<3x4xi8>
+  %rhs_4x5_2 = util.unfoldable_constant dense<[
+      [123, -125, 127, -128, 91],
+      [-70, 37, 0, -40, 57],
+      [-128, 127, -121, -100, 99],
+      [127, 105, 83, 51, -128]]> : tensor<4x5xi8>
+  %c_0 = arith.constant 0 : i32
+  %c_minus2 = arith.constant -2 : i32
+  %c_plus3 = arith.constant 3 : i32
+  %c_plus41 = arith.constant 41 : i32
+  %c_minus57 = arith.constant -57 : i32
+  %c_minus128 = arith.constant -128 : i32
+  %c_plus127 = arith.constant 127 : i32
+
+  %init_acc_uninitialized =  linalg.init_tensor [3, 5] : tensor<3x5xi32>
+  %zero_acc = linalg.fill(%c_0, %init_acc_uninitialized) : i32, tensor<3x5xi32> -> tensor<3x5xi32>
+  // Test special case: both zero points are 0
+  call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_0, %c_0, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+  // Test special cases: one of the zero points is 0
+  call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_0, %c_plus3, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+  call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_0, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+  // Test general cases: both zero points are nonzero
+  call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_plus3, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+  call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_plus41, %c_minus57, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+  call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_minus128, %c_plus127, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
+  return
+}