diff --git a/Cargo.toml b/Cargo.toml index 1b9b1e7..db3e83b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vented" description = "Event driven encrypted tcp communicaton" -version = "0.9.2" +version = "0.10.0" authors = ["trivernis "] edition = "2018" readme = "README.md" @@ -17,7 +17,7 @@ rmp-serde = "0.14.4" serde = { version = "1.0.117", features = ["serde_derive"] } byteorder = "1.3.4" parking_lot = "0.11.0" -executors = "0.8.0" +scheduled-thread-pool = "0.2.5" log = "0.4.11" crypto_box = "0.5.0" rand = "0.7.3" @@ -26,6 +26,7 @@ generic-array = "0.14.4" typenum = "1.12.0" x25519-dalek = "1.1.0" crossbeam-utils = "0.8.0" +crossbeam-channel = "0.5.0" [dev-dependencies] simple_logger = "1.11.0" \ No newline at end of file diff --git a/src/event/mod.rs b/src/event/mod.rs index a9ea2ac..540a660 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -1,10 +1,11 @@ use std::io::Read; -use crate::utils::result::{VentedError, VentedResult}; use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use crate::utils::result::{VentedError, VentedResult}; + pub trait GenericEvent {} #[cfg(test)] diff --git a/src/event_handler/mod.rs b/src/event_handler/mod.rs index cf58e20..0531bec 100644 --- a/src/event_handler/mod.rs +++ b/src/event_handler/mod.rs @@ -20,8 +20,8 @@ impl EventHandler { /// Adds a handler for the given event pub fn on(&mut self, event_name: &str, handler: F) - where - F: Fn(Event) -> Option + Send + Sync, + where + F: Fn(Event) -> Option + Send + Sync, { match self.event_handlers.get_mut(event_name) { Some(handlers) => handlers.push(Box::new(handler)), diff --git a/src/event_handler/tests.rs b/src/event_handler/tests.rs index 7845141..601f5ab 100644 --- a/src/event_handler/tests.rs +++ b/src/event_handler/tests.rs @@ -1,5 +1,5 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use crate::event::Event; use crate::event_handler::EventHandler; diff --git a/src/lib.rs b/src/lib.rs index f7a3996..154ecb4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ -pub mod crypto; +pub use crossbeam_utils::sync::WaitGroup; + pub mod event; pub mod event_handler; pub mod server; +pub mod stream; pub mod utils; -pub use crossbeam_utils::sync::WaitGroup; diff --git a/src/server/data.rs b/src/server/data.rs index a01cb3c..660288f 100644 --- a/src/server/data.rs +++ b/src/server/data.rs @@ -1,16 +1,17 @@ -use crate::crypto::CryptoStream; -use crate::event_handler::EventHandler; -use crate::utils::result::VentedError; -use crate::utils::sync::AsyncValue; -use crypto_box::SecretKey; -use executors::crossbeam_workstealing_pool; -use executors::parker::DynParker; -use parking_lot::Mutex; use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; + +use crypto_box::SecretKey; +use parking_lot::Mutex; +use scheduled_thread_pool::ScheduledThreadPool; use x25519_dalek::PublicKey; +use crate::event_handler::EventHandler; +use crate::stream::cryptostream::CryptoStream; +use crate::stream::manager::ConcurrentStreamManager; +use crate::utils::result::VentedError; +use crate::utils::sync::AsyncValue; + #[derive(Clone, Debug)] pub struct Node { pub id: String, @@ -26,9 +27,8 @@ pub(crate) struct ServerConnectionContext { pub global_secret: SecretKey, pub known_nodes: Arc>>, pub event_handler: Arc>, - pub connections: Arc>>, pub forwarded_connections: Arc>>>, - pub pool: crossbeam_workstealing_pool::ThreadPool, + pub pool: Arc>, pub redirect_handles: Arc>>>, - pub listener_count: Arc, + pub manager: ConcurrentStreamManager, } diff --git a/src/server/mod.rs b/src/server/mod.rs index 15c3d77..cdfcd6e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,10 +1,19 @@ use std::collections::HashMap; -use std::net::{Shutdown, TcpListener, TcpStream}; +use std::io::Write; +use std::iter::FromIterator; +use std::mem; +use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; +use crossbeam_utils::sync::WaitGroup; use crypto_box::{PublicKey, SecretKey}; -use executors::{crossbeam_workstealing_pool, Executor}; +use parking_lot::Mutex; +use scheduled_thread_pool::ScheduledThreadPool; +use sha2::Digest; +use x25519_dalek::StaticSecret; -use crate::crypto::CryptoStream; use crate::event::Event; use crate::event_handler::EventHandler; use crate::server::data::{Node, ServerConnectionContext}; @@ -13,19 +22,10 @@ use crate::server::server_events::{ ACCEPT_EVENT, AUTH_EVENT, CHALLENGE_EVENT, CONNECT_EVENT, MISMATCH_EVENT, READY_EVENT, REDIRECT_EVENT, REJECT_EVENT, }; +use crate::stream::cryptostream::CryptoStream; +use crate::stream::manager::ConcurrentStreamManager; use crate::utils::result::{VentedError, VentedResult}; use crate::utils::sync::AsyncValue; -use crossbeam_utils::sync::WaitGroup; -use executors::parker::DynParker; -use parking_lot::Mutex; -use sha2::Digest; -use std::io::Write; -use std::iter::FromIterator; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::thread; -use std::time::Duration; -use x25519_dalek::StaticSecret; pub mod data; pub mod server_events; @@ -33,14 +33,13 @@ pub mod server_events; pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION"); type ForwardFutureVector = Arc>>>; -type CryptoStreamMap = Arc>>; /// The vented server that provides parallel handling of connections /// Usage: /// ```rust /// use vented::server::VentedServer; /// use vented::server::data::Node; -/// use vented::crypto::SecretKey; +/// use vented::stream::SecretKey; /// use rand::thread_rng; /// use vented::event::Event; /// @@ -56,7 +55,7 @@ type CryptoStreamMap = Arc>>; /// // in a real world example the secret key needs to be loaded from somewhere because connections /// // with unknown keys are not accepted. /// let global_secret = SecretKey::generate(&mut thread_rng()); -/// let mut server = VentedServer::new("A".to_string(), global_secret, nodes.clone(), 4); +/// let mut server = VentedServer::new("A".to_string(), global_secret, nodes.clone(), 4, 100); /// /// /// server.listen("localhost:20000".to_string()); @@ -65,19 +64,17 @@ type CryptoStreamMap = Arc>>; /// /// None // the return value is the response event Option /// }); -/// assert!(server.emit("B".to_string(), Event::new("ping".to_string())).get_value().is_err()) // this won't work without a known node B +/// assert!(server.emit("B", Event::new("ping".to_string())).get_value().is_err()) // this won't work without a known node B /// ``` pub struct VentedServer { - connections: CryptoStreamMap, forwarded_connections: ForwardFutureVector, known_nodes: Arc>>, - pool: crossbeam_workstealing_pool::ThreadPool, event_handler: Arc>, global_secret_key: SecretKey, node_id: String, redirect_handles: Arc>>>, - listener_count: Arc, - num_threads: usize, + manager: ConcurrentStreamManager, + pool: Arc>, } impl VentedServer { @@ -89,22 +86,22 @@ impl VentedServer { secret_key: SecretKey, nodes: Vec, num_threads: usize, + max_threads: usize, ) -> Self { let mut server = Self { node_id, - num_threads, + manager: ConcurrentStreamManager::new(max_threads), event_handler: Arc::new(Mutex::new(EventHandler::new())), - pool: executors::crossbeam_workstealing_pool::pool_with_auto_parker(num_threads), - connections: Arc::new(Mutex::new(HashMap::new())), forwarded_connections: Arc::new(Mutex::new(HashMap::new())), global_secret_key: secret_key, known_nodes: Arc::new(Mutex::new(HashMap::from_iter( nodes.iter().cloned().map(|node| (node.id.clone(), node)), ))), redirect_handles: Arc::new(Mutex::new(HashMap::new())), - listener_count: Arc::new(AtomicUsize::new(0)), + pool: Arc::new(Mutex::new(ScheduledThreadPool::new(num_threads))), }; server.register_events(); + server.start_event_listener(); server } @@ -126,35 +123,9 @@ impl VentedServer { /// Emits an event to the specified Node /// The actual writing is done in a separate thread from the thread pool. - /// With the returned wait group one can wait for the event to be written. - pub fn emit(&self, node_id: String, event: Event) -> AsyncValue<(), VentedError> { - let future = AsyncValue::new(); - - self.pool.execute({ - let mut future = AsyncValue::clone(&future); - let context = self.get_server_context(); - move || { - - if let Ok(stream) = Self::get_connection(context.clone(), &node_id) { - if let Err(e) = stream.send(event) { - log::error!("Failed to send event: {}", e); - context.connections.lock().remove(stream.receiver_node()); - - future.reject(e); - } else { - future.resolve(()); - } - } else { - log::trace!( - "Trying to redirect the event to a different node to be sent to target node..." - ); - let result = Self::send_event_redirected(context.clone(), node_id, event); - future.result(result); - } - } - }); - - future + /// For that reason an Async value is returned to use it to wait for the result + pub fn emit(&self, node_id: S, event: Event) -> AsyncValue<(), VentedError> { + Self::send_event(self.get_server_context(), &node_id.to_string(), event, true) } /// Adds a handler for the given event. @@ -175,8 +146,6 @@ impl VentedServer { let context = self.get_server_context(); let wg = WaitGroup::new(); let wg2 = WaitGroup::clone(&wg); - let num_threads = self.num_threads; - let listener_count = Arc::clone(&self.listener_count); thread::spawn(move || match TcpListener::bind(&address) { Ok(listener) => { @@ -186,17 +155,8 @@ impl VentedServer { for connection in listener.incoming() { match connection { Ok(stream) => { - let listener_count = listener_count.load(Ordering::Relaxed); - - if listener_count >= num_threads { - log::warn!("Connection limit reached. Shutting down incoming connection..."); - if let Err(e) = stream.shutdown(Shutdown::Both) { - log::error!("Failed to shutdown connection: {}", e) - } - } else { - if let Err(e) = Self::handle_connection(context.clone(), stream) { - log::error!("Failed to handle connection: {}", e); - } + if let Err(e) = Self::handle_connection(context.clone(), stream) { + log::error!("Failed to handle connection: {}", e); } } Err(e) => log::trace!("Failed to establish connection: {}", e), @@ -219,19 +179,87 @@ impl VentedServer { node_id: self.node_id.clone(), global_secret: self.global_secret_key.clone(), known_nodes: Arc::clone(&self.known_nodes), - connections: Arc::clone(&self.connections), event_handler: Arc::clone(&self.event_handler), - pool: self.pool.clone(), + pool: Arc::clone(&self.pool), forwarded_connections: Arc::clone(&self.forwarded_connections), redirect_handles: Arc::clone(&self.redirect_handles), - listener_count: Arc::clone(&self.listener_count), + manager: self.manager.clone(), + } + } + + /// Starts the event listener thread + fn start_event_listener(&self) { + let receiver = self.manager.receiver(); + let event_handler = Arc::clone(&self.event_handler); + let context = self.get_server_context(); + let wg = WaitGroup::new(); + + thread::spawn({ + let wg = WaitGroup::clone(&wg); + move || { + mem::drop(wg); + while let Ok((origin, event)) = receiver.recv() { + let responses = event_handler.lock().handle_event(event); + + for response in responses { + Self::send_event(context.clone(), &origin, response, true); + } + } + log::warn!("Event listener stopped!"); + } + }); + wg.wait(); + } + + /// Sends an event asynchronously to a node + /// The redirect flag is used to determine if it should be tried to redirect an event after + /// a direct sending attempt failed + fn send_event( + context: ServerConnectionContext, + target: &String, + event: Event, + redirect: bool, + ) -> AsyncValue<(), VentedError> { + if context.manager.has_connection(target) { + context.manager.send(target, event) + } else { + let future = AsyncValue::new(); + + context.pool.lock().execute({ + let mut future = AsyncValue::clone(&future); + let node_id = target.clone(); + let context = context.clone(); + + move || { + log::trace!( + "Trying to redirect the event to a different node to be sent to target node..." + ); + if let Ok(connection) = Self::get_connection(context.clone(), &node_id) { + if let Err(e) = context.manager.add_connection(connection) { + future.reject(e); + return; + } + log::trace!("Established new connection."); + let result = context.manager.send(&node_id, event).get_value(); + future.result(result); + } else if redirect { + log::trace!("Trying to send event redirected"); + let result = Self::send_event_redirected(context, &node_id, event); + future.result(result); + } else { + future.reject(VentedError::UnreachableNode(node_id)) + } + } + }); + + future } } /// Tries to send an event redirected by emitting a redirect event to all public nodes fn send_event_redirected( context: ServerConnectionContext, - target: String, + target: &String, event: Event, ) -> VentedResult<()> { let public_nodes = context @@ -255,70 +283,58 @@ impl VentedServer { .lock() .insert(payload.id, AsyncValue::clone(&future)); - if let Ok(stream) = Self::get_connection(context.clone(), &node.id) { - if let Err(e) = stream.send(Event::with_payload(REDIRECT_EVENT, &payload)) { - log::error!("Failed to send event: {}", e); - context.connections.lock().remove(stream.receiver_node()); - } + if let Err(e) = Self::send_event( + context.clone(), + &node.id, + Event::with_payload(REDIRECT_EVENT, &payload), + false, + ) + .get_value() + { + log::error!("Failed to redirect via {}: {}", node.id, e); } - if let Some(Ok(_)) = future.get_value_with_timeout(Duration::from_secs(1)) { + if let Some(Ok(_)) = future.get_value_with_timeout(Duration::from_secs(10)) { return Ok(()); } } - Err(VentedError::UnreachableNode(target)) + Err(VentedError::UnreachableNode(target.clone())) } /// Handles a single connection by first performing a key exchange and /// then establishing an encrypted connection - fn handle_connection(params: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> { - let event_handler = Arc::clone(¶ms.event_handler); + fn handle_connection(context: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> { + let event_handler = Arc::clone(&context.event_handler); log::trace!( "Received connection from {}", stream.peer_addr().expect("Failed to get peer address") ); - thread::spawn(move || { - let connections = Arc::clone(¶ms.connections); - let listener_count = Arc::clone(¶ms.listener_count); - listener_count.fetch_add(1, Ordering::Relaxed); + context.pool.lock().execute({ + let context = context.clone(); + move || { + let manager = context.manager.clone(); - let stream = match VentedServer::get_crypto_stream(params, stream) { - Ok(stream) => stream, - Err(e) => { - log::error!("Failed to establish encrypted connection: {}", e); + let stream = match VentedServer::get_crypto_stream(context, stream) { + Ok(stream) => stream, + Err(e) => { + log::error!("Failed to establish encrypted connection: {}", e); + return; + } + }; + log::trace!("Secure connection established."); + if let Err(e) = manager.add_connection(stream) { + log::trace!("Failed to add connection to manager: {}", e); return; } - }; - log::trace!("Secure connection established."); - event_handler.lock().handle_event(Event::new(READY_EVENT)); - if let Err(e) = Self::handle_read(event_handler, &stream) { - log::error!("Connection aborted: {}", e); + event_handler.lock().handle_event(Event::new(READY_EVENT)); } - - connections.lock().remove(stream.receiver_node()); - listener_count.fetch_sub(1, Ordering::Relaxed); }); Ok(()) } - /// Handler for reading after the connection is established - fn handle_read( - event_handler: Arc>, - stream: &CryptoStream, - ) -> VentedResult<()> { - while let Ok(mut event) = stream.read() { - event.origin = Some(stream.receiver_node().clone()); - for response in event_handler.lock().handle_event(event) { - stream.send(response)? - } - } - - Ok(()) - } - /// Takes three attempts to retrieve a connection for the given node. /// First it tries to use the already established connection stored in the shared connections vector. /// If that fails it tries to establish a new connection to the node by using the known address @@ -326,14 +342,6 @@ impl VentedServer { context: ServerConnectionContext, target: &String, ) -> VentedResult { - log::trace!("Trying to connect to {}", target); - - if let Some(stream) = context.connections.lock().get(target) { - log::trace!("Reusing existing connection."); - - return Ok(CryptoStream::clone(stream)); - } - let target_node = context .known_nodes .lock() @@ -354,25 +362,20 @@ impl VentedServer { /// Establishes a crypto stream for the given stream fn get_crypto_stream( - params: ServerConnectionContext, + context: ServerConnectionContext, stream: TcpStream, ) -> VentedResult { stream.set_read_timeout(Some(Duration::from_secs(10)))?; stream.set_write_timeout(Some(Duration::from_secs(10)))?; - let (node_id, stream) = VentedServer::perform_key_exchange( - params.is_server, + let (_, stream) = VentedServer::perform_key_exchange( + context.is_server, stream, - params.node_id.clone(), - params.global_secret, - params.known_nodes, + context.node_id.clone(), + context.global_secret, + context.known_nodes, )?; - params - .connections - .lock() - .insert(node_id, CryptoStream::clone(&stream)); - Ok(stream) } @@ -383,26 +386,8 @@ impl VentedServer { ) -> VentedResult { let stream = TcpStream::connect(address)?; context.is_server = false; - - let connections = Arc::clone(&context.connections); - let event_handler = Arc::clone(&context.event_handler); - let listener_count = Arc::clone(&context.listener_count); let stream = Self::get_crypto_stream(context, stream)?; - thread::spawn({ - let stream = CryptoStream::clone(&stream); - - move || { - listener_count.fetch_add(1, Ordering::Relaxed); - event_handler.lock().handle_event(Event::new(READY_EVENT)); - if let Err(e) = Self::handle_read(event_handler, &stream) { - log::error!("Connection aborted: {}", e); - } - connections.lock().remove(stream.receiver_node()); - listener_count.fetch_sub(1, Ordering::Relaxed); - } - }); - Ok(stream) } diff --git a/src/server/server_events.rs b/src/server/server_events.rs index 4f5f569..af9b813 100644 --- a/src/server/server_events.rs +++ b/src/server/server_events.rs @@ -1,12 +1,13 @@ +use std::sync::Arc; + +use rand::{thread_rng, RngCore}; +use serde::{Deserialize, Serialize}; +use x25519_dalek::PublicKey; + use crate::event::Event; use crate::server::data::Node; use crate::server::VentedServer; use crate::utils::result::VentedError; -use executors::Executor; -use rand::{thread_rng, RngCore}; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use x25519_dalek::PublicKey; pub(crate) const CONNECT_EVENT: &str = "conn:connect"; pub(crate) const AUTH_EVENT: &str = "conn:authenticate"; @@ -21,7 +22,7 @@ pub(crate) const REDIRECT_REDIRECTED_EVENT: &str = "conn:redirect_redirected"; pub const NODE_LIST_REQUEST_EVENT: &str = "conn:node_list_request"; pub const NODE_LIST_EVENT: &str = "conn:node_list"; -pub const READY_EVENT: &str = "server:ready"; +pub(crate) const READY_EVENT: &str = "server:ready"; #[derive(Serialize, Deserialize, Debug)] pub(crate) struct NodeInformationPayload { @@ -120,59 +121,73 @@ impl VentedServer { } }); self.on(REDIRECT_EVENT, { - let connections = Arc::clone(&self.connections); + let manager = self.manager.clone(); + let pool = Arc::clone(&self.pool); move |event| { let payload = event.get_payload::().ok()?; - let stream = connections.lock().get(&payload.target)?.clone(); - if stream - .send(Event::with_payload(REDIRECT_REDIRECTED_EVENT, &payload)) - .is_ok() - { - Some(Event::with_payload( - REDIRECT_CONFIRM_EVENT, - &RedirectResponsePayload { id: payload.id }, - )) - } else { - Some(Event::with_payload( - REDIRECT_FAIL_EVENT, - &RedirectResponsePayload { id: payload.id }, - )) - } + let origin = event.origin?; + let manager = manager.clone(); + + pool.lock().execute(move || { + let response = if manager + .send( + &payload.target, + Event::with_payload(REDIRECT_REDIRECTED_EVENT, &payload), + ) + .get_value() + .is_ok() + { + Event::with_payload( + REDIRECT_CONFIRM_EVENT, + &RedirectResponsePayload { id: payload.id }, + ) + } else { + Event::with_payload( + REDIRECT_FAIL_EVENT, + &RedirectResponsePayload { id: payload.id }, + ) + }; + manager.send(&origin, response); + }); + + None } }); self.on(REDIRECT_REDIRECTED_EVENT, { let event_handler = Arc::clone(&self.event_handler); - let connections = Arc::clone(&self.connections); + let manager = self.manager.clone(); let pool = self.pool.clone(); let known_nodes = Arc::clone(&self.known_nodes); move |event| { let payload = event.get_payload::().ok()?; let event = Event::from_bytes(&mut &payload.content[..]).ok()?; - let proxy_stream = connections.lock().get(&payload.proxy)?.clone(); if known_nodes.lock().contains_key(&payload.source) { - pool.execute({ + pool.lock().execute({ let event_handler = Arc::clone(&event_handler); + let manager = manager.clone(); move || { - let response = event_handler.lock().handle_event(event); - let event = response.first().cloned().map(|mut value| { - Event::with_payload( - REDIRECT_EVENT, - &RedirectPayload::new( - payload.target, - payload.proxy, - payload.source, - value.as_bytes(), - ), - ) - }); - if let Some(event) = event { - proxy_stream - .send(event) - .expect("Failed to respond to redirected event."); - } + let responses = event_handler.lock().handle_event(event); + responses + .iter() + .cloned() + .map(|mut value| { + let payload = payload.clone(); + Event::with_payload( + REDIRECT_EVENT, + &RedirectPayload::new( + payload.target, + payload.proxy, + payload.source, + value.as_bytes(), + ), + ) + }) + .for_each(|event| { + manager.send(&payload.proxy, event); + }); } }); } diff --git a/src/crypto/mod.rs b/src/stream/cryptostream.rs similarity index 96% rename from src/crypto/mod.rs rename to src/stream/cryptostream.rs index 520868e..2575171 100644 --- a/src/crypto/mod.rs +++ b/src/stream/cryptostream.rs @@ -3,18 +3,16 @@ use std::net::TcpStream; use std::sync::Arc; use byteorder::{BigEndian, ByteOrder}; +use crypto_box::{ChaChaBox, SecretKey}; use crypto_box::aead::{Aead, Payload}; +use generic_array::GenericArray; use parking_lot::Mutex; -use sha2::digest::generic_array::GenericArray; use sha2::Digest; -use typenum::U24; +use typenum::*; +use x25519_dalek::PublicKey; use crate::event::Event; - use crate::utils::result::VentedResult; -use crypto_box::ChaChaBox; -pub use crypto_box::PublicKey; -pub use crypto_box::SecretKey; /// A cryptographical stream object that handles encryption and decryption of streams #[derive(Clone)] @@ -104,16 +102,16 @@ impl CryptoStream { } pub struct EncryptionBox -where - T: Aead, + where + T: Aead, { inner: T, counter: u128, } impl EncryptionBox -where - T: Aead, + where + T: Aead, { /// Creates a new encryption box with the given inner value pub fn new(inner: T) -> Self { diff --git a/src/stream/manager.rs b/src/stream/manager.rs new file mode 100644 index 0000000..8b0171a --- /dev/null +++ b/src/stream/manager.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; +use std::mem; +use std::sync::Arc; +use std::thread; +use std::thread::{JoinHandle, ThreadId}; +use std::time::Duration; + +use crossbeam_channel::{Receiver, Sender}; +use parking_lot::Mutex; + +use crate::event::Event; +use crate::stream::cryptostream::CryptoStream; +use crate::utils::result::{VentedError, VentedResult}; +use crate::utils::sync::AsyncValue; +use crate::WaitGroup; + +const MAX_ENQUEUED_EVENTS: usize = 50; +const SEND_TIMEOUT_SECONDS: u64 = 60; + +#[derive(Clone, Debug)] +pub struct ConcurrentStreamManager { + max_threads: usize, + threads: Arc>>>, + emitters: Arc)>>>>, + event_receiver: Receiver<(String, Event)>, + listener_sender: Sender<(String, Event)>, +} + +impl ConcurrentStreamManager { + pub fn new(max_threads: usize) -> Self { + let (sender, receiver) = crossbeam_channel::unbounded(); + + Self { + max_threads, + threads: Arc::new(Mutex::new(HashMap::new())), + emitters: Arc::new(Mutex::new(HashMap::new())), + event_receiver: receiver, + listener_sender: sender, + } + } + + /// Returns if the manager has a connection to the given node + pub fn has_connection(&self, node: &String) -> bool { + self.emitters.lock().contains_key(node) + } + + /// Returns the receiver for events + pub fn receiver(&self) -> Receiver<(String, Event)> { + self.event_receiver.clone() + } + + /// Sends an event and returns an async value with the result + pub fn send(&self, target: &String, event: Event) -> AsyncValue<(), VentedError> { + let mut value = AsyncValue::new(); + if let Some(emitter) = self.emitters.lock().get(target) { + if let Err(_) = emitter.send_timeout( + (event, value.clone()), + Duration::from_secs(SEND_TIMEOUT_SECONDS), + ) { + value.reject(VentedError::UnreachableNode(target.clone())); + } + } else { + value.reject(VentedError::UnknownNode(target.clone())) + } + + value + } + + /// Adds a connection to the manager causing it to start two new threads + /// This call blocks until the two threads are started up + pub fn add_connection(&self, stream: CryptoStream) -> VentedResult<()> { + if self.threads.lock().len() > self.max_threads { + return Err(VentedError::TooManyThreads); + } + let sender = self.listener_sender.clone(); + let recv_id = stream.receiver_node().clone(); + let (emitter, receiver) = crossbeam_channel::bounded(MAX_ENQUEUED_EVENTS); + self.emitters.lock().insert(recv_id.clone(), emitter); + let wg = WaitGroup::new(); + + let sender_thread = thread::Builder::new() + .name(format!("sender-{}", stream.receiver_node())) + .spawn({ + let stream = stream.clone(); + let recv_id = recv_id.clone(); + let emitters = Arc::clone(&self.emitters); + let threads = Arc::clone(&self.threads); + let wg = WaitGroup::clone(&wg); + + move || { + mem::drop(wg); + while let Ok((event, mut future)) = receiver.recv() { + if let Err(e) = stream.send(event) { + log::debug!("Failed to send event to {}: {}", recv_id, e); + future.reject(e); + break; + } + future.resolve(()); + } + emitters.lock().remove(&recv_id); + threads.lock().remove(&thread::current().id()); + } + })?; + self.threads + .lock() + .insert(sender_thread.thread().id(), sender_thread); + + let receiver_thread = thread::Builder::new() + .name(format!("receiver-{}", stream.receiver_node())) + .spawn({ + let threads = Arc::clone(&self.threads); + let wg = WaitGroup::clone(&wg); + move || { + mem::drop(wg); + while let Ok(mut event) = stream.read() { + event.origin = Some(stream.receiver_node().clone()); + + if let Err(e) = sender.send((stream.receiver_node().clone(), event)) { + log::trace!("Failed to get event from {}: {}", recv_id, e); + break; + } + } + threads.lock().remove(&thread::current().id()); + } + })?; + self.threads + .lock() + .insert(receiver_thread.thread().id(), receiver_thread); + wg.wait(); + + Ok(()) + } +} diff --git a/src/stream/mod.rs b/src/stream/mod.rs new file mode 100644 index 0000000..383559f --- /dev/null +++ b/src/stream/mod.rs @@ -0,0 +1,5 @@ +pub use crypto_box::PublicKey; +pub use crypto_box::SecretKey; + +pub mod cryptostream; +pub mod manager; diff --git a/src/utils/result.rs b/src/utils/result.rs index d716b9e..dc6bc98 100644 --- a/src/utils/result.rs +++ b/src/utils/result.rs @@ -1,7 +1,8 @@ -use crate::server::CRATE_VERSION; use std::error::Error; use std::{fmt, io}; +use crate::server::CRATE_VERSION; + pub type VentedResult = Result; #[derive(Debug)] @@ -18,6 +19,7 @@ pub enum VentedError { Rejected, AuthFailed, VersionMismatch(String), + TooManyThreads, } impl fmt::Display for VentedError { @@ -39,6 +41,7 @@ impl fmt::Display for VentedError { CRATE_VERSION, version ), Self::UnreachableNode(node) => write!(f, "Node {} can't be reached", node), + Self::TooManyThreads => write!(f, "Could not start threads. Thread limit reached."), } } } diff --git a/src/utils/sync.rs b/src/utils/sync.rs index a77399b..080ce6d 100644 --- a/src/utils/sync.rs +++ b/src/utils/sync.rs @@ -1,8 +1,10 @@ -use crate::WaitGroup; -use parking_lot::Mutex; +use std::{mem, thread}; use std::sync::Arc; use std::time::{Duration, Instant}; -use std::{mem, thread}; + +use parking_lot::Mutex; + +use crate::WaitGroup; pub struct AsyncValue { value: Arc>>, @@ -13,8 +15,8 @@ pub struct AsyncValue { } impl AsyncValue -where - E: std::fmt::Display, + where + E: std::fmt::Display, { /// Creates the future with no value pub fn new() -> Self { @@ -49,8 +51,8 @@ where } pub fn on_error(&mut self, cb: F) -> &mut Self - where - F: FnOnce(&E) -> () + Send + Sync + 'static, + where + F: FnOnce(&E) -> () + Send + Sync + 'static, { self.err_cb.lock().replace(Box::new(cb)); @@ -58,8 +60,8 @@ where } pub fn on_success(&mut self, cb: F) -> &mut Self - where - F: FnOnce(&V) -> () + Send + Sync + 'static, + where + F: FnOnce(&V) -> () + Send + Sync + 'static, { self.ok_cb.lock().replace(Box::new(cb)); diff --git a/tests/test_communication.rs b/tests/test_communication.rs index a2ceab3..721427f 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -5,7 +5,7 @@ use std::thread; use std::time::Duration; use vented::event::Event; use vented::server::data::Node; -use vented::server::server_events::{NODE_LIST_REQUEST_EVENT, READY_EVENT}; +use vented::server::server_events::NODE_LIST_REQUEST_EVENT; use vented::server::VentedServer; fn setup() { @@ -17,7 +17,6 @@ fn test_server_communication() { setup(); let ping_count = Arc::new(AtomicUsize::new(0)); let pong_count = Arc::new(AtomicUsize::new(0)); - let ready_count = Arc::new(AtomicUsize::new(0)); let mut rng = rand::thread_rng(); let global_secret_a = SecretKey::generate(&mut rng); let global_secret_b = SecretKey::generate(&mut rng); @@ -43,9 +42,9 @@ fn test_server_communication() { trusted: false, }, ]; - let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 2); - let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes.clone(), 3); - let mut server_c = VentedServer::new("C".to_string(), global_secret_c, nodes, 3); + let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 2, 100); + let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes.clone(), 3, 100); + let server_c = VentedServer::new("C".to_string(), global_secret_c, nodes, 3, 100); let wg = server_a.listen("localhost:22222".to_string()); wg.wait(); @@ -64,48 +63,23 @@ fn test_server_communication() { None } }); - server_a.on(READY_EVENT, { - let ready_count = Arc::clone(&ready_count); - - move |_| { - println!("Server A ready"); - ready_count.fetch_add(1, Ordering::Relaxed); - None - } - }); - server_b.on(READY_EVENT, { - let ready_count = Arc::clone(&ready_count); - move |_| { - println!("Server B ready"); - ready_count.fetch_add(1, Ordering::Relaxed); - None - } - }); - server_c.on(READY_EVENT, { - let ready_count = Arc::clone(&ready_count); - move |_| { - println!("Server C ready"); - ready_count.fetch_add(1, Ordering::Relaxed); - None - } - }); server_b - .emit("A".to_string(), Event::new(NODE_LIST_REQUEST_EVENT)) + .emit("A", Event::new(NODE_LIST_REQUEST_EVENT)) .on_success(|_| println!("Success")) .block_unwrap(); server_c - .emit("A".to_string(), Event::new("ping".to_string())) + .emit("A", Event::new("ping".to_string())) .block_unwrap(); for _ in 0..9 { server_b - .emit("A".to_string(), Event::new("ping".to_string())) + .emit("A", Event::new("ping".to_string())) .block_unwrap(); } server_a - .emit("B".to_string(), Event::new("pong".to_string())) + .emit("B", Event::new("pong".to_string())) .block_unwrap(); server_b - .emit("C".to_string(), Event::new("ping".to_string())) + .emit("C", Event::new("ping".to_string())) .block_unwrap(); // wait one second to make sure the servers were able to process the events @@ -113,7 +87,6 @@ fn test_server_communication() { thread::sleep(Duration::from_millis(10)); } - assert_eq!(ready_count.load(Ordering::SeqCst), 4); assert_eq!(ping_count.load(Ordering::SeqCst), 10); assert_eq!(pong_count.load(Ordering::SeqCst), 10); }