diff --git a/Cargo.lock b/Cargo.lock index 02abeb37..795eed0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -148,7 +148,7 @@ dependencies = [ [[package]] name = "rmp-ipc" -version = "0.4.2" +version = "0.4.3" dependencies = [ "lazy_static", "log", diff --git a/Cargo.toml b/Cargo.toml index 06c40691..614bd7ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rmp-ipc" -version = "0.4.2" +version = "0.4.3" authors = ["trivernis "] edition = "2018" readme = "README.md" diff --git a/src/error.rs b/src/error.rs index ab3cb2ce..933b39af 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,6 +22,9 @@ pub enum Error { #[error("Channel Error: {0}")] ReceiveError(#[from] oneshot::error::RecvError), + + #[error("Send Error")] + SendError, } impl From for Error { diff --git a/src/ipc/client.rs b/src/ipc/client.rs index b1795984..adaae1f2 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -7,6 +7,7 @@ use crate::namespaces::namespace::Namespace; use std::collections::HashMap; use std::sync::Arc; use tokio::net::TcpStream; +use tokio::sync::oneshot; use tokio::sync::RwLock; use typemap_rev::TypeMap; @@ -26,20 +27,26 @@ impl IPCClient { let stream = TcpStream::connect(address).await?; let (read_half, write_half) = stream.into_split(); let emitter = StreamEmitter::new(write_half); + let (tx, rx) = oneshot::channel(); let ctx = Context::new( StreamEmitter::clone(&emitter), Arc::new(RwLock::new(self.data)), + Some(tx), ); let handler = Arc::new(self.handler); let namespaces = Arc::new(self.namespaces); log::debug!("IPC client connected to {}", address); - tokio::spawn({ + let handle = tokio::spawn({ let ctx = Context::clone(&ctx); async move { handle_connection(namespaces, handler, read_half, ctx).await; } }); + tokio::spawn(async move { + let _ = rx.await; + handle.abort(); + }); Ok(ctx) } diff --git a/src/ipc/context.rs b/src/ipc/context.rs index f942dcf4..4cd10e7f 100644 --- a/src/ipc/context.rs +++ b/src/ipc/context.rs @@ -1,8 +1,10 @@ -use crate::error::Result; +use crate::error::{Error, Result}; use crate::ipc::stream_emitter::StreamEmitter; use crate::Event; use std::collections::HashMap; +use std::mem; use std::sync::Arc; +use tokio::sync::oneshot::Sender; use tokio::sync::{oneshot, Mutex, RwLock}; use typemap_rev::TypeMap; @@ -28,15 +30,22 @@ pub struct Context { /// Field to store additional context data pub data: Arc>, + stop_sender: Arc>>>, + reply_listeners: Arc>>>, } impl Context { - pub(crate) fn new(emitter: StreamEmitter, data: Arc>) -> Self { + pub(crate) fn new( + emitter: StreamEmitter, + data: Arc>, + stop_sender: Option>, + ) -> Self { Self { emitter, reply_listeners: Arc::new(Mutex::new(HashMap::new())), data, + stop_sender: Arc::new(Mutex::new(stop_sender)), } } @@ -52,6 +61,16 @@ impl Context { Ok(event) } + /// Stops the listener and closes the connection + pub async fn stop(self) -> Result<()> { + let mut sender = self.stop_sender.lock().await; + if let Some(sender) = mem::take(&mut *sender) { + sender.send(()).map_err(|_| Error::SendError)?; + } + + Ok(()) + } + /// Returns the channel for a reply to the given message id pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option> { let mut listeners = self.reply_listeners.lock().await; diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 1fa3e5ca..6926eb60 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -37,7 +37,7 @@ impl IPCServer { tokio::spawn(async { let (read_half, write_half) = stream.into_split(); let emitter = StreamEmitter::new(write_half); - let ctx = Context::new(StreamEmitter::clone(&emitter), data); + let ctx = Context::new(StreamEmitter::clone(&emitter), data, None); handle_connection(namespaces, handler, read_half, ctx).await; });