Improve test protocol by using some unsafe magic

Signed-off-by: trivernis <trivernis@protonmail.com>
pull/25/head
trivernis 3 years ago
parent 9e7cd26f6a
commit f70563d099
Signed by: Trivernis
GPG Key ID: DFFFCC2C7A02DB45

@ -3,43 +3,47 @@ use bromine::error::Result;
use bromine::prelude::{AsyncProtocolStreamSplit, IPCError};
use bromine::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
use lazy_static::lazy_static;
use std::cmp::min;
use std::collections::HashMap;
use std::future::Future;
use std::io::Error;
use std::mem;
use std::pin::Pin;
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use std::sync::Arc;
use std::sync::Mutex;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::mpsc::{
channel as async_channel, Receiver as AsyncReceiver, Sender as AsyncSender,
};
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
lazy_static! {
static ref LISTENERS_REF: Arc<AsyncMutex<HashMap<u8, AsyncSender<TestProtocolStream>>>> =
Arc::new(AsyncMutex::new(HashMap::new()));
static ref LISTENERS_REF: Arc<Mutex<HashMap<u8, Sender<TestProtocolStream>>>> =
Arc::new(Mutex::new(HashMap::new()));
}
/// Adds a channel that receives streams to handle
async fn add_port(number: u8, sender: tokio::sync::mpsc::Sender<TestProtocolStream>) {
let mut listeners = LISTENERS_REF.lock().await;
listeners.insert(number, sender);
}
/// Returns a stream for the given port connecting with the server via channels
async fn get_port(number: u8) -> Option<TestProtocolStream> {
let mut listeners = LISTENERS_REF.lock().await;
if let Some(sender) = listeners.get_mut(&number) {
let (s1, r1) = mpsc::channel();
let (s2, r2) = mpsc::channel();
let (s1, r1) = channel(2);
let (s2, r2) = channel(2);
let stream_1 = TestProtocolStream {
sender: Arc::new(Mutex::new(s1)),
sender: s1,
receiver: Arc::new(Mutex::new(r2)),
future: None,
remaining_buf: Default::default(),
};
let stream_2 = TestProtocolStream {
sender: Arc::new(Mutex::new(s2)),
sender: s2,
receiver: Arc::new(Mutex::new(r1)),
future: None,
remaining_buf: Default::default(),
};
sender.send(stream_2).await.ok();
@ -50,7 +54,7 @@ async fn get_port(number: u8) -> Option<TestProtocolStream> {
}
pub struct TestProtocolListener {
receiver: Arc<AsyncMutex<AsyncReceiver<TestProtocolStream>>>,
receiver: Arc<Mutex<Receiver<TestProtocolStream>>>,
}
#[async_trait]
@ -60,11 +64,11 @@ impl AsyncStreamProtocolListener for TestProtocolListener {
type Stream = TestProtocolStream;
async fn protocol_bind(address: Self::AddressType) -> Result<Self> {
let (sender, receiver) = async_channel(1);
let (sender, receiver) = channel(1);
add_port(address, sender).await;
Ok(Self {
receiver: Arc::new(AsyncMutex::new(receiver)),
receiver: Arc::new(Mutex::new(receiver)),
})
}
@ -79,10 +83,79 @@ impl AsyncStreamProtocolListener for TestProtocolListener {
}
}
#[derive(Clone)]
impl Clone for TestProtocolStream {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
receiver: Arc::clone(&self.receiver),
future: None,
remaining_buf: Default::default(),
}
}
}
pub struct TestProtocolStream {
sender: Arc<Mutex<Sender<Vec<u8>>>>,
sender: Sender<Vec<u8>>,
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
future: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
remaining_buf: Arc<Mutex<Vec<u8>>>,
}
impl TestProtocolStream {
/// Read from the receiver and remaining buffer
async fn read_from_receiver(
buf: &mut ReadBuf<'static>,
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
remaining_buf: Arc<Mutex<Vec<u8>>>,
) {
{
let mut remaining_buf = remaining_buf.lock().await;
if !remaining_buf.is_empty() {
if Self::read_from_remaining_buffer(buf, &mut remaining_buf).await {
return;
}
}
}
let mut receiver = receiver.lock().await;
if let Some(mut bytes) = receiver.recv().await {
let slice_len = min(bytes.len(), buf.capacity());
buf.put_slice(&bytes[0..slice_len]);
bytes.reverse();
bytes.truncate(bytes.len() - slice_len);
bytes.reverse();
let mut remaining_buf = remaining_buf.lock().await;
remaining_buf.append(&mut bytes);
}
}
/// Read from the remaining buffer returning a boolean if the
/// read buffer has been filled
async fn read_from_remaining_buffer(
buf: &mut ReadBuf<'static>,
remaining_buf: &mut Vec<u8>,
) -> bool {
if remaining_buf.len() < buf.capacity() {
buf.put_slice(&remaining_buf);
remaining_buf.clear();
false
} else if remaining_buf.len() == buf.capacity() {
buf.put_slice(&remaining_buf);
remaining_buf.clear();
true
} else {
let slice_len = buf.capacity();
let remaining_len = remaining_buf.len();
buf.put_slice(&remaining_buf[0..slice_len]);
remaining_buf.reverse();
remaining_buf.truncate(remaining_len - slice_len);
remaining_buf.reverse();
true
}
}
}
impl AsyncProtocolStreamSplit for TestProtocolStream {
@ -108,15 +181,33 @@ impl AsyncProtocolStream for TestProtocolStream {
impl AsyncRead for TestProtocolStream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let receiver = self.receiver.lock().unwrap();
if let Ok(b) = receiver.recv() {
buf.put_slice(&b);
Poll::Ready(Ok(()))
} else {
Poll::Pending
unsafe {
// we need a mutable reference to access the inner future
let stream = self.get_unchecked_mut();
if stream.future.is_none() {
// we need to change the lifetime to be able to use the read buffer in the read future
let buf: &mut ReadBuf<'static> = mem::transmute(buf);
let receiver = Arc::clone(&stream.receiver);
let remaining_buf = Arc::clone(&stream.remaining_buf);
let future = TestProtocolStream::read_from_receiver(buf, receiver, remaining_buf);
stream.future = Some(Box::pin(future));
}
if let Some(future) = &mut stream.future {
match future.as_mut().poll(cx) {
Poll::Ready(_) => {
stream.future = None;
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Pending
}
}
}
}
@ -124,15 +215,36 @@ impl AsyncRead for TestProtocolStream {
impl AsyncWrite for TestProtocolStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::prelude::rust_2015::Result<usize, Error>> {
let sender = self.sender.lock().unwrap();
let vec_buf = buf.to_vec();
let buf_len = vec_buf.len();
sender.send(vec_buf).unwrap();
let write_len = buf.len();
unsafe {
// we need a mutable reference to access the inner future
let stream = self.get_unchecked_mut();
if stream.future.is_none() {
// we take ownership here so that we don't need to change lifetimes here
let buf = buf.to_vec();
let sender = stream.sender.clone();
Poll::Ready(Ok(buf_len))
let future = async move {
sender.send(buf).await.unwrap();
};
stream.future = Some(Box::pin(future));
}
if let Some(future) = &mut stream.future {
match future.as_mut().poll(cx) {
Poll::Ready(_) => {
stream.future = None;
Poll::Ready(Ok(write_len))
}
Poll::Pending => Poll::Pending,
}
} else {
Poll::Pending
}
}
}
fn poll_flush(

@ -41,5 +41,11 @@ async fn it_passes_events() {
tokio::task::spawn(async { get_builder(0).build_server().await.unwrap() });
tokio::time::sleep(Duration::from_millis(100)).await;
let ctx = get_builder(0).build_client().await.unwrap();
ctx.emitter.emit("ping", ()).await.unwrap(); // todo fix reply deadlock
ctx.emitter
.emit("ping", ())
.await
.unwrap()
.await_reply(&ctx)
.await
.unwrap();
}

Loading…
Cancel
Save