Tweaks to e2e matmul tests (#15930)
This restricts the range of random `bf16` test matrix elements, so that
we avoid larger-accumulation-size matmul tests getting accumulators >=
256, at which point integers are no longer exactly representable in
`bf16` and accumulation becomes very inaccurate. This is what was
causing test failures on large tests specifically in
https://github.com/openxla/iree/pull/15911 .
This also fixes the computation of the window of the test matrices that
is displayed on error. The problem arose when diagnosing a matmul test
error in cases where M and N are both small, and K is larger. The new
window computation is also just simpler.
diff --git a/runtime/src/iree/tooling/trace_replay.c b/runtime/src/iree/tooling/trace_replay.c
index 5f2350a..7bfdcf7 100644
--- a/runtime/src/iree/tooling/trace_replay.c
+++ b/runtime/src/iree/tooling/trace_replay.c
@@ -392,10 +392,13 @@
case IREE_HAL_ELEMENT_TYPE_INT_16:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
- case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
*min = -4;
*max = +4;
break;
+ case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
+ *min = -2;
+ *max = +2;
+ break;
case IREE_HAL_ELEMENT_TYPE_UINT_16:
*min = 0;
*max = +4;
diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c
index c2a1fc9..5d76c48 100644
--- a/tools/iree-e2e-matmul-test.c
+++ b/tools/iree-e2e-matmul-test.c
@@ -828,12 +828,8 @@
int m_end = iree_min(m_size, row + context);
int n_start = iree_max(0, (int)col - (int)context);
int n_end = iree_min(n_size, col + context);
- // We have a lot more freedom to pick k_start, k_end, since these parameters
- // only affect which regions of the input lhs and rhs matrices are printed.
- // If we were only testing random lhs and rhs, we would just pick
- // k_start = 0 and any reasonable k_end value.
- int k_start = iree_max(0, iree_min(m_start, n_start));
- int k_end = iree_min(k_size, iree_max(m_end, n_end));
+ int k_start = 0;
+ int k_end = iree_min(k_size, 2 * context);
// [k_start, k_end) could be arbitrarily long at this point. Constrain it a
// bit to avoid huge output.
k_end = iree_min(k_end, k_start + 4 * context);