use crate::error::{Error, Result}; use crate::event::Event; use crate::ipc::stream_emitter::{EmitMetadata, StreamEmitter}; use futures::future; use futures::future::Either; use std::collections::HashMap; use std::mem; use std::ops::{Deref, DerefMut}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::oneshot::Sender; use tokio::sync::{oneshot, Mutex, RwLock}; use tokio::time::Duration; use typemap_rev::TypeMap; use crate::payload::IntoPayload; #[cfg(feature = "serialize")] use crate::payload::{DynamicSerializer, SerdePayload}; pub(crate) type ReplyListeners = Arc>>>; /// An object provided to each callback function. /// Currently it only holds the event emitter to emit response events in event callbacks. /// ```rust /// use bromine::prelude::*; /// /// async fn my_callback(ctx: &Context, _event: Event) -> IPCResult<()> { /// // use the emitter on the context object to emit events /// // inside callbacks /// ctx.emit("ping", ()).await?; /// Ok(()) /// } /// ``` #[derive(Clone)] pub struct Context { /// The event emitter emitter: StreamEmitter, /// Field to store additional context data pub data: Arc>, stop_sender: Arc>>>, reply_listeners: ReplyListeners, reply_timeout: Duration, ref_id: Option, #[cfg(feature = "serialize")] pub default_serializer: DynamicSerializer, } impl Context { pub(crate) fn new( emitter: StreamEmitter, data: Arc>, stop_sender: Option>, reply_listeners: ReplyListeners, reply_timeout: Duration, #[cfg(feature = "serialize")] default_serializer: DynamicSerializer, ) -> Self { Self { emitter, reply_listeners, data, stop_sender: Arc::new(Mutex::new(stop_sender)), reply_timeout, #[cfg(feature = "serialize")] default_serializer, ref_id: None, } } /// Emits an event with a given payload that can be serialized into bytes pub async fn emit, P: IntoPayload>( &self, name: S, payload: P, ) -> Result { let payload_bytes = payload.into_payload(&self)?; if let Some(ref_id) = &self.ref_id { self.emitter .emit_response(*ref_id, name, payload_bytes) .await } else { self.emitter.emit(name, payload_bytes).await } } /// Emits an event to a specific namespace pub async fn emit_to, S2: AsRef, P: IntoPayload>( &self, namespace: S1, name: S2, payload: P, ) -> Result { let payload_bytes = payload.into_payload(&self)?; if let Some(ref_id) = &self.ref_id { self.emitter .emit_response_to(*ref_id, namespace, name, payload_bytes) .await } else { self.emitter.emit_to(namespace, name, payload_bytes).await } } /// Waits for a reply to the given message ID #[tracing::instrument(level = "debug", skip(self))] pub async fn await_reply(&self, message_id: u64) -> Result { let (rx, tx) = oneshot::channel(); { let mut listeners = self.reply_listeners.lock().await; listeners.insert(message_id, rx); } let result = future::select( Box::pin(tx), Box::pin(tokio::time::sleep(self.reply_timeout)), ) .await; let event = match result { Either::Left((tx_result, _)) => Ok(tx_result?), Either::Right(_) => { let mut listeners = self.reply_listeners.lock().await; listeners.remove(&message_id); Err(Error::Timeout) } }?; Ok(event) } /// Stops the listener and closes the connection #[tracing::instrument(level = "debug", skip(self))] pub async fn stop(self) -> Result<()> { let mut sender = self.stop_sender.lock().await; if let Some(sender) = mem::take(&mut *sender) { sender.send(()).map_err(|_| Error::SendError)?; } Ok(()) } #[cfg(feature = "serialize")] pub fn create_serde_payload(&self, data: T) -> SerdePayload { SerdePayload::new(self.default_serializer.clone(), data) } /// Returns the channel for a reply to the given message id pub(crate) async fn get_reply_sender(&self, ref_id: u64) -> Option> { let mut listeners = self.reply_listeners.lock().await; listeners.remove(&ref_id) } pub(crate) fn set_ref_id(&mut self, id: Option) { self.ref_id = id; } } pub struct PooledContext { contexts: Vec>, } pub struct PoolGuard where T: Clone, { inner: T, count: Arc, } impl Deref for PoolGuard where T: Clone, { type Target = T; fn deref(&self) -> &Self::Target { &self.inner } } impl DerefMut for PoolGuard where T: Clone, { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner } } impl Clone for PoolGuard where T: Clone, { fn clone(&self) -> Self { self.acquire(); Self { inner: self.inner.clone(), count: Arc::clone(&self.count), } } } impl Drop for PoolGuard where T: Clone, { fn drop(&mut self) { self.release(); } } impl PoolGuard where T: Clone, { pub(crate) fn new(inner: T) -> Self { Self { inner, count: Arc::new(AtomicUsize::new(0)), } } /// Acquires the context by adding 1 to the count #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn acquire(&self) { let count = self.count.fetch_add(1, Ordering::Relaxed); tracing::trace!(count); } /// Releases the connection by subtracting from the stored count #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn release(&self) { let count = self.count.fetch_sub(1, Ordering::Relaxed); tracing::trace!(count); } pub(crate) fn count(&self) -> usize { self.count.load(Ordering::Relaxed) } } impl PooledContext { /// Creates a new pooled context from a list of contexts pub(crate) fn new(contexts: Vec) -> Self { Self { contexts: contexts.into_iter().map(PoolGuard::new).collect(), } } /// Acquires a context from the pool /// It always chooses the one that is used the least #[tracing::instrument(level = "trace", skip_all)] pub fn acquire(&self) -> PoolGuard { self.contexts .iter() .min_by_key(|c| c.count()) .unwrap() .clone() } }