mirror of https://github.com/Trivernis/bromine.git
commit
c4baf40e65
@ -0,0 +1,167 @@
|
||||
use crate::prelude::encrypted::EncryptedStream;
|
||||
use crate::prelude::IPCResult;
|
||||
use crate::protocol::AsyncProtocolStream;
|
||||
use bytes::Bytes;
|
||||
use chacha20poly1305::aead::{Aead, NewAead};
|
||||
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
|
||||
use rand::thread_rng;
|
||||
use rand_core::RngCore;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
/// A structure used for encryption.
|
||||
/// It holds the cipher initialized with the given key
|
||||
/// and two counters for both encryption and decryption
|
||||
/// count which are used to keep track of the nonce.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CipherBox {
|
||||
cipher: ChaCha20Poly1305,
|
||||
en_count: Arc<AtomicU64>,
|
||||
de_count: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl CipherBox {
|
||||
pub fn new(key: Bytes) -> Self {
|
||||
let key = Key::from_slice(&key[..]);
|
||||
let cipher = ChaCha20Poly1305::new(key);
|
||||
|
||||
Self {
|
||||
cipher,
|
||||
en_count: Arc::new(AtomicU64::new(0)),
|
||||
de_count: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypts the given message
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn encrypt(&self, data: Bytes) -> io::Result<Bytes> {
|
||||
self.cipher
|
||||
.encrypt(&self.en_nonce(), &data[..])
|
||||
.map(Bytes::from)
|
||||
.map_err(|_| io::Error::from(ErrorKind::InvalidData))
|
||||
}
|
||||
|
||||
/// Decrypts the given message
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn decrypt(&self, data: Bytes) -> io::Result<Bytes> {
|
||||
self.cipher
|
||||
.decrypt(&self.de_nonce(), &data[..])
|
||||
.map(Bytes::from)
|
||||
.map_err(|_| io::Error::from(ErrorKind::InvalidData))
|
||||
}
|
||||
|
||||
/// Updates the stored key.
|
||||
/// This must be done simultaneously on server and client side
|
||||
/// to keep track of the nonce
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn update_key(&mut self, key: Bytes) {
|
||||
let key = Key::from_slice(&key[..]);
|
||||
self.cipher = ChaCha20Poly1305::new(key);
|
||||
self.reset_counters();
|
||||
}
|
||||
|
||||
/// Resets the nonce counters.
|
||||
/// This must be done simultaneously on server and client side.
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
pub fn reset_counters(&mut self) {
|
||||
self.de_count.store(0, Ordering::SeqCst);
|
||||
self.en_count.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
fn en_nonce(&self) -> Nonce {
|
||||
let count = self.en_count.fetch_add(1, Ordering::SeqCst);
|
||||
tracing::trace!("encrypted count {}", count);
|
||||
nonce_from_number(count)
|
||||
}
|
||||
|
||||
fn de_nonce(&self) -> Nonce {
|
||||
let count = self.de_count.fetch_add(1, Ordering::SeqCst);
|
||||
tracing::trace!("decrypted count {}", count);
|
||||
nonce_from_number(count)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a nonce from a given number
|
||||
/// The given number is repeated to fit the nonce bytes
|
||||
fn nonce_from_number(number: u64) -> Nonce {
|
||||
let number_bytes: [u8; 8] = number.to_be_bytes();
|
||||
let num_vec = number_bytes.repeat(2);
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
nonce_bytes.copy_from_slice(&num_vec[..12]);
|
||||
|
||||
nonce_bytes.into()
|
||||
}
|
||||
|
||||
impl<T: AsyncProtocolStream> EncryptedStream<T> {
|
||||
/// Does a server-client key exchange.
|
||||
/// 1. The server receives the public key of the client
|
||||
/// 2. The server sends its own public key
|
||||
/// 3. The server creates an intermediary encrypted connection
|
||||
/// 4. The server generates a new secret
|
||||
/// 5. The server sends the secret to the client
|
||||
/// 6. The connection is upgraded with the new shared key
|
||||
pub async fn from_server_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult<Self> {
|
||||
let other_pub = receive_public_key(&mut inner).await?;
|
||||
send_public_key(&mut inner, &secret).await?;
|
||||
let shared_secret = secret.diffie_hellman(&other_pub);
|
||||
let mut stream = Self::new(inner, shared_secret);
|
||||
let permanent_secret = generate_secret();
|
||||
stream.write_all(&permanent_secret).await?;
|
||||
stream.flush().await?;
|
||||
stream.update_key(permanent_secret.into());
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
/// Does a client-server key exchange.
|
||||
/// 1. The client sends its public key to the server
|
||||
/// 2. The client receives the servers public key
|
||||
/// 3. The client creates an intermediary encrypted connection
|
||||
/// 4. The client receives the new key from the server
|
||||
/// 5. The connection is upgraded with the new shared key
|
||||
pub async fn from_client_key_exchange(mut inner: T, secret: StaticSecret) -> IPCResult<Self> {
|
||||
send_public_key(&mut inner, &secret).await?;
|
||||
let other_pub = receive_public_key(&mut inner).await?;
|
||||
let shared_secret = secret.diffie_hellman(&other_pub);
|
||||
let mut stream = Self::new(inner, shared_secret);
|
||||
let mut key_buf = vec![0u8; 32];
|
||||
stream.read_exact(&mut key_buf).await?;
|
||||
stream.update_key(key_buf.into());
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
async fn receive_public_key<T: AsyncProtocolStream>(stream: &mut T) -> IPCResult<PublicKey> {
|
||||
let mut pk_buf = [0u8; 32];
|
||||
stream.read_exact(&mut pk_buf).await?;
|
||||
|
||||
Ok(PublicKey::from(pk_buf))
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip_all)]
|
||||
async fn send_public_key<T: AsyncProtocolStream>(
|
||||
stream: &mut T,
|
||||
secret: &StaticSecret,
|
||||
) -> IPCResult<()> {
|
||||
let own_pk = PublicKey::from(secret);
|
||||
stream.write_all(own_pk.as_bytes()).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
fn generate_secret() -> Vec<u8> {
|
||||
let mut rng = thread_rng();
|
||||
let mut buf = vec![0u8; 32];
|
||||
rng.fill_bytes(&mut buf);
|
||||
|
||||
Sha256::digest(&buf).to_vec()
|
||||
}
|
@ -0,0 +1,240 @@
|
||||
use crate::prelude::encrypted::{
|
||||
EncryptedPackage, EncryptedReadStream, EncryptedStream, EncryptedWriteStream,
|
||||
};
|
||||
use crate::prelude::AsyncProtocolStream;
|
||||
use crate::protocol::encrypted::crypt_handling::CipherBox;
|
||||
use bytes::{Buf, BufMut, Bytes};
|
||||
use std::cmp::min;
|
||||
use std::io;
|
||||
use std::io::Error;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
|
||||
const WRITE_BUF_SIZE: usize = 1024;
|
||||
|
||||
impl<T: AsyncProtocolStream> Unpin for EncryptedStream<T> {}
|
||||
|
||||
impl<T: AsyncProtocolStream> AsyncWrite for EncryptedStream<T> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
Pin::new(&mut self.write_half).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Pin::new(&mut self.write_half).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Pin::new(&mut self.write_half).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncProtocolStream> AsyncRead for EncryptedStream<T> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.read_half).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static + AsyncRead + Unpin + Send + Sync> Unpin for EncryptedReadStream<T> {}
|
||||
|
||||
impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadStream<T> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if self.fut.is_none() {
|
||||
if self.remaining.len() > 0 {
|
||||
let max_copy = min(buf.remaining(), self.remaining.len());
|
||||
let bytes = self.remaining.copy_to_bytes(max_copy);
|
||||
buf.put_slice(&bytes);
|
||||
tracing::trace!("{} bytes read from buffer", bytes.len());
|
||||
}
|
||||
|
||||
if buf.remaining() > 0 {
|
||||
tracing::trace!("{} bytes remaining to read", buf.remaining());
|
||||
let reader = self.inner.take().unwrap();
|
||||
let cipher = self.cipher.take().unwrap();
|
||||
|
||||
self.fut = Some(Box::pin(async move { read_bytes(reader, cipher).await }));
|
||||
} else {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
match self.fut.as_mut().unwrap().as_mut().poll(cx) {
|
||||
Poll::Ready((result, reader, cipher)) => {
|
||||
self.inner = Some(reader);
|
||||
self.cipher = Some(cipher);
|
||||
match result {
|
||||
Ok(bytes) => {
|
||||
self.fut = None;
|
||||
self.remaining.put(bytes);
|
||||
let max_copy = min(self.remaining.len(), buf.remaining());
|
||||
let bytes = self.remaining.copy_to_bytes(max_copy);
|
||||
buf.put_slice(&bytes);
|
||||
tracing::trace!("{} bytes read from buffer", bytes.len());
|
||||
|
||||
if buf.remaining() == 0 {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> Unpin for EncryptedWriteStream<T> {}
|
||||
|
||||
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWriteStream<T> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
let written_length = buf.len();
|
||||
|
||||
if self.fut_write.is_none() {
|
||||
self.buffer.put(Bytes::from(buf.to_vec()));
|
||||
|
||||
if self.buffer.len() >= WRITE_BUF_SIZE {
|
||||
tracing::trace!("buffer has reached sending size: {}", self.buffer.len());
|
||||
let buffer_len = self.buffer.len();
|
||||
let max_copy = min(u32::MAX as usize, buffer_len);
|
||||
let plaintext = self.buffer.copy_to_bytes(max_copy);
|
||||
let writer = self.inner.take().unwrap();
|
||||
let cipher = self.cipher.take().unwrap();
|
||||
|
||||
self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher)))
|
||||
} else {
|
||||
return Poll::Ready(Ok(written_length));
|
||||
}
|
||||
}
|
||||
|
||||
match self.fut_write.as_mut().unwrap().as_mut().poll(cx) {
|
||||
Poll::Ready((result, writer, cipher)) => {
|
||||
self.inner = Some(writer);
|
||||
self.cipher = Some(cipher);
|
||||
self.fut_write = None;
|
||||
|
||||
Poll::Ready(result.map(|_| written_length))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
let buffer_len = self.buffer.len();
|
||||
|
||||
if self.fut_flush.is_none() {
|
||||
let max_copy = min(u32::MAX as usize, buffer_len);
|
||||
let plaintext = self.buffer.copy_to_bytes(max_copy);
|
||||
let writer = self.inner.take().unwrap();
|
||||
let cipher = self.cipher.take().unwrap();
|
||||
|
||||
self.fut_flush = Some(Box::pin(async move {
|
||||
let (mut writer, cipher) = if plaintext.len() > 0 {
|
||||
let (result, writer, cipher) = write_bytes(plaintext, writer, cipher).await;
|
||||
if result.is_err() {
|
||||
return (result, writer, cipher);
|
||||
}
|
||||
(writer, cipher)
|
||||
} else {
|
||||
(writer, cipher)
|
||||
};
|
||||
if let Err(e) = writer.flush().await {
|
||||
(Err(e), writer, cipher)
|
||||
} else {
|
||||
(Ok(()), writer, cipher)
|
||||
}
|
||||
}))
|
||||
}
|
||||
match self.fut_flush.as_mut().unwrap().as_mut().poll(cx) {
|
||||
Poll::Ready((result, writer, cipher)) => {
|
||||
self.inner = Some(writer);
|
||||
self.cipher = Some(cipher);
|
||||
self.fut_flush = None;
|
||||
|
||||
Poll::Ready(result)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if self.fut_shutdown.is_none() {
|
||||
match self.as_mut().poll_flush(cx) {
|
||||
Poll::Ready(result) => match result {
|
||||
Ok(_) => {
|
||||
let mut writer = self.inner.take().unwrap();
|
||||
self.fut_shutdown = Some(Box::pin(async move { writer.shutdown().await }));
|
||||
Poll::Pending
|
||||
}
|
||||
Err(e) => Poll::Ready(Err(e)),
|
||||
},
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
} else {
|
||||
match self.fut_shutdown.as_mut().unwrap().as_mut().poll(cx) {
|
||||
Poll::Ready(result) => {
|
||||
self.fut_shutdown = None;
|
||||
Poll::Ready(result)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
async fn write_bytes<T: AsyncWrite + Unpin>(
|
||||
bytes: Bytes,
|
||||
mut writer: T,
|
||||
cipher: CipherBox,
|
||||
) -> (io::Result<()>, T, CipherBox) {
|
||||
tracing::trace!("plaintext size: {}", bytes.len());
|
||||
let encrypted_bytes = match cipher.encrypt(bytes) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
return (Err(e), writer, cipher);
|
||||
}
|
||||
};
|
||||
let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes();
|
||||
tracing::trace!("encrypted size: {}", package_bytes.len());
|
||||
if let Err(e) = writer.write_all(&package_bytes[..]).await {
|
||||
return (Err(e), writer, cipher);
|
||||
}
|
||||
tracing::trace!("everything sent");
|
||||
|
||||
(Ok(()), writer, cipher)
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", skip_all)]
|
||||
async fn read_bytes<T: AsyncRead + Unpin>(
|
||||
mut reader: T,
|
||||
cipher: CipherBox,
|
||||
) -> (io::Result<Bytes>, T, CipherBox) {
|
||||
let package = match EncryptedPackage::from_async_read(&mut reader).await {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return (Err(e), reader, cipher);
|
||||
}
|
||||
};
|
||||
tracing::trace!("received {} bytes", package.bytes.len());
|
||||
match cipher.decrypt(package.into_inner()) {
|
||||
Ok(bytes) => (Ok(bytes), reader, cipher),
|
||||
Err(e) => (Err(e), reader, cipher),
|
||||
}
|
||||
}
|
@ -0,0 +1,152 @@
|
||||
mod crypt_handling;
|
||||
mod io_impl;
|
||||
mod protocol_impl;
|
||||
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
pub use io_impl::*;
|
||||
pub use protocol_impl::*;
|
||||
use rand_core::RngCore;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use x25519_dalek::{SharedSecret, StaticSecret};
|
||||
|
||||
use crate::prelude::encrypted::crypt_handling::CipherBox;
|
||||
use crate::prelude::{AsyncProtocolStream, AsyncStreamProtocolListener};
|
||||
|
||||
pub type OptionalFuture<T> = Option<Pin<Box<dyn Future<Output = T> + Send + Sync>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EncryptionOptions<T: Clone + Default> {
|
||||
pub inner_options: T,
|
||||
pub secret: StaticSecret,
|
||||
}
|
||||
|
||||
impl<T: Clone + Default> Default for EncryptionOptions<T> {
|
||||
fn default() -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut secret = [0u8; 32];
|
||||
rng.fill_bytes(&mut secret);
|
||||
|
||||
Self {
|
||||
secret: StaticSecret::from(secret),
|
||||
inner_options: T::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EncryptedListener<T: AsyncStreamProtocolListener> {
|
||||
inner: T,
|
||||
secret: StaticSecret,
|
||||
}
|
||||
|
||||
impl<T: AsyncStreamProtocolListener> EncryptedListener<T> {
|
||||
pub fn new(inner: T, secret: StaticSecret) -> Self {
|
||||
Self { inner, secret }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EncryptedStream<T: AsyncProtocolStream> {
|
||||
read_half: EncryptedReadStream<T::OwnedSplitReadHalf>,
|
||||
write_half: EncryptedWriteStream<T::OwnedSplitWriteHalf>,
|
||||
}
|
||||
|
||||
impl<T: AsyncProtocolStream> EncryptedStream<T> {
|
||||
pub fn new(inner: T, secret: SharedSecret) -> Self {
|
||||
let cipher_box = CipherBox::new(Bytes::from(secret.to_bytes().to_vec()));
|
||||
let (read, write) = inner.protocol_into_split();
|
||||
let read_half = EncryptedReadStream::new(read, cipher_box.clone());
|
||||
let write_half = EncryptedWriteStream::new(write, cipher_box);
|
||||
|
||||
Self {
|
||||
read_half,
|
||||
write_half,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_key(&mut self, key: Bytes) {
|
||||
self.write_half
|
||||
.cipher
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.update_key(key.clone());
|
||||
self.read_half
|
||||
.cipher
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.update_key(key.clone());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EncryptedReadStream<T: AsyncRead> {
|
||||
inner: Option<T>,
|
||||
fut: OptionalFuture<(io::Result<Bytes>, T, CipherBox)>,
|
||||
remaining: BytesMut,
|
||||
cipher: Option<CipherBox>,
|
||||
}
|
||||
|
||||
impl<T: 'static + AsyncRead + Unpin + Send + Sync> EncryptedReadStream<T> {
|
||||
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
fut: None,
|
||||
remaining: BytesMut::new(),
|
||||
cipher: Some(cipher),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EncryptedWriteStream<T: 'static + AsyncWrite + Unpin + Send + Sync> {
|
||||
inner: Option<T>,
|
||||
cipher: Option<CipherBox>,
|
||||
buffer: BytesMut,
|
||||
fut_write: OptionalFuture<(io::Result<()>, T, CipherBox)>,
|
||||
fut_flush: OptionalFuture<(io::Result<()>, T, CipherBox)>,
|
||||
fut_shutdown: OptionalFuture<io::Result<()>>,
|
||||
}
|
||||
|
||||
impl<T: 'static + AsyncWrite + Unpin + Send + Sync> EncryptedWriteStream<T> {
|
||||
pub(crate) fn new(inner: T, cipher: CipherBox) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
cipher: Some(cipher),
|
||||
buffer: BytesMut::with_capacity(1024),
|
||||
fut_write: None,
|
||||
fut_flush: None,
|
||||
fut_shutdown: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct EncryptedPackage {
|
||||
bytes: Bytes,
|
||||
}
|
||||
|
||||
impl EncryptedPackage {
|
||||
pub fn new(bytes: Bytes) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
|
||||
pub fn into_bytes(self) -> Bytes {
|
||||
let mut buf = BytesMut::with_capacity(4 + self.bytes.len());
|
||||
buf.put_u32(self.bytes.len() as u32);
|
||||
buf.put(self.bytes);
|
||||
|
||||
buf.freeze()
|
||||
}
|
||||
|
||||
pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
|
||||
let length = reader.read_u32().await?;
|
||||
let mut bytes_buf = vec![0u8; length as usize];
|
||||
reader.read_exact(&mut bytes_buf).await?;
|
||||
|
||||
Ok(Self {
|
||||
bytes: Bytes::from(bytes_buf),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Bytes {
|
||||
self.bytes
|
||||
}
|
||||
}
|
@ -0,0 +1,55 @@
|
||||
use crate::error::Result;
|
||||
use crate::prelude::encrypted::{EncryptedReadStream, EncryptedWriteStream};
|
||||
use crate::prelude::{AsyncProtocolStreamSplit, IPCResult};
|
||||
use crate::protocol::encrypted::{EncryptedListener, EncryptedStream, EncryptionOptions};
|
||||
use crate::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
impl<T: AsyncStreamProtocolListener> AsyncStreamProtocolListener for EncryptedListener<T> {
|
||||
type AddressType = T::AddressType;
|
||||
type RemoteAddressType = T::RemoteAddressType;
|
||||
type Stream = EncryptedStream<T::Stream>;
|
||||
type ListenerOptions = EncryptionOptions<T::ListenerOptions>;
|
||||
|
||||
async fn protocol_bind(
|
||||
address: Self::AddressType,
|
||||
options: Self::ListenerOptions,
|
||||
) -> IPCResult<Self> {
|
||||
let inner = T::protocol_bind(address, options.inner_options).await?;
|
||||
|
||||
Ok(EncryptedListener::new(inner, options.secret))
|
||||
}
|
||||
|
||||
async fn protocol_accept(&self) -> IPCResult<(Self::Stream, Self::RemoteAddressType)> {
|
||||
let (inner_stream, remote_addr) = self.inner.protocol_accept().await?;
|
||||
let stream =
|
||||
Self::Stream::from_server_key_exchange(inner_stream, self.secret.clone()).await?;
|
||||
|
||||
Ok((stream, remote_addr))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: AsyncProtocolStream> AsyncProtocolStream for EncryptedStream<T> {
|
||||
type AddressType = T::AddressType;
|
||||
type StreamOptions = EncryptionOptions<T::StreamOptions>;
|
||||
|
||||
async fn protocol_connect(
|
||||
address: Self::AddressType,
|
||||
options: Self::StreamOptions,
|
||||
) -> Result<Self> {
|
||||
let inner = T::protocol_connect(address, options.inner_options).await?;
|
||||
EncryptedStream::from_client_key_exchange(inner, options.secret).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: AsyncProtocolStream> AsyncProtocolStreamSplit for EncryptedStream<T> {
|
||||
type OwnedSplitReadHalf = EncryptedReadStream<T::OwnedSplitReadHalf>;
|
||||
type OwnedSplitWriteHalf = EncryptedWriteStream<T::OwnedSplitWriteHalf>;
|
||||
|
||||
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
|
||||
(self.read_half, self.write_half)
|
||||
}
|
||||
}
|
@ -0,0 +1,111 @@
|
||||
use crate::utils::call_counter::increment_counter_for_event;
|
||||
use crate::utils::protocol::TestProtocolListener;
|
||||
use crate::utils::{get_free_port, start_server_and_client};
|
||||
use bromine::prelude::encrypted::EncryptedListener;
|
||||
use bromine::prelude::*;
|
||||
use bromine::IPCBuilder;
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use futures::StreamExt;
|
||||
use rand_core::RngCore;
|
||||
use std::io::Read;
|
||||
use std::time::Duration;
|
||||
|
||||
mod utils;
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_sends_and_receives_smaller_packages() {
|
||||
send_and_receive_bytes(140).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_sends_and_receives_larger_packages() {
|
||||
send_and_receive_bytes(1024 * 32).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_sends_and_receives_strings() {
|
||||
let ctx = get_client_with_server().await;
|
||||
let response = ctx
|
||||
.emit("string", StringPayload(String::from("Hello World")))
|
||||
.await_reply()
|
||||
.await
|
||||
.unwrap();
|
||||
let response_string = response.payload::<StringPayload>().unwrap().0;
|
||||
|
||||
assert_eq!(&response_string, "Hello World")
|
||||
}
|
||||
|
||||
async fn send_and_receive_bytes(byte_size: usize) -> IPCResult<()> {
|
||||
let ctx = get_client_with_server().await;
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut buffer = vec![0u8; byte_size];
|
||||
rng.fill_bytes(&mut buffer);
|
||||
|
||||
let mut stream = ctx
|
||||
.emit("bytes", BytePayload::new(buffer.clone()))
|
||||
.stream_replies()
|
||||
.await?;
|
||||
let mut count = 0;
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let bytes = response.unwrap().payload::<BytePayload>()?;
|
||||
assert_eq!(bytes.into_inner(), buffer);
|
||||
count += 1;
|
||||
}
|
||||
assert_eq!(count, 100);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_client_with_server() -> Context {
|
||||
let port = get_free_port();
|
||||
|
||||
start_server_and_client(move || get_builder(port)).await
|
||||
}
|
||||
|
||||
fn get_builder(port: u8) -> IPCBuilder<EncryptedListener<TestProtocolListener>> {
|
||||
IPCBuilder::new()
|
||||
.address(port)
|
||||
.on("bytes", callback!(handle_bytes))
|
||||
.on("string", callback!(handle_string))
|
||||
.timeout(Duration::from_secs(10))
|
||||
}
|
||||
|
||||
async fn handle_bytes(ctx: &Context, event: Event) -> IPCResult<Response> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
let bytes = event.payload::<BytePayload>()?.into_inner();
|
||||
|
||||
for _ in 0u8..99 {
|
||||
ctx.emit("bytes", BytePayload::new(bytes.clone())).await?;
|
||||
}
|
||||
|
||||
ctx.response(BytePayload::new(bytes))
|
||||
}
|
||||
|
||||
async fn handle_string(ctx: &Context, event: Event) -> IPCResult<Response> {
|
||||
ctx.response(event.payload::<StringPayload>()?)
|
||||
}
|
||||
|
||||
pub struct StringPayload(String);
|
||||
|
||||
impl IntoPayload for StringPayload {
|
||||
fn into_payload(self, _: &Context) -> IPCResult<Bytes> {
|
||||
let mut buf = BytesMut::with_capacity(self.0.len() + 4);
|
||||
buf.put_u32(self.0.len() as u32);
|
||||
buf.put(Bytes::from(self.0));
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromPayload for StringPayload {
|
||||
fn from_payload<R: Read>(mut reader: R) -> IPCResult<Self> {
|
||||
let len = reader.read_u32::<BigEndian>()?;
|
||||
let mut buf = vec![0u8; len as usize];
|
||||
reader.read_exact(&mut buf)?;
|
||||
let string = String::from_utf8(buf).map_err(|_| IPCError::from("not a string"))?;
|
||||
|
||||
Ok(StringPayload(string))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue