[opentitantool] Add get_socket() to refactor some duplicate code in ti50emulator

Signed-off-by: Alphan Ulusoy <alphan@google.com>
diff --git a/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs b/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs
index 83a2d80..a8a3ba7 100644
--- a/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs
+++ b/sw/host/opentitanlib/src/transport/ti50emulator/uart.rs
@@ -4,7 +4,7 @@
 
 use anyhow::{bail, Context, Result};
 
-use std::cell::{Cell, RefCell};
+use std::cell::{Cell, RefCell, RefMut};
 use std::io::{Read, Write};
 use std::os::unix::net::UnixStream;
 use std::path::PathBuf;
@@ -59,6 +59,17 @@
         };
         return Ok(valid);
     }
+
+    pub fn get_socket(&self) -> Result<RefMut<UnixStream>> {
+        if self.check_state()? {
+            self.reconnect()?;
+            // Socket should be valid as long as the subprocess is running.
+            return Ok(RefMut::map(self.socket.borrow_mut(), |socket| {
+                socket.as_mut().unwrap()
+            }));
+        }
+        bail!(UartError::GenericError("Invalid socket".to_string()));
+    }
 }
 
 /// A trait which represents a UART.
@@ -79,42 +90,23 @@
     /// Reads UART receive data into `buf`, returning the number of bytes read.
     /// This function _may_ block.
     fn read(&self, buf: &mut [u8]) -> Result<usize> {
-        if self.check_state()? {
-            self.reconnect()?;
-            if let Some(ref mut fd) = *self.socket.borrow_mut() {
-                fd.set_read_timeout(None)?;
-                return Ok(fd.read(buf)?);
-            }
-            bail!(UartError::GenericError("Invalid socket".to_string()));
-        };
-        Ok(0)
+        let mut socket = self.get_socket()?;
+        socket.set_read_timeout(None)?;
+        return Ok(socket.read(buf)?);
     }
 
     /// Reads UART receive data into `buf`, returning the number of bytes read.
     /// The `timeout` may be used to specify a duration to wait for data.
     /// If timeout expires without any data arriving `Ok(0)` will be returned, never `Err(_)`.
     fn read_timeout(&self, buf: &mut [u8], timeout: Duration) -> Result<usize> {
-        if self.check_state()? {
-            self.reconnect()?;
-            if let Some(ref mut fd) = *self.socket.borrow_mut() {
-                fd.set_read_timeout(Some(timeout))?;
-                return Ok(fd.read(buf).context("UART read error")?);
-            }
-            bail!(UartError::GenericError("Invalid socket".to_string()));
-        };
-        Ok(0)
+        let mut socket = self.get_socket()?;
+        socket.set_read_timeout(Some(timeout))?;
+        return Ok(socket.read(buf).context("UART read error")?);
     }
 
     /// Writes data from `buf` to the UART.
     fn write(&self, buf: &[u8]) -> Result<()> {
-        if self.check_state()? {
-            self.reconnect()?;
-            if let Some(ref mut fd) = *self.socket.borrow_mut() {
-                fd.write(buf).context("UART read error")?;
-                return Ok(());
-            }
-            bail!(UartError::GenericError("Invalid socket".to_string()));
-        };
-        Ok(())
+        self.get_socket()?.write(buf).context("UART read error")?;
+        return Ok(());
     }
 }