sirenia: Automatically use serde_bytes in RPC parameters.
Label request parameters as byte fields using serde_bytes in
sirenia-rpc-macros.
Vec<u8> is often inefficient in serde, but serde_bytes provides a way of
labeling these as byte fields so they can be handled efficiently.
Follow-up work will be needed to support serde_bytes returned by RPCs.
BUG=b:233694042,b:239114202
TEST=cargo test --workspace --all-features
Change-Id: I944cf5921d98e1fc083bc6fceed2084c117f69ed
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/3778699
Reviewed-by: Paramjit Oberoi <psoberoi@google.com>
Tested-by: Allen Webb <allenwebb@google.com>
Commit-Queue: Allen Webb <allenwebb@google.com>
diff --git a/sirenia/libsirenia/sirenia-rpc-macros/Cargo.toml b/sirenia/libsirenia/sirenia-rpc-macros/Cargo.toml
index e5fefed..a00a312 100644
--- a/sirenia/libsirenia/sirenia-rpc-macros/Cargo.toml
+++ b/sirenia/libsirenia/sirenia-rpc-macros/Cargo.toml
@@ -15,5 +15,7 @@
[dev-dependencies]
anyhow = "1.0.0"
assert_matches = "1.5.0"
+flexbuffers = "2.0.0"
libsirenia = { path = ".." } # provided by ebuild
serde = { version = "1.0.114", features = ["derive"] }
+serde_bytes = "0.10.0"
diff --git a/sirenia/libsirenia/sirenia-rpc-macros/src/lib.rs b/sirenia/libsirenia/sirenia-rpc-macros/src/lib.rs
index 8be57c7..a6a8930 100644
--- a/sirenia/libsirenia/sirenia-rpc-macros/src/lib.rs
+++ b/sirenia/libsirenia/sirenia-rpc-macros/src/lib.rs
@@ -287,7 +287,14 @@
fn get_request_enum_item(&self) -> TokenStream {
let enum_name = &self.enum_name;
- let args = &self.request_args;
+ let args: Vec<FnArg> = self.request_args.iter().map(|f|
+ // Vec<u8> does not always send efficiently so label it as a byte field.
+ if matches!(f, syn::FnArg::Typed(t) if t.ty == parse_quote!(Vec<u8>)) {
+ parse_quote!(#[serde(with = "serde_bytes")] #f)
+ } else {
+ f.clone()
+ }
+ ).collect();
quote! {
#enum_name{#(#args),*}
}
@@ -741,6 +748,7 @@
#[error()]
fn ping(&mut self, value: usize) -> Result<usize, E>;
fn terminate(&mut self) -> Result<(), E>;
+ fn echo(&mut self, value: Vec<u8>) -> Result<Vec<u8>, E>;
}
);
@@ -750,6 +758,7 @@
pub trait Nested<E>: other::Test<E> {
fn ping(&mut self, value: usize) -> Result<usize, E>;
fn terminate(&mut self) -> Result<(), E>;
+ fn echo(&mut self, value: Vec<u8>) -> Result<Vec<u8>, E>;
}
#[derive(::std::fmt::Debug, ::serde::Deserialize, ::serde::Serialize)]
@@ -757,6 +766,10 @@
Test(other::TestRequest),
Ping { value: usize },
Terminate {},
+ Echo {
+ #[serde(with = "serde_bytes")]
+ value: Vec<u8>
+ },
}
#[derive(::std::fmt::Debug, ::serde::Deserialize, ::serde::Serialize)]
@@ -764,6 +777,7 @@
Test(other::TestResponse),
Ping(usize),
Terminate(::std::result::Result<(), ()>),
+ Echo(::std::result::Result<Vec<u8>, ()>),
}
pub struct NestedClient {
@@ -815,6 +829,17 @@
Err(::libsirenia::rpc::Error::ResponseMismatch.into())
}
}
+
+ fn echo(&mut self, value: Vec<u8>) ->
+ ::std::result::Result<Vec<u8>, ::anyhow::Error> {
+ if let NestedResponse::Echo(response) =
+ NestedRpc::rpc(self, NestedRequest::Echo { value },)?
+ {
+ response.map_err(|x| x.into())
+ } else {
+ Err(::libsirenia::rpc::Error::ResponseMismatch.into())
+ }
+ }
}
impl<R: NestedRpc> other::TestRpc for R {
@@ -868,6 +893,17 @@
}
}
}
+ NestedRequest::Echo { value } => {
+ match self.echo(value) {
+ Ok(x) => Ok(NestedResponse::Echo(Ok(x))),
+ Err(err) => {
+ match err.downcast::<()>() {
+ Ok(err) => Ok(NestedResponse::Echo(Err(err))),
+ Err(err) => Err(err),
+ }
+ }
+ }
+ }
}
}
}
diff --git a/sirenia/libsirenia/sirenia-rpc-macros/tests/nested_rpc.rs b/sirenia/libsirenia/sirenia-rpc-macros/tests/nested_rpc.rs
index 1faaacd..13c0d8c 100644
--- a/sirenia/libsirenia/sirenia-rpc-macros/tests/nested_rpc.rs
+++ b/sirenia/libsirenia/sirenia-rpc-macros/tests/nested_rpc.rs
@@ -29,6 +29,7 @@
#[sirenia_rpc]
pub trait NestedRpc<E>: TestRpc<E> + OtherRpc<E> {
fn terminate(&mut self) -> Result<(), E>;
+ fn echo_bytes(&mut self, bytes: Vec<u8>) -> Result<Vec<u8>, E>;
}
#[derive(Clone)]
@@ -58,6 +59,10 @@
fn terminate(&mut self) -> Result<(), anyhow::Error> {
Err(anyhow!("Done"))
}
+
+ fn echo_bytes(&mut self, bytes: Vec<u8>) -> Result<Vec<u8>, anyhow::Error> {
+ Ok(bytes)
+ }
}
#[test]
diff --git a/sirenia/libsirenia/sirenia-rpc-macros/tests/smoke.rs b/sirenia/libsirenia/sirenia-rpc-macros/tests/smoke.rs
index 09062bc..d0cc9d0 100644
--- a/sirenia/libsirenia/sirenia-rpc-macros/tests/smoke.rs
+++ b/sirenia/libsirenia/sirenia-rpc-macros/tests/smoke.rs
@@ -9,6 +9,7 @@
use std::fmt;
use std::fmt::Display;
use std::fmt::Formatter;
+use std::iter;
use std::thread::spawn;
use anyhow::anyhow;
@@ -43,6 +44,7 @@
fn checked_add(&mut self, addend_a: i32, addend_b: i32) -> Result<Option<i32>, E>;
#[error()]
fn terminate(&mut self) -> Result<(), E>;
+ fn echo_bytes(&mut self, bytes: Vec<u8>) -> Result<Vec<u8>, E>;
}
#[derive(Clone)]
@@ -68,6 +70,10 @@
fn terminate(&mut self) -> Result<(), anyhow::Error> {
Err(anyhow!("Done"))
}
+
+ fn echo_bytes(&mut self, bytes: Vec<u8>) -> Result<Vec<u8>, anyhow::Error> {
+ Ok(bytes)
+ }
}
#[test]
@@ -108,3 +114,11 @@
client_thread.join().unwrap();
}
+
+#[test]
+fn byte_field_size_test() {
+ let bytes: Vec<u8> = iter::repeat(77u8).take(1024).collect();
+ let bytes_len = bytes.len();
+ let request = TestRpcRequest::EchoBytes { bytes };
+ assert!(flexbuffers::to_vec(request).unwrap().len() < bytes_len * 12 / 10);
+}