blob: 85e8cd80317a09a7e37c28671335dc6144f711d5 [file] [log] [blame]
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Handles the communication abstraction for sirenia. Used both for
//! communication between dugong and trichechus and between TEEs and
//! trichechus.
pub mod persistence;
pub mod tee_api;
pub mod trichechus;
use std::array::TryFromSliceError;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::fmt;
use std::fmt::Debug;
use std::io;
use std::io::BufWriter;
use std::io::Read;
use std::io::Write;
use std::ops::Deref;
use std::result::Result as StdResult;
use flexbuffers::FlexbufferSerializer;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error as ThisError;
use crate::sys::eagain_is_ok;
pub const LENGTH_BYTE_SIZE: usize = 4;
#[derive(Debug, ThisError)]
pub enum Error {
#[error("failed to read: {0}")]
Read(#[source] io::Error),
#[error("no data to read from socket")]
EmptyRead,
#[error("failed to write: {0}")]
Write(#[source] io::Error),
#[error("error deserializing: {0}")]
Deserialize(#[source] flexbuffers::DeserializationError),
#[error("error serializing: {0}")]
Serialize(#[source] flexbuffers::SerializationError),
}
/// The result of an operation in this crate.
pub type Result<T> = StdResult<T, Error>;
#[derive(Debug)]
pub struct NonBlockingMessageReader {
read_buffer: Vec<u8>,
read_size: usize,
read_target: Option<usize>,
}
impl NonBlockingMessageReader {
pub fn partial_read(&self) -> bool {
self.read_size != 0 || self.read_target.is_some()
}
pub fn remaining(&self) -> Option<usize> {
self.read_target.map(|a| a - self.read_size)
}
pub fn clear(&mut self) {
self.read_buffer.clear();
self.read_buffer.resize(LENGTH_BYTE_SIZE, 0u8);
self.read_size = 0;
self.read_target = None;
}
fn try_read<R: Read>(&mut self, r: &mut R) -> Result<()> {
if let Some(size) =
eagain_is_ok(r.read(&mut self.read_buffer.as_mut_slice()[self.read_size..]))
.map_err(Error::Read)?
{
if size == 0 {
return Err(Error::EmptyRead);
}
self.read_size += size;
}
Ok(())
}
pub fn read_message<'de, R: Read, D: Deserialize<'de>>(
&'de mut self,
r: &mut R,
) -> Result<Option<D>> {
if self.read_target.is_none() {
if self.read_size == 0 {
// Perform a clear since self.read_buffer might have been borrowed.
self.read_buffer.clear();
self.read_buffer.resize(LENGTH_BYTE_SIZE, 0u8);
}
// Read the length of the serialized message first.
debug_assert_eq!(self.read_buffer.len(), LENGTH_BYTE_SIZE);
self.try_read(r)?;
if self.read_size < LENGTH_BYTE_SIZE {
return Ok(None);
}
let target =
u32::from_be_bytes(self.read_buffer.as_slice().try_into().unwrap()) as usize;
self.read_size = 0;
self.read_target = Some(target);
self.read_buffer.clear();
self.read_buffer.resize(target, 0);
}
// Read the serialized message.
let message_size = self.read_target.unwrap();
if message_size == 0 {
return Err(Error::EmptyRead);
}
debug_assert_eq!(self.read_buffer.len(), message_size);
self.try_read(r)?;
if self.read_size < message_size {
return Ok(None);
}
// Partial clear since self.read_buffer will be borrowed.
self.read_size = 0;
self.read_target = None;
let ret = flexbuffers::from_slice(self.read_buffer.as_slice())
.map(Some)
.map_err(Error::Deserialize);
ret
}
}
impl Default for NonBlockingMessageReader {
fn default() -> Self {
NonBlockingMessageReader {
read_buffer: vec![0u8; LENGTH_BYTE_SIZE],
read_size: 0,
read_target: None,
}
}
}
// Reads a message from the given Read. First reads a u32 that says the length
// of the serialized message, then reads the serialized message and
// deserializes it.
pub fn read_message<R: Read, D: DeserializeOwned>(r: &mut R) -> Result<D> {
// Read the length of the serialized message first
let mut buf = [0; LENGTH_BYTE_SIZE];
r.read_exact(&mut buf).map_err(Error::Read)?;
let message_size: u32 = u32::from_be_bytes(buf);
if message_size == 0 {
return Err(Error::EmptyRead);
}
// Read the actual serialized message
let mut ser_message = vec![0; message_size as usize];
r.read_exact(&mut ser_message).map_err(Error::Read)?;
flexbuffers::from_slice(&ser_message).map_err(Error::Deserialize)
}
// Writes the given message to the given Write. First writes the length of the
// serialized message then the serialized message itself.
pub fn write_message<W: Write, S: Serialize>(w: &mut W, m: S) -> Result<()> {
let mut writer = BufWriter::new(w);
// Serialize the message and calculate the length
let mut ser = FlexbufferSerializer::new();
m.serialize(&mut ser).map_err(Error::Serialize)?;
let len: u32 = ser.view().len() as u32;
let mut len_ser = FlexbufferSerializer::new();
len.serialize(&mut len_ser).map_err(Error::Serialize)?;
writer.write(&len.to_be_bytes()).map_err(Error::Write)?;
writer.write(ser.view()).map_err(Error::Write)?;
Ok(())
}
/// Types needed for trichechus RPC
pub const SHA256_SIZE: usize = 32;
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub struct Digest([u8; SHA256_SIZE]);
impl Deref for Digest {
type Target = [u8; SHA256_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<[u8; SHA256_SIZE]> for Digest {
fn from(value: [u8; SHA256_SIZE]) -> Self {
Digest(value)
}
}
impl TryFrom<&[u8]> for Digest {
type Error = TryFromSliceError;
fn try_from(value: &[u8]) -> StdResult<Self, Self::Error> {
Ok(Digest(value.try_into()?))
}
}
impl fmt::Display for Digest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.map(|x| format!("{:02x}", x)).join(""))
}
}
#[cfg(test)]
mod test {
use std::io::Cursor;
use std::io::Seek;
use std::mem::size_of;
use assert_matches::assert_matches;
use super::*;
const TEST_VALUE: u32 = 77;
#[test]
fn read_and_write_message() {
let mut channel = Cursor::new(Vec::<u8>::with_capacity(size_of::<u32>() * 2));
write_message(&mut channel, TEST_VALUE).unwrap();
channel.rewind().unwrap();
assert_matches!(read_message(&mut channel), Ok(TEST_VALUE));
channel.rewind().unwrap();
let mut reader = NonBlockingMessageReader::default();
let ret: Option<u32> = reader.read_message(&mut channel).unwrap();
assert_matches!(ret, Some(TEST_VALUE));
}
#[test]
fn readmessage_error_read() {
let mut channel = Cursor::new(Vec::<u8>::with_capacity(size_of::<u32>() * 2));
let ret: Result<u32> = read_message(&mut channel);
assert_matches!(ret, Err(Error::Read(_)));
}
#[test]
fn nonblockingmessagereader_error_emptyread() {
let mut channel = Cursor::new(Vec::<u8>::with_capacity(size_of::<u32>() * 2));
let mut reader = NonBlockingMessageReader::default();
let ret: Result<Option<u32>> = reader.read_message(&mut channel);
assert_matches!(ret, Err(Error::EmptyRead));
}
#[test]
fn nonblockingmessagereader_partial_read() {
let mut buf = Vec::<u8>::with_capacity(size_of::<u32>() * 2);
write_message(&mut Cursor::new(&mut buf), TEST_VALUE).unwrap();
let mut reader = NonBlockingMessageReader::default();
assert!(!reader.partial_read());
eprintln!("0 Reader: {:?}", &reader);
let mut offset = 0;
const TAKE: usize = 1;
let ret: Option<u32> = reader
.read_message(&mut Cursor::new(&mut buf[offset..(offset + TAKE)]))
.unwrap();
assert_matches!(ret, None);
eprintln!("1 Reader: {:?}", &reader);
offset += TAKE;
assert_eq!(reader.read_size, offset);
assert!(reader.partial_read());
assert_matches!(reader.remaining(), None);
let remaining: usize = size_of::<u32>() - TAKE;
let ret: Result<Option<u32>> =
reader.read_message(&mut Cursor::new(&mut buf[offset..(offset + remaining)]));
// This is an empty read because the end of the slice is reached trying to read the message.
assert_matches!(ret, Err(Error::EmptyRead));
eprintln!("2 Reader: {:?}", &reader);
offset += remaining;
assert!(reader.partial_read());
assert_eq!(reader.read_size, offset - size_of::<u32>());
assert_matches!(reader.remaining(), Some(v) if v == buf.len() - size_of::<u32>());
let ret: Option<u32> = reader
.read_message(&mut Cursor::new(&mut buf[offset..(offset + TAKE)]))
.unwrap();
assert_matches!(ret, None);
eprintln!("3 Reader: {:?}", &reader);
offset += TAKE;
assert!(reader.partial_read());
assert_matches!(reader.remaining(), Some(v) if v == buf.len() - size_of::<u32>() - TAKE);
let ret: Option<u32> = reader
.read_message(&mut Cursor::new(&mut buf[offset..]))
.unwrap();
assert_matches!(ret, Some(TEST_VALUE));
eprintln!("4 Reader: {:?}", &reader);
assert!(!reader.partial_read());
assert_matches!(reader.remaining(), None);
}
}