diff --git a/Cargo.lock b/Cargo.lock index 175f8f4f..6022a5c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,7 +148,7 @@ dependencies = [ [[package]] name = "rmp-ipc" -version = "0.3.0" +version = "0.4.0" dependencies = [ "lazy_static", "log", @@ -156,6 +156,7 @@ dependencies = [ "serde", "thiserror", "tokio", + "typemap_rev", ] [[package]] @@ -248,6 +249,12 @@ dependencies = [ "syn", ] +[[package]] +name = "typemap_rev" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5b74f0a24b5454580a79abb6994393b09adf0ab8070f15827cb666255de155" + [[package]] name = "unicode-xid" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index dd423320..4907a95d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rmp-ipc" -version = "0.3.0" +version = "0.4.0" authors = ["trivernis "] edition = "2018" readme = "README.md" @@ -15,6 +15,7 @@ thiserror = "1.0.30" rmp-serde = "0.15.4" log = "0.4.14" lazy_static = "1.4.0" +typemap_rev = "0.1.5" [dependencies.serde] version = "1.0.130" diff --git a/src/ipc/builder.rs b/src/ipc/builder.rs index a02bc36e..516349b9 100644 --- a/src/ipc/builder.rs +++ b/src/ipc/builder.rs @@ -10,11 +10,19 @@ use crate::namespaces::namespace::Namespace; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; +use typemap_rev::{TypeMap, TypeMapKey}; -#[derive(Clone)] /// A builder for the IPC server or client. /// ```no_run -///use rmp_ipc::IPCBuilder; +///use typemap_rev::TypeMapKey; +/// use rmp_ipc::IPCBuilder; +/// +/// struct CustomKey; +/// +/// impl TypeMapKey for CustomKey { +/// type Value = String; +/// } +/// ///# async fn a() { /// IPCBuilder::new() /// .address("127.0.0.1:2020") @@ -23,6 +31,15 @@ use std::pin::Pin; /// println!("Received ping event."); /// Ok(()) /// })) +/// // register a namespace +/// .namespace("namespace") +/// .on("namespace-event", |_ctx, _event| Box::pin(async move { +/// println!("Namespace event."); +/// Ok(()) +/// })) +/// .build() +/// // add context shared data +/// .insert::("Hello World".to_string()) /// // can also be build_client which would return an emitter for events /// .build_server().await.unwrap(); ///# } @@ -31,6 +48,7 @@ pub struct IPCBuilder { handler: EventHandler, address: Option, namespaces: HashMap, + data: TypeMap, } impl IPCBuilder { @@ -52,9 +70,17 @@ impl IPCBuilder { handler, address: None, namespaces: HashMap::new(), + data: TypeMap::new(), } } + /// Adds globally shared data + pub fn insert(mut self, value: K::Value) -> Self { + self.data.insert::(value); + + self + } + /// Adds an event callback pub fn on(mut self, event: &str, callback: F) -> Self where @@ -96,6 +122,7 @@ impl IPCBuilder { let server = IPCServer { namespaces: self.namespaces, handler: self.handler, + data: self.data, }; server.start(&self.address.unwrap()).await?; @@ -108,6 +135,7 @@ impl IPCBuilder { let client = IPCClient { namespaces: self.namespaces, handler: self.handler, + data: self.data, }; let ctx = client.connect(&self.address.unwrap()).await?; diff --git a/src/ipc/client.rs b/src/ipc/client.rs index a23d9792..c9071566 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -7,6 +7,8 @@ use crate::namespaces::namespace::Namespace; use std::collections::HashMap; use std::sync::Arc; use tokio::net::TcpStream; +use tokio::sync::RwLock; +use typemap_rev::TypeMap; /// The IPC Client to connect to an IPC Server. /// Use the [IPCBuilder](crate::builder::IPCBuilder) to create the client. @@ -14,6 +16,7 @@ use tokio::net::TcpStream; pub struct IPCClient { pub(crate) handler: EventHandler, pub(crate) namespaces: HashMap, + pub(crate) data: TypeMap, } impl IPCClient { @@ -23,7 +26,10 @@ impl IPCClient { let stream = TcpStream::connect(address).await?; let (read_half, write_half) = stream.into_split(); let emitter = StreamEmitter::new(write_half); - let ctx = Context::new(StreamEmitter::clone(&emitter)); + let ctx = Context::new( + StreamEmitter::clone(&emitter), + Arc::new(RwLock::new(self.data)), + ); let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); diff --git a/src/ipc/context.rs b/src/ipc/context.rs index 77e29dcb..f942dcf4 100644 --- a/src/ipc/context.rs +++ b/src/ipc/context.rs @@ -3,7 +3,8 @@ use crate::ipc::stream_emitter::StreamEmitter; use crate::Event; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::{oneshot, Mutex}; +use tokio::sync::{oneshot, Mutex, RwLock}; +use typemap_rev::TypeMap; /// An object provided to each callback function. /// Currently it only holds the event emitter to emit response events in event callbacks. @@ -24,14 +25,18 @@ pub struct Context { /// The event emitter pub emitter: StreamEmitter, + /// Field to store additional context data + pub data: Arc>, + reply_listeners: Arc>>>, } impl Context { - pub(crate) fn new(emitter: StreamEmitter) -> Self { + pub(crate) fn new(emitter: StreamEmitter, data: Arc>) -> Self { Self { emitter, reply_listeners: Arc::new(Mutex::new(HashMap::new())), + data, } } diff --git a/src/ipc/server.rs b/src/ipc/server.rs index c7ad1564..5de5f9b2 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -7,6 +7,8 @@ use crate::namespaces::namespace::Namespace; use std::collections::HashMap; use std::sync::Arc; use tokio::net::TcpListener; +use tokio::sync::RwLock; +use typemap_rev::TypeMap; /// The IPC Server listening for connections. /// Use the [IPCBuilder](crate::builder::IPCBuilder) to create a server. @@ -14,6 +16,7 @@ use tokio::net::TcpListener; pub struct IPCServer { pub(crate) handler: EventHandler, pub(crate) namespaces: HashMap, + pub(crate) data: TypeMap, } impl IPCServer { @@ -23,15 +26,17 @@ impl IPCServer { let listener = TcpListener::bind(address).await?; let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); + let data = Arc::new(RwLock::new(self.data)); while let Ok((stream, _)) = listener.accept().await { let handler = Arc::clone(&handler); let namespaces = Arc::clone(&namespaces); + let data = Arc::clone(&data); tokio::spawn(async { let (read_half, write_half) = stream.into_split(); let emitter = StreamEmitter::new(write_half); - let ctx = Context::new(StreamEmitter::clone(&emitter)); + let ctx = Context::new(StreamEmitter::clone(&emitter), data); handle_connection(namespaces, handler, read_half, ctx).await; }); diff --git a/src/lib.rs b/src/lib.rs index 68d6f6e5..277d68e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,7 +31,15 @@ //! //! Server Example: //! ```no_run +//! use typemap_rev::TypeMapKey; //! use rmp_ipc::IPCBuilder; +//! +//! struct MyKey; +//! +//! impl TypeMapKey for MyKey { +//! type Value = u32; +//! } +//! //! // create the server //!# async fn a() { //! IPCBuilder::new() @@ -45,10 +53,18 @@ //! .namespace("mainspace-server") //! .on("do-something", |ctx, event| Box::pin(async move { //! println!("Doing something"); +//! { +//! // access data +//! let mut data = ctx.data.write().await; +//! let mut my_key = data.get_mut::().unwrap(); +//! *my_key += 1; +//! } //! ctx.emitter.emit_response_to(event.id(), "mainspace-client", "something", ()).await?; //! Ok(()) //! })) //! .build() +//! // store additional data +//! .insert::(3) //! .build_server().await.unwrap(); //! # } //! ``` diff --git a/src/tests/ipc_tests.rs b/src/tests/ipc_tests.rs index c2e98b2a..22b67fc7 100644 --- a/src/tests/ipc_tests.rs +++ b/src/tests/ipc_tests.rs @@ -8,6 +8,7 @@ use crate::{Event, IPCBuilder}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime}; +use typemap_rev::TypeMapKey; async fn handle_ping_event(ctx: &Context, e: Event) -> Result<()> { let mut ping_data = e.data::()?; @@ -21,15 +22,19 @@ async fn handle_ping_event(ctx: &Context, e: Event) -> Result<()> { Ok(()) } +fn get_builder_with_ping(address: &str) -> IPCBuilder { + IPCBuilder::new() + .on("ping", |ctx, e| Box::pin(handle_ping_event(ctx, e))) + .address(address) +} + #[tokio::test] async fn it_receives_events() { - let builder = IPCBuilder::new() - .on("ping", |ctx, e| Box::pin(handle_ping_event(ctx, e))) - .address("127.0.0.1:8281"); + let builder = get_builder_with_ping("127.0.0.1:8281"); let server_running = Arc::new(AtomicBool::new(false)); tokio::spawn({ let server_running = Arc::clone(&server_running); - let builder = builder.clone(); + let builder = get_builder_with_ping("127.0.0.1:8281"); async move { server_running.store(true, Ordering::SeqCst); builder.build_server().await.unwrap(); @@ -56,17 +61,21 @@ async fn it_receives_events() { assert_eq!(reply.name(), "pong"); } -#[tokio::test] -async fn it_receives_namespaced_events() { - let builder = IPCBuilder::new() +fn get_builder_with_ping_mainspace(address: &str) -> IPCBuilder { + IPCBuilder::new() .namespace("mainspace") .on("ping", |ctx, e| Box::pin(handle_ping_event(ctx, e))) .build() - .address("127.0.0.1:8282"); + .address(address) +} + +#[tokio::test] +async fn it_receives_namespaced_events() { + let builder = get_builder_with_ping_mainspace("127.0.0.1:8282"); let server_running = Arc::new(AtomicBool::new(false)); tokio::spawn({ let server_running = Arc::clone(&server_running); - let builder = builder.clone(); + let builder = get_builder_with_ping_mainspace("127.0.0.1:8282"); async move { server_running.store(true, Ordering::SeqCst); builder.build_server().await.unwrap(); @@ -94,32 +103,46 @@ async fn it_receives_namespaced_events() { assert_eq!(reply.name(), "pong"); } -#[tokio::test] -async fn it_handles_errors() { - let error_occurred = Arc::new(AtomicBool::new(false)); - let builder = IPCBuilder::new() +struct ErrorOccurredKey; + +impl TypeMapKey for ErrorOccurredKey { + type Value = Arc; +} + +fn get_builder_with_error_handling(error_occurred: Arc, address: &str) -> IPCBuilder { + IPCBuilder::new() + .insert::(error_occurred) .on("ping", move |_, _| { Box::pin(async move { Err(Error::from("ERRROROROROR")) }) }) .on("error", { - let error_occurred = Arc::clone(&error_occurred); - move |_, e| { - let error_occurred = Arc::clone(&error_occurred); + move |ctx, e| { Box::pin(async move { let error = e.data::()?; assert!(error.message.len() > 0); assert_eq!(error.code, 500); - error_occurred.store(true, Ordering::SeqCst); + { + let data = ctx.data.read().await; + let error_occurred = data.get::().unwrap(); + error_occurred.store(true, Ordering::SeqCst); + } Ok(()) }) } }) - .address("127.0.0.1:8283"); + .address(address) +} + +#[tokio::test] +async fn it_handles_errors() { + let error_occurred = Arc::new(AtomicBool::new(false)); + let builder = get_builder_with_error_handling(Arc::clone(&error_occurred), "127.0.0.1:8283"); let server_running = Arc::new(AtomicBool::new(false)); tokio::spawn({ let server_running = Arc::clone(&server_running); - let builder = builder.clone(); + let error_occurred = Arc::clone(&error_occurred); + let builder = get_builder_with_error_handling(error_occurred, "127.0.0.1:8283"); async move { server_running.store(true, Ordering::SeqCst); builder.build_server().await.unwrap();