From 32f15a2c89a10465fd9b9614ba87ac7b30a784f6 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 7 Nov 2020 21:06:58 +0100 Subject: [PATCH] Refactor connection function Signed-off-by: trivernis --- src/server/data.rs | 42 +++++++++++ src/server/mod.rs | 134 ++++++++++++++++++++++-------------- src/server/router.rs | 0 src/server/server_events.rs | 3 +- tests/test_communication.rs | 35 ++++++++-- 5 files changed, 155 insertions(+), 59 deletions(-) create mode 100644 src/server/router.rs diff --git a/src/server/data.rs b/src/server/data.rs index a47315a..a02f1f0 100644 --- a/src/server/data.rs +++ b/src/server/data.rs @@ -1,9 +1,11 @@ use crate::crypto::CryptoStream; use crate::event_handler::EventHandler; +use crate::WaitGroup; use crypto_box::SecretKey; use parking_lot::Mutex; use scheduled_thread_pool::ScheduledThreadPool; use std::collections::HashMap; +use std::mem; use std::sync::Arc; use x25519_dalek::PublicKey; @@ -22,5 +24,45 @@ pub(crate) struct ServerConnectionContext { pub known_nodes: Arc>>, pub event_handler: Arc>, pub connections: Arc>>, + pub forwarded_connections: Arc>>>, pub listener_pool: Arc>, } + +#[derive(Clone)] +pub(crate) struct Future { + value: Arc>>, + wg: Option, +} + +impl Future { + /// Creates the future with no value + pub fn new() -> Self { + Self { + value: Arc::new(Mutex::new(None)), + wg: Some(WaitGroup::new()), + } + } + + /// Creates the future with an already resolved value + pub fn with_value(value: T) -> Self { + Self { + value: Arc::new(Mutex::new(Some(value))), + wg: None, + } + } + + /// Sets the value of the future consuming the wait group + pub fn set_value(&mut self, value: T) { + self.value.lock().replace(value); + mem::take(&mut self.wg); + } + + /// Returns the value of the future after it has been set. + /// This call blocks + pub fn get_value(&mut self) -> T { + if let Some(wg) = mem::take(&mut self.wg) { + wg.wait(); + } + self.value.lock().take().unwrap() + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 85b53f9..6c215bc 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -9,7 +9,7 @@ use crate::event::Event; use crate::event_handler::EventHandler; use crate::result::VentedError::UnknownNode; use crate::result::{VentedError, VentedResult}; -use crate::server::data::{Node, ServerConnectionContext}; +use crate::server::data::{Future, Node, ServerConnectionContext}; use crate::server::server_events::{ AuthPayload, ChallengePayload, NodeInformationPayload, VersionMismatchPayload, ACCEPT_EVENT, AUTH_EVENT, CHALLENGE_EVENT, CONNECT_EVENT, MISMATCH_EVENT, READY_EVENT, REJECT_EVENT, @@ -27,6 +27,9 @@ pub mod server_events; pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION"); +type ForwardFutureVector = Arc>>>; +type CryptoStreamMap = Arc>>; + /// The vented server that provides parallel handling of connections /// Usage: /// ```rust @@ -58,7 +61,8 @@ pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION"); /// server.emit("B".to_string(), Event::new("ping".to_string())).unwrap(); /// ``` pub struct VentedServer { - connections: Arc>>, + connections: CryptoStreamMap, + forwarded_connections: ForwardFutureVector, known_nodes: Arc>>, listener_pool: Arc>, sender_pool: Arc>, @@ -90,6 +94,7 @@ impl VentedServer { num_threads, ))), connections: Arc::new(Mutex::new(HashMap::new())), + forwarded_connections: Arc::new(Mutex::new(HashMap::new())), global_secret_key: secret_key, known_nodes: Arc::new(Mutex::new(nodes)), } @@ -109,42 +114,22 @@ impl VentedServer { /// The actual writing is done in a separate thread from the thread pool. /// With the returned wait group one can wait for the event to be written. pub fn emit(&self, node_id: String, event: Event) -> VentedResult { - let handler = self.connections.lock().get(&node_id).cloned(); let wg = WaitGroup::new(); - let wg2 = WaitGroup::clone(&wg); + let stream = self.get_connection(node_id)?; - if let Some(handler) = handler { + self.sender_pool.lock().execute({ + let wg = WaitGroup::clone(&wg); let connections = Arc::clone(&self.connections); - self.sender_pool.lock().execute(move || { - if let Err(e) = handler.send(event) { + move || { + if let Err(e) = stream.send(event) { log::error!("Failed to send event: {}", e); - connections.lock().remove(handler.receiver_node()); + connections.lock().remove(stream.receiver_node()); } std::mem::drop(wg); - }); - Ok(wg2) - } else { - let found_node = self - .known_nodes - .lock() - .iter() - .find(|n| n.id == node_id) - .cloned(); - if let Some(node) = found_node { - if let Some(address) = &node.address { - let handler = self.connect(address.clone())?; - self.sender_pool.lock().execute(move || { - handler.send(event).expect("Failed to send event"); - std::mem::drop(wg); - }); - Ok(wg2) - } else { - Err(VentedError::NotAServer(node_id)) - } - } else { - Err(VentedError::UnknownNode(node_id)) } - } + }); + + Ok(wg) } /// Adds a handler for the given event. @@ -171,6 +156,7 @@ impl VentedServer { Ok(listener) => { log::info!("Listener running on {}", address); std::mem::drop(wg); + for connection in listener.incoming() { match connection { Ok(stream) => { @@ -201,6 +187,7 @@ impl VentedServer { connections: Arc::clone(&self.connections), event_handler: Arc::clone(&self.event_handler), listener_pool: Arc::clone(&self.listener_pool), + forwarded_connections: Arc::clone(&self.forwarded_connections), } } @@ -209,9 +196,14 @@ impl VentedServer { fn handle_connection(params: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> { let pool = Arc::clone(¶ms.listener_pool); let event_handler = Arc::clone(¶ms.event_handler); + log::trace!( + "Received connection from {}", + stream.peer_addr().expect("Failed to get peer address") + ); pool.lock().execute(move || { let connections = Arc::clone(¶ms.connections); + let stream = match VentedServer::get_crypto_stream(params, stream) { Ok(stream) => stream, Err(e) => { @@ -219,23 +211,65 @@ impl VentedServer { return; } }; - event_handler - .lock() - .handle_event(Event::new(READY_EVENT.to_string())); - while let Ok(event) = stream.read() { - if let Some(response) = event_handler.lock().handle_event(event) { - if let Err(e) = stream.send(response) { - log::error!("Failed to send response event: {}", e); - break; - } - } + log::trace!("Secure connection established."); + event_handler.lock().handle_event(Event::new(READY_EVENT)); + if let Err(e) = Self::handle_read(event_handler, &stream) { + log::error!("Connection aborted: {}", e); } + connections.lock().remove(stream.receiver_node()); }); Ok(()) } + /// Handler for reading after the connection is established + fn handle_read( + event_handler: Arc>, + stream: &CryptoStream, + ) -> VentedResult<()> { + while let Ok(event) = stream.read() { + if let Some(response) = event_handler.lock().handle_event(event) { + stream.send(response)? + } + } + + Ok(()) + } + + /// Takes three attempts to retrieve a connection for the given node. + /// First it tries to use the already established connection stored in the shared connections vector. + /// If that fails it tries to establish a new connection to the node by using the known address + fn get_connection(&self, target: String) -> VentedResult { + log::trace!("Trying to connect to {}", target); + if let Some(stream) = self.connections.lock().get(&target) { + log::trace!("Reusing existing connection."); + return Ok(CryptoStream::clone(stream)); + } + + let target_node = { + self.known_nodes + .lock() + .iter() + .find(|node| node.id == target) + .cloned() + .ok_or(VentedError::UnknownNode(target.clone()))? + }; + if let Some(address) = target_node.address { + log::trace!("Connecting to known address"); + match self.connect(address) { + Ok(stream) => { + return Ok(stream); + } + Err(e) => log::error!("Failed to connect to node '{}': {}", target, e), + } + } + + log::debug!("All connection attempts to {} failed!", target); + + Err(VentedError::NotAServer(target)) + } + /// Establishes a crypto stream for the given stream fn get_crypto_stream( params: ServerConnectionContext, @@ -264,25 +298,20 @@ impl VentedServer { context.is_server = false; let connections = Arc::clone(&context.connections); - let stream = Self::get_crypto_stream(context, stream)?; + let stream = Self::get_crypto_stream(context.clone(), stream)?; + self.listener_pool.lock().execute({ let stream = CryptoStream::clone(&stream); let event_handler = Arc::clone(&self.event_handler); + event_handler.lock().handle_event(Event::new(READY_EVENT)); + move || { - while let Ok(event) = stream.read() { - if let Some(response) = event_handler.lock().handle_event(event) { - if let Err(e) = stream.send(response) { - log::error!("Failed to send response event: {}", e); - break; - } - } + if let Err(e) = Self::handle_read(event_handler, &stream) { + log::error!("Connection aborted: {}", e); } connections.lock().remove(stream.receiver_node()); } }); - self.event_handler - .lock() - .handle_event(Event::new(READY_EVENT.to_string())); Ok(stream) } @@ -336,6 +365,7 @@ impl VentedServer { )?; stream.flush()?; let event = Event::from_bytes(&mut stream)?; + if event.name != CONNECT_EVENT { return Err(VentedError::UnexpectedEvent(event.name)); } diff --git a/src/server/router.rs b/src/server/router.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/server/server_events.rs b/src/server/server_events.rs index c8cd847..9b7d390 100644 --- a/src/server/server_events.rs +++ b/src/server/server_events.rs @@ -6,8 +6,7 @@ pub(crate) const CHALLENGE_EVENT: &str = "conn:challenge"; pub(crate) const ACCEPT_EVENT: &str = "conn:accept"; pub(crate) const REJECT_EVENT: &str = "conn:reject"; pub(crate) const MISMATCH_EVENT: &str = "conn:reject_version_mismatch"; - -pub const READY_EVENT: &str = "connection:ready"; +pub const READY_EVENT: &str = "server:ready"; #[derive(Serialize, Deserialize, Debug)] pub(crate) struct NodeInformationPayload { diff --git a/tests/test_communication.rs b/tests/test_communication.rs index 67d49aa..27dcb1a 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -16,11 +16,14 @@ fn setup() { fn test_server_communication() { setup(); let ping_count = Arc::new(AtomicUsize::new(0)); + let ping_c_count = Arc::new(AtomicUsize::new(0)); let pong_count = Arc::new(AtomicUsize::new(0)); let ready_count = Arc::new(AtomicUsize::new(0)); let mut rng = rand::thread_rng(); let global_secret_a = SecretKey::generate(&mut rng); let global_secret_b = SecretKey::generate(&mut rng); + let global_secret_c = SecretKey::generate(&mut rng); + let nodes = vec![ Node { id: "A".to_string(), @@ -32,9 +35,15 @@ fn test_server_communication() { address: None, public_key: global_secret_b.public_key(), }, + Node { + id: "C".to_string(), + address: None, + public_key: global_secret_c.public_key(), + }, ]; - let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 4); - let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes, 4); + let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 6); + let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes.clone(), 3); + let mut server_c = VentedServer::new("C".to_string(), global_secret_c, nodes, 3); let wg = server_a.listen("localhost:22222".to_string()); wg.wait(); @@ -67,7 +76,19 @@ fn test_server_communication() { None } }); - for _ in 0..10 { + server_c.on("ping", { + let ping_c_count = Arc::clone(&ping_c_count); + move |_| { + ping_c_count.fetch_add(1, Ordering::Relaxed); + println!("C RECEIVED A PING!"); + None + } + }); + let wg = server_c + .emit("A".to_string(), Event::new("ping".to_string())) + .unwrap(); + wg.wait(); + for _ in 0..9 { let wg = server_b .emit("A".to_string(), Event::new("ping".to_string())) .unwrap(); @@ -77,13 +98,17 @@ fn test_server_communication() { .emit("B".to_string(), Event::new("pong".to_string())) .unwrap(); wg.wait(); + assert!(server_b + .emit("C".to_string(), Event::new("ping".to_string())) + .is_err()); // wait one second to make sure the servers were able to process the events for _ in 0..100 { thread::sleep(Duration::from_millis(10)); } - assert_eq!(ready_count.load(Ordering::SeqCst), 2); + assert_eq!(ping_c_count.load(Ordering::SeqCst), 0); + assert_eq!(ready_count.load(Ordering::SeqCst), 3); assert_eq!(ping_count.load(Ordering::SeqCst), 10); - assert_eq!(pong_count.load(Ordering::SeqCst), 11); + assert_eq!(pong_count.load(Ordering::SeqCst), 10); }