mirror of https://github.com/Trivernis/bromine.git
Add encrypted wrapper protocol implementation
Signed-off-by: trivernis <trivernis@protonmail.com>pull/38/head
parent
ac471d296e
commit
fe7dc97008
@ -0,0 +1,158 @@
|
|||||||
|
use crate::prelude::encrypted::EncryptedStream;
|
||||||
|
use crate::prelude::IPCResult;
|
||||||
|
use crate::protocol::AsyncProtocolStream;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use chacha20poly1305::aead::{Aead, NewAead};
|
||||||
|
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
|
||||||
|
use rand::thread_rng;
|
||||||
|
use rand_core::RngCore;
|
||||||
|
use sha2::{Digest, Sha224, Sha256};
|
||||||
|
use std::io;
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use x25519_dalek::{PublicKey, StaticSecret};
|
||||||
|
|
||||||
|
/// A structure used for encryption.
|
||||||
|
/// It holds the cipher initialized with the given key
|
||||||
|
/// and two counters for both encryption and decryption
|
||||||
|
/// count which are used to keep track of the nonce.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct CipherBox {
|
||||||
|
cipher: ChaCha20Poly1305,
|
||||||
|
en_count: Arc<AtomicU64>,
|
||||||
|
de_count: Arc<AtomicU64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CipherBox {
|
||||||
|
pub fn new(key: Bytes) -> Self {
|
||||||
|
let key = Key::from_slice(&key[..]);
|
||||||
|
let cipher = ChaCha20Poly1305::new(key);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
cipher,
|
||||||
|
en_count: Arc::new(AtomicU64::new(0)),
|
||||||
|
de_count: Arc::new(AtomicU64::new(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encrypts the given message
|
||||||
|
pub fn encrypt(&self, data: Bytes) -> io::Result<Bytes> {
|
||||||
|
self.cipher
|
||||||
|
.encrypt(&self.en_nonce(), &data[..])
|
||||||
|
.map(Bytes::from)
|
||||||
|
.map_err(|_| io::Error::from(ErrorKind::InvalidData))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypts the given message
|
||||||
|
pub fn decrypt(&self, data: Bytes) -> io::Result<Bytes> {
|
||||||
|
self.cipher
|
||||||
|
.decrypt(&self.de_nonce(), &data[..])
|
||||||
|
.map(Bytes::from)
|
||||||
|
.map_err(|_| io::Error::from(ErrorKind::InvalidData))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates the stored key.
|
||||||
|
/// This must be done simultaneously on server and client side
|
||||||
|
/// to keep track of the nonce
|
||||||
|
pub fn update_key(&mut self, key: Bytes) {
|
||||||
|
let key = Key::from_slice(&key[..]);
|
||||||
|
self.cipher = ChaCha20Poly1305::new(key);
|
||||||
|
self.reset_counters();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resets the nonce counters.
|
||||||
|
/// This must be done simultaneously on server and client side.
|
||||||
|
pub fn reset_counters(&mut self) {
|
||||||
|
self.de_count.store(0, Ordering::SeqCst);
|
||||||
|
self.en_count.store(0, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn en_nonce(&self) -> Nonce {
|
||||||
|
let count = self.en_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
nonce_from_number(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn de_nonce(&self) -> Nonce {
|
||||||
|
let count = self.de_count.fetch_add(1, Ordering::SeqCst);
|
||||||
|
nonce_from_number(count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a nonce from a given number
|
||||||
|
/// the nonce is passed through sha224 for pseudo-randomness
|
||||||
|
fn nonce_from_number(number: u64) -> Nonce {
|
||||||
|
let number_bytes: [u8; 8] = number.to_be_bytes();
|
||||||
|
let sha_bytes = Sha224::digest(&number_bytes).to_vec();
|
||||||
|
let mut nonce_bytes = [0u8; 12];
|
||||||
|
nonce_bytes.copy_from_slice(&sha_bytes[..12]);
|
||||||
|
|
||||||
|
nonce_bytes.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncProtocolStream> EncryptedStream<T> {
|
||||||
|
/// Does a server-client key exchange.
|
||||||
|
/// 1. The server receives the public key of the client
|
||||||
|
/// 2. The server sends its own public key
|
||||||
|
/// 3. The server creates an intermediary encrypted connection
|
||||||
|
/// 4. The server generates a new secret
|
||||||
|
/// 5. The server sends the secret to the client
|
||||||
|
/// 6. The connection is upgraded with the new shared key
|
||||||
|
pub async fn from_server_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult<Self> {
|
||||||
|
let other_pub = receive_public_key(&mut inner).await?;
|
||||||
|
send_public_key(&mut inner, &secret).await?;
|
||||||
|
let shared_secret = secret.diffie_hellman(&other_pub);
|
||||||
|
let mut stream = Self::new(inner, shared_secret);
|
||||||
|
let permanent_secret = generate_secret();
|
||||||
|
stream.write_all(&permanent_secret).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
stream.update_key(permanent_secret.into());
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Does a client-server key exchange.
|
||||||
|
/// 1. The client sends its public key to the server
|
||||||
|
/// 2. The client receives the servers public key
|
||||||
|
/// 3. The client creates an intermediary encrypted connection
|
||||||
|
/// 4. The client receives the new key from the server
|
||||||
|
/// 5. The connection is upgraded with the new shared key
|
||||||
|
pub async fn from_client_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult<Self> {
|
||||||
|
send_public_key(&mut inner, &secret).await?;
|
||||||
|
let other_pub = receive_public_key(&mut inner).await?;
|
||||||
|
let shared_secret = secret.diffie_hellman(&other_pub);
|
||||||
|
let mut stream = Self::new(inner, shared_secret);
|
||||||
|
let mut key_buf = vec![0u8; 32];
|
||||||
|
stream.read_exact(&mut key_buf).await?;
|
||||||
|
stream.update_key(key_buf.into());
|
||||||
|
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult<PublicKey> {
|
||||||
|
let mut pk_buf = [0u8; 32];
|
||||||
|
stream.read_exact(&mut pk_buf).await?;
|
||||||
|
|
||||||
|
Ok(PublicKey::from(pk_buf))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_public_key<T: AsyncProtocolStream>(
|
||||||
|
stream: &mut T,
|
||||||
|
secret: &StaticSecret,
|
||||||
|
) -> IPCResult<()> {
|
||||||
|
let own_pk = PublicKey::from(secret);
|
||||||
|
stream.write_all(own_pk.as_bytes()).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_secret() -> Vec<u8> {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
let mut buf = vec![0u8; 32];
|
||||||
|
rng.fill_bytes(&mut buf);
|
||||||
|
|
||||||
|
Sha256::digest(&buf).to_vec()
|
||||||
|
}
|
@ -0,0 +1,220 @@
|
|||||||
|
use crate::prelude::encrypted::{
|
||||||
|
EncryptedPackage, EncryptedReadStream, EncryptedStream, EncryptedWriteStream,
|
||||||
|
};
|
||||||
|
use crate::prelude::AsyncProtocolStream;
|
||||||
|
use crate::protocol::encrypted::crypt_handling::CipherBox;
|
||||||
|
use bytes::{Buf, BufMut, Bytes};
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::io;
|
||||||
|
use std::io::Error;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
|
|
||||||
|
const WRITE_BUF_SIZE: usize = 1024;
|
||||||
|
|
||||||
|
impl<T: AsyncProtocolStream> Unpin for EncryptedStream<T> {}
|
||||||
|
|
||||||
|
impl<T: AsyncProtocolStream> AsyncWrite for EncryptedStream<T> {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, Error>> {
|
||||||
|
Pin::new(&mut self.write_half).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
Pin::new(&mut self.write_half).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
Pin::new(&mut self.write_half).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncProtocolStream> AsyncRead for EncryptedStream<T> {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.read_half).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: 'static + AsyncRead + Unpin + Send + Sync> Unpin for EncryptedReadStream<T> {}
|
||||||
|
|
||||||
|
impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadStream<T> {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
if self.fut.is_none() {
|
||||||
|
let max_copy = min(buf.remaining(), self.remaining.len());
|
||||||
|
let bytes = self.remaining.copy_to_bytes(max_copy);
|
||||||
|
buf.put_slice(&bytes);
|
||||||
|
|
||||||
|
if buf.remaining() > 0 {
|
||||||
|
let mut reader = self.inner.take().unwrap();
|
||||||
|
let cipher = self.cipher.take().unwrap();
|
||||||
|
|
||||||
|
self.fut = Some(Box::pin(async move {
|
||||||
|
let package = match EncryptedPackage::from_async_read(&mut reader).await {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return (Err(e), reader, cipher);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match cipher.decrypt(package.into_inner()) {
|
||||||
|
Ok(bytes) => (Ok(bytes), reader, cipher),
|
||||||
|
Err(e) => (Err(e), reader, cipher),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.fut.is_some() {
|
||||||
|
match self.fut.as_mut().unwrap().as_mut().poll(cx) {
|
||||||
|
Poll::Ready((result, reader, cipher)) => {
|
||||||
|
self.inner = Some(reader);
|
||||||
|
self.cipher = Some(cipher);
|
||||||
|
match result {
|
||||||
|
Ok(bytes) => {
|
||||||
|
self.remaining.put(bytes);
|
||||||
|
let max_copy = min(self.remaining.len(), buf.remaining());
|
||||||
|
let bytes = self.remaining.copy_to_bytes(max_copy);
|
||||||
|
self.fut = None;
|
||||||
|
buf.put_slice(&bytes);
|
||||||
|
|
||||||
|
if buf.remaining() == 0 {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Poll::Ready(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> Unpin for EncryptedWriteStream<T> {}
|
||||||
|
|
||||||
|
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWriteStream<T> {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, Error>> {
|
||||||
|
if buf.remaining() > 0 {
|
||||||
|
let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) };
|
||||||
|
self.buffer.put(Bytes::from(buf));
|
||||||
|
|
||||||
|
if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE {
|
||||||
|
let buffer_len = self.buffer.len();
|
||||||
|
let max_copy = min(u32::MAX as usize, buffer_len);
|
||||||
|
let plaintext = self.buffer.copy_to_bytes(max_copy);
|
||||||
|
let writer = self.inner.take().unwrap();
|
||||||
|
let cipher = self.cipher.take().unwrap();
|
||||||
|
|
||||||
|
self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if self.fut_write.is_some() {
|
||||||
|
match self.fut_write.as_mut().unwrap().as_mut().poll(cx) {
|
||||||
|
Poll::Ready((result, writer, cipher)) => {
|
||||||
|
self.inner = Some(writer);
|
||||||
|
self.cipher = Some(cipher);
|
||||||
|
self.fut_write = None;
|
||||||
|
|
||||||
|
Poll::Ready(result.map(|_| buf.len()))
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Poll::Ready(Ok(buf.len()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
let buffer_len = self.buffer.len();
|
||||||
|
|
||||||
|
if !self.buffer.is_empty() && self.fut_flush.is_none() {
|
||||||
|
let max_copy = min(u32::MAX as usize, buffer_len);
|
||||||
|
let plaintext = self.buffer.copy_to_bytes(max_copy);
|
||||||
|
let writer = self.inner.take().unwrap();
|
||||||
|
let cipher = self.cipher.take().unwrap();
|
||||||
|
|
||||||
|
self.fut_flush = Some(Box::pin(async move {
|
||||||
|
let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await;
|
||||||
|
if result.is_err() {
|
||||||
|
return (result, writer, cipher);
|
||||||
|
}
|
||||||
|
if let Err(e) = writer.flush().await {
|
||||||
|
(Err(e), writer, cipher)
|
||||||
|
} else {
|
||||||
|
(Ok(()), writer, cipher)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
match self.fut_flush.as_mut().unwrap().as_mut().poll(cx) {
|
||||||
|
Poll::Ready((result, writer, cipher)) => {
|
||||||
|
self.inner = Some(writer);
|
||||||
|
self.cipher = Some(cipher);
|
||||||
|
self.fut_flush = None;
|
||||||
|
|
||||||
|
Poll::Ready(result)
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
if self.fut_shutdown.is_none() {
|
||||||
|
match self.as_mut().poll_flush(cx) {
|
||||||
|
Poll::Ready(result) => match result {
|
||||||
|
Ok(_) => {
|
||||||
|
let mut writer = self.inner.take().unwrap();
|
||||||
|
self.fut_shutdown = Some(Box::pin(async move { writer.shutdown().await }));
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
Err(e) => Poll::Ready(Err(e)),
|
||||||
|
},
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match self.fut_shutdown.as_mut().unwrap().as_mut().poll(cx) {
|
||||||
|
Poll::Ready(result) => {
|
||||||
|
self.fut_shutdown = None;
|
||||||
|
Poll::Ready(result)
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_bytes<T: AsyncWrite + Unpin>(
|
||||||
|
bytes: Bytes,
|
||||||
|
mut writer: T,
|
||||||
|
cipher: CipherBox,
|
||||||
|
) -> (io::Result<()>, T, CipherBox) {
|
||||||
|
let encrypted_bytes = match cipher.encrypt(bytes) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
return (Err(e), writer, cipher);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes();
|
||||||
|
if let Err(e) = writer.write_all(&package_bytes[..]).await {
|
||||||
|
return (Err(e), writer, cipher);
|
||||||
|
}
|
||||||
|
|
||||||
|
(Ok(()), writer, cipher)
|
||||||
|
}
|
@ -0,0 +1,152 @@
|
|||||||
|
mod crypt_handling;
|
||||||
|
mod io_impl;
|
||||||
|
mod protocol_impl;
|
||||||
|
|
||||||
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
pub use io_impl::*;
|
||||||
|
pub use protocol_impl::*;
|
||||||
|
use rand_core::RngCore;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::io;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufReader, BufWriter};
|
||||||
|
use x25519_dalek::{SharedSecret, StaticSecret};
|
||||||
|
|
||||||
|
use crate::prelude::encrypted::crypt_handling::CipherBox;
|
||||||
|
use crate::prelude::{AsyncProtocolStream, AsyncStreamProtocolListener};
|
||||||
|
|
||||||
|
pub type OptionalFuture<T> = Option<Pin<Box<dyn Future<Output = T> + Send + Sync>>>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct EncryptionOptions<T: Clone + Default> {
|
||||||
|
pub inner_options: T,
|
||||||
|
pub secret: StaticSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Clone + Default> Default for EncryptionOptions<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let mut secret = [0u8; 32];
|
||||||
|
rng.fill_bytes(&mut secret);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
secret: StaticSecret::from(secret),
|
||||||
|
inner_options: T::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EncryptedListener<T: AsyncStreamProtocolListener> {
|
||||||
|
inner: T,
|
||||||
|
secret: StaticSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncStreamProtocolListener> EncryptedListener<T> {
|
||||||
|
pub fn new(inner: T, secret: StaticSecret) -> Self {
|
||||||
|
Self { inner, secret }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EncryptedStream<T: AsyncProtocolStream> {
|
||||||
|
read_half: EncryptedReadStream<T::OwnedSplitReadHalf>,
|
||||||
|
write_half: EncryptedWriteStream<T::OwnedSplitWriteHalf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncProtocolStream> EncryptedStream<T> {
|
||||||
|
pub fn new(inner: T, secret: SharedSecret) -> Self {
|
||||||
|
let cipher_box = CipherBox::new(Bytes::from(secret.to_bytes().to_vec()));
|
||||||
|
let (read, write) = inner.protocol_into_split();
|
||||||
|
let read_half = EncryptedReadStream::new(read, cipher_box.clone());
|
||||||
|
let write_half = EncryptedWriteStream::new(write, cipher_box);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
read_half,
|
||||||
|
write_half,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_key(&mut self, key: Bytes) {
|
||||||
|
self.write_half
|
||||||
|
.cipher
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.update_key(key.clone());
|
||||||
|
self.read_half
|
||||||
|
.cipher
|
||||||
|
.as_mut()
|
||||||
|
.unwrap()
|
||||||
|
.update_key(key.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EncryptedReadStream<T: AsyncRead> {
|
||||||
|
inner: Option<BufReader<T>>,
|
||||||
|
fut: OptionalFuture<(io::Result<Bytes>, BufReader<T>, CipherBox)>,
|
||||||
|
remaining: BytesMut,
|
||||||
|
cipher: Option<CipherBox>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> {
|
||||||
|
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Some(BufReader::new(inner)),
|
||||||
|
fut: None,
|
||||||
|
remaining: BytesMut::new(),
|
||||||
|
cipher: Some(cipher),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EncryptedWriteStream<T: 'static + AsyncWrite + Unpin + Send + Sync> {
|
||||||
|
inner: Option<BufWriter<T>>,
|
||||||
|
cipher: Option<CipherBox>,
|
||||||
|
buffer: BytesMut,
|
||||||
|
fut_write: OptionalFuture<(io::Result<()>, BufWriter<T>, CipherBox)>,
|
||||||
|
fut_flush: OptionalFuture<(io::Result<()>, BufWriter<T>, CipherBox)>,
|
||||||
|
fut_shutdown: OptionalFuture<io::Result<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> EncryptedWriteStream<T> {
|
||||||
|
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: Some(BufWriter::new(inner)),
|
||||||
|
cipher: Some(cipher),
|
||||||
|
buffer: BytesMut::with_capacity(1024),
|
||||||
|
fut_write: None,
|
||||||
|
fut_flush: None,
|
||||||
|
fut_shutdown: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct EncryptedPackage {
|
||||||
|
bytes: Bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncryptedPackage {
|
||||||
|
pub fn new(bytes: Bytes) -> Self {
|
||||||
|
Self { bytes }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_bytes(self) -> Bytes {
|
||||||
|
let mut buf = BytesMut::with_capacity(4 + self.bytes.len());
|
||||||
|
buf.put_u32(self.bytes.len() as u32);
|
||||||
|
buf.put(self.bytes);
|
||||||
|
|
||||||
|
buf.freeze()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
|
||||||
|
let length = reader.read_u32().await?;
|
||||||
|
let mut bytes_buf = vec![0u8; length as usize];
|
||||||
|
reader.read_exact(&mut bytes_buf).await?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
bytes: Bytes::from(bytes_buf),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_inner(self) -> Bytes {
|
||||||
|
self.bytes
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,55 @@
|
|||||||
|
use crate::error::Result;
|
||||||
|
use crate::prelude::encrypted::{EncryptedReadStream, EncryptedWriteStream};
|
||||||
|
use crate::prelude::{AsyncProtocolStreamSplit, IPCResult};
|
||||||
|
use crate::protocol::encrypted::{EncryptedListener, EncryptedStream, EncryptionOptions};
|
||||||
|
use crate::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T: AsyncStreamProtocolListener> AsyncStreamProtocolListener for EncryptedListener<T> {
|
||||||
|
type AddressType = T::AddressType;
|
||||||
|
type RemoteAddressType = T::RemoteAddressType;
|
||||||
|
type Stream = EncryptedStream<T::Stream>;
|
||||||
|
type ListenerOptions = EncryptionOptions<T::ListenerOptions>;
|
||||||
|
|
||||||
|
async fn protocol_bind(
|
||||||
|
address: Self::AddressType,
|
||||||
|
options: Self::ListenerOptions,
|
||||||
|
) -> IPCResult<Self> {
|
||||||
|
let inner = T::protocol_bind(address, options.inner_options).await?;
|
||||||
|
|
||||||
|
Ok(EncryptedListener::new(inner, options.secret))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn protocol_accept(&self) -> IPCResult<(Self::Stream, Self::RemoteAddressType)> {
|
||||||
|
let (inner_stream, remote_addr) = self.inner.protocol_accept().await?;
|
||||||
|
let stream =
|
||||||
|
Self::Stream::from_server_key_exchange(inner_stream, self.secret.clone()).await?;
|
||||||
|
|
||||||
|
Ok((stream, remote_addr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T: AsyncProtocolStream> AsyncProtocolStream for EncryptedStream<T> {
|
||||||
|
type AddressType = T::AddressType;
|
||||||
|
type StreamOptions = EncryptionOptions<T::StreamOptions>;
|
||||||
|
|
||||||
|
async fn protocol_connect(
|
||||||
|
address: Self::AddressType,
|
||||||
|
options: Self::StreamOptions,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let inner = T::protocol_connect(address, options.inner_options).await?;
|
||||||
|
EncryptedStream::from_client_key_exchange(inner, options.secret).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<T: AsyncProtocolStream> AsyncProtocolStreamSplit for EncryptedStream<T> {
|
||||||
|
type OwnedSplitReadHalf = EncryptedReadStream<T::OwnedSplitReadHalf>;
|
||||||
|
type OwnedSplitWriteHalf = EncryptedWriteStream<T::OwnedSplitWriteHalf>;
|
||||||
|
|
||||||
|
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
|
||||||
|
(self.read_half, self.write_half)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
use crate::utils::call_counter::increment_counter_for_event;
|
||||||
|
use crate::utils::protocol::TestProtocolListener;
|
||||||
|
use crate::utils::{get_free_port, start_server_and_client};
|
||||||
|
use bromine::prelude::encrypted::EncryptedListener;
|
||||||
|
use bromine::prelude::*;
|
||||||
|
use bromine::IPCBuilder;
|
||||||
|
use byteorder::{BigEndian, ReadBytesExt};
|
||||||
|
use bytes::{BufMut, Bytes, BytesMut};
|
||||||
|
use futures::StreamExt;
|
||||||
|
use rand_core::RngCore;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn it_sends_and_receives() {
|
||||||
|
let ctx = get_client_with_server().await;
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let mut buffer = vec![0u8; 140];
|
||||||
|
rng.fill_bytes(&mut buffer);
|
||||||
|
|
||||||
|
let mut stream = ctx
|
||||||
|
.emit("bytes", BytePayload::new(buffer.clone()))
|
||||||
|
.stream_replies()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let mut count = 0;
|
||||||
|
|
||||||
|
while let Some(Ok(response)) = stream.next().await {
|
||||||
|
let bytes = response.payload::<BytePayload>().unwrap();
|
||||||
|
assert_eq!(bytes.into_inner(), buffer);
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
assert_eq!(count, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn it_sends_and_receives_strings() {
|
||||||
|
let ctx = get_client_with_server().await;
|
||||||
|
let response = ctx
|
||||||
|
.emit("string", StringPayload(String::from("Hello World")))
|
||||||
|
.await_reply()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let response_string = response.payload::<StringPayload>().unwrap().0;
|
||||||
|
|
||||||
|
assert_eq!(&response_string, "Hello World")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_client_with_server() -> Context {
|
||||||
|
let port = get_free_port();
|
||||||
|
|
||||||
|
start_server_and_client(move || get_builder(port)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_builder(port: u8) -> IPCBuilder<EncryptedListener<TestProtocolListener>> {
|
||||||
|
IPCBuilder::new()
|
||||||
|
.address(port)
|
||||||
|
.on("bytes", callback!(handle_bytes))
|
||||||
|
.on("string", callback!(handle_string))
|
||||||
|
.timeout(Duration::from_millis(100))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult<Response> {
|
||||||
|
increment_counter_for_event(ctx, &event).await;
|
||||||
|
let bytes = event.payload::<BytePayload>()?.into_bytes();
|
||||||
|
|
||||||
|
for _ in 0u8..99 {
|
||||||
|
ctx.emit("bytes", BytePayload::from(bytes.clone())).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.response(BytePayload::from(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_string(ctx: &Context, event: Event) -> IPCResult<Response> {
|
||||||
|
ctx.response(event.payload::<StringPayload>()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct StringPayload(String);
|
||||||
|
|
||||||
|
impl IntoPayload for StringPayload {
|
||||||
|
fn into_payload(self, _: &Context) -> IPCResult<Bytes> {
|
||||||
|
let mut buf = BytesMut::with_capacity(self.0.len() + 4);
|
||||||
|
buf.put_u32(self.0.len() as u32);
|
||||||
|
buf.put(Bytes::from(self.0));
|
||||||
|
|
||||||
|
Ok(buf.freeze())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromPayload for StringPayload {
|
||||||
|
fn from_payload<R: Read>(mut reader: R) -> IPCResult<Self> {
|
||||||
|
let len = reader.read_u32::<BigEndian>()?;
|
||||||
|
let mut buf = vec![0u8; len as usize];
|
||||||
|
reader.read_exact(&mut buf)?;
|
||||||
|
let string = String::from_utf8(buf).map_err(|_| IPCError::from("not a string"))?;
|
||||||
|
|
||||||
|
Ok(StringPayload(string))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue