Fix issues with encryption writers

Signed-off-by: trivernis <trivernis@protonmail.com>
pull/38/head
trivernis 3 years ago
parent fe7dc97008
commit ef99adfee1
Signed by: Trivernis
GPG Key ID: DFFFCC2C7A02DB45

79
Cargo.lock generated

@ -20,6 +20,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "ansi_term"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
dependencies = [
"winapi",
]
[[package]]
name = "async-trait"
version = "0.1.52"
@ -124,6 +133,7 @@ dependencies = [
"futures-core",
"lazy_static",
"num_enum",
"port_check",
"postcard",
"rand",
"rand_core 0.6.3",
@ -134,6 +144,7 @@ dependencies = [
"thiserror",
"tokio",
"tracing",
"tracing-subscriber",
"trait-bound-typemap",
"x25519-dalek",
]
@ -749,6 +760,12 @@ dependencies = [
"syn",
]
[[package]]
name = "once_cell"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9"
[[package]]
name = "oorandom"
version = "11.1.3"
@ -812,6 +829,12 @@ dependencies = [
"universal-hash",
]
[[package]]
name = "port_check"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6519412c9e0d4be579b9f0618364d19cb434b324fc6ddb1b27b1e682c7105ed"
[[package]]
name = "postcard"
version = "0.7.3"
@ -1104,12 +1127,27 @@ dependencies = [
"digest 0.10.3",
]
[[package]]
name = "sharded-slab"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31"
dependencies = [
"lazy_static",
]
[[package]]
name = "slab"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5"
[[package]]
name = "smallvec"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
[[package]]
name = "socket2"
version = "0.4.4"
@ -1193,6 +1231,15 @@ dependencies = [
"syn",
]
[[package]]
name = "thread_local"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180"
dependencies = [
"once_cell",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
@ -1270,6 +1317,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c"
dependencies = [
"lazy_static",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3"
dependencies = [
"lazy_static",
"log",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e0ab7bdc962035a87fba73f3acca9b8a8d0034c2e6f60b84aeaaddddc155dce"
dependencies = [
"ansi_term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]]
@ -1309,6 +1382,12 @@ dependencies = [
"subtle",
]
[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "vcell"
version = "0.1.3"

@ -56,6 +56,8 @@ features = ["alloc"]
rmp-serde = "1.0.0"
crossbeam-utils = "0.8.7"
futures = "0.3.21"
tracing-subscriber = "0.3.9"
port_check = "0.1.5"
[dev-dependencies.serde]
version = "1.0.136"

@ -6,7 +6,7 @@ use chacha20poly1305::aead::{Aead, NewAead};
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
use rand::thread_rng;
use rand_core::RngCore;
use sha2::{Digest, Sha224, Sha256};
use sha2::{Digest, Sha256};
use std::io;
use std::io::ErrorKind;
use std::sync::atomic::{AtomicU64, Ordering};
@ -38,6 +38,7 @@ impl CipherBox {
}
/// Encrypts the given message
#[tracing::instrument(level = "trace", skip_all)]
pub fn encrypt(&self, data: Bytes) -> io::Result<Bytes> {
self.cipher
.encrypt(&self.en_nonce(), &data[..])
@ -46,6 +47,7 @@ impl CipherBox {
}
/// Decrypts the given message
#[tracing::instrument(level = "trace", skip_all)]
pub fn decrypt(&self, data: Bytes) -> io::Result<Bytes> {
self.cipher
.decrypt(&self.de_nonce(), &data[..])
@ -56,6 +58,7 @@ impl CipherBox {
/// Updates the stored key.
/// This must be done simultaneously on server and client side
/// to keep track of the nonce
#[tracing::instrument(level = "trace", skip_all)]
pub fn update_key(&mut self, key: Bytes) {
let key = Key::from_slice(&key[..]);
self.cipher = ChaCha20Poly1305::new(key);
@ -64,6 +67,7 @@ impl CipherBox {
/// Resets the nonce counters.
/// This must be done simultaneously on server and client side.
#[tracing::instrument(level = "trace", skip_all)]
pub fn reset_counters(&mut self) {
self.de_count.store(0, Ordering::SeqCst);
self.en_count.store(0, Ordering::SeqCst);
@ -71,22 +75,24 @@ impl CipherBox {
fn en_nonce(&self) -> Nonce {
let count = self.en_count.fetch_add(1, Ordering::SeqCst);
tracing::trace!("encrypted count {}", count);
nonce_from_number(count)
}
fn de_nonce(&self) -> Nonce {
let count = self.de_count.fetch_add(1, Ordering::SeqCst);
tracing::trace!("decrypted count {}", count);
nonce_from_number(count)
}
}
/// Generates a nonce from a given number
/// the nonce is passed through sha224 for pseudo-randomness
/// The given number is repeated to fit the nonce bytes
fn nonce_from_number(number: u64) -> Nonce {
let number_bytes: [u8; 8] = number.to_be_bytes();
let sha_bytes = Sha224::digest(&number_bytes).to_vec();
let num_vec = number_bytes.repeat(2);
let mut nonce_bytes = [0u8; 12];
nonce_bytes.copy_from_slice(&sha_bytes[..12]);
nonce_bytes.copy_from_slice(&num_vec[..12]);
nonce_bytes.into()
}
@ -131,6 +137,7 @@ impl<T: AsyncProtocolStream> EncryptedStream<T> {
}
}
#[tracing::instrument(level = "debug", skip_all)]
async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult<PublicKey> {
let mut pk_buf = [0u8; 32];
stream.read_exact(&mut pk_buf).await?;
@ -138,6 +145,7 @@ async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult
Ok(PublicKey::from(pk_buf))
}
#[tracing::instrument(level = "debug", skip_all)]
async fn send_public_key<T: AsyncProtocolStream>(
stream: &mut T,
secret: &StaticSecret,
@ -149,6 +157,7 @@ async fn send_public_key<T: AsyncProtocolStream>(
Ok(())
}
#[tracing::instrument(level = "trace", skip_all)]
fn generate_secret() -> Vec<u8> {
let mut rng = thread_rng();
let mut buf = vec![0u8; 32];

@ -52,40 +52,35 @@ impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadSt
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.fut.is_none() {
if self.remaining.len() > 0 {
let max_copy = min(buf.remaining(), self.remaining.len());
let bytes = self.remaining.copy_to_bytes(max_copy);
buf.put_slice(&bytes);
tracing::trace!("{} bytes read from buffer", bytes.len());
}
if buf.remaining() > 0 {
let mut reader = self.inner.take().unwrap();
tracing::trace!("{} bytes remaining to read", buf.remaining());
let reader = self.inner.take().unwrap();
let cipher = self.cipher.take().unwrap();
self.fut = Some(Box::pin(async move {
let package = match EncryptedPackage::from_async_read(&mut reader).await {
Ok(p) => p,
Err(e) => {
return (Err(e), reader, cipher);
}
};
match cipher.decrypt(package.into_inner()) {
Ok(bytes) => (Ok(bytes), reader, cipher),
Err(e) => (Err(e), reader, cipher),
}
}));
self.fut = Some(Box::pin(async move { read_bytes(reader, cipher).await }));
} else {
return Poll::Ready(Ok(()));
}
}
if self.fut.is_some() {
match self.fut.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready((result, reader, cipher)) => {
self.inner = Some(reader);
self.cipher = Some(cipher);
match result {
Ok(bytes) => {
self.fut = None;
self.remaining.put(bytes);
let max_copy = min(self.remaining.len(), buf.remaining());
let bytes = self.remaining.copy_to_bytes(max_copy);
self.fut = None;
buf.put_slice(&bytes);
tracing::trace!("{} bytes read from buffer", bytes.len());
if buf.remaining() == 0 {
Poll::Ready(Ok(()))
@ -98,9 +93,6 @@ impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadSt
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Ready(Ok(()))
}
}
}
@ -112,11 +104,13 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
if buf.remaining() > 0 {
let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) };
self.buffer.put(Bytes::from(buf));
let written_length = buf.len();
if self.fut_write.is_none() {
self.buffer.put(Bytes::from(buf.to_vec()));
if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE {
if self.buffer.len() >= WRITE_BUF_SIZE {
tracing::trace!("buffer has reached sending size: {}", self.buffer.len());
let buffer_len = self.buffer.len();
let max_copy = min(u32::MAX as usize, buffer_len);
let plaintext = self.buffer.copy_to_bytes(max_copy);
@ -124,38 +118,42 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
let cipher = self.cipher.take().unwrap();
self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher)))
} else {
return Poll::Ready(Ok(written_length));
}
}
if self.fut_write.is_some() {
match self.fut_write.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready((result, writer, cipher)) => {
self.inner = Some(writer);
self.cipher = Some(cipher);
self.fut_write = None;
Poll::Ready(result.map(|_| buf.len()))
Poll::Ready(result.map(|_| written_length))
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Ready(Ok(buf.len()))
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let buffer_len = self.buffer.len();
if !self.buffer.is_empty() && self.fut_flush.is_none() {
if self.fut_flush.is_none() {
let max_copy = min(u32::MAX as usize, buffer_len);
let plaintext = self.buffer.copy_to_bytes(max_copy);
let writer = self.inner.take().unwrap();
let cipher = self.cipher.take().unwrap();
self.fut_flush = Some(Box::pin(async move {
let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await;
let (mut writer, cipher) = if plaintext.len() > 0 {
let (result, writer, cipher) = write_bytes(plaintext, writer, cipher).await;
if result.is_err() {
return (result, writer, cipher);
}
(writer, cipher)
} else {
(writer, cipher)
};
if let Err(e) = writer.flush().await {
(Err(e), writer, cipher)
} else {
@ -200,11 +198,13 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
}
}
#[tracing::instrument(level = "trace", skip_all)]
async fn write_bytes<T: AsyncWrite + Unpin>(
bytes: Bytes,
mut writer: T,
cipher: CipherBox,
) -> (io::Result<()>, T, CipherBox) {
tracing::trace!("plaintext size: {}", bytes.len());
let encrypted_bytes = match cipher.encrypt(bytes) {
Ok(b) => b,
Err(e) => {
@ -212,9 +212,29 @@ async fn write_bytes<T: AsyncWrite + Unpin>(
}
};
let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes();
tracing::trace!("encrypted size: {}", package_bytes.len());
if let Err(e) = writer.write_all(&package_bytes[..]).await {
return (Err(e), writer, cipher);
}
tracing::trace!("everything sent");
(Ok(()), writer, cipher)
}
#[tracing::instrument(level = "trace", skip_all)]
async fn read_bytes<T: AsyncRead + Unpin>(
mut reader: T,
cipher: CipherBox,
) -> (io::Result<Bytes>, T, CipherBox) {
let package = match EncryptedPackage::from_async_read(&mut reader).await {
Ok(p) => p,
Err(e) => {
return (Err(e), reader, cipher);
}
};
tracing::trace!("received {} bytes", package.bytes.len());
match cipher.decrypt(package.into_inner()) {
Ok(bytes) => (Ok(bytes), reader, cipher),
Err(e) => (Err(e), reader, cipher),
}
}

@ -9,7 +9,7 @@ use rand_core::RngCore;
use std::future::Future;
use std::io;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use x25519_dalek::{SharedSecret, StaticSecret};
use crate::prelude::encrypted::crypt_handling::CipherBox;
@ -80,8 +80,8 @@ impl<T: AsyncProtocolStream> EncryptedStream<T> {
}
pub struct EncryptedReadStream<T: AsyncRead> {
inner: Option<BufReader<T>>,
fut: OptionalFuture<(io::Result<Bytes>, BufReader<T>, CipherBox)>,
inner: Option<T>,
fut: OptionalFuture<(io::Result<Bytes>, T, CipherBox)>,
remaining: BytesMut,
cipher: Option<CipherBox>,
}
@ -89,7 +89,7 @@ pub struct EncryptedReadStream<T: AsyncRead> {
impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> {
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
Self {
inner: Some(BufReader::new(inner)),
inner: Some(inner),
fut: None,
remaining: BytesMut::new(),
cipher: Some(cipher),
@ -98,18 +98,18 @@ impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> {
}
pub struct EncryptedWriteStream<T: 'static + AsyncWrite + Unpin + Send + Sync> {
inner: Option<BufWriter<T>>,
inner: Option<T>,
cipher: Option<CipherBox>,
buffer: BytesMut,
fut_write: OptionalFuture<(io::Result<()>, BufWriter<T>, CipherBox)>,
fut_flush: OptionalFuture<(io::Result<()>, BufWriter<T>, CipherBox)>,
fut_write: OptionalFuture<(io::Result<()>, T, CipherBox)>,
fut_flush: OptionalFuture<(io::Result<()>, T, CipherBox)>,
fut_shutdown: OptionalFuture<io::Result<()>>,
}
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> EncryptedWriteStream<T> {
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
Self {
inner: Some(BufWriter::new(inner)),
inner: Some(inner),
cipher: Some(cipher),
buffer: BytesMut::with_capacity(1024),
fut_write: None,

@ -14,25 +14,13 @@ use std::time::Duration;
mod utils;
#[tokio::test]
async fn it_sends_and_receives() {
let ctx = get_client_with_server().await;
let mut rng = rand::thread_rng();
let mut buffer = vec![0u8; 140];
rng.fill_bytes(&mut buffer);
let mut stream = ctx
.emit("bytes", BytePayload::new(buffer.clone()))
.stream_replies()
.await
.unwrap();
let mut count = 0;
while let Some(Ok(response)) = stream.next().await {
let bytes = response.payload::<BytePayload>().unwrap();
assert_eq!(bytes.into_inner(), buffer);
count += 1;
async fn it_sends_and_receives_smaller_packages() {
send_and_receive_bytes(140).await.unwrap();
}
assert_eq!(count, 100)
#[tokio::test]
async fn it_sends_and_receives_larger_packages() {
send_and_receive_bytes(1024 * 32).await.unwrap();
}
#[tokio::test]
@ -48,6 +36,28 @@ async fn it_sends_and_receives_strings() {
assert_eq!(&response_string, "Hello World")
}
async fn send_and_receive_bytes(byte_size: usize) -> IPCResult<()> {
let ctx = get_client_with_server().await;
let mut rng = rand::thread_rng();
let mut buffer = vec![0u8; byte_size];
rng.fill_bytes(&mut buffer);
let mut stream = ctx
.emit("bytes", BytePayload::new(buffer.clone()))
.stream_replies()
.await?;
let mut count = 0;
while let Some(response) = stream.next().await {
let bytes = response.unwrap().payload::<BytePayload>()?;
assert_eq!(bytes.into_inner(), buffer);
count += 1;
}
assert_eq!(count, 100);
Ok(())
}
async fn get_client_with_server() -> Context {
let port = get_free_port();
@ -59,18 +69,18 @@ fn get_builder(port: u8) -> IPCBuilder<EncryptedListener<TestProtocolListener>>
.address(port)
.on("bytes", callback!(handle_bytes))
.on("string", callback!(handle_string))
.timeout(Duration::from_millis(100))
.timeout(Duration::from_secs(10))
}
async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult<Response> {
increment_counter_for_event(ctx, &event).await;
let bytes = event.payload::<BytePayload>()?.into_bytes();
let bytes = event.payload::<BytePayload>()?.into_inner();
for _ in 0u8..99 {
ctx.emit("bytes", BytePayload::from(bytes.clone())).await?;
ctx.emit("bytes", BytePayload::new(bytes.clone())).await?;
}
ctx.response(BytePayload::from(bytes))
ctx.response(BytePayload::new(bytes))
}
async fn handle_string(ctx: &Context, event: Event) -> IPCResult<Response> {

@ -1,20 +1,32 @@
#![allow(unused)]
use bromine::context::Context;
use bromine::protocol::AsyncStreamProtocolListener;
use bromine::IPCBuilder;
use call_counter::*;
use lazy_static::lazy_static;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
use std::sync::Arc;
use tokio::sync::oneshot::channel;
pub mod call_counter;
pub mod protocol;
pub fn setup() {
lazy_static! {
static ref SETUP_DONE: Arc<AtomicBool> = Default::default();
}
if !SETUP_DONE.swap(true, Ordering::SeqCst) {
tracing_subscriber::fmt::init();
}
}
pub fn get_free_port() -> u8 {
lazy_static! {
static ref PORT_COUNTER: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
}
PORT_COUNTER.fetch_add(1, Ordering::Relaxed)
let count = PORT_COUNTER.fetch_add(1, Ordering::Relaxed);
count
}
pub async fn start_server_and_client<
@ -23,6 +35,7 @@ pub async fn start_server_and_client<
>(
builder_fn: F,
) -> Context {
setup();
let counters = CallCounter::default();
let (sender, receiver) = channel::<()>();
let client_builder = builder_fn().insert::<CallCounterKey>(counters.clone());

Loading…
Cancel
Save