diff --git a/Cargo.toml b/Cargo.toml index dd94391..9b18ca6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vented" description = "Event driven encrypted tcp communicaton" -version = "0.11.4" +version = "0.11.5" authors = ["trivernis "] edition = "2018" readme = "README.md" diff --git a/src/server/mod.rs b/src/server/mod.rs index c99c725..93fb349 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -10,8 +10,8 @@ use std::iter::FromIterator; use std::sync::Arc; use std::time::{Duration, Instant}; +use async_std::sync::Mutex; use crypto_box::{PublicKey, SecretKey}; -use parking_lot::Mutex; use sha2::Digest; use x25519_dalek::StaticSecret; @@ -122,8 +122,7 @@ impl VentedServer { /// Returns the nodes known to the server pub fn nodes(&self) -> Vec { - self.known_nodes - .lock() + task::block_on(self.known_nodes.lock()) .values() .cloned() .map(Node::from) @@ -240,7 +239,7 @@ impl VentedServer { NodeState::Dead(Instant::now()) }; - if let Some(node) = self.known_nodes.lock().get_mut(target) { + if let Some(node) = self.known_nodes.lock().await.get_mut(target) { node.set_node_state(node_state); } @@ -252,6 +251,7 @@ impl VentedServer { let connected_nodes = self .known_nodes .lock() + .await .values() .filter(|node| node.is_alive()) .cloned() @@ -267,6 +267,7 @@ impl VentedServer { let mut value = AsyncValue::new(); self.redirect_handles .lock() + .await .insert(payload.id, AsyncValue::clone(&value)); if let Ok(mut stream) = self.get_connection(&node.node().id).await { @@ -307,6 +308,7 @@ impl VentedServer { log::trace!("Secure connection established."); self.connections .lock() + .await .insert(stream.receiver_node().clone(), stream.clone()); self.event_handler .handle_event(Event::new(READY_EVENT)) @@ -342,7 +344,7 @@ impl VentedServer { stream.receiver_node(), e ); - stream.shutdown().expect("Failed to shutdown stream"); + stream.shutdown().await.expect("Failed to shutdown stream"); } } }); @@ -357,15 +359,15 @@ impl VentedServer { } } } - connections.lock().remove(stream.receiver_node()); - stream.shutdown().expect("Failed to shutdown stream"); + connections.lock().await.remove(stream.receiver_node()); + stream.shutdown().await.expect("Failed to shutdown stream"); } /// Takes three attempts to retrieve a connection for the given node. /// First it tries to use the already established connection stored in the shared connections vector. /// If that fails it tries to establish a new connection to the node by using the known address async fn get_connection(&self, target: &String) -> VentedResult { - if let Some(stream) = self.connections.lock().get(target) { + if let Some(stream) = self.connections.lock().await.get(target) { log::trace!("Reusing existing connection."); return Ok(stream.clone()); } @@ -373,6 +375,7 @@ impl VentedServer { let target_node = self .known_nodes .lock() + .await .get(target) .cloned() .ok_or(VentedError::UnknownNode(target.clone()))?; @@ -386,6 +389,7 @@ impl VentedServer { log::error!("Failed to connect to node {}'s address: {}", target, e); self.known_nodes .lock() + .await .get_mut(target) .unwrap() .node_mut() @@ -405,6 +409,7 @@ impl VentedServer { let stream = self.perform_client_key_exchange(stream).await?; self.connections .lock() + .await .insert(stream.receiver_node().clone(), stream.clone()); task::spawn(Self::read_stream( stream.clone(), @@ -462,7 +467,7 @@ impl VentedServer { let public_key = PublicKey::from(public_key); - let node_data = if let Some(data) = self.known_nodes.lock().get(&node_id) { + let node_data = if let Some(data) = self.known_nodes.lock().await.get(&node_id) { data.clone() } else { stream.write(&Event::new(REJECT_EVENT).as_bytes()).await?; @@ -486,7 +491,7 @@ impl VentedServer { let final_secret = Self::generate_final_secret(pre_secret.to_bytes().to_vec(), key_a, key_b); let final_public = final_secret.public_key(); - stream.update_key(&final_secret, &final_public); + stream.update_key(&final_secret, &final_public).await; Ok(stream) } @@ -524,7 +529,7 @@ impl VentedServer { } let public_key = PublicKey::from(public_key); - let data_options = self.known_nodes.lock().get(&node_id).cloned(); + let data_options = self.known_nodes.lock().await.get(&node_id).cloned(); let node_data = if let Some(data) = data_options { data } else { @@ -564,7 +569,7 @@ impl VentedServer { let final_secret = Self::generate_final_secret(pre_secret.to_bytes().to_vec(), key_a, key_b); let final_public = final_secret.public_key(); - stream.update_key(&final_secret, &final_public); + stream.update_key(&final_secret, &final_public).await; Ok(stream) } diff --git a/src/server/server_events.rs b/src/server/server_events.rs index 0a40f76..05bb279 100644 --- a/src/server/server_events.rs +++ b/src/server/server_events.rs @@ -113,7 +113,7 @@ impl VentedServer { let redirect_handles = Arc::clone(&redirect_handles); Box::pin(async move { let payload = event.get_payload::().ok()?; - let mut value = redirect_handles.lock().remove(&payload.id)?; + let mut value = redirect_handles.lock().await.remove(&payload.id)?; value.resolve(()); None }) @@ -125,7 +125,7 @@ impl VentedServer { let redirect_handles = Arc::clone(&redirect_handles); Box::pin(async move { let payload = event.get_payload::().ok()?; - let mut value = redirect_handles.lock().remove(&payload.id)?; + let mut value = redirect_handles.lock().await.remove(&payload.id)?; value.reject(VentedError::Rejected); None @@ -146,7 +146,7 @@ impl VentedServer { payload.proxy, payload.target ); - let opt_stream = connections.lock().get(&payload.target).cloned(); + let opt_stream = connections.lock().await.get(&payload.target).cloned(); if let Some(mut stream) = opt_stream { if let Ok(_) = stream .send(Event::with_payload(REDIRECT_REDIRECTED_EVENT, &payload)) @@ -202,15 +202,15 @@ impl VentedServer { ) }) .collect::>(); - let opt_stream = connections.lock().get(&origin).cloned(); + let opt_stream = connections.lock().await.get(&origin).cloned(); log::trace!("Sending responses..."); if let Some(mut stream) = opt_stream { for response in responses { if let Err(e) = stream.send(response).await { log::error!("Failed to send response events: {}", e); - connections.lock().remove(stream.receiver_node()); - stream.shutdown().expect("Failed to shutdown stream"); + connections.lock().await.remove(stream.receiver_node()); + stream.shutdown().await.expect("Failed to shutdown stream"); } } } @@ -230,7 +230,7 @@ impl VentedServer { let own_node_id = own_node_id.clone(); Box::pin(async move { let list = event.get_payload::().ok()?; - let mut own_nodes = node_list.lock(); + let mut own_nodes = node_list.lock().await; let origin = event.origin?; if !own_nodes.get(&origin)?.node().trusted { @@ -269,6 +269,7 @@ impl VentedServer { let sender_id = event.origin?; let nodes = node_list .lock() + .await .values() .filter(|node| node.node().id != sender_id) .map(|node| NodeListElement { diff --git a/src/stream/cryptostream.rs b/src/stream/cryptostream.rs index b7d32d1..5e72fcb 100644 --- a/src/stream/cryptostream.rs +++ b/src/stream/cryptostream.rs @@ -6,11 +6,11 @@ use async_std::prelude::*; +use async_std::sync::Mutex; use byteorder::{BigEndian, ByteOrder}; use crypto_box::aead::{Aead, Payload}; use crypto_box::{ChaChaBox, SecretKey}; use generic_array::GenericArray; -use parking_lot::Mutex; use sha2::Digest; use std::sync::Arc; use typenum::*; @@ -24,7 +24,8 @@ use async_std::net::{Shutdown, TcpStream}; #[derive(Clone)] pub struct CryptoStream { recv_node_id: String, - stream: TcpStream, + read_stream: Arc>, + write_stream: Arc>, send_secret: Arc>>, recv_secret: Arc>>, } @@ -42,7 +43,8 @@ impl CryptoStream { Ok(Self { recv_node_id: node_id, - stream: inner, + read_stream: Arc::new(Mutex::new(inner.clone())), + write_stream: Arc::new(Mutex::new(inner)), send_secret: Arc::new(Mutex::new(send_box)), recv_secret: Arc::new(Mutex::new(recv_box)), }) @@ -53,15 +55,16 @@ impl CryptoStream { /// length: u64 /// data: length pub async fn send(&mut self, mut event: Event) -> VentedResult<()> { - let ciphertext = self.send_secret.lock().encrypt(&event.as_bytes())?; + let ciphertext = self.send_secret.lock().await.encrypt(&event.as_bytes())?; let mut length_raw = [0u8; 8]; BigEndian::write_u64(&mut length_raw, ciphertext.len() as u64); log::trace!("Encoded event '{}' to raw message", event.name); - self.stream.write(&length_raw).await?; - self.stream.write(&ciphertext).await?; - self.stream.flush().await?; + let mut stream = self.write_stream.lock().await; + stream.write(&length_raw).await?; + stream.write(&ciphertext).await?; + stream.flush().await?; log::trace!("Event sent"); @@ -71,14 +74,15 @@ impl CryptoStream { /// Reads an event from the stream. Blocks until data is received pub async fn read(&mut self) -> VentedResult { let mut length_raw = [0u8; 8]; - self.stream.read_exact(&mut length_raw).await?; + let mut stream = self.read_stream.lock().await; + stream.read_exact(&mut length_raw).await?; let length = BigEndian::read_u64(&length_raw); let mut ciphertext = vec![0u8; length as usize]; - self.stream.read(&mut ciphertext).await?; + stream.read(&mut ciphertext).await?; log::trace!("Received raw message"); - let plaintext = self.recv_secret.lock().decrypt(&ciphertext)?; + let plaintext = self.recv_secret.lock().await.decrypt(&ciphertext)?; let event = Event::from(&mut &plaintext[..])?; log::trace!("Decoded message to event '{}'", event.name); @@ -87,11 +91,11 @@ impl CryptoStream { } /// Updates the keys in the inner encryption box - pub fn update_key(&self, secret_key: &SecretKey, public_key: &PublicKey) { + pub async fn update_key(&self, secret_key: &SecretKey, public_key: &PublicKey) { let send_box = ChaChaBox::new(public_key, secret_key); let recv_box = ChaChaBox::new(public_key, secret_key); - self.send_secret.lock().swap_box(send_box); - self.recv_secret.lock().swap_box(recv_box); + self.send_secret.lock().await.swap_box(send_box); + self.recv_secret.lock().await.swap_box(recv_box); log::trace!("Updated secret"); } @@ -100,8 +104,8 @@ impl CryptoStream { } /// Closes both streams - pub fn shutdown(&mut self) -> VentedResult<()> { - self.stream.shutdown(Shutdown::Both)?; + pub async fn shutdown(&mut self) -> VentedResult<()> { + self.read_stream.lock().await.shutdown(Shutdown::Both)?; Ok(()) } diff --git a/tests/test_communication.rs b/tests/test_communication.rs index c2bedf4..084f649 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -1,21 +1,21 @@ +use async_std::sync::Mutex; use async_std::task; use crypto_box::SecretKey; use log::LevelFilter; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use vented::event::Event; use vented::server::data::{Node, ServerTimeouts}; use vented::server::server_events::NODE_LIST_REQUEST_EVENT; use vented::server::VentedServer; fn setup() { - simple_logger::SimpleLogger::new() + let _ = simple_logger::SimpleLogger::new() .with_module_level("async_std", LevelFilter::Warn) .with_module_level("async_io", LevelFilter::Warn) .with_module_level("polling", LevelFilter::Warn) - .init() - .unwrap(); + .init(); } #[test] @@ -147,3 +147,100 @@ fn test_server_communication() { assert_eq!(pong_count.load(Ordering::SeqCst), 10); assert!(c_pinged.load(Ordering::SeqCst)); } + +const COUNT: usize = 20000; + +#[test] +fn test_high_traffic() { + setup(); + let ping_count = Arc::new(AtomicUsize::new(0)); + let pong_count = Arc::new(AtomicUsize::new(0)); + let last_pong = Arc::new(Mutex::new(Instant::now())); + let running = Arc::new(AtomicBool::new(true)); + let mut rng = rand::thread_rng(); + let global_secret_a = SecretKey::generate(&mut rng); + let global_secret_b = SecretKey::generate(&mut rng); + + let nodes = vec![ + Node { + id: "A".to_string(), + addresses: vec!["localhost:22223".to_string()], + public_key: global_secret_a.public_key(), + trusted: true, + }, + Node { + id: "B".to_string(), + addresses: vec![], + public_key: global_secret_b.public_key(), + trusted: false, + }, + ]; + + task::block_on(async { + let mut server_a = VentedServer::new( + "A".to_string(), + global_secret_a, + nodes.clone(), + ServerTimeouts::default(), + ); + let mut server_b = VentedServer::new( + "B".to_string(), + global_secret_b, + nodes, + ServerTimeouts::default(), + ); + server_a.listen("localhost:22223".to_string()); + task::sleep(Duration::from_millis(10)).await; + + server_a.on("ping", { + let ping_count = Arc::clone(&ping_count); + move |_| { + let ping_count = Arc::clone(&ping_count); + Box::pin(async move { + ping_count.fetch_add(1, Ordering::Relaxed); + + Some(Event::new("pong".to_string())) + }) + } + }); + server_b.on("pong", { + let pong_count = Arc::clone(&pong_count); + let running = Arc::clone(&running); + let last_pong = Arc::clone(&last_pong); + + move |_| { + let pong_count = Arc::clone(&pong_count); + let running = Arc::clone(&running); + let last_pong = Arc::clone(&last_pong); + + Box::pin(async move { + *last_pong.lock().await = Instant::now(); + let num = pong_count.fetch_add(1, Ordering::Relaxed); + log::info!("Received pong nr. {}", num); + if num >= COUNT - (COUNT / 100) { + running.store(true, Ordering::Relaxed); + None + } else { + Some(Event::new("ping".to_string())) + } + }) + } + }); + let mut promises = Vec::new(); + + for _ in 0..COUNT / 100 { + promises.push(server_b.emit("A", Event::new("ping"))); + } + futures::future::join_all(promises).await; + + while running.load(Ordering::Relaxed) + && last_pong.lock().await.elapsed() < Duration::from_secs(2) + { + task::sleep(Duration::from_micros(10)).await; + } + }); + // wait one second to make sure the servers were able to process the events + + assert_eq!(ping_count.load(Ordering::SeqCst), COUNT); + assert_eq!(pong_count.load(Ordering::SeqCst), COUNT); +}