Simplify indexed load testing

Non-segmented cases are now handled as if they're 1-segmented.

Minor improvement to expected output computation.

Change-Id: Ide24ee6e4e33bcb04966266b9d8c43cf2638723c
diff --git a/tests/cocotb/rvv_load_store_test.py b/tests/cocotb/rvv_load_store_test.py
index bc79775..58f30ce 100644
--- a/tests/cocotb/rvv_load_store_test.py
+++ b/tests/cocotb/rvv_load_store_test.py
@@ -116,69 +116,6 @@
         assert (actual_outputs == expected_outputs).all(), debug_msg
 
 
-async def vector_load_indexed(
-        dut,
-        elf_name: str,
-        cases: list[dict],  # keys: impl, vl, in_bytes, out_size.
-        dtype,
-        index_dtype,
-):
-    """RVV load-store test template for indexed loads.
-
-    Each test performs a gather operation and writes the result to an output.
-    """
-    fixture = await Fixture.Create(dut)
-    r = runfiles.Create()
-    await fixture.load_elf_and_lookup_symbols(
-        r.Rlocation('kelvin_hw/tests/cocotb/rvv/load_store/' + elf_name),
-        ['impl', 'vl', 'in_buf', 'out_buf', 'index_buf'] +
-            list({c['impl'] for c in cases}),
-    )
-
-    rng = np.random.default_rng()
-    for c in tqdm.tqdm(cases):
-        impl = c['impl']
-        vl = c['vl']
-        in_bytes = c['in_bytes']
-        out_size = c['out_size']
-
-        # Don't go beyond the buffer.
-        index_max = in_bytes - np.dtype(dtype).itemsize + 1
-        # TODO(davidgao): currently assuming the vl is supported.
-        # We'll eventually want to test unsupported vl.
-        indices = rng.integers(0, index_max, out_size, dtype=index_dtype)
-        # Index is in bytes so input needs to be in bytes.
-        input_data = rng.integers(0, 256, in_bytes, dtype=np.uint8)
-        # Input needs to be reinterpreted. Note indices in use can reach
-        # beyond index_dtype when dtype is wider than uint8.
-        indices_in_use = np.array([
-            np.arange(x, x + np.dtype(dtype).itemsize)
-            for x in indices[:vl].astype(np.uint32)])
-        expected_outputs = input_data[indices_in_use].view(dtype)[..., 0]
-        sbz = np.zeros(out_size - vl, dtype=dtype)
-        expected_outputs = np.concat((expected_outputs, sbz))
-
-        await fixture.write_ptr('impl', impl)
-        await fixture.write_word('vl', vl)
-        await fixture.write('index_buf', indices)
-        await fixture.write('in_buf', input_data)
-        await fixture.write('out_buf', np.zeros([out_size], dtype=dtype))
-
-        await fixture.run_to_halt()
-
-        actual_outputs = (await fixture.read(
-            'out_buf', out_size * np.dtype(dtype).itemsize)).view(dtype)
-
-        debug_msg = str({
-            'impl': impl,
-            'input': input_data,
-            'indices': indices,
-            'expected': expected_outputs,
-            'actual': actual_outputs,
-        })
-        assert (actual_outputs == expected_outputs).all(), debug_msg
-
-
 async def vector_load_segmented_indexed(
         dut,
         elf_name: str,
@@ -216,13 +153,14 @@
         # Index is in bytes so input needs to be in bytes.
         input_data = rng.integers(0, 256, in_bytes, dtype=np.uint8)
         # Input needs to be reinterpreted. Note indices in use can reach
-        # beyond index_dtype when dtype is wider than uint8.
-        indices_in_use = np.array([
-            np.arange(x + s * np.dtype(dtype).itemsize, x + (s + 1) * np.dtype(dtype).itemsize)
-            for s in range(segments)
-            for x in indices[:vl].astype(np.uint32)
-        ])
-        expected_outputs = input_data[indices_in_use].view(dtype)[..., 0]
+        # beyond index_dtype when dtype is wider than uint8 or when segments
+        # is >1.
+        indices_in_use = \
+            np.arange(segments).reshape(-1, 1, 1) * np.dtype(dtype).itemsize + \
+            indices[:vl].reshape(1, -1, 1) + \
+            np.arange(np.dtype(dtype).itemsize).reshape(1, 1, -1)
+        indices_in_use = indices_in_use.reshape(-1)
+        expected_outputs = input_data[indices_in_use].view(dtype)
         sbz = np.zeros(out_size - vl * segments, dtype=dtype)
         expected_outputs = np.concat((expected_outputs, sbz))
 
@@ -693,11 +631,12 @@
         return {
             'impl': impl,
             'vl': vl,
+            'segments': 1,
             'in_bytes': 256,
             'out_size': vl * 2,
         }
 
-    await vector_load_indexed(
+    await vector_load_segmented_indexed(
         dut = dut,
         elf_name = 'load8_index8.elf',
         cases = [
@@ -876,11 +815,12 @@
         return {
             'impl': impl,
             'vl': vl,
+            'segments': 1,
             'in_bytes': 32000,  # DTCM is 32KB
             'out_size': vl * 2,
         }
 
-    await vector_load_indexed(
+    await vector_load_segmented_indexed(
         dut = dut,
         elf_name = 'load8_index16.elf',
         cases = [
@@ -1135,11 +1075,12 @@
         return {
             'impl': impl,
             'vl': vl,
+            'segments': 1,
             'in_bytes': 32000,  # DTCM is 32KB
             'out_size': vl * 2,
         }
 
-    await vector_load_indexed(
+    await vector_load_segmented_indexed(
         dut = dut,
         elf_name = 'load8_index32.elf',
         cases = [
@@ -1174,11 +1115,12 @@
         return {
             'impl': impl,
             'vl': vl,
+            'segments': 1,
             'in_bytes': 257,  # 2 bytes at offset 255 reachable.
             'out_size': vl * 2,
         }
 
-    await vector_load_indexed(
+    await vector_load_segmented_indexed(
         dut = dut,
         elf_name = 'load16_index8.elf',
         cases = [
@@ -1499,11 +1441,12 @@
         return {
             'impl': impl,
             'vl': vl,
+            'segments': 1,
             'in_bytes': 259,  # 4 bytes at offset 255 reachable.
             'out_size': vl * 2,
         }
 
-    await vector_load_indexed(
+    await vector_load_segmented_indexed(
         dut = dut,
         elf_name = 'load32_index8.elf',
         cases = [