blob: 106de2a63f2e3ef09263cbb4dbe1628bb86d1575 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
/// Runs a [9P] server.
///
/// [9P]: http://man.cat-v.org/plan_9/5/0intro
extern crate getopts;
extern crate libc;
extern crate libchromeos;
#[macro_use]
extern crate log;
extern crate p9;
use libc::gid_t;
use std::ffi::{CStr, CString};
use std::fmt;
use std::fs::{remove_file, File};
use std::io::{self, BufReader, BufWriter};
use std::net;
use std::num::ParseIntError;
use std::os::raw::c_uint;
use std::os::unix::fs::FileTypeExt;
use std::os::unix::fs::PermissionsExt;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::unix::net::{SocketAddr, UnixListener};
use std::path::{Path, PathBuf};
use std::result;
use std::str::FromStr;
use std::string;
use std::sync::Arc;
use std::thread;
use libchromeos::syslog;
use libchromeos::vsock::*;
const DEFAULT_BUFFER_SIZE: usize = 8192;
// Address family identifiers.
const VSOCK: &str = "vsock:";
const UNIX: &str = "unix:";
const UNIX_FD: &str = "unix-fd:";
// Usage for this program.
const USAGE: &str = "9s [options] {vsock:<port>|unix:<path>|unix-fd:<fd>|<ip>:<port>}";
// Program name.
const IDENT: &[u8] = b"9s\0";
enum ListenAddress {
Net(net::SocketAddr),
Unix(String),
UnixFd(RawFd),
Vsock(c_uint),
}
#[derive(Debug)]
enum ParseAddressError {
MissingUnixPath,
MissingUnixFd,
MissingVsockPort,
Net(net::AddrParseError),
Unix(string::ParseError),
UnixFd(ParseIntError),
Vsock(ParseIntError),
}
impl fmt::Display for ParseAddressError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ParseAddressError::MissingUnixPath => write!(f, "missing unix path"),
ParseAddressError::MissingUnixFd => write!(f, "missing unix file descriptor"),
ParseAddressError::MissingVsockPort => write!(f, "missing vsock port number"),
ParseAddressError::Net(ref e) => e.fmt(f),
ParseAddressError::Unix(ref e) => write!(f, "invalid unix path: {}", e),
ParseAddressError::UnixFd(ref e) => write!(f, "invalid file descriptor: {}", e),
ParseAddressError::Vsock(ref e) => write!(f, "invalid vsock port number: {}", e),
}
}
}
impl FromStr for ListenAddress {
type Err = ParseAddressError;
fn from_str(s: &str) -> result::Result<Self, Self::Err> {
if s.starts_with(VSOCK) {
if s.len() > VSOCK.len() {
Ok(ListenAddress::Vsock(
s[VSOCK.len()..].parse().map_err(ParseAddressError::Vsock)?,
))
} else {
Err(ParseAddressError::MissingVsockPort)
}
} else if s.starts_with(UNIX) {
if s.len() > UNIX.len() {
Ok(ListenAddress::Unix(
s[UNIX.len()..].parse().map_err(ParseAddressError::Unix)?,
))
} else {
Err(ParseAddressError::MissingUnixPath)
}
} else if s.starts_with(UNIX_FD) {
if s.len() > UNIX_FD.len() {
Ok(ListenAddress::UnixFd(
s[UNIX_FD.len()..]
.parse()
.map_err(ParseAddressError::UnixFd)?,
))
} else {
Err(ParseAddressError::MissingUnixFd)
}
} else {
Ok(ListenAddress::Net(
s.parse().map_err(ParseAddressError::Net)?,
))
}
}
}
#[derive(Debug)]
enum Error {
Address(ParseAddressError),
Argument(getopts::Fail),
Cid(ParseIntError),
IdMapConvertHost(String),
IdMapConvertClient(String),
IdMapDuplicate(String),
IdMapParse(String),
IO(io::Error),
MissingAcceptCid,
SocketGid(ParseIntError),
SocketPathNotAbsolute(PathBuf),
Syslog(log::SetLoggerError),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Address(ref e) => e.fmt(f),
Error::Argument(ref e) => e.fmt(f),
Error::Cid(ref e) => write!(f, "invalid cid value: {}", e),
Error::IdMapConvertClient(ref s) => {
write!(f, "malformed client portion of id map ({})", s)
}
Error::IdMapConvertHost(ref s) => write!(f, "malformed host portion of id map ({})", s),
Error::IdMapDuplicate(ref s) => write!(f, "duplicate mapping for host id {}", s),
Error::IdMapParse(ref s) => write!(
f,
"id map must have exactly 2 components: <host_id>:<client_id> ({})",
s
),
Error::IO(ref e) => e.fmt(f),
Error::MissingAcceptCid => write!(f, "`accept_cid` is required for vsock servers"),
Error::SocketGid(ref e) => write!(f, "invalid gid value: {}", e),
Error::SocketPathNotAbsolute(ref p) => {
write!(f, "unix socket path must be absolute: {:?}", p)
}
Error::Syslog(ref e) => write!(f, "failed to initialize syslog: {}", e),
}
}
}
struct UnixSocketAddr(SocketAddr);
impl fmt::Display for UnixSocketAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(path) = self.0.as_pathname() {
write!(f, "{}", path.to_str().unwrap_or("<malformed path>"))
} else {
write!(f, "<unnamed or abstract socket>")
}
}
}
type Result<T> = result::Result<T, Error>;
#[derive(Clone)]
struct ServerParams {
root: String,
uid_map: p9::ServerUidMap,
gid_map: p9::ServerGidMap,
}
fn handle_client<R: io::Read, W: io::Write>(
server_params: Arc<ServerParams>,
mut reader: R,
mut writer: W,
) -> io::Result<()> {
let params: ServerParams = (*server_params).clone();
let mut server = p9::Server::new(PathBuf::from(&params.root), params.uid_map, params.gid_map)?;
loop {
server.handle_message(&mut reader, &mut writer)?;
}
}
fn spawn_server_thread<
R: 'static + io::Read + Send,
W: 'static + io::Write + Send,
D: 'static + fmt::Display + Send,
>(
server_params: &Arc<ServerParams>,
reader: R,
writer: W,
peer: D,
) {
let reader = BufReader::with_capacity(DEFAULT_BUFFER_SIZE, reader);
let writer = BufWriter::with_capacity(DEFAULT_BUFFER_SIZE, writer);
let params = server_params.clone();
thread::spawn(move || {
if let Err(e) = handle_client(params, reader, writer) {
error!("error while handling client {}: {}", peer, e);
}
});
}
fn run_vsock_server(
server_params: Arc<ServerParams>,
port: c_uint,
accept_cid: VsockCid,
) -> io::Result<()> {
let listener = VsockListener::bind((VsockCid::Any, port))?;
loop {
let (stream, peer) = listener.accept()?;
if accept_cid != peer.cid {
warn!("ignoring connection from {}", peer);
continue;
}
info!("accepted connection from {}", peer);
spawn_server_thread(&server_params, stream.try_clone()?, stream, peer);
}
}
fn adjust_socket_ownership(path: &Path, gid: gid_t) -> io::Result<()> {
// At this point we expect valid path since we supposedly created
// the socket, so any failure in transforming path is _really_ unexpected.
let path_str = path.as_os_str().to_str().expect("invalid unix socket path");
let path_cstr = CString::new(path_str).expect("malformed unix socket path");
// Safe as kernel only reads from the path and we know it is properly
// formed and we check the result for errors.
// Note: calling chown with uid -1 will preserve current user ownership.
let res = unsafe { libc::chown(path_cstr.as_ptr(), libc::uid_t::max_value(), gid) };
if res < 0 {
return Err(io::Error::last_os_error());
}
// Allow both owner and group read/write access to the socket, and
// deny access to the rest of the world.
let mut permissions = path.metadata()?.permissions();
permissions.set_mode(0o660);
Ok(())
}
fn run_unix_server(server_params: Arc<ServerParams>, listener: UnixListener) -> io::Result<()> {
loop {
let (stream, peer) = listener.accept()?;
let peer = UnixSocketAddr(peer);
info!("accepted connection from {}", peer);
spawn_server_thread(&server_params, stream.try_clone()?, stream, peer);
}
}
fn run_unix_server_with_path(
server_params: Arc<ServerParams>,
path: &Path,
socket_gid: Option<gid_t>,
) -> io::Result<()> {
if path.exists() {
let metadata = path.metadata()?;
if !metadata.file_type().is_socket() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Requested socket path points to existing non-socket object",
));
}
remove_file(path)?;
}
let listener = UnixListener::bind(path)?;
if let Some(gid) = socket_gid {
adjust_socket_ownership(path, gid)?;
}
run_unix_server(server_params, listener)
}
fn run_unix_server_with_fd(server_params: Arc<ServerParams>, fd: RawFd) -> io::Result<()> {
// This is safe as we are using our very own file descriptor.
let file = unsafe { File::from_raw_fd(fd) };
let metadata = file.metadata()?;
if !metadata.file_type().is_socket() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Supplied file descriptor is not a socket",
));
}
// This is safe as because we have validated that we are dealing with a socket and
// we are checking the result.
let ret = unsafe { libc::listen(file.as_raw_fd(), 128) };
if ret < 0 {
return Err(io::Error::last_os_error());
}
// This is safe because we are dealing with listening socket.
let listener = unsafe { UnixListener::from_raw_fd(file.into_raw_fd()) };
run_unix_server(server_params, listener)
}
fn add_id_mapping<T: Clone + FromStr + Ord>(s: &str, map: &mut p9::ServerIdMap<T>) -> Result<()> {
let components: Vec<&str> = s.split(':').collect();
if components.len() != 2 {
return Err(Error::IdMapParse(s.to_owned()));
}
let host_id = components[0]
.parse::<T>()
.map_err(|_| Error::IdMapConvertHost(components[0].to_owned()))?;
let client_id = components[1]
.parse::<T>()
.map_err(|_| Error::IdMapConvertClient(components[1].to_owned()))?;
if map.contains_key(&host_id) {
return Err(Error::IdMapDuplicate(components[0].to_owned()));
}
map.insert(host_id, client_id);
Ok(())
}
fn main() -> Result<()> {
let mut opts = getopts::Options::new();
opts.optopt(
"",
"accept_cid",
"only accept connections from this vsock context id",
"CID",
);
opts.optopt(
"r",
"root",
"root directory for clients (default is \"/\")",
"PATH",
);
opts.optopt(
"",
"socket_gid",
"change socket group ownership to the specified ID",
"GID",
);
opts.optmulti(
"",
"uid_map",
"translate uids from host to client",
"UID:UID",
);
opts.optmulti(
"",
"gid_map",
"translate gids from host to client",
"GID:GID",
);
opts.optflag("h", "help", "print this help menu");
let matches = opts
.parse(std::env::args_os().skip(1))
.map_err(Error::Argument)?;
if matches.opt_present("h") || matches.free.is_empty() {
print!("{}", opts.usage(USAGE));
return Ok(());
}
let mut uid_map: p9::ServerUidMap = Default::default();
matches
.opt_strs("uid_map")
.iter()
.try_for_each(|s| add_id_mapping(s, &mut uid_map))?;
let mut gid_map: p9::ServerGidMap = Default::default();
matches
.opt_strs("gid_map")
.iter()
.try_for_each(|s| add_id_mapping(s, &mut gid_map))?;
let server_params = Arc::from(ServerParams {
root: matches.opt_str("r").unwrap_or_else(|| "/".into()),
uid_map,
gid_map,
});
// Safe because this string is defined above in this file and it contains exactly
// one nul byte, which appears at the end.
let ident = CStr::from_bytes_with_nul(IDENT).unwrap();
syslog::init(ident).map_err(Error::Syslog)?;
// We already checked that |matches.free| has at least one item.
match matches.free[0]
.parse::<ListenAddress>()
.map_err(Error::Address)?
{
ListenAddress::Vsock(port) => {
let accept_cid = if let Some(cid) = matches.opt_str("accept_cid") {
cid.parse::<VsockCid>().map_err(Error::Cid)
} else {
Err(Error::MissingAcceptCid)
}?;
run_vsock_server(server_params, port, accept_cid).map_err(Error::IO)?;
}
ListenAddress::Net(_) => {
error!("Network server unimplemented");
}
ListenAddress::Unix(path) => {
let path = Path::new(&path);
if !path.is_absolute() {
return Err(Error::SocketPathNotAbsolute(path.to_owned()));
}
let socket_gid = matches
.opt_get::<gid_t>("socket_gid")
.map_err(Error::SocketGid)?;
run_unix_server_with_path(server_params, path, socket_gid).map_err(Error::IO)?;
}
ListenAddress::UnixFd(fd) => {
// Try duplicating the fd to verify that it is a valid file descriptor. It will also
// ensure that we will not accidentally close file descriptor used by something else.
// Safe because this doesn't modify any memory and we check the return value.
let fd = unsafe { libc::fcntl(fd, libc::F_DUPFD_CLOEXEC, 0) };
if fd < 0 {
return Err(Error::IO(io::Error::last_os_error()));
}
run_unix_server_with_fd(server_params, fd).map_err(Error::IO)?;
}
}
Ok(())
}