blob: c91d523be0a7c639b22c483c39ab67f50277d59a [file] [log] [blame]
// Copyright 2019 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::fmt;
use std::io::{self, Read, Write};
use std::result;
use crate::stream::StreamSocket;
// This was picked arbitrarily. crosvm doesn't yet use VIRTIO_NET_F_MTU, so there's no reason to
// opt for massive 65535 byte frames.
const MAX_FRAME_SIZE: usize = 8192;
/// Errors that can be encountered by a ForwarderSession.
#[remain::sorted]
#[derive(Debug)]
pub enum ForwarderError {
/// An io::Error was encountered while reading from a stream.
ReadFromStream(io::Error),
/// An io::Error was encountered while shutting down writes on a stream.
ShutDownStream(io::Error),
/// An io::Error was encountered while writing to a stream.
WriteToStream(io::Error),
}
type Result<T> = result::Result<T, ForwarderError>;
impl fmt::Display for ForwarderError {
#[remain::check]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::ForwarderError::*;
#[remain::sorted]
match self {
ReadFromStream(e) => write!(f, "failed to read from stream: {}", e),
ShutDownStream(e) => write!(f, "failed to shut down stream: {}", e),
WriteToStream(e) => write!(f, "failed to write to stream: {}", e),
}
}
}
/// A ForwarderSession owns stream sockets that it forwards traffic between.
pub struct ForwarderSession {
local: StreamSocket,
remote: StreamSocket,
}
fn forward(from_stream: &mut StreamSocket, to_stream: &mut StreamSocket) -> Result<bool> {
let mut buf = [0u8; MAX_FRAME_SIZE];
let count = from_stream
.read(&mut buf)
.map_err(ForwarderError::ReadFromStream)?;
if count == 0 {
to_stream
.shut_down_write()
.map_err(ForwarderError::ShutDownStream)?;
return Ok(true);
}
to_stream
.write_all(&buf[..count])
.map_err(ForwarderError::WriteToStream)?;
Ok(false)
}
impl ForwarderSession {
/// Creates a forwarder session from a local and remote stream socket.
pub fn new(local: StreamSocket, remote: StreamSocket) -> Self {
ForwarderSession { local, remote }
}
/// Forwards traffic from the local socket to the remote socket.
/// Returns true if the local socket has reached EOF and the
/// remote socket has been shut down for further writes.
pub fn forward_from_local(&mut self) -> Result<bool> {
forward(&mut self.local, &mut self.remote)
}
/// Forwards traffic from the remote socket to the local socket.
/// Returns true if the remote socket has reached EOF and the
/// local socket has been shut down for further writes.
pub fn forward_from_remote(&mut self) -> Result<bool> {
forward(&mut self.remote, &mut self.local)
}
/// Returns a reference to the local stream socket.
pub fn local_stream(&self) -> &StreamSocket {
&self.local
}
/// Returns a reference to the remote stream socket.
pub fn remote_stream(&self) -> &StreamSocket {
&self.remote
}
/// Returns true if both sockets are completely shut down and the session can be dropped.
pub fn is_shut_down(&self) -> bool {
self.local.is_shut_down() && self.remote.is_shut_down()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
use std::net::Shutdown;
use std::os::unix::net::UnixStream;
#[test]
fn forward_unix() {
// Local streams.
let (mut london, folkestone) = UnixStream::pair().unwrap();
// Remote streams.
let (coquelles, mut paris) = UnixStream::pair().unwrap();
// Connect the local and remote sockets via the chunnel.
let mut forwarder = ForwarderSession::new(folkestone.into(), coquelles.into());
// Put some traffic in from London.
let greeting = b"hello";
london.write(greeting).unwrap();
// Expect forwarding from the local end not to have reached EOF.
assert!(!forwarder.forward_from_local().unwrap());
let mut salutation = [0u8; 8];
let count = paris.read(&mut salutation).unwrap();
assert_eq!(greeting.len(), count);
assert_eq!(greeting, &salutation[..count]);
// Shut the local socket down. The forwarder should detect this and perform a shutdown,
// which will manifest as an EOF when reading.
london.shutdown(Shutdown::Write).unwrap();
assert!(forwarder.forward_from_local().unwrap());
assert_eq!(paris.read(&mut salutation).unwrap(), 0);
// Don't consider the forwarder shut down until both ends are.
assert!(!forwarder.is_shut_down());
// Forward traffic from the remote end.
let salutation = b"bonjour";
paris.write(salutation).unwrap();
// Expect forwarding from the remote end not to have reached EOF.
assert!(!forwarder.forward_from_remote().unwrap());
let mut greeting = [0u8; 8];
let count = london.read(&mut greeting).unwrap();
assert_eq!(salutation.len(), count);
assert_eq!(salutation, &greeting[..count]);
// Shut the remote socket down. The forwarder should detect this and perform a shutdown,
// which will manifest as an EOF when reading.
paris.shutdown(Shutdown::Write).unwrap();
assert!(forwarder.forward_from_remote().unwrap());
assert_eq!(london.read(&mut greeting).unwrap(), 0);
// The forwarder should now be considered shut down.
assert!(forwarder.is_shut_down());
}
}