diff --git a/Cargo.lock b/Cargo.lock index c21e5985..b5e2f2b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aead" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b613b8e1e3cf911a086f53f03bf286f52fd7a7258e4fa606f0ef220d39d8877" +dependencies = [ + "generic-array", +] + [[package]] name = "aho-corasick" version = "0.7.18" @@ -91,6 +100,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +dependencies = [ + "generic-array", +] + [[package]] name = "bromine" version = "0.20.0" @@ -99,6 +117,7 @@ dependencies = [ "bincode", "byteorder", "bytes", + "chacha20poly1305", "criterion", "crossbeam-utils", "futures", @@ -106,13 +125,17 @@ dependencies = [ "lazy_static", "num_enum", "postcard", + "rand", + "rand_core 0.6.3", "rmp-serde", "serde", "serde_json", + "sha2", "thiserror", "tokio", "tracing", "trait-bound-typemap", + "x25519-dalek", ] [[package]] @@ -160,6 +183,40 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chacha20" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b72a433d0cf2aef113ba70f62634c56fddb0f244e6377185c56a7cadbd8f91" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", + "zeroize", +] + +[[package]] +name = "chacha20poly1305" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b84ed6d1d5f7aa9bdde921a5090e0ca4d934d250ea3b402a5fab3a994e28a2a" +dependencies = [ + "aead", + "chacha20", + "cipher", + "poly1305", + "zeroize", +] + +[[package]] +name = "cipher" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +dependencies = [ + "generic-array", +] + [[package]] name = "clap" version = "2.34.0" @@ -183,6 +240,15 @@ dependencies = [ "volatile-register", ] +[[package]] +name = "cpufeatures" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59a6001667ab124aebae2a495118e11d30984c3a653e99d86d58971708cf5e4b" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.3.5" @@ -277,6 +343,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "crypto-common" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.1.6" @@ -299,6 +375,38 @@ dependencies = [ "memchr", ] +[[package]] +name = "curve25519-dalek" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f9d052967f590a76e62eb387bd0bbb1b000182c3cefe5364db6b7211651bc0" +dependencies = [ + "byteorder", + "digest 0.9.0", + "rand_core 0.5.1", + "subtle", + "zeroize", +] + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + +[[package]] +name = "digest" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "either" version = "1.6.1" @@ -404,6 +512,38 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.10.2+wasi-snapshot-preview1", +] + [[package]] name = "half" version = "1.8.2" @@ -526,7 +666,7 @@ dependencies = [ "log", "miow", "ntapi", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "winapi", ] @@ -615,6 +755,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "pin-project-lite" version = "0.2.8" @@ -655,6 +801,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly1305" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" +dependencies = [ + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "postcard" version = "0.7.3" @@ -672,6 +829,12 @@ version = "0.1.5-pre" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c68cb38ed13fd7bc9dd5db8f165b7c8d9c1a315104083a2b10f11354c2af97f" +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + [[package]] name = "proc-macro-crate" version = "1.1.3" @@ -700,6 +863,45 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core 0.6.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.3", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", +] + +[[package]] +name = "rand_core" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +dependencies = [ + "getrandom 0.2.5", +] + [[package]] name = "rayon" version = "1.5.1" @@ -891,6 +1093,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest 0.10.3", +] + [[package]] name = "slab" version = "0.4.5" @@ -922,6 +1135,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + [[package]] name = "syn" version = "1.0.86" @@ -933,6 +1152,18 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -1050,6 +1281,12 @@ dependencies = [ "multi-trait-object", ] +[[package]] +name = "typenum" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" + [[package]] name = "unicode-width" version = "0.1.9" @@ -1062,12 +1299,28 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "universal-hash" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f214e8f697e925001e66ec2c6e37a4ef93f0f78c2eed7814394e10c62025b05" +dependencies = [ + "generic-array", + "subtle", +] + [[package]] name = "vcell" version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77439c1b53d2303b20d9459b1ade71a83c716e3f9c34f3228c00e6f185d6c002" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "void" version = "1.0.2" @@ -1094,6 +1347,18 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + +[[package]] +name = "wasi" +version = "0.10.2+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1194,3 +1459,35 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "x25519-dalek" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2392b6b94a576b4e2bf3c5b2757d63f10ada8020a2e4d08ac849ebcf6ea8e077" +dependencies = [ + "curve25519-dalek", + "rand_core 0.5.1", + "zeroize", +] + +[[package]] +name = "zeroize" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4756f7db3f7b5574938c3eb1c117038b8e07f95ee6718c0efad4ac21508f1efd" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f8f187641dad4f680d25c4bfc4225b418165984179f26ca76ec4fb6441d3a17" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] diff --git a/Cargo.toml b/Cargo.toml index e9e82cd2..1b3ebdf4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,11 @@ rmp-serde = { version = "1.0.0", optional = true } bincode = { version = "1.3.3", optional = true } serde_json = { version = "1.0.79", optional = true } bytes = "1.1.0" +chacha20poly1305 = "0.9.0" +x25519-dalek = "1.2.0" +rand = "0.8.5" +rand_core = "0.6.3" +sha2 = "0.10.2" [dependencies.serde] optional = true diff --git a/src/ipc/builder.rs b/src/ipc/builder.rs index 4183f05e..83cba3ce 100644 --- a/src/ipc/builder.rs +++ b/src/ipc/builder.rs @@ -11,6 +11,7 @@ use crate::namespaces::builder::NamespaceBuilder; use crate::namespaces::namespace::Namespace; #[cfg(feature = "serialize")] use crate::payload::DynamicSerializer; +use crate::prelude::AsyncProtocolStream; use crate::protocol::AsyncStreamProtocolListener; use std::collections::HashMap; use std::future::Future; @@ -64,6 +65,8 @@ pub struct IPCBuilder { timeout: Duration, #[cfg(feature = "serialize")] default_serializer: DynamicSerializer, + listener_options: L::ListenerOptions, + stream_options: ::StreamOptions, } impl IPCBuilder @@ -89,6 +92,8 @@ where timeout: Duration::from_secs(60), #[cfg(feature = "serialize")] default_serializer: DynamicSerializer::first_available(), + listener_options: Default::default(), + stream_options: Default::default(), } } @@ -163,6 +168,23 @@ where self } + /// Sets the options for the given protocol listener + pub fn server_options(mut self, options: L::ListenerOptions) -> Self { + self.listener_options = options; + + self + } + + /// Sets the options for the given protocol stream + pub fn client_options( + mut self, + options: ::StreamOptions, + ) -> Self { + self.stream_options = options; + + self + } + /// Builds an ipc server #[tracing::instrument(skip(self))] pub async fn build_server(self) -> Result<()> { @@ -176,7 +198,9 @@ where #[cfg(feature = "serialize")] default_serializer: self.default_serializer, }; - server.start::(self.address.unwrap()).await?; + server + .start::(self.address.unwrap(), self.listener_options) + .await?; Ok(()) } @@ -198,7 +222,9 @@ where default_serializer: self.default_serializer, }; - let ctx = client.connect::(self.address.unwrap()).await?; + let ctx = client + .connect::(self.address.unwrap(), self.stream_options.clone()) + .await?; Ok(ctx) } @@ -230,7 +256,9 @@ where default_serializer: self.default_serializer.clone(), }; - let ctx = client.connect::(address.clone()).await?; + let ctx = client + .connect::(address.clone(), self.stream_options.clone()) + .await?; contexts.push(ctx); } diff --git a/src/ipc/client.rs b/src/ipc/client.rs index ae7f8600..aa490a73 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -33,12 +33,13 @@ pub struct IPCClient { impl IPCClient { /// Connects to a given address and returns an emitter for events to that address. /// Invoked by [IPCBuilder::build_client](crate::builder::IPCBuilder::build_client) - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self, options))] pub async fn connect( self, address: S::AddressType, + options: S::StreamOptions, ) -> Result { - let stream = S::protocol_connect(address).await?; + let stream = S::protocol_connect(address, options).await?; let (read_half, write_half) = stream.protocol_into_split(); let emitter = StreamEmitter::new::(write_half); diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 9a71a66c..c88d238a 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -31,12 +31,13 @@ pub struct IPCServer { impl IPCServer { /// Starts the IPC Server. /// Invoked by [IPCBuilder::build_server](crate::builder::IPCBuilder::build_server) - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self, options))] pub async fn start( self, address: L::AddressType, + options: L::ListenerOptions, ) -> Result<()> { - let listener = L::protocol_bind(address.clone()).await?; + let listener = L::protocol_bind(address.clone(), options).await?; let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); let data = Arc::new(RwLock::new(self.data)); diff --git a/src/ipc/stream_emitter/emit_metadata.rs b/src/ipc/stream_emitter/emit_metadata.rs index 57d77fab..b5a3e85b 100644 --- a/src/ipc/stream_emitter/emit_metadata.rs +++ b/src/ipc/stream_emitter/emit_metadata.rs @@ -91,6 +91,7 @@ impl Future for EmitMetadata

{ let event_bytes = event.into_bytes()?; let mut stream = stream.lock().await; stream.deref_mut().write_all(&event_bytes[..]).await?; + stream.deref_mut().flush().await?; tracing::trace!(bytes_len = event_bytes.len()); Ok(event_id) diff --git a/src/lib.rs b/src/lib.rs index b3cc6fe8..6eba0549 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,6 +101,7 @@ //! # } //! ``` +extern crate core; #[cfg(all( feature = "serialize", not(any( diff --git a/src/protocol/encrypted/crypt_handling.rs b/src/protocol/encrypted/crypt_handling.rs new file mode 100644 index 00000000..4bfe35c9 --- /dev/null +++ b/src/protocol/encrypted/crypt_handling.rs @@ -0,0 +1,158 @@ +use crate::prelude::encrypted::EncryptedStream; +use crate::prelude::IPCResult; +use crate::protocol::AsyncProtocolStream; +use bytes::Bytes; +use chacha20poly1305::aead::{Aead, NewAead}; +use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce}; +use rand::thread_rng; +use rand_core::RngCore; +use sha2::{Digest, Sha224, Sha256}; +use std::io; +use std::io::ErrorKind; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use x25519_dalek::{PublicKey, StaticSecret}; + +/// A structure used for encryption. +/// It holds the cipher initialized with the given key +/// and two counters for both encryption and decryption +/// count which are used to keep track of the nonce. +#[derive(Clone)] +pub(crate) struct CipherBox { + cipher: ChaCha20Poly1305, + en_count: Arc, + de_count: Arc, +} + +impl CipherBox { + pub fn new(key: Bytes) -> Self { + let key = Key::from_slice(&key[..]); + let cipher = ChaCha20Poly1305::new(key); + + Self { + cipher, + en_count: Arc::new(AtomicU64::new(0)), + de_count: Arc::new(AtomicU64::new(0)), + } + } + + /// Encrypts the given message + pub fn encrypt(&self, data: Bytes) -> io::Result { + self.cipher + .encrypt(&self.en_nonce(), &data[..]) + .map(Bytes::from) + .map_err(|_| io::Error::from(ErrorKind::InvalidData)) + } + + /// Decrypts the given message + pub fn decrypt(&self, data: Bytes) -> io::Result { + self.cipher + .decrypt(&self.de_nonce(), &data[..]) + .map(Bytes::from) + .map_err(|_| io::Error::from(ErrorKind::InvalidData)) + } + + /// Updates the stored key. + /// This must be done simultaneously on server and client side + /// to keep track of the nonce + pub fn update_key(&mut self, key: Bytes) { + let key = Key::from_slice(&key[..]); + self.cipher = ChaCha20Poly1305::new(key); + self.reset_counters(); + } + + /// Resets the nonce counters. + /// This must be done simultaneously on server and client side. + pub fn reset_counters(&mut self) { + self.de_count.store(0, Ordering::SeqCst); + self.en_count.store(0, Ordering::SeqCst); + } + + fn en_nonce(&self) -> Nonce { + let count = self.en_count.fetch_add(1, Ordering::SeqCst); + nonce_from_number(count) + } + + fn de_nonce(&self) -> Nonce { + let count = self.de_count.fetch_add(1, Ordering::SeqCst); + nonce_from_number(count) + } +} + +/// Generates a nonce from a given number +/// the nonce is passed through sha224 for pseudo-randomness +fn nonce_from_number(number: u64) -> Nonce { + let number_bytes: [u8; 8] = number.to_be_bytes(); + let sha_bytes = Sha224::digest(&number_bytes).to_vec(); + let mut nonce_bytes = [0u8; 12]; + nonce_bytes.copy_from_slice(&sha_bytes[..12]); + + nonce_bytes.into() +} + +impl EncryptedStream { + /// Does a server-client key exchange. + /// 1. The server receives the public key of the client + /// 2. The server sends its own public key + /// 3. The server creates an intermediary encrypted connection + /// 4. The server generates a new secret + /// 5. The server sends the secret to the client + /// 6. The connection is upgraded with the new shared key + pub async fn from_server_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult { + let other_pub = receive_public_key(&mut inner).await?; + send_public_key(&mut inner, &secret).await?; + let shared_secret = secret.diffie_hellman(&other_pub); + let mut stream = Self::new(inner, shared_secret); + let permanent_secret = generate_secret(); + stream.write_all(&permanent_secret).await?; + stream.flush().await?; + stream.update_key(permanent_secret.into()); + + Ok(stream) + } + + /// Does a client-server key exchange. + /// 1. The client sends its public key to the server + /// 2. The client receives the servers public key + /// 3. The client creates an intermediary encrypted connection + /// 4. The client receives the new key from the server + /// 5. The connection is upgraded with the new shared key + pub async fn from_client_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult { + send_public_key(&mut inner, &secret).await?; + let other_pub = receive_public_key(&mut inner).await?; + let shared_secret = secret.diffie_hellman(&other_pub); + let mut stream = Self::new(inner, shared_secret); + let mut key_buf = vec![0u8; 32]; + stream.read_exact(&mut key_buf).await?; + stream.update_key(key_buf.into()); + + Ok(stream) + } +} + +async fn receive_public_key(stream: &mut T) -> IPCResult { + let mut pk_buf = [0u8; 32]; + stream.read_exact(&mut pk_buf).await?; + + Ok(PublicKey::from(pk_buf)) +} + +async fn send_public_key( + stream: &mut T, + secret: &StaticSecret, +) -> IPCResult<()> { + let own_pk = PublicKey::from(secret); + stream.write_all(own_pk.as_bytes()).await?; + stream.flush().await?; + + Ok(()) +} + +fn generate_secret() -> Vec { + let mut rng = thread_rng(); + let mut buf = vec![0u8; 32]; + rng.fill_bytes(&mut buf); + + Sha256::digest(&buf).to_vec() +} diff --git a/src/protocol/encrypted/io_impl.rs b/src/protocol/encrypted/io_impl.rs new file mode 100644 index 00000000..94360e38 --- /dev/null +++ b/src/protocol/encrypted/io_impl.rs @@ -0,0 +1,220 @@ +use crate::prelude::encrypted::{ + EncryptedPackage, EncryptedReadStream, EncryptedStream, EncryptedWriteStream, +}; +use crate::prelude::AsyncProtocolStream; +use crate::protocol::encrypted::crypt_handling::CipherBox; +use bytes::{Buf, BufMut, Bytes}; +use std::cmp::min; +use std::io; +use std::io::Error; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; + +const WRITE_BUF_SIZE: usize = 1024; + +impl Unpin for EncryptedStream {} + +impl AsyncWrite for EncryptedStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.write_half).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.write_half).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.write_half).poll_shutdown(cx) + } +} + +impl AsyncRead for EncryptedStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.read_half).poll_read(cx, buf) + } +} + +impl Unpin for EncryptedReadStream {} + +impl AsyncRead for EncryptedReadStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.fut.is_none() { + let max_copy = min(buf.remaining(), self.remaining.len()); + let bytes = self.remaining.copy_to_bytes(max_copy); + buf.put_slice(&bytes); + + if buf.remaining() > 0 { + let mut reader = self.inner.take().unwrap(); + let cipher = self.cipher.take().unwrap(); + + self.fut = Some(Box::pin(async move { + let package = match EncryptedPackage::from_async_read(&mut reader).await { + Ok(p) => p, + Err(e) => { + return (Err(e), reader, cipher); + } + }; + match cipher.decrypt(package.into_inner()) { + Ok(bytes) => (Ok(bytes), reader, cipher), + Err(e) => (Err(e), reader, cipher), + } + })); + } + } + if self.fut.is_some() { + match self.fut.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, reader, cipher)) => { + self.inner = Some(reader); + self.cipher = Some(cipher); + match result { + Ok(bytes) => { + self.remaining.put(bytes); + let max_copy = min(self.remaining.len(), buf.remaining()); + let bytes = self.remaining.copy_to_bytes(max_copy); + self.fut = None; + buf.put_slice(&bytes); + + if buf.remaining() == 0 { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + Err(e) => Poll::Ready(Err(e)), + } + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl Unpin for EncryptedWriteStream {} + +impl AsyncWrite for EncryptedWriteStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if buf.remaining() > 0 { + let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) }; + self.buffer.put(Bytes::from(buf)); + + if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE { + let buffer_len = self.buffer.len(); + let max_copy = min(u32::MAX as usize, buffer_len); + let plaintext = self.buffer.copy_to_bytes(max_copy); + let writer = self.inner.take().unwrap(); + let cipher = self.cipher.take().unwrap(); + + self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher))) + } + } + if self.fut_write.is_some() { + match self.fut_write.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, writer, cipher)) => { + self.inner = Some(writer); + self.cipher = Some(cipher); + self.fut_write = None; + + Poll::Ready(result.map(|_| buf.len())) + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let buffer_len = self.buffer.len(); + + if !self.buffer.is_empty() && self.fut_flush.is_none() { + let max_copy = min(u32::MAX as usize, buffer_len); + let plaintext = self.buffer.copy_to_bytes(max_copy); + let writer = self.inner.take().unwrap(); + let cipher = self.cipher.take().unwrap(); + + self.fut_flush = Some(Box::pin(async move { + let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await; + if result.is_err() { + return (result, writer, cipher); + } + if let Err(e) = writer.flush().await { + (Err(e), writer, cipher) + } else { + (Ok(()), writer, cipher) + } + })) + } + match self.fut_flush.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready((result, writer, cipher)) => { + self.inner = Some(writer); + self.cipher = Some(cipher); + self.fut_flush = None; + + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.fut_shutdown.is_none() { + match self.as_mut().poll_flush(cx) { + Poll::Ready(result) => match result { + Ok(_) => { + let mut writer = self.inner.take().unwrap(); + self.fut_shutdown = Some(Box::pin(async move { writer.shutdown().await })); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + }, + Poll::Pending => Poll::Pending, + } + } else { + match self.fut_shutdown.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready(result) => { + self.fut_shutdown = None; + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + } +} + +async fn write_bytes( + bytes: Bytes, + mut writer: T, + cipher: CipherBox, +) -> (io::Result<()>, T, CipherBox) { + let encrypted_bytes = match cipher.encrypt(bytes) { + Ok(b) => b, + Err(e) => { + return (Err(e), writer, cipher); + } + }; + let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes(); + if let Err(e) = writer.write_all(&package_bytes[..]).await { + return (Err(e), writer, cipher); + } + + (Ok(()), writer, cipher) +} diff --git a/src/protocol/encrypted/mod.rs b/src/protocol/encrypted/mod.rs new file mode 100644 index 00000000..7126dcdc --- /dev/null +++ b/src/protocol/encrypted/mod.rs @@ -0,0 +1,152 @@ +mod crypt_handling; +mod io_impl; +mod protocol_impl; + +use bytes::{BufMut, Bytes, BytesMut}; +pub use io_impl::*; +pub use protocol_impl::*; +use rand_core::RngCore; +use std::future::Future; +use std::io; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter}; +use x25519_dalek::{SharedSecret, StaticSecret}; + +use crate::prelude::encrypted::crypt_handling::CipherBox; +use crate::prelude::{AsyncProtocolStream, AsyncStreamProtocolListener}; + +pub type OptionalFuture = Option + Send + Sync>>>; + +#[derive(Clone)] +pub struct EncryptionOptions { + pub inner_options: T, + pub secret: StaticSecret, +} + +impl Default for EncryptionOptions { + fn default() -> Self { + let mut rng = rand::thread_rng(); + let mut secret = [0u8; 32]; + rng.fill_bytes(&mut secret); + + Self { + secret: StaticSecret::from(secret), + inner_options: T::default(), + } + } +} + +pub struct EncryptedListener { + inner: T, + secret: StaticSecret, +} + +impl EncryptedListener { + pub fn new(inner: T, secret: StaticSecret) -> Self { + Self { inner, secret } + } +} + +pub struct EncryptedStream { + read_half: EncryptedReadStream, + write_half: EncryptedWriteStream, +} + +impl EncryptedStream { + pub fn new(inner: T, secret: SharedSecret) -> Self { + let cipher_box = CipherBox::new(Bytes::from(secret.to_bytes().to_vec())); + let (read, write) = inner.protocol_into_split(); + let read_half = EncryptedReadStream::new(read, cipher_box.clone()); + let write_half = EncryptedWriteStream::new(write, cipher_box); + + Self { + read_half, + write_half, + } + } + + pub fn update_key(&mut self, key: Bytes) { + self.write_half + .cipher + .as_mut() + .unwrap() + .update_key(key.clone()); + self.read_half + .cipher + .as_mut() + .unwrap() + .update_key(key.clone()); + } +} + +pub struct EncryptedReadStream { + inner: Option>, + fut: OptionalFuture<(io::Result, BufReader, CipherBox)>, + remaining: BytesMut, + cipher: Option, +} + +impl EncryptedReadStream { + pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { + Self { + inner: Some(BufReader::new(inner)), + fut: None, + remaining: BytesMut::new(), + cipher: Some(cipher), + } + } +} + +pub struct EncryptedWriteStream { + inner: Option>, + cipher: Option, + buffer: BytesMut, + fut_write: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, + fut_flush: OptionalFuture<(io::Result<()>, BufWriter, CipherBox)>, + fut_shutdown: OptionalFuture>, +} + +impl EncryptedWriteStream { + pub(crate) fn new(inner: T, cipher: CipherBox) -> Self { + Self { + inner: Some(BufWriter::new(inner)), + cipher: Some(cipher), + buffer: BytesMut::with_capacity(1024), + fut_write: None, + fut_flush: None, + fut_shutdown: None, + } + } +} + +pub(crate) struct EncryptedPackage { + bytes: Bytes, +} + +impl EncryptedPackage { + pub fn new(bytes: Bytes) -> Self { + Self { bytes } + } + + pub fn into_bytes(self) -> Bytes { + let mut buf = BytesMut::with_capacity(4 + self.bytes.len()); + buf.put_u32(self.bytes.len() as u32); + buf.put(self.bytes); + + buf.freeze() + } + + pub async fn from_async_read(reader: &mut R) -> io::Result { + let length = reader.read_u32().await?; + let mut bytes_buf = vec![0u8; length as usize]; + reader.read_exact(&mut bytes_buf).await?; + + Ok(Self { + bytes: Bytes::from(bytes_buf), + }) + } + + pub fn into_inner(self) -> Bytes { + self.bytes + } +} diff --git a/src/protocol/encrypted/protocol_impl.rs b/src/protocol/encrypted/protocol_impl.rs new file mode 100644 index 00000000..53d664d2 --- /dev/null +++ b/src/protocol/encrypted/protocol_impl.rs @@ -0,0 +1,55 @@ +use crate::error::Result; +use crate::prelude::encrypted::{EncryptedReadStream, EncryptedWriteStream}; +use crate::prelude::{AsyncProtocolStreamSplit, IPCResult}; +use crate::protocol::encrypted::{EncryptedListener, EncryptedStream, EncryptionOptions}; +use crate::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener}; +use async_trait::async_trait; + +#[async_trait] +impl AsyncStreamProtocolListener for EncryptedListener { + type AddressType = T::AddressType; + type RemoteAddressType = T::RemoteAddressType; + type Stream = EncryptedStream; + type ListenerOptions = EncryptionOptions; + + async fn protocol_bind( + address: Self::AddressType, + options: Self::ListenerOptions, + ) -> IPCResult { + let inner = T::protocol_bind(address, options.inner_options).await?; + + Ok(EncryptedListener::new(inner, options.secret)) + } + + async fn protocol_accept(&self) -> IPCResult<(Self::Stream, Self::RemoteAddressType)> { + let (inner_stream, remote_addr) = self.inner.protocol_accept().await?; + let stream = + Self::Stream::from_server_key_exchange(inner_stream, self.secret.clone()).await?; + + Ok((stream, remote_addr)) + } +} + +#[async_trait] +impl AsyncProtocolStream for EncryptedStream { + type AddressType = T::AddressType; + type StreamOptions = EncryptionOptions; + + async fn protocol_connect( + address: Self::AddressType, + options: Self::StreamOptions, + ) -> Result { + let inner = T::protocol_connect(address, options.inner_options).await?; + EncryptedStream::from_client_key_exchange(inner, options.secret).await + } +} + +#[async_trait] +impl AsyncProtocolStreamSplit for EncryptedStream { + type OwnedSplitReadHalf = EncryptedReadStream; + type OwnedSplitWriteHalf = EncryptedWriteStream; + + fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) { + (self.read_half, self.write_half) + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 6fa9c8b2..2792fb6f 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,5 +1,6 @@ pub mod tcp; +pub mod encrypted; #[cfg(unix)] pub mod unix_socket; @@ -12,25 +13,33 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub trait AsyncStreamProtocolListener: Sized + Send + Sync { type AddressType: Clone + Debug + Send + Sync; type RemoteAddressType: Debug + Send + Sync; - type Stream: 'static + AsyncProtocolStream + Send + Sync; + type Stream: 'static + AsyncProtocolStream; + type ListenerOptions: Clone + Default + Send + Sync; - async fn protocol_bind(address: Self::AddressType) -> IPCResult; + async fn protocol_bind( + address: Self::AddressType, + options: Self::ListenerOptions, + ) -> IPCResult; async fn protocol_accept(&self) -> IPCResult<(Self::Stream, Self::RemoteAddressType)>; } pub trait AsyncProtocolStreamSplit { - type OwnedSplitReadHalf: AsyncRead + Send + Sync + Unpin; - type OwnedSplitWriteHalf: AsyncWrite + Send + Sync + Unpin; + type OwnedSplitReadHalf: 'static + AsyncRead + Send + Sync + Unpin; + type OwnedSplitWriteHalf: 'static + AsyncWrite + Send + Sync + Unpin; fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf); } #[async_trait] pub trait AsyncProtocolStream: - AsyncRead + AsyncWrite + Send + Sync + AsyncProtocolStreamSplit + Sized + AsyncRead + AsyncWrite + Send + Sync + AsyncProtocolStreamSplit + Sized + Unpin { type AddressType: Clone + Debug + Send + Sync; + type StreamOptions: Clone + Default + Send + Sync; - async fn protocol_connect(address: Self::AddressType) -> IPCResult; + async fn protocol_connect( + address: Self::AddressType, + options: Self::StreamOptions, + ) -> IPCResult; } diff --git a/src/protocol/tcp.rs b/src/protocol/tcp.rs index abe43d4f..f9369b4f 100644 --- a/src/protocol/tcp.rs +++ b/src/protocol/tcp.rs @@ -10,8 +10,12 @@ impl AsyncStreamProtocolListener for TcpListener { type AddressType = SocketAddr; type RemoteAddressType = SocketAddr; type Stream = TcpStream; + type ListenerOptions = (); - async fn protocol_bind(address: Self::AddressType) -> IPCResult { + async fn protocol_bind( + address: Self::AddressType, + _: Self::ListenerOptions, + ) -> IPCResult { let listener = TcpListener::bind(address).await?; Ok(listener) @@ -36,8 +40,12 @@ impl AsyncProtocolStreamSplit for TcpStream { #[async_trait] impl AsyncProtocolStream for TcpStream { type AddressType = SocketAddr; + type StreamOptions = (); - async fn protocol_connect(address: Self::AddressType) -> IPCResult { + async fn protocol_connect( + address: Self::AddressType, + _: Self::StreamOptions, + ) -> IPCResult { let stream = TcpStream::connect(address).await?; Ok(stream) diff --git a/src/protocol/unix_socket.rs b/src/protocol/unix_socket.rs index 91809083..2cbf16f0 100644 --- a/src/protocol/unix_socket.rs +++ b/src/protocol/unix_socket.rs @@ -13,8 +13,9 @@ impl AsyncStreamProtocolListener for UnixListener { type AddressType = PathBuf; type RemoteAddressType = SocketAddr; type Stream = UnixStream; + type ListenerOptions = (); - async fn protocol_bind(address: Self::AddressType) -> Result { + async fn protocol_bind(address: Self::AddressType, _: Self::ListenerOptions) -> Result { let listener = UnixListener::bind(address)?; Ok(listener) @@ -39,8 +40,12 @@ impl AsyncProtocolStreamSplit for UnixStream { #[async_trait] impl AsyncProtocolStream for UnixStream { type AddressType = PathBuf; + type StreamOptions = (); - async fn protocol_connect(address: Self::AddressType) -> IPCResult { + async fn protocol_connect( + address: Self::AddressType, + _: Self::StreamOptions, + ) -> IPCResult { let stream = UnixStream::connect(address).await?; stream .ready(Interest::READABLE | Interest::WRITABLE) diff --git a/tests/test_encryption.rs b/tests/test_encryption.rs new file mode 100644 index 00000000..3ab90988 --- /dev/null +++ b/tests/test_encryption.rs @@ -0,0 +1,101 @@ +use crate::utils::call_counter::increment_counter_for_event; +use crate::utils::protocol::TestProtocolListener; +use crate::utils::{get_free_port, start_server_and_client}; +use bromine::prelude::encrypted::EncryptedListener; +use bromine::prelude::*; +use bromine::IPCBuilder; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, Bytes, BytesMut}; +use futures::StreamExt; +use rand_core::RngCore; +use std::io::Read; +use std::time::Duration; + +mod utils; + +#[tokio::test] +async fn it_sends_and_receives() { + let ctx = get_client_with_server().await; + let mut rng = rand::thread_rng(); + let mut buffer = vec![0u8; 140]; + rng.fill_bytes(&mut buffer); + + let mut stream = ctx + .emit("bytes", BytePayload::new(buffer.clone())) + .stream_replies() + .await + .unwrap(); + let mut count = 0; + + while let Some(Ok(response)) = stream.next().await { + let bytes = response.payload::().unwrap(); + assert_eq!(bytes.into_inner(), buffer); + count += 1; + } + assert_eq!(count, 100) +} + +#[tokio::test] +async fn it_sends_and_receives_strings() { + let ctx = get_client_with_server().await; + let response = ctx + .emit("string", StringPayload(String::from("Hello World"))) + .await_reply() + .await + .unwrap(); + let response_string = response.payload::().unwrap().0; + + assert_eq!(&response_string, "Hello World") +} + +async fn get_client_with_server() -> Context { + let port = get_free_port(); + + start_server_and_client(move || get_builder(port)).await +} + +fn get_builder(port: u8) -> IPCBuilder> { + IPCBuilder::new() + .address(port) + .on("bytes", callback!(handle_bytes)) + .on("string", callback!(handle_string)) + .timeout(Duration::from_millis(100)) +} + +async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult { + increment_counter_for_event(ctx, &event).await; + let bytes = event.payload::()?.into_bytes(); + + for _ in 0u8..99 { + ctx.emit("bytes", BytePayload::from(bytes.clone())).await?; + } + + ctx.response(BytePayload::from(bytes)) +} + +async fn handle_string(ctx: &Context, event: Event) -> IPCResult { + ctx.response(event.payload::()?) +} + +pub struct StringPayload(String); + +impl IntoPayload for StringPayload { + fn into_payload(self, _: &Context) -> IPCResult { + let mut buf = BytesMut::with_capacity(self.0.len() + 4); + buf.put_u32(self.0.len() as u32); + buf.put(Bytes::from(self.0)); + + Ok(buf.freeze()) + } +} + +impl FromPayload for StringPayload { + fn from_payload(mut reader: R) -> IPCResult { + let len = reader.read_u32::()?; + let mut buf = vec![0u8; len as usize]; + reader.read_exact(&mut buf)?; + let string = String::from_utf8(buf).map_err(|_| IPCError::from("not a string"))?; + + Ok(StringPayload(string)) + } +} diff --git a/tests/utils/call_counter.rs b/tests/utils/call_counter.rs index a9d61dd9..beeec79f 100644 --- a/tests/utils/call_counter.rs +++ b/tests/utils/call_counter.rs @@ -1,3 +1,4 @@ +#![allow(unused)] use bromine::context::Context; use bromine::event::Event; use std::collections::HashMap; diff --git a/tests/utils/protocol.rs b/tests/utils/protocol.rs index 329645de..f95c18df 100644 --- a/tests/utils/protocol.rs +++ b/tests/utils/protocol.rs @@ -62,8 +62,9 @@ impl AsyncStreamProtocolListener for TestProtocolListener { type AddressType = u8; type RemoteAddressType = u8; type Stream = TestProtocolStream; + type ListenerOptions = (); - async fn protocol_bind(address: Self::AddressType) -> Result { + async fn protocol_bind(address: Self::AddressType, _: Self::ListenerOptions) -> Result { let (sender, receiver) = channel(1); add_port(address, sender).await; @@ -170,8 +171,9 @@ impl AsyncProtocolStreamSplit for TestProtocolStream { #[async_trait] impl AsyncProtocolStream for TestProtocolStream { type AddressType = u8; + type StreamOptions = (); - async fn protocol_connect(address: Self::AddressType) -> Result { + async fn protocol_connect(address: Self::AddressType, _: Self::StreamOptions) -> Result { get_port(address) .await .ok_or_else(|| IPCError::from("Failed to connect"))