[opentitantool] Add get_socket() to refactor some duplicate code in ti50emulator
Signed-off-by: Alphan Ulusoy <alphan@google.com>
diff --git a/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs b/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs
index 83a2d80..a8a3ba7 100644
--- a/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs
+++ b/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs
@@ -4,7 +4,7 @@
use anyhow::{bail, Context, Result};
-use std::cell::{Cell, RefCell};
+use std::cell::{Cell, RefCell, RefMut};
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
@@ -59,6 +59,17 @@
};
return Ok(valid);
}
+
+ pub fn get_socket(&self) -> Result<RefMut<UnixStream>> {
+ if self.check_state()? {
+ self.reconnect()?;
+ // Socket should be valid as long as the subprocess is running.
+ return Ok(RefMut::map(self.socket.borrow_mut(), |socket| {
+ socket.as_mut().unwrap()
+ }));
+ }
+ bail!(UartError::GenericError("Invalid socket".to_string()));
+ }
}
/// A trait which represents a UART.
@@ -79,42 +90,23 @@
/// Reads UART receive data into `buf`, returning the number of bytes read.
/// This function _may_ block.
fn read(&self, buf: &mut [u8]) -> Result<usize> {
- if self.check_state()? {
- self.reconnect()?;
- if let Some(ref mut fd) = *self.socket.borrow_mut() {
- fd.set_read_timeout(None)?;
- return Ok(fd.read(buf)?);
- }
- bail!(UartError::GenericError("Invalid socket".to_string()));
- };
- Ok(0)
+ let mut socket = self.get_socket()?;
+ socket.set_read_timeout(None)?;
+ return Ok(socket.read(buf)?);
}
/// Reads UART receive data into `buf`, returning the number of bytes read.
/// The `timeout` may be used to specify a duration to wait for data.
/// If timeout expires without any data arriving `Ok(0)` will be returned, never `Err(_)`.
fn read_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<usize> {
- if self.check_state()? {
- self.reconnect()?;
- if let Some(ref mut fd) = *self.socket.borrow_mut() {
- fd.set_read_timeout(Some(timeout))?;
- return Ok(fd.read(buf).context("UART read error")?);
- }
- bail!(UartError::GenericError("Invalid socket".to_string()));
- };
- Ok(0)
+ let mut socket = self.get_socket()?;
+ socket.set_read_timeout(Some(timeout))?;
+ return Ok(socket.read(buf).context("UART read error")?);
}
/// Writes data from `buf` to the UART.
fn write(&self, buf: &[u8]) -> Result<()> {
- if self.check_state()? {
- self.reconnect()?;
- if let Some(ref mut fd) = *self.socket.borrow_mut() {
- fd.write(buf).context("UART read error")?;
- return Ok(());
- }
- bail!(UartError::GenericError("Invalid socket".to_string()));
- };
- Ok(())
+ self.get_socket()?.write(buf).context("UART read error")?;
+ return Ok(());
}
}