| // Copyright 2019 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| use std::fmt; |
| use std::io; |
| use std::net::TcpStream; |
| use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; |
| use std::os::unix::net::UnixStream; |
| use std::result; |
| |
| use libc::{self, c_void, shutdown, EPIPE, SHUT_WR}; |
| |
| use libchromeos::vsock::VsockStream; |
| |
| /// StreamSocket provides a generic abstraction around any connection-oriented stream socket. |
| /// The socket will be closed when StreamSocket is dropped, but writes to the socket can also |
| /// be shut down manually. |
| pub struct StreamSocket { |
| fd: RawFd, |
| shut_down: bool, |
| } |
| |
| impl StreamSocket { |
| /// Connects to the given socket address. Supported socket types are vsock, unix, and TCP. |
| pub fn connect(sockaddr: &str) -> result::Result<StreamSocket, StreamSocketError> { |
| const UNIX_PREFIX: &str = "unix:"; |
| const VSOCK_PREFIX: &str = "vsock:"; |
| |
| if sockaddr.starts_with(VSOCK_PREFIX) { |
| let vsock_stream = VsockStream::connect(sockaddr) |
| .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?; |
| Ok(vsock_stream.into()) |
| } else if sockaddr.starts_with(UNIX_PREFIX) { |
| let (_prefix, sock_path) = sockaddr.split_at(UNIX_PREFIX.len()); |
| let unix_stream = UnixStream::connect(sock_path) |
| .map_err(|e| StreamSocketError::ConnectUnix(sockaddr.to_string(), e))?; |
| Ok(unix_stream.into()) |
| } else { |
| // Assume this is a TCP stream. |
| let tcp_stream = TcpStream::connect(sockaddr) |
| .map_err(|e| StreamSocketError::ConnectTcp(sockaddr.to_string(), e))?; |
| Ok(tcp_stream.into()) |
| } |
| } |
| |
| /// Shuts down writes to the socket using shutdown(2). |
| pub fn shut_down_write(&mut self) -> io::Result<()> { |
| // Safe because no memory is modified and the return value is checked. |
| let ret = unsafe { shutdown(self.fd, SHUT_WR) }; |
| if ret < 0 { |
| return Err(io::Error::last_os_error()); |
| } |
| |
| self.shut_down = true; |
| Ok(()) |
| } |
| |
| /// Returns true if the socket has been shut down for writes, false otherwise. |
| pub fn is_shut_down(&self) -> bool { |
| self.shut_down |
| } |
| } |
| |
| impl io::Read for StreamSocket { |
| fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
| // Safe because this will only modify the contents of |buf| and we check the return value. |
| let ret = unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) }; |
| if ret < 0 { |
| return Err(io::Error::last_os_error()); |
| } |
| |
| Ok(ret as usize) |
| } |
| } |
| |
| impl io::Write for StreamSocket { |
| fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
| // Safe because this doesn't modify any memory and we check the return value. |
| let ret = unsafe { libc::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) }; |
| if ret < 0 { |
| // If a write causes EPIPE then the socket is shut down for writes. |
| let err = io::Error::last_os_error(); |
| if let Some(errno) = err.raw_os_error() { |
| if errno == EPIPE { |
| self.shut_down = true |
| } |
| } |
| |
| return Err(err); |
| } |
| |
| Ok(ret as usize) |
| } |
| |
| fn flush(&mut self) -> io::Result<()> { |
| // No buffered data so nothing to do. |
| Ok(()) |
| } |
| } |
| |
| impl AsRawFd for StreamSocket { |
| fn as_raw_fd(&self) -> RawFd { |
| self.fd |
| } |
| } |
| |
| impl From<TcpStream> for StreamSocket { |
| fn from(stream: TcpStream) -> Self { |
| StreamSocket { |
| fd: stream.into_raw_fd(), |
| shut_down: false, |
| } |
| } |
| } |
| |
| impl From<UnixStream> for StreamSocket { |
| fn from(stream: UnixStream) -> Self { |
| StreamSocket { |
| fd: stream.into_raw_fd(), |
| shut_down: false, |
| } |
| } |
| } |
| |
| impl From<VsockStream> for StreamSocket { |
| fn from(stream: VsockStream) -> Self { |
| StreamSocket { |
| fd: stream.into_raw_fd(), |
| shut_down: false, |
| } |
| } |
| } |
| |
| impl FromRawFd for StreamSocket { |
| unsafe fn from_raw_fd(fd: RawFd) -> Self { |
| StreamSocket { |
| fd, |
| shut_down: false, |
| } |
| } |
| } |
| |
| impl Drop for StreamSocket { |
| fn drop(&mut self) { |
| // Safe because this doesn't modify any memory and we are the only |
| // owner of the file descriptor. |
| unsafe { libc::close(self.fd) }; |
| } |
| } |
| |
| /// Error enums for StreamSocket. |
| #[remain::sorted] |
| #[derive(Debug)] |
| pub enum StreamSocketError { |
| ConnectTcp(String, io::Error), |
| ConnectUnix(String, io::Error), |
| ConnectVsock(String, io::Error), |
| } |
| |
| impl fmt::Display for StreamSocketError { |
| #[remain::check] |
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| use self::StreamSocketError::*; |
| |
| #[remain::sorted] |
| match self { |
| ConnectTcp(sockaddr, e) => { |
| write!(f, "failed to connect to TCP sockaddr {}: {}", sockaddr, e) |
| } |
| ConnectUnix(sockaddr, e) => { |
| write!(f, "failed to connect to unix sockaddr {}: {}", sockaddr, e) |
| } |
| ConnectVsock(sockaddr, e) => { |
| write!(f, "failed to connect to vsock sockaddr {}: {}", sockaddr, e) |
| } |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use std::io::{Read, Write}; |
| use std::net::TcpListener; |
| use std::os::unix::net::{UnixListener, UnixStream}; |
| use tempfile::TempDir; |
| |
| #[test] |
| fn sock_connect_tcp() { |
| let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
| let sockaddr = format!("127.0.0.1:{}", listener.local_addr().unwrap().port()); |
| |
| let _stream = StreamSocket::connect(&sockaddr).unwrap(); |
| } |
| |
| #[test] |
| fn sock_connect_unix() { |
| let tempdir = TempDir::new().unwrap(); |
| let path = tempdir.path().to_owned().join("test.sock"); |
| let _listener = UnixListener::bind(&path).unwrap(); |
| |
| let unix_addr = format!("unix:{}", path.to_str().unwrap()); |
| let _stream = StreamSocket::connect(&unix_addr).unwrap(); |
| } |
| |
| #[test] |
| fn invalid_sockaddr() { |
| assert!(StreamSocket::connect("this is not a valid sockaddr").is_err()); |
| } |
| |
| #[test] |
| fn shut_down_write() { |
| let (unix_stream, _dummy) = UnixStream::pair().unwrap(); |
| let mut stream: StreamSocket = unix_stream.into(); |
| |
| stream.write(b"hello").unwrap(); |
| |
| stream.shut_down_write().unwrap(); |
| |
| assert!(stream.is_shut_down()); |
| assert!(stream.write(b"goodbye").is_err()); |
| } |
| |
| #[test] |
| fn read_from_shut_down_sock() { |
| let (unix_stream1, unix_stream2) = UnixStream::pair().unwrap(); |
| let mut stream1: StreamSocket = unix_stream1.into(); |
| let mut stream2: StreamSocket = unix_stream2.into(); |
| |
| stream1.shut_down_write().unwrap(); |
| |
| // Reads from the other end of the socket should now return EOF. |
| let mut buf = Vec::new(); |
| assert_eq!(stream2.read_to_end(&mut buf).unwrap(), 0); |
| } |
| } |