diff --git a/Cargo.lock b/Cargo.lock index acdf68c9..a7557c4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -500,7 +500,7 @@ dependencies = [ [[package]] name = "rmp-ipc" -version = "0.7.2" +version = "0.8.1" dependencies = [ "criterion", "lazy_static", diff --git a/Cargo.toml b/Cargo.toml index 780c7c3d..a3a4a6fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rmp-ipc" -version = "0.7.2" +version = "0.8.1" authors = ["trivernis "] edition = "2018" readme = "README.md" diff --git a/src/error.rs b/src/error.rs index 933b39af..22d01973 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use crate::error_event::ErrorEventData; use thiserror::Error; use tokio::sync::oneshot; @@ -25,6 +26,9 @@ pub enum Error { #[error("Send Error")] SendError, + + #[error("Error response: {0}")] + ErrorEvent(#[from] ErrorEventData), } impl From for Error { diff --git a/src/events/error_event.rs b/src/events/error_event.rs index e11ae703..00a471c5 100644 --- a/src/events/error_event.rs +++ b/src/events/error_event.rs @@ -1,4 +1,6 @@ use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt::{Display, Formatter}; pub static ERROR_EVENT_NAME: &str = "error"; @@ -11,3 +13,11 @@ pub struct ErrorEventData { pub code: u16, pub message: String, } + +impl Error for ErrorEventData {} + +impl Display for ErrorEventData { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "IPC Code {}: '{}'", self.code, self.message) + } +} diff --git a/src/ipc/builder.rs b/src/ipc/builder.rs index cff2d4b3..fffb5e27 100644 --- a/src/ipc/builder.rs +++ b/src/ipc/builder.rs @@ -3,13 +3,15 @@ use crate::events::error_event::{ErrorEventData, ERROR_EVENT_NAME}; use crate::events::event::Event; use crate::events::event_handler::EventHandler; use crate::ipc::client::IPCClient; -use crate::ipc::context::Context; +use crate::ipc::context::{Context, PooledContext, ReplyListeners}; use crate::ipc::server::IPCServer; use crate::namespaces::builder::NamespaceBuilder; use crate::namespaces::namespace::Namespace; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; +use tokio::sync::RwLock; use typemap_rev::{TypeMap, TypeMapKey}; /// A builder for the IPC server or client. @@ -131,10 +133,13 @@ impl IPCBuilder { #[tracing::instrument(skip(self))] pub async fn build_client(self) -> Result { self.validate()?; + let data = Arc::new(RwLock::new(self.data)); + let reply_listeners = ReplyListeners::default(); let client = IPCClient { namespaces: self.namespaces, handler: self.handler, - data: self.data, + data, + reply_listeners, }; let ctx = client.connect(&self.address.unwrap()).await?; @@ -142,6 +147,36 @@ impl IPCBuilder { Ok(ctx) } + /// Builds a pooled IPC client + /// This causes the builder to actually create `pool_size` clients and + /// return a [crate::context::PooledContext] that allows one to [crate::context::PooledContext::acquire] a single context + /// to emit events. + #[tracing::instrument(skip(self))] + pub async fn build_pooled_client(self, pool_size: usize) -> Result { + if pool_size == 0 { + Error::BuildError("Pool size must be greater than 0".to_string()); + } + self.validate()?; + let data = Arc::new(RwLock::new(self.data)); + let mut contexts = Vec::new(); + let address = self.address.unwrap(); + let reply_listeners = ReplyListeners::default(); + + for _ in 0..pool_size { + let client = IPCClient { + namespaces: self.namespaces.clone(), + handler: self.handler.clone(), + data: Arc::clone(&data), + reply_listeners: Arc::clone(&reply_listeners), + }; + + let ctx = client.connect(&address).await?; + contexts.push(ctx); + } + + Ok(PooledContext::new(contexts)) + } + /// Validates that all required fields have been provided #[tracing::instrument(skip(self))] fn validate(&self) -> Result<()> { diff --git a/src/ipc/client.rs b/src/ipc/client.rs index ba929738..1546f1e2 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,7 +1,7 @@ use super::handle_connection; use crate::error::Result; use crate::events::event_handler::EventHandler; -use crate::ipc::context::Context; +use crate::ipc::context::{Context, ReplyListeners}; use crate::ipc::stream_emitter::StreamEmitter; use crate::namespaces::namespace::Namespace; use std::collections::HashMap; @@ -14,10 +14,12 @@ use typemap_rev::TypeMap; /// The IPC Client to connect to an IPC Server. /// Use the [IPCBuilder](crate::builder::IPCBuilder) to create the client. /// Usually one does not need to use the IPCClient object directly. +#[derive(Clone)] pub struct IPCClient { pub(crate) handler: EventHandler, pub(crate) namespaces: HashMap, - pub(crate) data: TypeMap, + pub(crate) data: Arc>, + pub(crate) reply_listeners: ReplyListeners, } impl IPCClient { @@ -31,8 +33,9 @@ impl IPCClient { let (tx, rx) = oneshot::channel(); let ctx = Context::new( StreamEmitter::clone(&emitter), - Arc::new(RwLock::new(self.data)), + self.data, Some(tx), + self.reply_listeners, ); let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); diff --git a/src/ipc/context.rs b/src/ipc/context.rs index 81914c83..beff8e25 100644 --- a/src/ipc/context.rs +++ b/src/ipc/context.rs @@ -3,11 +3,15 @@ use crate::event::Event; use crate::ipc::stream_emitter::StreamEmitter; use std::collections::HashMap; use std::mem; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::oneshot::Sender; use tokio::sync::{oneshot, Mutex, RwLock}; use typemap_rev::TypeMap; +pub(crate) type ReplyListeners = Arc>>>; + /// An object provided to each callback function. /// Currently it only holds the event emitter to emit response events in event callbacks. /// ```rust @@ -30,7 +34,7 @@ pub struct Context { stop_sender: Arc>>>, - reply_listeners: Arc>>>, + reply_listeners: ReplyListeners, } impl Context { @@ -38,10 +42,11 @@ impl Context { emitter: StreamEmitter, data: Arc>, stop_sender: Option>, + reply_listeners: ReplyListeners, ) -> Self { Self { emitter, - reply_listeners: Arc::new(Mutex::new(HashMap::new())), + reply_listeners, data, stop_sender: Arc::new(Mutex::new(stop_sender)), } @@ -77,3 +82,109 @@ impl Context { listeners.remove(&ref_id) } } + +#[derive(Clone)] +pub struct PooledContext { + contexts: Vec>, +} + +pub struct PoolGuard +where + T: Clone, +{ + inner: T, + count: Arc, +} + +impl Deref for PoolGuard +where + T: Clone, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for PoolGuard +where + T: Clone, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl Clone for PoolGuard +where + T: Clone, +{ + fn clone(&self) -> Self { + self.acquire(); + + Self { + inner: self.inner.clone(), + count: Arc::clone(&self.count), + } + } +} + +impl Drop for PoolGuard +where + T: Clone, +{ + fn drop(&mut self) { + self.release(); + } +} + +impl PoolGuard +where + T: Clone, +{ + pub(crate) fn new(inner: T) -> Self { + Self { + inner, + count: Arc::new(AtomicUsize::new(0)), + } + } + + /// Acquires the context by adding 1 to the count + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) fn acquire(&self) { + let count = self.count.fetch_add(1, Ordering::Relaxed); + tracing::trace!(count); + } + + /// Releases the connection by subtracting from the stored count + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) fn release(&self) { + let count = self.count.fetch_sub(1, Ordering::Relaxed); + tracing::trace!(count); + } + + pub(crate) fn count(&self) -> usize { + self.count.load(Ordering::Relaxed) + } +} + +impl PooledContext { + /// Creates a new pooled context from a list of contexts + pub(crate) fn new(contexts: Vec) -> Self { + Self { + contexts: contexts.into_iter().map(PoolGuard::new).collect(), + } + } + + /// Acquires a context from the pool + /// It always chooses the one that is used the least + #[tracing::instrument(level = "trace", skip_all)] + pub fn acquire(&self) -> PoolGuard { + self.contexts + .iter() + .min_by_key(|c| c.count()) + .unwrap() + .clone() + } +} diff --git a/src/ipc/server.rs b/src/ipc/server.rs index d1ed1113..ff0d53e4 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -1,7 +1,7 @@ use super::handle_connection; use crate::error::Result; use crate::events::event_handler::EventHandler; -use crate::ipc::context::Context; +use crate::ipc::context::{Context, ReplyListeners}; use crate::ipc::stream_emitter::StreamEmitter; use crate::namespaces::namespace::Namespace; use std::collections::HashMap; @@ -40,7 +40,8 @@ impl IPCServer { tokio::spawn(async { let (read_half, write_half) = stream.into_split(); let emitter = StreamEmitter::new(write_half); - let ctx = Context::new(StreamEmitter::clone(&emitter), data, None); + let reply_listeners = ReplyListeners::default(); + let ctx = Context::new(StreamEmitter::clone(&emitter), data, None, reply_listeners); handle_connection(namespaces, handler, read_half, ctx).await; }); diff --git a/src/ipc/stream_emitter.rs b/src/ipc/stream_emitter.rs index 13b549fd..996fdb88 100644 --- a/src/ipc/stream_emitter.rs +++ b/src/ipc/stream_emitter.rs @@ -1,4 +1,5 @@ use crate::error::Result; +use crate::error_event::{ErrorEventData, ERROR_EVENT_NAME}; use crate::events::event::Event; use crate::events::payload::EventSendPayload; use crate::ipc::context::Context; @@ -119,6 +120,10 @@ impl EmitMetadata { #[tracing::instrument(skip(self, ctx), fields(self.message_id))] pub async fn await_reply(&self, ctx: &Context) -> Result { let reply = ctx.await_reply(self.message_id).await?; - Ok(reply) + if reply.name() == ERROR_EVENT_NAME { + Err(reply.data::()?.into()) + } else { + Ok(reply) + } } } diff --git a/src/lib.rs b/src/lib.rs index f053f100..5a647a1f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,6 +111,7 @@ pub use events::event; pub use events::event_handler; pub use events::payload; pub use ipc::builder::IPCBuilder; +pub use ipc::context; pub use macros::*; pub use namespaces::builder::NamespaceBuilder; pub use namespaces::namespace; @@ -122,6 +123,7 @@ pub mod prelude { pub use crate::event::Event; pub use crate::event_handler::EventHandler; pub use crate::ipc::context::Context; + pub use crate::ipc::context::{PoolGuard, PooledContext}; pub use crate::ipc::*; pub use crate::macros::*; pub use crate::namespace::Namespace; diff --git a/src/tests/ipc_tests.rs b/src/tests/ipc_tests.rs index c738a339..142566f4 100644 --- a/src/tests/ipc_tests.rs +++ b/src/tests/ipc_tests.rs @@ -39,8 +39,9 @@ async fn it_receives_events() { while !server_running.load(Ordering::Relaxed) { tokio::time::sleep(Duration::from_millis(10)).await; } - let ctx = builder.build_client().await.unwrap(); - let reply = ctx + let pool = builder.build_pooled_client(8).await.unwrap(); + let reply = pool + .acquire() .emitter .emit( "ping", @@ -51,7 +52,7 @@ async fn it_receives_events() { ) .await .unwrap() - .await_reply(&ctx) + .await_reply(&pool.acquire()) .await .unwrap(); assert_eq!(reply.name(), "pong"); @@ -205,7 +206,6 @@ async fn test_error_responses() { .await .unwrap() .await_reply(&ctx) - .await - .unwrap(); - assert_eq!(reply.name(), "error"); + .await; + assert!(reply.is_err()); }