Fault for reduction instructions when vstart != 0.
Change-Id: I69c8f28402e760971d23b63776bec3afd68110b7
diff --git a/hdl/chisel/src/kelvin/rvv/RvvCore.scala b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
index 75ae92a..2917755 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvCore.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
@@ -459,16 +459,22 @@
io.rvv2lsu <> rvvCoreWrapper.io.rvv2lsu
io.lsu2rvv <> rvvCoreWrapper.io.lsu2rvv
- io.configState.valid := rvvCoreWrapper.io.configStateValid
- io.configState.bits.vl := rvvCoreWrapper.io.configVl
- io.configState.bits.vstart := rvvCoreWrapper.io.configVstart
- io.configState.bits.ma := rvvCoreWrapper.io.configMa
- io.configState.bits.ta := rvvCoreWrapper.io.configTa
- io.configState.bits.xrm := rvvCoreWrapper.io.configXrm
- io.configState.bits.sew := rvvCoreWrapper.io.configSew
- io.configState.bits.lmul := rvvCoreWrapper.io.configLmul
- io.rvv_idle := rvvCoreWrapper.io.rvv_idle
- io.queue_capacity := rvvCoreWrapper.io.queue_capacity
+ // Conservatively mark config state as invalid the cycle when CSR instruction
+ // updates vstart, vxrm or vxsat.
+ io.configState.valid := rvvCoreWrapper.io.configStateValid &&
+ !rvvCoreWrapper.io.vcsr_valid &&
+ !io.csr.vstart_write.valid &&
+ !io.csr.vxrm_write.valid &&
+ !io.csr.vxsat_write.valid
+ io.configState.bits.vl := rvvCoreWrapper.io.configVl
+ io.configState.bits.vstart := rvvCoreWrapper.io.configVstart
+ io.configState.bits.ma := rvvCoreWrapper.io.configMa
+ io.configState.bits.ta := rvvCoreWrapper.io.configTa
+ io.configState.bits.xrm := rvvCoreWrapper.io.configXrm
+ io.configState.bits.sew := rvvCoreWrapper.io.configSew
+ io.configState.bits.lmul := rvvCoreWrapper.io.configLmul
+ io.rvv_idle := rvvCoreWrapper.io.rvv_idle
+ io.queue_capacity := rvvCoreWrapper.io.queue_capacity
val vstart_wdata = MuxCase(vstart, Seq(
rvvCoreWrapper.io.vcsr_valid -> rvvCoreWrapper.io.vcsr_vstart,
diff --git a/hdl/chisel/src/kelvin/rvv/RvvDecode.scala b/hdl/chisel/src/kelvin/rvv/RvvDecode.scala
index 03f1b7d..1a85a16 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvDecode.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvDecode.scala
@@ -45,6 +45,23 @@
bits(7, 5)
}
+ // If this instruction is a "reduction" instruction, that requires vstart=0
+ def isReduction(): Bool = {
+ (opcode === RvvCompressedOpcode.RVVALU) && (funct3() === "b010".U) &&
+ MuxLookup(funct6(), false.B)(Seq(
+ "b000000".U -> true.B, // vredsum
+ "b000001".U -> true.B, // vredand
+ "b000010".U -> true.B, // vredor
+ "b000011".U -> true.B, // vredxor
+ "b000100".U -> true.B, // vredminu
+ "b000101".U -> true.B, // vredmin
+ "b000110".U -> true.B, // vredmaxu
+ "b000111".U -> true.B, // vredmax
+ "b110000".U -> true.B, // vwredsumu
+ "b110001".U -> true.B, // vwredsum
+ ))
+ }
+
// "Addressing Mode" for loads/store (see Section 7.2 of RVV Spec)
def mop: RvvAddressingMode.Type = {
RvvAddressingMode(bits(20, 19))
diff --git a/hdl/chisel/src/kelvin/scalar/Decode.scala b/hdl/chisel/src/kelvin/scalar/Decode.scala
index b808286..828666c 100644
--- a/hdl/chisel/src/kelvin/scalar/Decode.scala
+++ b/hdl/chisel/src/kelvin/scalar/Decode.scala
@@ -263,6 +263,8 @@
val jalrFault = Output(Vec(p.instructionLanes, Bool()))
val bxxFault = Output(Vec(p.instructionLanes, Bool()))
val undefFault = Output(Vec(p.instructionLanes, Bool()))
+ val rvvFault = Option.when(p.enableRvv)(
+ Output(Vec(p.instructionLanes, Bool())))
val bruTarget = Output(Vec(p.instructionLanes, UInt(32.W)))
val jalrTarget = Input(Vec(p.instructionLanes, new RegfileBranchTargetIO))
@@ -420,12 +422,27 @@
val fence = decodedInsts.map(x => x.isFency() && (io.mactive || io.lsuActive))
// ---------------------------------------------------------------------------
+ // Csr interlock
+ val csrInterlock = (0 until p.instructionLanes).map(i =>
+ if (i == 0) {
+ true.B
+ } else {
+ !decodedInsts(0).isCsr()
+ }
+ )
+
+ // ---------------------------------------------------------------------------
// Rvv config interlock rules
// RVV Load store unit requires valid config state on dispatch.
- val rvvConfigInterlock = if (p.enableRvv) {
+ val configInvalid = if (p.enableRvv) {
val configChange = decodedInsts.map(
x => x.rvv.get.valid && x.rvv.get.bits.isVset())
- val configInvalid = configChange.scan(!io.rvvState.get.valid)(_ || _)
+ configChange.scan(!io.rvvState.get.valid)(_ || _)
+ } else {
+ Seq.fill(p.instructionLanes)(false.B)
+ }
+
+ val rvvConfigInterlock = if (p.enableRvv) {
val canDispatchRvv = (0 until p.instructionLanes).map(i =>
!decodedInsts(i).rvv.get.valid || // Don't lock non-rvv
!decodedInsts(i).rvv.get.bits.isLoadStore() || // Non-LSU can handle change
@@ -437,6 +454,21 @@
}
// ---------------------------------------------------------------------------
+ // Rvv reduction
+ // Don't allow reduction instructions to execute if vstart != 0
+ val rvvReductionInterlock = if (p.enableRvv) {
+ (0 until p.instructionLanes).map(i => {
+ val invalidReduction =
+ decodedInsts(i).rvv.get.valid &&
+ decodedInsts(i).rvv.get.bits.isReduction() &&
+ (configInvalid(i) || (io.rvvState.get.bits.vstart =/= 0.U))
+ !invalidReduction
+ })
+ } else {
+ Seq.fill(p.instructionLanes)(true.B)
+ }
+
+ // ---------------------------------------------------------------------------
// Rvv Interlock
val rvvInterlock = if (p.enableRvv) {
val isRvv = decodedInsts.map(x => x.rvv.get.valid)
@@ -499,7 +531,9 @@
!floatWriteAfterWrite(i) && // Avoid WAW hazards
!branchInterlock(i) && // Only branch/alu can be dispatched after a branch
!fence(i) && // Don't dispatch if fence interlocked
+ csrInterlock(i) &&
rvvConfigInterlock(i) && // Rvv interlock rules
+ rvvReductionInterlock(i) && // Don't dispatch reduction if vstart != 0
// rvvLsuInterlock(i) && // Dispatch only one Rvv LsuOp
lsuInterlock(i) && // Ensure lsu instructions can be dispatched into queue
rvvInterlock(i) && // Ensure rvv instructions can be dispatched into queue
@@ -721,6 +755,22 @@
io.inst(i).ready := lastReady(i + 1)
}
+ // Fault handling for RVV
+ if (p.enableRvv) {
+ for (i <- 0 until p.instructionLanes) {
+ io.rvvFault.get(i) := (if (i == 0) {
+ // Return fault if vstart != 0
+ val isReduction = decodedInsts(i).rvv.get.valid &&
+ decodedInsts(0).rvv.get.bits.isReduction()
+ val vStartNotZero = io.rvvState.get.valid &&
+ (io.rvvState.get.bits.vstart =/= 0.U)
+ io.inst(0).valid && isReduction && vStartNotZero
+ } else {
+ false.B
+ })
+ }
+ }
+
for (i <- 0 until p.instructionLanes) {
val d = decodedInsts(i)
val rs3Addr = io.inst(i).bits.inst(31,27)
@@ -796,6 +846,8 @@
io.jalFault(i) := decode(i).io.jalFault && !io.branchTaken && decode(i).io.inst.valid
io.bxxFault(i) := decode(i).io.bxxFault & !io.branchTaken && decode(i).io.inst.valid
io.undefFault(i) := decode(i).io.undefFault & !io.branchTaken && decode(i).io.inst.valid
+ if (p.enableRvv) { io.rvvFault.get(i) := false.B }
+
io.bruTarget(i) := decode(i).io.bruTarget
}
diff --git a/hdl/chisel/src/kelvin/scalar/FaultManager.scala b/hdl/chisel/src/kelvin/scalar/FaultManager.scala
index e8ff77b..5470413 100644
--- a/hdl/chisel/src/kelvin/scalar/FaultManager.scala
+++ b/hdl/chisel/src/kelvin/scalar/FaultManager.scala
@@ -32,6 +32,7 @@
val jalr = Bool()
val bxx = Bool()
val undef = Bool()
+ val rvv = if (p.enableRvv) Some(Bool()) else None
}))
val pc = Input(Vec(p.instructionLanes, new Bundle {
val pc = UInt(32.W)
@@ -51,7 +52,13 @@
val out = Output(Valid(new FaultManagerOutput))
})
- val faults = VecInit((0 until p.instructionLanes).map(x => (io.in.fault(x).csr | io.in.fault(x).jal | io.in.fault(x).jalr | io.in.fault(x).bxx | io.in.fault(x).undef)))
+ val faults = VecInit((0 until p.instructionLanes).map(x => (
+ io.in.fault(x).csr |
+ io.in.fault(x).jal |
+ io.in.fault(x).jalr |
+ io.in.fault(x).bxx |
+ io.in.fault(x).undef |
+ io.in.fault(x).rvv.getOrElse(false.B))))
val fault = faults.reduce(_|_)
val first_fault = PriorityEncoder(faults)
val undef_fault = io.in.fault.map(_.undef).reduce(_|_)
@@ -64,6 +71,8 @@
val jalr_fault_idx = PriorityEncoder(io.in.fault.map(_.jalr))
val bxx_fault = io.in.fault.map(_.bxx).reduce(_|_)
val bxx_fault_idx = PriorityEncoder(io.in.fault.map(_.bxx))
+ val rvv_fault = io.in.fault.map(_.rvv.getOrElse(false.B)).reduce(_|_)
+ val rvv_fault_idx = PriorityEncoder(io.in.fault.map(_.rvv.getOrElse(false.B)))
val instr_access_fault = io.in.memory_fault.valid && io.in.ibus_fault
val load_fault = io.in.memory_fault.valid && !io.in.memory_fault.bits.write && !io.in.ibus_fault
val store_fault = io.in.memory_fault.valid && io.in.memory_fault.bits.write && !io.in.ibus_fault
@@ -84,6 +93,7 @@
(jalr_fault && (jalr_fault_idx === first_fault)) -> 0.U(32.W),
(bxx_fault && (bxx_fault_idx === first_fault)) -> 0.U(32.W),
(undef_fault && (undef_fault_idx === first_fault)) -> 2.U(32.W),
+ (rvv_fault && (rvv_fault_idx === first_fault)) -> 2.U(32.W),
))
io.out.bits.mtval := MuxCase(0.U(32.W), Seq(
load_fault -> io.in.memory_fault.bits.addr,
@@ -94,5 +104,6 @@
(jalr_fault && (jalr_fault_idx === first_fault)) -> (io.in.jalr(jalr_fault_idx).target & "xFFFFFFFE".U),
(bxx_fault && (bxx_fault_idx === first_fault)) -> 0.U(32.W),
(undef_fault && (undef_fault_idx === first_fault)) -> io.in.undef(undef_fault_idx).inst,
+ (rvv_fault && (rvv_fault_idx === first_fault)) -> io.in.undef(rvv_fault_idx).inst,
))
}
\ No newline at end of file
diff --git a/hdl/chisel/src/kelvin/scalar/SCore.scala b/hdl/chisel/src/kelvin/scalar/SCore.scala
index 83f4c54..80cbb86 100644
--- a/hdl/chisel/src/kelvin/scalar/SCore.scala
+++ b/hdl/chisel/src/kelvin/scalar/SCore.scala
@@ -142,6 +142,9 @@
fault_manager.io.in.fault(i).jalr := dispatch.io.jalrFault(i)
fault_manager.io.in.fault(i).bxx := dispatch.io.bxxFault(i)
fault_manager.io.in.fault(i).undef := dispatch.io.undefFault(i)
+ if (p.enableRvv) {
+ fault_manager.io.in.fault(i).rvv.get := dispatch.io.rvvFault.get(i)
+ }
fault_manager.io.in.pc(i).pc := fetch.io.inst.lanes(i).bits.addr
fault_manager.io.in.jalr(i).target := regfile.io.target(i).data
fault_manager.io.in.undef(i).inst := fetch.io.inst.lanes(i).bits.inst
diff --git a/kelvin_test_utils/core_mini_axi_interface.py b/kelvin_test_utils/core_mini_axi_interface.py
index 78c3bc8..162e8b5 100644
--- a/kelvin_test_utils/core_mini_axi_interface.py
+++ b/kelvin_test_utils/core_mini_axi_interface.py
@@ -801,6 +801,15 @@
timeout_cycles = timeout_cycles - 1
assert timeout_cycles > 0
+ async def wait_for_fault(self, timeout_cycles=1000):
+ cycle_count = 0
+ while self.dut.io_fault.value != 1 and timeout_cycles > 0:
+ await ClockCycles(self.dut.io_aclk, 1)
+ timeout_cycles = timeout_cycles - 1
+ cycle_count += 1
+ assert timeout_cycles > 0
+ return cycle_count
+
async def watch(self, addr, timeout_cycles=1_000_000):
elem_addr = addr % 16
line_addr = addr - elem_addr
diff --git a/kelvin_test_utils/sim_test_fixture.py b/kelvin_test_utils/sim_test_fixture.py
index b387213..ccb2c82 100644
--- a/kelvin_test_utils/sim_test_fixture.py
+++ b/kelvin_test_utils/sim_test_fixture.py
@@ -54,3 +54,10 @@
async def run_to_halt(self, timeout_cycles=10000):
await self.core_mini_axi.execute_from(self.entry_point)
return (await self.core_mini_axi.wait_for_halted(timeout_cycles=timeout_cycles))
+
+ async def run_to_fault(self, timeout_cycles=10000):
+ await self.core_mini_axi.execute_from(self.entry_point)
+ return (await self.core_mini_axi.wait_for_fault(timeout_cycles=timeout_cycles))
+
+ def fault(self):
+ return (self.core_mini_axi.dut.io_fault.value == 1)
\ No newline at end of file
diff --git a/tests/cocotb/BUILD b/tests/cocotb/BUILD
index 3045e5b..1dfea33 100644
--- a/tests/cocotb/BUILD
+++ b/tests/cocotb/BUILD
@@ -190,6 +190,7 @@
RVV_ARITHMETIC_TESTCASES = [
"arithmetic_m1_vanilla_ops",
"reduction_m1_vanilla_ops",
+ "reduction_m1_failure_ops",
"widen_math_ops_test_impl",
]
# END_TESTCASES_FOR_rvv_arithmetic_cocotb_test
diff --git a/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc b/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc
index 780e3e4..e7a06ab 100644
--- a/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc
+++ b/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc
@@ -16,16 +16,36 @@
#include <riscv_vector.h>
+#define __ATTRIBUTE_IN_DTCM__ \
+ __attribute__((section(".data"))) __attribute__((aligned(16)))
-{DTYPE}_t in_buf_1[{IN_DATA_SIZE}] __attribute__((section(".data"))) __attribute__((aligned(16)));
-{DTYPE}_t scalar_input __attribute__((section(".data"))) __attribute__((aligned(16)));
-{DTYPE}_t out_buf __attribute__((section(".data"))) __attribute__((aligned(16)));
+{DTYPE}_t in_buf_1[{IN_DATA_SIZE}] __ATTRIBUTE_IN_DTCM__;
+{DTYPE}_t scalar_input __ATTRIBUTE_IN_DTCM__;
+{DTYPE}_t out_buf __ATTRIBUTE_IN_DTCM__;
+uint32_t vstart __ATTRIBUTE_IN_DTCM__ = 0;
+uint32_t vl __ATTRIBUTE_IN_DTCM__ = {NUM_OPERANDS};
+uint32_t faulted __ATTRIBUTE_IN_DTCM__ = 0;
+uint32_t mcause __ATTRIBUTE_IN_DTCM__ = 0;
+
+// Fault handler to log fault
+extern "C" {
+void kelvin_exception_handler() {
+ faulted = 1;
+ uint32_t local_mcause;
+ asm volatile("csrr %0, mcause" : "=r"(local_mcause));
+ mcause = local_mcause;
+
+ asm volatile("ebreak");
+ while (1) {}
+}
+}
void {REDUCTION_OP}_{SIGN}{SEW}_m1(const {DTYPE}_t* in_buf_1, const {DTYPE}_t scalar_input, {DTYPE}_t* out_buf){
- v{DTYPE}m1_t input_v1 = __riscv_vle{SEW}_v_{SIGN}{SEW}m1(in_buf_1, {NUM_OPERANDS});
- v{DTYPE}m1_t input_s1 = __riscv_vmv_v_x_{SIGN}{SEW}m1(scalar_input, {NUM_OPERANDS});
- v{DTYPE}m1_t {REDUCTION_OP}_result = __riscv_v{REDUCTION_OP}_vs_{SIGN}{SEW}m1_{SIGN}{SEW}m1(input_v1, input_s1, {NUM_OPERANDS});
+ v{DTYPE}m1_t input_v1 = __riscv_vle{SEW}_v_{SIGN}{SEW}m1(in_buf_1, vl);
+ v{DTYPE}m1_t input_s1 = __riscv_vmv_v_x_{SIGN}{SEW}m1(scalar_input, vl);
+ asm("csrw vstart, %0" : : "r"(vstart));
+ v{DTYPE}m1_t {REDUCTION_OP}_result = __riscv_v{REDUCTION_OP}_vs_{SIGN}{SEW}m1_{SIGN}{SEW}m1(input_v1, input_s1, vl);
*out_buf = __riscv_vmv_x_s_{SIGN}{SEW}m1_{SIGN}{SEW}({REDUCTION_OP}_result);
}
diff --git a/tests/cocotb/rvv_arithmetic_cocotb_test.py b/tests/cocotb/rvv_arithmetic_cocotb_test.py
index 607064d..d85d827 100644
--- a/tests/cocotb/rvv_arithmetic_cocotb_test.py
+++ b/tests/cocotb/rvv_arithmetic_cocotb_test.py
@@ -191,6 +191,64 @@
num_bytes=16)
+async def reduction_m1_failure_test(dut, dtypes, math_ops: str, num_bytes: int):
+ """RVV reduction test template.
+
+ Each test performs a reduction op loading `in_buf_1` and storing the output to `out_buf`.
+ """
+ m1_failure_op_elfs = [
+ f"rvv_{math_op}_{dtype}_m1.elf" for math_op in math_ops
+ for dtype in dtypes
+ ]
+ pattern_extract = re.compile("rvv_(.*)_(.*)_m1.elf")
+
+ r = runfiles.Create()
+ fixture = await Fixture.Create(dut)
+
+ with tqdm.tqdm(m1_failure_op_elfs) as t:
+ for elf_name in t:
+ t.set_postfix({"binary": os.path.basename(elf_name)})
+ elf_path = r.Rlocation(
+ f"kelvin_hw/tests/cocotb/rvv/arithmetics/{elf_name}")
+ await fixture.load_elf_and_lookup_symbols(
+ elf_path,
+ ['in_buf_1', 'scalar_input', 'out_buf', 'vstart', 'vl',
+ 'faulted', 'mcause'],
+ )
+ math_op, dtype = pattern_extract.match(elf_name).groups()
+ np_type = STR_TO_NP_TYPE[dtype]
+ itemsize = np.dtype(np_type).itemsize
+ num_test_values = int(num_bytes / np.dtype(np_type).itemsize)
+
+ min_value = np.iinfo(np_type).min
+ max_value = np.iinfo(np_type).max + 1 # One above.
+ input_1 = np.random.randint(min_value,
+ max_value,
+ num_test_values,
+ dtype=np_type)
+ input_2 = np.random.randint(min_value, max_value, 1, dtype=np_type)
+
+ await fixture.write('in_buf_1', input_1)
+ await fixture.write('scalar_input', input_2)
+ await fixture.write('vstart', np.array([1], dtype=np.uint32))
+ await fixture.write('out_buf', np.zeros(1, dtype=np_type))
+
+ await fixture.run_to_halt()
+ faulted = (await fixture.read('faulted', 4)).view(np.uint32)
+ mcause = (await fixture.read('mcause', 4)).view(np.uint32)
+ assert(faulted == True)
+ assert(mcause == 0x2) # Invalid instruction
+
+
+@cocotb.test()
+async def reduction_m1_failure_ops(dut):
+ await reduction_m1_failure_test(
+ dut=dut,
+ dtypes=["int8", "int16", "int32", "uint8", "uint16", "uint32"],
+ math_ops=["redsum", "redmin", "redmax"],
+ num_bytes=16)
+
+
async def _widen_math_ops_test_impl(
dut,
dtypes,
diff --git a/toolchain/crt/BUILD b/toolchain/crt/BUILD
index 611960b..70279e8 100644
--- a/toolchain/crt/BUILD
+++ b/toolchain/crt/BUILD
@@ -18,6 +18,7 @@
name = "crt",
srcs = [
"crt.S",
+ "kelvin_exceptions.cc",
"kelvin_gloss.cc",
"kelvin_start.S",
],
diff --git a/toolchain/crt/kelvin_exceptions.cc b/toolchain/crt/kelvin_exceptions.cc
new file mode 100644
index 0000000..0a7297e
--- /dev/null
+++ b/toolchain/crt/kelvin_exceptions.cc
@@ -0,0 +1,20 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+extern "C" {
+void __attribute__((weak)) kelvin_exception_handler() {
+ asm volatile("ebreak");
+ while (1) {}
+}
+}
\ No newline at end of file
diff --git a/toolchain/crt/kelvin_start.S b/toolchain/crt/kelvin_start.S
index d87ffce..dc44a8e 100644
--- a/toolchain/crt/kelvin_start.S
+++ b/toolchain/crt/kelvin_start.S
@@ -16,6 +16,8 @@
// A starting functions for simple kelvin programs.
+.extern kelvin_exception_handler
+
/**
* Entry point.
*/
@@ -82,7 +84,7 @@
# simply call ebreak.
# Users who require real trap handling should
# install their own trap vector.
- la t0, failure
+ la t0, kelvin_exception_handler
csrw mtvec, t0
# Set up sentinel value in _ret