diff --git a/src/error.rs b/src/error.rs index 403a19a7..ab3cb2ce 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,5 @@ use thiserror::Error; +use tokio::sync::oneshot; pub type Result = std::result::Result; @@ -18,6 +19,9 @@ pub enum Error { #[error("{0}")] Message(String), + + #[error("Channel Error: {0}")] + ReceiveError(#[from] oneshot::error::RecvError), } impl From for Error { diff --git a/src/events/event.rs b/src/events/event.rs index 26fab4bb..1dd00791 100644 --- a/src/events/event.rs +++ b/src/events/event.rs @@ -65,4 +65,15 @@ impl Event { Ok(event_bytes) } + + /// The identifier of the message + pub fn id(&self) -> u64 { + self.id + } + + /// The ID of the message referenced by this message. + /// It represents the message that is replied to and can be None. + pub fn reference_id(&self) -> Option { + self.ref_id.clone() + } } diff --git a/src/ipc/builder.rs b/src/ipc/builder.rs index 7039f8d2..7171f40a 100644 --- a/src/ipc/builder.rs +++ b/src/ipc/builder.rs @@ -5,7 +5,6 @@ use crate::events::event_handler::EventHandler; use crate::ipc::client::IPCClient; use crate::ipc::context::Context; use crate::ipc::server::IPCServer; -use crate::ipc::stream_emitter::StreamEmitter; use std::future::Future; use std::pin::Pin; @@ -85,15 +84,15 @@ impl IPCBuilder { } /// Builds an ipc client - pub async fn build_client(self) -> Result { + pub async fn build_client(self) -> Result { self.validate()?; let client = IPCClient { handler: self.handler, }; - let emitter = client.connect(&self.address.unwrap()).await?; + let ctx = client.connect(&self.address.unwrap()).await?; - Ok(emitter) + Ok(ctx) } /// Validates that all required fields have been provided diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 3a8af336..5e55892c 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -16,17 +16,20 @@ pub struct IPCClient { impl IPCClient { /// Connects to a given address and returns an emitter for events to that address. /// Invoked by [IPCBuilder::build_client](crate::builder::IPCBuilder::build_client) - pub async fn connect(self, address: &str) -> Result { + pub async fn connect(self, address: &str) -> Result { let stream = TcpStream::connect(address).await?; let (read_half, write_half) = stream.into_split(); let emitter = StreamEmitter::new(write_half); let ctx = Context::new(StreamEmitter::clone(&emitter)); let handler = Arc::new(self.handler); - tokio::spawn(async move { - handle_connection(handler, read_half, ctx).await; + tokio::spawn({ + let ctx = Context::clone(&ctx); + async move { + handle_connection(handler, read_half, ctx).await; + } }); - Ok(emitter) + Ok(ctx) } } diff --git a/src/ipc/context.rs b/src/ipc/context.rs index 53ec5ab3..77e29dcb 100644 --- a/src/ipc/context.rs +++ b/src/ipc/context.rs @@ -1,4 +1,9 @@ +use crate::error::Result; use crate::ipc::stream_emitter::StreamEmitter; +use crate::Event; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{oneshot, Mutex}; /// An object provided to each callback function. /// Currently it only holds the event emitter to emit response events in event callbacks. @@ -18,10 +23,33 @@ use crate::ipc::stream_emitter::StreamEmitter; pub struct Context { /// The event emitter pub emitter: StreamEmitter, + + reply_listeners: Arc>>>, } impl Context { pub(crate) fn new(emitter: StreamEmitter) -> Self { - Self { emitter } + Self { + emitter, + reply_listeners: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Waits for a reply to the given message ID + pub async fn await_reply(&self, message_id: u64) -> Result { + let (rx, tx) = oneshot::channel(); + { + let mut listeners = self.reply_listeners.lock().await; + listeners.insert(message_id, rx); + } + let event = tx.await?; + + Ok(event) + } + + /// Returns the channel for a reply to the given message id + pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option> { + let mut listeners = self.reply_listeners.lock().await; + listeners.remove(&ref_id) } } diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index c3761594..5507c23f 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -14,28 +14,41 @@ pub mod stream_emitter; /// Handles listening to a connection and triggering the corresponding event functions async fn handle_connection(handler: Arc, mut read_half: OwnedReadHalf, ctx: Context) { while let Ok(event) = Event::from_async_read(&mut read_half).await { - let ctx = Context::clone(&ctx); - let handler = Arc::clone(&handler); - - tokio::spawn(async move { - if let Err(e) = handler.handle_event(&ctx, event).await { - // emit an error event - if let Err(e) = ctx - .emitter - .emit( - ERROR_EVENT_NAME, - ErrorEventData { - message: format!("{:?}", e), - code: 500, - }, - ) - .await - { - log::error!("Error occurred when sending error response: {:?}", e); + // check if the event is a reply + if let Some(ref_id) = event.reference_id() { + // get the listener for replies + if let Some(sender) = ctx.get_reply_sender(ref_id).await { + // try sending the event to the listener for replies + if let Err(event) = sender.send(event) { + handle_event(Context::clone(&ctx), Arc::clone(&handler), event); } - log::error!("Failed to handle event: {:?}", e); + continue; } - }); + } + handle_event(Context::clone(&ctx), Arc::clone(&handler), event); } log::debug!("Connection closed."); } + +/// Handles a single event in a different tokio context +fn handle_event(ctx: Context, handler: Arc, event: Event) { + tokio::spawn(async move { + if let Err(e) = handler.handle_event(&ctx, event).await { + // emit an error event + if let Err(e) = ctx + .emitter + .emit( + ERROR_EVENT_NAME, + ErrorEventData { + message: format!("{:?}", e), + code: 500, + }, + ) + .await + { + log::error!("Error occurred when sending error response: {:?}", e); + } + log::error!("Failed to handle event: {:?}", e); + } + }); +} diff --git a/src/ipc/stream_emitter.rs b/src/ipc/stream_emitter.rs index 8c2d4e5c..2e637d29 100644 --- a/src/ipc/stream_emitter.rs +++ b/src/ipc/stream_emitter.rs @@ -1,3 +1,4 @@ +use crate::context::Context; use crate::error::Result; use crate::events::event::Event; use serde::Serialize; @@ -26,7 +27,7 @@ impl StreamEmitter { event: &str, data: T, res_id: Option, - ) -> Result<()> { + ) -> Result { let data_bytes = rmp_serde::to_vec(&data)?; let event = Event::new(event.to_string(), data_bytes, res_id); let event_bytes = event.to_bytes()?; @@ -35,14 +36,14 @@ impl StreamEmitter { (*stream).write_all(&event_bytes[..]).await?; } - Ok(()) + Ok(EmitMetadata::new(event.id())) } /// Emits an event - pub async fn emit(&self, event: &str, data: T) -> Result<()> { - self._emit(event, data, None).await?; + pub async fn emit(&self, event: &str, data: T) -> Result { + let metadata = self._emit(event, data, None).await?; - Ok(()) + Ok(metadata) } /// Emits a response to an event @@ -51,9 +52,32 @@ impl StreamEmitter { event_id: u64, event: &str, data: T, - ) -> Result<()> { - self._emit(event, data, Some(event_id)).await?; + ) -> Result { + let metadata = self._emit(event, data, Some(event_id)).await?; - Ok(()) + Ok(metadata) + } +} + +/// A metadata object returned after emitting an event. +/// This object can be used to wait for a response to an event. +pub struct EmitMetadata { + message_id: u64, +} + +impl EmitMetadata { + pub(crate) fn new(message_id: u64) -> Self { + Self { message_id } + } + + /// The ID of the emitted message + pub fn message_id(&self) -> u64 { + self.message_id + } + + /// Waits for a reply to the given message. + pub async fn await_reply(&self, ctx: &Context) -> Result { + let reply = ctx.await_reply(self.message_id).await?; + Ok(reply) } } diff --git a/src/lib.rs b/src/lib.rs index 6d394495..caae50a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,17 +6,19 @@ //! // create the client //! # async fn a() { //! -//! let emitter = IPCBuilder::new() +//! let ctx = IPCBuilder::new() //! .address("127.0.0.1:2020") //! // register callback -//! .on("ping", |_ctx, _event| Box::pin(async move { +//! .on("ping", |ctx, event| Box::pin(async move { //! println!("Received ping event."); +//! ctx.emitter.emit_response(event.id(), "pong", ()).await?; //! Ok(()) //! })) //! .build_client().await.unwrap(); //! //! // emit an initial event -//! emitter.emit("ping", ()).await.unwrap(); +//! let response = ctx.emitter.emit("ping", ()).await.unwrap().await_reply(&ctx).await.unwrap(); +//! assert_eq!(response.name(), "pong"); //! # } //! ``` //! @@ -28,8 +30,9 @@ //! IPCBuilder::new() //! .address("127.0.0.1:2020") //! // register callback -//! .on("ping", |_ctx, _event| Box::pin(async move { +//! .on("ping", |ctx, event| Box::pin(async move { //! println!("Received ping event."); +//! ctx.emitter.emit_response(event.id(), "pong", ()).await?; //! Ok(()) //! })) //! .build_server().await.unwrap(); diff --git a/src/tests/ipc_tests.rs b/src/tests/ipc_tests.rs index 823dfda5..6afe5872 100644 --- a/src/tests/ipc_tests.rs +++ b/src/tests/ipc_tests.rs @@ -2,26 +2,22 @@ use self::super::utils::PingEventData; use crate::error::Error; use crate::events::error_event::ErrorEventData; use crate::IPCBuilder; -use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; #[tokio::test] async fn it_receives_events() { - let ctr = Arc::new(AtomicU8::new(0)); let builder = IPCBuilder::new() .on("ping", { - let ctr = Arc::clone(&ctr); move |ctx, e| { - let ctr = Arc::clone(&ctr); Box::pin(async move { - ctr.fetch_add(1, Ordering::Relaxed); let mut ping_data = e.data::()?; ping_data.time = SystemTime::now(); ping_data.ttl -= 1; if ping_data.ttl > 0 { - ctx.emitter.emit("ping", ping_data).await?; + ctx.emitter.emit_response(e.id(), "pong", ping_data).await?; } Ok(()) @@ -41,8 +37,9 @@ async fn it_receives_events() { while !server_running.load(Ordering::Relaxed) { tokio::time::sleep(Duration::from_millis(10)).await; } - let client = builder.build_client().await.unwrap(); - client + let ctx = builder.build_client().await.unwrap(); + let reply = ctx + .emitter .emit( "ping", PingEventData { @@ -51,9 +48,11 @@ async fn it_receives_events() { }, ) .await + .unwrap() + .await_reply(&ctx) + .await .unwrap(); - tokio::time::sleep(Duration::from_secs(1)).await; - assert_eq!(ctr.load(Ordering::SeqCst), 16); + assert_eq!(reply.name(), "pong"); } #[tokio::test] @@ -91,8 +90,8 @@ async fn it_handles_errors() { while !server_running.load(Ordering::Relaxed) { tokio::time::sleep(Duration::from_millis(10)).await; } - let client = builder.build_client().await.unwrap(); - client.emit("ping", ()).await.unwrap(); + let ctx = builder.build_client().await.unwrap(); + ctx.emitter.emit("ping", ()).await.unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; assert!(error_occurred.load(Ordering::SeqCst));