blob: 02b6839670e48666ebe9ed6f8627742ea7986e86 [file] [log] [blame]
// Copyright 2022 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! The hypervisor memory service for manaTEE.
#![deny(unsafe_op_in_unsafe_fn)]
use std::cell::RefCell;
use std::cmp;
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::convert::{TryFrom, TryInto};
use std::env;
use std::fmt::{Debug, Formatter};
use std::fs::File;
use std::io::{ErrorKind, Read, Write};
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::path::PathBuf;
use std::ptr::null_mut;
use std::rc::Rc;
use std::result::Result as StdResult;
use std::time::Duration;
use anyhow::{anyhow, bail, Context, Result};
use balloon_control::{BalloonStats, BalloonTubeCommand, BalloonTubeResult};
use data_model::DataInit;
use libc::{recvfrom, MSG_PEEK, MSG_TRUNC};
use libsirenia::{
build_info::BUILD_TIMESTAMP,
linux::events::{AddEventSourceMutator, EventMultiplexer, EventSource, Mutator},
sys,
transport::{Error as TransportError, Transport, UnixServerTransport},
};
use serde::{Deserialize, Serialize};
use sys_util::{
net::UnixSeqpacket,
{error, handle_eintr_errno, info, pagesize, round_up_to_page_size, syslog, warn},
};
const CROS_GUEST_ID: u32 = 0;
#[repr(u32)]
enum MessageId {
// GetBalloonStats(array<u32 id>) => (array<TaggedBalloonStats>);
GetBalloonStats = 1,
// RebalanceMemory(array<BalloonDelta> deltas) => (array<ActualBalloonDelta> actual);
RebalanceMemory = 2,
// PrepareVm(u64 mem_size, u64 init_mem_size) => (i32 res, u32 id, u64 shortfall);
PrepareVm = 3,
// FinishAddVm(u32 id) => i32
FinishAddVm = 4,
// RemoveVm(u32 id) => i32
RemoveVm = 5,
}
impl TryFrom<u32> for MessageId {
type Error = anyhow::Error;
fn try_from(v: u32) -> Result<MessageId> {
use MessageId::*;
match v {
v if v == GetBalloonStats as u32 => Ok(GetBalloonStats),
v if v == RebalanceMemory as u32 => Ok(RebalanceMemory),
v if v == PrepareVm as u32 => Ok(PrepareVm),
v if v == FinishAddVm as u32 => Ok(FinishAddVm),
v if v == RemoveVm as u32 => Ok(RemoveVm),
_ => Err(anyhow!(format!("unknown message id {}", v))),
}
}
}
#[repr(C)]
#[derive(Copy, Clone, Default)]
struct MmsMessageHeader {
len: u32,
msg_type: u32,
}
// Safe because MmsMessageHeader only contains plain data.
unsafe impl DataInit for MmsMessageHeader {}
#[derive(Deserialize)]
struct GetBalloonStatsMsg {
ids: Vec<u32>,
}
#[derive(Serialize)]
struct TaggedBalloonStats {
id: u32,
stats: BalloonStats,
balloon_actual: u64,
}
#[derive(Serialize)]
struct GetBalloonStatsResp {
all_stats: Vec<TaggedBalloonStats>,
}
#[derive(Deserialize, Debug)]
struct BalloonDelta {
id: u32,
#[serde(with = "i64_from_double")]
delta: i64,
}
#[derive(Deserialize)]
struct RebalanceMemoryMsg {
deltas: Vec<BalloonDelta>,
}
#[derive(Serialize)]
struct ActualBalloonDelta {
id: u32,
delta: i64,
}
#[derive(Serialize)]
struct RebalanceMemoryResp {
actual_deltas: Vec<ActualBalloonDelta>,
}
#[derive(Deserialize, Debug)]
struct PrepareVmMsg {
#[serde(with = "u64_from_double")]
mem_size: u64,
#[serde(with = "u64_from_double")]
init_mem_size: u64,
}
#[derive(Serialize)]
struct PrepareVmResp {
res: i32,
id: u32,
shortfall: u64,
}
fn error_prepare_vm_resp(res: i32) -> PrepareVmResp {
PrepareVmResp {
res,
id: 0,
shortfall: 0,
}
}
#[derive(Deserialize)]
struct FinishAddVmMsg {
id: u32,
}
#[derive(Deserialize)]
struct RemoveVmMsg {
id: u32,
}
#[derive(Serialize)]
struct SimpleResp {
res: i32,
}
// TODO(stevensd): use something other than json
macro_rules! from_double {
( $name:ident, $dest_type:ty ) => {
mod $name {
use serde::{Deserialize, Deserializer};
pub fn deserialize<'de, D>(deserializer: D) -> Result<$dest_type, D::Error>
where
D: Deserializer<'de>,
{
Ok(f64::deserialize(deserializer)? as $dest_type)
}
}
};
}
from_double!(u64_from_double, u64);
from_double!(i64_from_double, i64);
// In practice this won't overflow, since mem_size is checked to be less than the
// CrOS guest's total memory, so it will be significantly less than 2^64.
fn calculate_extra_bytes(mem_size: u64) -> u64 {
// 3.2MB/GB for shmem xarray
// 2MB/GB for EPT
// 2MB/GB for page tables
// 2MB/GB for kvm rmap
// .5MB/GB for kvm gfn tracking
// => 9.7MB/GB
// TODO(stevensd): uprev/backport removal of rmap/gfn tracking to hypervisor
let extra_bytes = round_up_to_page_size(mem_size as usize * 97 / 10240) as u64;
// 6MB for crosvm
extra_bytes + (6 * 1024 * 1024)
}
// Returns Ok(None) if EOF is encountered.
fn read_obj<T: DataInit>(connection: &mut Transport) -> Result<Option<T>> {
let mut bytes = vec![0; mem::size_of::<T>()];
match connection.r.read_exact(&mut bytes) {
Ok(_) => {}
Err(e) if e.kind() == ErrorKind::UnexpectedEof => return Ok(None),
e => e.context("failed to read bytes")?,
};
T::from_slice(&bytes)
.context("failed to parse bytes")
.map(|o| Some(*o))
}
fn sync_balloon_command(
conn: &mut Transport,
msg: BalloonTubeCommand,
) -> Result<BalloonTubeResult> {
conn.w
.write(&serde_json::ser::to_vec(&msg).unwrap())
.with_context(|| "failed to issue command")?;
let ret = unsafe {
handle_eintr_errno!(recvfrom(
conn.r.as_raw_fd(),
null_mut(),
0,
MSG_TRUNC | MSG_PEEK,
null_mut(),
null_mut(),
))
};
if ret < 0 {
bail!("Failed to get message size: {}", sys::errno());
}
let mut resp = vec![0; ret as usize];
conn.r
.read_exact(&mut resp)
.with_context(|| "failed to read response")?;
serde_json::from_slice(&resp).with_context(|| "failed to parse response")
}
fn adjust_balloon(client: &mut CrosVmClient, delta: i64) -> i64 {
let target_size = if delta > 0 {
client.balloon_size + (delta as u64)
} else {
client.balloon_size.saturating_sub(delta.abs() as u64)
};
let actual_delta = match sync_balloon_command(
&mut client.client,
BalloonTubeCommand::Adjust {
num_bytes: target_size,
allow_failure: true,
},
) {
Ok(BalloonTubeResult::Adjusted {
num_bytes: actual_size,
}) => {
let actual_delta = (actual_size as i64) - (client.balloon_size as i64);
client.balloon_size = actual_size;
actual_delta
}
res => {
error!("Error adjusting balloon {:?}", res);
// Be pessimistic - if we were trying to reclaim memory, assume the balloon didn't
// inflate at all, and if we were trying to release memory, assume nothing was
// released. If the sibling is dead, then things will be sorted out when the VM
// is removed.
if delta > 0 {
0
} else {
client.balloon_size = target_size;
delta
}
}
};
actual_delta
}
fn get_control_server_path(id: u32) -> PathBuf {
PathBuf::from(format!("/run/mms_control_{}.sock", id))
}
fn wait_for_hangup(conn: &Transport) {
let mut fds = libc::pollfd {
fd: conn.r.as_raw_fd(),
events: libc::POLLHUP,
revents: 0,
};
// Safe because we give a valid pointer to a list (of 1) FD and check the
// return value.
let mut ret = unsafe { handle_eintr_errno!(libc::poll(&mut fds, 1, 10 * 1000)) };
if ret == 0 {
if fds.revents == libc::POLLHUP {
return;
}
warn!("Long wait for client hangup");
// Safe because we give a valid pointer to a list (of 1) FD and check the
// return value.
ret = unsafe { handle_eintr_errno!(libc::poll(&mut fds, 1, -1)) };
}
if ret == -1 || (fds.revents & libc::POLLHUP) == 0 {
error!(
"Error cleaning up stale clients {} {}",
sys::errno(),
fds.revents
);
}
}
fn cleanup_control_server(id: u32, server: UnixServerTransport) {
// Unlink the file to stop any new clients.
if let Err(e) = std::fs::remove_file(get_control_server_path(id)) {
warn!("Error unlinking control server {}: {:?}", id, e);
}
// Check if there is a pending client, and wait for the client to close if there is.
match server.accept_with_timeout(Duration::ZERO) {
Ok(conn) => {
wait_for_hangup(&conn);
}
Err(e) => {
if let TransportError::Accept(e) = e {
if e.kind() != ErrorKind::TimedOut {
warn!("Error checking for trailing clients {}: {:?}", id, e);
}
}
}
}
}
#[derive(Debug)]
struct CrosVmClient {
client: Transport,
mem_size: u64,
balloon_size: u64,
}
// About the flow for starting a new VM:
//
// MMS implements a small state machine for managing startup of new
// VMs. The states are as follows:
//
// - Idle: MMS is not starting a new VM
// - Pending: MMS is in the middle of starting a new VM
// - Failed: MMS failed to start a new VM
//
// The following messages affect the state machine:
//
// - PrepareVm: Reserves memory for the new VM. To allow the client to
// deal with failures to reserve enough memory, this message can be sent
// multiple times, and the reserved memory will continue to accumulate.
// - I -> P or P -> P
// - FinishAddVm: Finishes VM startup.
// - P -> I or P -> F
// - RemoveVm: Cleans up VM starting VM, if the id matches the id
// of the starting VM.
// - P -> I or F -> I
// - Client crashes: Equivalent to removing the starting VM.
//
// Establishing a balloon control connection with a new VM is a two
// step process. When preparing for a new VM, MMS creates a new named
// domain socket, and when finishing adding a new VM, MMS accepts a connection
// from that socket. The connection-based approach allows MMS wait for
// the client to finish shutting down when removing a VM, and it also gives
// MMS a clean, race-free way to prevent any new VMs from starting and using
// a socket (by unlinking the socket).
struct StartingVmState {
id: u32,
server: UnixServerTransport,
mem_size: u64,
init_mem_size: u64,
reserved_mem: u64,
client: Option<Transport>,
failed: bool,
}
impl Debug for StartingVmState {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StartingVmState")
.field("id", &self.id)
.finish()
}
}
struct MmsState {
cros_ctrl_connected: bool,
pending_ctrl_connections: VecDeque<File>,
clients: BTreeMap<u32, CrosVmClient>,
starting_vm_state: Option<StartingVmState>,
next_id: u32,
}
struct CtrlHandler {
connection: Transport,
state: Rc<RefCell<MmsState>>,
}
macro_rules! dispatch_message {
($self: ident, $fn: ident, $data: expr) => {
serde_json::to_vec(
&$self.$fn(&serde_json::from_slice(&$data).with_context(|| "failed to parse")?),
)
.with_context(|| "failed to serialize response")
};
}
impl CtrlHandler {
fn new(connection: Transport, state: Rc<RefCell<MmsState>>) -> Self {
CtrlHandler { connection, state }
}
fn handle_balloon_stats(
&mut self,
GetBalloonStatsMsg { ids }: &GetBalloonStatsMsg,
) -> GetBalloonStatsResp {
let mut state = self.state.borrow_mut();
let mut all_stats = Vec::new();
for id in ids {
let client = match state.clients.get_mut(id) {
Some(client) => client,
None => {
warn!("Missing client for {}", id);
continue;
}
};
match sync_balloon_command(&mut client.client, BalloonTubeCommand::Stats { id: 0 }) {
Ok(BalloonTubeResult::Stats {
stats,
balloon_actual,
..
}) => {
all_stats.push(TaggedBalloonStats {
id: *id,
stats,
balloon_actual,
});
}
Ok(resp) => error!("Unexpected response {:?}", resp),
Err(e) => error!("Error fetching stats {} {}", id, e),
};
}
GetBalloonStatsResp { all_stats }
}
fn validate_rebalance_deltas(&self, deltas: &[BalloonDelta]) -> Result<()> {
let state = self.state.borrow();
let mut ids = BTreeSet::new();
let mut total_delta = 0;
let pagesize = pagesize();
for delta in deltas {
if delta.delta % (pagesize as i64) != 0 {
bail!("invalid balloon config {:?}", delta);
}
if !ids.insert(delta.id) {
bail!("duplicate id {}", delta.id);
}
let client = state
.clients
.get(&delta.id)
.with_context(|| format!("unknown target id {}", delta.id))?;
let new_size = if delta.delta > 0 {
client.balloon_size.checked_add(delta.delta as u64)
} else {
delta
.delta
.checked_abs()
.and_then(|d| client.balloon_size.checked_sub(d as u64))
}
// Also catches underflow
.with_context(|| format!("balloon overflow {} {}", client.balloon_size, delta.delta))?;
if new_size > client.mem_size {
bail!("overinflate balloon {} {}", new_size, client.mem_size);
}
total_delta += delta.delta;
}
if total_delta != 0 {
bail!("unbalanced config {}", total_delta);
}
Ok(())
}
fn handle_rebalance_memory(
&mut self,
RebalanceMemoryMsg { deltas }: &RebalanceMemoryMsg,
) -> RebalanceMemoryResp {
if let Err(err) = self.validate_rebalance_deltas(deltas) {
error!("Invalid rebalance: {:?}", err);
return RebalanceMemoryResp {
actual_deltas: deltas
.iter()
.map(|delta| ActualBalloonDelta {
id: delta.id,
delta: 0,
})
.collect(),
};
}
let mut state = self.state.borrow_mut();
let mut slack: i64 = 0;
let mut actual_deltas = Vec::new();
// Inflate balloons to reclaim their memory.
for delta in deltas {
if delta.delta <= 0 {
continue;
}
let client = state.clients.get_mut(&delta.id).unwrap();
let actual_delta = adjust_balloon(client, delta.delta);
if actual_delta != delta.delta {
info!(
"balloon inflate mismatch id={} expected={} actual={}",
delta.id, delta.delta, actual_delta
);
}
slack += actual_delta;
actual_deltas.push(ActualBalloonDelta {
id: delta.id,
delta: actual_delta,
});
}
// Deflate balloons to give reclaimed memory to other VMs
for delta in deltas {
if delta.delta >= 0 {
continue;
}
let client = state.clients.get_mut(&delta.id).unwrap();
let adjusted_delta = -cmp::min(delta.delta.abs(), slack);
let actual_delta = adjust_balloon(client, adjusted_delta);
if adjusted_delta != actual_delta {
warn!(
"balloon deflate mismatch id={} expected={} actual={}",
delta.id, adjusted_delta, actual_delta
);
}
slack += actual_delta;
actual_deltas.push(ActualBalloonDelta {
id: delta.id,
delta: actual_delta,
});
}
if slack != 0 {
// This should not happen. It either require that a balloon over-inflates
// in the first stage, or that a balloon fails to deflate as requested in
// the second stage. Neither should be possible.
error!("non-zero slack remaining: {}", slack);
}
RebalanceMemoryResp { actual_deltas }
}
fn prepare_vm(&mut self, msg: &PrepareVmMsg) -> PrepareVmResp {
let mut state = self.state.borrow_mut();
let already_reserved = match state.starting_vm_state.as_mut() {
Some(vm_state) => {
if vm_state.mem_size != msg.mem_size || vm_state.init_mem_size != msg.init_mem_size
{
error!(
"prepare_vm mismatch with pending request {:?} {} {}",
msg, vm_state.mem_size, vm_state.init_mem_size
);
return error_prepare_vm_resp(-libc::EINVAL);
}
if vm_state.failed {
error!("prepare_vm mismatch with failed request {}", vm_state.id);
return error_prepare_vm_resp(-libc::EINVAL);
}
vm_state.reserved_mem
}
None => {
if msg.mem_size % (pagesize() as u64) != 0
|| msg.init_mem_size % (pagesize() as u64) != 0
|| msg.init_mem_size > msg.mem_size
{
error!("invalid prepare VM request {:?}", msg);
return error_prepare_vm_resp(-libc::EINVAL);
}
let crosvm_client = state.clients.get(&CROS_GUEST_ID).unwrap();
if msg.mem_size >= crosvm_client.mem_size {
error!("Oversized guest {:?}", msg);
return error_prepare_vm_resp(-libc::EINVAL);
}
// Just panic on overflow - 2^32 VMs should be enough.
let id = state.next_id;
state.next_id = state.next_id.checked_add(1).unwrap();
let path = get_control_server_path(id);
let server = match UnixServerTransport::new(&path) {
Ok(server) => server,
Err(e) => {
error!("failed to create server for {}: {:?}", id, e);
return error_prepare_vm_resp(-libc::EIO);
}
};
state.starting_vm_state = Some(StartingVmState {
id,
server,
mem_size: msg.mem_size,
init_mem_size: msg.init_mem_size,
reserved_mem: 0,
client: None,
failed: false,
});
0
}
};
let required_mem = msg.init_mem_size + calculate_extra_bytes(msg.mem_size);
let crosvm_client = state.clients.get_mut(&CROS_GUEST_ID).unwrap();
let new_reserved =
adjust_balloon(crosvm_client, (required_mem - already_reserved) as i64) as u64;
// starting_vm_state cannot be None here
let vm_state = state.starting_vm_state.as_mut().unwrap();
vm_state.reserved_mem += new_reserved;
PrepareVmResp {
res: if vm_state.reserved_mem == required_mem {
0
} else {
-libc::ENOMEM
},
id: vm_state.id,
shortfall: required_mem - vm_state.reserved_mem,
}
}
fn finish_add_vm(&mut self, FinishAddVmMsg { id }: &FinishAddVmMsg) -> SimpleResp {
let mut state = self.state.borrow_mut();
let pending_id = match state.starting_vm_state.as_ref() {
Some(vm_state) => {
if vm_state.failed {
error!("pending failed vm in finish_add_vm {}", vm_state.id);
return SimpleResp { res: -libc::EINVAL };
}
Some(vm_state.id)
}
None => None,
};
if Some(*id) != pending_id {
error!("id mismatch in finish_add_vm {} {:?}", id, pending_id);
return SimpleResp { res: -libc::EINVAL };
}
// starting_vm_state cannot be None here
let vm_state = state.starting_vm_state.as_mut().unwrap();
let balloon_size = (vm_state.mem_size - vm_state.init_mem_size) as i64;
let required_mem = vm_state.init_mem_size + calculate_extra_bytes(vm_state.mem_size);
let client = if required_mem == vm_state.reserved_mem {
match vm_state.server.accept_with_timeout(Duration::from_secs(10)) {
Ok(client) => {
let mut client = CrosVmClient {
client,
mem_size: vm_state.mem_size,
balloon_size: 0,
};
if adjust_balloon(&mut client, balloon_size) != balloon_size {
error!("Failed to inflate new client balloon");
Err((-libc::ENOMEM, Some(client.client)))
} else {
Ok(client)
}
}
Err(msg) => {
error!("Failed to connect to vm: {:?}", msg);
Err((-libc::EIO, None))
}
}
} else {
error!(
"Mismatch memory for added VM: required={} reserved={}",
required_mem, vm_state.reserved_mem
);
Err((-libc::ENOMEM, None))
};
match client {
Ok(client) => {
let vm_state = state.starting_vm_state.take().unwrap();
state.clients.insert(vm_state.id, client);
cleanup_control_server(vm_state.id, vm_state.server);
SimpleResp { res: 0 }
}
Err((res, client)) => {
vm_state.failed = true;
vm_state.client = client;
SimpleResp { res }
}
}
}
fn remove_vm(&mut self, RemoveVmMsg { id }: &RemoveVmMsg) -> SimpleResp {
if *id == CROS_GUEST_ID {
error!("Invalid id in remove_vm {}", CROS_GUEST_ID);
return SimpleResp { res: -libc::EINVAL };
}
let mut state = self.state.borrow_mut();
let released_mem = match state.clients.remove(id) {
None => {
let pending_id = state.starting_vm_state.as_ref().map(|state| state.id);
if Some(*id) != pending_id {
error!("Unknown id in remove_vm {}", *id);
return SimpleResp { res: -libc::EINVAL };
}
let vm_state = state.starting_vm_state.take().unwrap();
cleanup_control_server(vm_state.id, vm_state.server);
if let Some(client) = vm_state.client {
wait_for_hangup(&client);
}
vm_state.reserved_mem
}
Some(client) => {
wait_for_hangup(&client.client);
client.mem_size - client.balloon_size + calculate_extra_bytes(client.mem_size)
}
};
let crosvm_client = state.clients.get_mut(&CROS_GUEST_ID).unwrap();
let actual_released = adjust_balloon(crosvm_client, -(released_mem as i64));
if actual_released != -(released_mem as i64) {
error!(
"Failed to release sibling memory back to CrOS guest {} {}",
-(released_mem as i64),
actual_released
);
}
SimpleResp { res: 0 }
}
fn handle_message(&mut self) -> Result<()> {
let header = match read_obj::<MmsMessageHeader>(&mut self.connection)
.context("failed to read header")?
{
Some(header) => header,
None => return Ok(()),
};
let mut bytes = vec![0; header.len as usize];
self.connection
.r
.read_exact(&mut bytes)
.with_context(|| "failed to read ctl message")?;
let msg = match header.msg_type.try_into()? {
MessageId::GetBalloonStats => dispatch_message!(self, handle_balloon_stats, bytes),
MessageId::RebalanceMemory => dispatch_message!(self, handle_rebalance_memory, bytes),
MessageId::PrepareVm => dispatch_message!(self, prepare_vm, bytes),
MessageId::FinishAddVm => dispatch_message!(self, finish_add_vm, bytes),
MessageId::RemoveVm => dispatch_message!(self, remove_vm, bytes),
}?;
let mut resp_bytes = Vec::new();
let resp_header = MmsMessageHeader {
len: msg.len() as u32,
msg_type: header.msg_type,
};
resp_bytes.extend_from_slice(resp_header.as_slice());
resp_bytes.extend_from_slice(&msg);
self.connection
.w
.write_all(&resp_bytes)
.with_context(|| "failed writing response")?;
Ok(())
}
}
impl Debug for CtrlHandler {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CtrlHandler").finish()
}
}
impl EventSource for CtrlHandler {
fn on_event(&mut self) -> StdResult<Option<Box<dyn Mutator>>, String> {
if let Err(msg) = self.handle_message() {
error!("Error processing message {}", msg);
};
Ok(None)
}
fn on_hangup(&mut self) -> std::result::Result<Option<Box<dyn Mutator>>, String> {
let mut state = self.state.borrow_mut();
// Wait for all old non-cros VMs to go away
let mut released_mem = 0;
info!("waiting for stale crosvm clients to exit");
for (id, client) in &state.clients {
if *id == CROS_GUEST_ID {
continue;
}
wait_for_hangup(&client.client);
released_mem += (client.mem_size - client.balloon_size
+ calculate_extra_bytes(client.mem_size)) as i64;
}
state.clients.retain(|k, _| *k == CROS_GUEST_ID);
if let Some(vm_state) = state.starting_vm_state.take() {
info!("cleaning up control server");
cleanup_control_server(vm_state.id, vm_state.server);
if let Some(client) = vm_state.client {
info!("cleaning up client");
wait_for_hangup(&client);
}
released_mem += vm_state.reserved_mem as i64;
}
info!("all stale crosvm clients exited");
let crosvm_client = state.clients.get_mut(&CROS_GUEST_ID).unwrap();
let actual_released = adjust_balloon(crosvm_client, -released_mem);
if actual_released != -released_mem {
error!(
"Failed to release sibling memory back to CrOS guest {} {}",
-released_mem, actual_released
);
}
state.cros_ctrl_connected = false;
if let Some(ctrl_file) = state.pending_ctrl_connections.pop_front() {
process_new_ctrl_connection(&mut state, self.state.clone(), ctrl_file)
} else {
Ok(None)
}
}
}
impl AsRawFd for CtrlHandler {
fn as_raw_fd(&self) -> RawFd {
self.connection.as_raw_fd()
}
}
struct MmsBridge {
bridge: UnixSeqpacket,
state: Rc<RefCell<MmsState>>,
}
impl MmsBridge {
fn new(bridge: UnixSeqpacket, state: Rc<RefCell<MmsState>>) -> Self {
MmsBridge { bridge, state }
}
}
fn process_new_ctrl_connection(
state: &mut MmsState,
state_rc: Rc<RefCell<MmsState>>,
ctrl_file: File,
) -> StdResult<Option<Box<dyn Mutator>>, String> {
if state.clients.len() != 1 {
return Err("unknown crosvm clients".to_string());
}
let ctrl_file2 = ctrl_file
.try_clone()
.map_err(|e| format!("Clone error {:?}", e))?;
let ctrl_connection = Transport::from_files(ctrl_file, ctrl_file2);
let ctrl_handler = CtrlHandler::new(ctrl_connection, state_rc);
state.cros_ctrl_connected = true;
Ok(Some(Box::new(AddEventSourceMutator::from(ctrl_handler))))
}
impl EventSource for MmsBridge {
fn on_event(&mut self) -> StdResult<Option<Box<dyn Mutator>>, String> {
let mut state = self.state.borrow_mut();
let ctrl_socket = match self.bridge.recv_as_vec_with_fds() {
Ok((_, fd)) => fd[0],
Err(err) => {
return Err(format!(
"Error receiving ctrl socket from bridge: {:?}",
err
))
}
};
// Safe because we own the fd.
let ctrl_file = unsafe { File::from_raw_fd(ctrl_socket) };
// Although the previous connection should generally be torn down before the
// the new connection, it's possible that the teardown gets delayed. In particular,
// we need to handle the case where the executor processes the hangup and new
// connection in a single iteration - when that happens, the new connection is
// processed before the hangup.
if state.cros_ctrl_connected {
state.pending_ctrl_connections.push_back(ctrl_file);
warn!(
"Duplicate control connection. Pending count is {}",
state.pending_ctrl_connections.len()
);
return Ok(None);
}
process_new_ctrl_connection(&mut state, self.state.clone(), ctrl_file)
}
}
impl AsRawFd for MmsBridge {
fn as_raw_fd(&self) -> RawFd {
self.bridge.as_raw_fd()
}
}
impl Debug for MmsBridge {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MmsBridge").finish()
}
}
fn main() {
if let Err(e) = syslog::init() {
eprintln!("Failed to initialize syslog: {}", e);
return;
}
info!("starting ManaTEE memory service: {}", BUILD_TIMESTAMP);
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
error!("Usage: manatee_memory_service <CrOS guest memory in MiB> <MMS bridge socket path>");
return;
}
let cros_mem = match args[1].parse::<u64>() {
Ok(cros_mem) => match cros_mem.checked_mul(1024 * 1024) {
Some(cros_mem) => cros_mem,
None => {
error!("Cros memory size overflow: {}", cros_mem);
return;
}
},
Err(e) => {
error!("Error parsing cros memory size: {:?}", e);
return;
}
};
let bridge = match UnixSeqpacket::connect(PathBuf::from(&args[2])) {
Ok(bridge) => bridge,
Err(e) => {
error!("Error connecting to MMS bridge {:?}", e);
return;
}
};
let crosvm_server = match UnixServerTransport::new(&get_control_server_path(CROS_GUEST_ID)) {
Ok(server) => server,
Err(e) => {
error!("Failed to start cros guest server {:?}", e);
return;
}
};
let crosvm_client = match crosvm_server.accept_with_timeout(Duration::MAX) {
Ok(client) => client,
Err(e) => {
error!("Failed to connect to cros guest balloon {:?}", e);
return;
}
};
cleanup_control_server(CROS_GUEST_ID, crosvm_server);
let mut clients = BTreeMap::new();
clients.insert(
CROS_GUEST_ID,
CrosVmClient {
client: crosvm_client,
mem_size: cros_mem,
balloon_size: 0,
},
);
let state = Rc::new(RefCell::new(MmsState {
cros_ctrl_connected: false,
pending_ctrl_connections: VecDeque::new(),
clients,
starting_vm_state: None,
next_id: CROS_GUEST_ID + 1,
}));
let mms_bridge = MmsBridge::new(bridge, state);
let mut ctx = EventMultiplexer::new().unwrap();
ctx.add_event(Box::new(mms_bridge)).unwrap();
while !ctx.is_empty() {
if let Err(e) = ctx.run_once() {
error!("{}", e);
};
}
info!("ManaTEE memory service exiting");
}