blob: 1437ebee23f90e0eaf616561eccaac24503ae5af [file] [log] [blame]
// Copyright lowRISC contributors.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
use anyhow::{bail, Result};
use mio::event::Event;
use mio::net::TcpListener;
use mio::net::TcpStream;
use mio::{Events, Interest, Poll, Token};
use mio_signals::{Signal, SignalSet, Signals};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;
use std::io::{ErrorKind, Read, Write};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::CommandHandler;
const BUFFER_SIZE: usize = 8192;
const EOL_CODE: u8 = b'\n';
fn get_next_token() -> Token {
static TOCKEN_COUNTER: AtomicUsize = AtomicUsize::new(0);
Token(TOCKEN_COUNTER.fetch_add(1, Ordering::Relaxed))
}
/// This struct listens on a TCP socket, and maintains a number of concurrent connections,
/// receiving serialized JSON representations of `Msg`, passing them to the given
/// `CommandHandler` to obtain responses to be sent as socket flow contol permits. Note that
/// this implementaion is not specific to (and does not refer to) any particular protocol.
pub struct JsonSocketServer<Msg: DeserializeOwned + Serialize, T: CommandHandler<Msg>> {
command_handler: T,
poll: Poll,
socket: TcpListener,
socket_token: Token,
signals: Signals,
signal_token: Token,
connection_map: HashMap<Token, Connection>,
exit_requested: bool,
phantom: PhantomData<Msg>,
}
impl<Msg: DeserializeOwned + Serialize, T: CommandHandler<Msg>> JsonSocketServer<Msg, T> {
pub fn new(command_handler: T, mut socket: TcpListener) -> Result<Self> {
let poll = Poll::new()?;
let socket_token = get_next_token();
poll.registry()
.register(&mut socket, socket_token, Interest::READABLE)?;
// Create a `Signals` instance that will catch given set of signals for us.
let signals: SignalSet = Signal::Terminate | Signal::Interrupt;
let mut signals = Signals::new(signals)?;
// And register it with our `Poll` instance.
let signal_token = get_next_token();
poll.registry()
.register(&mut signals, signal_token, Interest::READABLE)?;
Ok(Self {
command_handler,
poll,
socket,
socket_token,
signals,
signal_token,
connection_map: HashMap::new(),
exit_requested: false,
phantom: PhantomData,
})
}
pub fn run_loop(&mut self) -> Result<()> {
let mut events = Events::with_capacity(1024);
while !self.exit_requested {
match self.poll.poll(&mut events, None) {
Ok(()) => (),
Err(err) if err.kind() == ErrorKind::Interrupted => {
continue;
}
Err(err) => bail!("poll: {}", err),
}
for event in events.iter() {
if event.token() == self.socket_token {
self.process_new_connection()?;
} else if event.token() == self.signal_token {
self.process_signals()?;
} else {
match self.process_connection(event) {
Ok(shutdown) => {
if shutdown {
self.shutdown_connection(event)?;
}
}
Err(e) => {
log::warn!("Connection {:#X} error: {}", event.token().0, e,);
self.shutdown_connection(event)?;
}
}
}
}
}
Ok(())
}
/// Accept new socket connections, creating new Connection objects.
fn process_new_connection(&mut self) -> Result<()> {
loop {
match self.socket.accept() {
Ok((mut conn_socket, _addres)) => {
let token = get_next_token();
log::info!("New connection id:{:#X}", token.0);
match self.connection_map.entry(token) {
Vacant(entry) => {
self.poll.registry().register(
&mut conn_socket,
token,
Interest::READABLE | Interest::WRITABLE,
)?;
entry.insert(Connection::new(conn_socket));
}
Occupied(_) => {
panic!("JsonSocketServer error: token colision");
}
};
}
Err(err) if err.kind() == ErrorKind::WouldBlock => {
// No more connections ready to accept (or spurious poll event).
return Ok(());
}
Err(err) => bail!("Error accepting TCP connection: {}", err),
}
}
}
fn process_signals(&mut self) -> Result<()> {
loop {
match self.signals.receive()? {
Some(Signal::Interrupt) => {
log::info!("Got interrupt signal");
self.exit_requested = true;
}
Some(Signal::Terminate) => {
log::info!("Got terminate signal");
self.exit_requested = true;
}
Some(signal) => {
log::info!("Got unexpected signal: {:?}", signal);
}
None => return Ok(()),
}
}
}
/// Read and write as much as possible from one particular socket connection.
fn process_connection(&mut self, event: &Event) -> Result<bool> {
match self.connection_map.get_mut(&event.token()) {
Some(conn) => {
if event.is_writable() {
conn.write()?;
}
if event.is_readable() {
conn.read()?;
Self::process_any_requests(conn, &mut self.command_handler)?;
}
// Return whether this connection object should be dropped.
Ok((conn.rx_eof && (conn.tx_buf.is_empty())) || conn.broken)
}
None => bail!("Connection don't exist token:{:#X}", event.token().0),
}
}
/// Close a socket connection and remove it from the poll list.
fn shutdown_connection(&mut self, event: &Event) -> Result<()> {
log::info!("Closing connection id:{:#X}", event.token().0);
let mut conn = self
.connection_map
.remove(&event.token())
.expect("Missing connection this should never happend!!!");
self.poll.registry().deregister(&mut conn.socket)?;
// As `conn` runs out of scope here, its `drop()` method will close the OS handle, which
// in turn causes TCP/IP connection shutdown to be signalled to the remote end.
Ok(())
}
/// Check if the buffer contains at least one full JSON request. If so, remove it from the
/// buffer, decode and return it.
fn get_complete_request(conn: &mut Connection) -> Result<Option<Msg>> {
if let Some(n) = conn.rx_buf.iter().position(|c| *c == EOL_CODE) {
let res = serde_json::from_slice::<Msg>(&conn.rx_buf[..n])?;
if n + 1 < conn.rx_buf.len() {
// Shuffling bytes around in a Vec is expensive, but realistically, as the
// clients would be waiting for response to each request before sending the next
// request, this code will rarely if ever execute.
conn.rx_buf.rotate_left(n + 1);
}
conn.rx_buf.resize(conn.rx_buf.len() - n - 1, 0);
return Ok(Some(res));
}
Ok(None)
}
// Look for any completely received requests in the rx_buf, and handle them one by one.
fn process_any_requests(conn: &mut Connection, command_handler: &mut T) -> Result<()> {
while let Some(request) = Self::get_complete_request(conn)? {
// One complete request received, execute it.
let resp = command_handler.execute_cmd(&request)?;
// Encode response into tx_buf.
serde_json::to_writer(&mut conn.tx_buf, &resp)?;
conn.tx_buf.push(EOL_CODE);
// Transmit as much as possible without blocking, leaving any remnant in
// tx_buf. poll() will tell us when more can be written.
conn.write()?;
}
Ok(())
}
}
/// Represents one connection with a remote OpenTitan tool invocation.
struct Connection {
socket: TcpStream,
/// Outgoing data waiting to be written when the socket permits.
tx_buf: Vec<u8>,
/// Data received from the remote end, but not yet decoded into `Msg`.
rx_buf: Vec<u8>,
/// The remote end indicated end-of-stream. After processing any remaning data in `rx_buf`,
/// this Connection should be gracefully shut down and dropped.
rx_eof: bool,
/// Some error happened during writing or reading from the socket, we cannot meaningfully
/// continue processing, and the connection should be dropped as soon as possible.
broken: bool,
}
impl Connection {
fn new(soc: TcpStream) -> Self {
Self {
socket: soc,
tx_buf: Vec::new(),
rx_buf: Vec::new(),
rx_eof: false,
broken: false,
}
}
// Fill rx_buf with as much data as is available on the socket.
fn read(&mut self) -> Result<()> {
let mut rx_buf_len: usize = self.rx_buf.len();
loop {
self.rx_buf.resize(rx_buf_len + BUFFER_SIZE, 0);
match self.socket.read(&mut self.rx_buf[rx_buf_len..]) {
Ok(0) => {
self.rx_eof = true;
break;
}
Ok(n) => {
rx_buf_len += n;
}
Err(err) => {
if err.kind() != ErrorKind::WouldBlock {
self.broken = true;
}
break; // Break out of loop, also on expected WouldBlock
}
}
}
self.rx_buf.resize(rx_buf_len, 0);
Ok(())
}
// Transmit as much data out of tx_buf as socket will allow.
fn write(&mut self) -> Result<()> {
while !self.tx_buf.is_empty() {
match self.socket.write(&self.tx_buf) {
Ok(n) => {
if n < self.tx_buf.len() {
// Shuffling bytes around in a Vec is expensive, but realistically, as
// the clients would be waiting for response to each request before
// sending the next request, it is unlikely that the OS transmit buffer
// would ever fill up and cause partial writes.
self.tx_buf.rotate_left(n);
}
self.tx_buf.resize(self.tx_buf.len() - n, 0);
}
Err(err) => {
if err.kind() != ErrorKind::WouldBlock {
self.broken = true;
}
break;
}
}
}
Ok(())
}
}