blob: a4d5fc4ea64ab4c4716da90e3e8ac1bc009af209 [file] [log] [blame]
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
/// Structs to supplement std::net.
use std::io;
use std::mem::{self, size_of};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use libc::{
c_int, in6_addr, in_addr, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET,
AF_INET6, SOCK_CLOEXEC, SOCK_STREAM,
};
/// Assist in handling both IP version 4 and IP version 6.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum InetVersion {
V4,
V6,
}
impl InetVersion {
pub fn from_sockaddr(s: &SocketAddr) -> Self {
match s {
SocketAddr::V4(_) => InetVersion::V4,
SocketAddr::V6(_) => InetVersion::V6,
}
}
}
impl Into<sa_family_t> for InetVersion {
fn into(self) -> sa_family_t {
match self {
InetVersion::V4 => AF_INET as sa_family_t,
InetVersion::V6 => AF_INET6 as sa_family_t,
}
}
}
fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
sockaddr_in {
sin_family: AF_INET as sa_family_t,
sin_port: s.port().to_be(),
sin_addr: in_addr {
s_addr: u32::from_ne_bytes(s.ip().octets()),
},
sin_zero: [0; 8],
}
}
fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
sockaddr_in6 {
sin6_family: AF_INET6 as sa_family_t,
sin6_port: s.port().to_be(),
sin6_flowinfo: 0,
sin6_addr: in6_addr {
s6_addr: s.ip().octets(),
},
sin6_scope_id: 0,
}
}
/// A TCP socket.
///
/// Do not use this class unless you need to change socket options or query the
/// state of the socket prior to calling listen or connect. Instead use either TcpStream or
/// TcpListener.
#[derive(Debug)]
pub struct TcpSocket {
inet_version: InetVersion,
fd: RawFd,
}
impl TcpSocket {
pub fn new(inet_version: InetVersion) -> io::Result<Self> {
let fd = unsafe {
libc::socket(
Into::<sa_family_t>::into(inet_version) as c_int,
SOCK_STREAM | SOCK_CLOEXEC,
0,
)
};
if fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(TcpSocket { inet_version, fd })
}
}
pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
let sockaddr = addr
.to_socket_addrs()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
.next()
.unwrap();
let ret = match sockaddr {
SocketAddr::V4(a) => {
let sin = sockaddrv4_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::bind(
self.fd,
&sin as *const sockaddr_in as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
SocketAddr::V6(a) => {
let sin6 = sockaddrv6_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::bind(
self.fd,
&sin6 as *const sockaddr_in6 as *const sockaddr,
size_of::<sockaddr_in6>() as socklen_t,
)
}
}
};
if ret < 0 {
let bind_err = io::Error::last_os_error();
Err(bind_err)
} else {
Ok(())
}
}
pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
let sockaddr = addr
.to_socket_addrs()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
.next()
.unwrap();
let ret = match sockaddr {
SocketAddr::V4(a) => {
let sin = sockaddrv4_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::connect(
self.fd,
&sin as *const sockaddr_in as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
SocketAddr::V6(a) => {
let sin6 = sockaddrv6_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::connect(
self.fd,
&sin6 as *const sockaddr_in6 as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
};
if ret < 0 {
let connect_err = io::Error::last_os_error();
Err(connect_err)
} else {
// Safe because the ownership of the raw fd is released from self and taken over by the
// new TcpStream.
Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
}
}
pub fn listen(self) -> io::Result<TcpListener> {
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe { libc::listen(self.fd, 1) };
if ret < 0 {
let listen_err = io::Error::last_os_error();
Err(listen_err)
} else {
// Safe because the ownership of the raw fd is released from self and taken over by the
// new TcpListener.
Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
}
}
/// Returns the port that this socket is bound to. This can only succeed after bind is called.
pub fn local_port(&self) -> io::Result<u16> {
match self.inet_version {
InetVersion::V4 => {
let mut sin = sockaddr_in {
sin_family: 0,
sin_port: 0,
sin_addr: in_addr { s_addr: 0 },
sin_zero: [0; 8],
};
// Safe because we give a valid pointer for addrlen and check the length.
let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
let ret = unsafe {
// Get the socket address that was actually bound.
libc::getsockname(
self.fd,
&mut sin as *mut sockaddr_in as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
// If this doesn't match, it's not safe to get the port out of the sockaddr.
assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
Ok(u16::from_be(sin.sin_port))
}
}
InetVersion::V6 => {
let mut sin6 = sockaddr_in6 {
sin6_family: 0,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
};
// Safe because we give a valid pointer for addrlen and check the length.
let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
let ret = unsafe {
// Get the socket address that was actually bound.
libc::getsockname(
self.fd,
&mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
// If this doesn't match, it's not safe to get the port out of the sockaddr.
assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
Ok(u16::from_be(sin6.sin6_port))
}
}
}
}
}
impl IntoRawFd for TcpSocket {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl AsRawFd for TcpSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for TcpSocket {
fn drop(&mut self) {
// Safe because this doesn't modify any memory and we are the only
// owner of the file descriptor.
unsafe { libc::close(self.fd) };
}
}