use crate::error::{Error, Result}; use crate::event::Event; use crate::ipc::stream_emitter::StreamEmitter; use crate::protocol::AsyncProtocolStream; use std::collections::HashMap; use std::mem; use std::ops::{Deref, DerefMut}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::oneshot::Sender; use tokio::sync::{oneshot, Mutex, RwLock}; use typemap_rev::TypeMap; pub(crate) type ReplyListeners = Arc>>>; /// An object provided to each callback function. /// Currently it only holds the event emitter to emit response events in event callbacks. /// ```rust /// use rmp_ipc::prelude::*; /// /// async fn my_callback(ctx: &Context, _event: Event) -> IPCResult<()> { /// // use the emitter on the context object to emit events /// // inside callbacks /// ctx.emitter.emit("ping", ()).await?; /// Ok(()) /// } /// ``` pub struct Context { /// The event emitter pub emitter: StreamEmitter, /// Field to store additional context data pub data: Arc>, stop_sender: Arc>>>, reply_listeners: ReplyListeners, } impl Clone for Context where S: AsyncProtocolStream, { fn clone(&self) -> Self { Self { emitter: self.emitter.clone(), data: Arc::clone(&self.data), stop_sender: Arc::clone(&self.stop_sender), reply_listeners: Arc::clone(&self.reply_listeners), } } } impl

Context

where P: AsyncProtocolStream, { pub(crate) fn new( emitter: StreamEmitter

, data: Arc>, stop_sender: Option>, reply_listeners: ReplyListeners, ) -> Self { Self { emitter, reply_listeners, data, stop_sender: Arc::new(Mutex::new(stop_sender)), } } /// Waits for a reply to the given message ID #[tracing::instrument(level = "debug", skip(self))] pub async fn await_reply(&self, message_id: u64) -> Result { let (rx, tx) = oneshot::channel(); { let mut listeners = self.reply_listeners.lock().await; listeners.insert(message_id, rx); } let event = tx.await?; Ok(event) } /// Stops the listener and closes the connection #[tracing::instrument(level = "debug", skip(self))] pub async fn stop(self) -> Result<()> { let mut sender = self.stop_sender.lock().await; if let Some(sender) = mem::take(&mut *sender) { sender.send(()).map_err(|_| Error::SendError)?; } Ok(()) } /// Returns the channel for a reply to the given message id pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option> { let mut listeners = self.reply_listeners.lock().await; listeners.remove(&ref_id) } } pub struct PooledContext { contexts: Vec>>, } impl Clone for PooledContext where S: AsyncProtocolStream, { fn clone(&self) -> Self { Self { contexts: self.contexts.clone(), } } } pub struct PoolGuard where T: Clone, { inner: T, count: Arc, } impl Deref for PoolGuard where T: Clone, { type Target = T; fn deref(&self) -> &Self::Target { &self.inner } } impl DerefMut for PoolGuard where T: Clone, { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner } } impl Clone for PoolGuard where T: Clone, { fn clone(&self) -> Self { self.acquire(); Self { inner: self.inner.clone(), count: Arc::clone(&self.count), } } } impl Drop for PoolGuard where T: Clone, { fn drop(&mut self) { self.release(); } } impl PoolGuard where T: Clone, { pub(crate) fn new(inner: T) -> Self { Self { inner, count: Arc::new(AtomicUsize::new(0)), } } /// Acquires the context by adding 1 to the count #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn acquire(&self) { let count = self.count.fetch_add(1, Ordering::Relaxed); tracing::trace!(count); } /// Releases the connection by subtracting from the stored count #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn release(&self) { let count = self.count.fetch_sub(1, Ordering::Relaxed); tracing::trace!(count); } pub(crate) fn count(&self) -> usize { self.count.load(Ordering::Relaxed) } } impl

PooledContext

where P: AsyncProtocolStream, { /// Creates a new pooled context from a list of contexts pub(crate) fn new(contexts: Vec>) -> Self { Self { contexts: contexts.into_iter().map(PoolGuard::new).collect(), } } /// Acquires a context from the pool /// It always chooses the one that is used the least #[tracing::instrument(level = "trace", skip_all)] pub fn acquire(&self) -> PoolGuard> { self.contexts .iter() .min_by_key(|c| c.count()) .unwrap() .clone() } }