Fix issues with encryption writers

Signed-off-by: trivernis <trivernis@protonmail.com>
pull/38/head
trivernis 3 years ago
parent fe7dc97008
commit ef99adfee1
Signed by: Trivernis
GPG Key ID: DFFFCC2C7A02DB45

79
Cargo.lock generated

@ -20,6 +20,15 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "ansi_term"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
dependencies = [
"winapi",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.52" version = "0.1.52"
@ -124,6 +133,7 @@ dependencies = [
"futures-core", "futures-core",
"lazy_static", "lazy_static",
"num_enum", "num_enum",
"port_check",
"postcard", "postcard",
"rand", "rand",
"rand_core 0.6.3", "rand_core 0.6.3",
@ -134,6 +144,7 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber",
"trait-bound-typemap", "trait-bound-typemap",
"x25519-dalek", "x25519-dalek",
] ]
@ -749,6 +760,12 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "once_cell"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9"
[[package]] [[package]]
name = "oorandom" name = "oorandom"
version = "11.1.3" version = "11.1.3"
@ -812,6 +829,12 @@ dependencies = [
"universal-hash", "universal-hash",
] ]
[[package]]
name = "port_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6519412c9e0d4be579b9f0618364d19cb434b324fc6ddb1b27b1e682c7105ed"
[[package]] [[package]]
name = "postcard" name = "postcard"
version = "0.7.3" version = "0.7.3"
@ -1104,12 +1127,27 @@ dependencies = [
"digest 0.10.3", "digest 0.10.3",
] ]
[[package]]
name = "sharded-slab"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31"
dependencies = [
"lazy_static",
]
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.5" version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5"
[[package]]
name = "smallvec"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
[[package]] [[package]]
name = "socket2" name = "socket2"
version = "0.4.4" version = "0.4.4"
@ -1193,6 +1231,15 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "thread_local"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180"
dependencies = [
"once_cell",
]
[[package]] [[package]]
name = "tinytemplate" name = "tinytemplate"
version = "1.2.1" version = "1.2.1"
@ -1270,6 +1317,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c" checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c"
dependencies = [ dependencies = [
"lazy_static", "lazy_static",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3"
dependencies = [
"lazy_static",
"log",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e0ab7bdc962035a87fba73f3acca9b8a8d0034c2e6f60b84aeaaddddc155dce"
dependencies = [
"ansi_term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
] ]
[[package]] [[package]]
@ -1309,6 +1382,12 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]] [[package]]
name = "vcell" name = "vcell"
version = "0.1.3" version = "0.1.3"

@ -56,6 +56,8 @@ features = ["alloc"]
rmp-serde = "1.0.0" rmp-serde = "1.0.0"
crossbeam-utils = "0.8.7" crossbeam-utils = "0.8.7"
futures = "0.3.21" futures = "0.3.21"
tracing-subscriber = "0.3.9"
port_check = "0.1.5"
[dev-dependencies.serde] [dev-dependencies.serde]
version = "1.0.136" version = "1.0.136"

@ -6,7 +6,7 @@ use chacha20poly1305::aead::{Aead, NewAead};
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce}; use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
use rand::thread_rng; use rand::thread_rng;
use rand_core::RngCore; use rand_core::RngCore;
use sha2::{Digest, Sha224, Sha256}; use sha2::{Digest, Sha256};
use std::io; use std::io;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
@ -38,6 +38,7 @@ impl CipherBox {
} }
/// Encrypts the given message /// Encrypts the given message
#[tracing::instrument(level = "trace", skip_all)]
pub fn encrypt(&self, data: Bytes) -> io::Result<Bytes> { pub fn encrypt(&self, data: Bytes) -> io::Result<Bytes> {
self.cipher self.cipher
.encrypt(&self.en_nonce(), &data[..]) .encrypt(&self.en_nonce(), &data[..])
@ -46,6 +47,7 @@ impl CipherBox {
} }
/// Decrypts the given message /// Decrypts the given message
#[tracing::instrument(level = "trace", skip_all)]
pub fn decrypt(&self, data: Bytes) -> io::Result<Bytes> { pub fn decrypt(&self, data: Bytes) -> io::Result<Bytes> {
self.cipher self.cipher
.decrypt(&self.de_nonce(), &data[..]) .decrypt(&self.de_nonce(), &data[..])
@ -56,6 +58,7 @@ impl CipherBox {
/// Updates the stored key. /// Updates the stored key.
/// This must be done simultaneously on server and client side /// This must be done simultaneously on server and client side
/// to keep track of the nonce /// to keep track of the nonce
#[tracing::instrument(level = "trace", skip_all)]
pub fn update_key(&mut self, key: Bytes) { pub fn update_key(&mut self, key: Bytes) {
let key = Key::from_slice(&key[..]); let key = Key::from_slice(&key[..]);
self.cipher = ChaCha20Poly1305::new(key); self.cipher = ChaCha20Poly1305::new(key);
@ -64,6 +67,7 @@ impl CipherBox {
/// Resets the nonce counters. /// Resets the nonce counters.
/// This must be done simultaneously on server and client side. /// This must be done simultaneously on server and client side.
#[tracing::instrument(level = "trace", skip_all)]
pub fn reset_counters(&mut self) { pub fn reset_counters(&mut self) {
self.de_count.store(0, Ordering::SeqCst); self.de_count.store(0, Ordering::SeqCst);
self.en_count.store(0, Ordering::SeqCst); self.en_count.store(0, Ordering::SeqCst);
@ -71,22 +75,24 @@ impl CipherBox {
fn en_nonce(&self) -> Nonce { fn en_nonce(&self) -> Nonce {
let count = self.en_count.fetch_add(1, Ordering::SeqCst); let count = self.en_count.fetch_add(1, Ordering::SeqCst);
tracing::trace!("encrypted count {}", count);
nonce_from_number(count) nonce_from_number(count)
} }
fn de_nonce(&self) -> Nonce { fn de_nonce(&self) -> Nonce {
let count = self.de_count.fetch_add(1, Ordering::SeqCst); let count = self.de_count.fetch_add(1, Ordering::SeqCst);
tracing::trace!("decrypted count {}", count);
nonce_from_number(count) nonce_from_number(count)
} }
} }
/// Generates a nonce from a given number /// Generates a nonce from a given number
/// the nonce is passed through sha224 for pseudo-randomness /// The given number is repeated to fit the nonce bytes
fn nonce_from_number(number: u64) -> Nonce { fn nonce_from_number(number: u64) -> Nonce {
let number_bytes: [u8; 8] = number.to_be_bytes(); let number_bytes: [u8; 8] = number.to_be_bytes();
let sha_bytes = Sha224::digest(&number_bytes).to_vec(); let num_vec = number_bytes.repeat(2);
let mut nonce_bytes = [0u8; 12]; let mut nonce_bytes = [0u8; 12];
nonce_bytes.copy_from_slice(&sha_bytes[..12]); nonce_bytes.copy_from_slice(&num_vec[..12]);
nonce_bytes.into() nonce_bytes.into()
} }
@ -131,6 +137,7 @@ impl<T: AsyncProtocolStream> EncryptedStream<T> {
} }
} }
#[tracing::instrument(level = "debug", skip_all)]
async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult<PublicKey> { async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult<PublicKey> {
let mut pk_buf = [0u8; 32]; let mut pk_buf = [0u8; 32];
stream.read_exact(&mut pk_buf).await?; stream.read_exact(&mut pk_buf).await?;
@ -138,6 +145,7 @@ async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult
Ok(PublicKey::from(pk_buf)) Ok(PublicKey::from(pk_buf))
} }
#[tracing::instrument(level = "debug", skip_all)]
async fn send_public_key<T: AsyncProtocolStream>( async fn send_public_key<T: AsyncProtocolStream>(
stream: &mut T, stream: &mut T,
secret: &StaticSecret, secret: &StaticSecret,
@ -149,6 +157,7 @@ async fn send_public_key<T: AsyncProtocolStream>(
Ok(()) Ok(())
} }
#[tracing::instrument(level = "trace", skip_all)]
fn generate_secret() -> Vec<u8> { fn generate_secret() -> Vec<u8> {
let mut rng = thread_rng(); let mut rng = thread_rng();
let mut buf = vec![0u8; 32]; let mut buf = vec![0u8; 32];

@ -52,54 +52,46 @@ impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadSt
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
if self.fut.is_none() { if self.fut.is_none() {
let max_copy = min(buf.remaining(), self.remaining.len()); if self.remaining.len() > 0 {
let bytes = self.remaining.copy_to_bytes(max_copy); let max_copy = min(buf.remaining(), self.remaining.len());
buf.put_slice(&bytes); let bytes = self.remaining.copy_to_bytes(max_copy);
buf.put_slice(&bytes);
tracing::trace!("{} bytes read from buffer", bytes.len());
}
if buf.remaining() > 0 { if buf.remaining() > 0 {
let mut reader = self.inner.take().unwrap(); tracing::trace!("{} bytes remaining to read", buf.remaining());
let reader = self.inner.take().unwrap();
let cipher = self.cipher.take().unwrap(); let cipher = self.cipher.take().unwrap();
self.fut = Some(Box::pin(async move { self.fut = Some(Box::pin(async move { read_bytes(reader, cipher).await }));
let package = match EncryptedPackage::from_async_read(&mut reader).await { } else {
Ok(p) => p, return Poll::Ready(Ok(()));
Err(e) => {
return (Err(e), reader, cipher);
}
};
match cipher.decrypt(package.into_inner()) {
Ok(bytes) => (Ok(bytes), reader, cipher),
Err(e) => (Err(e), reader, cipher),
}
}));
} }
} }
if self.fut.is_some() { match self.fut.as_mut().unwrap().as_mut().poll(cx) {
match self.fut.as_mut().unwrap().as_mut().poll(cx) { Poll::Ready((result, reader, cipher)) => {
Poll::Ready((result, reader, cipher)) => { self.inner = Some(reader);
self.inner = Some(reader); self.cipher = Some(cipher);
self.cipher = Some(cipher); match result {
match result { Ok(bytes) => {
Ok(bytes) => { self.fut = None;
self.remaining.put(bytes); self.remaining.put(bytes);
let max_copy = min(self.remaining.len(), buf.remaining()); let max_copy = min(self.remaining.len(), buf.remaining());
let bytes = self.remaining.copy_to_bytes(max_copy); let bytes = self.remaining.copy_to_bytes(max_copy);
self.fut = None; buf.put_slice(&bytes);
buf.put_slice(&bytes); tracing::trace!("{} bytes read from buffer", bytes.len());
if buf.remaining() == 0 { if buf.remaining() == 0 {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { } else {
Poll::Pending Poll::Pending
}
} }
Err(e) => Poll::Ready(Err(e)),
} }
Err(e) => Poll::Ready(Err(e)),
} }
Poll::Pending => Poll::Pending,
} }
} else { Poll::Pending => Poll::Pending,
Poll::Ready(Ok(()))
} }
} }
} }
@ -112,11 +104,13 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<Result<usize, Error>> { ) -> Poll<Result<usize, Error>> {
if buf.remaining() > 0 { let written_length = buf.len();
let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) };
self.buffer.put(Bytes::from(buf)); if self.fut_write.is_none() {
self.buffer.put(Bytes::from(buf.to_vec()));
if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE { if self.buffer.len() >= WRITE_BUF_SIZE {
tracing::trace!("buffer has reached sending size: {}", self.buffer.len());
let buffer_len = self.buffer.len(); let buffer_len = self.buffer.len();
let max_copy = min(u32::MAX as usize, buffer_len); let max_copy = min(u32::MAX as usize, buffer_len);
let plaintext = self.buffer.copy_to_bytes(max_copy); let plaintext = self.buffer.copy_to_bytes(max_copy);
@ -124,38 +118,42 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
let cipher = self.cipher.take().unwrap(); let cipher = self.cipher.take().unwrap();
self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher))) self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher)))
} else {
return Poll::Ready(Ok(written_length));
} }
} }
if self.fut_write.is_some() {
match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { match self.fut_write.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready((result, writer, cipher)) => { Poll::Ready((result, writer, cipher)) => {
self.inner = Some(writer); self.inner = Some(writer);
self.cipher = Some(cipher); self.cipher = Some(cipher);
self.fut_write = None; self.fut_write = None;
Poll::Ready(result.map(|_| buf.len())) Poll::Ready(result.map(|_| written_length))
}
Poll::Pending => Poll::Pending,
} }
} else { Poll::Pending => Poll::Pending,
Poll::Ready(Ok(buf.len()))
} }
} }
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let buffer_len = self.buffer.len(); let buffer_len = self.buffer.len();
if !self.buffer.is_empty() && self.fut_flush.is_none() { if self.fut_flush.is_none() {
let max_copy = min(u32::MAX as usize, buffer_len); let max_copy = min(u32::MAX as usize, buffer_len);
let plaintext = self.buffer.copy_to_bytes(max_copy); let plaintext = self.buffer.copy_to_bytes(max_copy);
let writer = self.inner.take().unwrap(); let writer = self.inner.take().unwrap();
let cipher = self.cipher.take().unwrap(); let cipher = self.cipher.take().unwrap();
self.fut_flush = Some(Box::pin(async move { self.fut_flush = Some(Box::pin(async move {
let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await; let (mut writer, cipher) = if plaintext.len() > 0 {
if result.is_err() { let (result, writer, cipher) = write_bytes(plaintext, writer, cipher).await;
return (result, writer, cipher); if result.is_err() {
} return (result, writer, cipher);
}
(writer, cipher)
} else {
(writer, cipher)
};
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
(Err(e), writer, cipher) (Err(e), writer, cipher)
} else { } else {
@ -200,11 +198,13 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
} }
} }
#[tracing::instrument(level = "trace", skip_all)]
async fn write_bytes<T: AsyncWrite + Unpin>( async fn write_bytes<T: AsyncWrite + Unpin>(
bytes: Bytes, bytes: Bytes,
mut writer: T, mut writer: T,
cipher: CipherBox, cipher: CipherBox,
) -> (io::Result<()>, T, CipherBox) { ) -> (io::Result<()>, T, CipherBox) {
tracing::trace!("plaintext size: {}", bytes.len());
let encrypted_bytes = match cipher.encrypt(bytes) { let encrypted_bytes = match cipher.encrypt(bytes) {
Ok(b) => b, Ok(b) => b,
Err(e) => { Err(e) => {
@ -212,9 +212,29 @@ async fn write_bytes<T: AsyncWrite + Unpin>(
} }
}; };
let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes(); let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes();
tracing::trace!("encrypted size: {}", package_bytes.len());
if let Err(e) = writer.write_all(&package_bytes[..]).await { if let Err(e) = writer.write_all(&package_bytes[..]).await {
return (Err(e), writer, cipher); return (Err(e), writer, cipher);
} }
tracing::trace!("everything sent");
(Ok(()), writer, cipher) (Ok(()), writer, cipher)
} }
#[tracing::instrument(level = "trace", skip_all)]
async fn read_bytes<T: AsyncRead + Unpin>(
mut reader: T,
cipher: CipherBox,
) -> (io::Result<Bytes>, T, CipherBox) {
let package = match EncryptedPackage::from_async_read(&mut reader).await {
Ok(p) => p,
Err(e) => {
return (Err(e), reader, cipher);
}
};
tracing::trace!("received {} bytes", package.bytes.len());
match cipher.decrypt(package.into_inner()) {
Ok(bytes) => (Ok(bytes), reader, cipher),
Err(e) => (Err(e), reader, cipher),
}
}

@ -9,7 +9,7 @@ use rand_core::RngCore;
use std::future::Future; use std::future::Future;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use x25519_dalek::{SharedSecret, StaticSecret}; use x25519_dalek::{SharedSecret, StaticSecret};
use crate::prelude::encrypted::crypt_handling::CipherBox; use crate::prelude::encrypted::crypt_handling::CipherBox;
@ -80,8 +80,8 @@ impl<T: AsyncProtocolStream> EncryptedStream<T> {
} }
pub struct EncryptedReadStream<T: AsyncRead> { pub struct EncryptedReadStream<T: AsyncRead> {
inner: Option<BufReader<T>>, inner: Option<T>,
fut: OptionalFuture<(io::Result<Bytes>, BufReader<T>, CipherBox)>, fut: OptionalFuture<(io::Result<Bytes>, T, CipherBox)>,
remaining: BytesMut, remaining: BytesMut,
cipher: Option<CipherBox>, cipher: Option<CipherBox>,
} }
@ -89,7 +89,7 @@ pub struct EncryptedReadStream<T: AsyncRead> {
impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> { impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> {
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
Self { Self {
inner: Some(BufReader::new(inner)), inner: Some(inner),
fut: None, fut: None,
remaining: BytesMut::new(), remaining: BytesMut::new(),
cipher: Some(cipher), cipher: Some(cipher),
@ -98,18 +98,18 @@ impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> {
} }
pub struct EncryptedWriteStream<T: 'static + AsyncWrite + Unpin + Send + Sync> { pub struct EncryptedWriteStream<T: 'static + AsyncWrite + Unpin + Send + Sync> {
inner: Option<BufWriter<T>>, inner: Option<T>,
cipher: Option<CipherBox>, cipher: Option<CipherBox>,
buffer: BytesMut, buffer: BytesMut,
fut_write: OptionalFuture<(io::Result<()>, BufWriter<T>, CipherBox)>, fut_write: OptionalFuture<(io::Result<()>, T, CipherBox)>,
fut_flush: OptionalFuture<(io::Result<()>, BufWriter<T>, CipherBox)>, fut_flush: OptionalFuture<(io::Result<()>, T, CipherBox)>,
fut_shutdown: OptionalFuture<io::Result<()>>, fut_shutdown: OptionalFuture<io::Result<()>>,
} }
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> EncryptedWriteStream<T> { impl<T: 'static + AsyncWrite + Unpin + Send + Sync> EncryptedWriteStream<T> {
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
Self { Self {
inner: Some(BufWriter::new(inner)), inner: Some(inner),
cipher: Some(cipher), cipher: Some(cipher),
buffer: BytesMut::with_capacity(1024), buffer: BytesMut::with_capacity(1024),
fut_write: None, fut_write: None,

@ -14,25 +14,13 @@ use std::time::Duration;
mod utils; mod utils;
#[tokio::test] #[tokio::test]
async fn it_sends_and_receives() { async fn it_sends_and_receives_smaller_packages() {
let ctx = get_client_with_server().await; send_and_receive_bytes(140).await.unwrap();
let mut rng = rand::thread_rng(); }
let mut buffer = vec![0u8; 140];
rng.fill_bytes(&mut buffer);
let mut stream = ctx
.emit("bytes", BytePayload::new(buffer.clone()))
.stream_replies()
.await
.unwrap();
let mut count = 0;
while let Some(Ok(response)) = stream.next().await { #[tokio::test]
let bytes = response.payload::<BytePayload>().unwrap(); async fn it_sends_and_receives_larger_packages() {
assert_eq!(bytes.into_inner(), buffer); send_and_receive_bytes(1024 * 32).await.unwrap();
count += 1;
}
assert_eq!(count, 100)
} }
#[tokio::test] #[tokio::test]
@ -48,6 +36,28 @@ async fn it_sends_and_receives_strings() {
assert_eq!(&response_string, "Hello World") assert_eq!(&response_string, "Hello World")
} }
async fn send_and_receive_bytes(byte_size: usize) -> IPCResult<()> {
let ctx = get_client_with_server().await;
let mut rng = rand::thread_rng();
let mut buffer = vec![0u8; byte_size];
rng.fill_bytes(&mut buffer);
let mut stream = ctx
.emit("bytes", BytePayload::new(buffer.clone()))
.stream_replies()
.await?;
let mut count = 0;
while let Some(response) = stream.next().await {
let bytes = response.unwrap().payload::<BytePayload>()?;
assert_eq!(bytes.into_inner(), buffer);
count += 1;
}
assert_eq!(count, 100);
Ok(())
}
async fn get_client_with_server() -> Context { async fn get_client_with_server() -> Context {
let port = get_free_port(); let port = get_free_port();
@ -59,18 +69,18 @@ fn get_builder(port: u8) -> IPCBuilder<EncryptedListener<TestProtocolListener>>
.address(port) .address(port)
.on("bytes", callback!(handle_bytes)) .on("bytes", callback!(handle_bytes))
.on("string", callback!(handle_string)) .on("string", callback!(handle_string))
.timeout(Duration::from_millis(100)) .timeout(Duration::from_secs(10))
} }
async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult<Response> { async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult<Response> {
increment_counter_for_event(ctx, &event).await; increment_counter_for_event(ctx, &event).await;
let bytes = event.payload::<BytePayload>()?.into_bytes(); let bytes = event.payload::<BytePayload>()?.into_inner();
for _ in 0u8..99 { for _ in 0u8..99 {
ctx.emit("bytes", BytePayload::from(bytes.clone())).await?; ctx.emit("bytes", BytePayload::new(bytes.clone())).await?;
} }
ctx.response(BytePayload::from(bytes)) ctx.response(BytePayload::new(bytes))
} }
async fn handle_string(ctx: &Context, event: Event) -> IPCResult<Response> { async fn handle_string(ctx: &Context, event: Event) -> IPCResult<Response> {

@ -1,20 +1,32 @@
#![allow(unused)]
use bromine::context::Context; use bromine::context::Context;
use bromine::protocol::AsyncStreamProtocolListener; use bromine::protocol::AsyncStreamProtocolListener;
use bromine::IPCBuilder; use bromine::IPCBuilder;
use call_counter::*; use call_counter::*;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::oneshot::channel; use tokio::sync::oneshot::channel;
pub mod call_counter; pub mod call_counter;
pub mod protocol; pub mod protocol;
pub fn setup() {
lazy_static! {
static ref SETUP_DONE: Arc<AtomicBool> = Default::default();
}
if !SETUP_DONE.swap(true, Ordering::SeqCst) {
tracing_subscriber::fmt::init();
}
}
pub fn get_free_port() -> u8 { pub fn get_free_port() -> u8 {
lazy_static! { lazy_static! {
static ref PORT_COUNTER: Arc<AtomicU8> = Arc::new(AtomicU8::new(0)); static ref PORT_COUNTER: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
} }
PORT_COUNTER.fetch_add(1, Ordering::Relaxed) let count = PORT_COUNTER.fetch_add(1, Ordering::Relaxed);
count
} }
pub async fn start_server_and_client< pub async fn start_server_and_client<
@ -23,6 +35,7 @@ pub async fn start_server_and_client<
>( >(
builder_fn: F, builder_fn: F,
) -> Context { ) -> Context {
setup();
let counters = CallCounter::default(); let counters = CallCounter::default();
let (sender, receiver) = channel::<()>(); let (sender, receiver) = channel::<()>();
let client_builder = builder_fn().insert::<CallCounterKey>(counters.clone()); let client_builder = builder_fn().insert::<CallCounterKey>(counters.clone());

Loading…
Cancel
Save