[opentitantool]: Allow SPI chip select across transactions

While implementing TPM communication, I came across a use case in which
I wanted to perform a few SPI transfers, inspect the data that was read,
and then possibly issue some more SPI transfers, while keeping CS
asserted throughout.

The current `spi::Target` interface is such that `run_transaction()`
keeps CS asserted for the duration of the list of transfers given in the
parameter list and then deasserts it unconditionally, with no option to
add more based on the result of the previous ones.

This PR introduces an explicit `assert_cs()` function, which can be used
to keep CS asserted across multiple invocations of `run_transaction()`,
like below.

```
let spi = ...;
{
  let _cs_asserted = spi.assert_cs()?;
  spi.run_transaction(...)?;
  if ... {
    spi.run_transaction(...)?;
  }
  // CS will automatically deassert here (or earlier if bailing out.)
}
```

Furthermore, `assert_cs()` counts its invocations, and releases CS only
when every nested assertion has been released, such that if the above
code was part of a helper method, and the called wanted to make several
invocations of the helper, during a single CS session, then it could put
another `assert_cs()` around its logic.

Signed-off-by: Jes B. Klinke <jbk@chromium.org>
Change-Id: I219f3647d29129c4dfc1fc346f855fb4e9da2e53
diff --git a/sw/host/opentitanlib/src/io/spi.rs b/sw/host/opentitanlib/src/io/spi.rs
index 2a80218..2f356a5 100644
--- a/sw/host/opentitanlib/src/io/spi.rs
+++ b/sw/host/opentitanlib/src/io/spi.rs
@@ -151,6 +151,37 @@
         Err(SpiError::InvalidOption("This target does not support set_voltage".to_string()).into())
     }
 
-    /// Runs a SPI transaction composed from the slice of [`Transfer`] objects.
+    /// Runs a SPI transaction composed from the slice of [`Transfer`] objects.  Will assert the
+    /// CS for the duration of the entire transactions.
     fn run_transaction(&self, transaction: &mut [Transfer]) -> Result<()>;
+
+    /// Assert the CS signal.  Uses reference counting, will be deasserted when each and every
+    /// returned `AssertChipSelect` object have gone out of scope.
+    fn assert_cs(self: Rc<Self>) -> Result<AssertChipSelect>;
+}
+
+/// Object that keeps the CS asserted, deasserting when it goes out of scope, (unless another
+/// instance keeps CS asserted longer.)
+pub struct AssertChipSelect {
+    target: Rc<dyn TargetChipDeassert>,
+}
+
+impl AssertChipSelect {
+    // Needs to be public in order for implementation of `Target` to be able to call it.  Never
+    // called by users of `Target`.
+    pub fn new(target: Rc<dyn TargetChipDeassert>) -> Self {
+        Self { target }
+    }
+}
+
+impl Drop for AssertChipSelect {
+    fn drop(&mut self) {
+        self.target.deassert_cs()
+    }
+}
+
+// Needs to be public in order for implementation of `Target` to be able to implement it.  Never
+// called by users of `Target`.
+pub trait TargetChipDeassert {
+    fn deassert_cs(&self);
 }
diff --git a/sw/host/opentitanlib/src/proxy/handler.rs b/sw/host/opentitanlib/src/proxy/handler.rs
index 62c1d50..4cc8600 100644
--- a/sw/host/opentitanlib/src/proxy/handler.rs
+++ b/sw/host/opentitanlib/src/proxy/handler.rs
@@ -4,6 +4,7 @@
 
 use anyhow::{bail, Result};
 
+use std::collections::HashMap;
 use std::time::Duration;
 
 use super::errors::SerializedError;
@@ -18,23 +19,28 @@
 use crate::bootstrap::Bootstrap;
 use crate::io::i2c;
 use crate::io::spi;
+use crate::transport::TransportError;
 
 /// Implementation of the handling of each protocol request, by means of an underlying
 /// `Transport` implementation.
 pub struct TransportCommandHandler<'a> {
     transport: &'a TransportWrapper,
+    spi_chip_select: HashMap<String, Vec<spi::AssertChipSelect>>,
 }
 
 impl<'a> TransportCommandHandler<'a> {
     pub fn new(transport: &'a TransportWrapper) -> Self {
-        Self { transport }
+        Self {
+            transport,
+            spi_chip_select: HashMap::new(),
+        }
     }
 
     /// This method will perform whatever action on the underlying `Transport` that is requested
     /// by the given `Request`, and return a response to be sent to the client.  Any `Err`
     /// return from this method will be propagated to the remote client, without any server-side
     /// logging.
-    fn do_execute_cmd(&self, req: &Request) -> Result<Response> {
+    fn do_execute_cmd(&mut self, req: &Request) -> Result<Response> {
         match req {
             Request::GetCapabilities => {
                 Ok(Response::GetCapabilities(self.transport.capabilities()?))
@@ -175,6 +181,25 @@
                             transaction: resps,
                         }))
                     }
+                    SpiRequest::AssertChipSelect => {
+                        // Add a `spi::AssertChipSelect` object to the stack for this particular
+                        // SPI instance.
+                        self.spi_chip_select
+                            .entry(id.to_string())
+                            .or_insert(Vec::new())
+                            .push(instance.assert_cs()?);
+                        Ok(Response::Spi(SpiResponse::AssertChipSelect))
+                    }
+                    SpiRequest::DeassertChipSelect => {
+                        // Remove a `spi::AssertChipSelect` object from the stack for this
+                        // particular SPI instance.
+                        self.spi_chip_select
+                            .get_mut(id)
+                            .ok_or(TransportError::InvalidOperation)?
+                            .pop()
+                            .ok_or(TransportError::InvalidOperation)?;
+                        Ok(Response::Spi(SpiResponse::DeassertChipSelect))
+                    }
                 }
             }
             Request::I2c { id, command } => {
@@ -258,7 +283,7 @@
     /// by the given `Message`, and return a response to be sent to the client.  Any `Err`
     /// return from this method will be treated as an irrecoverable protocol error, causing an
     /// error message in the server log, and the connection to be terminated.
-    fn execute_cmd(&self, msg: &Message) -> Result<Message> {
+    fn execute_cmd(&mut self, msg: &Message) -> Result<Message> {
         if let Message::Req(req) = msg {
             // Package either `Ok()` or `Err()` into a `Message`, to be sent via network.
             return Ok(Message::Res(
diff --git a/sw/host/opentitanlib/src/proxy/mod.rs b/sw/host/opentitanlib/src/proxy/mod.rs
index ad6f206..37e550d 100644
--- a/sw/host/opentitanlib/src/proxy/mod.rs
+++ b/sw/host/opentitanlib/src/proxy/mod.rs
@@ -19,7 +19,7 @@
 /// Interface for handlers of protocol messages, responding to each message with a single
 /// instance of the same protocol message.
 pub trait CommandHandler<Msg> {
-    fn execute_cmd(&self, msg: &Msg) -> Result<Msg>;
+    fn execute_cmd(&mut self, msg: &Msg) -> Result<Msg>;
 }
 
 /// This is the main entry point for the session proxy.  This struct will either bind on a
diff --git a/sw/host/opentitanlib/src/proxy/protocol.rs b/sw/host/opentitanlib/src/proxy/protocol.rs
index 3f77f23..61fbfe3 100644
--- a/sw/host/opentitanlib/src/proxy/protocol.rs
+++ b/sw/host/opentitanlib/src/proxy/protocol.rs
@@ -116,6 +116,8 @@
     RunTransaction {
         transaction: Vec<SpiTransferRequest>,
     },
+    AssertChipSelect,
+    DeassertChipSelect,
 }
 
 #[derive(Serialize, Deserialize)]
@@ -142,6 +144,8 @@
     RunTransaction {
         transaction: Vec<SpiTransferResponse>,
     },
+    AssertChipSelect,
+    DeassertChipSelect,
 }
 
 #[derive(Serialize, Deserialize)]
diff --git a/sw/host/opentitanlib/src/proxy/socket_server.rs b/sw/host/opentitanlib/src/proxy/socket_server.rs
index d314d97..f161278 100644
--- a/sw/host/opentitanlib/src/proxy/socket_server.rs
+++ b/sw/host/opentitanlib/src/proxy/socket_server.rs
@@ -159,7 +159,7 @@
                 }
                 if event.is_readable() {
                     conn.read()?;
-                    Self::process_any_requests(conn, &self.command_handler)?;
+                    Self::process_any_requests(conn, &mut self.command_handler)?;
                 }
                 // Return whether this connection object should be dropped.
                 return Ok((conn.rx_eof && (conn.tx_buf.len() == 0)) || conn.broken);
@@ -199,7 +199,7 @@
     }
 
     // Look for any completely received requests in the rx_buf, and handle them one by one.
-    fn process_any_requests(conn: &mut Connection, command_handler: &T) -> Result<()> {
+    fn process_any_requests(conn: &mut Connection, command_handler: &mut T) -> Result<()> {
         while let Some(request) = Self::get_complete_request(conn)? {
             // One complete request received, execute it.
             let resp = command_handler.execute_cmd(&request)?;
diff --git a/sw/host/opentitanlib/src/transport/cw310/spi.rs b/sw/host/opentitanlib/src/transport/cw310/spi.rs
index 3bb49bf..68161c5 100644
--- a/sw/host/opentitanlib/src/transport/cw310/spi.rs
+++ b/sw/host/opentitanlib/src/transport/cw310/spi.rs
@@ -6,9 +6,10 @@
 use std::cell::RefCell;
 use std::rc::Rc;
 
-use crate::io::spi::{SpiError, Target, Transfer, TransferMode};
+use crate::io::spi::{AssertChipSelect, SpiError, Target, Transfer, TransferMode};
 use crate::transport::cw310::usb::Backend;
 use crate::transport::cw310::CW310;
+use crate::transport::TransportError;
 
 pub struct CW310Spi {
     device: Rc<RefCell<Backend>>,
@@ -101,4 +102,8 @@
         self.device.borrow().spi1_set_cs_pin(true)?;
         result
     }
+
+    fn assert_cs(self: Rc<Self>) -> Result<AssertChipSelect> {
+        Err(TransportError::UnsupportedOperation.into())
+    }
 }
diff --git a/sw/host/opentitanlib/src/transport/errors.rs b/sw/host/opentitanlib/src/transport/errors.rs
index ba6422b..925f0b8 100644
--- a/sw/host/opentitanlib/src/transport/errors.rs
+++ b/sw/host/opentitanlib/src/transport/errors.rs
@@ -41,6 +41,8 @@
     InvalidStrappingName(String),
     #[error("Transport does not support the requested operation")]
     UnsupportedOperation,
+    #[error("Requested operation invalid at this time")]
+    InvalidOperation,
     #[error("Error communicating with FTDI: {0}")]
     FtdiError(String),
     #[error("Error communicating with debugger: {0}")]
diff --git a/sw/host/opentitanlib/src/transport/hyperdebug/spi.rs b/sw/host/opentitanlib/src/transport/hyperdebug/spi.rs
index f7136a9..e88e028 100644
--- a/sw/host/opentitanlib/src/transport/hyperdebug/spi.rs
+++ b/sw/host/opentitanlib/src/transport/hyperdebug/spi.rs
@@ -4,11 +4,14 @@
 
 use anyhow::{ensure, Result};
 use rusb::{Direction, Recipient, RequestType};
+use std::cell::Cell;
 use std::mem::size_of;
 use std::rc::Rc;
 use zerocopy::{AsBytes, FromBytes};
 
-use crate::io::spi::{SpiError, Target, Transfer, TransferMode};
+use crate::io::spi::{
+    AssertChipSelect, SpiError, Target, TargetChipDeassert, Transfer, TransferMode,
+};
 use crate::transport::hyperdebug::{BulkInterface, Inner};
 use crate::transport::TransportError;
 
@@ -17,6 +20,7 @@
     interface: BulkInterface,
     _target_idx: u8,
     max_chunk_size: usize,
+    cs_asserted_count: Cell<u32>,
 }
 
 const USB_SPI_PKT_ID_CMD_GET_USB_SPI_CONFIG: u16 = 0;
@@ -26,6 +30,8 @@
 //const USB_SPI_PKT_ID_CMD_RESTART_RESPONSE: u16 = 4;
 const USB_SPI_PKT_ID_RSP_TRANSFER_START: u16 = 5;
 const USB_SPI_PKT_ID_RSP_TRANSFER_CONTINUE: u16 = 6;
+const USB_SPI_PKT_ID_CMD_CHIP_SELECT: u16 = 7;
+const USB_SPI_PKT_ID_RSP_CHIP_SELECT: u16 = 8;
 
 //const USB_SPI_REQ_DISABLE: u8 = 1;
 const USB_SPI_REQ_ENABLE: u8 = 0;
@@ -112,6 +118,36 @@
     }
 }
 
+#[derive(AsBytes, FromBytes, Debug)]
+#[repr(C)]
+struct CmdChipSelect {
+    packet_id: u16,
+    flags: u16,
+}
+impl CmdChipSelect {
+    fn new(assert_chip_select: bool) -> Self {
+        Self {
+            packet_id: USB_SPI_PKT_ID_CMD_CHIP_SELECT,
+            flags: if assert_chip_select { 1 } else { 0 },
+        }
+    }
+}
+
+#[derive(AsBytes, FromBytes, Debug, Default)]
+#[repr(C)]
+struct RspChipSelect {
+    packet_id: u16,
+    status_code: u16,
+}
+impl RspChipSelect {
+    fn new() -> Self {
+        Self {
+            packet_id: 0,
+            status_code: 0,
+        }
+    }
+}
+
 impl HyperdebugSpiTarget {
     pub fn open(inner: &Rc<Inner>, spi_interface: &BulkInterface, idx: u8) -> Result<Self> {
         let mut usb_handle = inner.usb_device.borrow_mut();
@@ -160,6 +196,7 @@
             interface: *spi_interface,
             _target_idx: idx,
             max_chunk_size: std::cmp::min(resp.max_write_chunk, resp.max_read_chunk) as usize,
+            cs_asserted_count: Cell::new(0),
         })
     }
 
@@ -235,6 +272,28 @@
         Ok(())
     }
 
+    /// Request assertion or deassertion of chip select
+    fn do_assert_cs(&self, assert: bool) -> Result<()> {
+        let req = CmdChipSelect::new(assert);
+        self.usb_write_bulk(&req.as_bytes())?;
+
+        let mut resp = RspChipSelect::new();
+        let bytecount = self.usb_read_bulk(&mut resp.as_bytes_mut())?;
+        ensure!(
+            bytecount >= 4,
+            TransportError::CommunicationError("Unrecognized reponse to CHIP_SELECT".to_string())
+        );
+        ensure!(
+            resp.packet_id == USB_SPI_PKT_ID_RSP_CHIP_SELECT,
+            TransportError::CommunicationError("Unrecognized reponse to CHIP_SELECT".to_string())
+        );
+        ensure!(
+            resp.status_code == 0,
+            TransportError::CommunicationError("SPI error".to_string())
+        );
+        Ok(())
+    }
+
     /// Send one USB packet.
     fn usb_write_bulk(&self, buf: &[u8]) -> Result<()> {
         self.inner
@@ -346,4 +405,28 @@
         }
         Ok(())
     }
+
+    fn assert_cs(self: Rc<Self>) -> Result<AssertChipSelect> {
+        {
+            let cs_asserted_count = self.cs_asserted_count.get();
+            if cs_asserted_count == 0 {
+                self.do_assert_cs(true)?;
+            }
+            self.cs_asserted_count.set(cs_asserted_count + 1);
+        }
+        Ok(AssertChipSelect::new(self))
+    }
+}
+
+impl TargetChipDeassert for HyperdebugSpiTarget {
+    fn deassert_cs(&self) {
+        let cs_asserted_count = self.cs_asserted_count.get();
+        if cs_asserted_count - 1 == 0 {
+            // We cannot propagate errors through `Drop::drop()`, so panic on any error.  (Logging
+            // would be another option.)
+            self.do_assert_cs(false)
+                .expect("Error while deasserting CS");
+        }
+        self.cs_asserted_count.set(cs_asserted_count - 1);
+    }
 }
diff --git a/sw/host/opentitanlib/src/transport/proxy/spi.rs b/sw/host/opentitanlib/src/transport/proxy/spi.rs
index d8ecd94..4ac3248 100644
--- a/sw/host/opentitanlib/src/transport/proxy/spi.rs
+++ b/sw/host/opentitanlib/src/transport/proxy/spi.rs
@@ -6,7 +6,9 @@
 use std::rc::Rc;
 
 use super::ProxyError;
-use crate::io::spi::{SpiError, Target, Transfer, TransferMode};
+use crate::io::spi::{
+    AssertChipSelect, SpiError, Target, TargetChipDeassert, Transfer, TransferMode,
+};
 use crate::proxy::protocol::{
     Request, Response, SpiRequest, SpiResponse, SpiTransferRequest, SpiTransferResponse,
 };
@@ -143,4 +145,23 @@
             _ => bail!(ProxyError::UnexpectedReply()),
         }
     }
+
+    fn assert_cs(self: Rc<Self>) -> Result<AssertChipSelect> {
+        match self.execute_command(SpiRequest::AssertChipSelect)? {
+            SpiResponse::AssertChipSelect => Ok(AssertChipSelect::new(self)),
+            _ => bail!(ProxyError::UnexpectedReply()),
+        }
+    }
+}
+
+impl TargetChipDeassert for ProxySpi {
+    fn deassert_cs(&self) {
+        match self
+            .execute_command(SpiRequest::DeassertChipSelect)
+            .expect("Error deactivating chip select")
+        {
+            SpiResponse::DeassertChipSelect => (),
+            _ => panic!("Error deactivating chip select"),
+        }
+    }
 }
diff --git a/sw/host/opentitanlib/src/transport/ti50emulator/spi.rs b/sw/host/opentitanlib/src/transport/ti50emulator/spi.rs
index 56ff9cc..f225d4e 100644
--- a/sw/host/opentitanlib/src/transport/ti50emulator/spi.rs
+++ b/sw/host/opentitanlib/src/transport/ti50emulator/spi.rs
@@ -3,8 +3,9 @@
 // SPDX-License-Identifier: Apache-2.0
 
 use anyhow::Result;
+use std::rc::Rc;
 
-use crate::io::spi::{SpiError, Target, Transfer, TransferMode};
+use crate::io::spi::{AssertChipSelect, SpiError, Target, Transfer, TransferMode};
 use crate::transport::TransportError;
 use crate::util::voltage::Voltage;
 
@@ -59,4 +60,10 @@
     fn run_transaction(&self, _transaction: &mut [Transfer]) -> Result<()> {
         Err(TransportError::UnsupportedOperation.into())
     }
+
+    /// Assert the CS signal.  Uses reference counting, will be deasserted when each and every
+    /// returned `AssertChipSelect` object have gone out of scope.
+    fn assert_cs(self: Rc<Self>) -> Result<AssertChipSelect> {
+        Err(TransportError::UnsupportedOperation.into())
+    }
 }
diff --git a/sw/host/opentitanlib/src/transport/ultradebug/spi.rs b/sw/host/opentitanlib/src/transport/ultradebug/spi.rs
index 1a0cd5f..96e3ea6 100644
--- a/sw/host/opentitanlib/src/transport/ultradebug/spi.rs
+++ b/sw/host/opentitanlib/src/transport/ultradebug/spi.rs
@@ -7,12 +7,15 @@
 use std::cell::RefCell;
 use std::rc::Rc;
 
-use crate::io::spi::{ClockPolarity, SpiError, Target, Transfer, TransferMode};
+use crate::io::spi::{
+    AssertChipSelect, ClockPolarity, SpiError, Target, TargetChipDeassert, Transfer, TransferMode,
+};
 use crate::transport::ultradebug::mpsse;
 use crate::transport::ultradebug::Ultradebug;
 
 struct Inner {
     mode: TransferMode,
+    cs_asserted_count: u32,
 }
 
 /// Represents the Ultradebug SPI device.
@@ -26,6 +29,7 @@
     pub const PIN_MOSI: u8 = 1;
     pub const PIN_MISO: u8 = 2;
     pub const PIN_CHIP_SELECT: u8 = 3;
+    pub const MASK_CHIP_SELECT: u8 = 1u8 << Self::PIN_CHIP_SELECT;
     pub const PIN_SPI_ZB: u8 = 4;
     pub fn open(ultradebug: &Ultradebug) -> Result<Self> {
         let mpsse = ultradebug.mpsse(ftdi::Interface::B)?;
@@ -42,9 +46,26 @@
             device: mpsse,
             inner: RefCell::new(Inner {
                 mode: TransferMode::Mode0,
+                cs_asserted_count: 0,
             }),
         })
     }
+
+    fn do_assert_cs(&self, assert: bool) -> Result<()> {
+        let device = self.device.borrow();
+        // Assert or deassert CS#
+        device
+            .execute(&mut [mpsse::Command::SetLowGpio(
+                device.gpio_direction,
+                if assert {
+                    device.gpio_value & !Self::MASK_CHIP_SELECT
+                } else {
+                    device.gpio_value | Self::MASK_CHIP_SELECT
+                },
+            )])
+            .context("FTDI error")?;
+        Ok(())
+    }
 }
 
 impl Target for UltradebugSpi {
@@ -96,12 +117,14 @@
 
         let mut command = Vec::new();
         let device = self.device.borrow();
-        let chip_select = 1u8 << UltradebugSpi::PIN_CHIP_SELECT;
-        // Assert CS# (drive low).
-        command.push(mpsse::Command::SetLowGpio(
-            device.gpio_direction,
-            device.gpio_value & !chip_select,
-        ));
+        let cs_not_already_asserted = self.inner.borrow().cs_asserted_count == 0;
+        if cs_not_already_asserted {
+            // Assert CS# (drive low).
+            command.push(mpsse::Command::SetLowGpio(
+                device.gpio_direction,
+                device.gpio_value & !Self::MASK_CHIP_SELECT,
+            ));
+        }
         // Translate SPI Read/Write Transactions into MPSSE Commands.
         for transfer in transaction.iter_mut() {
             command.push(match transfer {
@@ -137,12 +160,38 @@
                 ),
             });
         }
-        // Release CS# (allow to float high).
-        command.push(mpsse::Command::SetLowGpio(
-            device.gpio_direction,
-            device.gpio_value | chip_select,
-        ));
+        if cs_not_already_asserted {
+            // Release CS# (allow to float high).
+            command.push(mpsse::Command::SetLowGpio(
+                device.gpio_direction,
+                device.gpio_value | Self::MASK_CHIP_SELECT,
+            ));
+        }
         device.execute(&mut command).context("FTDI error")?;
         Ok(())
     }
+
+    fn assert_cs(self: Rc<Self>) -> Result<AssertChipSelect> {
+        {
+            let mut inner = self.inner.borrow_mut();
+            if inner.cs_asserted_count == 0 {
+                self.do_assert_cs(true)?;
+            }
+            inner.cs_asserted_count += 1;
+        }
+        Ok(AssertChipSelect::new(self))
+    }
+}
+
+impl TargetChipDeassert for UltradebugSpi {
+    fn deassert_cs(&self) {
+        let mut inner = self.inner.borrow_mut();
+        inner.cs_asserted_count -= 1;
+        if inner.cs_asserted_count == 0 {
+            // We cannot propagate errors through `Drop::drop()`, so panic on any error.  (Logging
+            // would be another option.)
+            self.do_assert_cs(false)
+                .expect("Error while deasserting CS");
+        }
+    }
 }