Refactor connection function

Signed-off-by: trivernis <trivernis@protonmail.com>
pull/1/head
trivernis 4 years ago
parent 2c05e2736f
commit 32f15a2c89
Signed by: Trivernis
GPG Key ID: DFFFCC2C7A02DB45

@ -1,9 +1,11 @@
use crate::crypto::CryptoStream; use crate::crypto::CryptoStream;
use crate::event_handler::EventHandler; use crate::event_handler::EventHandler;
use crate::WaitGroup;
use crypto_box::SecretKey; use crypto_box::SecretKey;
use parking_lot::Mutex; use parking_lot::Mutex;
use scheduled_thread_pool::ScheduledThreadPool; use scheduled_thread_pool::ScheduledThreadPool;
use std::collections::HashMap; use std::collections::HashMap;
use std::mem;
use std::sync::Arc; use std::sync::Arc;
use x25519_dalek::PublicKey; use x25519_dalek::PublicKey;
@ -22,5 +24,45 @@ pub(crate) struct ServerConnectionContext {
pub known_nodes: Arc<Mutex<Vec<Node>>>, pub known_nodes: Arc<Mutex<Vec<Node>>>,
pub event_handler: Arc<Mutex<EventHandler>>, pub event_handler: Arc<Mutex<EventHandler>>,
pub connections: Arc<Mutex<HashMap<String, CryptoStream>>>, pub connections: Arc<Mutex<HashMap<String, CryptoStream>>>,
pub forwarded_connections: Arc<Mutex<HashMap<(String, String), Future<CryptoStream>>>>,
pub listener_pool: Arc<Mutex<ScheduledThreadPool>>, pub listener_pool: Arc<Mutex<ScheduledThreadPool>>,
} }
#[derive(Clone)]
pub(crate) struct Future<T> {
value: Arc<Mutex<Option<T>>>,
wg: Option<WaitGroup>,
}
impl<T> Future<T> {
/// Creates the future with no value
pub fn new() -> Self {
Self {
value: Arc::new(Mutex::new(None)),
wg: Some(WaitGroup::new()),
}
}
/// Creates the future with an already resolved value
pub fn with_value(value: T) -> Self {
Self {
value: Arc::new(Mutex::new(Some(value))),
wg: None,
}
}
/// Sets the value of the future consuming the wait group
pub fn set_value(&mut self, value: T) {
self.value.lock().replace(value);
mem::take(&mut self.wg);
}
/// Returns the value of the future after it has been set.
/// This call blocks
pub fn get_value(&mut self) -> T {
if let Some(wg) = mem::take(&mut self.wg) {
wg.wait();
}
self.value.lock().take().unwrap()
}
}

@ -9,7 +9,7 @@ use crate::event::Event;
use crate::event_handler::EventHandler; use crate::event_handler::EventHandler;
use crate::result::VentedError::UnknownNode; use crate::result::VentedError::UnknownNode;
use crate::result::{VentedError, VentedResult}; use crate::result::{VentedError, VentedResult};
use crate::server::data::{Node, ServerConnectionContext}; use crate::server::data::{Future, Node, ServerConnectionContext};
use crate::server::server_events::{ use crate::server::server_events::{
AuthPayload, ChallengePayload, NodeInformationPayload, VersionMismatchPayload, ACCEPT_EVENT, AuthPayload, ChallengePayload, NodeInformationPayload, VersionMismatchPayload, ACCEPT_EVENT,
AUTH_EVENT, CHALLENGE_EVENT, CONNECT_EVENT, MISMATCH_EVENT, READY_EVENT, REJECT_EVENT, AUTH_EVENT, CHALLENGE_EVENT, CONNECT_EVENT, MISMATCH_EVENT, READY_EVENT, REJECT_EVENT,
@ -27,6 +27,9 @@ pub mod server_events;
pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION"); pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION");
type ForwardFutureVector = Arc<Mutex<HashMap<(String, String), Future<CryptoStream>>>>;
type CryptoStreamMap = Arc<Mutex<HashMap<String, CryptoStream>>>;
/// The vented server that provides parallel handling of connections /// The vented server that provides parallel handling of connections
/// Usage: /// Usage:
/// ```rust /// ```rust
@ -58,7 +61,8 @@ pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION");
/// server.emit("B".to_string(), Event::new("ping".to_string())).unwrap(); /// server.emit("B".to_string(), Event::new("ping".to_string())).unwrap();
/// ``` /// ```
pub struct VentedServer { pub struct VentedServer {
connections: Arc<Mutex<HashMap<String, CryptoStream>>>, connections: CryptoStreamMap,
forwarded_connections: ForwardFutureVector,
known_nodes: Arc<Mutex<Vec<Node>>>, known_nodes: Arc<Mutex<Vec<Node>>>,
listener_pool: Arc<Mutex<ScheduledThreadPool>>, listener_pool: Arc<Mutex<ScheduledThreadPool>>,
sender_pool: Arc<Mutex<ScheduledThreadPool>>, sender_pool: Arc<Mutex<ScheduledThreadPool>>,
@ -90,6 +94,7 @@ impl VentedServer {
num_threads, num_threads,
))), ))),
connections: Arc::new(Mutex::new(HashMap::new())), connections: Arc::new(Mutex::new(HashMap::new())),
forwarded_connections: Arc::new(Mutex::new(HashMap::new())),
global_secret_key: secret_key, global_secret_key: secret_key,
known_nodes: Arc::new(Mutex::new(nodes)), known_nodes: Arc::new(Mutex::new(nodes)),
} }
@ -109,42 +114,22 @@ impl VentedServer {
/// The actual writing is done in a separate thread from the thread pool. /// The actual writing is done in a separate thread from the thread pool.
/// With the returned wait group one can wait for the event to be written. /// With the returned wait group one can wait for the event to be written.
pub fn emit(&self, node_id: String, event: Event) -> VentedResult<WaitGroup> { pub fn emit(&self, node_id: String, event: Event) -> VentedResult<WaitGroup> {
let handler = self.connections.lock().get(&node_id).cloned();
let wg = WaitGroup::new(); let wg = WaitGroup::new();
let wg2 = WaitGroup::clone(&wg); let stream = self.get_connection(node_id)?;
if let Some(handler) = handler { self.sender_pool.lock().execute({
let wg = WaitGroup::clone(&wg);
let connections = Arc::clone(&self.connections); let connections = Arc::clone(&self.connections);
self.sender_pool.lock().execute(move || { move || {
if let Err(e) = handler.send(event) { if let Err(e) = stream.send(event) {
log::error!("Failed to send event: {}", e); log::error!("Failed to send event: {}", e);
connections.lock().remove(handler.receiver_node()); connections.lock().remove(stream.receiver_node());
} }
std::mem::drop(wg); std::mem::drop(wg);
});
Ok(wg2)
} else {
let found_node = self
.known_nodes
.lock()
.iter()
.find(|n| n.id == node_id)
.cloned();
if let Some(node) = found_node {
if let Some(address) = &node.address {
let handler = self.connect(address.clone())?;
self.sender_pool.lock().execute(move || {
handler.send(event).expect("Failed to send event");
std::mem::drop(wg);
});
Ok(wg2)
} else {
Err(VentedError::NotAServer(node_id))
}
} else {
Err(VentedError::UnknownNode(node_id))
} }
} });
Ok(wg)
} }
/// Adds a handler for the given event. /// Adds a handler for the given event.
@ -171,6 +156,7 @@ impl VentedServer {
Ok(listener) => { Ok(listener) => {
log::info!("Listener running on {}", address); log::info!("Listener running on {}", address);
std::mem::drop(wg); std::mem::drop(wg);
for connection in listener.incoming() { for connection in listener.incoming() {
match connection { match connection {
Ok(stream) => { Ok(stream) => {
@ -201,6 +187,7 @@ impl VentedServer {
connections: Arc::clone(&self.connections), connections: Arc::clone(&self.connections),
event_handler: Arc::clone(&self.event_handler), event_handler: Arc::clone(&self.event_handler),
listener_pool: Arc::clone(&self.listener_pool), listener_pool: Arc::clone(&self.listener_pool),
forwarded_connections: Arc::clone(&self.forwarded_connections),
} }
} }
@ -209,9 +196,14 @@ impl VentedServer {
fn handle_connection(params: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> { fn handle_connection(params: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> {
let pool = Arc::clone(&params.listener_pool); let pool = Arc::clone(&params.listener_pool);
let event_handler = Arc::clone(&params.event_handler); let event_handler = Arc::clone(&params.event_handler);
log::trace!(
"Received connection from {}",
stream.peer_addr().expect("Failed to get peer address")
);
pool.lock().execute(move || { pool.lock().execute(move || {
let connections = Arc::clone(&params.connections); let connections = Arc::clone(&params.connections);
let stream = match VentedServer::get_crypto_stream(params, stream) { let stream = match VentedServer::get_crypto_stream(params, stream) {
Ok(stream) => stream, Ok(stream) => stream,
Err(e) => { Err(e) => {
@ -219,23 +211,65 @@ impl VentedServer {
return; return;
} }
}; };
event_handler log::trace!("Secure connection established.");
.lock() event_handler.lock().handle_event(Event::new(READY_EVENT));
.handle_event(Event::new(READY_EVENT.to_string())); if let Err(e) = Self::handle_read(event_handler, &stream) {
while let Ok(event) = stream.read() { log::error!("Connection aborted: {}", e);
if let Some(response) = event_handler.lock().handle_event(event) {
if let Err(e) = stream.send(response) {
log::error!("Failed to send response event: {}", e);
break;
}
}
} }
connections.lock().remove(stream.receiver_node()); connections.lock().remove(stream.receiver_node());
}); });
Ok(()) Ok(())
} }
/// Handler for reading after the connection is established
fn handle_read(
event_handler: Arc<Mutex<EventHandler>>,
stream: &CryptoStream,
) -> VentedResult<()> {
while let Ok(event) = stream.read() {
if let Some(response) = event_handler.lock().handle_event(event) {
stream.send(response)?
}
}
Ok(())
}
/// Takes three attempts to retrieve a connection for the given node.
/// First it tries to use the already established connection stored in the shared connections vector.
/// If that fails it tries to establish a new connection to the node by using the known address
fn get_connection(&self, target: String) -> VentedResult<CryptoStream> {
log::trace!("Trying to connect to {}", target);
if let Some(stream) = self.connections.lock().get(&target) {
log::trace!("Reusing existing connection.");
return Ok(CryptoStream::clone(stream));
}
let target_node = {
self.known_nodes
.lock()
.iter()
.find(|node| node.id == target)
.cloned()
.ok_or(VentedError::UnknownNode(target.clone()))?
};
if let Some(address) = target_node.address {
log::trace!("Connecting to known address");
match self.connect(address) {
Ok(stream) => {
return Ok(stream);
}
Err(e) => log::error!("Failed to connect to node '{}': {}", target, e),
}
}
log::debug!("All connection attempts to {} failed!", target);
Err(VentedError::NotAServer(target))
}
/// Establishes a crypto stream for the given stream /// Establishes a crypto stream for the given stream
fn get_crypto_stream( fn get_crypto_stream(
params: ServerConnectionContext, params: ServerConnectionContext,
@ -264,25 +298,20 @@ impl VentedServer {
context.is_server = false; context.is_server = false;
let connections = Arc::clone(&context.connections); let connections = Arc::clone(&context.connections);
let stream = Self::get_crypto_stream(context, stream)?; let stream = Self::get_crypto_stream(context.clone(), stream)?;
self.listener_pool.lock().execute({ self.listener_pool.lock().execute({
let stream = CryptoStream::clone(&stream); let stream = CryptoStream::clone(&stream);
let event_handler = Arc::clone(&self.event_handler); let event_handler = Arc::clone(&self.event_handler);
event_handler.lock().handle_event(Event::new(READY_EVENT));
move || { move || {
while let Ok(event) = stream.read() { if let Err(e) = Self::handle_read(event_handler, &stream) {
if let Some(response) = event_handler.lock().handle_event(event) { log::error!("Connection aborted: {}", e);
if let Err(e) = stream.send(response) {
log::error!("Failed to send response event: {}", e);
break;
}
}
} }
connections.lock().remove(stream.receiver_node()); connections.lock().remove(stream.receiver_node());
} }
}); });
self.event_handler
.lock()
.handle_event(Event::new(READY_EVENT.to_string()));
Ok(stream) Ok(stream)
} }
@ -336,6 +365,7 @@ impl VentedServer {
)?; )?;
stream.flush()?; stream.flush()?;
let event = Event::from_bytes(&mut stream)?; let event = Event::from_bytes(&mut stream)?;
if event.name != CONNECT_EVENT { if event.name != CONNECT_EVENT {
return Err(VentedError::UnexpectedEvent(event.name)); return Err(VentedError::UnexpectedEvent(event.name));
} }

@ -6,8 +6,7 @@ pub(crate) const CHALLENGE_EVENT: &str = "conn:challenge";
pub(crate) const ACCEPT_EVENT: &str = "conn:accept"; pub(crate) const ACCEPT_EVENT: &str = "conn:accept";
pub(crate) const REJECT_EVENT: &str = "conn:reject"; pub(crate) const REJECT_EVENT: &str = "conn:reject";
pub(crate) const MISMATCH_EVENT: &str = "conn:reject_version_mismatch"; pub(crate) const MISMATCH_EVENT: &str = "conn:reject_version_mismatch";
pub const READY_EVENT: &str = "server:ready";
pub const READY_EVENT: &str = "connection:ready";
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub(crate) struct NodeInformationPayload { pub(crate) struct NodeInformationPayload {

@ -16,11 +16,14 @@ fn setup() {
fn test_server_communication() { fn test_server_communication() {
setup(); setup();
let ping_count = Arc::new(AtomicUsize::new(0)); let ping_count = Arc::new(AtomicUsize::new(0));
let ping_c_count = Arc::new(AtomicUsize::new(0));
let pong_count = Arc::new(AtomicUsize::new(0)); let pong_count = Arc::new(AtomicUsize::new(0));
let ready_count = Arc::new(AtomicUsize::new(0)); let ready_count = Arc::new(AtomicUsize::new(0));
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let global_secret_a = SecretKey::generate(&mut rng); let global_secret_a = SecretKey::generate(&mut rng);
let global_secret_b = SecretKey::generate(&mut rng); let global_secret_b = SecretKey::generate(&mut rng);
let global_secret_c = SecretKey::generate(&mut rng);
let nodes = vec![ let nodes = vec![
Node { Node {
id: "A".to_string(), id: "A".to_string(),
@ -32,9 +35,15 @@ fn test_server_communication() {
address: None, address: None,
public_key: global_secret_b.public_key(), public_key: global_secret_b.public_key(),
}, },
Node {
id: "C".to_string(),
address: None,
public_key: global_secret_c.public_key(),
},
]; ];
let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 4); let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 6);
let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes, 4); let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes.clone(), 3);
let mut server_c = VentedServer::new("C".to_string(), global_secret_c, nodes, 3);
let wg = server_a.listen("localhost:22222".to_string()); let wg = server_a.listen("localhost:22222".to_string());
wg.wait(); wg.wait();
@ -67,7 +76,19 @@ fn test_server_communication() {
None None
} }
}); });
for _ in 0..10 { server_c.on("ping", {
let ping_c_count = Arc::clone(&ping_c_count);
move |_| {
ping_c_count.fetch_add(1, Ordering::Relaxed);
println!("C RECEIVED A PING!");
None
}
});
let wg = server_c
.emit("A".to_string(), Event::new("ping".to_string()))
.unwrap();
wg.wait();
for _ in 0..9 {
let wg = server_b let wg = server_b
.emit("A".to_string(), Event::new("ping".to_string())) .emit("A".to_string(), Event::new("ping".to_string()))
.unwrap(); .unwrap();
@ -77,13 +98,17 @@ fn test_server_communication() {
.emit("B".to_string(), Event::new("pong".to_string())) .emit("B".to_string(), Event::new("pong".to_string()))
.unwrap(); .unwrap();
wg.wait(); wg.wait();
assert!(server_b
.emit("C".to_string(), Event::new("ping".to_string()))
.is_err());
// wait one second to make sure the servers were able to process the events // wait one second to make sure the servers were able to process the events
for _ in 0..100 { for _ in 0..100 {
thread::sleep(Duration::from_millis(10)); thread::sleep(Duration::from_millis(10));
} }
assert_eq!(ready_count.load(Ordering::SeqCst), 2); assert_eq!(ping_c_count.load(Ordering::SeqCst), 0);
assert_eq!(ready_count.load(Ordering::SeqCst), 3);
assert_eq!(ping_count.load(Ordering::SeqCst), 10); assert_eq!(ping_count.load(Ordering::SeqCst), 10);
assert_eq!(pong_count.load(Ordering::SeqCst), 11); assert_eq!(pong_count.load(Ordering::SeqCst), 10);
} }

Loading…
Cancel
Save