diff --git a/Cargo.toml b/Cargo.toml index f2f39d0..1b9b1e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vented" description = "Event driven encrypted tcp communicaton" -version = "0.9.1" +version = "0.9.2" authors = ["trivernis "] edition = "2018" readme = "README.md" diff --git a/src/server/data.rs b/src/server/data.rs index a4650c2..a01cb3c 100644 --- a/src/server/data.rs +++ b/src/server/data.rs @@ -7,6 +7,7 @@ use executors::crossbeam_workstealing_pool; use executors::parker::DynParker; use parking_lot::Mutex; use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use x25519_dalek::PublicKey; @@ -29,4 +30,5 @@ pub(crate) struct ServerConnectionContext { pub forwarded_connections: Arc>>>, pub pool: crossbeam_workstealing_pool::ThreadPool, pub redirect_handles: Arc>>>, + pub listener_count: Arc, } diff --git a/src/server/mod.rs b/src/server/mod.rs index 90f6fd7..15c3d77 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::net::{TcpListener, TcpStream}; +use std::net::{Shutdown, TcpListener, TcpStream}; use crypto_box::{PublicKey, SecretKey}; use executors::{crossbeam_workstealing_pool, Executor}; @@ -21,6 +21,7 @@ use parking_lot::Mutex; use sha2::Digest; use std::io::Write; use std::iter::FromIterator; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; use std::time::Duration; @@ -75,6 +76,8 @@ pub struct VentedServer { global_secret_key: SecretKey, node_id: String, redirect_handles: Arc>>>, + listener_count: Arc, + num_threads: usize, } impl VentedServer { @@ -89,6 +92,7 @@ impl VentedServer { ) -> Self { let mut server = Self { node_id, + num_threads, event_handler: Arc::new(Mutex::new(EventHandler::new())), pool: executors::crossbeam_workstealing_pool::pool_with_auto_parker(num_threads), connections: Arc::new(Mutex::new(HashMap::new())), @@ -98,6 +102,7 @@ impl VentedServer { nodes.iter().cloned().map(|node| (node.id.clone(), node)), ))), redirect_handles: Arc::new(Mutex::new(HashMap::new())), + listener_count: Arc::new(AtomicUsize::new(0)), }; server.register_events(); @@ -170,6 +175,8 @@ impl VentedServer { let context = self.get_server_context(); let wg = WaitGroup::new(); let wg2 = WaitGroup::clone(&wg); + let num_threads = self.num_threads; + let listener_count = Arc::clone(&self.listener_count); thread::spawn(move || match TcpListener::bind(&address) { Ok(listener) => { @@ -179,8 +186,17 @@ impl VentedServer { for connection in listener.incoming() { match connection { Ok(stream) => { - if let Err(e) = Self::handle_connection(context.clone(), stream) { - log::error!("Failed to handle connection: {}", e); + let listener_count = listener_count.load(Ordering::Relaxed); + + if listener_count >= num_threads { + log::warn!("Connection limit reached. Shutting down incoming connection..."); + if let Err(e) = stream.shutdown(Shutdown::Both) { + log::error!("Failed to shutdown connection: {}", e) + } + } else { + if let Err(e) = Self::handle_connection(context.clone(), stream) { + log::error!("Failed to handle connection: {}", e); + } } } Err(e) => log::trace!("Failed to establish connection: {}", e), @@ -208,6 +224,7 @@ impl VentedServer { pool: self.pool.clone(), forwarded_connections: Arc::clone(&self.forwarded_connections), redirect_handles: Arc::clone(&self.redirect_handles), + listener_count: Arc::clone(&self.listener_count), } } @@ -224,6 +241,7 @@ impl VentedServer { .filter(|node| node.address.is_some()) .cloned() .collect::>(); + for node in public_nodes { let payload = RedirectPayload::new( context.node_id.clone(), @@ -255,15 +273,16 @@ impl VentedServer { /// Handles a single connection by first performing a key exchange and /// then establishing an encrypted connection fn handle_connection(params: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> { - let pool = params.pool.clone(); let event_handler = Arc::clone(¶ms.event_handler); log::trace!( "Received connection from {}", stream.peer_addr().expect("Failed to get peer address") ); - pool.execute(move || { + thread::spawn(move || { let connections = Arc::clone(¶ms.connections); + let listener_count = Arc::clone(¶ms.listener_count); + listener_count.fetch_add(1, Ordering::Relaxed); let stream = match VentedServer::get_crypto_stream(params, stream) { Ok(stream) => stream, @@ -279,6 +298,7 @@ impl VentedServer { } connections.lock().remove(stream.receiver_node()); + listener_count.fetch_sub(1, Ordering::Relaxed); }); Ok(()) @@ -337,6 +357,9 @@ impl VentedServer { params: ServerConnectionContext, stream: TcpStream, ) -> VentedResult { + stream.set_read_timeout(Some(Duration::from_secs(10)))?; + stream.set_write_timeout(Some(Duration::from_secs(10)))?; + let (node_id, stream) = VentedServer::perform_key_exchange( params.is_server, stream, @@ -362,19 +385,21 @@ impl VentedServer { context.is_server = false; let connections = Arc::clone(&context.connections); - let pool = context.pool.clone(); let event_handler = Arc::clone(&context.event_handler); + let listener_count = Arc::clone(&context.listener_count); let stream = Self::get_crypto_stream(context, stream)?; - pool.execute({ + thread::spawn({ let stream = CryptoStream::clone(&stream); move || { + listener_count.fetch_add(1, Ordering::Relaxed); event_handler.lock().handle_event(Event::new(READY_EVENT)); if let Err(e) = Self::handle_read(event_handler, &stream) { log::error!("Connection aborted: {}", e); } connections.lock().remove(stream.receiver_node()); + listener_count.fetch_sub(1, Ordering::Relaxed); } }); diff --git a/tests/test_communication.rs b/tests/test_communication.rs index e665810..a2ceab3 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -43,7 +43,7 @@ fn test_server_communication() { trusted: false, }, ]; - let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 6); + let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 2); let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes.clone(), 3); let mut server_c = VentedServer::new("C".to_string(), global_secret_c, nodes, 3); let wg = server_a.listen("localhost:22222".to_string());