diff --git a/Cargo.toml b/Cargo.toml index 00b04a2..32538db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vented" description = "Event driven encrypted tcp communicaton" -version = "0.10.5" +version = "0.11.0" authors = ["trivernis "] edition = "2018" readme = "README.md" @@ -17,7 +17,6 @@ rmp-serde = "0.14.4" serde = { version = "1.0.117", features = ["serde_derive"] } byteorder = "1.3.4" parking_lot = "0.11.0" -scheduled-thread-pool = "0.2.5" log = "0.4.11" crypto_box = "0.5.0" rand = "0.7.3" @@ -26,7 +25,7 @@ generic-array = "0.14.4" typenum = "1.12.0" x25519-dalek = "1.1.0" crossbeam-utils = "0.8.0" -crossbeam-channel = "0.5.0" +async-std = "1.7.0" [dev-dependencies] simple_logger = "1.11.0" \ No newline at end of file diff --git a/README.md b/README.md index 5d67e11..1834bdb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Vented -Vented is an event based TCP server with encryption that uses message pack for payload data. +Vented is an event based asynchronous TCP server with encryption that uses message pack for payload data. ## Encryption @@ -15,31 +15,35 @@ The crate used for the key exchanges is [x25519-dalek](https://crates.io/crates/ ```rust use vented::server::VentedServer; -use vented::server::data::Node; -use vented::crypto::SecretKey; +use vented::server::data::{Node, ServerTimeouts}; +use vented::stream::SecretKey; use rand::thread_rng; use vented::event::Event; fn main() { + let global_secret_b = SecretKey::generate(&mut thread_rng()); let nodes = vec![ - Node { - id: "B".to_string(), - address: None, - public_key: global_secret_b.public_key() // load it from somewhere - }, + Node { + id: "B".to_string(), + addresses: vec![], + trusted: true, + public_key: global_secret_b.public_key() // load it from somewhere + }, ]; // in a real world example the secret key needs to be loaded from somewhere because connections // with unknown keys are not accepted. - let global_secret = SecretKey::new(&mut thread_rng()); - let mut server = VentedServer::new("A".to_string(), global_secret, nodes.clone(), 4); + let global_secret = SecretKey::generate(&mut thread_rng()); + let mut server = VentedServer::new("A".to_string(), global_secret, nodes.clone(), ServerTimeouts::default()); server.listen("localhost:20000".to_string()); server.on("pong", |_event| { - println!("Pong!"); - - None // the return value is the response event Option + Box::pin(async {println!("Pong!"); + + None + }) }); - server.emit("B".to_string(), Event::new("ping".to_string())).unwrap(); + assert!(async_std::task::block_on(server.emit("B", Event::new("ping".to_string()))).is_err()) // this won't work without a known node B + } } ``` \ No newline at end of file diff --git a/src/event/mod.rs b/src/event/mod.rs index 540a660..b90e5ca 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -1,10 +1,11 @@ -use std::io::Read; +use async_std::io::{Read, ReadExt}; use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use crate::utils::result::{VentedError, VentedResult}; +use async_std::net::TcpStream; pub trait GenericEvent {} @@ -72,17 +73,38 @@ impl Event { data } + pub fn from(buf: &mut R) -> VentedResult { + let name_length = buf.read_u16::()?; + let mut name_buf = vec![0u8; name_length as usize]; + buf.read_exact(&mut name_buf)?; + let event_name = String::from_utf8(name_buf).map_err(|_| VentedError::NameDecodingError)?; + + let payload_length = buf.read_u64::()?; + let mut payload = vec![0u8; payload_length as usize]; + buf.read_exact(&mut payload)?; + + Ok(Self { + name: event_name, + payload, + origin: None, + }) + } + /// Deserializes the message from bytes that can be read from the given reader /// The result will be the Message with the specific message payload type - pub fn from_bytes(bytes: &mut R) -> VentedResult { - let name_length = bytes.read_u16::()?; + pub async fn from_async_tcp(stream: &mut TcpStream) -> VentedResult { + let mut name_length_raw = [0u8; 2]; + stream.read_exact(&mut name_length_raw).await?; + let name_length = BigEndian::read_u16(&mut name_length_raw); let mut name_buf = vec![0u8; name_length as usize]; - bytes.read_exact(&mut name_buf)?; + stream.read_exact(&mut name_buf).await?; let event_name = String::from_utf8(name_buf).map_err(|_| VentedError::NameDecodingError)?; - let payload_length = bytes.read_u64::()?; + let mut payload_length_raw = [0u8; 8]; + stream.read_exact(&mut payload_length_raw).await?; + let payload_length = BigEndian::read_u64(&payload_length_raw); let mut payload = vec![0u8; payload_length as usize]; - bytes.read_exact(&mut payload)?; + stream.read_exact(&mut payload).await?; Ok(Self { name: event_name, diff --git a/src/event/tests.rs b/src/event/tests.rs index 6f6c9e9..3184499 100644 --- a/src/event/tests.rs +++ b/src/event/tests.rs @@ -34,7 +34,7 @@ fn it_deserializes_events() { let mut event = Event::with_payload("test".to_string(), &payload); let event_bytes = event.as_bytes(); - let deserialized_event = Event::from_bytes(&mut event_bytes.as_slice()).unwrap(); + let deserialized_event = Event::from(&mut event_bytes.as_slice()).unwrap(); assert_eq!(deserialized_event.name, "test".to_string()); assert_eq!( deserialized_event.get_payload::().unwrap(), diff --git a/src/event_handler/mod.rs b/src/event_handler/mod.rs index 0531bec..a4f0166 100644 --- a/src/event_handler/mod.rs +++ b/src/event_handler/mod.rs @@ -1,49 +1,77 @@ use std::collections::HashMap; use crate::event::Event; +use async_std::prelude::*; +use async_std::sync::Arc; +use async_std::task; +use parking_lot::Mutex; +use std::pin::Pin; #[cfg(test)] mod tests; +pub trait EventCallback: + Fn(Event) -> Pin>>> + Send + Sync +{ +} + /// A handler for events +#[derive(Clone)] pub struct EventHandler { - event_handlers: HashMap Option + Send + Sync>>>, + event_handlers: Arc< + Mutex< + HashMap< + String, + Vec< + Box< + dyn Fn(Event) -> Pin>>> + Send + Sync, + >, + >, + >, + >, + >, } impl EventHandler { /// Creates a new vented event_handler pub fn new() -> Self { Self { - event_handlers: HashMap::new(), + event_handlers: Arc::new(Mutex::new(HashMap::new())), } } /// Adds a handler for the given event pub fn on(&mut self, event_name: &str, handler: F) - where - F: Fn(Event) -> Option + Send + Sync, + where + F: Fn(Event) -> Pin>>> + Send + Sync, { - match self.event_handlers.get_mut(event_name) { + let mut handlers = self.event_handlers.lock(); + match handlers.get_mut(event_name) { Some(handlers) => handlers.push(Box::new(handler)), None => { - self.event_handlers - .insert(event_name.to_string(), vec![Box::new(handler)]); + handlers.insert(event_name.to_string(), vec![Box::new(handler)]); } } } /// Handles a single event - pub fn handle_event(&mut self, event: Event) -> Vec { - let mut response_events = Vec::new(); + pub async fn handle_event(&mut self, event: Event) -> Vec { + let mut response_events: Vec = Vec::new(); - if let Some(handlers) = self.event_handlers.get(&event.name) { + if let Some(handlers) = self.event_handlers.lock().get(&event.name) { for handler in handlers { - if let Some(e) = handler(event.clone()) { - response_events.push(e); - } + let result = handler(event.clone()); + task::block_on(async { + if let Some(e) = result.await { + response_events.push(e.clone()); + } + }) } } response_events } } + +unsafe impl Send for EventHandler {} +unsafe impl Sync for EventHandler {} diff --git a/src/event_handler/tests.rs b/src/event_handler/tests.rs index 601f5ab..6507b66 100644 --- a/src/event_handler/tests.rs +++ b/src/event_handler/tests.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use crate::event::Event; use crate::event_handler::EventHandler; +use async_std::task; #[test] fn it_handles_events() { @@ -11,39 +12,53 @@ fn it_handles_events() { { let call_count = Arc::clone(&call_count); handler.on("test", move |event| { - call_count.fetch_add(1, Ordering::Relaxed); + let call_count = Arc::clone(&call_count); + Box::pin(async move { + call_count.fetch_add(1, Ordering::Relaxed); - Some(event) + Some(event) + }) }); } { let call_count = Arc::clone(&call_count); handler.on("test", move |_event| { - call_count.fetch_add(1, Ordering::Relaxed); + let call_count = Arc::clone(&call_count); + Box::pin(async move { + call_count.fetch_add(1, Ordering::Relaxed); - None + None + }) }); } { let call_count = Arc::clone(&call_count); handler.on("test2", move |_event| { - call_count.fetch_add(1, Ordering::Relaxed); + let call_count = Arc::clone(&call_count); + Box::pin(async move { + call_count.fetch_add(1, Ordering::Relaxed); - None + None + }) }); } { let call_count = Arc::clone(&call_count); handler.on("test2", move |_event| { - call_count.fetch_add(1, Ordering::Relaxed); + let call_count = Arc::clone(&call_count); + Box::pin(async move { + call_count.fetch_add(1, Ordering::Relaxed); - None + None + }) }) } - handler.handle_event(Event::new("test".to_string())); - handler.handle_event(Event::new("test".to_string())); - handler.handle_event(Event::new("test2".to_string())); + task::block_on(async move { + handler.handle_event(Event::new("test".to_string())).await; + handler.handle_event(Event::new("test".to_string())).await; + handler.handle_event(Event::new("test2".to_string())).await; + }); assert_eq!(call_count.load(Ordering::Relaxed), 6) } diff --git a/src/lib.rs b/src/lib.rs index 154ecb4..e80517d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,9 @@ +#[macro_use] +pub mod utils; + pub use crossbeam_utils::sync::WaitGroup; pub mod event; pub mod event_handler; pub mod server; pub mod stream; -pub mod utils; - diff --git a/src/server/data.rs b/src/server/data.rs index 6f89df2..6c79824 100644 --- a/src/server/data.rs +++ b/src/server/data.rs @@ -1,18 +1,9 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use crypto_box::SecretKey; -use parking_lot::Mutex; -use scheduled_thread_pool::ScheduledThreadPool; use x25519_dalek::PublicKey; -use crate::event_handler::EventHandler; -use crate::stream::cryptostream::CryptoStream; -use crate::stream::manager::{ConcurrentStreamManager, CONNECTION_TIMEOUT_SECONDS}; -use crate::utils::result::VentedError; -use crate::utils::sync::AsyncValue; use std::time::{Duration, Instant}; +pub const CONNECTION_TIMEOUT_SECS: u64 = 10; + #[derive(Clone, Debug)] pub struct Node { pub id: String, @@ -34,21 +25,6 @@ pub enum NodeState { Unknown, } -#[derive(Clone)] -pub(crate) struct ServerConnectionContext { - pub is_server: bool, - pub node_id: String, - pub global_secret: SecretKey, - pub known_nodes: Arc>>, - pub event_handler: Arc>, - pub forwarded_connections: Arc>>>, - pub sender_pool: Arc>, - pub recv_pool: Arc>, - pub redirect_handles: Arc>>>, - pub manager: ConcurrentStreamManager, - pub timeouts: ServerTimeouts, -} - #[derive(Clone, Debug)] pub struct ServerTimeouts { pub send_timeout: Duration, @@ -58,8 +34,8 @@ pub struct ServerTimeouts { impl Default for ServerTimeouts { fn default() -> Self { Self { - send_timeout: Duration::from_secs(CONNECTION_TIMEOUT_SECONDS), - redirect_timeout: Duration::from_secs(CONNECTION_TIMEOUT_SECONDS * 2), + send_timeout: Duration::from_secs(CONNECTION_TIMEOUT_SECS), + redirect_timeout: Duration::from_secs(CONNECTION_TIMEOUT_SECS * 2), } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index a87b1ca..31618fa 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,32 +1,28 @@ +use async_std::net::{TcpListener, TcpStream}; use std::collections::HashMap; -use std::io::Write; use std::iter::FromIterator; -use std::mem; -use std::net::{TcpListener, TcpStream}; use std::sync::Arc; -use std::thread; use std::time::Instant; -use crossbeam_utils::sync::WaitGroup; use crypto_box::{PublicKey, SecretKey}; use parking_lot::Mutex; -use scheduled_thread_pool::ScheduledThreadPool; use sha2::Digest; use x25519_dalek::StaticSecret; use crate::event::Event; use crate::event_handler::EventHandler; -use crate::server::data::{Node, NodeData, NodeState, ServerConnectionContext, ServerTimeouts}; +use crate::server::data::{Node, NodeData, NodeState, ServerTimeouts}; use crate::server::server_events::{ AuthPayload, ChallengePayload, NodeInformationPayload, RedirectPayload, VersionMismatchPayload, ACCEPT_EVENT, AUTH_EVENT, CHALLENGE_EVENT, CONNECT_EVENT, MISMATCH_EVENT, READY_EVENT, REDIRECT_EVENT, REJECT_EVENT, }; use crate::stream::cryptostream::CryptoStream; -use crate::stream::manager::ConcurrentStreamManager; use crate::utils::result::{VentedError, VentedResult}; use crate::utils::sync::AsyncValue; -use std::cmp::max; +use async_std::prelude::*; +use async_std::task; +use std::pin::Pin; pub mod data; pub mod server_events; @@ -57,28 +53,28 @@ type ForwardFutureVector = Arc +/// None +/// }) /// }); -/// assert!(server.emit("B", Event::new("ping".to_string())).get_value().is_err()) // this won't work without a known node B +/// assert!(async_std::task::block_on(server.emit("B", Event::new("ping".to_string()))).is_err()) // this won't work without a known node B /// ``` +#[derive(Clone)] pub struct VentedServer { forwarded_connections: ForwardFutureVector, known_nodes: Arc>>, - event_handler: Arc>, + event_handler: EventHandler, global_secret_key: SecretKey, node_id: String, redirect_handles: Arc>>>, - manager: ConcurrentStreamManager, - sender_pool: Arc>, - receiver_pool: Arc>, timeouts: ServerTimeouts, + connections: Arc>>, } impl VentedServer { @@ -90,13 +86,11 @@ impl VentedServer { secret_key: SecretKey, nodes: Vec, timeouts: ServerTimeouts, - num_threads: usize, - max_threads: usize, ) -> Self { let mut server = Self { node_id, - manager: ConcurrentStreamManager::new(max_threads), - event_handler: Arc::new(Mutex::new(EventHandler::new())), + connections: Arc::new(Mutex::new(HashMap::new())), + event_handler: EventHandler::new(), forwarded_connections: Arc::new(Mutex::new(HashMap::new())), global_secret_key: secret_key, known_nodes: Arc::new(Mutex::new(HashMap::from_iter( @@ -106,18 +100,9 @@ impl VentedServer { .map(|node| (node.id.clone(), node.into())), ))), redirect_handles: Arc::new(Mutex::new(HashMap::new())), - sender_pool: Arc::new(Mutex::new(ScheduledThreadPool::new(max( - num_threads / 2, - 1, - )))), - receiver_pool: Arc::new(Mutex::new(ScheduledThreadPool::new(max( - num_threads / 2, - 1, - )))), timeouts, }; server.register_events(); - server.start_event_listener(); server } @@ -143,10 +128,9 @@ impl VentedServer { } /// Emits an event to the specified Node - /// The actual writing is done in a separate thread from the thread pool. - /// For that reason an Async value is returned to use it to wait for the result - pub fn emit(&self, node_id: S, event: Event) -> AsyncValue<(), VentedError> { - Self::send_event(self.get_server_context(), &node_id.to_string(), event, true) + #[inline] + pub async fn emit(&self, node_id: S, event: Event) -> VentedResult<()> { + self.send_event(&node_id.to_string(), event, true).await } /// Adds a handler for the given event. @@ -154,204 +138,129 @@ impl VentedServer { /// Multiple handlers can be registered for an event. pub fn on(&mut self, event_name: &str, handler: F) where - F: Fn(Event) -> Option + Send + Sync, + F: Fn(Event) -> Pin>>> + Send + Sync, { - self.event_handler.lock().on(event_name, handler); + self.event_handler.on(event_name, handler); } /// Starts listening on the specified address (with port!) /// This will cause a new thread to start up so that the method returns immediately /// With the returned wait group one can wait for the server to be ready. /// The method can be called multiple times to start listeners on multiple ports. - pub fn listen(&mut self, address: String) -> WaitGroup { - let context = self.get_server_context(); - let wg = WaitGroup::new(); - let wg2 = WaitGroup::clone(&wg); - - thread::spawn(move || match TcpListener::bind(&address) { - Ok(listener) => { - log::info!("Listener running on {}", address); - std::mem::drop(wg); - - for connection in listener.incoming() { - match connection { - Ok(stream) => { - if let Err(e) = Self::handle_connection(context.clone(), stream) { + pub fn listen(&self, address: String) { + let this = self.clone(); + task::spawn(async move { + let listener = match TcpListener::bind(&address).await { + Ok(l) => l, + Err(e) => { + log::error!("Failed to bind listener to address {}: {}", address, e); + return; + } + }; + log::info!("Listener running on {}", address); + while let Some(connection) = listener.incoming().next().await { + match connection { + Ok(stream) => { + let mut this = this.clone(); + task::spawn(async move { + if let Err(e) = this.handle_connection(stream).await { log::error!("Failed to handle connection: {}", e); } - } - Err(e) => log::trace!("Failed to establish connection: {}", e), - } - } - } - Err(e) => { - log::error!("Failed to bind listener: {}", e); - std::mem::drop(wg); - } - }); - - wg2 - } - - /// Returns a copy of the servers metadata - fn get_server_context(&self) -> ServerConnectionContext { - ServerConnectionContext { - is_server: true, - node_id: self.node_id.clone(), - global_secret: self.global_secret_key.clone(), - known_nodes: Arc::clone(&self.known_nodes), - event_handler: Arc::clone(&self.event_handler), - sender_pool: Arc::clone(&self.sender_pool), - forwarded_connections: Arc::clone(&self.forwarded_connections), - redirect_handles: Arc::clone(&self.redirect_handles), - manager: self.manager.clone(), - recv_pool: Arc::clone(&self.receiver_pool), - timeouts: self.timeouts.clone(), - } - } - - /// Starts the event listener thread - fn start_event_listener(&self) { - let receiver = self.manager.receiver(); - let event_handler = Arc::clone(&self.event_handler); - let context = self.get_server_context(); - let wg = WaitGroup::new(); - - thread::spawn({ - let wg = WaitGroup::clone(&wg); - move || { - mem::drop(wg); - while let Ok((origin, event)) = receiver.recv() { - if let Some(node) = context.known_nodes.lock().get_mut(&origin) { - node.set_node_state(NodeState::Alive(Instant::now())); + }); } - let responses = event_handler.lock().handle_event(event); - - for response in responses { - Self::send_event(context.clone(), &origin, response, true); + Err(e) => { + log::trace!("Failed to establish connection: {}", e); + continue; } } - log::warn!("Event listener stopped!"); } }); - wg.wait(); } /// Sends an event asynchronously to a node /// The redirect flag is used to determine if it should be tried to redirect an event after /// a direct sending attempt failed - fn send_event( - context: ServerConnectionContext, - target: &String, - event: Event, - redirect: bool, - ) -> AsyncValue<(), VentedError> { + async fn send_event(&self, target: &String, event: Event, redirect: bool) -> VentedResult<()> { log::trace!( "Emitting: '{}' from {} to {}", event.name, - context.node_id, + self.node_id, target ); - if context.manager.has_connection(target) { + let mut result = Ok(()); + let node_state = if let Ok(mut stream) = self.get_connection(target).await { log::trace!("Reusing existing connection."); - context.manager.send(target, event) - } else { - let future = AsyncValue::new(); - - context.sender_pool.lock().execute({ - let mut future = AsyncValue::clone(&future); - let node_id = target.clone(); - let context = context.clone(); - - move || { - log::trace!("Trying to establish connection..."); - let node_state = if let Ok(connection) = - Self::get_connection(context.clone(), &node_id) - { - if let Err(e) = context.manager.add_connection(connection) { - future.reject(e); - return; - } - log::trace!("Established new connection."); - let result = context.manager.send(&node_id, event).get_value(); - match result { - Ok(_) => { - future.resolve(()); - NodeState::Alive(Instant::now()) - } - Err(e) => { - future.reject(e); - NodeState::Dead(Instant::now()) - } - } - } else if redirect { - log::trace!("Trying to use a proxy node..."); - let result = Self::send_event_redirected(context.clone(), &node_id, event); - match result { - Ok(_) => { - future.resolve(()); - NodeState::Alive(Instant::now()) - } - Err(e) => { - future.reject(e); - NodeState::Dead(Instant::now()) - } - } - } else { - log::trace!("Failed to emit event to node {}", node_id); - future.reject(VentedError::UnreachableNode(node_id.clone())); - NodeState::Dead(Instant::now()) - }; - if let Some(node) = context.known_nodes.lock().get_mut(&node_id) { - node.set_node_state(node_state); - } + match stream.send(event).await { + Ok(_) => NodeState::Alive(Instant::now()), + Err(e) => { + result = Err(e); + NodeState::Dead(Instant::now()) + } + } + } else if redirect { + log::trace!("Trying to use a proxy node..."); + match self.send_event_redirected(&target, event).await { + Ok(_) => { + result = Ok(()); + NodeState::Alive(Instant::now()) + } + Err(e) => { + log::trace!("Failed to redirect: {}", e); + result = Err(e); + NodeState::Dead(Instant::now()) } - }); + } + } else { + log::trace!("Failed to emit event to node {}", target); + result = Err(VentedError::UnreachableNode(target.clone())); - future + NodeState::Dead(Instant::now()) + }; + + if let Some(node) = self.known_nodes.lock().get_mut(target) { + node.set_node_state(node_state); } + + result } /// Tries to send an event redirected by emitting a redirect event to all public nodes - fn send_event_redirected( - context: ServerConnectionContext, - target: &String, - event: Event, - ) -> VentedResult<()> { - let public_nodes = context + async fn send_event_redirected(&self, target: &String, event: Event) -> VentedResult<()> { + let connected_nodes = self .known_nodes .lock() .values() - .filter(|node| !node.node().addresses.is_empty() && node.is_alive()) + .filter(|node| node.is_alive()) .cloned() .collect::>(); - for node in public_nodes { + for node in connected_nodes { let payload = RedirectPayload::new( - context.node_id.clone(), + self.node_id.clone(), node.node().id.clone(), target.clone(), event.clone().as_bytes(), ); - let mut future = AsyncValue::new(); - context - .redirect_handles + let mut value = AsyncValue::new(); + self.redirect_handles .lock() - .insert(payload.id, AsyncValue::clone(&future)); - - if let Err(e) = Self::send_event( - context.clone(), - &node.node().id, - Event::with_payload(REDIRECT_EVENT, &payload), - false, - ) - .get_value() - { - log::error!("Failed to redirect via {}: {}", node.node().id, e); + .insert(payload.id, AsyncValue::clone(&value)); + + if let Ok(mut stream) = self.get_connection(&node.node().id).await { + if let Err(e) = stream + .send(Event::with_payload(REDIRECT_EVENT, &payload)) + .await + { + log::trace!("Failed to redirect via {}: {}", stream.receiver_node(), e); + continue; + } + } else { + continue; } - if let Some(Ok(_)) = - future.get_value_with_timeout(context.timeouts.redirect_timeout.clone()) + if let Some(Ok(_)) = value + .get_value_with_timeout_async(self.timeouts.redirect_timeout.clone()) + .await { return Ok(()); } else { @@ -364,46 +273,76 @@ impl VentedServer { /// Handles a single connection by first performing a key exchange and /// then establishing an encrypted connection - fn handle_connection(context: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> { - let event_handler = Arc::clone(&context.event_handler); - stream.set_write_timeout(Some(context.timeouts.send_timeout))?; + async fn handle_connection(&mut self, stream: TcpStream) -> VentedResult<()> { log::trace!( "Received connection from {}", stream.peer_addr().expect("Failed to get peer address") ); - context.recv_pool.lock().execute({ - let context = context.clone(); - move || { - let manager = context.manager.clone(); + let stream = self.perform_server_key_exchange(stream).await?; - let stream = match VentedServer::get_crypto_stream(context, stream) { - Ok(stream) => stream, - Err(e) => { - log::error!("Failed to establish encrypted connection: {}", e); - return; + log::trace!("Secure connection established."); + self.connections + .lock() + .insert(stream.receiver_node().clone(), stream.clone()); + self.event_handler + .handle_event(Event::new(READY_EVENT)) + .await; + Self::read_stream( + stream.clone(), + self.connections.clone(), + self.event_handler.clone(), + ) + .await; + + Ok(()) + } + + /// Reads events from the stream and removes it from the known connections when it's closed + async fn read_stream( + mut stream: CryptoStream, + connections: Arc>>, + mut handler: EventHandler, + ) { + loop { + match stream.read().await { + Ok(mut event) => { + event.origin = Some(stream.receiver_node().clone()); + let results = handler.handle_event(event).await; + for result in results { + if let Err(e) = stream.send(result).await { + log::error!( + "Failed to send event to {}: {}", + stream.receiver_node(), + e + ); + break; + } } - }; - log::trace!("Secure connection established."); - if let Err(e) = manager.add_connection(stream) { - log::trace!("Failed to add connection to manager: {}", e); - return; } - event_handler.lock().handle_event(Event::new(READY_EVENT)); + Err(e) => { + log::error!( + "Failed to read events from {}: {}", + stream.receiver_node(), + e + ); + break; + } } - }); - - Ok(()) + } + connections.lock().remove(stream.receiver_node()); } /// Takes three attempts to retrieve a connection for the given node. /// First it tries to use the already established connection stored in the shared connections vector. /// If that fails it tries to establish a new connection to the node by using the known address - fn get_connection( - context: ServerConnectionContext, - target: &String, - ) -> VentedResult { - let target_node = context + async fn get_connection(&self, target: &String) -> VentedResult { + if let Some(stream) = self.connections.lock().get(target) { + log::trace!("Reusing existing connection."); + return Ok(stream.clone()); + } + + let target_node = self .known_nodes .lock() .get(target) @@ -413,12 +352,11 @@ impl VentedServer { log::trace!("Connecting to known addresses"); for address in &target_node.node().addresses { - match Self::connect(context.clone(), address.clone()) { + match self.connect(address.clone()).await { Ok(stream) => return Ok(stream), Err(e) => { log::error!("Failed to connect to node {}'s address: {}", target, e); - context - .known_nodes + self.known_nodes .lock() .get_mut(target) .unwrap() @@ -433,84 +371,43 @@ impl VentedServer { Err(VentedError::UnreachableNode(target.clone())) } - /// Establishes a crypto stream for the given stream - fn get_crypto_stream( - context: ServerConnectionContext, - stream: TcpStream, - ) -> VentedResult { - let (_, stream) = VentedServer::perform_key_exchange( - context.is_server, - stream, - context.node_id.clone(), - context.global_secret, - context.known_nodes, - )?; - - Ok(stream) - } - /// Connects to the given address as a tcp client - fn connect( - mut context: ServerConnectionContext, - address: String, - ) -> VentedResult { - let stream = TcpStream::connect(address)?; - stream.set_write_timeout(Some(context.timeouts.send_timeout))?; - context.is_server = false; - let stream = Self::get_crypto_stream(context, stream)?; + async fn connect(&self, address: String) -> VentedResult { + let stream = TcpStream::connect(address).await?; + let stream = self.perform_client_key_exchange(stream).await?; + self.connections + .lock() + .insert(stream.receiver_node().clone(), stream.clone()); + task::spawn(Self::read_stream( + stream.clone(), + self.connections.clone(), + self.event_handler.clone(), + )); Ok(stream) } - /// Performs a key exchange - fn perform_key_exchange( - is_server: bool, - stream: TcpStream, - own_node_id: String, - global_secret: SecretKey, - known_nodes: Arc>>, - ) -> VentedResult<(String, CryptoStream)> { - let secret_key = SecretKey::generate(&mut rand::thread_rng()); - if is_server { - Self::perform_server_key_exchange( - stream, - &secret_key, - own_node_id, - global_secret, - known_nodes, - ) - } else { - Self::perform_client_key_exchange( - stream, - &secret_key, - own_node_id, - global_secret, - known_nodes, - ) - } - } - /// Performs the client side DH key exchange - fn perform_client_key_exchange( + async fn perform_client_key_exchange( + &self, mut stream: TcpStream, - secret_key: &SecretKey, - own_node_id: String, - global_secret: SecretKey, - known_nodes: Arc>>, - ) -> VentedResult<(String, CryptoStream)> { - stream.write( - &Event::with_payload( - CONNECT_EVENT, - &NodeInformationPayload { - public_key: secret_key.public_key().to_bytes(), - node_id: own_node_id, - vented_version: PROTOCOL_VERSION.to_string(), - }, + ) -> VentedResult { + let secret_key = SecretKey::generate(&mut rand::thread_rng()); + stream + .write( + &Event::with_payload( + CONNECT_EVENT, + &NodeInformationPayload { + public_key: secret_key.public_key().to_bytes(), + node_id: self.node_id.clone(), + vented_version: PROTOCOL_VERSION.to_string(), + }, + ) + .as_bytes(), ) - .as_bytes(), - )?; - stream.flush()?; - let event = Event::from_bytes(&mut stream)?; + .await?; + stream.flush().await?; + let event = Event::from_async_tcp(&mut stream).await?; if event.name != CONNECT_EVENT { return Err(VentedError::UnexpectedEvent(event.name)); @@ -522,34 +419,39 @@ impl VentedServer { } = event.get_payload::().unwrap(); if !Self::compare_version(&vented_version, PROTOCOL_VERSION) { - stream.write( - &Event::with_payload( - MISMATCH_EVENT, - &VersionMismatchPayload::new(PROTOCOL_VERSION, &vented_version), + stream + .write( + &Event::with_payload( + MISMATCH_EVENT, + &VersionMismatchPayload::new(PROTOCOL_VERSION, &vented_version), + ) + .as_bytes(), ) - .as_bytes(), - )?; - stream.flush()?; + .await?; + stream.flush().await?; return Err(VentedError::VersionMismatch(vented_version)); } let public_key = PublicKey::from(public_key); - let node_data = if let Some(data) = known_nodes.lock().get(&node_id) { + let node_data = if let Some(data) = self.known_nodes.lock().get(&node_id) { data.clone() } else { - stream.write(&Event::new(REJECT_EVENT).as_bytes())?; - stream.flush()?; + stream.write(&Event::new(REJECT_EVENT).as_bytes()).await?; + stream.flush().await?; return Err(VentedError::UnknownNode(node_id)); }; let mut stream = CryptoStream::new(node_id.clone(), stream, &public_key, &secret_key)?; log::trace!("Authenticating recipient..."); - let key_a = Self::authenticate_other(&mut stream, node_data.node().public_key)?; + let key_a = Self::authenticate_other(&mut stream, node_data.node().public_key).await?; log::trace!("Authenticating self..."); - let key_b = - Self::authenticate_self(&mut stream, StaticSecret::from(global_secret.to_bytes()))?; + let key_b = Self::authenticate_self( + &mut stream, + StaticSecret::from(self.global_secret_key.to_bytes()), + ) + .await?; log::trace!("Connection fully authenticated."); let pre_secret = StaticSecret::from(secret_key.to_bytes()).diffie_hellman(&public_key); @@ -558,19 +460,18 @@ impl VentedServer { let final_public = final_secret.public_key(); stream.update_key(&final_secret, &final_public); - Ok((node_id, stream)) + Ok(stream) } /// Performs a DH key exchange by using the crypto_box module and events /// On success it returns a secret box with the established secret and the node id of the client - fn perform_server_key_exchange( + async fn perform_server_key_exchange( + &self, mut stream: TcpStream, - secret_key: &SecretKey, - own_node_id: String, - global_secret: SecretKey, - known_nodes: Arc>>, - ) -> VentedResult<(String, CryptoStream)> { - let event = Event::from_bytes(&mut stream)?; + ) -> VentedResult { + let secret_key = SecretKey::generate(&mut rand::thread_rng()); + let event = Event::from_async_tcp(&mut stream).await?; + if event.name != CONNECT_EVENT { return Err(VentedError::UnexpectedEvent(event.name)); } @@ -581,46 +482,54 @@ impl VentedServer { } = event.get_payload::().unwrap(); if !Self::compare_version(&vented_version, PROTOCOL_VERSION) { - stream.write( - &Event::with_payload( - MISMATCH_EVENT, - &VersionMismatchPayload::new(PROTOCOL_VERSION, &vented_version), + stream + .write( + &Event::with_payload( + MISMATCH_EVENT, + &VersionMismatchPayload::new(PROTOCOL_VERSION, &vented_version), + ) + .as_bytes(), ) - .as_bytes(), - )?; - stream.flush()?; + .await?; + stream.flush().await?; return Err(VentedError::VersionMismatch(vented_version)); } let public_key = PublicKey::from(public_key); - let node_data = if let Some(data) = known_nodes.lock().get(&node_id) { - data.clone() + let data_options = self.known_nodes.lock().get(&node_id).cloned(); + let node_data = if let Some(data) = data_options { + data } else { - stream.write(&Event::new(REJECT_EVENT).as_bytes())?; - stream.flush()?; + stream.write(&Event::new(REJECT_EVENT).as_bytes()).await?; + stream.flush().await?; return Err(VentedError::UnknownNode(node_id)); }; - stream.write( - &Event::with_payload( - CONNECT_EVENT, - &NodeInformationPayload { - public_key: secret_key.public_key().to_bytes(), - node_id: own_node_id, - vented_version: PROTOCOL_VERSION.to_string(), - }, + stream + .write( + &Event::with_payload( + CONNECT_EVENT, + &NodeInformationPayload { + public_key: secret_key.public_key().to_bytes(), + node_id: self.node_id.clone(), + vented_version: PROTOCOL_VERSION.to_string(), + }, + ) + .as_bytes(), ) - .as_bytes(), - )?; - stream.flush()?; + .await?; + stream.flush().await?; let mut stream = CryptoStream::new(node_id.clone(), stream, &public_key, &secret_key)?; log::trace!("Authenticating self..."); - let key_a = - Self::authenticate_self(&mut stream, StaticSecret::from(global_secret.to_bytes()))?; + let key_a = Self::authenticate_self( + &mut stream, + StaticSecret::from(self.global_secret_key.to_bytes()), + ) + .await?; log::trace!("Authenticating recipient..."); - let key_b = Self::authenticate_other(&mut stream, node_data.node().public_key)?; + let key_b = Self::authenticate_other(&mut stream, node_data.node().public_key).await?; log::trace!("Connection fully authenticated."); let pre_secret = StaticSecret::from(secret_key.to_bytes()).diffie_hellman(&public_key); @@ -629,59 +538,63 @@ impl VentedServer { let final_public = final_secret.public_key(); stream.update_key(&final_secret, &final_public); - Ok((node_id, stream)) + Ok(stream) } /// Performs the challenged side of the authentication challenge - fn authenticate_self( - stream: &CryptoStream, + async fn authenticate_self( + stream: &mut CryptoStream, static_secret: StaticSecret, ) -> VentedResult> { - let challenge_event = stream.read()?; + let challenge_event = stream.read().await?; if challenge_event.name != CHALLENGE_EVENT { - stream.send(Event::new(REJECT_EVENT))?; + stream.send(Event::new(REJECT_EVENT)).await?; return Err(VentedError::UnexpectedEvent(challenge_event.name)); } let ChallengePayload { public_key } = challenge_event.get_payload()?; let auth_key = static_secret.diffie_hellman(&PublicKey::from(public_key)); - stream.send(Event::with_payload( - AUTH_EVENT, - &AuthPayload { - calculated_secret: auth_key.to_bytes(), - }, - ))?; + stream + .send(Event::with_payload( + AUTH_EVENT, + &AuthPayload { + calculated_secret: auth_key.to_bytes(), + }, + )) + .await?; - let response = stream.read()?; + let response = stream.read().await?; match response.name.as_str() { ACCEPT_EVENT => Ok(auth_key.to_bytes().to_vec()), REJECT_EVENT => Err(VentedError::Rejected), _ => { - stream.send(Event::new(REJECT_EVENT))?; + stream.send(Event::new(REJECT_EVENT)).await?; Err(VentedError::UnexpectedEvent(response.name)) } } } /// Authenticates the other party by using their stored public key and a generated secret - fn authenticate_other( - stream: &CryptoStream, + async fn authenticate_other( + stream: &mut CryptoStream, other_static_public: PublicKey, ) -> VentedResult> { let auth_secret = SecretKey::generate(&mut rand::thread_rng()); - stream.send(Event::with_payload( - CHALLENGE_EVENT, - &ChallengePayload { - public_key: auth_secret.public_key().to_bytes(), - }, - ))?; + stream + .send(Event::with_payload( + CHALLENGE_EVENT, + &ChallengePayload { + public_key: auth_secret.public_key().to_bytes(), + }, + )) + .await?; - let auth_event = stream.read()?; + let auth_event = stream.read().await?; if auth_event.name != AUTH_EVENT { - stream.send(Event::new(REJECT_EVENT))?; + stream.send(Event::new(REJECT_EVENT)).await?; return Err(VentedError::UnexpectedEvent(auth_event.name)); } let AuthPayload { calculated_secret } = auth_event.get_payload()?; @@ -689,10 +602,10 @@ impl VentedServer { StaticSecret::from(auth_secret.to_bytes()).diffie_hellman(&other_static_public); if expected_secret.to_bytes() != calculated_secret { - stream.send(Event::new(REJECT_EVENT))?; + stream.send(Event::new(REJECT_EVENT)).await?; Err(VentedError::AuthFailed) } else { - stream.send(Event::new(ACCEPT_EVENT))?; + stream.send(Event::new(ACCEPT_EVENT)).await?; Ok(calculated_secret.to_vec()) } } diff --git a/src/server/server_events.rs b/src/server/server_events.rs index ea71fe3..fdcc40f 100644 --- a/src/server/server_events.rs +++ b/src/server/server_events.rs @@ -103,153 +103,158 @@ impl VentedServer { self.on(REDIRECT_CONFIRM_EVENT, { let redirect_handles = Arc::clone(&self.redirect_handles); move |event| { - let payload = event.get_payload::().ok()?; - let mut future = redirect_handles.lock().remove(&payload.id)?; - future.resolve(()); - - None + let redirect_handles = Arc::clone(&redirect_handles); + Box::pin(async move { + let payload = event.get_payload::().ok()?; + let mut value = redirect_handles.lock().remove(&payload.id)?; + value.resolve(()); + None + }) } }); self.on(REDIRECT_FAIL_EVENT, { let redirect_handles = Arc::clone(&self.redirect_handles); move |event| { - let payload = event.get_payload::().ok()?; - let mut future = redirect_handles.lock().remove(&payload.id)?; - future.reject(VentedError::Rejected); + let redirect_handles = Arc::clone(&redirect_handles); + Box::pin(async move { + let payload = event.get_payload::().ok()?; + let mut value = redirect_handles.lock().remove(&payload.id)?; + value.reject(VentedError::Rejected); - None + None + }) } }); self.on(REDIRECT_EVENT, { - let manager = self.manager.clone(); - let pool = Arc::clone(&self.sender_pool); - + let connections = Arc::clone(&self.connections); move |event| { - let payload = event.get_payload::().ok()?; - let origin = event.origin?; - let manager = manager.clone(); - - pool.lock().execute(move || { - let response = if manager - .send( - &payload.target, - Event::with_payload(REDIRECT_REDIRECTED_EVENT, &payload), - ) - .get_value() - .is_ok() - { - Event::with_payload( - REDIRECT_CONFIRM_EVENT, - &RedirectResponsePayload { id: payload.id }, - ) - } else { - Event::with_payload( - REDIRECT_FAIL_EVENT, - &RedirectResponsePayload { id: payload.id }, - ) - }; - manager.send(&origin, response); - }); + let connections = Arc::clone(&connections); + Box::pin(async move { + let payload = event.get_payload::().ok()?; + if payload.source == event.origin? { + let opt_stream = connections.lock().get(&payload.target).cloned(); + if let Some(mut stream) = opt_stream { + if let Ok(_) = stream + .send(Event::with_payload(REDIRECT_REDIRECTED_EVENT, &payload)) + .await + { + return Some(Event::with_payload( + REDIRECT_CONFIRM_EVENT, + &RedirectResponsePayload { id: payload.id }, + )); + } + } + } - None + Some(Event::with_payload( + REDIRECT_FAIL_EVENT, + &RedirectResponsePayload { id: payload.id }, + )) + }) } }); self.on(REDIRECT_REDIRECTED_EVENT, { - let event_handler = Arc::clone(&self.event_handler); - let manager = self.manager.clone(); - let pool = self.sender_pool.clone(); - let known_nodes = Arc::clone(&self.known_nodes); + let event_handler = self.event_handler.clone(); + let connections = Arc::clone(&self.connections); move |event| { - let payload = event.get_payload::().ok()?; - let event = Event::from_bytes(&mut &payload.content[..]).ok()?; + let connections = Arc::clone(&connections); + let mut event_handler = event_handler.clone(); + Box::pin(async move { + let payload = event.get_payload::().ok()?; + let event = Event::from(&mut &payload.content[..]).ok()?; + let origin = event.origin.clone()?; - if known_nodes.lock().contains_key(&payload.source) { - pool.lock().execute({ - let event_handler = Arc::clone(&event_handler); - let manager = manager.clone(); - move || { - let responses = event_handler.lock().handle_event(event); - responses - .iter() - .cloned() - .map(|mut value| { - let payload = payload.clone(); - Event::with_payload( - REDIRECT_EVENT, - &RedirectPayload::new( - payload.target, - payload.proxy, - payload.source, - value.as_bytes(), - ), - ) - }) - .for_each(|event| { - manager.send(&payload.proxy, event); - }); + let responses = event_handler.handle_event(event).await; + let responses = responses + .iter() + .cloned() + .map(|mut value| { + let payload = payload.clone(); + Event::with_payload( + REDIRECT_EVENT, + &RedirectPayload::new( + payload.target, + payload.proxy, + payload.source, + value.as_bytes(), + ), + ) + }) + .collect::>(); + let opt_stream = connections.lock().get(&origin).cloned(); + if let Some(mut stream) = opt_stream { + for response in responses { + stream.send(response).await.ok()?; } - }); - } + } - None + None + }) } }); self.on(NODE_LIST_EVENT, { let node_list = Arc::clone(&self.known_nodes); - let own_id = self.node_id.clone(); + let own_node_id = self.node_id.clone(); move |event| { - let list = event.get_payload::().ok()?; - let mut own_nodes = node_list.lock(); - let origin = event.origin?; + let node_list = Arc::clone(&node_list); + let own_node_id = own_node_id.clone(); + Box::pin(async move { + let list = event.get_payload::().ok()?; + let mut own_nodes = node_list.lock(); + let origin = event.origin?; - if !own_nodes.get(&origin)?.node().trusted { - log::warn!("Untrusted node '{}' tried to send network update!", origin); - return None; - } + if !own_nodes.get(&origin)?.node().trusted { + log::warn!("Untrusted node '{}' tried to send network update!", origin); + return None; + } - let mut new_nodes = 0; - for node in list.nodes { - if !own_nodes.contains_key(&node.id) && node.id != own_id { - own_nodes.insert( - node.id.clone(), - Node { - id: node.id, - trusted: false, - public_key: PublicKey::from(node.public_key), - addresses: node.addresses, - } - .into(), - ); - new_nodes += 1; + let mut new_nodes = 0; + for node in list.nodes { + if !own_nodes.contains_key(&node.id) && node.id != own_node_id { + own_nodes.insert( + node.id.clone(), + Node { + id: node.id, + trusted: false, + public_key: PublicKey::from(node.public_key), + addresses: node.addresses, + } + .into(), + ); + new_nodes += 1; + } } - } - log::debug!("Updated node list: Added {} new nodes", new_nodes); + log::debug!("Updated node list: Added {} new nodes", new_nodes); - None + None + }) } }); self.on(NODE_LIST_REQUEST_EVENT, { let node_list = Arc::clone(&self.known_nodes); move |event| { - let sender_id = event.origin?; - let nodes = node_list - .lock() - .values() - .filter(|node| node.node().id != sender_id) - .map(|node| NodeListElement { - id: node.node().id.clone(), - addresses: node.node().addresses.clone(), - public_key: node.node().public_key.to_bytes(), - }) - .collect(); + let node_list = Arc::clone(&node_list); + Box::pin(async move { + let sender_id = event.origin?; + let nodes = node_list + .lock() + .values() + .filter(|node| node.node().id != sender_id) + .map(|node| NodeListElement { + id: node.node().id.clone(), + addresses: node.node().addresses.clone(), + public_key: node.node().public_key.to_bytes(), + }) + .collect(); - Some(Event::with_payload( - NODE_LIST_EVENT, - &NodeListPayload { nodes }, - )) + Some(Event::with_payload( + NODE_LIST_EVENT, + &NodeListPayload { nodes }, + )) + }) } }); } diff --git a/src/stream/cryptostream.rs b/src/stream/cryptostream.rs index 2ec4332..5fc0f7e 100644 --- a/src/stream/cryptostream.rs +++ b/src/stream/cryptostream.rs @@ -1,6 +1,4 @@ -use std::io::{Read, Write}; -use std::net::{Shutdown, TcpStream}; -use std::sync::Arc; +use async_std::prelude::*; use byteorder::{BigEndian, ByteOrder}; use crypto_box::aead::{Aead, Payload}; @@ -8,18 +6,19 @@ use crypto_box::{ChaChaBox, SecretKey}; use generic_array::GenericArray; use parking_lot::Mutex; use sha2::Digest; +use std::sync::Arc; use typenum::*; use x25519_dalek::PublicKey; use crate::event::Event; use crate::utils::result::VentedResult; +use async_std::net::{Shutdown, TcpStream}; /// A cryptographical stream object that handles encryption and decryption of streams #[derive(Clone)] pub struct CryptoStream { recv_node_id: String, - send_stream: Arc>, - recv_stream: Arc>, + stream: TcpStream, send_secret: Arc>>, recv_secret: Arc>>, } @@ -32,15 +31,12 @@ impl CryptoStream { public_key: &PublicKey, secret_key: &SecretKey, ) -> VentedResult { - let send_stream = Arc::new(Mutex::new(inner.try_clone()?)); - let recv_stream = Arc::new(Mutex::new(inner)); let send_box = EncryptionBox::new(ChaChaBox::new(public_key, secret_key)); let recv_box = EncryptionBox::new(ChaChaBox::new(public_key, secret_key)); Ok(Self { recv_node_id: node_id, - send_stream, - recv_stream, + stream: inner, send_secret: Arc::new(Mutex::new(send_box)), recv_secret: Arc::new(Mutex::new(recv_box)), }) @@ -50,17 +46,16 @@ impl CryptoStream { /// format: /// length: u64 /// data: length - pub fn send(&self, mut event: Event) -> VentedResult<()> { + pub async fn send(&mut self, mut event: Event) -> VentedResult<()> { let ciphertext = self.send_secret.lock().encrypt(&event.as_bytes())?; - let mut stream = self.send_stream.lock(); let mut length_raw = [0u8; 8]; BigEndian::write_u64(&mut length_raw, ciphertext.len() as u64); log::trace!("Encoded event '{}' to raw message", event.name); - stream.write(&length_raw)?; - stream.write(&ciphertext)?; - stream.flush()?; + self.stream.write(&length_raw).await?; + self.stream.write(&ciphertext).await?; + self.stream.flush().await?; log::trace!("Event sent"); @@ -68,19 +63,18 @@ impl CryptoStream { } /// Reads an event from the stream. Blocks until data is received - pub fn read(&self) -> VentedResult { - let mut stream = self.recv_stream.lock(); + pub async fn read(&mut self) -> VentedResult { let mut length_raw = [0u8; 8]; - stream.read_exact(&mut length_raw)?; + self.stream.read_exact(&mut length_raw).await?; let length = BigEndian::read_u64(&length_raw); let mut ciphertext = vec![0u8; length as usize]; - stream.read(&mut ciphertext)?; + self.stream.read(&mut ciphertext).await?; log::trace!("Received raw message"); let plaintext = self.recv_secret.lock().decrypt(&ciphertext)?; - let event = Event::from_bytes(&mut &plaintext[..])?; + let event = Event::from(&mut &plaintext[..])?; log::trace!("Decoded message to event '{}'", event.name); Ok(event) @@ -100,8 +94,8 @@ impl CryptoStream { } /// Closes both streams - pub fn shutdown(&self) -> VentedResult<()> { - self.send_stream.lock().shutdown(Shutdown::Both)?; + pub fn shutdown(&mut self) -> VentedResult<()> { + self.stream.shutdown(Shutdown::Both)?; Ok(()) } diff --git a/src/stream/manager.rs b/src/stream/manager.rs deleted file mode 100644 index e233c8f..0000000 --- a/src/stream/manager.rs +++ /dev/null @@ -1,148 +0,0 @@ -use std::collections::HashMap; -use std::mem; -use std::sync::Arc; -use std::thread; -use std::thread::{JoinHandle, ThreadId}; -use std::time::Duration; - -use crossbeam_channel::{Receiver, Sender}; -use parking_lot::Mutex; - -use crate::event::Event; -use crate::stream::cryptostream::CryptoStream; -use crate::utils::result::{VentedError, VentedResult}; -use crate::utils::sync::AsyncValue; -use crate::WaitGroup; - -const MAX_ENQUEUED_EVENTS: usize = 50; -pub const CONNECTION_TIMEOUT_SECONDS: u64 = 5; - -#[derive(Clone, Debug)] -pub struct ConcurrentStreamManager { - max_threads: usize, - threads: Arc>>>, - emitters: Arc)>>>>, - event_receiver: Receiver<(String, Event)>, - listener_sender: Sender<(String, Event)>, -} - -impl ConcurrentStreamManager { - pub fn new(max_threads: usize) -> Self { - let (sender, receiver) = crossbeam_channel::unbounded(); - - Self { - max_threads, - threads: Arc::new(Mutex::new(HashMap::new())), - emitters: Arc::new(Mutex::new(HashMap::new())), - event_receiver: receiver, - listener_sender: sender, - } - } - - /// Returns if the manager has a connection to the given node - pub fn has_connection(&self, node: &String) -> bool { - self.emitters.lock().contains_key(node) - } - - /// Returns the receiver for events - pub fn receiver(&self) -> Receiver<(String, Event)> { - self.event_receiver.clone() - } - - /// Sends an event and returns an async value with the result - pub fn send(&self, target: &String, event: Event) -> AsyncValue<(), VentedError> { - let mut value = AsyncValue::new(); - if let Some(emitter) = self.emitters.lock().get(target) { - if let Err(_) = emitter.send_timeout( - (event, value.clone()), - Duration::from_secs(CONNECTION_TIMEOUT_SECONDS), - ) { - value.reject(VentedError::UnreachableNode(target.clone())); - } - } else { - value.reject(VentedError::UnknownNode(target.clone())) - } - - value - } - - /// Adds a connection to the manager causing it to start two new threads - /// This call blocks until the two threads are started up - pub fn add_connection(&self, stream: CryptoStream) -> VentedResult<()> { - if self.threads.lock().len() > self.max_threads { - return Err(VentedError::TooManyThreads); - } - let sender = self.listener_sender.clone(); - let recv_id = stream.receiver_node().clone(); - let (emitter, receiver) = crossbeam_channel::bounded(MAX_ENQUEUED_EVENTS); - self.emitters.lock().insert(recv_id.clone(), emitter); - let wg = WaitGroup::new(); - - let sender_thread = thread::Builder::new() - .name(format!("sender-{}", stream.receiver_node())) - .spawn({ - let stream = stream.clone(); - let recv_id = recv_id.clone(); - let emitters = Arc::clone(&self.emitters); - let threads = Arc::clone(&self.threads); - let wg = WaitGroup::clone(&wg); - - move || { - mem::drop(wg); - while let Ok((event, mut future)) = receiver.recv() { - if let Err(e) = stream.send(event) { - log::debug!("Failed to send event to {}: {}", recv_id, e); - future.reject(e); - break; - } - future.resolve(()); - } - if let Err(e) = stream.shutdown() { - log::error!("Failed to shutdown stream: {}", e); - } - emitters.lock().remove(&recv_id); - threads.lock().remove(&thread::current().id()); - } - })?; - self.threads - .lock() - .insert(sender_thread.thread().id(), sender_thread); - - let receiver_thread = thread::Builder::new() - .name(format!("receiver-{}", stream.receiver_node())) - .spawn({ - let threads = Arc::clone(&self.threads); - let wg = WaitGroup::clone(&wg); - move || { - mem::drop(wg); - loop { - match stream.read() { - Ok(mut event) => { - event.origin = Some(stream.receiver_node().clone()); - - if let Err(e) = sender.send((stream.receiver_node().clone(), event)) - { - log::trace!("Failed to get event from {}: {}", recv_id, e); - break; - } - } - Err(e) => { - log::error!("Failed to send event: {}", e); - break; - } - } - } - if let Err(e) = stream.shutdown() { - log::error!("Failed to shutdown stream: {}", e); - } - threads.lock().remove(&thread::current().id()); - } - })?; - self.threads - .lock() - .insert(receiver_thread.thread().id(), receiver_thread); - wg.wait(); - - Ok(()) - } -} diff --git a/src/stream/mod.rs b/src/stream/mod.rs index 383559f..3dc9a7c 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -2,4 +2,3 @@ pub use crypto_box::PublicKey; pub use crypto_box::SecretKey; pub mod cryptostream; -pub mod manager; diff --git a/src/utils/sync.rs b/src/utils/sync.rs index 080ce6d..2daf2e6 100644 --- a/src/utils/sync.rs +++ b/src/utils/sync.rs @@ -1,6 +1,6 @@ -use std::{mem, thread}; use std::sync::Arc; use std::time::{Duration, Instant}; +use std::{mem}; use parking_lot::Mutex; @@ -15,8 +15,8 @@ pub struct AsyncValue { } impl AsyncValue - where - E: std::fmt::Display, +where + E: std::fmt::Display, { /// Creates the future with no value pub fn new() -> Self { @@ -51,8 +51,8 @@ impl AsyncValue } pub fn on_error(&mut self, cb: F) -> &mut Self - where - F: FnOnce(&E) -> () + Send + Sync + 'static, + where + F: FnOnce(&E) -> () + Send + Sync + 'static, { self.err_cb.lock().replace(Box::new(cb)); @@ -60,8 +60,8 @@ impl AsyncValue } pub fn on_success(&mut self, cb: F) -> &mut Self - where - F: FnOnce(&V) -> () + Send + Sync + 'static, + where + F: FnOnce(&V) -> () + Send + Sync + 'static, { self.ok_cb.lock().replace(Box::new(cb)); @@ -113,12 +113,32 @@ impl AsyncValue } } + /// Returns the value asynchronously + pub async fn get_value_async(&mut self) -> Result { + while self.value.lock().is_none() { + async_std::task::sleep(Duration::from_millis(1)).await; + } + if let Some(err) = self.error.lock().take() { + Err(err) + } else { + Ok(self.value.lock().take().unwrap()) + } + } + /// Returns the value of the future only blocking for the given timeout pub fn get_value_with_timeout(&mut self, timeout: Duration) -> Option> { + async_std::task::block_on(self.get_value_with_timeout_async(timeout)) + } + + /// Returns the value of the future asynchronous with a timeout after the given duration + pub async fn get_value_with_timeout_async( + &mut self, + timeout: Duration, + ) -> Option> { let start = Instant::now(); while self.value.lock().is_none() { - thread::sleep(Duration::from_millis(1)); + async_std::task::sleep(Duration::from_millis(1)).await; if start.elapsed() > timeout { break; } @@ -144,3 +164,6 @@ impl Clone for AsyncValue { } } } + +unsafe impl Sync for AsyncValue {} +unsafe impl Send for AsyncValue {} diff --git a/tests/test_communication.rs b/tests/test_communication.rs index 036dce6..1776b15 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -1,7 +1,8 @@ +use async_std::task; use crypto_box::SecretKey; +use log::LevelFilter; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use std::thread; use std::time::Duration; use vented::event::Event; use vented::server::data::{Node, ServerTimeouts}; @@ -9,7 +10,12 @@ use vented::server::server_events::NODE_LIST_REQUEST_EVENT; use vented::server::VentedServer; fn setup() { - simple_logger::SimpleLogger::new().init().unwrap(); + simple_logger::SimpleLogger::new() + .with_module_level("async_std", LevelFilter::Warn) + .with_module_level("async_io", LevelFilter::Warn) + .with_module_level("polling", LevelFilter::Warn) + .init() + .unwrap(); } #[test] @@ -51,74 +57,80 @@ fn test_server_communication() { trusted: false, }) } - let mut server_a = VentedServer::new( - "A".to_string(), - global_secret_a, - nodes_a, - ServerTimeouts::default(), - 20, - 100, - ); - let mut server_b = VentedServer::new( - "B".to_string(), - global_secret_b, - nodes.clone(), - ServerTimeouts::default(), - 3, - 100, - ); - let server_c = VentedServer::new( - "C".to_string(), - global_secret_c, - nodes, - ServerTimeouts::default(), - 3, - 100, - ); - let wg = server_a.listen("localhost:22222".to_string()); - wg.wait(); - server_a.on("ping", { - let ping_count = Arc::clone(&ping_count); - move |_| { - ping_count.fetch_add(1, Ordering::Relaxed); + task::block_on(async { + let mut server_a = VentedServer::new( + "A".to_string(), + global_secret_a, + nodes_a, + ServerTimeouts::default(), + ); + let mut server_b = VentedServer::new( + "B".to_string(), + global_secret_b, + nodes.clone(), + ServerTimeouts::default(), + ); + let server_c = VentedServer::new( + "C".to_string(), + global_secret_c, + nodes, + ServerTimeouts::default(), + ); + server_a.listen("localhost:22222".to_string()); - Some(Event::new("pong".to_string())) - } - }); - server_b.on("pong", { - let pong_count = Arc::clone(&pong_count); - move |_| { - pong_count.fetch_add(1, Ordering::Relaxed); - None + server_a.on("ping", { + let ping_count = Arc::clone(&ping_count); + move |_| { + let ping_count = Arc::clone(&ping_count); + Box::pin(async move { + ping_count.fetch_add(1, Ordering::Relaxed); + + Some(Event::new("pong".to_string())) + }) + } + }); + server_b.on("pong", { + let pong_count = Arc::clone(&pong_count); + move |_| { + let pong_count = Arc::clone(&pong_count); + Box::pin(async move { + pong_count.fetch_add(1, Ordering::Relaxed); + None + }) + } + }); + for i in 0..10 { + assert!(server_a + .emit(format!("Nodes-{}", i), Event::new("ping")) + .await + .is_err()); } - }); - for i in 0..10 { - server_a.emit(format!("Nodes-{}", i), Event::new("ping")); - } - server_b - .emit("A", Event::new(NODE_LIST_REQUEST_EVENT)) - .on_success(|_| println!("Success")) - .block_unwrap(); - server_c - .emit("A", Event::new("ping".to_string())) - .block_unwrap(); - for _ in 0..9 { server_b + .emit("A", Event::new(NODE_LIST_REQUEST_EVENT)) + .await + .unwrap(); + server_c .emit("A", Event::new("ping".to_string())) - .block_unwrap(); - } - server_a - .emit("B", Event::new("pong".to_string())) - .block_unwrap(); - server_b - .emit("C", Event::new("ping".to_string())) - .block_unwrap(); - + .await + .unwrap(); + for _ in 0..9 { + server_b + .emit("A", Event::new("ping".to_string())) + .await + .unwrap(); + } + server_a + .emit("B", Event::new("pong".to_string())) + .await + .unwrap(); + server_b + .emit("C", Event::new("ping".to_string())) + .await + .unwrap(); + task::sleep(Duration::from_secs(1)).await; + }); // wait one second to make sure the servers were able to process the events - for _ in 0..100 { - thread::sleep(Duration::from_millis(10)); - } assert_eq!(ping_count.load(Ordering::SeqCst), 10); assert_eq!(pong_count.load(Ordering::SeqCst), 10);