[dif/rv_timer] add function to write to counter

This adds an additional function, dif_rv_timer_counter_write(),
to fix #7921.

Signed-off-by: Timothy Trippel <ttrippel@google.com>
diff --git a/sw/device/lib/dif/dif_rv_timer.c b/sw/device/lib/dif/dif_rv_timer.c
index 3d9a946..2b6051e 100644
--- a/sw/device/lib/dif/dif_rv_timer.c
+++ b/sw/device/lib/dif/dif_rv_timer.c
@@ -7,6 +7,7 @@
 #include <stddef.h>
 
 #include "sw/device/lib/base/bitfield.h"
+
 #include "rv_timer_regs.h"  // Generated.
 
 /**
@@ -158,6 +159,36 @@
   }
 }
 
+dif_rv_timer_result_t dif_rv_timer_counter_write(const dif_rv_timer_t *timer,
+                                                 uint32_t hart_id,
+                                                 uint64_t count) {
+  if (timer == NULL || hart_id >= timer->config.hart_count) {
+    return kDifRvTimerBadArg;
+  }
+
+  // Disable the counter.
+  uint32_t ctrl_reg =
+      mmio_region_read32(timer->base_addr, RV_TIMER_CTRL_REG_OFFSET);
+  uint32_t ctrl_reg_cleared = bitfield_bit32_write(ctrl_reg, hart_id, false);
+  mmio_region_write32(timer->base_addr, RV_TIMER_CTRL_REG_OFFSET,
+                      ctrl_reg_cleared);
+
+  // Write the new count.
+  uint32_t lower_count = count;
+  uint32_t upper_count = count >> 32;
+  mmio_region_write32(timer->base_addr,
+                      reg_for_hart(hart_id, RV_TIMER_TIMER_V_LOWER0_REG_OFFSET),
+                      lower_count);
+  mmio_region_write32(timer->base_addr,
+                      reg_for_hart(hart_id, RV_TIMER_TIMER_V_UPPER0_REG_OFFSET),
+                      upper_count);
+
+  // Re-enable the counter (if it was previously enabled).
+  mmio_region_write32(timer->base_addr, RV_TIMER_CTRL_REG_OFFSET, ctrl_reg);
+
+  return kDifRvTimerOk;
+}
+
 dif_rv_timer_result_t dif_rv_timer_arm(const dif_rv_timer_t *timer,
                                        uint32_t hart_id, uint32_t comp_id,
                                        uint64_t threshold) {
diff --git a/sw/device/lib/dif/dif_rv_timer.h b/sw/device/lib/dif/dif_rv_timer.h
index 471a2df..f73b5a9 100644
--- a/sw/device/lib/dif/dif_rv_timer.h
+++ b/sw/device/lib/dif/dif_rv_timer.h
@@ -200,7 +200,7 @@
     dif_rv_timer_enabled_t state);
 
 /**
- * Reads the current value on a particlar hart's timer.
+ * Reads the current value on a particular hart's timer.
  *
  * @param timer A timer device.
  * @param hart_id The hart counter to read.
@@ -212,6 +212,18 @@
                                                 uint64_t *out);
 
 /**
+ * Writes the given value to a particular hart's timer.
+ *
+ * @param timer A timer device.
+ * @param hart_id The hart counter to write.
+ * @param count The counter value to write.
+ * @return The result of the operation.
+ */
+dif_rv_timer_result_t dif_rv_timer_counter_write(const dif_rv_timer_t *timer,
+                                                 uint32_t hart_id,
+                                                 uint64_t count);
+
+/**
  * Arms the timer to go off once the counter value is greater than
  * or equal to `threshold`, by setting up the given comparator.
  *
diff --git a/sw/device/lib/dif/dif_rv_timer_unittest.cc b/sw/device/lib/dif/dif_rv_timer_unittest.cc
index 5291f54..6b81809 100644
--- a/sw/device/lib/dif/dif_rv_timer_unittest.cc
+++ b/sw/device/lib/dif/dif_rv_timer_unittest.cc
@@ -2,18 +2,19 @@
 // Licensed under the Apache License, Version 2.0, see LICENSE for details.
 // SPDX-License-Identifier: Apache-2.0
 
+#include "sw/device/lib/dif/dif_rv_timer.h"
+
 #include <cstring>
 #include <limits>
 #include <ostream>
 #include <stdint.h>
 
-#include "sw/device/lib/dif/dif_rv_timer.h"
-#include "rv_timer_regs.h"  // Generated.
-
 #include "gtest/gtest.h"
 #include "sw/device/lib/base/mmio.h"
 #include "sw/device/lib/base/testing/mock_mmio.h"
 
+#include "rv_timer_regs.h"  // Generated.
+
 // We define global namespace == and << to make `dif_i2c_timing_params_t` work
 // nicely with EXPECT_EQ.
 bool operator==(dif_rv_timer_tick_params_t a, dif_rv_timer_tick_params_t b) {
@@ -46,7 +47,8 @@
   // The timer frequency devices the clock speed, so their quotient minus 1 is
   // the prescale.
   dif_rv_timer_tick_params_t params, expected = {
-                                         .prescale = 49, .tick_step = 1,
+                                         .prescale = 49,
+                                         .tick_step = 1,
                                      };
   EXPECT_EQ(
       dif_rv_timer_approximate_tick_params(kClockSpeed, kSlowTimer, &params),
@@ -57,7 +59,8 @@
 TEST(ApproximateParamsTest, WithStep) {
   // 50 MHz / 5 is 10 MHz; multiplied by 12, we get 120 MHz.
   dif_rv_timer_tick_params_t params, expected = {
-                                         .prescale = 4, .tick_step = 12,
+                                         .prescale = 4,
+                                         .tick_step = 12,
                                      };
   EXPECT_EQ(
       dif_rv_timer_approximate_tick_params(kClockSpeed, kFastTimer, &params),
@@ -94,7 +97,8 @@
  protected:
   dif_rv_timer_t MakeTimer(dif_rv_timer_config_t config) {
     return {
-        .base_addr = dev().region(), .config = config,
+        .base_addr = dev().region(),
+        .config = config,
     };
   }
 };
@@ -133,7 +137,8 @@
   dif_rv_timer timer;
   EXPECT_EQ(dif_rv_timer_init(dev().region(),
                               {
-                                  .hart_count = 1, .comparator_count = 1,
+                                  .hart_count = 1,
+                                  .comparator_count = 1,
                               },
                               &timer),
             kDifRvTimerOk);
@@ -164,7 +169,8 @@
   dif_rv_timer timer;
   EXPECT_EQ(dif_rv_timer_init(dev().region(),
                               {
-                                  .hart_count = 1, .comparator_count = 4,
+                                  .hart_count = 1,
+                                  .comparator_count = 4,
                               },
                               &timer),
             kDifRvTimerOk);
@@ -198,7 +204,8 @@
   dif_rv_timer timer;
   EXPECT_EQ(dif_rv_timer_init(dev().region(),
                               {
-                                  .hart_count = 4, .comparator_count = 4,
+                                  .hart_count = 4,
+                                  .comparator_count = 4,
                               },
                               &timer),
             kDifRvTimerOk);
@@ -207,7 +214,8 @@
 TEST_F(InitTest, NullArgs) {
   EXPECT_EQ(dif_rv_timer_init(dev().region(),
                               {
-                                  .hart_count = 1, .comparator_count = 1,
+                                  .hart_count = 1,
+                                  .comparator_count = 1,
                               },
                               nullptr),
             kDifRvTimerBadArg);
@@ -217,13 +225,15 @@
   dif_rv_timer_t timer;
   EXPECT_EQ(dif_rv_timer_init(dev().region(),
                               {
-                                  .hart_count = 0, .comparator_count = 1,
+                                  .hart_count = 0,
+                                  .comparator_count = 1,
                               },
                               &timer),
             kDifRvTimerBadArg);
   EXPECT_EQ(dif_rv_timer_init(dev().region(),
                               {
-                                  .hart_count = 1, .comparator_count = 0,
+                                  .hart_count = 1,
+                                  .comparator_count = 0,
                               },
                               &timer),
             kDifRvTimerBadArg);
@@ -407,6 +417,34 @@
   EXPECT_EQ(dif_rv_timer_counter_read(&timer, 5, &value), kDifRvTimerBadArg);
 }
 
+class CounterWriteTest : public TimerTest {};
+
+TEST_F(CounterWriteTest, Baseline) {
+  EXPECT_READ32(RV_TIMER_CTRL_REG_OFFSET, 0x0000'0001);
+  EXPECT_WRITE32(RV_TIMER_CTRL_REG_OFFSET, 0x0000'0000);
+  EXPECT_WRITE32(RegForHart(0, RV_TIMER_TIMER_V_LOWER0_REG_OFFSET),
+                 0xDEAD'BEEF);
+  EXPECT_WRITE32(RegForHart(0, RV_TIMER_TIMER_V_UPPER0_REG_OFFSET),
+                 0xCAFE'FEED);
+  EXPECT_WRITE32(RV_TIMER_CTRL_REG_OFFSET, 0x0000'0001);
+
+  auto timer = MakeTimer({1, 1});
+  uint64_t count = 0xCAFE'FEED'DEAD'BEEF;
+  EXPECT_EQ(dif_rv_timer_counter_write(&timer, 0, count), kDifRvTimerOk);
+}
+
+TEST_F(CounterWriteTest, NullArgs) {
+  uint64_t count = 0xCAFE'FEED'DEAD'BEEF;
+  EXPECT_EQ(dif_rv_timer_counter_write(nullptr, 0, count), kDifRvTimerBadArg);
+}
+
+TEST_F(CounterWriteTest, BadHart) {
+  auto timer = MakeTimer({1, 1});
+  uint64_t count = 0xCAFE'FEED'DEAD'BEEF;
+  EXPECT_EQ(dif_rv_timer_counter_write(&timer, 1, count), kDifRvTimerBadArg);
+  EXPECT_EQ(dif_rv_timer_counter_write(&timer, 2, count), kDifRvTimerBadArg);
+}
+
 class ArmTest : public TimerTest {};
 
 TEST_F(ArmTest, Baseline) {