Add vstart and vxrm to Csr.

Change-Id: I4f37b5a8404940dea3893f8c2fad4d93e7fce078
diff --git a/hdl/chisel/src/kelvin/rvv/RvvCore.scala b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
index 4e57b99..fc1c8a9 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvCore.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
@@ -400,8 +400,10 @@
   rvvCoreWrapper.io.rd <> io.rd
   rvvCoreWrapper.io.async_rd <> io.async_rd
 
-  rvvCoreWrapper.io.vstart := vstart
-  rvvCoreWrapper.io.vxrm := vxrm
+  rvvCoreWrapper.io.vstart := Mux(
+      io.csr.csr_vstart.valid, io.csr.csr_vstart.bits, vstart)
+  rvvCoreWrapper.io.vxrm := Mux(
+      io.csr.csr_vxrm.valid, io.csr.csr_vxrm.bits, vxrm)
   rvvCoreWrapper.io.vxsat := 0.U
   rvvCoreWrapper.io.vcsr_ready := true.B
 
@@ -418,6 +420,18 @@
   io.configState.bits.lmul    := rvvCoreWrapper.io.configLmul
   io.rvv_idle                 := rvvCoreWrapper.io.rvv_idle
 
-  vstart := Mux(rvvCoreWrapper.io.vcsr_valid, rvvCoreWrapper.io.vcsr_vstart, vstart)
-  vxrm := Mux(rvvCoreWrapper.io.vcsr_valid, rvvCoreWrapper.io.vcsr_xrm, vxrm)
+  val vstart_wdata = MuxCase(vstart, Seq(
+      rvvCoreWrapper.io.vcsr_valid -> rvvCoreWrapper.io.vcsr_vstart,
+      io.csr.csr_vstart.valid -> io.csr.csr_vstart.bits,
+  ))
+  vstart := vstart_wdata
+
+  val vxrm_wdata = MuxCase(vxrm, Seq(
+      rvvCoreWrapper.io.vcsr_valid -> rvvCoreWrapper.io.vcsr_xrm,
+      io.csr.csr_vxrm.valid -> io.csr.csr_vxrm.bits,
+  ))
+  vxrm := vxrm_wdata
+
+  io.csr.vstart := vstart
+  io.csr.vxrm := vxrm
 }
diff --git a/hdl/chisel/src/kelvin/rvv/RvvInterface.scala b/hdl/chisel/src/kelvin/rvv/RvvInterface.scala
index 139de23..6c68488 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvInterface.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvInterface.scala
@@ -64,5 +64,15 @@
     // Async scalar regfile writes.
     val async_rd = Decoupled(new RegfileWriteDataIO)
 
+    // Csr Interface.
+    val csr = new RvvCsrIO(p)
+
     val rvv_idle = Output(Bool())
+}
+
+class RvvCsrIO(p: Parameters) extends Bundle {
+  val vstart = Output(UInt(log2Ceil(p.rvvVlen).W))
+  val vxrm = Output(UInt(2.W))
+  val csr_vstart = Input(Valid(UInt(log2Ceil(p.rvvVlen).W)))
+  val csr_vxrm = Input(Valid(UInt(2.W)))
 }
\ No newline at end of file
diff --git a/hdl/chisel/src/kelvin/scalar/Csr.scala b/hdl/chisel/src/kelvin/scalar/Csr.scala
index d5dd4e0..e8c4040 100644
--- a/hdl/chisel/src/kelvin/scalar/Csr.scala
+++ b/hdl/chisel/src/kelvin/scalar/Csr.scala
@@ -18,6 +18,15 @@
 import chisel3.util._
 import kelvin.float.{CsrFloatIO}
 
+class CsrRvvIO(p: Parameters) extends Bundle {
+  // To Csr from RvvCore
+  val vstart = Input(UInt(log2Ceil(p.rvvVlen).W))
+  val vxrm = Input(UInt(2.W))
+  // From Csr to RvvCore
+  val vstart_write = Output(Valid(UInt(log2Ceil(p.rvvVlen).W)))
+  val vxrm_write = Output(Valid(UInt(2.W)))
+}
+
 object Csr {
   def apply(p: Parameters): Csr = {
     return Module(new Csr(p))
@@ -28,6 +37,8 @@
   val FFLAGS    = Value(0x001.U(12.W))
   val FRM       = Value(0x002.U(12.W))
   val FCSR      = Value(0x003.U(12.W))
+  val VSTART    = Value(0x008.U(12.W))
+  val VXRM      = Value(0x009.U(12.W))
   val MSTATUS   = Value(0x300.U(12.W))
   val MISA      = Value(0x301.U(12.W))
   val MIE       = Value(0x304.U(12.W))
@@ -170,6 +181,7 @@
     val rd  = Valid(Flipped(new RegfileWriteDataIO))
     val bru = Flipped(new CsrBruIO(p))
     val float = Option.when(p.enableFloat) { Flipped(new CsrFloatIO(p)) }
+    val rvv = Option.when(p.enableRvv) { new CsrRvvIO(p) }
 
     // Vector core.
     val vcore = (if (p.enableVector) {
@@ -282,6 +294,8 @@
   val fflagsEn    = csr_address === CsrAddress.FFLAGS
   val frmEn       = csr_address === CsrAddress.FRM
   val fcsrEn      = csr_address === CsrAddress.FCSR
+  val vstartEn    = Option.when(p.enableRvv) { csr_address === CsrAddress.VSTART }
+  val vxrmEn      = Option.when(p.enableRvv) { csr_address === CsrAddress.VXRM }
   val mstatusEn   = csr_address === CsrAddress.MSTATUS
   val misaEn      = csr_address === CsrAddress.MISA
   val mieEn       = csr_address === CsrAddress.MIE
@@ -389,6 +403,8 @@
     ) ++
       Option.when(p.enableRvv) {
         Seq(
+          vstartEn.get -> io.rvv.get.vstart,
+          vxrmEn.get   -> io.rvv.get.vxrm,
           vlenbEn.get -> 16.U(32.W),  // Vector length in Bytes
         )
       }.getOrElse(Seq())
@@ -443,6 +459,13 @@
     }
   }
 
+  if (p.enableRvv) {
+    io.rvv.get.vstart_write.valid := req.valid && vstartEn.get
+    io.rvv.get.vstart_write.bits  := wdata(log2Ceil(p.rvvVlen)-1, 0)
+    io.rvv.get.vxrm_write.valid   := req.valid && vxrmEn.get
+    io.rvv.get.vxrm_write.bits    := wdata(1,0)
+  }
+
   // mcycle implementation
   // If one of the enable signals for
   // the register are true, overwrite the enabled half
diff --git a/hdl/chisel/src/kelvin/scalar/SCore.scala b/hdl/chisel/src/kelvin/scalar/SCore.scala
index 8905f2f..2a7d1ad 100644
--- a/hdl/chisel/src/kelvin/scalar/SCore.scala
+++ b/hdl/chisel/src/kelvin/scalar/SCore.scala
@@ -425,6 +425,11 @@
 
     // Register inputs
     io.rvvcore.get.rs := regfile.io.readData
+
+    io.rvvcore.get.csr.csr_vstart <> csr.io.rvv.get.vstart_write
+    io.rvvcore.get.csr.csr_vxrm <> csr.io.rvv.get.vxrm_write
+    csr.io.rvv.get.vstart := io.rvvcore.get.csr.vstart
+    csr.io.rvv.get.vxrm := io.rvvcore.get.csr.vxrm
   }
 
   // ---------------------------------------------------------------------------
diff --git a/tests/cocotb/BUILD b/tests/cocotb/BUILD
index a7da55c..989df39 100644
--- a/tests/cocotb/BUILD
+++ b/tests/cocotb/BUILD
@@ -183,7 +183,8 @@
 
 RVV_TEST_BINARY_TARGETS = [
         "//tests/cocotb/rvv:rvv_load.elf",
-        "//tests/cocotb/rvv:rvv_add.elf"] + [
+        "//tests/cocotb/rvv:rvv_add.elf",
+        "//tests/cocotb/rvv:vstart_store.elf"] + [
         "//tests/cocotb/rvv/arithmetics:rvv_{}_{}_m1.elf".format(TEST_OP, DTYPE)
         for DTYPE in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]
         for TEST_OP in ["add", "sub", "mul", "div"]
diff --git a/tests/cocotb/rvv/BUILD b/tests/cocotb/rvv/BUILD
index 2a4a4c1..b53ebe7 100644
--- a/tests/cocotb/rvv/BUILD
+++ b/tests/cocotb/rvv/BUILD
@@ -26,5 +26,8 @@
         "rvv_load": {
             "srcs": ["rvv_load.S"],
         },
+        "vstart_store": {
+            "srcs": ["vstart_store.S"],
+        },
     },
 )
diff --git a/tests/cocotb/rvv/vstart_store.S b/tests/cocotb/rvv/vstart_store.S
new file mode 100644
index 0000000..0a2dcb2
--- /dev/null
+++ b/tests/cocotb/rvv/vstart_store.S
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+.section .data
+.align 16
+
+.global input_data
+input_data:
+    .space 64
+
+.global output_data
+output_data:
+    .space 64
+
+.section .text
+.global main
+.option norelax
+
+main:
+    la t0, input_data
+    la t1, output_data
+
+    vsetivli x0, 16, e8, m1, ta, ma
+    vle8.v v1, (t0)
+
+    li t2, 4
+    csrw vstart, t2
+
+    vse8.v v1, (t1)
+
+    wfi
+    ret
diff --git a/tests/cocotb/rvv_assembly_cocotb_test.py b/tests/cocotb/rvv_assembly_cocotb_test.py
index 8930e9d..f8b0154 100644
--- a/tests/cocotb/rvv_assembly_cocotb_test.py
+++ b/tests/cocotb/rvv_assembly_cocotb_test.py
@@ -100,3 +100,45 @@
         print(f" number of values supposed to be printed {num_values}", flush=True)
         await core_mini_axi.raise_irq()
     await core_mini_axi.wait_for_halted()
+
+
+@cocotb.test()
+async def core_mini_vstart_store(dut):
+    """Testbench to test vstart store.
+    """
+    # Test bench setup
+    core_mini_axi = CoreMiniAxiInterface(dut)
+    await core_mini_axi.init()
+    await core_mini_axi.reset()
+    cocotb.start_soon(core_mini_axi.clock.start())
+
+    elf_path = "../tests/cocotb/rvv/vstart_store.elf"
+    if not elf_path:
+        raise ValueError("elf_path must consist a valid path")
+    with open(elf_path, "rb") as f:
+        entry_point = await core_mini_axi.load_elf(f)
+
+    #Write your program inputs
+    with open(elf_path, "rb") as f:
+        input_addr = core_mini_axi.lookup_symbol(f, "input_data")
+        output_addr = core_mini_axi.lookup_symbol(f, "output_data")
+
+    input_data = np.random.randint(
+        np.iinfo(np.uint8).min, np.iinfo(np.uint8).max, 16, dtype=np.uint8)
+    await core_mini_axi.write(input_addr, input_data)
+    await core_mini_axi.write(output_addr, np.zeros(16, dtype=np.uint8))
+
+    await core_mini_axi.execute_from(entry_point)
+    await core_mini_axi.wait_for_wfi()
+
+    output_data = (await core_mini_axi.read(output_addr, 16)).view(np.uint8)
+
+    # vstart is 4, so first 4 elements are skipped.
+    # 12 elements are stored.
+    print(f"input_data={input_data}", flush=True)
+    print(f"output_data={output_data}", flush=True)
+    assert np.array_equal(output_data[0:4], np.zeros(4, dtype=np.uint8))
+    assert np.array_equal(output_data[4:], input_data[4:])
+
+    await core_mini_axi.raise_irq()
+    await core_mini_axi.wait_for_halted()