diff --git a/src/events/error_event.rs b/src/events/error_event.rs index 2b6e160e..bff310a1 100644 --- a/src/events/error_event.rs +++ b/src/events/error_event.rs @@ -8,6 +8,7 @@ use std::fmt::{Display, Formatter}; use std::io::Read; pub static ERROR_EVENT_NAME: &str = "error"; +pub static END_EVENT_NAME: &str = "end"; /// Data returned on error event. /// The error event has a default handler that just logs that diff --git a/src/events/event_handler.rs b/src/events/event_handler.rs index 2ef18062..8b87f603 100644 --- a/src/events/event_handler.rs +++ b/src/events/event_handler.rs @@ -1,14 +1,38 @@ use crate::error::Result; use crate::events::event::Event; use crate::ipc::context::Context; +use crate::payload::{BytePayload, IntoPayload}; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; +pub struct Response(Vec); + +impl Response { + /// Creates a new response with a given payload + pub fn payload(ctx: &Context, payload: P) -> Result { + let bytes = payload.into_payload(ctx)?; + + Ok(Self(bytes)) + } + + /// Creates an empty response + pub fn empty() -> Self { + Self(vec![]) + } + + pub(crate) fn into_byte_payload(self) -> BytePayload { + BytePayload::new(self.0) + } +} + type EventCallback = Arc< - dyn for<'a> Fn(&'a Context, Event) -> Pin> + Send + 'a)>> + dyn for<'a> Fn( + &'a Context, + Event, + ) -> Pin> + Send + 'a)>> + Send + Sync, >; @@ -46,7 +70,7 @@ impl EventHandler { F: for<'a> Fn( &'a Context, Event, - ) -> Pin> + Send + 'a)>> + ) -> Pin> + Send + 'a)>> + Send + Sync, { @@ -56,11 +80,11 @@ impl EventHandler { /// Handles a received event #[inline] #[tracing::instrument(level = "debug", skip(self, ctx, event))] - pub async fn handle_event(&self, ctx: &Context, event: Event) -> Result<()> { + pub async fn handle_event(&self, ctx: &Context, event: Event) -> Result { if let Some(cb) = self.callbacks.get(event.name()) { - cb.as_ref()(ctx, event).await?; + cb.as_ref()(ctx, event).await + } else { + Ok(Response::empty()) } - - Ok(()) } } diff --git a/src/ipc/builder.rs b/src/ipc/builder.rs index 58810081..0019dc21 100644 --- a/src/ipc/builder.rs +++ b/src/ipc/builder.rs @@ -1,5 +1,7 @@ use crate::error::{Error, Result}; -use crate::events::error_event::{ErrorEventData, ERROR_EVENT_NAME}; +use crate::error_event::ErrorEventData; +use crate::event_handler::Response; +use crate::events::error_event::ERROR_EVENT_NAME; use crate::events::event::Event; use crate::events::event_handler::EventHandler; use crate::ipc::client::IPCClient; @@ -24,6 +26,7 @@ use typemap_rev::{TypeMap, TypeMapKey}; /// use typemap_rev::TypeMapKey; /// use bromine::IPCBuilder; /// use tokio::net::TcpListener; +/// use bromine::prelude::Response; /// /// struct CustomKey; /// @@ -37,13 +40,13 @@ use typemap_rev::{TypeMap, TypeMapKey}; /// // register callback /// .on("ping", |_ctx, _event| Box::pin(async move { /// println!("Received ping event."); -/// Ok(()) +/// Ok(Response::empty()) /// })) /// // register a namespace /// .namespace("namespace") /// .on("namespace-event", |_ctx, _event| Box::pin(async move { /// println!("Namespace event."); -/// Ok(()) +/// Ok(Response::empty()) /// })) /// .build() /// // add context shared data @@ -75,7 +78,7 @@ where tracing::warn!(error_data.code); tracing::warn!("error_data.message = '{}'", error_data.message); - Ok(()) + Ok(Response::empty()) }) }); Self { @@ -102,7 +105,7 @@ where F: for<'a> Fn( &'a Context, Event, - ) -> Pin> + Send + 'a)>> + ) -> Pin> + Send + 'a)>> + Send + Sync, { diff --git a/src/ipc/context.rs b/src/ipc/context.rs index c939a376..658e0c15 100644 --- a/src/ipc/context.rs +++ b/src/ipc/context.rs @@ -15,6 +15,7 @@ use crate::ipc::stream_emitter::{EmitMetadata, StreamEmitter}; use crate::payload::IntoPayload; #[cfg(feature = "serialize")] use crate::payload::{DynamicSerializer, SerdePayload}; +use crate::prelude::Response; pub(crate) type ReplyListeners = Arc>>>; @@ -114,6 +115,11 @@ impl Context { } } + /// Ends the event flow by creating a final response + pub fn response(&self, payload: P) -> Result { + Response::payload(self, payload) + } + /// Registers a reply listener for a given event #[inline] #[tracing::instrument(level = "debug", skip(self))] diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 0319f6aa..db9f456b 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::error_event::{ErrorEventData, ERROR_EVENT_NAME}; +use crate::error_event::{ErrorEventData, END_EVENT_NAME, ERROR_EVENT_NAME}; use crate::event::EventType; use crate::events::event_handler::EventHandler; use crate::namespaces::namespace::Namespace; @@ -41,6 +41,12 @@ async fn handle_connection( } tracing::trace!("No response listener found for event. Passing to regular listener."); } + + if event.event_type() == EventType::End { + tracing::debug!("Received dangling end event with no listener"); + continue; + } + if let Some(namespace) = event.namespace().clone().and_then(|n| namespaces.get(&n)) { tracing::trace!("Passing event to namespace listener"); let handler = Arc::clone(&namespace.handler); @@ -58,24 +64,36 @@ fn handle_event(mut ctx: Context, handler: Arc, event: Event) { ctx.set_ref_id(Some(event.id())); tokio::spawn(async move { - if let Err(e) = handler.handle_event(&ctx, event).await { - // emit an error event - if let Err(e) = ctx - .emit_raw( - ERROR_EVENT_NAME, - None, - EventType::Error, - ErrorEventData { - message: format!("{:?}", e), - code: 500, - }, - ) - .await - { - tracing::error!("Error occurred when sending error response: {:?}", e); + match handler.handle_event(&ctx, event).await { + Ok(r) => { + // emit the response under a unique name to prevent it being interpreted as a new + // event initiator + if let Err(e) = ctx + .emit_raw(END_EVENT_NAME, None, EventType::End, r.into_byte_payload()) + .await + { + tracing::error!("Error occurred when sending error response: {:?}", e); + } } + Err(e) => { + // emit an error event + if let Err(e) = ctx + .emit_raw( + ERROR_EVENT_NAME, + None, + EventType::Error, + ErrorEventData { + message: format!("{:?}", e), + code: 500, + }, + ) + .await + { + tracing::error!("Error occurred when sending error response: {:?}", e); + } - tracing::error!("Failed to handle event: {:?}", e); + tracing::error!("Failed to handle event: {:?}", e); + } } }); } diff --git a/src/lib.rs b/src/lib.rs index a251ef06..6c824af0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ pub mod prelude { pub use crate::error::Error as IPCError; pub use crate::error::Result as IPCResult; pub use crate::event::Event; - pub use crate::event_handler::EventHandler; + pub use crate::event_handler::{EventHandler, Response}; pub use crate::ipc::context::Context; pub use crate::ipc::context::{PoolGuard, PooledContext}; pub use crate::ipc::*; diff --git a/src/macros.rs b/src/macros.rs index 0107c03f..ff15f867 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -30,4 +30,4 @@ macro_rules! events{ $handler.on($name, callback!($cb)); )* } -} \ No newline at end of file +} diff --git a/src/namespaces/builder.rs b/src/namespaces/builder.rs index a048a053..aa5b0ebe 100644 --- a/src/namespaces/builder.rs +++ b/src/namespaces/builder.rs @@ -1,5 +1,6 @@ use crate::error::Result; use crate::event::Event; +use crate::event_handler::Response; use crate::events::event_handler::EventHandler; use crate::ipc::context::Context; use crate::namespaces::namespace::Namespace; @@ -32,7 +33,7 @@ where F: for<'a> Fn( &'a Context, Event, - ) -> Pin> + Send + 'a)>> + ) -> Pin> + Send + 'a)>> + Send + Sync, { diff --git a/tests/test_events_with_payload.rs b/tests/test_events_with_payload.rs index a00f8601..7c14ad7f 100644 --- a/tests/test_events_with_payload.rs +++ b/tests/test_events_with_payload.rs @@ -36,11 +36,7 @@ async fn it_receives_payloads() { number: 0, string: String::from("Hello World"), }; - let reply = ctx - .emit("ping", payload) - .await_reply() - .await - .unwrap(); + let reply = ctx.emit("ping", payload).await_reply().await.unwrap(); let reply_payload = reply.payload::().unwrap(); let counters = get_counter_from_context(&ctx).await; @@ -62,19 +58,19 @@ fn get_builder(port: u8) -> IPCBuilder { .timeout(Duration::from_millis(10)) } -async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> { +async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; let payload = event.payload::()?; ctx.emit("pong", payload).await?; - Ok(()) + Ok(Response::empty()) } -async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult<()> { +async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; let _payload = event.payload::()?; - Ok(()) + Ok(Response::empty()) } #[cfg(feature = "serialize")] diff --git a/tests/test_raw_events.rs b/tests/test_raw_events.rs index e885b20d..c690f869 100644 --- a/tests/test_raw_events.rs +++ b/tests/test_raw_events.rs @@ -45,11 +45,7 @@ async fn it_sends_namespaced_events() { async fn it_receives_responses() { let port = get_free_port(); let ctx = get_client_with_server(port).await; - let reply = ctx - .emit("ping", EmptyPayload) - .await_reply() - .await - .unwrap(); + let reply = ctx.emit("ping", EmptyPayload).await_reply().await.unwrap(); let counter = get_counter_from_context(&ctx).await; assert_eq!(reply.name(), "pong"); @@ -108,29 +104,29 @@ fn get_builder(port: u8) -> IPCBuilder { .build() } -async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult<()> { +async fn handle_ping_event(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; ctx.emit("pong", EmptyPayload).await?; - Ok(()) + Ok(Response::empty()) } -async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult<()> { +async fn handle_pong_event(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; - Ok(()) + Ok(Response::empty()) } -async fn handle_create_error_event(ctx: &Context, event: Event) -> IPCResult<()> { +async fn handle_create_error_event(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; Err(IPCError::from("Test Error")) } -async fn handle_error_event(ctx: &Context, event: Event) -> IPCResult<()> { +async fn handle_error_event(ctx: &Context, event: Event) -> IPCResult { increment_counter_for_event(ctx, &event).await; - Ok(()) + Ok(Response::empty()) } pub struct EmptyPayload;