mirror of https://github.com/Trivernis/bromine.git
Merge pull request #25 from Trivernis/develop
More serialization formats and change in feature namespull/32/head
commit
ddf6e03ba2
@ -0,0 +1,23 @@
|
||||
#[cfg(feature = "serialize_rmp")]
|
||||
mod serialize_rmp;
|
||||
|
||||
#[cfg(feature = "serialize_rmp")]
|
||||
pub use serialize_rmp::*;
|
||||
|
||||
#[cfg(feature = "serialize_bincode")]
|
||||
mod serialize_bincode;
|
||||
|
||||
#[cfg(feature = "serialize_bincode")]
|
||||
pub use serialize_bincode::*;
|
||||
|
||||
#[cfg(feature = "serialize_postcard")]
|
||||
mod serialize_postcard;
|
||||
|
||||
#[cfg(feature = "serialize_postcard")]
|
||||
pub use serialize_postcard::*;
|
||||
|
||||
#[cfg(feature = "serialize_json")]
|
||||
mod serialize_json;
|
||||
|
||||
#[cfg(feature = "serialize_json")]
|
||||
pub use serialize_json::*;
|
@ -0,0 +1,28 @@
|
||||
use crate::payload::{EventReceivePayload, EventSendPayload};
|
||||
use crate::prelude::IPCResult;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::io::Read;
|
||||
|
||||
pub type SerializationError = bincode::Error;
|
||||
|
||||
impl<T> EventSendPayload for T
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
|
||||
let bytes = bincode::serialize(&self)?;
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EventReceivePayload for T
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
|
||||
let type_data = bincode::deserialize_from(reader)?;
|
||||
Ok(type_data)
|
||||
}
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
use crate::payload::{EventReceivePayload, EventSendPayload};
|
||||
use crate::prelude::IPCResult;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::io::Read;
|
||||
|
||||
pub type SerializationError = serde_json::Error;
|
||||
|
||||
impl<T> EventSendPayload for T
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
|
||||
let bytes = serde_json::to_vec(&self)?;
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EventReceivePayload for T
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
|
||||
let type_data = serde_json::from_reader(reader)?;
|
||||
|
||||
Ok(type_data)
|
||||
}
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
use crate::payload::{EventReceivePayload, EventSendPayload};
|
||||
use crate::prelude::IPCResult;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::io::Read;
|
||||
|
||||
pub type SerializationError = postcard::Error;
|
||||
|
||||
impl<T> EventSendPayload for T
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
|
||||
let bytes = postcard::to_allocvec(&self)?.to_vec();
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EventReceivePayload for T
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
fn from_payload_bytes<R: Read>(mut reader: R) -> IPCResult<Self> {
|
||||
let mut buf = Vec::new();
|
||||
// reading to end means reading the full size of the provided data
|
||||
reader.read_to_end(&mut buf)?;
|
||||
let type_data = postcard::from_bytes(&buf)?;
|
||||
|
||||
Ok(type_data)
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
use crate::payload::{EventReceivePayload, EventSendPayload};
|
||||
use crate::prelude::{IPCError, IPCResult};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::io::Read;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SerializationError {
|
||||
#[error("failed to serialize with rmp: {0}")]
|
||||
Serialize(#[from] rmp_serde::encode::Error),
|
||||
|
||||
#[error("failed to deserialize with rmp: {0}")]
|
||||
Deserialize(#[from] rmp_serde::decode::Error),
|
||||
}
|
||||
|
||||
impl From<rmp_serde::decode::Error> for IPCError {
|
||||
fn from(e: rmp_serde::decode::Error) -> Self {
|
||||
IPCError::Serialization(SerializationError::Deserialize(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<rmp_serde::encode::Error> for IPCError {
|
||||
fn from(e: rmp_serde::encode::Error) -> Self {
|
||||
IPCError::Serialization(SerializationError::Serialize(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EventSendPayload for T
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
|
||||
let bytes = rmp_serde::to_vec(&self)?;
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EventReceivePayload for T
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
fn from_payload_bytes<R: Read>(reader: R) -> IPCResult<Self> {
|
||||
let type_data = rmp_serde::from_read(reader)?;
|
||||
Ok(type_data)
|
||||
}
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
use crate::events::generate_event_id;
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn event_ids_work() {
|
||||
let mut ids = HashSet::new();
|
||||
|
||||
// simple collision test
|
||||
for _ in 0..100000 {
|
||||
assert!(ids.insert(generate_event_id()))
|
||||
}
|
||||
}
|
@ -1,235 +0,0 @@
|
||||
use super::utils::PingEventData;
|
||||
use crate::prelude::*;
|
||||
use crate::tests::utils::start_test_server;
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::net::TcpListener;
|
||||
use typemap_rev::TypeMapKey;
|
||||
|
||||
async fn handle_ping_event(ctx: &Context, e: Event) -> IPCResult<()> {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let mut ping_data = e.data::<PingEventData>()?;
|
||||
ping_data.time = SystemTime::now();
|
||||
ping_data.ttl -= 1;
|
||||
|
||||
if ping_data.ttl > 0 {
|
||||
ctx.emitter.emit_response(e.id(), "pong", ping_data).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_builder_with_ping<L: AsyncStreamProtocolListener>(address: L::AddressType) -> IPCBuilder<L> {
|
||||
IPCBuilder::new()
|
||||
.on("ping", |ctx, e| Box::pin(handle_ping_event(ctx, e)))
|
||||
.timeout(Duration::from_secs(10))
|
||||
.address(address)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_receives_tcp_events() {
|
||||
let socket_address = "127.0.0.1:8281".to_socket_addrs().unwrap().next().unwrap();
|
||||
it_receives_events::<TcpListener>(socket_address).await;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn it_receives_unix_socket_events() {
|
||||
let socket_path = PathBuf::from("/tmp/test_socket");
|
||||
if socket_path.exists() {
|
||||
std::fs::remove_file(&socket_path).unwrap();
|
||||
}
|
||||
it_receives_events::<tokio::net::UnixListener>(socket_path).await;
|
||||
}
|
||||
|
||||
async fn it_receives_events<L: 'static + AsyncStreamProtocolListener>(address: L::AddressType) {
|
||||
let builder = get_builder_with_ping::<L>(address.clone());
|
||||
let server_running = Arc::new(AtomicBool::new(false));
|
||||
|
||||
tokio::spawn({
|
||||
let server_running = Arc::clone(&server_running);
|
||||
let builder = get_builder_with_ping::<L>(address);
|
||||
async move {
|
||||
server_running.store(true, Ordering::SeqCst);
|
||||
builder.build_server().await.unwrap();
|
||||
}
|
||||
});
|
||||
while !server_running.load(Ordering::Relaxed) {
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
let pool = builder.build_pooled_client(8).await.unwrap();
|
||||
let reply = pool
|
||||
.acquire()
|
||||
.emitter
|
||||
.emit(
|
||||
"ping",
|
||||
PingEventData {
|
||||
ttl: 16,
|
||||
time: SystemTime::now(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&pool.acquire())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(reply.name(), "pong");
|
||||
}
|
||||
|
||||
fn get_builder_with_ping_namespace(address: &str) -> IPCBuilder<TcpListener> {
|
||||
IPCBuilder::new()
|
||||
.namespace("mainspace")
|
||||
.on("ping", callback!(handle_ping_event))
|
||||
.build()
|
||||
.address(address.to_socket_addrs().unwrap().next().unwrap())
|
||||
}
|
||||
|
||||
pub struct TestNamespace;
|
||||
|
||||
impl TestNamespace {
|
||||
async fn ping(_c: &Context, _e: Event) -> IPCResult<()> {
|
||||
println!("Ping received");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl NamespaceProvider for TestNamespace {
|
||||
fn name() -> &'static str {
|
||||
"Test"
|
||||
}
|
||||
|
||||
fn register(handler: &mut EventHandler) {
|
||||
events!(handler,
|
||||
"ping" => Self::ping,
|
||||
"ping2" => Self::ping
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_receives_namespaced_events() {
|
||||
let builder = get_builder_with_ping_namespace("127.0.0.1:8282");
|
||||
let server_running = Arc::new(AtomicBool::new(false));
|
||||
tokio::spawn({
|
||||
let server_running = Arc::clone(&server_running);
|
||||
let builder = get_builder_with_ping_namespace("127.0.0.1:8282");
|
||||
async move {
|
||||
server_running.store(true, Ordering::SeqCst);
|
||||
builder.build_server().await.unwrap();
|
||||
}
|
||||
});
|
||||
while !server_running.load(Ordering::Relaxed) {
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
let ctx = builder
|
||||
.add_namespace(namespace!(TestNamespace))
|
||||
.build_client()
|
||||
.await
|
||||
.unwrap();
|
||||
let reply = ctx
|
||||
.emitter
|
||||
.emit_to(
|
||||
"mainspace",
|
||||
"ping",
|
||||
PingEventData {
|
||||
ttl: 16,
|
||||
time: SystemTime::now(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(reply.name(), "pong");
|
||||
}
|
||||
|
||||
struct ErrorOccurredKey;
|
||||
|
||||
impl TypeMapKey for ErrorOccurredKey {
|
||||
type Value = Arc<AtomicBool>;
|
||||
}
|
||||
|
||||
fn get_builder_with_error_handling(
|
||||
error_occurred: Arc<AtomicBool>,
|
||||
address: &str,
|
||||
) -> IPCBuilder<TcpListener> {
|
||||
IPCBuilder::new()
|
||||
.insert::<ErrorOccurredKey>(error_occurred)
|
||||
.on("ping", move |_, _| {
|
||||
Box::pin(async move { Err(IPCError::from("ERRROROROROR")) })
|
||||
})
|
||||
.on(
|
||||
"error",
|
||||
callback!(ctx, event, async move {
|
||||
let error = event.data::<error_event::ErrorEventData>()?;
|
||||
assert!(error.message.len() > 0);
|
||||
assert_eq!(error.code, 500);
|
||||
{
|
||||
let data = ctx.data.read().await;
|
||||
let error_occurred = data.get::<ErrorOccurredKey>().unwrap();
|
||||
error_occurred.store(true, Ordering::SeqCst);
|
||||
}
|
||||
Ok(())
|
||||
}),
|
||||
)
|
||||
.address(address.to_socket_addrs().unwrap().next().unwrap())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_handles_errors() {
|
||||
let error_occurred = Arc::new(AtomicBool::new(false));
|
||||
let builder = get_builder_with_error_handling(Arc::clone(&error_occurred), "127.0.0.1:8283");
|
||||
let server_running = Arc::new(AtomicBool::new(false));
|
||||
|
||||
tokio::spawn({
|
||||
let server_running = Arc::clone(&server_running);
|
||||
let error_occurred = Arc::clone(&error_occurred);
|
||||
let builder = get_builder_with_error_handling(error_occurred, "127.0.0.1:8283");
|
||||
async move {
|
||||
server_running.store(true, Ordering::SeqCst);
|
||||
builder.build_server().await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
while !server_running.load(Ordering::Relaxed) {
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
let ctx = builder.build_client().await.unwrap();
|
||||
ctx.emitter.emit("ping", ()).await.unwrap();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
assert!(error_occurred.load(Ordering::SeqCst));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_responses() {
|
||||
static ADDRESS: &str = "127.0.0.1:8284";
|
||||
start_test_server(ADDRESS).await.unwrap();
|
||||
let ctx = IPCBuilder::<TcpListener>::new()
|
||||
.address(ADDRESS.to_socket_addrs().unwrap().next().unwrap())
|
||||
.build_client()
|
||||
.await
|
||||
.unwrap();
|
||||
let reply = ctx
|
||||
.emitter
|
||||
.emit("ping", ())
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(reply.name(), "pong");
|
||||
|
||||
let reply = ctx
|
||||
.emitter
|
||||
.emit("trigger_error", ())
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&ctx)
|
||||
.await;
|
||||
assert!(reply.is_err());
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
mod event_tests;
|
||||
mod ipc_tests;
|
||||
mod utils;
|
@ -1,37 +0,0 @@
|
||||
use crate::error::Error;
|
||||
use crate::IPCBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::time::SystemTime;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct PingEventData {
|
||||
pub time: SystemTime,
|
||||
pub ttl: u8,
|
||||
}
|
||||
|
||||
/// Starts a test IPC server
|
||||
pub fn start_test_server(address: &'static str) -> oneshot::Receiver<bool> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
tokio::task::spawn(async move {
|
||||
tx.send(true).unwrap();
|
||||
IPCBuilder::<TcpListener>::new()
|
||||
.address(address.to_socket_addrs().unwrap().next().unwrap())
|
||||
.on("ping", |ctx, event| {
|
||||
Box::pin(async move {
|
||||
ctx.emitter.emit_response(event.id(), "pong", ()).await?;
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
.on("trigger_error", |_, _| {
|
||||
Box::pin(async move { Err(Error::from("An error occurred.")) })
|
||||
})
|
||||
.build_server()
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
@ -0,0 +1,141 @@
|
||||
mod utils;
|
||||
|
||||
use crate::utils::start_server_and_client;
|
||||
use bromine::prelude::*;
|
||||
use payload_impl::SimplePayload;
|
||||
use std::time::Duration;
|
||||
use utils::call_counter::*;
|
||||
use utils::get_free_port;
|
||||
use utils::protocol::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_sends_payloads() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
|
||||
ctx.emitter
|
||||
.emit(
|
||||
"ping",
|
||||
SimplePayload {
|
||||
number: 0,
|
||||
string: String::from("Hello World"),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// wait for the event to be handled
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
|
||||
let counters = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert_eq!(counters.get("ping").await, 1);
|
||||
assert_eq!(counters.get("pong").await, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_receives_payloads() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
let reply = ctx
|
||||
.emitter
|
||||
.emit(
|
||||
"ping",
|
||||
SimplePayload {
|
||||
number: 0,
|
||||
string: String::from("Hello World"),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
let reply_payload = reply.data::<SimplePayload>().unwrap();
|
||||
let counters = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert_eq!(counters.get("ping").await, 1);
|
||||
assert_eq!(reply_payload.string, String::from("Hello World"));
|
||||
assert_eq!(reply_payload.number, 0);
|
||||
}
|
||||
|
||||
async fn get_client_with_server(port: u8) -> Context {
|
||||
start_server_and_client(move || get_builder(port)).await
|
||||
}
|
||||
|
||||
fn get_builder(port: u8) -> IPCBuilder<TestProtocolListener> {
|
||||
IPCBuilder::new()
|
||||
.address(port)
|
||||
.on("ping", callback!(handle_ping_event))
|
||||
.on("pong", callback!(handle_pong_event))
|
||||
.timeout(Duration::from_millis(10))
|
||||
}
|
||||
|
||||
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
let payload = event.data::<SimplePayload>()?;
|
||||
ctx.emitter
|
||||
.emit_response(event.id(), "pong", payload)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
let _payload = event.data::<SimplePayload>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "serialize")]
|
||||
mod payload_impl {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SimplePayload {
|
||||
pub string: String,
|
||||
pub number: u32,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "serialize"))]
|
||||
mod payload_impl {
|
||||
use bromine::error::Result;
|
||||
use bromine::payload::{EventReceivePayload, EventSendPayload};
|
||||
use bromine::prelude::IPCResult;
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use std::io::Read;
|
||||
|
||||
pub struct SimplePayload {
|
||||
pub string: String,
|
||||
pub number: u32,
|
||||
}
|
||||
|
||||
impl EventSendPayload for SimplePayload {
|
||||
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
|
||||
let mut buf = Vec::new();
|
||||
let string_length = self.string.len() as u16;
|
||||
let string_length_bytes = string_length.to_be_bytes();
|
||||
buf.append(&mut string_length_bytes.to_vec());
|
||||
let mut string_bytes = self.string.into_bytes();
|
||||
buf.append(&mut string_bytes);
|
||||
let num_bytes = self.number.to_be_bytes();
|
||||
buf.append(&mut num_bytes.to_vec());
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl EventReceivePayload for SimplePayload {
|
||||
fn from_payload_bytes<R: Read>(mut reader: R) -> Result<Self> {
|
||||
let string_length = reader.read_u16::<BigEndian>()?;
|
||||
let mut string_buf = vec![0u8; string_length as usize];
|
||||
reader.read_exact(&mut string_buf)?;
|
||||
let string = String::from_utf8(string_buf).unwrap();
|
||||
let number = reader.read_u32::<BigEndian>()?;
|
||||
|
||||
Ok(Self { string, number })
|
||||
}
|
||||
}
|
||||
}
|
@ -1,151 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use bromine::error::Result;
|
||||
use bromine::prelude::{AsyncProtocolStreamSplit, IPCError};
|
||||
use bromine::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
|
||||
use lazy_static::lazy_static;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Error;
|
||||
use std::pin::Pin;
|
||||
use std::sync::mpsc;
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::sync::mpsc::{
|
||||
channel as async_channel, Receiver as AsyncReceiver, Sender as AsyncSender,
|
||||
};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
|
||||
lazy_static! {
|
||||
static ref LISTENERS_REF: Arc<AsyncMutex<HashMap<u8, AsyncSender<TestProtocolStream>>>> =
|
||||
Arc::new(AsyncMutex::new(HashMap::new()));
|
||||
}
|
||||
|
||||
async fn add_port(number: u8, sender: tokio::sync::mpsc::Sender<TestProtocolStream>) {
|
||||
let mut listeners = LISTENERS_REF.lock().await;
|
||||
listeners.insert(number, sender);
|
||||
}
|
||||
|
||||
async fn get_port(number: u8) -> Option<TestProtocolStream> {
|
||||
let mut listeners = LISTENERS_REF.lock().await;
|
||||
|
||||
if let Some(sender) = listeners.get_mut(&number) {
|
||||
let (s1, r1) = mpsc::channel();
|
||||
let (s2, r2) = mpsc::channel();
|
||||
let stream_1 = TestProtocolStream {
|
||||
sender: Arc::new(Mutex::new(s1)),
|
||||
receiver: Arc::new(Mutex::new(r2)),
|
||||
};
|
||||
let stream_2 = TestProtocolStream {
|
||||
sender: Arc::new(Mutex::new(s2)),
|
||||
receiver: Arc::new(Mutex::new(r1)),
|
||||
};
|
||||
sender.send(stream_2).await.ok();
|
||||
|
||||
Some(stream_1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestProtocolListener {
|
||||
receiver: Arc<AsyncMutex<AsyncReceiver<TestProtocolStream>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncStreamProtocolListener for TestProtocolListener {
|
||||
type AddressType = u8;
|
||||
type RemoteAddressType = u8;
|
||||
type Stream = TestProtocolStream;
|
||||
|
||||
async fn protocol_bind(address: Self::AddressType) -> Result<Self> {
|
||||
let (sender, receiver) = async_channel(1);
|
||||
add_port(address, sender).await;
|
||||
|
||||
Ok(Self {
|
||||
receiver: Arc::new(AsyncMutex::new(receiver)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn protocol_accept(&self) -> Result<(Self::Stream, Self::RemoteAddressType)> {
|
||||
self.receiver
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.map(|r| (r, 0u8))
|
||||
.ok_or_else(|| IPCError::from("Failed to accept"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TestProtocolStream {
|
||||
sender: Arc<Mutex<Sender<Vec<u8>>>>,
|
||||
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl AsyncProtocolStreamSplit for TestProtocolStream {
|
||||
type OwnedSplitReadHalf = Self;
|
||||
type OwnedSplitWriteHalf = Self;
|
||||
|
||||
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
|
||||
(self.clone(), self)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncProtocolStream for TestProtocolStream {
|
||||
type AddressType = u8;
|
||||
|
||||
async fn protocol_connect(address: Self::AddressType) -> Result<Self> {
|
||||
get_port(address)
|
||||
.await
|
||||
.ok_or_else(|| IPCError::from("Failed to connect"))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TestProtocolStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let receiver = self.receiver.lock().unwrap();
|
||||
if let Ok(b) = receiver.recv() {
|
||||
buf.put_slice(&b);
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestProtocolStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::prelude::rust_2015::Result<usize, Error>> {
|
||||
let sender = self.sender.lock().unwrap();
|
||||
let vec_buf = buf.to_vec();
|
||||
let buf_len = vec_buf.len();
|
||||
sender.send(vec_buf).unwrap();
|
||||
|
||||
Poll::Ready(Ok(buf_len))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
@ -1,45 +1,158 @@
|
||||
mod test_protocol;
|
||||
mod utils;
|
||||
|
||||
use crate::utils::start_server_and_client;
|
||||
use bromine::prelude::*;
|
||||
use std::time::Duration;
|
||||
use test_protocol::*;
|
||||
use utils::call_counter::*;
|
||||
use utils::get_free_port;
|
||||
use utils::protocol::*;
|
||||
|
||||
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
ctx.emitter.emit_response(event.id(), "pong", ()).await?;
|
||||
/// Simple events are passed from the client to the server and responses
|
||||
/// are emitted back to the client. Both will have received an event.
|
||||
#[tokio::test]
|
||||
async fn it_sends_events() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
ctx.emitter.emit("ping", EmptyPayload).await.unwrap();
|
||||
|
||||
Ok(())
|
||||
// allow the event to be processed
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let counter = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert_eq!(counter.get("ping").await, 1);
|
||||
assert_eq!(counter.get("pong").await, 1);
|
||||
}
|
||||
|
||||
async fn handle_pong_event(_ctx: &Context, _event: Event) -> IPCResult<()> {
|
||||
Ok(())
|
||||
/// Events sent to a specific namespace are handled by the namespace event handler
|
||||
#[tokio::test]
|
||||
async fn it_sends_namespaced_events() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
ctx.emitter
|
||||
.emit_to("test", "ping", EmptyPayload)
|
||||
.await
|
||||
.unwrap();
|
||||
ctx.emitter
|
||||
.emit_to("test", "pong", EmptyPayload)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// allow the event to be processed
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let counter = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert_eq!(counter.get("test:ping").await, 1);
|
||||
assert_eq!(counter.get("test:pong").await, 1);
|
||||
}
|
||||
|
||||
/// When awaiting the reply to an event the handler for the event doesn't get called.
|
||||
/// Therefore we expect it to have a call count of 0.
|
||||
#[tokio::test]
|
||||
async fn it_receives_responses() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
let reply = ctx
|
||||
.emitter
|
||||
.emit("ping", EmptyPayload)
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&ctx)
|
||||
.await
|
||||
.unwrap();
|
||||
let counter = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert_eq!(reply.name(), "pong");
|
||||
assert_eq!(counter.get("ping").await, 1);
|
||||
assert_eq!(counter.get("pong").await, 0);
|
||||
}
|
||||
|
||||
/// When emitting errors from handlers the client should receive an error event
|
||||
/// with the error that occurred on the server.
|
||||
#[tokio::test]
|
||||
async fn it_handles_errors() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
ctx.emitter
|
||||
.emit("create_error", EmptyPayload)
|
||||
.await
|
||||
.unwrap();
|
||||
// allow the event to be processed
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
let counter = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert_eq!(counter.get("error").await, 1);
|
||||
}
|
||||
|
||||
/// When waiting for the reply to an event and an error occurs, the error should
|
||||
/// bypass the handler and be passed as the Err variant on the await reply instead.
|
||||
#[tokio::test]
|
||||
async fn it_receives_error_responses() {
|
||||
let port = get_free_port();
|
||||
let ctx = get_client_with_server(port).await;
|
||||
let result = ctx
|
||||
.emitter
|
||||
.emit("create_error", EmptyPayload)
|
||||
.await
|
||||
.unwrap()
|
||||
.await_reply(&ctx)
|
||||
.await;
|
||||
|
||||
let counter = get_counter_from_context(&ctx).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(counter.get("error").await, 0);
|
||||
}
|
||||
|
||||
async fn get_client_with_server(port: u8) -> Context {
|
||||
start_server_and_client(move || get_builder(port)).await
|
||||
}
|
||||
|
||||
fn get_builder(port: u8) -> IPCBuilder<TestProtocolListener> {
|
||||
IPCBuilder::new()
|
||||
.address(port)
|
||||
.on(
|
||||
"ping",
|
||||
callback!(
|
||||
ctx,
|
||||
event,
|
||||
async move { handle_ping_event(ctx, event).await }
|
||||
),
|
||||
)
|
||||
.timeout(Duration::from_millis(100))
|
||||
.on(
|
||||
"pong",
|
||||
callback!(
|
||||
ctx,
|
||||
event,
|
||||
async move { handle_pong_event(ctx, event).await }
|
||||
),
|
||||
)
|
||||
.on("ping", callback!(handle_ping_event))
|
||||
.on("pong", callback!(handle_pong_event))
|
||||
.on("create_error", callback!(handle_create_error_event))
|
||||
.on("error", callback!(handle_error_event))
|
||||
.namespace("test")
|
||||
.on("ping", callback!(handle_ping_event))
|
||||
.on("pong", callback!(handle_pong_event))
|
||||
.on("create_error", callback!(handle_create_error_event))
|
||||
.build()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_passes_events() {
|
||||
tokio::task::spawn(async { get_builder(0).build_server().await.unwrap() });
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
let ctx = get_builder(0).build_client().await.unwrap();
|
||||
ctx.emitter.emit("ping", ()).await.unwrap(); // todo fix reply deadlock
|
||||
async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
ctx.emitter
|
||||
.emit_response(event.id(), "pong", EmptyPayload)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_create_error_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
|
||||
Err(IPCError::from("Test Error"))
|
||||
}
|
||||
|
||||
async fn handle_error_event(ctx: &Context, event: Event) -> IPCResult<()> {
|
||||
increment_counter_for_event(ctx, &event).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct EmptyPayload;
|
||||
|
||||
impl EventSendPayload for EmptyPayload {
|
||||
fn to_payload_bytes(self) -> IPCResult<Vec<u8>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,61 @@
|
||||
use bromine::context::Context;
|
||||
use bromine::event::Event;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use typemap_rev::TypeMapKey;
|
||||
|
||||
pub async fn get_counter_from_context(ctx: &Context) -> CallCounter {
|
||||
let data = ctx.data.read().await;
|
||||
|
||||
data.get::<CallCounterKey>().unwrap().clone()
|
||||
}
|
||||
|
||||
pub async fn increment_counter_for_event(ctx: &Context, event: &Event) {
|
||||
let data = ctx.data.read().await;
|
||||
|
||||
let key_name = if let Some(namespace) = event.namespace() {
|
||||
format!("{}:{}", namespace, event.name())
|
||||
} else {
|
||||
event.name().to_string()
|
||||
};
|
||||
|
||||
data.get::<CallCounterKey>().unwrap().incr(&key_name).await;
|
||||
}
|
||||
|
||||
pub struct CallCounterKey;
|
||||
|
||||
impl TypeMapKey for CallCounterKey {
|
||||
type Value = CallCounter;
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct CallCounter {
|
||||
inner: Arc<RwLock<HashMap<String, AtomicUsize>>>,
|
||||
}
|
||||
|
||||
impl CallCounter {
|
||||
pub async fn incr(&self, name: &str) {
|
||||
{
|
||||
let calls = self.inner.read().await;
|
||||
if let Some(call) = calls.get(name) {
|
||||
call.fetch_add(1, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut calls = self.inner.write().await;
|
||||
calls.insert(name.to_string(), AtomicUsize::new(1));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get(&self, name: &str) -> usize {
|
||||
let calls = self.inner.read().await;
|
||||
|
||||
calls
|
||||
.get(name)
|
||||
.map(|n| n.load(Ordering::SeqCst))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
use bromine::context::Context;
|
||||
use bromine::protocol::AsyncStreamProtocolListener;
|
||||
use bromine::IPCBuilder;
|
||||
use call_counter::*;
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::atomic::{AtomicU8, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::oneshot::channel;
|
||||
|
||||
pub mod call_counter;
|
||||
pub mod protocol;
|
||||
|
||||
pub fn get_free_port() -> u8 {
|
||||
lazy_static! {
|
||||
static ref PORT_COUNTER: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
|
||||
}
|
||||
PORT_COUNTER.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub async fn start_server_and_client<
|
||||
F: Fn() -> IPCBuilder<L> + Send + Sync + 'static,
|
||||
L: AsyncStreamProtocolListener,
|
||||
>(
|
||||
builder_fn: F,
|
||||
) -> Context {
|
||||
let counters = CallCounter::default();
|
||||
let (sender, receiver) = channel::<()>();
|
||||
let client_builder = builder_fn().insert::<CallCounterKey>(counters.clone());
|
||||
|
||||
tokio::task::spawn({
|
||||
async move {
|
||||
sender.send(()).unwrap();
|
||||
builder_fn()
|
||||
.insert::<CallCounterKey>(counters)
|
||||
.build_server()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
});
|
||||
receiver.await.unwrap();
|
||||
|
||||
let ctx = client_builder.build_client().await.unwrap();
|
||||
|
||||
ctx
|
||||
}
|
@ -0,0 +1,263 @@
|
||||
use async_trait::async_trait;
|
||||
use bromine::error::Result;
|
||||
use bromine::prelude::{AsyncProtocolStreamSplit, IPCError};
|
||||
use bromine::protocol::{AsyncProtocolStream, AsyncStreamProtocolListener};
|
||||
use lazy_static::lazy_static;
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::io::Error;
|
||||
use std::mem;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
lazy_static! {
|
||||
static ref LISTENERS_REF: Arc<Mutex<HashMap<u8, Sender<TestProtocolStream>>>> =
|
||||
Arc::new(Mutex::new(HashMap::new()));
|
||||
}
|
||||
|
||||
/// Adds a channel that receives streams to handle
|
||||
async fn add_port(number: u8, sender: tokio::sync::mpsc::Sender<TestProtocolStream>) {
|
||||
let mut listeners = LISTENERS_REF.lock().await;
|
||||
listeners.insert(number, sender);
|
||||
}
|
||||
|
||||
/// Returns a stream for the given port connecting with the server via channels
|
||||
async fn get_port(number: u8) -> Option<TestProtocolStream> {
|
||||
let mut listeners = LISTENERS_REF.lock().await;
|
||||
|
||||
if let Some(sender) = listeners.get_mut(&number) {
|
||||
let (s1, r1) = channel(2);
|
||||
let (s2, r2) = channel(2);
|
||||
let stream_1 = TestProtocolStream {
|
||||
sender: s1,
|
||||
receiver: Arc::new(Mutex::new(r2)),
|
||||
future: None,
|
||||
remaining_buf: Default::default(),
|
||||
};
|
||||
let stream_2 = TestProtocolStream {
|
||||
sender: s2,
|
||||
receiver: Arc::new(Mutex::new(r1)),
|
||||
future: None,
|
||||
remaining_buf: Default::default(),
|
||||
};
|
||||
sender.send(stream_2).await.ok();
|
||||
|
||||
Some(stream_1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestProtocolListener {
|
||||
receiver: Arc<Mutex<Receiver<TestProtocolStream>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncStreamProtocolListener for TestProtocolListener {
|
||||
type AddressType = u8;
|
||||
type RemoteAddressType = u8;
|
||||
type Stream = TestProtocolStream;
|
||||
|
||||
async fn protocol_bind(address: Self::AddressType) -> Result<Self> {
|
||||
let (sender, receiver) = channel(1);
|
||||
add_port(address, sender).await;
|
||||
|
||||
Ok(Self {
|
||||
receiver: Arc::new(Mutex::new(receiver)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn protocol_accept(&self) -> Result<(Self::Stream, Self::RemoteAddressType)> {
|
||||
self.receiver
|
||||
.lock()
|
||||
.await
|
||||
.recv()
|
||||
.await
|
||||
.map(|r| (r, 0u8))
|
||||
.ok_or_else(|| IPCError::from("Failed to accept"))
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TestProtocolStream {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
sender: self.sender.clone(),
|
||||
receiver: Arc::clone(&self.receiver),
|
||||
future: None,
|
||||
remaining_buf: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TestProtocolStream {
|
||||
sender: Sender<Vec<u8>>,
|
||||
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
future: Option<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
|
||||
remaining_buf: Arc<Mutex<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl TestProtocolStream {
|
||||
/// Read from the receiver and remaining buffer
|
||||
async fn read_from_receiver(
|
||||
buf: &mut ReadBuf<'static>,
|
||||
receiver: Arc<Mutex<Receiver<Vec<u8>>>>,
|
||||
remaining_buf: Arc<Mutex<Vec<u8>>>,
|
||||
) {
|
||||
{
|
||||
let mut remaining_buf = remaining_buf.lock().await;
|
||||
if !remaining_buf.is_empty() {
|
||||
if Self::read_from_remaining_buffer(buf, &mut remaining_buf).await {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut receiver = receiver.lock().await;
|
||||
|
||||
if let Some(mut bytes) = receiver.recv().await {
|
||||
let slice_len = min(bytes.len(), buf.capacity());
|
||||
|
||||
buf.put_slice(&bytes[0..slice_len]);
|
||||
bytes.reverse();
|
||||
bytes.truncate(bytes.len() - slice_len);
|
||||
bytes.reverse();
|
||||
let mut remaining_buf = remaining_buf.lock().await;
|
||||
remaining_buf.append(&mut bytes);
|
||||
}
|
||||
}
|
||||
|
||||
/// Read from the remaining buffer returning a boolean if the
|
||||
/// read buffer has been filled
|
||||
async fn read_from_remaining_buffer(
|
||||
buf: &mut ReadBuf<'static>,
|
||||
remaining_buf: &mut Vec<u8>,
|
||||
) -> bool {
|
||||
if remaining_buf.len() < buf.capacity() {
|
||||
buf.put_slice(&remaining_buf);
|
||||
remaining_buf.clear();
|
||||
|
||||
false
|
||||
} else if remaining_buf.len() == buf.capacity() {
|
||||
buf.put_slice(&remaining_buf);
|
||||
remaining_buf.clear();
|
||||
|
||||
true
|
||||
} else {
|
||||
let slice_len = buf.capacity();
|
||||
let remaining_len = remaining_buf.len();
|
||||
buf.put_slice(&remaining_buf[0..slice_len]);
|
||||
remaining_buf.reverse();
|
||||
remaining_buf.truncate(remaining_len - slice_len);
|
||||
remaining_buf.reverse();
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncProtocolStreamSplit for TestProtocolStream {
|
||||
type OwnedSplitReadHalf = Self;
|
||||
type OwnedSplitWriteHalf = Self;
|
||||
|
||||
fn protocol_into_split(self) -> (Self::OwnedSplitReadHalf, Self::OwnedSplitWriteHalf) {
|
||||
(self.clone(), self)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncProtocolStream for TestProtocolStream {
|
||||
type AddressType = u8;
|
||||
|
||||
async fn protocol_connect(address: Self::AddressType) -> Result<Self> {
|
||||
get_port(address)
|
||||
.await
|
||||
.ok_or_else(|| IPCError::from("Failed to connect"))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TestProtocolStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
unsafe {
|
||||
// we need a mutable reference to access the inner future
|
||||
let stream = self.get_unchecked_mut();
|
||||
|
||||
if stream.future.is_none() {
|
||||
// we need to change the lifetime to be able to use the read buffer in the read future
|
||||
let buf: &mut ReadBuf<'static> = mem::transmute(buf);
|
||||
let receiver = Arc::clone(&stream.receiver);
|
||||
let remaining_buf = Arc::clone(&stream.remaining_buf);
|
||||
|
||||
let future = TestProtocolStream::read_from_receiver(buf, receiver, remaining_buf);
|
||||
stream.future = Some(Box::pin(future));
|
||||
}
|
||||
if let Some(future) = &mut stream.future {
|
||||
match future.as_mut().poll(cx) {
|
||||
Poll::Ready(_) => {
|
||||
stream.future = None;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestProtocolStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::prelude::rust_2015::Result<usize, Error>> {
|
||||
let write_len = buf.len();
|
||||
unsafe {
|
||||
// we need a mutable reference to access the inner future
|
||||
let stream = self.get_unchecked_mut();
|
||||
|
||||
if stream.future.is_none() {
|
||||
// we take ownership here so that we don't need to change lifetimes here
|
||||
let buf = buf.to_vec();
|
||||
let sender = stream.sender.clone();
|
||||
|
||||
let future = async move {
|
||||
sender.send(buf).await.unwrap();
|
||||
};
|
||||
stream.future = Some(Box::pin(future));
|
||||
}
|
||||
if let Some(future) = &mut stream.future {
|
||||
match future.as_mut().poll(cx) {
|
||||
Poll::Ready(_) => {
|
||||
stream.future = None;
|
||||
Poll::Ready(Ok(write_len))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::prelude::rust_2015::Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue