Refactor connection function

Signed-off-by: trivernis <trivernis@protonmail.com>
pull/1/head
trivernis 4 years ago
parent 2c05e2736f
commit 32f15a2c89
Signed by: Trivernis
GPG Key ID: DFFFCC2C7A02DB45

@ -1,9 +1,11 @@
use crate::crypto::CryptoStream;
use crate::event_handler::EventHandler;
use crate::WaitGroup;
use crypto_box::SecretKey;
use parking_lot::Mutex;
use scheduled_thread_pool::ScheduledThreadPool;
use std::collections::HashMap;
use std::mem;
use std::sync::Arc;
use x25519_dalek::PublicKey;
@ -22,5 +24,45 @@ pub(crate) struct ServerConnectionContext {
pub known_nodes: Arc<Mutex<Vec<Node>>>,
pub event_handler: Arc<Mutex<EventHandler>>,
pub connections: Arc<Mutex<HashMap<String, CryptoStream>>>,
pub forwarded_connections: Arc<Mutex<HashMap<(String, String), Future<CryptoStream>>>>,
pub listener_pool: Arc<Mutex<ScheduledThreadPool>>,
}
#[derive(Clone)]
pub(crate) struct Future<T> {
value: Arc<Mutex<Option<T>>>,
wg: Option<WaitGroup>,
}
impl<T> Future<T> {
/// Creates the future with no value
pub fn new() -> Self {
Self {
value: Arc::new(Mutex::new(None)),
wg: Some(WaitGroup::new()),
}
}
/// Creates the future with an already resolved value
pub fn with_value(value: T) -> Self {
Self {
value: Arc::new(Mutex::new(Some(value))),
wg: None,
}
}
/// Sets the value of the future consuming the wait group
pub fn set_value(&mut self, value: T) {
self.value.lock().replace(value);
mem::take(&mut self.wg);
}
/// Returns the value of the future after it has been set.
/// This call blocks
pub fn get_value(&mut self) -> T {
if let Some(wg) = mem::take(&mut self.wg) {
wg.wait();
}
self.value.lock().take().unwrap()
}
}

@ -9,7 +9,7 @@ use crate::event::Event;
use crate::event_handler::EventHandler;
use crate::result::VentedError::UnknownNode;
use crate::result::{VentedError, VentedResult};
use crate::server::data::{Node, ServerConnectionContext};
use crate::server::data::{Future, Node, ServerConnectionContext};
use crate::server::server_events::{
AuthPayload, ChallengePayload, NodeInformationPayload, VersionMismatchPayload, ACCEPT_EVENT,
AUTH_EVENT, CHALLENGE_EVENT, CONNECT_EVENT, MISMATCH_EVENT, READY_EVENT, REJECT_EVENT,
@ -27,6 +27,9 @@ pub mod server_events;
pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION");
type ForwardFutureVector = Arc<Mutex<HashMap<(String, String), Future<CryptoStream>>>>;
type CryptoStreamMap = Arc<Mutex<HashMap<String, CryptoStream>>>;
/// The vented server that provides parallel handling of connections
/// Usage:
/// ```rust
@ -58,7 +61,8 @@ pub(crate) const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION");
/// server.emit("B".to_string(), Event::new("ping".to_string())).unwrap();
/// ```
pub struct VentedServer {
connections: Arc<Mutex<HashMap<String, CryptoStream>>>,
connections: CryptoStreamMap,
forwarded_connections: ForwardFutureVector,
known_nodes: Arc<Mutex<Vec<Node>>>,
listener_pool: Arc<Mutex<ScheduledThreadPool>>,
sender_pool: Arc<Mutex<ScheduledThreadPool>>,
@ -90,6 +94,7 @@ impl VentedServer {
num_threads,
))),
connections: Arc::new(Mutex::new(HashMap::new())),
forwarded_connections: Arc::new(Mutex::new(HashMap::new())),
global_secret_key: secret_key,
known_nodes: Arc::new(Mutex::new(nodes)),
}
@ -109,42 +114,22 @@ impl VentedServer {
/// The actual writing is done in a separate thread from the thread pool.
/// With the returned wait group one can wait for the event to be written.
pub fn emit(&self, node_id: String, event: Event) -> VentedResult<WaitGroup> {
let handler = self.connections.lock().get(&node_id).cloned();
let wg = WaitGroup::new();
let wg2 = WaitGroup::clone(&wg);
let stream = self.get_connection(node_id)?;
if let Some(handler) = handler {
self.sender_pool.lock().execute({
let wg = WaitGroup::clone(&wg);
let connections = Arc::clone(&self.connections);
self.sender_pool.lock().execute(move || {
if let Err(e) = handler.send(event) {
move || {
if let Err(e) = stream.send(event) {
log::error!("Failed to send event: {}", e);
connections.lock().remove(handler.receiver_node());
connections.lock().remove(stream.receiver_node());
}
std::mem::drop(wg);
});
Ok(wg2)
} else {
let found_node = self
.known_nodes
.lock()
.iter()
.find(|n| n.id == node_id)
.cloned();
if let Some(node) = found_node {
if let Some(address) = &node.address {
let handler = self.connect(address.clone())?;
self.sender_pool.lock().execute(move || {
handler.send(event).expect("Failed to send event");
std::mem::drop(wg);
});
Ok(wg2)
} else {
Err(VentedError::NotAServer(node_id))
}
} else {
Err(VentedError::UnknownNode(node_id))
}
}
});
Ok(wg)
}
/// Adds a handler for the given event.
@ -171,6 +156,7 @@ impl VentedServer {
Ok(listener) => {
log::info!("Listener running on {}", address);
std::mem::drop(wg);
for connection in listener.incoming() {
match connection {
Ok(stream) => {
@ -201,6 +187,7 @@ impl VentedServer {
connections: Arc::clone(&self.connections),
event_handler: Arc::clone(&self.event_handler),
listener_pool: Arc::clone(&self.listener_pool),
forwarded_connections: Arc::clone(&self.forwarded_connections),
}
}
@ -209,9 +196,14 @@ impl VentedServer {
fn handle_connection(params: ServerConnectionContext, stream: TcpStream) -> VentedResult<()> {
let pool = Arc::clone(&params.listener_pool);
let event_handler = Arc::clone(&params.event_handler);
log::trace!(
"Received connection from {}",
stream.peer_addr().expect("Failed to get peer address")
);
pool.lock().execute(move || {
let connections = Arc::clone(&params.connections);
let stream = match VentedServer::get_crypto_stream(params, stream) {
Ok(stream) => stream,
Err(e) => {
@ -219,23 +211,65 @@ impl VentedServer {
return;
}
};
event_handler
.lock()
.handle_event(Event::new(READY_EVENT.to_string()));
while let Ok(event) = stream.read() {
if let Some(response) = event_handler.lock().handle_event(event) {
if let Err(e) = stream.send(response) {
log::error!("Failed to send response event: {}", e);
break;
}
}
log::trace!("Secure connection established.");
event_handler.lock().handle_event(Event::new(READY_EVENT));
if let Err(e) = Self::handle_read(event_handler, &stream) {
log::error!("Connection aborted: {}", e);
}
connections.lock().remove(stream.receiver_node());
});
Ok(())
}
/// Handler for reading after the connection is established
fn handle_read(
event_handler: Arc<Mutex<EventHandler>>,
stream: &CryptoStream,
) -> VentedResult<()> {
while let Ok(event) = stream.read() {
if let Some(response) = event_handler.lock().handle_event(event) {
stream.send(response)?
}
}
Ok(())
}
/// Takes three attempts to retrieve a connection for the given node.
/// First it tries to use the already established connection stored in the shared connections vector.
/// If that fails it tries to establish a new connection to the node by using the known address
fn get_connection(&self, target: String) -> VentedResult<CryptoStream> {
log::trace!("Trying to connect to {}", target);
if let Some(stream) = self.connections.lock().get(&target) {
log::trace!("Reusing existing connection.");
return Ok(CryptoStream::clone(stream));
}
let target_node = {
self.known_nodes
.lock()
.iter()
.find(|node| node.id == target)
.cloned()
.ok_or(VentedError::UnknownNode(target.clone()))?
};
if let Some(address) = target_node.address {
log::trace!("Connecting to known address");
match self.connect(address) {
Ok(stream) => {
return Ok(stream);
}
Err(e) => log::error!("Failed to connect to node '{}': {}", target, e),
}
}
log::debug!("All connection attempts to {} failed!", target);
Err(VentedError::NotAServer(target))
}
/// Establishes a crypto stream for the given stream
fn get_crypto_stream(
params: ServerConnectionContext,
@ -264,25 +298,20 @@ impl VentedServer {
context.is_server = false;
let connections = Arc::clone(&context.connections);
let stream = Self::get_crypto_stream(context, stream)?;
let stream = Self::get_crypto_stream(context.clone(), stream)?;
self.listener_pool.lock().execute({
let stream = CryptoStream::clone(&stream);
let event_handler = Arc::clone(&self.event_handler);
event_handler.lock().handle_event(Event::new(READY_EVENT));
move || {
while let Ok(event) = stream.read() {
if let Some(response) = event_handler.lock().handle_event(event) {
if let Err(e) = stream.send(response) {
log::error!("Failed to send response event: {}", e);
break;
}
}
if let Err(e) = Self::handle_read(event_handler, &stream) {
log::error!("Connection aborted: {}", e);
}
connections.lock().remove(stream.receiver_node());
}
});
self.event_handler
.lock()
.handle_event(Event::new(READY_EVENT.to_string()));
Ok(stream)
}
@ -336,6 +365,7 @@ impl VentedServer {
)?;
stream.flush()?;
let event = Event::from_bytes(&mut stream)?;
if event.name != CONNECT_EVENT {
return Err(VentedError::UnexpectedEvent(event.name));
}

@ -6,8 +6,7 @@ pub(crate) const CHALLENGE_EVENT: &str = "conn:challenge";
pub(crate) const ACCEPT_EVENT: &str = "conn:accept";
pub(crate) const REJECT_EVENT: &str = "conn:reject";
pub(crate) const MISMATCH_EVENT: &str = "conn:reject_version_mismatch";
pub const READY_EVENT: &str = "connection:ready";
pub const READY_EVENT: &str = "server:ready";
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct NodeInformationPayload {

@ -16,11 +16,14 @@ fn setup() {
fn test_server_communication() {
setup();
let ping_count = Arc::new(AtomicUsize::new(0));
let ping_c_count = Arc::new(AtomicUsize::new(0));
let pong_count = Arc::new(AtomicUsize::new(0));
let ready_count = Arc::new(AtomicUsize::new(0));
let mut rng = rand::thread_rng();
let global_secret_a = SecretKey::generate(&mut rng);
let global_secret_b = SecretKey::generate(&mut rng);
let global_secret_c = SecretKey::generate(&mut rng);
let nodes = vec![
Node {
id: "A".to_string(),
@ -32,9 +35,15 @@ fn test_server_communication() {
address: None,
public_key: global_secret_b.public_key(),
},
Node {
id: "C".to_string(),
address: None,
public_key: global_secret_c.public_key(),
},
];
let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 4);
let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes, 4);
let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 6);
let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes.clone(), 3);
let mut server_c = VentedServer::new("C".to_string(), global_secret_c, nodes, 3);
let wg = server_a.listen("localhost:22222".to_string());
wg.wait();
@ -67,7 +76,19 @@ fn test_server_communication() {
None
}
});
for _ in 0..10 {
server_c.on("ping", {
let ping_c_count = Arc::clone(&ping_c_count);
move |_| {
ping_c_count.fetch_add(1, Ordering::Relaxed);
println!("C RECEIVED A PING!");
None
}
});
let wg = server_c
.emit("A".to_string(), Event::new("ping".to_string()))
.unwrap();
wg.wait();
for _ in 0..9 {
let wg = server_b
.emit("A".to_string(), Event::new("ping".to_string()))
.unwrap();
@ -77,13 +98,17 @@ fn test_server_communication() {
.emit("B".to_string(), Event::new("pong".to_string()))
.unwrap();
wg.wait();
assert!(server_b
.emit("C".to_string(), Event::new("ping".to_string()))
.is_err());
// wait one second to make sure the servers were able to process the events
for _ in 0..100 {
thread::sleep(Duration::from_millis(10));
}
assert_eq!(ready_count.load(Ordering::SeqCst), 2);
assert_eq!(ping_c_count.load(Ordering::SeqCst), 0);
assert_eq!(ready_count.load(Ordering::SeqCst), 3);
assert_eq!(ping_count.load(Ordering::SeqCst), 10);
assert_eq!(pong_count.load(Ordering::SeqCst), 11);
assert_eq!(pong_count.load(Ordering::SeqCst), 10);
}

Loading…
Cancel
Save