Add tests for strided/strided-segmented loads/stores.

Check combinations of SEW and EEW.

Change-Id: I733ca6b95c3abd8b1200ba9f003eb59b333aa353
diff --git a/hdl/chisel/src/kelvin/scalar/Lsu.scala b/hdl/chisel/src/kelvin/scalar/Lsu.scala
index e62d108..755309a 100644
--- a/hdl/chisel/src/kelvin/scalar/Lsu.scala
+++ b/hdl/chisel/src/kelvin/scalar/Lsu.scala
@@ -504,9 +504,19 @@
     result.baseAddr := MuxCase(baseAddr, Seq(
       (!writeback || !lmulUpdate) -> baseAddr,
       // For Unit and strided updates
-      op.isOneOf(LsuOp.VLOAD_UNIT, LsuOp.VSTORE_UNIT,
-                 LsuOp.VLOAD_STRIDED, LsuOp.VSTORE_STRIDED) ->
+      op.isOneOf(LsuOp.VLOAD_UNIT, LsuOp.VSTORE_UNIT) ->
           (baseAddr + (vectorLoop.segment.max * 16.U) + 16.U),
+      op.isOneOf(LsuOp.VLOAD_STRIDED, LsuOp.VSTORE_STRIDED) ->
+          MuxCase(baseAddr + (elemStride * bytesPerSlot.U), Seq(
+            (elemWidth === "b000".U) ->
+                (baseAddr + (elemStride * bytesPerSlot.U)),
+            (elemWidth === "b101".U) ->
+                (baseAddr + (elemStride * (bytesPerSlot/2).U)),
+            (elemWidth === "b110".U) ->
+                (baseAddr + (elemStride * (bytesPerSlot/4).U)),
+          ))
+          // (baseAddr + (vectorLoop.segment.max * elemStride)(31, 0)),
+
       // Indexed don't have base addr changed.
     ))
     result.rd := result.vectorLoop.rd
diff --git a/tests/cocotb/BUILD b/tests/cocotb/BUILD
index 69bce27..21b6f84 100644
--- a/tests/cocotb/BUILD
+++ b/tests/cocotb/BUILD
@@ -176,6 +176,7 @@
     "load_store_bits",
     "load_unit_masked",
     "load_unit_all_vtypes_test",
+    "load_strided_all_vtypes_test",
     "load8_index8",
     "load8_index8_seg",
     "load8_index16",
@@ -206,6 +207,7 @@
     "load16_segment2_stride6_m1",
     "store_unit_masked",
     "store_unit_all_vtypes_test",
+    "store_strided_all_vtypes_test",
     "store8_index8",
     "store8_index8_seg",
     "store16_index8",
diff --git a/tests/cocotb/rvv/load_store/BUILD b/tests/cocotb/rvv/load_store/BUILD
index 32ede03..73d5979 100644
--- a/tests/cocotb/rvv/load_store/BUILD
+++ b/tests/cocotb/rvv/load_store/BUILD
@@ -29,6 +29,9 @@
         "load_unit_vtype": {
             "srcs": ["load_unit_vtype.cc"],
         },
+        "load_stride_vtype": {
+            "srcs": ["load_stride_vtype.cc"],
+        },
         "load8_index8": {
             "srcs": ["load8_index8.cc"],
         },
@@ -152,6 +155,9 @@
         "store_unit_vtype": {
             "srcs": ["store_unit_vtype.cc"],
         },
+        "store_strided_vtype": {
+            "srcs": ["store_strided_vtype.cc"],
+        },
     },
 )
 
@@ -161,6 +167,7 @@
         ":load_store_bits.elf",
         ":load_unit_masked.elf",
         ":load_unit_vtype.elf",
+        ":load_stride_vtype.elf",
         ":load8_index8.elf",
         ":load8_index8_seg.elf",
         ":load8_index16.elf",
@@ -202,5 +209,6 @@
         ":store32_seg_unit",
         ":store_unit_masked.elf",
         ":store_unit_vtype.elf",
+        ":store_strided_vtype.elf",
     ],
 )
\ No newline at end of file
diff --git a/tests/cocotb/rvv/load_store/load_stride_vtype.cc b/tests/cocotb/rvv/load_store/load_stride_vtype.cc
new file mode 100644
index 0000000..8658e0e
--- /dev/null
+++ b/tests/cocotb/rvv/load_store/load_stride_vtype.cc
@@ -0,0 +1,90 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <riscv_vector.h>
+#include <stdint.h>
+
+// Tests strided loads after setting vtype. Loads/stores should be agnostic to
+// vtype.
+
+size_t vl __attribute__((section(".data"))) = 16;
+size_t vtype __attribute__((section(".data"))) = 0;
+size_t stride __attribute__((section(".data"))) = 1;
+uint8_t load_data[8192] __attribute__((section(".data")));
+uint8_t store_data[256] __attribute__((section(".data")));
+
+extern "C" {
+
+#define CREATE_STRIDED_LOAD_FN(name, data_bits)                           \
+  __attribute__((used, retain)) void name() {                             \
+    size_t store_vl = 8 * __riscv_vlenb();                                \
+    asm("vsetvl zero, %[vl], %[vtype];"                                    \
+        "vlse" #data_bits ".v v8, %[load_data], %[stride];"                \
+        "vsetvli zero, %[store_vl], e8, m8, ta, ma;"                       \
+        "vse8.v v8, %[store_data];"                                       \
+        : [store_data] "=m"(store_data)                                   \
+        : [vl] "r"(vl), [store_vl] "r"(store_vl), [vtype] "r"(vtype),      \
+          [stride] "r"(stride), [load_data] "m"(load_data)                \
+        : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "vl",      \
+          "vtype");                                                       \
+  }
+
+#define CREATE_SEGMENT_STRIDED_LOAD_FN(name, data_bits, segment)            \
+  __attribute__((used, retain)) void name() {                               \
+    size_t store_vl = 8 * __riscv_vlenb();                                  \
+    asm("vsetvl zero, %[vl], %[vtype];"                                      \
+        "vlsseg" #segment "e" #data_bits ".v v8, %[load_data], %[stride];"   \
+        "vsetvli zero, %[store_vl], e8, m8, ta, ma;"                         \
+        "vse8.v v8, %[store_data];"                                         \
+        : [store_data] "=m"(store_data)                                     \
+        : [vl] "r"(vl), [store_vl] "r"(store_vl), [vtype] "r"(vtype),        \
+          [stride] "r"(stride), [load_data] "m"(load_data)                  \
+        : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "vl",        \
+          "vtype");                                                         \
+  }
+
+CREATE_STRIDED_LOAD_FN(test_vlse8, 8)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg2e8, 8, 2)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg3e8, 8, 3)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg4e8, 8, 4)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg5e8, 8, 5)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg6e8, 8, 6)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg7e8, 8, 7)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg8e8, 8, 8)
+
+CREATE_STRIDED_LOAD_FN(test_vlse16, 16)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg2e16, 16, 2)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg3e16, 16, 3)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg4e16, 16, 4)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg5e16, 16, 5)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg6e16, 16, 6)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg7e16, 16, 7)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg8e16, 16, 8)
+
+CREATE_STRIDED_LOAD_FN(test_vlse32, 32)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg2e32, 32, 2)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg3e32, 32, 3)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg4e32, 32, 4)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg5e32, 32, 5)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg6e32, 32, 6)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg7e32, 32, 7)
+CREATE_SEGMENT_STRIDED_LOAD_FN(test_vlsseg8e32, 32, 8)
+}
+
+void (*impl)() __attribute__((section(".data"))) = &test_vlse8;
+
+int main(int argc, char** argv) {
+  impl();
+  return 0;
+}
diff --git a/tests/cocotb/rvv/load_store/store_strided_vtype.cc b/tests/cocotb/rvv/load_store/store_strided_vtype.cc
new file mode 100644
index 0000000..0729ad2
--- /dev/null
+++ b/tests/cocotb/rvv/load_store/store_strided_vtype.cc
@@ -0,0 +1,96 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <riscv_vector.h>
+#include <stdint.h>
+
+// Tests strided stores after setting vtype. Strided stores should be vtype-agnostic.
+
+size_t vl __attribute__((section(".data"))) = 16;
+size_t vtype __attribute__((section(".data"))) = 0;
+size_t stride __attribute__((section(".data"))) = 0;
+uint8_t load_data[256] __attribute__((section(".data")));
+uint8_t store_data[8192] __attribute__((section(".data")));
+
+extern "C" {
+
+#define CREATE_STRIDED_STORE_FN(name, data_bits) \
+__attribute__((used, retain)) void name() { \
+  size_t load_vl = 8*__riscv_vlenb(); \
+  asm("vsetvli zero, %[load_vl], e8, m8, ta, ma;" \
+      "vle8.v v8, %[load_data];" \
+      "vsetvl zero, %[vl], %[vtype];" \
+      "vsse" #data_bits ".v v8, %[store_data], %[stride];" \
+      : [store_data] "=m"(store_data) \
+      : [vl] "r"(vl), \
+        [load_vl] "r"(load_vl), \
+        [vtype] "r"(vtype), \
+        [stride] "r"(stride), \
+        [load_data] "m"(load_data) \
+      : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", \
+        "vl", "vtype"); \
+}
+
+#define CREATE_STRIDED_SEGMENT_STORE_FN(name, data_bits, segment) \
+__attribute__((used, retain)) void name() { \
+  size_t load_vl = 8*__riscv_vlenb(); \
+  asm("vsetvli zero, %[load_vl], e8, m8, ta, ma;" \
+      "vle8.v v8, %[load_data];" \
+      "vsetvl zero, %[vl], %[vtype];" \
+      "vssseg" #segment "e" #data_bits ".v v8, %[store_data], %[stride];" \
+      : [store_data] "=m"(store_data) \
+      : [vl] "r"(vl), \
+        [load_vl] "r"(load_vl), \
+        [vtype] "r"(vtype), \
+        [stride] "r"(stride), \
+        [load_data] "m"(load_data) \
+      : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", \
+        "vl", "vtype"); \
+}
+
+CREATE_STRIDED_STORE_FN(test_vsse8, 8)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg2e8, 8, 2)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg3e8, 8, 3)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg4e8, 8, 4)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg5e8, 8, 5)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg6e8, 8, 6)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg7e8, 8, 7)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg8e8, 8, 8)
+
+CREATE_STRIDED_STORE_FN(test_vsse16, 16)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg2e16, 16, 2)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg3e16, 16, 3)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg4e16, 16, 4)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg5e16, 16, 5)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg6e16, 16, 6)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg7e16, 16, 7)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg8e16, 16, 8)
+
+CREATE_STRIDED_STORE_FN(test_vsse32, 32)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg2e32, 32, 2)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg3e32, 32, 3)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg4e32, 32, 4)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg5e32, 32, 5)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg6e32, 32, 6)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg7e32, 32, 7)
+CREATE_STRIDED_SEGMENT_STORE_FN(test_vssseg8e32, 32, 8)
+
+}
+
+void (*impl)() __attribute__((section(".data"))) = &test_vsse8;
+
+int main(int argc, char** argv) {
+  impl();
+  return 0;
+}
diff --git a/tests/cocotb/rvv_load_store_test.py b/tests/cocotb/rvv_load_store_test.py
index bfb22d1..d9ff26c 100644
--- a/tests/cocotb/rvv_load_store_test.py
+++ b/tests/cocotb/rvv_load_store_test.py
@@ -2657,3 +2657,168 @@
             ]
             expected_result = np.transpose(np.stack(segment_regs)).flatten()
             assert (expected_result == store_data).all()
+
+
+@cocotb.test()
+async def load_strided_all_vtypes_test(dut):
+    """Testbench to test RVV strided/segmented loads, with all vtypes."""
+    fixture = await Fixture.Create(dut)
+    r = runfiles.Create()
+    functions = [
+        ("test_vlse8",      np.uint8, 1),
+        ("test_vlsseg2e8",  np.uint8, 2),
+        ("test_vlsseg3e8",  np.uint8, 3),
+        ("test_vlsseg4e8",  np.uint8, 4),
+        ("test_vlsseg5e8",  np.uint8, 5),
+        ("test_vlsseg6e8",  np.uint8, 6),
+        ("test_vlsseg7e8",  np.uint8, 7),
+        ("test_vlsseg8e8",  np.uint8, 8),
+        ("test_vlse16",     np.uint16, 1),
+        ("test_vlsseg2e16", np.uint16, 2),
+        ("test_vlsseg3e16", np.uint16, 3),
+        ("test_vlsseg4e16", np.uint16, 4),
+        ("test_vlsseg5e16", np.uint16, 5),
+        ("test_vlsseg6e16", np.uint16, 6),
+        ("test_vlsseg7e16", np.uint16, 7),
+        ("test_vlsseg8e16", np.uint16, 8),
+        ("test_vlse32",     np.uint32, 1),
+        ("test_vlsseg2e32", np.uint32, 2),
+        ("test_vlsseg3e32", np.uint32, 3),
+        ("test_vlsseg4e32", np.uint32, 4),
+        ("test_vlsseg5e32", np.uint32, 5),
+        ("test_vlsseg6e32", np.uint32, 6),
+        ("test_vlsseg7e32", np.uint32, 7),
+        ("test_vlsseg8e32", np.uint32, 8),
+    ]
+
+    await fixture.load_elf_and_lookup_symbols(
+        r.Rlocation('kelvin_hw/tests/cocotb/rvv/load_store/load_stride_vtype.elf'),
+        ['vl', 'vtype', 'stride', 'load_data', 'store_data', 'impl'] +
+            list(f[0] for f in functions),
+    )
+
+    vlenb = 16
+    with tqdm.tqdm(functions) as t:
+      for (function, dtype, segments) in t:
+        for sew in SEWS:
+          for lmul, vlmax in SEW_TO_LMULS_AND_VLMAXS[DTYPE_TO_SEW[dtype]]:
+            if (LMUL_TO_EMUL[lmul] * segments) > 8:
+              continue
+
+            t.set_postfix({
+                'function': function,
+                'sew': sew,
+                'lmul': lmul,
+            })
+            # TODO(derekjchow): Use sew instead of DTYPE_TO_SEW[dtype]
+            vtype = construct_vtype(1, 1, DTYPE_TO_SEW[dtype], lmul)
+            stride = 32
+            await fixture.write_ptr('impl', function)
+            await fixture.write_word('vtype', vtype)
+            await fixture.write_word('vl', vlmax)
+            await fixture.write_word('stride', stride)
+            load_data = np.random.randint(
+                0, 255, 8192, dtype=np.uint8)
+            await fixture.write('load_data', load_data)
+            await fixture.write('store_data',
+                                np.zeros(256, dtype=np.uint8))
+
+            await fixture.run_to_halt()
+
+            segment_size = segments * np.dtype(dtype).itemsize
+            expected_segment_data = [
+                load_data[(stride*i):(stride*i) + segment_size].view(dtype)
+                for i in range(vlmax)]
+            expected_segment_data = np.transpose(
+                np.stack(expected_segment_data))
+
+            store_data = (await fixture.read('store_data', 256))
+
+            regsize = vlenb * LMUL_TO_EMUL[lmul]
+            for segment in range(segments):
+              segment_reg = store_data[segment*regsize:(segment+1)*regsize]
+              store_result = segment_reg.view(dtype)[0:vlmax]
+              expected_result = expected_segment_data[segment]
+              assert (store_result == expected_result).all()
+
+
+@cocotb.test()
+async def store_strided_all_vtypes_test(dut):
+    """Testbench to test RVV strided/segmented store, with all vtypes."""
+    fixture = await Fixture.Create(dut)
+    r = runfiles.Create()
+    functions = [
+        ("test_vsse8",      np.uint8, 1),
+        ("test_vssseg2e8",  np.uint8, 2),
+        ("test_vssseg3e8",  np.uint8, 3),
+        ("test_vssseg4e8",  np.uint8, 4),
+        ("test_vssseg5e8",  np.uint8, 5),
+        ("test_vssseg6e8",  np.uint8, 6),
+        ("test_vssseg7e8",  np.uint8, 7),
+        ("test_vssseg8e8",  np.uint8, 8),
+        ("test_vsse16",     np.uint16, 1),
+        ("test_vssseg2e16", np.uint16, 2),
+        ("test_vssseg3e16", np.uint16, 3),
+        ("test_vssseg4e16", np.uint16, 4),
+        ("test_vssseg5e16", np.uint16, 5),
+        ("test_vssseg6e16", np.uint16, 6),
+        ("test_vssseg7e16", np.uint16, 7),
+        ("test_vssseg8e16", np.uint16, 8),
+        ("test_vsse32",     np.uint32, 1),
+        ("test_vssseg2e32", np.uint32, 2),
+        ("test_vssseg3e32", np.uint32, 3),
+        ("test_vssseg4e32", np.uint32, 4),
+        ("test_vssseg5e32", np.uint32, 5),
+        ("test_vssseg6e32", np.uint32, 6),
+        ("test_vssseg7e32", np.uint32, 7),
+        ("test_vssseg8e32", np.uint32, 8),
+    ]
+
+    await fixture.load_elf_and_lookup_symbols(
+        r.Rlocation('kelvin_hw/tests/cocotb/rvv/load_store/store_strided_vtype.elf'),
+        ['vl', 'vtype', 'stride', 'load_data', 'store_data', 'impl'] +
+            list(f[0] for f in functions),
+    )
+
+    vlenb = 16
+    with tqdm.tqdm(functions) as t:
+      for (function, dtype, segments) in t:
+        for sew in SEWS:
+          for lmul, vlmax in SEW_TO_LMULS_AND_VLMAXS[DTYPE_TO_SEW[dtype]]:
+            if (LMUL_TO_EMUL[lmul] * segments) > 8:
+              continue
+
+            t.set_postfix({
+                'function': function,
+                'sew': sew,
+                'lmul': lmul,
+            })
+            # TODO(derekjchow): Use sew instead of DTYPE_TO_SEW[dtype]
+            vtype = construct_vtype(1, 1, DTYPE_TO_SEW[dtype], lmul)
+            stride = 32
+            await fixture.write_ptr('impl', function)
+            await fixture.write_word('vtype', vtype)
+            await fixture.write_word('vl', vlmax)
+            await fixture.write_word('stride', stride)
+            load_data = np.random.randint(
+                0, 255, 256, dtype=np.uint8)
+            await fixture.write('load_data', load_data)
+            await fixture.write('store_data',
+                                np.zeros(8192, dtype=np.uint8))
+
+            await fixture.run_to_halt()
+
+            regstride = vlenb * LMUL_TO_EMUL[lmul]
+            regsize = vlmax * np.dtype(dtype).itemsize
+            regs = [
+                load_data[(i*regstride):(i*regstride)+regsize].view(dtype)
+                for i in range(segments)
+            ]
+            segment_regs = np.transpose(np.stack(regs)).copy()
+            expected_store_data = np.zeros(8192, dtype=np.uint8)
+            for v in range(vlmax):
+              data = segment_regs[v].view(np.uint8)
+              expected_store_data[(v*stride):(v*stride)+ len(data)] = data
+
+            store_data = (await fixture.read('store_data', 8192))
+            assert (expected_store_data == store_data).all()