From 19c1cdb649919db0a79ad613b8aaedf83547b307 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sat, 14 Nov 2020 18:26:24 +0100 Subject: [PATCH] Fix event redirect handling Signed-off-by: trivernis --- Cargo.toml | 2 +- src/server/mod.rs | 29 +++++++++------- src/server/server_events.rs | 68 ++++++++++++++++++++++++------------- tests/test_communication.rs | 26 ++++++++++---- 4 files changed, 82 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3e530b5..dd94391 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vented" description = "Event driven encrypted tcp communicaton" -version = "0.11.3" +version = "0.11.4" authors = ["trivernis "] edition = "2018" readme = "README.md" diff --git a/src/server/mod.rs b/src/server/mod.rs index b87f233..c99c725 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -325,23 +325,27 @@ impl VentedServer { async fn read_stream( mut stream: CryptoStream, connections: Arc>>, - mut handler: EventHandler, + handler: EventHandler, ) { loop { match stream.read().await { Ok(mut event) => { - event.origin = Some(stream.receiver_node().clone()); - let results = handler.handle_event(event).await; - for result in results { - if let Err(e) = stream.send(result).await { - log::error!( - "Failed to send event to {}: {}", - stream.receiver_node(), - e - ); - break; + let mut handler = handler.clone(); + let mut stream = stream.clone(); + task::spawn(async move { + event.origin = Some(stream.receiver_node().clone()); + let results = handler.handle_event(event).await; + for result in results { + if let Err(e) = stream.send(result).await { + log::error!( + "Failed to send event to {}: {}", + stream.receiver_node(), + e + ); + stream.shutdown().expect("Failed to shutdown stream"); + } } - } + }); } Err(e) => { log::error!( @@ -354,6 +358,7 @@ impl VentedServer { } } connections.lock().remove(stream.receiver_node()); + stream.shutdown().expect("Failed to shutdown stream"); } /// Takes three attempts to retrieve a connection for the given node. diff --git a/src/server/server_events.rs b/src/server/server_events.rs index 72672ef..0a40f76 100644 --- a/src/server/server_events.rs +++ b/src/server/server_events.rs @@ -6,6 +6,7 @@ use std::sync::Arc; +use async_std::task; use rand::{thread_rng, RngCore}; use serde::{Deserialize, Serialize}; use x25519_dalek::PublicKey; @@ -137,13 +138,21 @@ impl VentedServer { let connections = Arc::clone(&connections); Box::pin(async move { let payload = event.get_payload::().ok()?; + if payload.source == event.origin? { + log::trace!( + "Handling redirect from {} via {} to {}", + payload.source, + payload.proxy, + payload.target + ); let opt_stream = connections.lock().get(&payload.target).cloned(); if let Some(mut stream) = opt_stream { if let Ok(_) = stream .send(Event::with_payload(REDIRECT_REDIRECTED_EVENT, &payload)) .await { + log::trace!("Redirect succeeded"); return Some(Event::with_payload( REDIRECT_CONFIRM_EVENT, &RedirectResponsePayload { id: payload.id }, @@ -152,6 +161,7 @@ impl VentedServer { } } + log::trace!("Redirect failed"); Some(Event::with_payload( REDIRECT_FAIL_EVENT, &RedirectResponsePayload { id: payload.id }, @@ -168,32 +178,44 @@ impl VentedServer { let mut event_handler = event_handler.clone(); Box::pin(async move { let payload = event.get_payload::().ok()?; - let event = Event::from(&mut &payload.content[..]).ok()?; let origin = event.origin.clone()?; + let event = Event::from(&mut &payload.content[..]).ok()?; - let responses = event_handler.handle_event(event).await; - let responses = responses - .iter() - .cloned() - .map(|mut value| { - let payload = payload.clone(); - Event::with_payload( - REDIRECT_EVENT, - &RedirectPayload::new( - payload.target, - payload.proxy, - payload.source, - value.as_bytes(), - ), - ) - }) - .collect::>(); - let opt_stream = connections.lock().get(&origin).cloned(); - if let Some(mut stream) = opt_stream { - for response in responses { - stream.send(response).await.ok()?; + log::trace!("Spawning task to handle redirect responses"); + task::spawn(async move { + log::trace!("Emitting redirected event..."); + let responses = event_handler.handle_event(event).await; + log::trace!("Mapping responses..."); + let responses = responses + .iter() + .cloned() + .map(|mut value| { + let payload = payload.clone(); + Event::with_payload( + REDIRECT_EVENT, + &RedirectPayload::new( + payload.target, + payload.proxy, + payload.source, + value.as_bytes(), + ), + ) + }) + .collect::>(); + let opt_stream = connections.lock().get(&origin).cloned(); + + log::trace!("Sending responses..."); + if let Some(mut stream) = opt_stream { + for response in responses { + if let Err(e) = stream.send(response).await { + log::error!("Failed to send response events: {}", e); + connections.lock().remove(stream.receiver_node()); + stream.shutdown().expect("Failed to shutdown stream"); + } + } } - } + }); + log::trace!("Done"); None }) diff --git a/tests/test_communication.rs b/tests/test_communication.rs index 1776b15..c2bedf4 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -1,7 +1,7 @@ use async_std::task; use crypto_box::SecretKey; use log::LevelFilter; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; use vented::event::Event; @@ -23,6 +23,7 @@ fn test_server_communication() { setup(); let ping_count = Arc::new(AtomicUsize::new(0)); let pong_count = Arc::new(AtomicUsize::new(0)); + let c_pinged = Arc::new(AtomicBool::new(false)); let mut rng = rand::thread_rng(); let global_secret_a = SecretKey::generate(&mut rng); let global_secret_b = SecretKey::generate(&mut rng); @@ -71,7 +72,7 @@ fn test_server_communication() { nodes.clone(), ServerTimeouts::default(), ); - let server_c = VentedServer::new( + let mut server_c = VentedServer::new( "C".to_string(), global_secret_c, nodes, @@ -100,6 +101,16 @@ fn test_server_communication() { }) } }); + server_c.on("ping", { + let c_pinged = Arc::clone(&c_pinged); + move |_| { + let c_pinged = Arc::clone(&c_pinged); + Box::pin(async move { + c_pinged.store(true, Ordering::Relaxed); + None + }) + } + }); for i in 0..10 { assert!(server_a .emit(format!("Nodes-{}", i), Event::new("ping")) @@ -114,6 +125,10 @@ fn test_server_communication() { .emit("A", Event::new("ping".to_string())) .await .unwrap(); + server_b + .emit("C", Event::new("ping".to_string())) + .await + .unwrap(); for _ in 0..9 { server_b .emit("A", Event::new("ping".to_string())) @@ -124,14 +139,11 @@ fn test_server_communication() { .emit("B", Event::new("pong".to_string())) .await .unwrap(); - server_b - .emit("C", Event::new("ping".to_string())) - .await - .unwrap(); - task::sleep(Duration::from_secs(1)).await; + task::sleep(Duration::from_secs(2)).await; }); // wait one second to make sure the servers were able to process the events assert_eq!(ping_count.load(Ordering::SeqCst), 10); assert_eq!(pong_count.load(Ordering::SeqCst), 10); + assert!(c_pinged.load(Ordering::SeqCst)); }