From ca264abae81be61a6a64608cee322e6f66e4b848 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 26 Mar 2022 10:53:33 +0100 Subject: [PATCH 1/4] Improve test protocol Signed-off-by: trivernis --- tests/utils/protocol.rs | 88 ++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/tests/utils/protocol.rs b/tests/utils/protocol.rs index 192aec51..329645de 100644 --- a/tests/utils/protocol.rs +++ b/tests/utils/protocol.rs @@ -104,7 +104,7 @@ pub struct TestProtocolStream { impl TestProtocolStream { /// Read from the receiver and remaining buffer async fn read_from_receiver( - buf: &mut ReadBuf<'static>, + buf: &mut ReadBuf<'_>, receiver: Arc>>>, remaining_buf: Arc>>, ) { @@ -133,7 +133,7 @@ impl TestProtocolStream { /// Read from the remaining buffer returning a boolean if the /// read buffer has been filled async fn read_from_remaining_buffer( - buf: &mut ReadBuf<'static>, + buf: &mut ReadBuf<'_>, remaining_buf: &mut Vec, ) -> bool { if remaining_buf.len() < buf.capacity() { @@ -180,70 +180,66 @@ impl AsyncProtocolStream for TestProtocolStream { impl AsyncRead for TestProtocolStream { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - unsafe { - // we need a mutable reference to access the inner future - let stream = self.get_unchecked_mut(); + if self.future.is_none() { + // we need to change the lifetime to be able to use the read buffer in the read future + let buf: &mut ReadBuf<'static> = unsafe { + // SAFETY: idk tbh + mem::transmute(buf) + }; + let receiver = Arc::clone(&self.receiver); + let remaining_buf = Arc::clone(&self.remaining_buf); - if stream.future.is_none() { - // we need to change the lifetime to be able to use the read buffer in the read future - let buf: &mut ReadBuf<'static> = mem::transmute(buf); - let receiver = Arc::clone(&stream.receiver); - let remaining_buf = Arc::clone(&stream.remaining_buf); - - let future = TestProtocolStream::read_from_receiver(buf, receiver, remaining_buf); - stream.future = Some(Box::pin(future)); - } - if let Some(future) = &mut stream.future { - match future.as_mut().poll(cx) { - Poll::Ready(_) => { - stream.future = None; - Poll::Ready(Ok(())) - } - Poll::Pending => Poll::Pending, + let future = TestProtocolStream::read_from_receiver(buf, receiver, remaining_buf); + self.future = Some(Box::pin(future)); + } + if let Some(future) = &mut self.future { + match future.as_mut().poll(cx) { + Poll::Ready(_) => { + self.future = None; + Poll::Ready(Ok(())) } - } else { - Poll::Pending + Poll::Pending => Poll::Pending, } + } else { + Poll::Pending } } } +impl Unpin for TestProtocolStream {} + impl AsyncWrite for TestProtocolStream { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let write_len = buf.len(); - unsafe { - // we need a mutable reference to access the inner future - let stream = self.get_unchecked_mut(); - if stream.future.is_none() { - // we take ownership here so that we don't need to change lifetimes here - let buf = buf.to_vec(); - let sender = stream.sender.clone(); + if self.future.is_none() { + // we take ownership here so that we don't need to change lifetimes here + let buf = buf.to_vec(); + let sender = self.sender.clone(); - let future = async move { - sender.send(buf).await.unwrap(); - }; - stream.future = Some(Box::pin(future)); - } - if let Some(future) = &mut stream.future { - match future.as_mut().poll(cx) { - Poll::Ready(_) => { - stream.future = None; - Poll::Ready(Ok(write_len)) - } - Poll::Pending => Poll::Pending, + let future = async move { + sender.send(buf).await.unwrap(); + }; + self.future = Some(Box::pin(future)); + } + if let Some(future) = &mut self.future { + match future.as_mut().poll(cx) { + Poll::Ready(_) => { + self.future = None; + Poll::Ready(Ok(write_len)) } - } else { - Poll::Pending + Poll::Pending => Poll::Pending, } + } else { + Poll::Pending } } From ac471d296e007118c64ec7b3aea364adf0e5dde0 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 26 Mar 2022 12:12:48 +0100 Subject: [PATCH 2/4] Change internal bytes representation to Bytes object from bytes crate Signed-off-by: trivernis --- Cargo.lock | 3 +- Cargo.toml | 3 +- benches/deserialization_benchmark.rs | 15 ++-- benches/serialization_benchmark.rs | 9 ++- src/events/error_event.rs | 16 ++-- src/events/event.rs | 67 +++++++++-------- src/events/event_handler.rs | 7 +- src/events/payload.rs | 73 +++++++++++-------- src/events/payload_serializer/mod.rs | 5 +- .../payload_serializer/serialize_bincode.rs | 5 +- .../payload_serializer/serialize_json.rs | 5 +- .../payload_serializer/serialize_postcard.rs | 5 +- .../payload_serializer/serialize_rmp.rs | 5 +- src/ipc/stream_emitter/event_metadata.rs | 2 +- src/lib.rs | 7 ++ tests/test_event_streams.rs | 9 ++- tests/test_events_with_payload.rs | 19 ++--- tests/test_raw_events.rs | 5 +- 18 files changed, 148 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 49e57c63..c21e5985 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,11 +93,12 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bromine" -version = "0.19.0" +version = "0.20.0" dependencies = [ "async-trait", "bincode", "byteorder", + "bytes", "criterion", "crossbeam-utils", "futures", diff --git a/Cargo.toml b/Cargo.toml index a197f693..e9e82cd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bromine" -version = "0.19.0" +version = "0.20.0" authors = ["trivernis "] edition = "2018" readme = "README.md" @@ -31,6 +31,7 @@ trait-bound-typemap = "0.3.3" rmp-serde = { version = "1.0.0", optional = true } bincode = { version = "1.3.3", optional = true } serde_json = { version = "1.0.79", optional = true } +bytes = "1.1.0" [dependencies.serde] optional = true diff --git a/benches/deserialization_benchmark.rs b/benches/deserialization_benchmark.rs index c79889f1..76c2ee02 100644 --- a/benches/deserialization_benchmark.rs +++ b/benches/deserialization_benchmark.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use criterion::{black_box, BenchmarkId, Throughput}; use criterion::{criterion_group, criterion_main}; use criterion::{BatchSize, Criterion}; @@ -9,17 +10,21 @@ use tokio::runtime::Runtime; pub const EVENT_NAME: &str = "bench_event"; fn create_event_bytes_reader(data_size: usize) -> Cursor> { - let bytes = Event::initiator(None, EVENT_NAME.to_string(), vec![0u8; data_size]) - .into_bytes() - .unwrap(); - Cursor::new(bytes) + let bytes = Event::initiator( + None, + EVENT_NAME.to_string(), + Bytes::from(vec![0u8; data_size]), + ) + .into_bytes() + .unwrap(); + Cursor::new(bytes.to_vec()) } fn event_deserialization(c: &mut Criterion) { let runtime = Runtime::new().unwrap(); let mut group = c.benchmark_group("event_deserialization"); - for size in (0..10) + for size in (0..16) .step_by(2) .map(|i| 1024 * 2u32.pow(i as u32) as usize) { diff --git a/benches/serialization_benchmark.rs b/benches/serialization_benchmark.rs index 2d25643b..f99a31f0 100644 --- a/benches/serialization_benchmark.rs +++ b/benches/serialization_benchmark.rs @@ -1,4 +1,5 @@ use bromine::event::Event; +use bytes::Bytes; use criterion::{ black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput, }; @@ -6,13 +7,17 @@ use criterion::{ pub const EVENT_NAME: &str = "bench_event"; fn create_event(data_size: usize) -> Event { - Event::initiator(None, EVENT_NAME.to_string(), vec![0u8; data_size]) + Event::initiator( + None, + EVENT_NAME.to_string(), + Bytes::from(vec![0u8; data_size]), + ) } fn event_serialization(c: &mut Criterion) { let mut group = c.benchmark_group("event_serialization"); - for size in (0..10) + for size in (0..16) .step_by(2) .map(|i| 1024 * 2u32.pow(i as u32) as usize) { diff --git a/src/events/error_event.rs b/src/events/error_event.rs index bff310a1..90c78e34 100644 --- a/src/events/error_event.rs +++ b/src/events/error_event.rs @@ -3,6 +3,7 @@ use crate::error::Result; use crate::payload::{FromPayload, IntoPayload}; use crate::prelude::{IPCError, IPCResult}; use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, Bytes, BytesMut}; use std::error::Error; use std::fmt::{Display, Formatter}; use std::io::Read; @@ -29,14 +30,13 @@ impl Display for ErrorEventData { } impl IntoPayload for ErrorEventData { - fn into_payload(self, _: &Context) -> IPCResult> { - let mut buf = Vec::new(); - buf.append(&mut self.code.to_be_bytes().to_vec()); - let message_len = self.message.len() as u32; - buf.append(&mut message_len.to_be_bytes().to_vec()); - buf.append(&mut self.message.into_bytes()); - - Ok(buf) + fn into_payload(self, _: &Context) -> IPCResult { + let mut buf = BytesMut::new(); + buf.put_u16(self.code); + buf.put_u32(self.message.len() as u32); + buf.put(Bytes::from(self.message)); + + Ok(buf.freeze()) } } diff --git a/src/events/event.rs b/src/events/event.rs index c6e33ddc..59475f31 100644 --- a/src/events/event.rs +++ b/src/events/event.rs @@ -2,6 +2,7 @@ use crate::error::{Error, Result}; use crate::events::generate_event_id; use crate::events::payload::FromPayload; use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, Bytes, BytesMut}; use num_enum::{IntoPrimitive, TryFromPrimitive}; use std::convert::TryFrom; use std::fmt::Debug; @@ -16,7 +17,7 @@ pub const FORMAT_VERSION: [u8; 3] = [0, 9, 0]; #[derive(Debug)] pub struct Event { header: EventHeader, - data: Vec, + data: Bytes, } #[derive(Debug)] @@ -41,21 +42,21 @@ impl Event { /// Creates a new event that acts as an initiator for further response events #[tracing::instrument(level = "trace", skip(data))] #[inline] - pub fn initiator(namespace: Option, name: String, data: Vec) -> Self { + pub fn initiator(namespace: Option, name: String, data: Bytes) -> Self { Self::new(namespace, name, data, None, EventType::Initiator) } /// Creates a new event that is a response to a previous event #[tracing::instrument(level = "trace", skip(data))] #[inline] - pub fn response(namespace: Option, name: String, data: Vec, ref_id: u64) -> Self { + pub fn response(namespace: Option, name: String, data: Bytes, ref_id: u64) -> Self { Self::new(namespace, name, data, Some(ref_id), EventType::Response) } /// Creates a new error event as a response to a previous event #[tracing::instrument(level = "trace", skip(data))] #[inline] - pub fn error(namespace: Option, name: String, data: Vec, ref_id: u64) -> Self { + pub fn error(namespace: Option, name: String, data: Bytes, ref_id: u64) -> Self { Self::new(namespace, name, data, Some(ref_id), EventType::Error) } @@ -63,7 +64,7 @@ impl Event { /// and might contain a final response payload #[tracing::instrument(level = "trace", skip(data))] #[inline] - pub fn end(namespace: Option, name: String, data: Vec, ref_id: u64) -> Self { + pub fn end(namespace: Option, name: String, data: Bytes, ref_id: u64) -> Self { Self::new(namespace, name, data, Some(ref_id), EventType::Response) } @@ -72,7 +73,7 @@ impl Event { pub(crate) fn new( namespace: Option, name: String, - data: Vec, + data: Bytes, ref_id: Option, event_type: EventType, ) -> Self { @@ -145,57 +146,59 @@ impl Event { // additional header fields can be added a the end because when reading they will just be ignored let header: EventHeader = EventHeader::from_read(&mut Cursor::new(header_bytes))?; - let mut data = vec![0u8; data_length as usize]; - reader.read_exact(&mut data).await?; - let event = Event { header, data }; + let mut buf = vec![0u8; data_length as usize]; + reader.read_exact(&mut buf).await?; + let event = Event { + header, + data: Bytes::from(buf), + }; Ok(event) } /// Encodes the event into bytes #[tracing::instrument(level = "trace", skip(self))] - pub fn into_bytes(mut self) -> Result> { - let mut header_bytes = self.header.into_bytes(); + pub fn into_bytes(self) -> Result { + let header_bytes = self.header.into_bytes(); let header_length = header_bytes.len() as u16; let data_length = self.data.len(); let total_length = header_length as u64 + data_length as u64; tracing::trace!(total_length, header_length, data_length); - let mut buf = Vec::with_capacity(total_length as usize); - buf.append(&mut total_length.to_be_bytes().to_vec()); - buf.append(&mut header_length.to_be_bytes().to_vec()); - buf.append(&mut header_bytes); - buf.append(&mut self.data); + let mut buf = BytesMut::with_capacity(total_length as usize); + buf.put_u64(total_length); + buf.put_u16(header_length); + buf.put(header_bytes); + buf.put(self.data); - Ok(buf) + Ok(buf.freeze()) } } impl EventHeader { /// Serializes the event header into bytes - pub fn into_bytes(self) -> Vec { - let mut buf = FORMAT_VERSION.to_vec(); - buf.append(&mut self.id.to_be_bytes().to_vec()); - buf.push(self.event_type.into()); + pub fn into_bytes(self) -> Bytes { + let mut buf = BytesMut::with_capacity(256); + buf.put_slice(&FORMAT_VERSION); + buf.put_u64(self.id); + buf.put_u8(u8::from(self.event_type)); if let Some(ref_id) = self.ref_id { - buf.push(0xFF); - buf.append(&mut ref_id.to_be_bytes().to_vec()); + buf.put_u8(0xFF); + buf.put_u64(ref_id); } else { - buf.push(0x00); + buf.put_u8(0x00); } if let Some(namespace) = self.namespace { - let namespace_len = namespace.len() as u16; - buf.append(&mut namespace_len.to_be_bytes().to_vec()); - buf.append(&mut namespace.into_bytes()); + buf.put_u16(namespace.len() as u16); + buf.put(Bytes::from(namespace)); } else { - buf.append(&mut 0u16.to_be_bytes().to_vec()); + buf.put_u16(0); } - let name_len = self.name.len() as u16; - buf.append(&mut name_len.to_be_bytes().to_vec()); - buf.append(&mut self.name.into_bytes()); + buf.put_u16(self.name.len() as u16); + buf.put(Bytes::from(self.name)); - buf + buf.freeze() } /// Parses an event header from an async reader diff --git a/src/events/event_handler.rs b/src/events/event_handler.rs index 8b87f603..16c09fe4 100644 --- a/src/events/event_handler.rs +++ b/src/events/event_handler.rs @@ -2,13 +2,14 @@ use crate::error::Result; use crate::events::event::Event; use crate::ipc::context::Context; use crate::payload::{BytePayload, IntoPayload}; +use bytes::Bytes; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; -pub struct Response(Vec); +pub struct Response(Bytes); impl Response { /// Creates a new response with a given payload @@ -20,11 +21,11 @@ impl Response { /// Creates an empty response pub fn empty() -> Self { - Self(vec![]) + Self(Bytes::new()) } pub(crate) fn into_byte_payload(self) -> BytePayload { - BytePayload::new(self.0) + BytePayload::from(self.0) } } diff --git a/src/events/payload.rs b/src/events/payload.rs index 2d7dbd21..bad2d26f 100644 --- a/src/events/payload.rs +++ b/src/events/payload.rs @@ -1,5 +1,6 @@ use crate::prelude::IPCResult; use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, Bytes, BytesMut}; use std::io::Read; #[cfg(feature = "serialize")] @@ -7,18 +8,13 @@ pub use super::payload_serializer::*; /// Trait that serializes a type into bytes and can fail pub trait TryIntoBytes { - fn try_into_bytes(self) -> IPCResult>; -} - -/// Trait that serializes a type into bytes and never fails -pub trait IntoBytes { - fn into_bytes(self) -> Vec; + fn try_into_bytes(self) -> IPCResult; } /// Trait to convert event data into sending bytes /// It is implemented for all types that implement Serialize pub trait IntoPayload { - fn into_payload(self, ctx: &Context) -> IPCResult>; + fn into_payload(self, ctx: &Context) -> IPCResult; } /// Trait to get the event data from receiving bytes. @@ -31,25 +27,39 @@ pub trait FromPayload: Sized { /// serializing them #[derive(Clone)] pub struct BytePayload { - bytes: Vec, + bytes: Bytes, } impl BytePayload { #[inline] pub fn new(bytes: Vec) -> Self { - Self { bytes } + Self { + bytes: Bytes::from(bytes), + } } - /// Returns the bytes of the payload + /// Returns the bytes as a `Vec` of the payload #[inline] pub fn into_inner(self) -> Vec { + self.bytes.to_vec() + } + + /// Returns the bytes struct of the payload + #[inline] + pub fn into_bytes(self) -> Bytes { self.bytes } } +impl From for BytePayload { + fn from(bytes: Bytes) -> Self { + Self { bytes } + } +} + impl IntoPayload for BytePayload { #[inline] - fn into_payload(self, _: &Context) -> IPCResult> { + fn into_payload(self, _: &Context) -> IPCResult { Ok(self.bytes) } } @@ -87,20 +97,18 @@ impl TandemPayload { } impl IntoPayload for TandemPayload { - fn into_payload(self, ctx: &Context) -> IPCResult> { - let mut p1_bytes = self.load1.into_payload(&ctx)?; - let mut p2_bytes = self.load2.into_payload(&ctx)?; + fn into_payload(self, ctx: &Context) -> IPCResult { + let p1_bytes = self.load1.into_payload(&ctx)?; + let p2_bytes = self.load2.into_payload(&ctx)?; - let mut p1_length_bytes = (p1_bytes.len() as u64).to_be_bytes().to_vec(); - let mut p2_length_bytes = (p2_bytes.len() as u64).to_be_bytes().to_vec(); + let mut bytes = BytesMut::with_capacity(p1_bytes.len() + p2_bytes.len() + 16); - let mut bytes = Vec::new(); - bytes.append(&mut p1_length_bytes); - bytes.append(&mut p1_bytes); - bytes.append(&mut p2_length_bytes); - bytes.append(&mut p2_bytes); + bytes.put_u64(p1_bytes.len() as u64); + bytes.put(p1_bytes); + bytes.put_u64(p2_bytes.len() as u64); + bytes.put(p2_bytes); - Ok(bytes) + Ok(bytes.freeze()) } } @@ -123,8 +131,8 @@ impl FromPayload for TandemPayload { #[cfg(not(feature = "serialize"))] impl IntoPayload for () { - fn into_payload(self, _: &Context) -> IPCResult> { - Ok(vec![]) + fn into_payload(self, _: &Context) -> IPCResult { + Ok(Bytes::new()) } } @@ -135,6 +143,7 @@ mod serde_payload { use crate::payload::{FromPayload, TryIntoBytes}; use crate::prelude::{IPCResult, IntoPayload}; use byteorder::ReadBytesExt; + use bytes::{BufMut, Bytes, BytesMut}; use serde::de::DeserializeOwned; use serde::Serialize; use std::io::Read; @@ -168,20 +177,20 @@ mod serde_payload { } impl TryIntoBytes for SerdePayload { - fn try_into_bytes(self) -> IPCResult> { - let mut buf = Vec::new(); - let mut data_bytes = self.serializer.serialize(self.data)?; + fn try_into_bytes(self) -> IPCResult { + let mut buf = BytesMut::new(); + let data_bytes = self.serializer.serialize(self.data)?; let format_id = self.serializer as u8; - buf.push(format_id); - buf.append(&mut data_bytes); + buf.put_u8(format_id); + buf.put(data_bytes); - Ok(buf) + Ok(buf.freeze()) } } impl IntoPayload for SerdePayload { #[inline] - fn into_payload(self, _: &Context) -> IPCResult> { + fn into_payload(self, _: &Context) -> IPCResult { self.try_into_bytes() } } @@ -198,7 +207,7 @@ mod serde_payload { impl IntoPayload for T { #[inline] - fn into_payload(self, ctx: &Context) -> IPCResult> { + fn into_payload(self, ctx: &Context) -> IPCResult { ctx.create_serde_payload(self).into_payload(&ctx) } } diff --git a/src/events/payload_serializer/mod.rs b/src/events/payload_serializer/mod.rs index 7dab9643..c96923c6 100644 --- a/src/events/payload_serializer/mod.rs +++ b/src/events/payload_serializer/mod.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use serde::de::DeserializeOwned; use serde::Serialize; use std::io::Read; @@ -49,7 +50,7 @@ pub enum SerializationError { UnknownFormat(usize), } -#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Ord, PartialOrd, Eq, PartialEq)] pub enum DynamicSerializer { Messagepack, Bincode, @@ -109,7 +110,7 @@ impl DynamicSerializer { } } - pub fn serialize(&self, data: T) -> SerializationResult> { + pub fn serialize(&self, data: T) -> SerializationResult { match self { #[cfg(feature = "serialize_rmp")] DynamicSerializer::Messagepack => serialize_rmp::serialize(data), diff --git a/src/events/payload_serializer/serialize_bincode.rs b/src/events/payload_serializer/serialize_bincode.rs index 9567327b..afbc75b9 100644 --- a/src/events/payload_serializer/serialize_bincode.rs +++ b/src/events/payload_serializer/serialize_bincode.rs @@ -1,13 +1,14 @@ use crate::payload::SerializationResult; +use bytes::Bytes; use serde::de::DeserializeOwned; use serde::Serialize; use std::io::Read; #[inline] -pub fn serialize(data: T) -> SerializationResult> { +pub fn serialize(data: T) -> SerializationResult { let bytes = bincode::serialize(&data)?; - Ok(bytes) + Ok(Bytes::from(bytes)) } #[inline] diff --git a/src/events/payload_serializer/serialize_json.rs b/src/events/payload_serializer/serialize_json.rs index c35af059..5f41495c 100644 --- a/src/events/payload_serializer/serialize_json.rs +++ b/src/events/payload_serializer/serialize_json.rs @@ -1,13 +1,14 @@ use crate::payload::SerializationResult; +use bytes::Bytes; use serde::de::DeserializeOwned; use serde::Serialize; use std::io::Read; #[inline] -pub fn serialize(data: T) -> SerializationResult> { +pub fn serialize(data: T) -> SerializationResult { let bytes = serde_json::to_vec(&data)?; - Ok(bytes) + Ok(Bytes::from(bytes)) } #[inline] diff --git a/src/events/payload_serializer/serialize_postcard.rs b/src/events/payload_serializer/serialize_postcard.rs index 75237289..ab50d081 100644 --- a/src/events/payload_serializer/serialize_postcard.rs +++ b/src/events/payload_serializer/serialize_postcard.rs @@ -1,13 +1,14 @@ use crate::payload::SerializationResult; +use bytes::Bytes; use serde::de::DeserializeOwned; use serde::Serialize; use std::io::Read; #[inline] -pub fn serialize(data: T) -> SerializationResult> { +pub fn serialize(data: T) -> SerializationResult { let bytes = postcard::to_allocvec(&data)?.to_vec(); - Ok(bytes) + Ok(Bytes::from(bytes)) } #[inline] diff --git a/src/events/payload_serializer/serialize_rmp.rs b/src/events/payload_serializer/serialize_rmp.rs index ec9dfcc6..13e081af 100644 --- a/src/events/payload_serializer/serialize_rmp.rs +++ b/src/events/payload_serializer/serialize_rmp.rs @@ -1,13 +1,14 @@ use crate::payload::SerializationResult; +use bytes::Bytes; use serde::de::DeserializeOwned; use serde::Serialize; use std::io::Read; #[inline] -pub fn serialize(data: T) -> SerializationResult> { +pub fn serialize(data: T) -> SerializationResult { let bytes = rmp_serde::to_vec(&data)?; - Ok(bytes) + Ok(Bytes::from(bytes)) } #[inline] diff --git a/src/ipc/stream_emitter/event_metadata.rs b/src/ipc/stream_emitter/event_metadata.rs index c9a834f3..4ed9a57c 100644 --- a/src/ipc/stream_emitter/event_metadata.rs +++ b/src/ipc/stream_emitter/event_metadata.rs @@ -37,7 +37,7 @@ impl EventMetadata

{ let payload = self.payload.take().ok_or(Error::InvalidState)?; let res_id = self.res_id.take().ok_or(Error::InvalidState)?; let event_type = self.event_type.take().ok_or(Error::InvalidState)?; - let payload_bytes = payload.into_payload(&ctx)?; + let payload_bytes = payload.into_payload(&ctx)?.into(); let event = Event::new( namespace, diff --git a/src/lib.rs b/src/lib.rs index d15bc06e..b3cc6fe8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,6 +119,12 @@ mod macros; mod namespaces; pub mod protocol; +/// Reexported for usage in payload implementations +pub use bytes; + +/// Reexported for sharing data in context +pub use trait_bound_typemap; + pub use events::error_event; pub use events::event; pub use events::event_handler; @@ -146,4 +152,5 @@ pub mod prelude { pub use crate::payload::*; pub use crate::protocol::*; pub use crate::*; + pub use trait_bound_typemap::TypeMap; } diff --git a/tests/test_event_streams.rs b/tests/test_event_streams.rs index 5a8f2edb..dfdbbd51 100644 --- a/tests/test_event_streams.rs +++ b/tests/test_event_streams.rs @@ -3,6 +3,7 @@ use crate::utils::protocol::TestProtocolListener; use crate::utils::{get_free_port, start_server_and_client}; use bromine::prelude::*; use byteorder::ReadBytesExt; +use bytes::Bytes; use futures::StreamExt; use std::io::Read; use std::time::Duration; @@ -66,16 +67,16 @@ async fn handle_stream_event(ctx: &Context, event: Event) -> IPCResult pub struct EmptyPayload; impl IntoPayload for EmptyPayload { - fn into_payload(self, _: &Context) -> IPCResult> { - Ok(vec![]) + fn into_payload(self, _: &Context) -> IPCResult { + Ok(Bytes::new()) } } pub struct NumberPayload(u8); impl IntoPayload for NumberPayload { - fn into_payload(self, _: &Context) -> IPCResult> { - Ok(vec![self.0]) + fn into_payload(self, _: &Context) -> IPCResult { + Ok(Bytes::from(vec![self.0])) } } diff --git a/tests/test_events_with_payload.rs b/tests/test_events_with_payload.rs index 7c14ad7f..1e2be84c 100644 --- a/tests/test_events_with_payload.rs +++ b/tests/test_events_with_payload.rs @@ -91,6 +91,7 @@ mod payload_impl { use bromine::payload::{FromPayload, IntoPayload}; use bromine::prelude::IPCResult; use byteorder::{BigEndian, ReadBytesExt}; + use bytes::{BufMut, Bytes, BytesMut}; use std::io::Read; pub struct SimplePayload { @@ -99,17 +100,13 @@ mod payload_impl { } impl IntoPayload for SimplePayload { - fn into_payload(self, _: &Context) -> IPCResult> { - let mut buf = Vec::new(); - let string_length = self.string.len() as u16; - let string_length_bytes = string_length.to_be_bytes(); - buf.append(&mut string_length_bytes.to_vec()); - let mut string_bytes = self.string.into_bytes(); - buf.append(&mut string_bytes); - let num_bytes = self.number.to_be_bytes(); - buf.append(&mut num_bytes.to_vec()); - - Ok(buf) + fn into_payload(self, _: &Context) -> IPCResult { + let mut buf = BytesMut::new(); + buf.put_u16(self.string.len() as u16); + buf.put(Bytes::from(self.string)); + buf.put_u32(self.number); + + Ok(buf.freeze()) } } diff --git a/tests/test_raw_events.rs b/tests/test_raw_events.rs index c690f869..dfa9e7d3 100644 --- a/tests/test_raw_events.rs +++ b/tests/test_raw_events.rs @@ -2,6 +2,7 @@ mod utils; use crate::utils::start_server_and_client; use bromine::prelude::*; +use bytes::Bytes; use std::time::Duration; use utils::call_counter::*; use utils::get_free_port; @@ -132,7 +133,7 @@ async fn handle_error_event(ctx: &Context, event: Event) -> IPCResult pub struct EmptyPayload; impl IntoPayload for EmptyPayload { - fn into_payload(self, _: &Context) -> IPCResult> { - Ok(vec![]) + fn into_payload(self, _: &Context) -> IPCResult { + Ok(Bytes::new()) } } From fe7dc970083cc619f83fd988d302a4ef576d1b19 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 26 Mar 2022 18:39:15 +0100 Subject: [PATCH 3/4] Add encrypted wrapper protocol implementation Signed-off-by: trivernis --- Cargo.lock | 299 ++++++++++++++++++++++- Cargo.toml | 5 + src/ipc/builder.rs | 34 ++- src/ipc/client.rs | 5 +- src/ipc/server.rs | 5 +- src/ipc/stream_emitter/emit_metadata.rs | 1 + src/lib.rs | 1 + src/protocol/encrypted/crypt_handling.rs | 158 ++++++++++++ src/protocol/encrypted/io_impl.rs | 220 +++++++++++++++++ src/protocol/encrypted/mod.rs | 152 ++++++++++++ src/protocol/encrypted/protocol_impl.rs | 55 +++++ src/protocol/mod.rs | 21 +- src/protocol/tcp.rs | 12 +- src/protocol/unix_socket.rs | 9 +- tests/test_encryption.rs | 101 ++++++++ tests/utils/call_counter.rs | 1 + tests/utils/protocol.rs | 6 +- 17 files changed, 1065 insertions(+), 20 deletions(-) create mode 100644 src/protocol/encrypted/crypt_handling.rs create mode 100644 src/protocol/encrypted/io_impl.rs create mode 100644 src/protocol/encrypted/mod.rs create mode 100644 src/protocol/encrypted/protocol_impl.rs create mode 100644 tests/test_encryption.rs diff --git a/Cargo.lock b/Cargo.lock index c21e5985..b5e2f2b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aead" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" +dependencies = [ + "generic-array", +] + [[package]] name = "aho-corasick" version = "0.7.18" @@ -91,6 +100,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array", +] + [[package]] name = "bromine" version = "0.20.0" @@ -99,6 +117,7 @@ dependencies = [ "bincode", "byteorder", "bytes", + "chacha20poly1305", "criterion", "crossbeam-utils", "futures", @@ -106,13 +125,17 @@ dependencies = [ "lazy_static", "num_enum", "postcard", + "rand", + "rand_core 0.6.3", "rmp-serde", "serde", "serde_json", + "sha2", "thiserror", "tokio", "tracing", "trait-bound-typemap", + "x25519-dalek", ] [[package]] @@ -160,6 +183,40 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b72a433d0cf2aef113ba70f62634c56fddb0f244e6377185c56a7cadbd8f91" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", + "zeroize", +] + +[[package]] +name = "chacha20poly1305" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b84ed6d1d5f7aa9bdde921a5090e0ca4d934d250ea3b402a5fab3a994e28a2a" +dependencies = [ + "aead", + "chacha20", + "cipher", + "poly1305", + "zeroize", +] + +[[package]] +name = "cipher" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +dependencies = [ + "generic-array", +] + [[package]] name = "clap" version = "2.34.0" @@ -183,6 +240,15 @@ dependencies = [ "volatile-register", ] +[[package]] +name = "cpufeatures" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.3.5" @@ -277,6 +343,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crypto-common" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.1.6" @@ -299,6 +375,38 @@ dependencies = [ "memchr", ] +[[package]] +name = "curve25519-dalek" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f9d052967f590a76e62eb387bd0bbb1b000182c3cefe5364db6b7211651bc0" +dependencies = [ + "byteorder", + "digest 0.9.0", + "rand_core 0.5.1", + "subtle", + "zeroize", +] + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "either" version = "1.6.1" @@ -404,6 +512,38 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.10.2+wasi-snapshot-preview1", +] + [[package]] name = "half" version = "1.8.2" @@ -526,7 +666,7 @@ dependencies = [ "log", "miow", "ntapi", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "winapi", ] @@ -615,6 +755,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "pin-project-lite" version = "0.2.8" @@ -655,6 +801,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly1305" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" +dependencies = [ + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "postcard" version = "0.7.3" @@ -672,6 +829,12 @@ version = "0.1.5-pre" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c68cb38ed13fd7bc9dd5db8f165b7c8d9c1a315104083a2b10f11354c2af97f" +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + [[package]] name = "proc-macro-crate" version = "1.1.3" @@ -700,6 +863,45 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core 0.6.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.3", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", +] + +[[package]] +name = "rand_core" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +dependencies = [ + "getrandom 0.2.5", +] + [[package]] name = "rayon" version = "1.5.1" @@ -891,6 +1093,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "slab" version = "0.4.5" @@ -922,6 +1135,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + [[package]] name = "syn" version = "1.0.86" @@ -933,6 +1152,18 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -1050,6 +1281,12 @@ dependencies = [ "multi-trait-object", ] +[[package]] +name = "typenum" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" + [[package]] name = "unicode-width" version = "0.1.9" @@ -1062,12 +1299,28 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "universal-hash" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f214e8f697e925001e66ec2c6e37a4ef93f0f78c2eed7814394e10c62025b05" +dependencies = [ + "generic-array", + "subtle", +] + [[package]] name = "vcell" version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77439c1b53d2303b20d9459b1ade71a83c716e3f9c34f3228c00e6f185d6c002" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "void" version = "1.0.2" @@ -1094,6 +1347,18 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + +[[package]] +name = "wasi" +version = "0.10.2+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1194,3 +1459,35 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "x25519-dalek" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2392b6b94a576b4e2bf3c5b2757d63f10ada8020a2e4d08ac849ebcf6ea8e077" +dependencies = [ + "curve25519-dalek", + "rand_core 0.5.1", + "zeroize", +] + +[[package]] +name = "zeroize" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4756f7db3f7b5574938c3eb1c117038b8e07f95ee6718c0efad4ac21508f1efd" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f8f187641dad4f680d25c4bfc4225b418165984179f26ca76ec4fb6441d3a17" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] diff --git a/Cargo.toml b/Cargo.toml index e9e82cd2..1b3ebdf4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,11 @@ rmp-serde = { version = "1.0.0", optional = true } bincode = { version = "1.3.3", optional = true } serde_json = { version = "1.0.79", optional = true } bytes = "1.1.0" +chacha20poly1305 = "0.9.0" +x25519-dalek = "1.2.0" +rand = "0.8.5" +rand_core = "0.6.3" +sha2 = "0.10.2" [dependencies.serde] optional = true diff --git a/src/ipc/builder.rs b/src/ipc/builder.rs index 4183f05e..83cba3ce 100644 --- a/src/ipc/builder.rs +++ b/src/ipc/builder.rs @@ -11,6 +11,7 @@ use crate::namespaces::builder::NamespaceBuilder; use crate::namespaces::namespace::Namespace; #[cfg(feature = "serialize")] use crate::payload::DynamicSerializer; +use crate::prelude::AsyncProtocolStream; use crate::protocol::AsyncStreamProtocolListener; use std::collections::HashMap; use std::future::Future; @@ -64,6 +65,8 @@ pub struct IPCBuilder { timeout: Duration, #[cfg(feature = "serialize")] default_serializer: DynamicSerializer, + listener_options: L::ListenerOptions, + stream_options: ::StreamOptions, } impl IPCBuilder @@ -89,6 +92,8 @@ where timeout: Duration::from_secs(60), #[cfg(feature = "serialize")] default_serializer: DynamicSerializer::first_available(), + listener_options: Default::default(), + stream_options: Default::default(), } } @@ -163,6 +168,23 @@ where self } + /// Sets the options for the given protocol listener + pub fn server_options(mut self, options: L::ListenerOptions) -> Self { + self.listener_options = options; + + self + } + + /// Sets the options for the given protocol stream + pub fn client_options( + mut self, + options: ::StreamOptions, + ) -> Self { + self.stream_options = options; + + self + } + /// Builds an ipc server #[tracing::instrument(skip(self))] pub async fn build_server(self) -> Result<()> { @@ -176,7 +198,9 @@ where #[cfg(feature = "serialize")] default_serializer: self.default_serializer, }; - server.start::(self.address.unwrap()).await?; + server + .start::(self.address.unwrap(), self.listener_options) + .await?; Ok(()) } @@ -198,7 +222,9 @@ where default_serializer: self.default_serializer, }; - let ctx = client.connect::(self.address.unwrap()).await?; + let ctx = client + .connect::(self.address.unwrap(), self.stream_options.clone()) + .await?; Ok(ctx) } @@ -230,7 +256,9 @@ where default_serializer: self.default_serializer.clone(), }; - let ctx = client.connect::(address.clone()).await?; + let ctx = client + .connect::(address.clone(), self.stream_options.clone()) + .await?; contexts.push(ctx); } diff --git a/src/ipc/client.rs b/src/ipc/client.rs index ae7f8600..aa490a73 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -33,12 +33,13 @@ pub struct IPCClient { impl IPCClient { /// Connects to a given address and returns an emitter for events to that address. /// Invoked by [IPCBuilder::build_client](crate::builder::IPCBuilder::build_client) - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self, options))] pub async fn connect( self, address: S::AddressType, + options: S::StreamOptions, ) -> Result { - let stream = S::protocol_connect(address).await?; + let stream = S::protocol_connect(address, options).await?; let (read_half, write_half) = stream.protocol_into_split(); let emitter = StreamEmitter::new::(write_half); diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 9a71a66c..c88d238a 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -31,12 +31,13 @@ pub struct IPCServer { impl IPCServer { /// Starts the IPC Server. /// Invoked by [IPCBuilder::build_server](crate::builder::IPCBuilder::build_server) - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self, options))] pub async fn start( self, address: L::AddressType, + options: L::ListenerOptions, ) -> Result<()> { - let listener = L::protocol_bind(address.clone()).await?; + let listener = L::protocol_bind(address.clone(), options).await?; let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); let data = Arc::new(RwLock::new(self.data)); diff --git a/src/ipc/stream_emitter/emit_metadata.rs b/src/ipc/stream_emitter/emit_metadata.rs index 57d77fab..b5a3e85b 100644 --- a/src/ipc/stream_emitter/emit_metadata.rs +++ b/src/ipc/stream_emitter/emit_metadata.rs @@ -91,6 +91,7 @@ impl Future for EmitMetadata

{ let event_bytes = event.into_bytes()?; let mut stream = stream.lock().await; stream.deref_mut().write_all(&event_bytes[..]).await?; + stream.deref_mut().flush().await?; tracing::trace!(bytes_len = event_bytes.len()); Ok(event_id) diff --git a/src/lib.rs b/src/lib.rs index b3cc6fe8..6eba0549 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,7 @@ //! # } //! ``` +extern crate core; #[cfg(all( feature = "serialize", not(any( diff --git a/src/protocol/encrypted/crypt_handling.rs b/src/protocol/encrypted/crypt_handling.rs new file mode 100644 index 00000000..4bfe35c9 --- /dev/null +++ b/src/protocol/encrypted/crypt_handling.rs @@ -0,0 +1,158 @@ +use crate::prelude::encrypted::EncryptedStream; +use crate::prelude::IPCResult; +use crate::protocol::AsyncProtocolStream; +use bytes::Bytes; +use chacha20poly1305::aead::{Aead, NewAead}; +use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce}; +use rand::thread_rng; +use rand_core::RngCore; +use sha2::{Digest, Sha224, Sha256}; +use std::io; +use std::io::ErrorKind; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use x25519_dalek::{PublicKey, StaticSecret}; + +/// A structure used for encryption. +/// It holds the cipher initialized with the given key +/// and two counters for both encryption and decryption +/// count which are used to keep track of the nonce. +#[derive(Clone)] +pub(crate) struct CipherBox { + cipher: ChaCha20Poly1305, + en_count: Arc, + de_count: Arc, +} + +impl CipherBox { + pub fn new(key: Bytes) -> Self { + let key = Key::from_slice(&key[..]); + let cipher = ChaCha20Poly1305::new(key); + + Self { + cipher, + en_count: Arc::new(AtomicU64::new(0)), + de_count: Arc::new(AtomicU64::new(0)), + } + } + + /// Encrypts the given message + pub fn encrypt(&self, data: Bytes) -> io::Result { + self.cipher + .encrypt(&self.en_nonce(), &data[..]) + .map(Bytes::from) + .map_err(|_| io::Error::from(ErrorKind::InvalidData)) + } + + /// Decrypts the given message + pub fn decrypt(&self, data: Bytes) -> io::Result { + self.cipher + .decrypt(&self.de_nonce(), &data[..]) + .map(Bytes::from) + .map_err(|_| io::Error::from(ErrorKind::InvalidData)) + } + + /// Updates the stored key. + /// This must be done simultaneously on server and client side + /// to keep track of the nonce + pub fn update_key(&mut self, key: Bytes) { + let key = Key::from_slice(&key[..]); + self.cipher = ChaCha20Poly1305::new(key); + self.reset_counters(); + } + + /// Resets the nonce counters. + /// This must be done simultaneously on server and client side. + pub fn reset_counters(&mut self) { + self.de_count.store(0, Ordering::SeqCst); + self.en_count.store(0, Ordering::SeqCst); + } + + fn en_nonce(&self) -> Nonce { + let count = self.en_count.fetch_add(1, Ordering::SeqCst); + nonce_from_number(count) + } + + fn de_nonce(&self) -> Nonce { + let count = self.de_count.fetch_add(1, Ordering::SeqCst); + nonce_from_number(count) + } +} + +/// Generates a nonce from a given number +/// the nonce is passed through sha224 for pseudo-randomness +fn nonce_from_number(number: u64) -> Nonce { + let number_bytes: [u8; 8] = number.to_be_bytes(); + let sha_bytes = Sha224::digest(&number_bytes).to_vec(); + let mut nonce_bytes = [0u8; 12]; + nonce_bytes.copy_from_slice(&sha_bytes[..12]); + + nonce_bytes.into() +} + +impl EncryptedStream { + /// Does a server-client key exchange. + /// 1. The server receives the public key of the client + /// 2. The server sends its own public key + /// 3. The server creates an intermediary encrypted connection + /// 4. The server generates a new secret + /// 5. The server sends the secret to the client + /// 6. The connection is upgraded with the new shared key + pub async fn from_server_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult { + let other_pub = receive_public_key(&mut inner).await?; + send_public_key(&mut inner, &secret).await?; + let shared_secret = secret.diffie_hellman(&other_pub); + let mut stream = Self::new(inner, shared_secret); + let permanent_secret = generate_secret(); + stream.write_all(&permanent_secret).await?; + stream.flush().await?; + stream.update_key(permanent_secret.into()); + + Ok(stream) + } + + /// Does a client-server key exchange. + /// 1. The client sends its public key to the server + /// 2. The client receives the servers public key + /// 3. The client creates an intermediary encrypted connection + /// 4. The client receives the new key from the server + /// 5. The connection is upgraded with the new shared key + pub async fn from_client_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult { + send_public_key(&mut inner, &secret).await?; + let other_pub = receive_public_key(&mut inner).await?; + let shared_secret = secret.diffie_hellman(&other_pub); + let mut stream = Self::new(inner, shared_secret); + let mut key_buf = vec![0u8; 32]; + stream.read_exact(&mut key_buf).await?; + stream.update_key(key_buf.into()); + + Ok(stream) + } +} + +async fn receive_public_key(stream: &mut T) -> IPCResult { + let mut pk_buf = [0u8; 32]; + stream.read_exact(&mut pk_buf).await?; + + Ok(PublicKey::from(pk_buf)) +} + +async fn send_public_key( + stream: &mut T, + secret: &StaticSecret, +) -> IPCResult<()> { + let own_pk = PublicKey::from(secret); + stream.write_all(own_pk.as_bytes()).await?; + stream.flush().await?; + + Ok(()) +} + +fn generate_secret() -> Vec { + let mut rng = thread_rng(); + let mut buf = vec![0u8; 32]; + rng.fill_bytes(&mut buf); + + Sha256::digest(&buf).to_vec() +} diff --git a/src/protocol/encrypted/io_impl.rs b/src/protocol/encrypted/io_impl.rs new file mode 100644 index 00000000..94360e38 --- /dev/null +++ b/src/protocol/encrypted/io_impl.rs @@ -0,0 +1,220 @@ +use crate::prelude::encrypted::{ + EncryptedPackage, EncryptedReadStream, EncryptedStream, EncryptedWriteStream, +}; +use crate::prelude::AsyncProtocolStream; +use crate::protocol::encrypted::crypt_handling::CipherBox; +use bytes::{Buf, BufMut, Bytes}; +use std::cmp::min; +use std::io; +use std::io::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; + +const WRITE_BUF_SIZE: usize = 1024; + +impl Unpin for EncryptedStream {} + +impl AsyncWrite for EncryptedStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.write_half).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.write_half).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.write_half).poll_shutdown(cx) + } +} + +impl AsyncRead for EncryptedStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.read_half).poll_read(cx, buf) + } +} + +impl Unpin for EncryptedReadStream {} + +impl AsyncRead for EncryptedReadStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.fut.is_none() { + let max_copy = min(buf.remaining(), self.remaining.len()); + let bytes = self.remaining.copy_to_bytes(max_copy); + buf.put_slice(&bytes); + + if buf.remaining() > 0 { + let mut reader = self.inner.take().unwrap(); + let cipher = self.cipher.take().unwrap(); + + self.fut = Some(Box::pin(async move { + let package = match EncryptedPackage::from_async_read(&mut reader).await { + Ok(p) => p, + Err(e) => { + return (Err(e), reader, cipher); + } + }; + match cipher.decrypt(package.into_inner()) { + Ok(bytes) => (Ok(bytes), reader, cipher), + Err(e) => (Err(e), reader, cipher), + } + })); + } + } + if self.fut.is_some() { + match self.fut.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, reader, cipher)) => { + self.inner = Some(reader); + self.cipher = Some(cipher); + match result { + Ok(bytes) => { + self.remaining.put(bytes); + let max_copy = min(self.remaining.len(), buf.remaining()); + let bytes = self.remaining.copy_to_bytes(max_copy); + self.fut = None; + buf.put_slice(&bytes); + + if buf.remaining() == 0 { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + Err(e) => Poll::Ready(Err(e)), + } + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl Unpin for EncryptedWriteStream {} + +impl AsyncWrite for EncryptedWriteStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.remaining() > 0 { + let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) }; + self.buffer.put(Bytes::from(buf)); + + if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE { + let buffer_len = self.buffer.len(); + let max_copy = min(u32::MAX as usize, buffer_len); + let plaintext = self.buffer.copy_to_bytes(max_copy); + let writer = self.inner.take().unwrap(); + let cipher = self.cipher.take().unwrap(); + + self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher))) + } + } + if self.fut_write.is_some() { + match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, writer, cipher)) => { + self.inner = Some(writer); + self.cipher = Some(cipher); + self.fut_write = None; + + Poll::Ready(result.map(|_| buf.len())) + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let buffer_len = self.buffer.len(); + + if !self.buffer.is_empty() && self.fut_flush.is_none() { + let max_copy = min(u32::MAX as usize, buffer_len); + let plaintext = self.buffer.copy_to_bytes(max_copy); + let writer = self.inner.take().unwrap(); + let cipher = self.cipher.take().unwrap(); + + self.fut_flush = Some(Box::pin(async move { + let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await; + if result.is_err() { + return (result, writer, cipher); + } + if let Err(e) = writer.flush().await { + (Err(e), writer, cipher) + } else { + (Ok(()), writer, cipher) + } + })) + } + match self.fut_flush.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, writer, cipher)) => { + self.inner = Some(writer); + self.cipher = Some(cipher); + self.fut_flush = None; + + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.fut_shutdown.is_none() { + match self.as_mut().poll_flush(cx) { + Poll::Ready(result) => match result { + Ok(_) => { + let mut writer = self.inner.take().unwrap(); + self.fut_shutdown = Some(Box::pin(async move { writer.shutdown().await })); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + }, + Poll::Pending => Poll::Pending, + } + } else { + match self.fut_shutdown.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready(result) => { + self.fut_shutdown = None; + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + } +} + +async fn write_bytes( + bytes: Bytes, + mut writer: T, + cipher: CipherBox, +) -> (io::Result<()>, T, CipherBox) { + let encrypted_bytes = match cipher.encrypt(bytes) { + Ok(b) => b, + Err(e) => { + return (Err(e), writer, cipher); + } + }; + let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes(); + if let Err(e) = writer.write_all(&package_bytes[..]).await { + return (Err(e), writer, cipher); + } + + (Ok(()), writer, cipher) +} diff --git a/src/protocol/encrypted/mod.rs b/src/protocol/encrypted/mod.rs new file mode 100644 index 00000000..7126dcdc --- /dev/null +++ b/src/protocol/encrypted/mod.rs @@ -0,0 +1,152 @@ +mod crypt_handling; +mod io_impl; +mod protocol_impl; + +use bytes::{BufMut, Bytes, BytesMut}; +pub use io_impl::*; +pub use protocol_impl::*; +use rand_core::RngCore; +use std::future::Future; +use std::io; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter}; +use x25519_dalek::{SharedSecret, StaticSecret}; + +use crate::prelude::encrypted::crypt_handling::CipherBox; +use crate::prelude::{AsyncProtocolStream, AsyncStreamProtocolListener}; + +pub type OptionalFuture = Option + Send + Sync>>>; + +#[derive(Clone)] +pub struct EncryptionOptions { + pub inner_options: T, + pub secret: StaticSecret, +} + +impl Default for EncryptionOptions { + fn default() -> Self { + let mut rng = rand::thread_rng(); + let mut secret = [0u8; 32]; + rng.fill_bytes(&mut secret); + + Self { + secret: StaticSecret::from(secret), + inner_options: T::default(), + } + } +} + +pub struct EncryptedListener { + inner: T, + secret: StaticSecret, +} + +impl EncryptedListener { + pub fn new(inner: T, secret: StaticSecret) -> Self { + Self { inner, secret } + } +} + +pub struct EncryptedStream { + read_half: EncryptedReadStream, + write_half: EncryptedWriteStream, +} + +impl EncryptedStream { + pub fn new(inner: T, secret: SharedSecret) -> Self { + let cipher_box = CipherBox::new(Bytes::from(secret.to_bytes().to_vec())); + let (read, write) = inner.protocol_into_split(); + let read_half = EncryptedReadStream::new(read, cipher_box.clone()); + let write_half = EncryptedWriteStream::new(write, cipher_box); + + Self { + read_half, + write_half, + } + } + + pub fn update_key(&mut self, key: Bytes) { + self.write_half + .cipher + .as_mut() + .unwrap() + .update_key(key.clone()); + self.read_half + .cipher + .as_mut() + .unwrap() + .update_key(key.clone()); + } +} + +pub struct EncryptedReadStream { + inner: Option>, + fut: OptionalFuture<(io::Result, BufReader, CipherBox)>, + remaining: BytesMut, + cipher: Option, +} + +impl EncryptedReadStream { + pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { + Self { + inner: Some(BufReader::new(inner)), + fut: None, + remaining: BytesMut::new(), + cipher: Some(cipher), + } + } +} + +pub struct EncryptedWriteStream { + inner: Option>, + cipher: Option, + buffer: BytesMut, + fut_write: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, + fut_flush: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, + fut_shutdown: OptionalFuture>, +} + +impl EncryptedWriteStream { + pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { + Self { + inner: Some(BufWriter::new(inner)), + cipher: Some(cipher), + buffer: BytesMut::with_capacity(1024), + fut_write: None, + fut_flush: None, + fut_shutdown: None, + } + } +} + +pub(crate) struct EncryptedPackage { + bytes: Bytes, +} + +impl EncryptedPackage { + pub fn new(bytes: Bytes) -> Self { + Self { bytes } + } + + pub fn into_bytes(self) -> Bytes { + let mut buf = BytesMut::with_capacity(4 + self.bytes.len()); + buf.put_u32(self.bytes.len() as u32); + buf.put(self.bytes); + + buf.freeze() + } + + pub async fn from_async_read(reader: &mut R) -> io::Result { + let length = reader.read_u32().await?; + let mut bytes_buf = vec![0u8; length as usize]; + reader.read_exact(&mut bytes_buf).await?; + + Ok(Self { + bytes: Bytes::from(bytes_buf), + }) + } + + pub fn into_inner(self) -> Bytes { + self.bytes + } +} diff --git a/src/protocol/encrypted/protocol_impl.rs b/src/protocol/encrypted/protocol_impl.rs new file mode 100644 index 00000000..53d664d2 --- /dev/null +++ b/src/protocol/encrypted/protocol_impl.rs @@ -0,0 +1,55 @@ +use crate::error::Result; +use crate::prelude::encrypted::{EncryptedReadStream, EncryptedWriteStream}; +use crate::prelude::{AsyncProtocolStreamSplit, IPCResult}; +use crate::protocol::encrypted::{EncryptedListener, EncryptedStream, EncryptionOptions}; +use crate::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener}; +use async_trait::async_trait; + +#[async_trait] +impl AsyncStreamProtocolListener for EncryptedListener { + type AddressType = T::AddressType; + type RemoteAddressType = T::RemoteAddressType; + type Stream = EncryptedStream; + type ListenerOptions = EncryptionOptions; + + async fn protocol_bind( + address: Self::AddressType, + options: Self::ListenerOptions, + ) -> IPCResult { + let inner = T::protocol_bind(address, options.inner_options).await?; + + Ok(EncryptedListener::new(inner, options.secret)) + } + + async fn protocol_accept(&self) -> IPCResult<(Self::Stream, Self::RemoteAddressType)> { + let (inner_stream, remote_addr) = self.inner.protocol_accept().await?; + let stream = + Self::Stream::from_server_key_exchange(inner_stream, self.secret.clone()).await?; + + Ok((stream, remote_addr)) + } +} + +#[async_trait] +impl AsyncProtocolStream for EncryptedStream { + type AddressType = T::AddressType; + type StreamOptions = EncryptionOptions; + + async fn protocol_connect( + address: Self::AddressType, + options: Self::StreamOptions, + ) -> Result { + let inner = T::protocol_connect(address, options.inner_options).await?; + EncryptedStream::from_client_key_exchange(inner, options.secret).await + } +} + +#[async_trait] +impl AsyncProtocolStreamSplit for EncryptedStream { + type OwnedSplitReadHalf = EncryptedReadStream; + type OwnedSplitWriteHalf = EncryptedWriteStream; + + fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) { + (self.read_half, self.write_half) + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 6fa9c8b2..2792fb6f 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,5 +1,6 @@ pub mod tcp; +pub mod encrypted; #[cfg(unix)] pub mod unix_socket; @@ -12,25 +13,33 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub trait AsyncStreamProtocolListener: Sized + Send + Sync { type AddressType: Clone + Debug + Send + Sync; type RemoteAddressType: Debug + Send + Sync; - type Stream: 'static + AsyncProtocolStream + Send + Sync; + type Stream: 'static + AsyncProtocolStream; + type ListenerOptions: Clone + Default + Send + Sync; - async fn protocol_bind(address: Self::AddressType) -> IPCResult; + async fn protocol_bind( + address: Self::AddressType, + options: Self::ListenerOptions, + ) -> IPCResult; async fn protocol_accept(&self) -> IPCResult<(Self::Stream, Self::RemoteAddressType)>; } pub trait AsyncProtocolStreamSplit { - type OwnedSplitReadHalf: AsyncRead + Send + Sync + Unpin; - type OwnedSplitWriteHalf: AsyncWrite + Send + Sync + Unpin; + type OwnedSplitReadHalf: 'static + AsyncRead + Send + Sync + Unpin; + type OwnedSplitWriteHalf: 'static + AsyncWrite + Send + Sync + Unpin; fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf); } #[async_trait] pub trait AsyncProtocolStream: - AsyncRead + AsyncWrite + Send + Sync + AsyncProtocolStreamSplit + Sized + AsyncRead + AsyncWrite + Send + Sync + AsyncProtocolStreamSplit + Sized + Unpin { type AddressType: Clone + Debug + Send + Sync; + type StreamOptions: Clone + Default + Send + Sync; - async fn protocol_connect(address: Self::AddressType) -> IPCResult; + async fn protocol_connect( + address: Self::AddressType, + options: Self::StreamOptions, + ) -> IPCResult; } diff --git a/src/protocol/tcp.rs b/src/protocol/tcp.rs index abe43d4f..f9369b4f 100644 --- a/src/protocol/tcp.rs +++ b/src/protocol/tcp.rs @@ -10,8 +10,12 @@ impl AsyncStreamProtocolListener for TcpListener { type AddressType = SocketAddr; type RemoteAddressType = SocketAddr; type Stream = TcpStream; + type ListenerOptions = (); - async fn protocol_bind(address: Self::AddressType) -> IPCResult { + async fn protocol_bind( + address: Self::AddressType, + _: Self::ListenerOptions, + ) -> IPCResult { let listener = TcpListener::bind(address).await?; Ok(listener) @@ -36,8 +40,12 @@ impl AsyncProtocolStreamSplit for TcpStream { #[async_trait] impl AsyncProtocolStream for TcpStream { type AddressType = SocketAddr; + type StreamOptions = (); - async fn protocol_connect(address: Self::AddressType) -> IPCResult { + async fn protocol_connect( + address: Self::AddressType, + _: Self::StreamOptions, + ) -> IPCResult { let stream = TcpStream::connect(address).await?; Ok(stream) diff --git a/src/protocol/unix_socket.rs b/src/protocol/unix_socket.rs index 91809083..2cbf16f0 100644 --- a/src/protocol/unix_socket.rs +++ b/src/protocol/unix_socket.rs @@ -13,8 +13,9 @@ impl AsyncStreamProtocolListener for UnixListener { type AddressType = PathBuf; type RemoteAddressType = SocketAddr; type Stream = UnixStream; + type ListenerOptions = (); - async fn protocol_bind(address: Self::AddressType) -> Result { + async fn protocol_bind(address: Self::AddressType, _: Self::ListenerOptions) -> Result { let listener = UnixListener::bind(address)?; Ok(listener) @@ -39,8 +40,12 @@ impl AsyncProtocolStreamSplit for UnixStream { #[async_trait] impl AsyncProtocolStream for UnixStream { type AddressType = PathBuf; + type StreamOptions = (); - async fn protocol_connect(address: Self::AddressType) -> IPCResult { + async fn protocol_connect( + address: Self::AddressType, + _: Self::StreamOptions, + ) -> IPCResult { let stream = UnixStream::connect(address).await?; stream .ready(Interest::READABLE | Interest::WRITABLE) diff --git a/tests/test_encryption.rs b/tests/test_encryption.rs new file mode 100644 index 00000000..3ab90988 --- /dev/null +++ b/tests/test_encryption.rs @@ -0,0 +1,101 @@ +use crate::utils::call_counter::increment_counter_for_event; +use crate::utils::protocol::TestProtocolListener; +use crate::utils::{get_free_port, start_server_and_client}; +use bromine::prelude::encrypted::EncryptedListener; +use bromine::prelude::*; +use bromine::IPCBuilder; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::StreamExt; +use rand_core::RngCore; +use std::io::Read; +use std::time::Duration; + +mod utils; + +#[tokio::test] +async fn it_sends_and_receives() { + let ctx = get_client_with_server().await; + let mut rng = rand::thread_rng(); + let mut buffer = vec![0u8; 140]; + rng.fill_bytes(&mut buffer); + + let mut stream = ctx + .emit("bytes", BytePayload::new(buffer.clone())) + .stream_replies() + .await + .unwrap(); + let mut count = 0; + + while let Some(Ok(response)) = stream.next().await { + let bytes = response.payload::().unwrap(); + assert_eq!(bytes.into_inner(), buffer); + count += 1; + } + assert_eq!(count, 100) +} + +#[tokio::test] +async fn it_sends_and_receives_strings() { + let ctx = get_client_with_server().await; + let response = ctx + .emit("string", StringPayload(String::from("Hello World"))) + .await_reply() + .await + .unwrap(); + let response_string = response.payload::().unwrap().0; + + assert_eq!(&response_string, "Hello World") +} + +async fn get_client_with_server() -> Context { + let port = get_free_port(); + + start_server_and_client(move || get_builder(port)).await +} + +fn get_builder(port: u8) -> IPCBuilder> { + IPCBuilder::new() + .address(port) + .on("bytes", callback!(handle_bytes)) + .on("string", callback!(handle_string)) + .timeout(Duration::from_millis(100)) +} + +async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult { + increment_counter_for_event(ctx, &event).await; + let bytes = event.payload::()?.into_bytes(); + + for _ in 0u8..99 { + ctx.emit("bytes", BytePayload::from(bytes.clone())).await?; + } + + ctx.response(BytePayload::from(bytes)) +} + +async fn handle_string(ctx: &Context, event: Event) -> IPCResult { + ctx.response(event.payload::()?) +} + +pub struct StringPayload(String); + +impl IntoPayload for StringPayload { + fn into_payload(self, _: &Context) -> IPCResult { + let mut buf = BytesMut::with_capacity(self.0.len() + 4); + buf.put_u32(self.0.len() as u32); + buf.put(Bytes::from(self.0)); + + Ok(buf.freeze()) + } +} + +impl FromPayload for StringPayload { + fn from_payload(mut reader: R) -> IPCResult { + let len = reader.read_u32::()?; + let mut buf = vec![0u8; len as usize]; + reader.read_exact(&mut buf)?; + let string = String::from_utf8(buf).map_err(|_| IPCError::from("not a string"))?; + + Ok(StringPayload(string)) + } +} diff --git a/tests/utils/call_counter.rs b/tests/utils/call_counter.rs index a9d61dd9..beeec79f 100644 --- a/tests/utils/call_counter.rs +++ b/tests/utils/call_counter.rs @@ -1,3 +1,4 @@ +#![allow(unused)] use bromine::context::Context; use bromine::event::Event; use std::collections::HashMap; diff --git a/tests/utils/protocol.rs b/tests/utils/protocol.rs index 329645de..f95c18df 100644 --- a/tests/utils/protocol.rs +++ b/tests/utils/protocol.rs @@ -62,8 +62,9 @@ impl AsyncStreamProtocolListener for TestProtocolListener { type AddressType = u8; type RemoteAddressType = u8; type Stream = TestProtocolStream; + type ListenerOptions = (); - async fn protocol_bind(address: Self::AddressType) -> Result { + async fn protocol_bind(address: Self::AddressType, _: Self::ListenerOptions) -> Result { let (sender, receiver) = channel(1); add_port(address, sender).await; @@ -170,8 +171,9 @@ impl AsyncProtocolStreamSplit for TestProtocolStream { #[async_trait] impl AsyncProtocolStream for TestProtocolStream { type AddressType = u8; + type StreamOptions = (); - async fn protocol_connect(address: Self::AddressType) -> Result { + async fn protocol_connect(address: Self::AddressType, _: Self::StreamOptions) -> Result { get_port(address) .await .ok_or_else(|| IPCError::from("Failed to connect")) From ef99adfee183440dd66ddb157156ce952158dca0 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 26 Mar 2022 20:11:17 +0100 Subject: [PATCH 4/4] Fix issues with encryption writers Signed-off-by: trivernis --- Cargo.lock | 79 +++++++++++++ Cargo.toml | 2 + src/protocol/encrypted/crypt_handling.rs | 17 ++- src/protocol/encrypted/io_impl.rs | 138 +++++++++++++---------- src/protocol/encrypted/mod.rs | 16 +-- tests/test_encryption.rs | 54 +++++---- tests/utils/mod.rs | 17 ++- 7 files changed, 228 insertions(+), 95 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5e2f2b4..9deb4336 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "async-trait" version = "0.1.52" @@ -124,6 +133,7 @@ dependencies = [ "futures-core", "lazy_static", "num_enum", + "port_check", "postcard", "rand", "rand_core 0.6.3", @@ -134,6 +144,7 @@ dependencies = [ "thiserror", "tokio", "tracing", + "tracing-subscriber", "trait-bound-typemap", "x25519-dalek", ] @@ -749,6 +760,12 @@ dependencies = [ "syn", ] +[[package]] +name = "once_cell" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9" + [[package]] name = "oorandom" version = "11.1.3" @@ -812,6 +829,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "port_check" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6519412c9e0d4be579b9f0618364d19cb434b324fc6ddb1b27b1e682c7105ed" + [[package]] name = "postcard" version = "0.7.3" @@ -1104,12 +1127,27 @@ dependencies = [ "digest 0.10.3", ] +[[package]] +name = "sharded-slab" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +dependencies = [ + "lazy_static", +] + [[package]] name = "slab" version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" +[[package]] +name = "smallvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" + [[package]] name = "socket2" version = "0.4.4" @@ -1193,6 +1231,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +dependencies = [ + "once_cell", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -1270,6 +1317,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c" dependencies = [ "lazy_static", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e0ab7bdc962035a87fba73f3acca9b8a8d0034c2e6f60b84aeaaddddc155dce" +dependencies = [ + "ansi_term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -1309,6 +1382,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcell" version = "0.1.3" diff --git a/Cargo.toml b/Cargo.toml index 1b3ebdf4..1c72b631 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,8 @@ features = ["alloc"] rmp-serde = "1.0.0" crossbeam-utils = "0.8.7" futures = "0.3.21" +tracing-subscriber = "0.3.9" +port_check = "0.1.5" [dev-dependencies.serde] version = "1.0.136" diff --git a/src/protocol/encrypted/crypt_handling.rs b/src/protocol/encrypted/crypt_handling.rs index 4bfe35c9..27ee21aa 100644 --- a/src/protocol/encrypted/crypt_handling.rs +++ b/src/protocol/encrypted/crypt_handling.rs @@ -6,7 +6,7 @@ use chacha20poly1305::aead::{Aead, NewAead}; use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce}; use rand::thread_rng; use rand_core::RngCore; -use sha2::{Digest, Sha224, Sha256}; +use sha2::{Digest, Sha256}; use std::io; use std::io::ErrorKind; use std::sync::atomic::{AtomicU64, Ordering}; @@ -38,6 +38,7 @@ impl CipherBox { } /// Encrypts the given message + #[tracing::instrument(level = "trace", skip_all)] pub fn encrypt(&self, data: Bytes) -> io::Result { self.cipher .encrypt(&self.en_nonce(), &data[..]) @@ -46,6 +47,7 @@ impl CipherBox { } /// Decrypts the given message + #[tracing::instrument(level = "trace", skip_all)] pub fn decrypt(&self, data: Bytes) -> io::Result { self.cipher .decrypt(&self.de_nonce(), &data[..]) @@ -56,6 +58,7 @@ impl CipherBox { /// Updates the stored key. /// This must be done simultaneously on server and client side /// to keep track of the nonce + #[tracing::instrument(level = "trace", skip_all)] pub fn update_key(&mut self, key: Bytes) { let key = Key::from_slice(&key[..]); self.cipher = ChaCha20Poly1305::new(key); @@ -64,6 +67,7 @@ impl CipherBox { /// Resets the nonce counters. /// This must be done simultaneously on server and client side. + #[tracing::instrument(level = "trace", skip_all)] pub fn reset_counters(&mut self) { self.de_count.store(0, Ordering::SeqCst); self.en_count.store(0, Ordering::SeqCst); @@ -71,22 +75,24 @@ impl CipherBox { fn en_nonce(&self) -> Nonce { let count = self.en_count.fetch_add(1, Ordering::SeqCst); + tracing::trace!("encrypted count {}", count); nonce_from_number(count) } fn de_nonce(&self) -> Nonce { let count = self.de_count.fetch_add(1, Ordering::SeqCst); + tracing::trace!("decrypted count {}", count); nonce_from_number(count) } } /// Generates a nonce from a given number -/// the nonce is passed through sha224 for pseudo-randomness +/// The given number is repeated to fit the nonce bytes fn nonce_from_number(number: u64) -> Nonce { let number_bytes: [u8; 8] = number.to_be_bytes(); - let sha_bytes = Sha224::digest(&number_bytes).to_vec(); + let num_vec = number_bytes.repeat(2); let mut nonce_bytes = [0u8; 12]; - nonce_bytes.copy_from_slice(&sha_bytes[..12]); + nonce_bytes.copy_from_slice(&num_vec[..12]); nonce_bytes.into() } @@ -131,6 +137,7 @@ impl EncryptedStream { } } +#[tracing::instrument(level = "debug", skip_all)] async fn receive_public_key(stream: &mut T) -> IPCResult { let mut pk_buf = [0u8; 32]; stream.read_exact(&mut pk_buf).await?; @@ -138,6 +145,7 @@ async fn receive_public_key(stream: &mut T) -> IPCResult Ok(PublicKey::from(pk_buf)) } +#[tracing::instrument(level = "debug", skip_all)] async fn send_public_key( stream: &mut T, secret: &StaticSecret, @@ -149,6 +157,7 @@ async fn send_public_key( Ok(()) } +#[tracing::instrument(level = "trace", skip_all)] fn generate_secret() -> Vec { let mut rng = thread_rng(); let mut buf = vec![0u8; 32]; diff --git a/src/protocol/encrypted/io_impl.rs b/src/protocol/encrypted/io_impl.rs index 94360e38..9182e21b 100644 --- a/src/protocol/encrypted/io_impl.rs +++ b/src/protocol/encrypted/io_impl.rs @@ -52,54 +52,46 @@ impl AsyncRead for EncryptedReadSt buf: &mut ReadBuf<'_>, ) -> Poll> { if self.fut.is_none() { - let max_copy = min(buf.remaining(), self.remaining.len()); - let bytes = self.remaining.copy_to_bytes(max_copy); - buf.put_slice(&bytes); + if self.remaining.len() > 0 { + let max_copy = min(buf.remaining(), self.remaining.len()); + let bytes = self.remaining.copy_to_bytes(max_copy); + buf.put_slice(&bytes); + tracing::trace!("{} bytes read from buffer", bytes.len()); + } if buf.remaining() > 0 { - let mut reader = self.inner.take().unwrap(); + tracing::trace!("{} bytes remaining to read", buf.remaining()); + let reader = self.inner.take().unwrap(); let cipher = self.cipher.take().unwrap(); - self.fut = Some(Box::pin(async move { - let package = match EncryptedPackage::from_async_read(&mut reader).await { - Ok(p) => p, - Err(e) => { - return (Err(e), reader, cipher); - } - }; - match cipher.decrypt(package.into_inner()) { - Ok(bytes) => (Ok(bytes), reader, cipher), - Err(e) => (Err(e), reader, cipher), - } - })); + self.fut = Some(Box::pin(async move { read_bytes(reader, cipher).await })); + } else { + return Poll::Ready(Ok(())); } } - if self.fut.is_some() { - match self.fut.as_mut().unwrap().as_mut().poll(cx) { - Poll::Ready((result, reader, cipher)) => { - self.inner = Some(reader); - self.cipher = Some(cipher); - match result { - Ok(bytes) => { - self.remaining.put(bytes); - let max_copy = min(self.remaining.len(), buf.remaining()); - let bytes = self.remaining.copy_to_bytes(max_copy); - self.fut = None; - buf.put_slice(&bytes); - - if buf.remaining() == 0 { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } + match self.fut.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, reader, cipher)) => { + self.inner = Some(reader); + self.cipher = Some(cipher); + match result { + Ok(bytes) => { + self.fut = None; + self.remaining.put(bytes); + let max_copy = min(self.remaining.len(), buf.remaining()); + let bytes = self.remaining.copy_to_bytes(max_copy); + buf.put_slice(&bytes); + tracing::trace!("{} bytes read from buffer", bytes.len()); + + if buf.remaining() == 0 { + Poll::Ready(Ok(())) + } else { + Poll::Pending } - Err(e) => Poll::Ready(Err(e)), } + Err(e) => Poll::Ready(Err(e)), } - Poll::Pending => Poll::Pending, } - } else { - Poll::Ready(Ok(())) + Poll::Pending => Poll::Pending, } } } @@ -112,11 +104,13 @@ impl AsyncWrite for EncryptedWrit cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if buf.remaining() > 0 { - let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) }; - self.buffer.put(Bytes::from(buf)); + let written_length = buf.len(); + + if self.fut_write.is_none() { + self.buffer.put(Bytes::from(buf.to_vec())); - if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE { + if self.buffer.len() >= WRITE_BUF_SIZE { + tracing::trace!("buffer has reached sending size: {}", self.buffer.len()); let buffer_len = self.buffer.len(); let max_copy = min(u32::MAX as usize, buffer_len); let plaintext = self.buffer.copy_to_bytes(max_copy); @@ -124,38 +118,42 @@ impl AsyncWrite for EncryptedWrit let cipher = self.cipher.take().unwrap(); self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher))) + } else { + return Poll::Ready(Ok(written_length)); } } - if self.fut_write.is_some() { - match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { - Poll::Ready((result, writer, cipher)) => { - self.inner = Some(writer); - self.cipher = Some(cipher); - self.fut_write = None; - - Poll::Ready(result.map(|_| buf.len())) - } - Poll::Pending => Poll::Pending, + + match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, writer, cipher)) => { + self.inner = Some(writer); + self.cipher = Some(cipher); + self.fut_write = None; + + Poll::Ready(result.map(|_| written_length)) } - } else { - Poll::Ready(Ok(buf.len())) + Poll::Pending => Poll::Pending, } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let buffer_len = self.buffer.len(); - if !self.buffer.is_empty() && self.fut_flush.is_none() { + if self.fut_flush.is_none() { let max_copy = min(u32::MAX as usize, buffer_len); let plaintext = self.buffer.copy_to_bytes(max_copy); let writer = self.inner.take().unwrap(); let cipher = self.cipher.take().unwrap(); self.fut_flush = Some(Box::pin(async move { - let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await; - if result.is_err() { - return (result, writer, cipher); - } + let (mut writer, cipher) = if plaintext.len() > 0 { + let (result, writer, cipher) = write_bytes(plaintext, writer, cipher).await; + if result.is_err() { + return (result, writer, cipher); + } + (writer, cipher) + } else { + (writer, cipher) + }; if let Err(e) = writer.flush().await { (Err(e), writer, cipher) } else { @@ -200,11 +198,13 @@ impl AsyncWrite for EncryptedWrit } } +#[tracing::instrument(level = "trace", skip_all)] async fn write_bytes( bytes: Bytes, mut writer: T, cipher: CipherBox, ) -> (io::Result<()>, T, CipherBox) { + tracing::trace!("plaintext size: {}", bytes.len()); let encrypted_bytes = match cipher.encrypt(bytes) { Ok(b) => b, Err(e) => { @@ -212,9 +212,29 @@ async fn write_bytes( } }; let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes(); + tracing::trace!("encrypted size: {}", package_bytes.len()); if let Err(e) = writer.write_all(&package_bytes[..]).await { return (Err(e), writer, cipher); } + tracing::trace!("everything sent"); (Ok(()), writer, cipher) } + +#[tracing::instrument(level = "trace", skip_all)] +async fn read_bytes( + mut reader: T, + cipher: CipherBox, +) -> (io::Result, T, CipherBox) { + let package = match EncryptedPackage::from_async_read(&mut reader).await { + Ok(p) => p, + Err(e) => { + return (Err(e), reader, cipher); + } + }; + tracing::trace!("received {} bytes", package.bytes.len()); + match cipher.decrypt(package.into_inner()) { + Ok(bytes) => (Ok(bytes), reader, cipher), + Err(e) => (Err(e), reader, cipher), + } +} diff --git a/src/protocol/encrypted/mod.rs b/src/protocol/encrypted/mod.rs index 7126dcdc..dc130ec8 100644 --- a/src/protocol/encrypted/mod.rs +++ b/src/protocol/encrypted/mod.rs @@ -9,7 +9,7 @@ use rand_core::RngCore; use std::future::Future; use std::io; use std::pin::Pin; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use x25519_dalek::{SharedSecret, StaticSecret}; use crate::prelude::encrypted::crypt_handling::CipherBox; @@ -80,8 +80,8 @@ impl EncryptedStream { } pub struct EncryptedReadStream { - inner: Option>, - fut: OptionalFuture<(io::Result, BufReader, CipherBox)>, + inner: Option, + fut: OptionalFuture<(io::Result, T, CipherBox)>, remaining: BytesMut, cipher: Option, } @@ -89,7 +89,7 @@ pub struct EncryptedReadStream { impl EncryptedReadStream { pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { Self { - inner: Some(BufReader::new(inner)), + inner: Some(inner), fut: None, remaining: BytesMut::new(), cipher: Some(cipher), @@ -98,18 +98,18 @@ impl EncryptedReadStream { } pub struct EncryptedWriteStream { - inner: Option>, + inner: Option, cipher: Option, buffer: BytesMut, - fut_write: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, - fut_flush: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, + fut_write: OptionalFuture<(io::Result<()>, T, CipherBox)>, + fut_flush: OptionalFuture<(io::Result<()>, T, CipherBox)>, fut_shutdown: OptionalFuture>, } impl EncryptedWriteStream { pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { Self { - inner: Some(BufWriter::new(inner)), + inner: Some(inner), cipher: Some(cipher), buffer: BytesMut::with_capacity(1024), fut_write: None, diff --git a/tests/test_encryption.rs b/tests/test_encryption.rs index 3ab90988..24eeddea 100644 --- a/tests/test_encryption.rs +++ b/tests/test_encryption.rs @@ -14,25 +14,13 @@ use std::time::Duration; mod utils; #[tokio::test] -async fn it_sends_and_receives() { - let ctx = get_client_with_server().await; - let mut rng = rand::thread_rng(); - let mut buffer = vec![0u8; 140]; - rng.fill_bytes(&mut buffer); - - let mut stream = ctx - .emit("bytes", BytePayload::new(buffer.clone())) - .stream_replies() - .await - .unwrap(); - let mut count = 0; +async fn it_sends_and_receives_smaller_packages() { + send_and_receive_bytes(140).await.unwrap(); +} - while let Some(Ok(response)) = stream.next().await { - let bytes = response.payload::().unwrap(); - assert_eq!(bytes.into_inner(), buffer); - count += 1; - } - assert_eq!(count, 100) +#[tokio::test] +async fn it_sends_and_receives_larger_packages() { + send_and_receive_bytes(1024 * 32).await.unwrap(); } #[tokio::test] @@ -48,6 +36,28 @@ async fn it_sends_and_receives_strings() { assert_eq!(&response_string, "Hello World") } +async fn send_and_receive_bytes(byte_size: usize) -> IPCResult<()> { + let ctx = get_client_with_server().await; + let mut rng = rand::thread_rng(); + let mut buffer = vec![0u8; byte_size]; + rng.fill_bytes(&mut buffer); + + let mut stream = ctx + .emit("bytes", BytePayload::new(buffer.clone())) + .stream_replies() + .await?; + let mut count = 0; + + while let Some(response) = stream.next().await { + let bytes = response.unwrap().payload::()?; + assert_eq!(bytes.into_inner(), buffer); + count += 1; + } + assert_eq!(count, 100); + + Ok(()) +} + async fn get_client_with_server() -> Context { let port = get_free_port(); @@ -59,18 +69,18 @@ fn get_builder(port: u8) -> IPCBuilder> .address(port) .on("bytes", callback!(handle_bytes)) .on("string", callback!(handle_string)) - .timeout(Duration::from_millis(100)) + .timeout(Duration::from_secs(10)) } async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; - let bytes = event.payload::()?.into_bytes(); + let bytes = event.payload::()?.into_inner(); for _ in 0u8..99 { - ctx.emit("bytes", BytePayload::from(bytes.clone())).await?; + ctx.emit("bytes", BytePayload::new(bytes.clone())).await?; } - ctx.response(BytePayload::from(bytes)) + ctx.response(BytePayload::new(bytes)) } async fn handle_string(ctx: &Context, event: Event) -> IPCResult { diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 3b69a301..d38564ef 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,20 +1,32 @@ +#![allow(unused)] use bromine::context::Context; use bromine::protocol::AsyncStreamProtocolListener; use bromine::IPCBuilder; use call_counter::*; use lazy_static::lazy_static; -use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use std::sync::Arc; use tokio::sync::oneshot::channel; pub mod call_counter; pub mod protocol; +pub fn setup() { + lazy_static! { + static ref SETUP_DONE: Arc = Default::default(); + } + if !SETUP_DONE.swap(true, Ordering::SeqCst) { + tracing_subscriber::fmt::init(); + } +} + pub fn get_free_port() -> u8 { lazy_static! { static ref PORT_COUNTER: Arc = Arc::new(AtomicU8::new(0)); } - PORT_COUNTER.fetch_add(1, Ordering::Relaxed) + let count = PORT_COUNTER.fetch_add(1, Ordering::Relaxed); + + count } pub async fn start_server_and_client< @@ -23,6 +35,7 @@ pub async fn start_server_and_client< >( builder_fn: F, ) -> Context { + setup(); let counters = CallCounter::default(); let (sender, receiver) = channel::<()>(); let client_builder = builder_fn().insert::(counters.clone());