Add vstart and vxrm to Csr.
Change-Id: I4f37b5a8404940dea3893f8c2fad4d93e7fce078
diff --git a/hdl/chisel/src/kelvin/rvv/RvvCore.scala b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
index 4e57b99..fc1c8a9 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvCore.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
@@ -400,8 +400,10 @@
rvvCoreWrapper.io.rd <> io.rd
rvvCoreWrapper.io.async_rd <> io.async_rd
- rvvCoreWrapper.io.vstart := vstart
- rvvCoreWrapper.io.vxrm := vxrm
+ rvvCoreWrapper.io.vstart := Mux(
+ io.csr.csr_vstart.valid, io.csr.csr_vstart.bits, vstart)
+ rvvCoreWrapper.io.vxrm := Mux(
+ io.csr.csr_vxrm.valid, io.csr.csr_vxrm.bits, vxrm)
rvvCoreWrapper.io.vxsat := 0.U
rvvCoreWrapper.io.vcsr_ready := true.B
@@ -418,6 +420,18 @@
io.configState.bits.lmul := rvvCoreWrapper.io.configLmul
io.rvv_idle := rvvCoreWrapper.io.rvv_idle
- vstart := Mux(rvvCoreWrapper.io.vcsr_valid, rvvCoreWrapper.io.vcsr_vstart, vstart)
- vxrm := Mux(rvvCoreWrapper.io.vcsr_valid, rvvCoreWrapper.io.vcsr_xrm, vxrm)
+ val vstart_wdata = MuxCase(vstart, Seq(
+ rvvCoreWrapper.io.vcsr_valid -> rvvCoreWrapper.io.vcsr_vstart,
+ io.csr.csr_vstart.valid -> io.csr.csr_vstart.bits,
+ ))
+ vstart := vstart_wdata
+
+ val vxrm_wdata = MuxCase(vxrm, Seq(
+ rvvCoreWrapper.io.vcsr_valid -> rvvCoreWrapper.io.vcsr_xrm,
+ io.csr.csr_vxrm.valid -> io.csr.csr_vxrm.bits,
+ ))
+ vxrm := vxrm_wdata
+
+ io.csr.vstart := vstart
+ io.csr.vxrm := vxrm
}
diff --git a/hdl/chisel/src/kelvin/rvv/RvvInterface.scala b/hdl/chisel/src/kelvin/rvv/RvvInterface.scala
index 139de23..6c68488 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvInterface.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvInterface.scala
@@ -64,5 +64,15 @@
// Async scalar regfile writes.
val async_rd = Decoupled(new RegfileWriteDataIO)
+ // Csr Interface.
+ val csr = new RvvCsrIO(p)
+
val rvv_idle = Output(Bool())
+}
+
+class RvvCsrIO(p: Parameters) extends Bundle {
+ val vstart = Output(UInt(log2Ceil(p.rvvVlen).W))
+ val vxrm = Output(UInt(2.W))
+ val csr_vstart = Input(Valid(UInt(log2Ceil(p.rvvVlen).W)))
+ val csr_vxrm = Input(Valid(UInt(2.W)))
}
\ No newline at end of file
diff --git a/hdl/chisel/src/kelvin/scalar/Csr.scala b/hdl/chisel/src/kelvin/scalar/Csr.scala
index d5dd4e0..e8c4040 100644
--- a/hdl/chisel/src/kelvin/scalar/Csr.scala
+++ b/hdl/chisel/src/kelvin/scalar/Csr.scala
@@ -18,6 +18,15 @@
import chisel3.util._
import kelvin.float.{CsrFloatIO}
+class CsrRvvIO(p: Parameters) extends Bundle {
+ // To Csr from RvvCore
+ val vstart = Input(UInt(log2Ceil(p.rvvVlen).W))
+ val vxrm = Input(UInt(2.W))
+ // From Csr to RvvCore
+ val vstart_write = Output(Valid(UInt(log2Ceil(p.rvvVlen).W)))
+ val vxrm_write = Output(Valid(UInt(2.W)))
+}
+
object Csr {
def apply(p: Parameters): Csr = {
return Module(new Csr(p))
@@ -28,6 +37,8 @@
val FFLAGS = Value(0x001.U(12.W))
val FRM = Value(0x002.U(12.W))
val FCSR = Value(0x003.U(12.W))
+ val VSTART = Value(0x008.U(12.W))
+ val VXRM = Value(0x009.U(12.W))
val MSTATUS = Value(0x300.U(12.W))
val MISA = Value(0x301.U(12.W))
val MIE = Value(0x304.U(12.W))
@@ -170,6 +181,7 @@
val rd = Valid(Flipped(new RegfileWriteDataIO))
val bru = Flipped(new CsrBruIO(p))
val float = Option.when(p.enableFloat) { Flipped(new CsrFloatIO(p)) }
+ val rvv = Option.when(p.enableRvv) { new CsrRvvIO(p) }
// Vector core.
val vcore = (if (p.enableVector) {
@@ -282,6 +294,8 @@
val fflagsEn = csr_address === CsrAddress.FFLAGS
val frmEn = csr_address === CsrAddress.FRM
val fcsrEn = csr_address === CsrAddress.FCSR
+ val vstartEn = Option.when(p.enableRvv) { csr_address === CsrAddress.VSTART }
+ val vxrmEn = Option.when(p.enableRvv) { csr_address === CsrAddress.VXRM }
val mstatusEn = csr_address === CsrAddress.MSTATUS
val misaEn = csr_address === CsrAddress.MISA
val mieEn = csr_address === CsrAddress.MIE
@@ -389,6 +403,8 @@
) ++
Option.when(p.enableRvv) {
Seq(
+ vstartEn.get -> io.rvv.get.vstart,
+ vxrmEn.get -> io.rvv.get.vxrm,
vlenbEn.get -> 16.U(32.W), // Vector length in Bytes
)
}.getOrElse(Seq())
@@ -443,6 +459,13 @@
}
}
+ if (p.enableRvv) {
+ io.rvv.get.vstart_write.valid := req.valid && vstartEn.get
+ io.rvv.get.vstart_write.bits := wdata(log2Ceil(p.rvvVlen)-1, 0)
+ io.rvv.get.vxrm_write.valid := req.valid && vxrmEn.get
+ io.rvv.get.vxrm_write.bits := wdata(1,0)
+ }
+
// mcycle implementation
// If one of the enable signals for
// the register are true, overwrite the enabled half
diff --git a/hdl/chisel/src/kelvin/scalar/SCore.scala b/hdl/chisel/src/kelvin/scalar/SCore.scala
index 8905f2f..2a7d1ad 100644
--- a/hdl/chisel/src/kelvin/scalar/SCore.scala
+++ b/hdl/chisel/src/kelvin/scalar/SCore.scala
@@ -425,6 +425,11 @@
// Register inputs
io.rvvcore.get.rs := regfile.io.readData
+
+ io.rvvcore.get.csr.csr_vstart <> csr.io.rvv.get.vstart_write
+ io.rvvcore.get.csr.csr_vxrm <> csr.io.rvv.get.vxrm_write
+ csr.io.rvv.get.vstart := io.rvvcore.get.csr.vstart
+ csr.io.rvv.get.vxrm := io.rvvcore.get.csr.vxrm
}
// ---------------------------------------------------------------------------
diff --git a/tests/cocotb/BUILD b/tests/cocotb/BUILD
index a7da55c..989df39 100644
--- a/tests/cocotb/BUILD
+++ b/tests/cocotb/BUILD
@@ -183,7 +183,8 @@
RVV_TEST_BINARY_TARGETS = [
"//tests/cocotb/rvv:rvv_load.elf",
- "//tests/cocotb/rvv:rvv_add.elf"] + [
+ "//tests/cocotb/rvv:rvv_add.elf",
+ "//tests/cocotb/rvv:vstart_store.elf"] + [
"//tests/cocotb/rvv/arithmetics:rvv_{}_{}_m1.elf".format(TEST_OP, DTYPE)
for DTYPE in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]
for TEST_OP in ["add", "sub", "mul", "div"]
diff --git a/tests/cocotb/rvv/BUILD b/tests/cocotb/rvv/BUILD
index 2a4a4c1..b53ebe7 100644
--- a/tests/cocotb/rvv/BUILD
+++ b/tests/cocotb/rvv/BUILD
@@ -26,5 +26,8 @@
"rvv_load": {
"srcs": ["rvv_load.S"],
},
+ "vstart_store": {
+ "srcs": ["vstart_store.S"],
+ },
},
)
diff --git a/tests/cocotb/rvv/vstart_store.S b/tests/cocotb/rvv/vstart_store.S
new file mode 100644
index 0000000..0a2dcb2
--- /dev/null
+++ b/tests/cocotb/rvv/vstart_store.S
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2025 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+.section .data
+.align 16
+
+.global input_data
+input_data:
+ .space 64
+
+.global output_data
+output_data:
+ .space 64
+
+.section .text
+.global main
+.option norelax
+
+main:
+ la t0, input_data
+ la t1, output_data
+
+ vsetivli x0, 16, e8, m1, ta, ma
+ vle8.v v1, (t0)
+
+ li t2, 4
+ csrw vstart, t2
+
+ vse8.v v1, (t1)
+
+ wfi
+ ret
diff --git a/tests/cocotb/rvv_assembly_cocotb_test.py b/tests/cocotb/rvv_assembly_cocotb_test.py
index 8930e9d..f8b0154 100644
--- a/tests/cocotb/rvv_assembly_cocotb_test.py
+++ b/tests/cocotb/rvv_assembly_cocotb_test.py
@@ -100,3 +100,45 @@
print(f" number of values supposed to be printed {num_values}", flush=True)
await core_mini_axi.raise_irq()
await core_mini_axi.wait_for_halted()
+
+
+@cocotb.test()
+async def core_mini_vstart_store(dut):
+ """Testbench to test vstart store.
+ """
+ # Test bench setup
+ core_mini_axi = CoreMiniAxiInterface(dut)
+ await core_mini_axi.init()
+ await core_mini_axi.reset()
+ cocotb.start_soon(core_mini_axi.clock.start())
+
+ elf_path = "../tests/cocotb/rvv/vstart_store.elf"
+ if not elf_path:
+ raise ValueError("elf_path must consist a valid path")
+ with open(elf_path, "rb") as f:
+ entry_point = await core_mini_axi.load_elf(f)
+
+ #Write your program inputs
+ with open(elf_path, "rb") as f:
+ input_addr = core_mini_axi.lookup_symbol(f, "input_data")
+ output_addr = core_mini_axi.lookup_symbol(f, "output_data")
+
+ input_data = np.random.randint(
+ np.iinfo(np.uint8).min, np.iinfo(np.uint8).max, 16, dtype=np.uint8)
+ await core_mini_axi.write(input_addr, input_data)
+ await core_mini_axi.write(output_addr, np.zeros(16, dtype=np.uint8))
+
+ await core_mini_axi.execute_from(entry_point)
+ await core_mini_axi.wait_for_wfi()
+
+ output_data = (await core_mini_axi.read(output_addr, 16)).view(np.uint8)
+
+ # vstart is 4, so first 4 elements are skipped.
+ # 12 elements are stored.
+ print(f"input_data={input_data}", flush=True)
+ print(f"output_data={output_data}", flush=True)
+ assert np.array_equal(output_data[0:4], np.zeros(4, dtype=np.uint8))
+ assert np.array_equal(output_data[4:], input_data[4:])
+
+ await core_mini_axi.raise_irq()
+ await core_mini_axi.wait_for_halted()