From ce423d5c3da76f7336283f100605df330fc4f908 Mon Sep 17 00:00:00 2001 From: trivernis Date: Thu, 5 Nov 2020 17:26:39 +0100 Subject: [PATCH] Add key based authentication Signed-off-by: trivernis --- Cargo.toml | 3 +- src/server/data.rs | 6 +- src/server/mod.rs | 140 +++++++++++++++++++++++++++++------- src/server/server_events.rs | 9 +++ tests/test_communication.rs | 36 ++++++++-- 5 files changed, 161 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 62a07a1..af80508 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,4 +18,5 @@ crypto_box = "0.5.0" rand = "0.7.3" sha2 = "0.9.2" generic-array = "0.14.4" -typenum = "1.12.0" \ No newline at end of file +typenum = "1.12.0" +x25519-dalek = "1.1.0" \ No newline at end of file diff --git a/src/server/data.rs b/src/server/data.rs index 0b7a939..a47315a 100644 --- a/src/server/data.rs +++ b/src/server/data.rs @@ -5,18 +5,20 @@ use parking_lot::Mutex; use scheduled_thread_pool::ScheduledThreadPool; use std::collections::HashMap; use std::sync::Arc; +use x25519_dalek::PublicKey; #[derive(Clone, Debug)] pub struct Node { pub id: String, + pub public_key: PublicKey, pub address: Option, } #[derive(Clone)] pub(crate) struct ServerConnectionContext { pub is_server: bool, - pub secret_key: SecretKey, - pub own_node_id: String, + pub node_id: String, + pub global_secret: SecretKey, pub known_nodes: Arc>>, pub event_handler: Arc>, pub connections: Arc>>, diff --git a/src/server/mod.rs b/src/server/mod.rs index 9c2fe0e..c677e3b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -7,18 +7,21 @@ use scheduled_thread_pool::ScheduledThreadPool; use crate::crypto::CryptoStream; use crate::event::Event; use crate::event_handler::EventHandler; +use crate::result::VentedError::UnknownNode; use crate::result::{VentedError, VentedResult}; use crate::server::data::{Node, ServerConnectionContext}; use crate::server::server_events::{ - NodeInformationPayload, CONNECT_EVENT, CONN_ACCEPT_EVENT, CONN_REJECT_EVENT, + AuthPayload, NodeInformationPayload, AUTH_EVENT, CONNECT_EVENT, CONN_ACCEPT_EVENT, + CONN_CHALLENGE_EVENT, CONN_REJECT_EVENT, READY_EVENT, }; use parking_lot::Mutex; use std::io::Write; use std::sync::Arc; use std::thread; +use x25519_dalek::StaticSecret; pub mod data; -pub(crate) mod server_events; +pub mod server_events; /// The vented server that provides parallel handling of connections pub struct VentedServer { @@ -27,20 +30,24 @@ pub struct VentedServer { listener_pool: Arc>, sender_pool: Arc>, event_handler: Arc>, - secret_key: SecretKey, + global_secret_key: SecretKey, node_id: String, } impl VentedServer { - pub fn new(node_id: String, nodes: Vec, num_threads: usize) -> Self { - let mut rng = rand::thread_rng(); + pub fn new( + node_id: String, + secret_key: SecretKey, + nodes: Vec, + num_threads: usize, + ) -> Self { Self { node_id, event_handler: Arc::new(Mutex::new(EventHandler::new())), listener_pool: Arc::new(Mutex::new(ScheduledThreadPool::new(num_threads))), sender_pool: Arc::new(Mutex::new(ScheduledThreadPool::new(num_threads))), connections: Arc::new(Mutex::new(HashMap::new())), - secret_key: SecretKey::generate(&mut rng), + global_secret_key: secret_key, known_nodes: Arc::new(Mutex::new(nodes)), } } @@ -55,7 +62,13 @@ impl VentedServer { }); Ok(()) } else { - if let Some(node) = self.known_nodes.lock().iter().find(|n| n.id == node_id) { + let found_node = self + .known_nodes + .lock() + .iter() + .find(|n| n.id == node_id) + .cloned(); + if let Some(node) = found_node { if let Some(address) = &node.address { let handler = self.connect(address.clone())?; self.sender_pool.lock().execute(move || { @@ -104,8 +117,8 @@ impl VentedServer { fn get_server_context(&self) -> ServerConnectionContext { ServerConnectionContext { is_server: true, - own_node_id: self.node_id.clone(), - secret_key: self.secret_key.clone(), + node_id: self.node_id.clone(), + global_secret: self.global_secret_key.clone(), known_nodes: Arc::clone(&self.known_nodes), connections: Arc::clone(&self.connections), event_handler: Arc::clone(&self.event_handler), @@ -121,6 +134,9 @@ impl VentedServer { pool.lock().execute(move || { let stream = VentedServer::get_crypto_stream(params, stream).expect("Listener failed"); + event_handler + .lock() + .handle_event(Event::new(READY_EVENT.to_string())); while let Ok(event) = stream.read() { if let Some(response) = event_handler.lock().handle_event(event) { stream.send(response).expect("Failed to send response"); @@ -138,8 +154,8 @@ impl VentedServer { let (node_id, secret_box) = VentedServer::perform_key_exchange( params.is_server, &mut stream, - ¶ms.secret_key, - params.own_node_id, + params.node_id.clone(), + params.global_secret, params.known_nodes, )?; @@ -170,6 +186,9 @@ impl VentedServer { } } }); + self.event_handler + .lock() + .handle_event(Event::new(READY_EVENT.to_string())); Ok(stream) } @@ -178,14 +197,27 @@ impl VentedServer { fn perform_key_exchange( is_server: bool, stream: &mut TcpStream, - secret_key: &SecretKey, own_node_id: String, + global_secret: SecretKey, known_nodes: Arc>>, ) -> VentedResult<(String, ChaChaBox)> { + let secret_key = SecretKey::generate(&mut rand::thread_rng()); if is_server { - Self::perform_server_key_exchange(stream, secret_key, own_node_id, known_nodes) + Self::perform_server_key_exchange( + stream, + &secret_key, + own_node_id, + global_secret, + known_nodes, + ) } else { - Self::perform_client_key_exchange(stream, secret_key, own_node_id) + Self::perform_client_key_exchange( + stream, + &secret_key, + own_node_id, + global_secret, + known_nodes, + ) } } @@ -194,6 +226,8 @@ impl VentedServer { mut stream: &mut TcpStream, secret_key: &SecretKey, own_node_id: String, + global_secret: SecretKey, + known_nodes: Arc>>, ) -> VentedResult<(String, ChaChaBox)> { stream.write( &Event::with_payload( @@ -207,7 +241,7 @@ impl VentedServer { )?; stream.flush()?; let event = Event::from_bytes(&mut stream)?; - if event.name != CONN_ACCEPT_EVENT { + if event.name != CONN_CHALLENGE_EVENT { return Err(VentedError::UnknownNode(event.name)); } let NodeInformationPayload { @@ -215,6 +249,37 @@ impl VentedServer { node_id, } = event.get_payload::().unwrap(); let public_key = PublicKey::from(public_key); + let shared_auth_secret = + StaticSecret::from(global_secret.to_bytes()).diffie_hellman(&public_key); + + stream.write( + &Event::with_payload( + AUTH_EVENT.to_string(), + &AuthPayload { + calculated_secret: shared_auth_secret.to_bytes(), + }, + ) + .as_bytes(), + )?; + + let event = Event::from_bytes(&mut stream)?; + if event.name != CONN_ACCEPT_EVENT { + return Err(VentedError::UnknownNode(event.name)); + } + let known_nodes = known_nodes.lock(); + let node_static_info = event.get_payload::()?; + let node_data = if let Some(data) = known_nodes + .iter() + .find(|n| n.id == node_static_info.node_id) + { + data.clone() + } else { + return Err(UnknownNode(node_id)); + }; + if node_data.public_key.to_bytes() != node_static_info.public_key { + return Err(UnknownNode(node_id)); + } + let secret_box = ChaChaBox::new(&public_key, &secret_key); Ok((node_id, secret_box)) @@ -226,6 +291,7 @@ impl VentedServer { mut stream: &mut TcpStream, secret_key: &SecretKey, own_node_id: String, + global_secret: SecretKey, known_nodes: Arc>>, ) -> VentedResult<(String, ChaChaBox)> { let event = Event::from_bytes(&mut stream)?; @@ -238,29 +304,53 @@ impl VentedServer { } = event.get_payload::().unwrap(); let public_key = PublicKey::from(public_key); - if known_nodes - .lock() - .iter() - .find(|n| n.id == node_id) - .is_none() - { + let known_nodes = known_nodes.lock(); + let node_data = if let Some(data) = known_nodes.iter().find(|n| n.id == node_id) { + data.clone() + } else { stream.write(&Event::new(CONN_REJECT_EVENT.to_string()).as_bytes())?; stream.flush()?; - return Err(VentedError::UnknownNode(node_id)); - } + return Err(UnknownNode(node_id)); + }; let secret_box = ChaChaBox::new(&public_key, &secret_key); stream.write( &Event::with_payload( - CONN_ACCEPT_EVENT.to_string(), + CONN_CHALLENGE_EVENT.to_string(), &NodeInformationPayload { public_key: secret_key.public_key().to_bytes(), - node_id: own_node_id, + node_id: own_node_id.clone(), }, ) .as_bytes(), )?; stream.flush()?; + let auth_event = Event::from_bytes(&mut stream)?; + + if auth_event.name != AUTH_EVENT { + return Err(VentedError::UnexpectedEvent(auth_event.name)); + } + let AuthPayload { calculated_secret } = auth_event.get_payload::()?; + let expected_secret = + StaticSecret::from(secret_key.to_bytes()).diffie_hellman(&node_data.public_key); + + if expected_secret.to_bytes() != calculated_secret { + stream.write(&Event::new(CONN_REJECT_EVENT.to_string()).as_bytes())?; + stream.flush()?; + return Err(UnknownNode(node_id)); + } else { + stream.write( + &Event::with_payload( + CONN_ACCEPT_EVENT.to_string(), + &NodeInformationPayload { + node_id: own_node_id, + public_key: global_secret.public_key().to_bytes(), + }, + ) + .as_bytes(), + )?; + stream.flush()?; + } Ok((node_id, secret_box)) } diff --git a/src/server/server_events.rs b/src/server/server_events.rs index ddcd7fc..f816c82 100644 --- a/src/server/server_events.rs +++ b/src/server/server_events.rs @@ -1,11 +1,20 @@ use serde::{Deserialize, Serialize}; pub(crate) const CONNECT_EVENT: &str = "client:connect"; +pub(crate) const AUTH_EVENT: &str = "client:authenticate"; +pub(crate) const CONN_CHALLENGE_EVENT: &str = "server:conn_challenge"; pub(crate) const CONN_ACCEPT_EVENT: &str = "server:conn_accept"; pub(crate) const CONN_REJECT_EVENT: &str = "server:conn_reject"; +pub const READY_EVENT: &str = "connection:ready"; + #[derive(Serialize, Deserialize, Debug)] pub(crate) struct NodeInformationPayload { pub node_id: String, pub public_key: [u8; 32], } + +#[derive(Serialize, Deserialize, Debug)] +pub(crate) struct AuthPayload { + pub calculated_secret: [u8; 32], +} diff --git a/tests/test_communication.rs b/tests/test_communication.rs index a9ebdf8..9547167 100644 --- a/tests/test_communication.rs +++ b/tests/test_communication.rs @@ -1,27 +1,35 @@ +use crypto_box::SecretKey; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::thread; use std::time::Duration; use vented::event::Event; use vented::server::data::Node; +use vented::server::server_events::READY_EVENT; use vented::server::VentedServer; #[test] fn test_server_communication() { let ping_count = Arc::new(AtomicUsize::new(0)); let pong_count = Arc::new(AtomicUsize::new(0)); + let ready_count = Arc::new(AtomicUsize::new(0)); + let mut rng = rand::thread_rng(); + let global_secret_a = SecretKey::generate(&mut rng); + let global_secret_b = SecretKey::generate(&mut rng); let nodes = vec![ Node { id: "A".to_string(), address: Some("localhost:22222".to_string()), + public_key: global_secret_a.public_key(), }, Node { id: "B".to_string(), address: None, + public_key: global_secret_b.public_key(), }, ]; - let mut server_a = VentedServer::new("A".to_string(), nodes.clone(), 2); - let mut server_b = VentedServer::new("B".to_string(), nodes, 2); + let mut server_a = VentedServer::new("A".to_string(), global_secret_a, nodes.clone(), 4); + let mut server_b = VentedServer::new("B".to_string(), global_secret_b, nodes, 4); server_a.listen("localhost:22222".to_string()); thread::sleep(Duration::from_millis(10)); @@ -40,18 +48,36 @@ fn test_server_communication() { None } }); + server_a.on(READY_EVENT, { + let ready_count = Arc::clone(&ready_count); + move |_| { + ready_count.fetch_add(1, Ordering::Relaxed); + None + } + }); + server_b.on(READY_EVENT, { + let ready_count = Arc::clone(&ready_count); + move |_| { + ready_count.fetch_add(1, Ordering::Relaxed); + None + } + }); for _ in 0..10 { server_b .emit("A".to_string(), Event::new("ping".to_string())) .unwrap(); + thread::sleep(Duration::from_millis(10)); } server_a .emit("B".to_string(), Event::new("pong".to_string())) .unwrap(); // wait one second to make sure the servers were able to process the events - thread::sleep(Duration::from_secs(1)); + for _ in 0..100 { + thread::sleep(Duration::from_millis(10)); + } - assert_eq!(ping_count.load(Ordering::Relaxed), 10); - assert_eq!(pong_count.load(Ordering::Relaxed), 11); + assert_eq!(ready_count.load(Ordering::SeqCst), 2); + assert_eq!(ping_count.load(Ordering::SeqCst), 10); + assert_eq!(pong_count.load(Ordering::SeqCst), 11); }