Adjust behavior of Mlu stages

Change-Id: I8d67bfe5a6ab68d729c65d4bbdb29e446447aeac
diff --git a/hdl/chisel/src/kelvin/scalar/Mlu.scala b/hdl/chisel/src/kelvin/scalar/Mlu.scala
index e221851..7ea9bfc 100644
--- a/hdl/chisel/src/kelvin/scalar/Mlu.scala
+++ b/hdl/chisel/src/kelvin/scalar/Mlu.scala
@@ -51,8 +51,10 @@
 
 class MluStage2(p: Parameters) extends Bundle {
   val rd = UInt(5.W)
-  val mul = UInt(32.W)
-  val round = UInt(1.W)
+  val op = MluOp()
+  val prod = SInt(66.W)
+  val rs1 = UInt(32.W)
+  val rs2 = UInt(32.W)
 }
 
 class Mlu(p: Parameters) extends Module {
@@ -84,7 +86,6 @@
   val addr2in = stage2Input.bits.rd
   val sel2in = stage2Input.bits.sel
 
-
   val rs1 = (0 until p.instructionLanes).map(x => MuxOR(valid2in & sel2in(x), io.rs1(x).data)).reduce(_ | _)
   val rs2 = (0 until p.instructionLanes).map(x => MuxOR(valid2in & sel2in(x), io.rs2(x).data)).reduce(_ | _)
 
@@ -95,37 +96,44 @@
   val prod = rs1s * rs2s
   assert(prod.getWidth == 66)
 
-  val round = prod(30) && op2in.isOneOf(MluOp.DMULHR) ||
-              prod(31) && (op2in.isOneOf(MluOp.MULHR, MluOp.MULHSUR, MluOp.MULHUR))
-
-  val maxneg = 2.U(2.W)
-  val halfneg = 1.U(2.W)
-  val sat = rs1(29,0) === 0.U && rs2(29,0) === 0.U &&
-            (rs1(31,30) === maxneg && rs2(31,30) === maxneg ||
-              rs1(31,30) === maxneg && rs2(31,30) === halfneg ||
-              rs2(31,30) === maxneg && rs1(31,30) === halfneg)
-
-  val mul = MuxCase(0.U(32.W), Seq(
-    (op2in === MluOp.MUL) -> prod(31, 0),
-    op2in.isOneOf(MluOp.MULH, MluOp.MULHSU, MluOp.MULHU, MluOp.MULHR, MluOp.MULHSUR, MluOp.MULHUR) -> prod(63,32),
-    op2in.isOneOf(MluOp.DMULH, MluOp.DMULHR) -> Mux(sat, Mux(prod(65), 0x7fffffff.U(32.W), Cat(1.U(1.W), 0.U(31.W))), prod(62,31))
-  ))
-
   val stage2 = Wire(Decoupled(new MluStage2(p)))
   stage2.valid := valid2in
   stage2.bits.rd := addr2in
-  stage2.bits.mul := mul
-  stage2.bits.round := round
+  stage2.bits.op := op2in
+  stage2.bits.prod := prod
+  stage2.bits.rs1 := rs1
+  stage2.bits.rs2 := rs2
   stage2Input.ready := stage2.ready
-  val stage3 = Queue(stage2, 1, true)
+
+  val stage3Input = Queue(stage2, 1, true)
+  val op3in = stage3Input.bits.op
+  val prod3in = stage3Input.bits.prod
+  val rs1_3in = stage3Input.bits.rs1
+  val rs2_3in = stage3Input.bits.rs2
+
+  val maxneg = 2.U(2.W)
+  val halfneg = 1.U(2.W)
+  val sat = rs1_3in(29,0) === 0.U && rs2_3in(29,0) === 0.U &&
+            (rs1_3in(31,30) === maxneg && rs2_3in(31,30) === maxneg ||
+              rs1_3in(31,30) === maxneg && rs2_3in(31,30) === halfneg ||
+              rs2_3in(31,30) === maxneg && rs1_3in(31,30) === halfneg)
+
+  val mul = MuxCase(0.U(32.W), Seq(
+    (op3in === MluOp.MUL) -> prod3in(31, 0),
+    op3in.isOneOf(MluOp.MULH, MluOp.MULHSU, MluOp.MULHU, MluOp.MULHR, MluOp.MULHSUR, MluOp.MULHUR) -> prod3in(63,32),
+    op3in.isOneOf(MluOp.DMULH, MluOp.DMULHR) -> Mux(sat, Mux(prod3in(65), 0x7fffffff.U(32.W), Cat(1.U(1.W), 0.U(31.W))), prod3in(62,31))
+  ))
+
+  val round = prod3in(30) && op3in.isOneOf(MluOp.DMULHR) ||
+              prod3in(31) && (op3in.isOneOf(MluOp.MULHR, MluOp.MULHSUR, MluOp.MULHUR))
 
   // Stage 3 output result
   // Multiplier has a registered output.
-  stage3.ready := io.rd.ready
+  stage3Input.ready := io.rd.ready
 
-  io.rd.valid     := stage3.valid
-  io.rd.bits.addr := stage3.bits.rd
-  io.rd.bits.data := stage3.bits.mul + stage3.bits.round
+  io.rd.valid     := stage3Input.valid
+  io.rd.bits.addr := stage3Input.bits.rd
+  io.rd.bits.data := mul + round
 
   // Assertions.
   for (i <- 0 until p.instructionLanes) {