Fault for reduction instructions when vstart != 0.

Change-Id: I69c8f28402e760971d23b63776bec3afd68110b7
diff --git a/hdl/chisel/src/kelvin/rvv/RvvCore.scala b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
index 75ae92a..2917755 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvCore.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvCore.scala
@@ -459,16 +459,22 @@
   io.rvv2lsu <> rvvCoreWrapper.io.rvv2lsu
   io.lsu2rvv <> rvvCoreWrapper.io.lsu2rvv
 
-  io.configState.valid        := rvvCoreWrapper.io.configStateValid
-  io.configState.bits.vl      := rvvCoreWrapper.io.configVl
-  io.configState.bits.vstart  := rvvCoreWrapper.io.configVstart
-  io.configState.bits.ma      := rvvCoreWrapper.io.configMa
-  io.configState.bits.ta      := rvvCoreWrapper.io.configTa
-  io.configState.bits.xrm     := rvvCoreWrapper.io.configXrm
-  io.configState.bits.sew     := rvvCoreWrapper.io.configSew
-  io.configState.bits.lmul    := rvvCoreWrapper.io.configLmul
-  io.rvv_idle                 := rvvCoreWrapper.io.rvv_idle
-  io.queue_capacity           := rvvCoreWrapper.io.queue_capacity
+  // Conservatively mark config state as invalid the cycle when CSR instruction
+  // updates vstart, vxrm or vxsat.
+  io.configState.valid := rvvCoreWrapper.io.configStateValid &&
+      !rvvCoreWrapper.io.vcsr_valid &&
+      !io.csr.vstart_write.valid &&
+      !io.csr.vxrm_write.valid &&
+      !io.csr.vxsat_write.valid
+  io.configState.bits.vl     := rvvCoreWrapper.io.configVl
+  io.configState.bits.vstart := rvvCoreWrapper.io.configVstart
+  io.configState.bits.ma     := rvvCoreWrapper.io.configMa
+  io.configState.bits.ta     := rvvCoreWrapper.io.configTa
+  io.configState.bits.xrm    := rvvCoreWrapper.io.configXrm
+  io.configState.bits.sew    := rvvCoreWrapper.io.configSew
+  io.configState.bits.lmul   := rvvCoreWrapper.io.configLmul
+  io.rvv_idle                := rvvCoreWrapper.io.rvv_idle
+  io.queue_capacity          := rvvCoreWrapper.io.queue_capacity
 
   val vstart_wdata = MuxCase(vstart, Seq(
       rvvCoreWrapper.io.vcsr_valid -> rvvCoreWrapper.io.vcsr_vstart,
diff --git a/hdl/chisel/src/kelvin/rvv/RvvDecode.scala b/hdl/chisel/src/kelvin/rvv/RvvDecode.scala
index 03f1b7d..1a85a16 100644
--- a/hdl/chisel/src/kelvin/rvv/RvvDecode.scala
+++ b/hdl/chisel/src/kelvin/rvv/RvvDecode.scala
@@ -45,6 +45,23 @@
     bits(7, 5)
   }
 
+  // If this instruction is a "reduction" instruction, that requires vstart=0
+  def isReduction(): Bool = {
+    (opcode === RvvCompressedOpcode.RVVALU) && (funct3() === "b010".U) &&
+        MuxLookup(funct6(), false.B)(Seq(
+            "b000000".U -> true.B,  // vredsum
+            "b000001".U -> true.B,  // vredand
+            "b000010".U -> true.B,  // vredor
+            "b000011".U -> true.B,  // vredxor
+            "b000100".U -> true.B,  // vredminu
+            "b000101".U -> true.B,  // vredmin
+            "b000110".U -> true.B,  // vredmaxu
+            "b000111".U -> true.B,  // vredmax
+            "b110000".U -> true.B,  // vwredsumu
+            "b110001".U -> true.B,  // vwredsum
+        ))
+  }
+
   // "Addressing Mode" for loads/store (see Section 7.2 of RVV Spec)
   def mop: RvvAddressingMode.Type = {
     RvvAddressingMode(bits(20, 19))
diff --git a/hdl/chisel/src/kelvin/scalar/Decode.scala b/hdl/chisel/src/kelvin/scalar/Decode.scala
index b808286..828666c 100644
--- a/hdl/chisel/src/kelvin/scalar/Decode.scala
+++ b/hdl/chisel/src/kelvin/scalar/Decode.scala
@@ -263,6 +263,8 @@
     val jalrFault = Output(Vec(p.instructionLanes, Bool()))
     val bxxFault = Output(Vec(p.instructionLanes, Bool()))
     val undefFault = Output(Vec(p.instructionLanes, Bool()))
+    val rvvFault = Option.when(p.enableRvv)(
+        Output(Vec(p.instructionLanes, Bool())))
     val bruTarget = Output(Vec(p.instructionLanes, UInt(32.W)))
     val jalrTarget = Input(Vec(p.instructionLanes, new RegfileBranchTargetIO))
 
@@ -420,12 +422,27 @@
   val fence = decodedInsts.map(x => x.isFency() && (io.mactive || io.lsuActive))
 
   // ---------------------------------------------------------------------------
+  // Csr interlock
+  val csrInterlock = (0 until p.instructionLanes).map(i =>
+    if (i == 0) {
+      true.B
+    } else {
+      !decodedInsts(0).isCsr()
+    }
+  )
+
+  // ---------------------------------------------------------------------------
   // Rvv config interlock rules
   // RVV Load store unit requires valid config state on dispatch.
-  val rvvConfigInterlock = if (p.enableRvv) {
+  val configInvalid = if (p.enableRvv) {
     val configChange = decodedInsts.map(
         x => x.rvv.get.valid && x.rvv.get.bits.isVset())
-    val configInvalid = configChange.scan(!io.rvvState.get.valid)(_ || _)
+    configChange.scan(!io.rvvState.get.valid)(_ || _)
+  } else {
+    Seq.fill(p.instructionLanes)(false.B)
+  }
+
+  val rvvConfigInterlock = if (p.enableRvv) {
     val canDispatchRvv = (0 until p.instructionLanes).map(i =>
         !decodedInsts(i).rvv.get.valid || // Don't lock non-rvv
         !decodedInsts(i).rvv.get.bits.isLoadStore() || // Non-LSU can handle change
@@ -437,6 +454,21 @@
   }
 
   // ---------------------------------------------------------------------------
+  // Rvv reduction
+  // Don't allow reduction instructions to execute if vstart != 0
+  val rvvReductionInterlock = if (p.enableRvv) {
+    (0 until p.instructionLanes).map(i => {
+        val invalidReduction =
+            decodedInsts(i).rvv.get.valid &&
+            decodedInsts(i).rvv.get.bits.isReduction() &&
+            (configInvalid(i) || (io.rvvState.get.bits.vstart =/= 0.U))
+        !invalidReduction
+    })
+  } else {
+    Seq.fill(p.instructionLanes)(true.B)
+  }
+
+  // ---------------------------------------------------------------------------
   // Rvv Interlock
   val rvvInterlock = if (p.enableRvv) {
     val isRvv = decodedInsts.map(x => x.rvv.get.valid)
@@ -499,7 +531,9 @@
       !floatWriteAfterWrite(i) && // Avoid WAW hazards
       !branchInterlock(i) && // Only branch/alu can be dispatched after a branch
       !fence(i) &&           // Don't dispatch if fence interlocked
+      csrInterlock(i) &&
       rvvConfigInterlock(i) &&     // Rvv interlock rules
+      rvvReductionInterlock(i) && // Don't dispatch reduction if vstart != 0
       // rvvLsuInterlock(i) &&  // Dispatch only one Rvv LsuOp
       lsuInterlock(i) && // Ensure lsu instructions can be dispatched into queue
       rvvInterlock(i) && // Ensure rvv instructions can be dispatched into queue
@@ -721,6 +755,22 @@
     io.inst(i).ready := lastReady(i + 1)
   }
 
+  // Fault handling for RVV
+  if (p.enableRvv) {
+    for (i <- 0 until p.instructionLanes) {
+      io.rvvFault.get(i) := (if (i == 0) {
+        // Return fault if vstart != 0
+        val isReduction = decodedInsts(i).rvv.get.valid &&
+            decodedInsts(0).rvv.get.bits.isReduction()
+        val vStartNotZero = io.rvvState.get.valid &&
+            (io.rvvState.get.bits.vstart =/= 0.U)
+        io.inst(0).valid && isReduction && vStartNotZero
+      } else {
+        false.B
+      })
+    }
+  }
+
   for (i <- 0 until p.instructionLanes) {
     val d = decodedInsts(i)
     val rs3Addr = io.inst(i).bits.inst(31,27)
@@ -796,6 +846,8 @@
     io.jalFault(i) := decode(i).io.jalFault && !io.branchTaken && decode(i).io.inst.valid
     io.bxxFault(i) := decode(i).io.bxxFault & !io.branchTaken && decode(i).io.inst.valid
     io.undefFault(i) := decode(i).io.undefFault & !io.branchTaken && decode(i).io.inst.valid
+    if (p.enableRvv) { io.rvvFault.get(i) := false.B }
+
     io.bruTarget(i) := decode(i).io.bruTarget
   }
 
diff --git a/hdl/chisel/src/kelvin/scalar/FaultManager.scala b/hdl/chisel/src/kelvin/scalar/FaultManager.scala
index e8ff77b..5470413 100644
--- a/hdl/chisel/src/kelvin/scalar/FaultManager.scala
+++ b/hdl/chisel/src/kelvin/scalar/FaultManager.scala
@@ -32,6 +32,7 @@
         val jalr = Bool()
         val bxx = Bool()
         val undef = Bool()
+        val rvv = if (p.enableRvv) Some(Bool()) else None
       }))
       val pc = Input(Vec(p.instructionLanes, new Bundle {
         val pc = UInt(32.W)
@@ -51,7 +52,13 @@
     val out = Output(Valid(new FaultManagerOutput))
   })
 
-  val faults = VecInit((0 until p.instructionLanes).map(x => (io.in.fault(x).csr | io.in.fault(x).jal | io.in.fault(x).jalr | io.in.fault(x).bxx | io.in.fault(x).undef)))
+  val faults = VecInit((0 until p.instructionLanes).map(x => (
+      io.in.fault(x).csr |
+      io.in.fault(x).jal |
+      io.in.fault(x).jalr |
+      io.in.fault(x).bxx |
+      io.in.fault(x).undef |
+      io.in.fault(x).rvv.getOrElse(false.B))))
   val fault = faults.reduce(_|_)
   val first_fault = PriorityEncoder(faults)
   val undef_fault = io.in.fault.map(_.undef).reduce(_|_)
@@ -64,6 +71,8 @@
   val jalr_fault_idx = PriorityEncoder(io.in.fault.map(_.jalr))
   val bxx_fault = io.in.fault.map(_.bxx).reduce(_|_)
   val bxx_fault_idx = PriorityEncoder(io.in.fault.map(_.bxx))
+  val rvv_fault = io.in.fault.map(_.rvv.getOrElse(false.B)).reduce(_|_)
+  val rvv_fault_idx = PriorityEncoder(io.in.fault.map(_.rvv.getOrElse(false.B)))
   val instr_access_fault = io.in.memory_fault.valid && io.in.ibus_fault
   val load_fault = io.in.memory_fault.valid && !io.in.memory_fault.bits.write && !io.in.ibus_fault
   val store_fault = io.in.memory_fault.valid && io.in.memory_fault.bits.write && !io.in.ibus_fault
@@ -84,6 +93,7 @@
     (jalr_fault && (jalr_fault_idx === first_fault)) -> 0.U(32.W),
     (bxx_fault && (bxx_fault_idx === first_fault)) -> 0.U(32.W),
     (undef_fault && (undef_fault_idx === first_fault)) -> 2.U(32.W),
+    (rvv_fault && (rvv_fault_idx === first_fault)) -> 2.U(32.W),
   ))
   io.out.bits.mtval := MuxCase(0.U(32.W), Seq(
     load_fault -> io.in.memory_fault.bits.addr,
@@ -94,5 +104,6 @@
     (jalr_fault && (jalr_fault_idx === first_fault)) -> (io.in.jalr(jalr_fault_idx).target & "xFFFFFFFE".U),
     (bxx_fault && (bxx_fault_idx === first_fault)) -> 0.U(32.W),
     (undef_fault && (undef_fault_idx === first_fault)) -> io.in.undef(undef_fault_idx).inst,
+    (rvv_fault && (rvv_fault_idx === first_fault)) -> io.in.undef(rvv_fault_idx).inst,
   ))
 }
\ No newline at end of file
diff --git a/hdl/chisel/src/kelvin/scalar/SCore.scala b/hdl/chisel/src/kelvin/scalar/SCore.scala
index 83f4c54..80cbb86 100644
--- a/hdl/chisel/src/kelvin/scalar/SCore.scala
+++ b/hdl/chisel/src/kelvin/scalar/SCore.scala
@@ -142,6 +142,9 @@
     fault_manager.io.in.fault(i).jalr := dispatch.io.jalrFault(i)
     fault_manager.io.in.fault(i).bxx := dispatch.io.bxxFault(i)
     fault_manager.io.in.fault(i).undef := dispatch.io.undefFault(i)
+    if (p.enableRvv) {
+      fault_manager.io.in.fault(i).rvv.get := dispatch.io.rvvFault.get(i)
+    }
     fault_manager.io.in.pc(i).pc := fetch.io.inst.lanes(i).bits.addr
     fault_manager.io.in.jalr(i).target := regfile.io.target(i).data
     fault_manager.io.in.undef(i).inst := fetch.io.inst.lanes(i).bits.inst
diff --git a/kelvin_test_utils/core_mini_axi_interface.py b/kelvin_test_utils/core_mini_axi_interface.py
index 78c3bc8..162e8b5 100644
--- a/kelvin_test_utils/core_mini_axi_interface.py
+++ b/kelvin_test_utils/core_mini_axi_interface.py
@@ -801,6 +801,15 @@
         timeout_cycles = timeout_cycles - 1
         assert timeout_cycles > 0
 
+  async def wait_for_fault(self, timeout_cycles=1000):
+    cycle_count = 0
+    while self.dut.io_fault.value != 1 and timeout_cycles > 0:
+      await ClockCycles(self.dut.io_aclk, 1)
+      timeout_cycles = timeout_cycles - 1
+      cycle_count += 1
+    assert timeout_cycles > 0
+    return cycle_count
+
   async def watch(self, addr, timeout_cycles=1_000_000):
     elem_addr = addr % 16
     line_addr = addr - elem_addr
diff --git a/kelvin_test_utils/sim_test_fixture.py b/kelvin_test_utils/sim_test_fixture.py
index b387213..ccb2c82 100644
--- a/kelvin_test_utils/sim_test_fixture.py
+++ b/kelvin_test_utils/sim_test_fixture.py
@@ -54,3 +54,10 @@
     async def run_to_halt(self, timeout_cycles=10000):
         await self.core_mini_axi.execute_from(self.entry_point)
         return (await self.core_mini_axi.wait_for_halted(timeout_cycles=timeout_cycles))
+
+    async def run_to_fault(self, timeout_cycles=10000):
+        await self.core_mini_axi.execute_from(self.entry_point)
+        return (await self.core_mini_axi.wait_for_fault(timeout_cycles=timeout_cycles))
+
+    def fault(self):
+        return (self.core_mini_axi.dut.io_fault.value == 1)
\ No newline at end of file
diff --git a/tests/cocotb/BUILD b/tests/cocotb/BUILD
index 3045e5b..1dfea33 100644
--- a/tests/cocotb/BUILD
+++ b/tests/cocotb/BUILD
@@ -190,6 +190,7 @@
 RVV_ARITHMETIC_TESTCASES = [
     "arithmetic_m1_vanilla_ops",
     "reduction_m1_vanilla_ops",
+    "reduction_m1_failure_ops",
     "widen_math_ops_test_impl",
 ]
 # END_TESTCASES_FOR_rvv_arithmetic_cocotb_test
diff --git a/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc b/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc
index 780e3e4..e7a06ab 100644
--- a/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc
+++ b/tests/cocotb/rvv/arithmetics/rvv_reduction_template.cc
@@ -16,16 +16,36 @@
 
 #include <riscv_vector.h>
 
+#define __ATTRIBUTE_IN_DTCM__ \
+    __attribute__((section(".data"))) __attribute__((aligned(16)))
 
-{DTYPE}_t in_buf_1[{IN_DATA_SIZE}] __attribute__((section(".data"))) __attribute__((aligned(16)));
-{DTYPE}_t scalar_input __attribute__((section(".data"))) __attribute__((aligned(16)));
-{DTYPE}_t out_buf __attribute__((section(".data"))) __attribute__((aligned(16)));
+{DTYPE}_t in_buf_1[{IN_DATA_SIZE}] __ATTRIBUTE_IN_DTCM__;
+{DTYPE}_t scalar_input __ATTRIBUTE_IN_DTCM__;
+{DTYPE}_t out_buf __ATTRIBUTE_IN_DTCM__;
+uint32_t vstart __ATTRIBUTE_IN_DTCM__ = 0;
+uint32_t vl __ATTRIBUTE_IN_DTCM__ = {NUM_OPERANDS};
+uint32_t faulted __ATTRIBUTE_IN_DTCM__ = 0;
+uint32_t mcause __ATTRIBUTE_IN_DTCM__ = 0;
+
+// Fault handler to log fault
+extern "C" {
+void kelvin_exception_handler() {
+  faulted = 1;
+  uint32_t local_mcause;
+  asm volatile("csrr %0, mcause" : "=r"(local_mcause));
+  mcause = local_mcause;
+
+  asm volatile("ebreak");
+  while (1) {}
+}
+}
 
 void {REDUCTION_OP}_{SIGN}{SEW}_m1(const {DTYPE}_t* in_buf_1, const {DTYPE}_t scalar_input, {DTYPE}_t* out_buf){
 
-    v{DTYPE}m1_t input_v1 = __riscv_vle{SEW}_v_{SIGN}{SEW}m1(in_buf_1, {NUM_OPERANDS});
-    v{DTYPE}m1_t input_s1 = __riscv_vmv_v_x_{SIGN}{SEW}m1(scalar_input, {NUM_OPERANDS});
-    v{DTYPE}m1_t {REDUCTION_OP}_result = __riscv_v{REDUCTION_OP}_vs_{SIGN}{SEW}m1_{SIGN}{SEW}m1(input_v1, input_s1, {NUM_OPERANDS});
+    v{DTYPE}m1_t input_v1 = __riscv_vle{SEW}_v_{SIGN}{SEW}m1(in_buf_1, vl);
+    v{DTYPE}m1_t input_s1 = __riscv_vmv_v_x_{SIGN}{SEW}m1(scalar_input, vl);
+    asm("csrw vstart, %0" : : "r"(vstart));
+    v{DTYPE}m1_t {REDUCTION_OP}_result = __riscv_v{REDUCTION_OP}_vs_{SIGN}{SEW}m1_{SIGN}{SEW}m1(input_v1, input_s1, vl);
     *out_buf = __riscv_vmv_x_s_{SIGN}{SEW}m1_{SIGN}{SEW}({REDUCTION_OP}_result);
 }
 
diff --git a/tests/cocotb/rvv_arithmetic_cocotb_test.py b/tests/cocotb/rvv_arithmetic_cocotb_test.py
index 607064d..d85d827 100644
--- a/tests/cocotb/rvv_arithmetic_cocotb_test.py
+++ b/tests/cocotb/rvv_arithmetic_cocotb_test.py
@@ -191,6 +191,64 @@
         num_bytes=16)
 
 
+async def reduction_m1_failure_test(dut, dtypes, math_ops: str, num_bytes: int):
+    """RVV reduction test template.
+
+    Each test performs a reduction op loading `in_buf_1` and storing the output to `out_buf`.
+    """
+    m1_failure_op_elfs = [
+        f"rvv_{math_op}_{dtype}_m1.elf" for math_op in math_ops
+        for dtype in dtypes
+    ]
+    pattern_extract = re.compile("rvv_(.*)_(.*)_m1.elf")
+
+    r = runfiles.Create()
+    fixture = await Fixture.Create(dut)
+
+    with tqdm.tqdm(m1_failure_op_elfs) as t:
+        for elf_name in t:
+            t.set_postfix({"binary": os.path.basename(elf_name)})
+            elf_path = r.Rlocation(
+                f"kelvin_hw/tests/cocotb/rvv/arithmetics/{elf_name}")
+            await fixture.load_elf_and_lookup_symbols(
+                elf_path,
+                ['in_buf_1', 'scalar_input', 'out_buf', 'vstart', 'vl',
+                 'faulted', 'mcause'],
+            )
+            math_op, dtype = pattern_extract.match(elf_name).groups()
+            np_type = STR_TO_NP_TYPE[dtype]
+            itemsize = np.dtype(np_type).itemsize
+            num_test_values = int(num_bytes / np.dtype(np_type).itemsize)
+
+            min_value = np.iinfo(np_type).min
+            max_value = np.iinfo(np_type).max + 1  # One above.
+            input_1 = np.random.randint(min_value,
+                                        max_value,
+                                        num_test_values,
+                                        dtype=np_type)
+            input_2 = np.random.randint(min_value, max_value, 1, dtype=np_type)
+
+            await fixture.write('in_buf_1', input_1)
+            await fixture.write('scalar_input', input_2)
+            await fixture.write('vstart', np.array([1], dtype=np.uint32))
+            await fixture.write('out_buf', np.zeros(1, dtype=np_type))
+
+            await fixture.run_to_halt()
+            faulted = (await fixture.read('faulted', 4)).view(np.uint32)
+            mcause = (await fixture.read('mcause', 4)).view(np.uint32)
+            assert(faulted == True)
+            assert(mcause == 0x2) # Invalid instruction
+
+
+@cocotb.test()
+async def reduction_m1_failure_ops(dut):
+    await reduction_m1_failure_test(
+        dut=dut,
+        dtypes=["int8", "int16", "int32", "uint8", "uint16", "uint32"],
+        math_ops=["redsum", "redmin", "redmax"],
+        num_bytes=16)
+
+
 async def _widen_math_ops_test_impl(
     dut,
     dtypes,
diff --git a/toolchain/crt/BUILD b/toolchain/crt/BUILD
index 611960b..70279e8 100644
--- a/toolchain/crt/BUILD
+++ b/toolchain/crt/BUILD
@@ -18,6 +18,7 @@
     name = "crt",
     srcs = [
         "crt.S",
+        "kelvin_exceptions.cc",
         "kelvin_gloss.cc",
         "kelvin_start.S",
     ],
diff --git a/toolchain/crt/kelvin_exceptions.cc b/toolchain/crt/kelvin_exceptions.cc
new file mode 100644
index 0000000..0a7297e
--- /dev/null
+++ b/toolchain/crt/kelvin_exceptions.cc
@@ -0,0 +1,20 @@
+// Copyright 2025 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+extern "C" {
+void __attribute__((weak)) kelvin_exception_handler() {
+  asm volatile("ebreak");
+  while (1) {}
+}
+}
\ No newline at end of file
diff --git a/toolchain/crt/kelvin_start.S b/toolchain/crt/kelvin_start.S
index d87ffce..dc44a8e 100644
--- a/toolchain/crt/kelvin_start.S
+++ b/toolchain/crt/kelvin_start.S
@@ -16,6 +16,8 @@
 
 // A starting functions for simple kelvin programs.
 
+.extern kelvin_exception_handler
+
 /**
  * Entry point.
  */
@@ -82,7 +84,7 @@
         # simply call ebreak.
         # Users who require real trap handling should
         # install their own trap vector.
-        la t0, failure
+        la t0, kelvin_exception_handler
         csrw mtvec, t0
 
         # Set up sentinel value in _ret