Limit the latency of e2e matmul tests  (#15180)

Now that I have an AMD workstation, I'm running e2e matmul tests with
Vulkan for the first time, and I'm horrified how slow it is. Test
latencies ~ 1 minute.

Profiling, found that that's due to the slow reference matmul code that
this is comparing against. And it's got nothing to do with GPU, it's
just that the GPU e2e matmul tests are the only ones exercising really
large sizes. It's roughly equally slow on all element types.

There isn't a perfect solution to that problem. Tests need to compare
against some reference. Golden test data is a pain to maintain.
"Reference" implementations that are actually optimized are a
combination of less reliable and more maintenance. And so on.

I do believe though that test latency is very important as it affects
everyone's productivity.

This PR adds a command-line flag to the e2e matmul test runner:

```c
IREE_FLAG(
    int32_t, max_elements_to_check, 10000,
    "Maximum number of matrix elements to check for each matmul. For larger "
    "matrices, only every n-th element will be checked for some n chosed to "
    "stay just under that threshold and to avoid being a divisor of the inner "
    "dimension size to avoid special patterns. As the check uses a slow "
    "reference implementation, this is a trade-off between test latency and "
    "coverage. The value 0 means check all elements.");
```

Timing this test command:
```
ctest -R e2e_matmul_.*_vulkan -j32
```

Before: total latency 48 seconds,  `266.22 sec*proc`.
After: total latency 3.7 seconds, `20.43 sec*proc`.

=> 13x latency shrink
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index d1082a2..83b7343 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -27,6 +27,14 @@
 IREE_FLAG(
     float, acceptable_fp_delta, 1e-5,
     "Maximum absolute difference allowed with inexact floating point results.");
+IREE_FLAG(
+    int32_t, max_elements_to_check, 10000,
+    "Maximum number of matrix elements to check for each matmul. For larger "
+    "matrices, only every n-th element will be checked for some n chosed to "
+    "stay just under that threshold and to avoid being a divisor of the inner "
+    "dimension size to avoid special patterns. As the check uses a slow "
+    "reference implementation, this is a trade-off between test latency and "
+    "coverage. The value 0 means check all elements.");
 
 IREE_FLAG(bool, trace_execution, false, "Traces VM execution to stderr.");
 
@@ -552,7 +560,8 @@
 
 // Reference matmul implementation, used to compare matmul results against.
 static iree_status_t reference_matmul(iree_vm_list_t* input_list,
-                                      iree_hal_buffer_view_t* result) {
+                                      iree_hal_buffer_view_t* result,
+                                      int compute_every) {
   iree_hal_buffer_view_t* lhs = NULL;
   iree_hal_buffer_view_t* rhs = NULL;
   iree_hal_buffer_view_t* acc = NULL;
@@ -581,8 +590,11 @@
   iree_hal_element_type_t lhs_type = iree_hal_buffer_view_element_type(lhs);
   iree_hal_element_type_t rhs_type = iree_hal_buffer_view_element_type(rhs);
   iree_hal_element_type_t acc_type = iree_hal_buffer_view_element_type(result);
+  int count = 0;
   for (iree_hal_dim_t m = 0; m < m_size; ++m) {
     for (iree_hal_dim_t n = 0; n < n_size; ++n) {
+      if (++count < compute_every) continue;
+      count = 0;
       reference_matmul_element(m_size, k_size, n_size, lhs_type, rhs_type,
                                acc_type, lhs_mapping.contents.data,
                                rhs_mapping.contents.data,
@@ -792,8 +804,11 @@
     iree_hal_dim_t col, iree_hal_buffer_view_t* lhs,
     iree_hal_buffer_view_t* rhs, iree_hal_buffer_view_t* acc,
     iree_hal_buffer_view_t* actual_result,
-    iree_hal_buffer_view_t* expected_result) {
-  if (!file) {
+    iree_hal_buffer_view_t* expected_result, int check_every) {
+  if (!file || check_every > 1) {
+    // No logging of errors with check_every>1 as most of the reference matrix
+    // elements have not been computed. The caller is expected to retry with
+    // check_every=1.
     return iree_make_status(IREE_STATUS_ABORTED);
   }
   fprintf(file,
@@ -870,7 +885,7 @@
     iree_hal_dim_t n_size, iree_hal_buffer_view_t* lhs,
     iree_hal_buffer_view_t* rhs, iree_hal_buffer_view_t* acc,
     iree_hal_buffer_view_t* actual_result,
-    iree_hal_buffer_view_t* expected_result) {
+    iree_hal_buffer_view_t* expected_result, int check_every) {
   iree_hal_buffer_mapping_t actual_result_mapping;
   iree_hal_buffer_mapping_t expected_result_mapping;
   IREE_RETURN_IF_ERROR(map_host_local_row_major_data(
@@ -879,8 +894,11 @@
       expected_result, IREE_HAL_MEMORY_ACCESS_READ, &expected_result_mapping));
   iree_hal_element_type_t result_type =
       iree_hal_buffer_view_element_type(actual_result);
+  int count = 0;
   for (iree_hal_dim_t m = 0; m < m_size; ++m) {
     for (iree_hal_dim_t n = 0; n < n_size; ++n) {
+      if (++count < check_every) continue;
+      count = 0;
       iree_e2e_test_value_t actual_value =
           read_matrix_element(m_size, n_size, result_type,
                               actual_result_mapping.contents.data, m, n);
@@ -890,7 +908,7 @@
       if (!matmul_result_elements_agree(actual_value, expected_value)) {
         return check_matmul_failure(file, actual_value, expected_value, m, n,
                                     lhs, rhs, acc, actual_result,
-                                    expected_result);
+                                    expected_result, check_every);
       }
     }
   }
@@ -919,8 +937,31 @@
   IREE_RETURN_IF_ERROR(get_matmul_sizes(lhs, rhs, acc, actual_result, &m_size,
                                         &k_size, &n_size));
 
-  return check_matmul_results_impl(file, m_size, k_size, n_size, lhs, rhs, acc,
-                                   actual_result, expected_result);
+  int check_every = 1;
+  if (FLAG_max_elements_to_check) {
+    check_every = (iree_hal_buffer_view_element_count(actual_result) +
+                   FLAG_max_elements_to_check - 1) /
+                  FLAG_max_elements_to_check;
+    if (check_every < 1) check_every = 1;
+    if (check_every > 1)
+      while ((n_size % check_every) == 0) ++check_every;
+  }
+
+  IREE_CHECK_OK(reference_matmul(input_list, expected_result, check_every));
+
+  iree_status_t status =
+      check_matmul_results_impl(file, m_size, k_size, n_size, lhs, rhs, acc,
+                                actual_result, expected_result, check_every);
+
+  if (!iree_status_is_ok(status) && check_every > 1) {
+    // If we got a failure with check_every>1, that didn't log a useful
+    // numerical summary, as most of the reference matrix entries hadn't been
+    // computed. Rerun now with check_every=1 to get that numerical logging.
+    status = check_matmul_results_impl(file, m_size, k_size, n_size, lhs, rhs,
+                                       acc, actual_result, expected_result, 1);
+  }
+
+  return status;
 }
 
 /*****************************************************************************
@@ -1040,9 +1081,6 @@
                                                host_actual_result,
                                                &host_expected_result));
 
-  // Use the reference matmul implementation to fill host_expected_result
-  IREE_CHECK_OK(reference_matmul(host_inputs, host_expected_result));
-
   // Check that host_actual_result and host_expected_result agree.
   iree_status_t status = check_matmul_results(
       file, host_inputs, host_actual_result, host_expected_result);