|
|
|
@ -52,40 +52,35 @@ impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadSt
|
|
|
|
|
buf: &mut ReadBuf<'_>,
|
|
|
|
|
) -> Poll<std::io::Result<()>> {
|
|
|
|
|
if self.fut.is_none() {
|
|
|
|
|
if self.remaining.len() > 0 {
|
|
|
|
|
let max_copy = min(buf.remaining(), self.remaining.len());
|
|
|
|
|
let bytes = self.remaining.copy_to_bytes(max_copy);
|
|
|
|
|
buf.put_slice(&bytes);
|
|
|
|
|
tracing::trace!("{} bytes read from buffer", bytes.len());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if buf.remaining() > 0 {
|
|
|
|
|
let mut reader = self.inner.take().unwrap();
|
|
|
|
|
tracing::trace!("{} bytes remaining to read", buf.remaining());
|
|
|
|
|
let reader = self.inner.take().unwrap();
|
|
|
|
|
let cipher = self.cipher.take().unwrap();
|
|
|
|
|
|
|
|
|
|
self.fut = Some(Box::pin(async move {
|
|
|
|
|
let package = match EncryptedPackage::from_async_read(&mut reader).await {
|
|
|
|
|
Ok(p) => p,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
return (Err(e), reader, cipher);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
match cipher.decrypt(package.into_inner()) {
|
|
|
|
|
Ok(bytes) => (Ok(bytes), reader, cipher),
|
|
|
|
|
Err(e) => (Err(e), reader, cipher),
|
|
|
|
|
}
|
|
|
|
|
}));
|
|
|
|
|
self.fut = Some(Box::pin(async move { read_bytes(reader, cipher).await }));
|
|
|
|
|
} else {
|
|
|
|
|
return Poll::Ready(Ok(()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if self.fut.is_some() {
|
|
|
|
|
match self.fut.as_mut().unwrap().as_mut().poll(cx) {
|
|
|
|
|
Poll::Ready((result, reader, cipher)) => {
|
|
|
|
|
self.inner = Some(reader);
|
|
|
|
|
self.cipher = Some(cipher);
|
|
|
|
|
match result {
|
|
|
|
|
Ok(bytes) => {
|
|
|
|
|
self.fut = None;
|
|
|
|
|
self.remaining.put(bytes);
|
|
|
|
|
let max_copy = min(self.remaining.len(), buf.remaining());
|
|
|
|
|
let bytes = self.remaining.copy_to_bytes(max_copy);
|
|
|
|
|
self.fut = None;
|
|
|
|
|
buf.put_slice(&bytes);
|
|
|
|
|
tracing::trace!("{} bytes read from buffer", bytes.len());
|
|
|
|
|
|
|
|
|
|
if buf.remaining() == 0 {
|
|
|
|
|
Poll::Ready(Ok(()))
|
|
|
|
@ -98,9 +93,6 @@ impl<T: 'static + AsyncRead + Send + Sync + Unpin> AsyncRead for EncryptedReadSt
|
|
|
|
|
}
|
|
|
|
|
Poll::Pending => Poll::Pending,
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
Poll::Ready(Ok(()))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -112,11 +104,13 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
|
|
|
|
|
cx: &mut Context<'_>,
|
|
|
|
|
buf: &[u8],
|
|
|
|
|
) -> Poll<Result<usize, Error>> {
|
|
|
|
|
if buf.remaining() > 0 {
|
|
|
|
|
let buf = unsafe { std::mem::transmute::<_, &'static [u8]>(buf) };
|
|
|
|
|
self.buffer.put(Bytes::from(buf));
|
|
|
|
|
let written_length = buf.len();
|
|
|
|
|
|
|
|
|
|
if self.fut_write.is_none() {
|
|
|
|
|
self.buffer.put(Bytes::from(buf.to_vec()));
|
|
|
|
|
|
|
|
|
|
if self.fut_write.is_none() && self.buffer.len() >= WRITE_BUF_SIZE {
|
|
|
|
|
if self.buffer.len() >= WRITE_BUF_SIZE {
|
|
|
|
|
tracing::trace!("buffer has reached sending size: {}", self.buffer.len());
|
|
|
|
|
let buffer_len = self.buffer.len();
|
|
|
|
|
let max_copy = min(u32::MAX as usize, buffer_len);
|
|
|
|
|
let plaintext = self.buffer.copy_to_bytes(max_copy);
|
|
|
|
@ -124,38 +118,42 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
|
|
|
|
|
let cipher = self.cipher.take().unwrap();
|
|
|
|
|
|
|
|
|
|
self.fut_write = Some(Box::pin(write_bytes(plaintext, writer, cipher)))
|
|
|
|
|
} else {
|
|
|
|
|
return Poll::Ready(Ok(written_length));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if self.fut_write.is_some() {
|
|
|
|
|
|
|
|
|
|
match self.fut_write.as_mut().unwrap().as_mut().poll(cx) {
|
|
|
|
|
Poll::Ready((result, writer, cipher)) => {
|
|
|
|
|
self.inner = Some(writer);
|
|
|
|
|
self.cipher = Some(cipher);
|
|
|
|
|
self.fut_write = None;
|
|
|
|
|
|
|
|
|
|
Poll::Ready(result.map(|_| buf.len()))
|
|
|
|
|
Poll::Ready(result.map(|_| written_length))
|
|
|
|
|
}
|
|
|
|
|
Poll::Pending => Poll::Pending,
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
Poll::Ready(Ok(buf.len()))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
|
|
|
|
let buffer_len = self.buffer.len();
|
|
|
|
|
|
|
|
|
|
if !self.buffer.is_empty() && self.fut_flush.is_none() {
|
|
|
|
|
if self.fut_flush.is_none() {
|
|
|
|
|
let max_copy = min(u32::MAX as usize, buffer_len);
|
|
|
|
|
let plaintext = self.buffer.copy_to_bytes(max_copy);
|
|
|
|
|
let writer = self.inner.take().unwrap();
|
|
|
|
|
let cipher = self.cipher.take().unwrap();
|
|
|
|
|
|
|
|
|
|
self.fut_flush = Some(Box::pin(async move {
|
|
|
|
|
let (result, mut writer, cipher) = write_bytes(plaintext, writer, cipher).await;
|
|
|
|
|
let (mut writer, cipher) = if plaintext.len() > 0 {
|
|
|
|
|
let (result, writer, cipher) = write_bytes(plaintext, writer, cipher).await;
|
|
|
|
|
if result.is_err() {
|
|
|
|
|
return (result, writer, cipher);
|
|
|
|
|
}
|
|
|
|
|
(writer, cipher)
|
|
|
|
|
} else {
|
|
|
|
|
(writer, cipher)
|
|
|
|
|
};
|
|
|
|
|
if let Err(e) = writer.flush().await {
|
|
|
|
|
(Err(e), writer, cipher)
|
|
|
|
|
} else {
|
|
|
|
@ -200,11 +198,13 @@ impl<T: 'static + AsyncWrite + Unpin + Send + Sync> AsyncWrite for EncryptedWrit
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "trace", skip_all)]
|
|
|
|
|
async fn write_bytes<T: AsyncWrite + Unpin>(
|
|
|
|
|
bytes: Bytes,
|
|
|
|
|
mut writer: T,
|
|
|
|
|
cipher: CipherBox,
|
|
|
|
|
) -> (io::Result<()>, T, CipherBox) {
|
|
|
|
|
tracing::trace!("plaintext size: {}", bytes.len());
|
|
|
|
|
let encrypted_bytes = match cipher.encrypt(bytes) {
|
|
|
|
|
Ok(b) => b,
|
|
|
|
|
Err(e) => {
|
|
|
|
@ -212,9 +212,29 @@ async fn write_bytes<T: AsyncWrite + Unpin>(
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
let package_bytes = EncryptedPackage::new(encrypted_bytes).into_bytes();
|
|
|
|
|
tracing::trace!("encrypted size: {}", package_bytes.len());
|
|
|
|
|
if let Err(e) = writer.write_all(&package_bytes[..]).await {
|
|
|
|
|
return (Err(e), writer, cipher);
|
|
|
|
|
}
|
|
|
|
|
tracing::trace!("everything sent");
|
|
|
|
|
|
|
|
|
|
(Ok(()), writer, cipher)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[tracing::instrument(level = "trace", skip_all)]
|
|
|
|
|
async fn read_bytes<T: AsyncRead + Unpin>(
|
|
|
|
|
mut reader: T,
|
|
|
|
|
cipher: CipherBox,
|
|
|
|
|
) -> (io::Result<Bytes>, T, CipherBox) {
|
|
|
|
|
let package = match EncryptedPackage::from_async_read(&mut reader).await {
|
|
|
|
|
Ok(p) => p,
|
|
|
|
|
Err(e) => {
|
|
|
|
|
return (Err(e), reader, cipher);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
tracing::trace!("received {} bytes", package.bytes.len());
|
|
|
|
|
match cipher.decrypt(package.into_inner()) {
|
|
|
|
|
Ok(bytes) => (Ok(bytes), reader, cipher),
|
|
|
|
|
Err(e) => (Err(e), reader, cipher),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|