From 9cc7d1ffe8ad00e1bba41d0e86f3aa06ea61ca15 Mon Sep 17 00:00:00 2001 From: trivernis Date: Sun, 6 Feb 2022 15:03:29 +0100 Subject: [PATCH] Add asynchronous response streams Signed-off-by: trivernis --- Cargo.lock | 1 + Cargo.toml | 1 + src/ipc/context.rs | 20 +-- src/ipc/mod.rs | 4 +- src/ipc/stream_emitter/emit_metadata.rs | 9 ++ .../emit_metadata_with_response.rs | 26 +-- .../emit_metadata_with_response_stream.rs | 150 ++++++++++++++++++ src/ipc/stream_emitter/mod.rs | 17 +- src/lib.rs | 1 + tests/test_event_streams.rs | 88 ++++++++++ 10 files changed, 282 insertions(+), 35 deletions(-) create mode 100644 src/ipc/stream_emitter/emit_metadata_with_response_stream.rs create mode 100644 tests/test_event_streams.rs diff --git a/Cargo.lock b/Cargo.lock index 347cf69e..9951cb54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,7 @@ dependencies = [ "criterion", "crossbeam-utils", "futures", + "futures-core", "lazy_static", "num_enum", "postcard", diff --git a/Cargo.toml b/Cargo.toml index 5f58824f..68dcbaa9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ byteorder = "1.4.3" async-trait = "0.1.52" futures = "0.3.19" num_enum = "0.5.6" +futures-core = "0.3.19" rmp-serde = {version = "0.15.5", optional = true} bincode = {version = "1.3.3", optional = true} serde_json = {version = "1.0.73", optional = true} diff --git a/src/ipc/context.rs b/src/ipc/context.rs index b7f8b51a..1f0e081e 100644 --- a/src/ipc/context.rs +++ b/src/ipc/context.rs @@ -4,21 +4,21 @@ use std::ops::{Deref, DerefMut}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use tokio::sync::oneshot::{Receiver, Sender}; -use tokio::sync::{Mutex, oneshot, RwLock}; +use tokio::sync::mpsc::Receiver; +use tokio::sync::{mpsc, oneshot, Mutex, RwLock}; use tokio::time::Duration; use typemap_rev::TypeMap; use crate::error::{Error, Result}; use crate::event::{Event, EventType}; -use crate::ipc::stream_emitter::StreamEmitter; use crate::ipc::stream_emitter::emit_metadata::EmitMetadata; +use crate::ipc::stream_emitter::StreamEmitter; use crate::payload::IntoPayload; #[cfg(feature = "serialize")] use crate::payload::{DynamicSerializer, SerdePayload}; use crate::prelude::Response; -pub(crate) type ReplyListeners = Arc>>>; +pub(crate) type ReplyListeners = Arc>>>; /// An object provided to each callback function. /// Currently it only holds the event emitter to emit response events in event callbacks. @@ -40,7 +40,7 @@ pub struct Context { /// Field to store additional context data pub data: Arc>, - stop_sender: Arc>>>, + stop_sender: Arc>>>, pub(crate) reply_listeners: ReplyListeners, @@ -56,7 +56,7 @@ impl Context { pub(crate) fn new( emitter: StreamEmitter, data: Arc>, - stop_sender: Option>, + stop_sender: Option>, reply_listeners: ReplyListeners, reply_timeout: Duration, #[cfg(feature = "serialize")] default_serializer: DynamicSerializer, @@ -125,7 +125,7 @@ impl Context { #[inline] #[tracing::instrument(level = "debug", skip(self))] pub(crate) async fn register_reply_listener(&self, event_id: u64) -> Result> { - let (rx, tx) = oneshot::channel(); + let (rx, tx) = mpsc::channel(8); { let mut listeners = self.reply_listeners.lock().await; listeners.insert(event_id, rx); @@ -153,9 +153,9 @@ impl Context { /// Returns the channel for a reply to the given message id #[inline] - pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option> { - let mut listeners = self.reply_listeners.lock().await; - listeners.remove(&ref_id) + pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option> { + let listeners = self.reply_listeners.lock().await; + listeners.get(&ref_id).cloned() } #[inline] diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index db9f456b..39b4ca5d 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -34,8 +34,8 @@ async fn handle_connection( // get the listener for replies if let Some(sender) = ctx.get_reply_sender(ref_id).await { // try sending the event to the listener for replies - if let Err(event) = sender.send(event) { - handle_event(Context::clone(&ctx), Arc::clone(&handler), event); + if let Err(event) = sender.send(event).await { + handle_event(Context::clone(&ctx), Arc::clone(&handler), event.0); } continue; } diff --git a/src/ipc/stream_emitter/emit_metadata.rs b/src/ipc/stream_emitter/emit_metadata.rs index ebc4f559..57d77fab 100644 --- a/src/ipc/stream_emitter/emit_metadata.rs +++ b/src/ipc/stream_emitter/emit_metadata.rs @@ -2,6 +2,7 @@ use crate::context::Context; use crate::error::Error; use crate::event::EventType; use crate::ipc::stream_emitter::emit_metadata_with_response::EmitMetadataWithResponse; +use crate::ipc::stream_emitter::emit_metadata_with_response_stream::EmitMetadataWithResponseStream; use crate::ipc::stream_emitter::event_metadata::EventMetadata; use crate::ipc::stream_emitter::SendStream; use crate::payload::IntoPayload; @@ -57,6 +58,14 @@ impl EmitMetadata

{ emit_metadata: Some(self), } } + + pub fn stream_replies(self) -> EmitMetadataWithResponseStream

{ + EmitMetadataWithResponseStream { + timeout: None, + fut: None, + emit_metadata: Some(self), + } + } } impl Unpin for EmitMetadata

{} diff --git a/src/ipc/stream_emitter/emit_metadata_with_response.rs b/src/ipc/stream_emitter/emit_metadata_with_response.rs index cdbf4d54..5f6c6c1a 100644 --- a/src/ipc/stream_emitter/emit_metadata_with_response.rs +++ b/src/ipc/stream_emitter/emit_metadata_with_response.rs @@ -1,11 +1,10 @@ +use crate::context::Context; use crate::error::Error; use crate::error_event::ErrorEventData; use crate::event::{Event, EventType}; use crate::ipc::stream_emitter::emit_metadata::EmitMetadata; use crate::payload::IntoPayload; use crate::{error, poll_unwrap}; -use futures::future; -use futures::future::Either; use std::future::Future; use std::pin::Pin; use std::task::Poll; @@ -54,20 +53,20 @@ impl Future for EmitMetadataWithResponse }; self.fut = Some(Box::pin(async move { - let tx = ctx.register_reply_listener(event_id).await?; + let mut tx = ctx.register_reply_listener(event_id).await?; emit_metadata.await?; - let result = - future::select(Box::pin(tx), Box::pin(tokio::time::sleep(timeout))).await; - - let reply = match result { - Either::Left((tx_result, _)) => Ok(tx_result?), - Either::Right(_) => { - let mut listeners = ctx.reply_listeners.lock().await; - listeners.remove(&event_id); + let reply = tokio::select! { + tx_result = tx.recv() => { + Ok(tx_result.ok_or_else(|| Error::SendError)?) + } + _ = tokio::time::sleep(timeout) => { Err(Error::Timeout) } }?; + + remove_reply_listener(&ctx, event_id).await; + if reply.event_type() == EventType::Error { Err(reply.payload::()?.into()) } else { @@ -78,3 +77,8 @@ impl Future for EmitMetadataWithResponse self.fut.as_mut().unwrap().as_mut().poll(cx) } } + +pub(crate) async fn remove_reply_listener(ctx: &Context, event_id: u64) { + let mut listeners = ctx.reply_listeners.lock().await; + listeners.remove(&event_id); +} diff --git a/src/ipc/stream_emitter/emit_metadata_with_response_stream.rs b/src/ipc/stream_emitter/emit_metadata_with_response_stream.rs new file mode 100644 index 00000000..8ef0de3d --- /dev/null +++ b/src/ipc/stream_emitter/emit_metadata_with_response_stream.rs @@ -0,0 +1,150 @@ +use crate::context::Context; +use crate::error::{Error, Result}; +use crate::event::{Event, EventType}; +use crate::ipc::stream_emitter::emit_metadata::EmitMetadata; +use crate::ipc::stream_emitter::emit_metadata_with_response::remove_reply_listener; +use crate::payload::IntoPayload; +use crate::poll_unwrap; +use futures_core::Stream; +use std::future::Future; +use std::pin::Pin; +use std::task::Poll; +use std::time::Duration; +use tokio::sync::mpsc::Receiver; + +/// A metadata object returned after waiting for a reply to an event +/// This object needs to be awaited for to get the actual reply +pub struct EmitMetadataWithResponseStream { + pub(crate) timeout: Option, + pub(crate) fut: Option> + Send + Sync>>>, + pub(crate) emit_metadata: Option>, +} + +pub struct ResponseStream { + event_id: u64, + ctx: Option, + receiver: Option>, + timeout: Duration, + fut: Option, Context, Receiver)>>>>>, +} + +impl ResponseStream { + pub(crate) fn new( + event_id: u64, + timeout: Duration, + ctx: Context, + receiver: Receiver, + ) -> Self { + Self { + event_id, + ctx: Some(ctx), + receiver: Some(receiver), + timeout, + fut: None, + } + } +} + +impl Unpin for EmitMetadataWithResponseStream

{} + +impl EmitMetadataWithResponseStream

{ + /// Sets a timeout for awaiting replies to this emitted event + #[inline] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + + self + } +} + +impl Future for EmitMetadataWithResponseStream

{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + if self.fut.is_none() { + let mut emit_metadata = poll_unwrap!(self.emit_metadata.take()); + let ctx = poll_unwrap!(emit_metadata + .event_metadata + .as_ref() + .and_then(|m| m.ctx.clone())); + let timeout = self + .timeout + .clone() + .unwrap_or(ctx.default_reply_timeout.clone()); + + let event_id = match poll_unwrap!(emit_metadata.event_metadata.as_mut()).get_event() { + Ok(e) => e.id(), + Err(e) => { + return Poll::Ready(Err(e)); + } + }; + + self.fut = Some(Box::pin(async move { + let tx = ctx.register_reply_listener(event_id).await?; + emit_metadata.await?; + + Ok(ResponseStream::new(event_id, timeout, ctx, tx)) + })) + } + self.fut.as_mut().unwrap().as_mut().poll(cx) + } +} + +impl Unpin for ResponseStream {} + +impl Stream for ResponseStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if self.fut.is_none() { + if self.ctx.is_none() || self.receiver.is_none() { + return Poll::Ready(None); + } + let ctx = self.ctx.take().unwrap(); + let mut receiver = self.receiver.take().unwrap(); + let timeout = self.timeout; + let event_id = self.event_id; + + self.fut = Some(Box::pin(async move { + let event: Option = tokio::select! { + tx_result = receiver.recv() => { + Ok(tx_result) + } + _ = tokio::time::sleep(timeout) => { + Err(Error::Timeout) + } + }?; + + if event.is_none() || event.as_ref().unwrap().event_type() == EventType::End { + remove_reply_listener(&ctx, event_id).await; + } + + Ok((event, ctx, receiver)) + })); + } + + match self.fut.as_mut().unwrap().as_mut().poll(cx) { + Poll::Ready(r) => match r { + Ok((event, ctx, tx)) => { + self.fut = None; + + if let Some(event) = event { + if event.event_type() != EventType::End { + self.ctx = Some(ctx); + self.receiver = Some(tx); + } + + Poll::Ready(Some(Ok(event))) + } else { + Poll::Ready(None) + } + } + Err(e) => Poll::Ready(Some(Err(e))), + }, + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/ipc/stream_emitter/mod.rs b/src/ipc/stream_emitter/mod.rs index ded33cd5..bc0f1297 100644 --- a/src/ipc/stream_emitter/mod.rs +++ b/src/ipc/stream_emitter/mod.rs @@ -1,29 +1,22 @@ pub mod emit_metadata; pub mod emit_metadata_with_response; +pub mod emit_metadata_with_response_stream; mod event_metadata; - - - - use std::sync::Arc; - - -use emit_metadata::EmitMetadata; - - - -use tokio::io::{AsyncWrite}; +use tokio::io::AsyncWrite; use tokio::sync::Mutex; use tracing; - use crate::event::EventType; use crate::ipc::context::Context; use crate::payload::IntoPayload; use crate::protocol::AsyncProtocolStream; +pub use emit_metadata_with_response_stream::ResponseStream; +use crate::prelude::emit_metadata::EmitMetadata; + #[macro_export] macro_rules! poll_unwrap { ($val:expr) => { diff --git a/src/lib.rs b/src/lib.rs index 0b898607..6ba5e22b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,6 +137,7 @@ pub mod prelude { pub use crate::event_handler::{EventHandler, Response}; pub use crate::ipc::context::Context; pub use crate::ipc::context::{PoolGuard, PooledContext}; + pub use crate::ipc::stream_emitter::*; pub use crate::ipc::*; pub use crate::macros::*; pub use crate::namespace::Namespace; diff --git a/tests/test_event_streams.rs b/tests/test_event_streams.rs new file mode 100644 index 00000000..5a8f2edb --- /dev/null +++ b/tests/test_event_streams.rs @@ -0,0 +1,88 @@ +use crate::utils::call_counter::{get_counter_from_context, increment_counter_for_event}; +use crate::utils::protocol::TestProtocolListener; +use crate::utils::{get_free_port, start_server_and_client}; +use bromine::prelude::*; +use byteorder::ReadBytesExt; +use futures::StreamExt; +use std::io::Read; +use std::time::Duration; + +mod utils; + +/// When awaiting the reply to an event the handler for the event doesn't get called. +/// Therefore we expect it to have a call count of 0. +#[tokio::test] +async fn it_receives_responses() { + let port = get_free_port(); + let ctx = get_client_with_server(port).await; + let mut reply_stream = ctx + .emit("stream", EmptyPayload) + .stream_replies() + .await + .unwrap(); + + let mut reply_stream_2 = ctx + .emit("stream", EmptyPayload) + .stream_replies() + .await + .unwrap(); + + for i in 0u8..=100 { + if let Some(Ok(event)) = reply_stream.next().await { + assert_eq!(event.payload::().unwrap().0, i) + } else { + panic!("stream 1 has no value {}", i); + } + if let Some(Ok(event)) = reply_stream_2.next().await { + assert_eq!(event.payload::().unwrap().0, i) + } else { + panic!("stream 2 has no value {}", i); + } + } + let counter = get_counter_from_context(&ctx).await; + assert_eq!(counter.get("stream").await, 2); +} + +async fn get_client_with_server(port: u8) -> Context { + start_server_and_client(move || get_builder(port)).await +} + +fn get_builder(port: u8) -> IPCBuilder { + IPCBuilder::new() + .address(port) + .timeout(Duration::from_millis(100)) + .on("stream", callback!(handle_stream_event)) +} + +async fn handle_stream_event(ctx: &Context, event: Event) -> IPCResult { + increment_counter_for_event(ctx, &event).await; + for i in 0u8..=99 { + ctx.emit("number", NumberPayload(i)).await?; + } + + ctx.response(NumberPayload(100)) +} + +pub struct EmptyPayload; + +impl IntoPayload for EmptyPayload { + fn into_payload(self, _: &Context) -> IPCResult> { + Ok(vec![]) + } +} + +pub struct NumberPayload(u8); + +impl IntoPayload for NumberPayload { + fn into_payload(self, _: &Context) -> IPCResult> { + Ok(vec![self.0]) + } +} + +impl FromPayload for NumberPayload { + fn from_payload(mut reader: R) -> IPCResult { + let num = reader.read_u8()?; + + Ok(NumberPayload(num)) + } +}