From ef99adfee183440dd66ddb157156ce952158dca0 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 26 Mar 2022 20:11:17 +0100 Subject: [PATCH] Fix issues with encryption writers Signed-off-by: trivernis --- Cargo.lock | 79 +++++++++++++ Cargo.toml | 2 + src/protocol/encrypted/crypt_handling.rs | 17 ++- src/protocol/encrypted/io_impl.rs | 138 +++++++++++++---------- src/protocol/encrypted/mod.rs | 16 +-- tests/test_encryption.rs | 54 +++++---- tests/utils/mod.rs | 17 ++- 7 files changed, 228 insertions(+), 95 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5e2f2b4..9deb4336 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "async-trait" version = "0.1.52" @@ -124,6 +133,7 @@ dependencies = [ "futures-core", "lazy_static", "num_enum", + "port_check", "postcard", "rand", "rand_core 0.6.3", @@ -134,6 +144,7 @@ dependencies = [ "thiserror", "tokio", "tracing", + "tracing-subscriber", "trait-bound-typemap", "x25519-dalek", ] @@ -749,6 +760,12 @@ dependencies = [ "syn", ] +[[package]] +name = "once_cell" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9" + [[package]] name = "oorandom" version = "11.1.3" @@ -812,6 +829,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "port_check" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6519412c9e0d4be579b9f0618364d19cb434b324fc6ddb1b27b1e682c7105ed" + [[package]] name = "postcard" version = "0.7.3" @@ -1104,12 +1127,27 @@ dependencies = [ "digest 0.10.3", ] +[[package]] +name = "sharded-slab" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +dependencies = [ + "lazy_static", +] + [[package]] name = "slab" version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" +[[package]] +name = "smallvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" + [[package]] name = "socket2" version = "0.4.4" @@ -1193,6 +1231,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +dependencies = [ + "once_cell", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -1270,6 +1317,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c" dependencies = [ "lazy_static", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e0ab7bdc962035a87fba73f3acca9b8a8d0034c2e6f60b84aeaaddddc155dce" +dependencies = [ + "ansi_term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -1309,6 +1382,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcell" version = "0.1.3" diff --git a/Cargo.toml b/Cargo.toml index 1b3ebdf4..1c72b631 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,8 @@ features = ["alloc"] rmp-serde = "1.0.0" crossbeam-utils = "0.8.7" futures = "0.3.21" +tracing-subscriber = "0.3.9" +port_check = "0.1.5" [dev-dependencies.serde] version = "1.0.136" diff --git a/src/protocol/encrypted/crypt_handling.rs b/src/protocol/encrypted/crypt_handling.rs index 4bfe35c9..27ee21aa 100644 --- a/src/protocol/encrypted/crypt_handling.rs +++ b/src/protocol/encrypted/crypt_handling.rs @@ -6,7 +6,7 @@ use chacha20poly1305::aead::{Aead, NewAead}; use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce}; use rand::thread_rng; use rand_core::RngCore; -use sha2::{Digest, Sha224, Sha256}; +use sha2::{Digest, Sha256}; use std::io; use std::io::ErrorKind; use std::sync::atomic::{AtomicU64, Ordering}; @@ -38,6 +38,7 @@ impl CipherBox { } /// Encrypts the given message + #[tracing::instrument(level = "trace", skip_all)] pub fn encrypt(&self, data: Bytes) -> io::Result { self.cipher .encrypt(&self.en_nonce(), &data[..]) @@ -46,6 +47,7 @@ impl CipherBox { } /// Decrypts the given message + #[tracing::instrument(level = "trace", skip_all)] pub fn decrypt(&self, data: Bytes) -> io::Result { self.cipher .decrypt(&self.de_nonce(), &data[..]) @@ -56,6 +58,7 @@ impl CipherBox { /// Updates the stored key. /// This must be done simultaneously on server and client side /// to keep track of the nonce + #[tracing::instrument(level = "trace", skip_all)] pub fn update_key(&mut self, key: Bytes) { let key = Key::from_slice(&key[..]); self.cipher = ChaCha20Poly1305::new(key); @@ -64,6 +67,7 @@ impl CipherBox { /// Resets the nonce counters. /// This must be done simultaneously on server and client side. + #[tracing::instrument(level = "trace", skip_all)] pub fn reset_counters(&mut self) { self.de_count.store(0, Ordering::SeqCst); self.en_count.store(0, Ordering::SeqCst); @@ -71,22 +75,24 @@ impl CipherBox { fn en_nonce(&self) -> Nonce { let count = self.en_count.fetch_add(1, Ordering::SeqCst); + tracing::trace!("encrypted count {}", count); nonce_from_number(count) } fn de_nonce(&self) -> Nonce { let count = self.de_count.fetch_add(1, Ordering::SeqCst); + tracing::trace!("decrypted count {}", count); nonce_from_number(count) } } /// Generates a nonce from a given number -/// the nonce is passed through sha224 for pseudo-randomness +/// The given number is repeated to fit the nonce bytes fn nonce_from_number(number: u64) -> Nonce { let number_bytes: [u8; 8] = number.to_be_bytes(); - let sha_bytes = Sha224::digest(&number_bytes).to_vec(); + let num_vec = number_bytes.repeat(2); let mut nonce_bytes = [0u8; 12]; - nonce_bytes.copy_from_slice(&sha_bytes[..12]); + nonce_bytes.copy_from_slice(&num_vec[..12]); nonce_bytes.into() } @@ -131,6 +137,7 @@ impl EncryptedStream { } } +#[tracing::instrument(level = "debug", skip_all)] async fn receive_public_key(stream: &mut T) -> IPCResult { let mut pk_buf = [0u8; 32]; stream.read_exact(&mut pk_buf).await?; @@ -138,6 +145,7 @@ async fn receive_public_key(stream: &mut T) -> IPCResult Ok(PublicKey::from(pk_buf)) } +#[tracing::instrument(level = "debug", skip_all)] async fn send_public_key( stream: &mut T, secret: &StaticSecret, @@ -149,6 +157,7 @@ async fn send_public_key( Ok(()) } +#[tracing::instrument(level = "trace", skip_all)] fn generate_secret() -> Vec { let mut rng = thread_rng(); let mut buf = vec![0u8; 32]; diff --git a/src/protocol/encrypted/io_impl.rs b/src/protocol/encrypted/io_impl.rs index 94360e38..9182e21b 100644 --- a/src/protocol/encrypted/io_impl.rs +++ b/src/protocol/encrypted/io_impl.rs @@ -52,54 +52,46 @@ impl AsyncRead for EncryptedReadSt buf: &mut ReadBuf<'_>, ) -> Poll> { if self.fut.is_none() { - let max_copy = min(buf.remaining(), self.remaining.len()); - let bytes = self.remaining.copy_to_bytes(max_copy); - buf.put_slice(&bytes); + if self.remaining.len() > 0 { + let max_copy = min(buf.remaining(), self.remaining.len()); + let bytes = self.remaining.copy_to_bytes(max_copy); + buf.put_slice(&bytes); + tracing::trace!("{} bytes read from buffer", bytes.len()); + } if buf.remaining() > 0 { - let mut reader = self.inner.take().unwrap(); + tracing::trace!("{} bytes remaining to read", buf.remaining()); + let reader = self.inner.take().unwrap(); let cipher = self.cipher.take().unwrap(); - self.fut = Some(Box::pin(async move { - let package = match EncryptedPackage::from_async_read(&mut reader).await { - Ok(p) => p, - Err(e) => { - return (Err(e), reader, cipher); - } - }; - match cipher.decrypt(package.into_inner()) { - Ok(bytes) => (Ok(bytes), reader, cipher), - Err(e) => (Err(e), reader, cipher), - } - })); + self.fut = Some(Box::pin(async move { read_bytes(reader, cipher).await })); + } else { + return Poll::Ready(Ok(())); } } - if self.fut.is_some() { - match self.fut.as_mut().unwrap().as_mut().poll(cx) { - Poll::Ready((result, reader, cipher)) => { - self.inner = Some(reader); - self.cipher = Some(cipher); - match result { - Ok(bytes) => { - self.remaining.put(bytes); - let max_copy = min(self.remaining.len(), buf.remaining()); - let bytes = self.remaining.copy_to_bytes(max_copy); - self.fut = None; - buf.put_slice(&bytes); - - if buf.remaining() == 0 { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } + match self.fut.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, reader, cipher)) => { + self.inner = Some(reader); + self.cipher = Some(cipher); + match result { + Ok(bytes) => { + self.fut = None; + self.remaining.put(bytes); + let max_copy = min(self.remaining.len(), buf.remaining()); + let bytes = self.remaining.copy_to_bytes(max_copy); + buf.put_slice(&bytes); + tracing::trace!("{} bytes read from buffer", bytes.len()); + + if buf.remaining() == 0 { + Poll::Ready(Ok(())) + } else { + Poll::Pending } - Err(e) => Poll::Ready(Err(e)), } + Err(e) => Poll::Ready(Err(e)), } - Poll::Pending => Poll::Pending, } - } else { - Poll::Ready(Ok(())) + Poll::Pending => Poll::Pending, } } } @@ -112,11 +104,13 @@ impl AsyncWrite for EncryptedWrit cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if buf.remaining() > 0 { - let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) }; - self.buffer.put(Bytes::from(buf)); + let written_length = buf.len(); + + if self.fut_write.is_none() { + self.buffer.put(Bytes::from(buf.to_vec())); - if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE { + if self.buffer.len() >= WRITE_BUF_SIZE { + tracing::trace!("buffer has reached sending size: {}", self.buffer.len()); let buffer_len = self.buffer.len(); let max_copy = min(u32::MAX as usize, buffer_len); let plaintext = self.buffer.copy_to_bytes(max_copy); @@ -124,38 +118,42 @@ impl AsyncWrite for EncryptedWrit let cipher = self.cipher.take().unwrap(); self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher))) + } else { + return Poll::Ready(Ok(written_length)); } } - if self.fut_write.is_some() { - match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { - Poll::Ready((result, writer, cipher)) => { - self.inner = Some(writer); - self.cipher = Some(cipher); - self.fut_write = None; - - Poll::Ready(result.map(|_| buf.len())) - } - Poll::Pending => Poll::Pending, + + match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, writer, cipher)) => { + self.inner = Some(writer); + self.cipher = Some(cipher); + self.fut_write = None; + + Poll::Ready(result.map(|_| written_length)) } - } else { - Poll::Ready(Ok(buf.len())) + Poll::Pending => Poll::Pending, } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let buffer_len = self.buffer.len(); - if !self.buffer.is_empty() && self.fut_flush.is_none() { + if self.fut_flush.is_none() { let max_copy = min(u32::MAX as usize, buffer_len); let plaintext = self.buffer.copy_to_bytes(max_copy); let writer = self.inner.take().unwrap(); let cipher = self.cipher.take().unwrap(); self.fut_flush = Some(Box::pin(async move { - let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await; - if result.is_err() { - return (result, writer, cipher); - } + let (mut writer, cipher) = if plaintext.len() > 0 { + let (result, writer, cipher) = write_bytes(plaintext, writer, cipher).await; + if result.is_err() { + return (result, writer, cipher); + } + (writer, cipher) + } else { + (writer, cipher) + }; if let Err(e) = writer.flush().await { (Err(e), writer, cipher) } else { @@ -200,11 +198,13 @@ impl AsyncWrite for EncryptedWrit } } +#[tracing::instrument(level = "trace", skip_all)] async fn write_bytes( bytes: Bytes, mut writer: T, cipher: CipherBox, ) -> (io::Result<()>, T, CipherBox) { + tracing::trace!("plaintext size: {}", bytes.len()); let encrypted_bytes = match cipher.encrypt(bytes) { Ok(b) => b, Err(e) => { @@ -212,9 +212,29 @@ async fn write_bytes( } }; let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes(); + tracing::trace!("encrypted size: {}", package_bytes.len()); if let Err(e) = writer.write_all(&package_bytes[..]).await { return (Err(e), writer, cipher); } + tracing::trace!("everything sent"); (Ok(()), writer, cipher) } + +#[tracing::instrument(level = "trace", skip_all)] +async fn read_bytes( + mut reader: T, + cipher: CipherBox, +) -> (io::Result, T, CipherBox) { + let package = match EncryptedPackage::from_async_read(&mut reader).await { + Ok(p) => p, + Err(e) => { + return (Err(e), reader, cipher); + } + }; + tracing::trace!("received {} bytes", package.bytes.len()); + match cipher.decrypt(package.into_inner()) { + Ok(bytes) => (Ok(bytes), reader, cipher), + Err(e) => (Err(e), reader, cipher), + } +} diff --git a/src/protocol/encrypted/mod.rs b/src/protocol/encrypted/mod.rs index 7126dcdc..dc130ec8 100644 --- a/src/protocol/encrypted/mod.rs +++ b/src/protocol/encrypted/mod.rs @@ -9,7 +9,7 @@ use rand_core::RngCore; use std::future::Future; use std::io; use std::pin::Pin; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use x25519_dalek::{SharedSecret, StaticSecret}; use crate::prelude::encrypted::crypt_handling::CipherBox; @@ -80,8 +80,8 @@ impl EncryptedStream { } pub struct EncryptedReadStream { - inner: Option>, - fut: OptionalFuture<(io::Result, BufReader, CipherBox)>, + inner: Option, + fut: OptionalFuture<(io::Result, T, CipherBox)>, remaining: BytesMut, cipher: Option, } @@ -89,7 +89,7 @@ pub struct EncryptedReadStream { impl EncryptedReadStream { pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { Self { - inner: Some(BufReader::new(inner)), + inner: Some(inner), fut: None, remaining: BytesMut::new(), cipher: Some(cipher), @@ -98,18 +98,18 @@ impl EncryptedReadStream { } pub struct EncryptedWriteStream { - inner: Option>, + inner: Option, cipher: Option, buffer: BytesMut, - fut_write: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, - fut_flush: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, + fut_write: OptionalFuture<(io::Result<()>, T, CipherBox)>, + fut_flush: OptionalFuture<(io::Result<()>, T, CipherBox)>, fut_shutdown: OptionalFuture>, } impl EncryptedWriteStream { pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { Self { - inner: Some(BufWriter::new(inner)), + inner: Some(inner), cipher: Some(cipher), buffer: BytesMut::with_capacity(1024), fut_write: None, diff --git a/tests/test_encryption.rs b/tests/test_encryption.rs index 3ab90988..24eeddea 100644 --- a/tests/test_encryption.rs +++ b/tests/test_encryption.rs @@ -14,25 +14,13 @@ use std::time::Duration; mod utils; #[tokio::test] -async fn it_sends_and_receives() { - let ctx = get_client_with_server().await; - let mut rng = rand::thread_rng(); - let mut buffer = vec![0u8; 140]; - rng.fill_bytes(&mut buffer); - - let mut stream = ctx - .emit("bytes", BytePayload::new(buffer.clone())) - .stream_replies() - .await - .unwrap(); - let mut count = 0; +async fn it_sends_and_receives_smaller_packages() { + send_and_receive_bytes(140).await.unwrap(); +} - while let Some(Ok(response)) = stream.next().await { - let bytes = response.payload::().unwrap(); - assert_eq!(bytes.into_inner(), buffer); - count += 1; - } - assert_eq!(count, 100) +#[tokio::test] +async fn it_sends_and_receives_larger_packages() { + send_and_receive_bytes(1024 * 32).await.unwrap(); } #[tokio::test] @@ -48,6 +36,28 @@ async fn it_sends_and_receives_strings() { assert_eq!(&response_string, "Hello World") } +async fn send_and_receive_bytes(byte_size: usize) -> IPCResult<()> { + let ctx = get_client_with_server().await; + let mut rng = rand::thread_rng(); + let mut buffer = vec![0u8; byte_size]; + rng.fill_bytes(&mut buffer); + + let mut stream = ctx + .emit("bytes", BytePayload::new(buffer.clone())) + .stream_replies() + .await?; + let mut count = 0; + + while let Some(response) = stream.next().await { + let bytes = response.unwrap().payload::()?; + assert_eq!(bytes.into_inner(), buffer); + count += 1; + } + assert_eq!(count, 100); + + Ok(()) +} + async fn get_client_with_server() -> Context { let port = get_free_port(); @@ -59,18 +69,18 @@ fn get_builder(port: u8) -> IPCBuilder> .address(port) .on("bytes", callback!(handle_bytes)) .on("string", callback!(handle_string)) - .timeout(Duration::from_millis(100)) + .timeout(Duration::from_secs(10)) } async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; - let bytes = event.payload::()?.into_bytes(); + let bytes = event.payload::()?.into_inner(); for _ in 0u8..99 { - ctx.emit("bytes", BytePayload::from(bytes.clone())).await?; + ctx.emit("bytes", BytePayload::new(bytes.clone())).await?; } - ctx.response(BytePayload::from(bytes)) + ctx.response(BytePayload::new(bytes)) } async fn handle_string(ctx: &Context, event: Event) -> IPCResult { diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 3b69a301..d38564ef 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,20 +1,32 @@ +#![allow(unused)] use bromine::context::Context; use bromine::protocol::AsyncStreamProtocolListener; use bromine::IPCBuilder; use call_counter::*; use lazy_static::lazy_static; -use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use std::sync::Arc; use tokio::sync::oneshot::channel; pub mod call_counter; pub mod protocol; +pub fn setup() { + lazy_static! { + static ref SETUP_DONE: Arc = Default::default(); + } + if !SETUP_DONE.swap(true, Ordering::SeqCst) { + tracing_subscriber::fmt::init(); + } +} + pub fn get_free_port() -> u8 { lazy_static! { static ref PORT_COUNTER: Arc = Arc::new(AtomicU8::new(0)); } - PORT_COUNTER.fetch_add(1, Ordering::Relaxed) + let count = PORT_COUNTER.fetch_add(1, Ordering::Relaxed); + + count } pub async fn start_server_and_client< @@ -23,6 +35,7 @@ pub async fn start_server_and_client< >( builder_fn: F, ) -> Context { + setup(); let counters = CallCounter::default(); let (sender, receiver) = channel::<()>(); let client_builder = builder_fn().insert::(counters.clone());