diff --git a/Cargo.lock b/Cargo.lock index 13e16b3..129944e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -153,6 +153,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "dashmap" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8858831f7781322e539ea39e72449c46b059638250c14344fec8d0aa6e539c" +dependencies = [ + "cfg-if", + "num_cpus", + "parking_lot", +] + [[package]] name = "digest" version = "0.9.0" @@ -358,6 +369,15 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "http" version = "0.2.6" @@ -514,6 +534,16 @@ version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259" +[[package]] +name = "lock_api" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.16" @@ -630,6 +660,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.10.0" @@ -675,6 +715,29 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "parking_lot" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "995f667a6c822200b0433ac218e05582f0e2efa1b922a3fd2fbaadc5f87bab37" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "percent-encoding" version = "2.1.0" @@ -910,6 +973,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + [[package]] name = "sct" version = "0.6.1" @@ -1024,6 +1093,7 @@ dependencies = [ name = "serenity-rich-interaction" version = "0.2.6" dependencies = [ + "dashmap", "futures", "serde_json", "serenity", @@ -1051,6 +1121,12 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" +[[package]] +name = "smallvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" + [[package]] name = "socket2" version = "0.4.4" @@ -1515,6 +1591,49 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5acdd78cb4ba54c0045ac14f62d8f94a03d10047904ae2a40afa1e99d8f70825" +dependencies = [ + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_msvc" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17cffbe740121affb56fad0fc0e421804adf0ae00891205213b5cecd30db881d" + +[[package]] +name = "windows_i686_gnu" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2564fde759adb79129d9b4f54be42b32c89970c18ebf93124ca8870a498688ed" + +[[package]] +name = "windows_i686_msvc" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cd9d32ba70453522332c14d38814bceeb747d80b3958676007acadd7e166956" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfce6deae227ee8d356d19effc141a509cc503dfd1f850622ec4b0f84428e1f4" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d19538ccc21819d01deaf88d6a17eae6596a12e9aafdbb97916fb49896d89de9" + [[package]] name = "winreg" version = "0.10.1" diff --git a/Cargo.toml b/Cargo.toml index 16a3efc..5bc2ff4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ thiserror = "1.0.30" tracing= "0.1.33" futures = "0.3.21" serde_json = "1.0.79" +dashmap = "5.2.0" [dependencies.serenity] version = "0.10.10" diff --git a/src/core.rs b/src/core.rs index e12ab55..dc71b15 100644 --- a/src/core.rs +++ b/src/core.rs @@ -2,14 +2,14 @@ use crate::error::Result; use crate::events::RichEventHandler; use crate::menu::traits::EventDrivenMessage; use crate::menu::EventDrivenMessageContainer; +use dashmap::DashMap; use serenity::client::ClientBuilder; use serenity::http::Http; use serenity::model::channel::Message; use serenity::model::id::{ChannelId, MessageId}; -use std::collections::HashMap; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; pub static SHORT_TIMEOUT: Duration = Duration::from_secs(5); pub static MEDIUM_TIMEOUT: Duration = Duration::from_secs(20); @@ -18,6 +18,28 @@ pub static EXTRA_LONG_TIMEOUT: Duration = Duration::from_secs(600); pub type BoxedEventDrivenMessage = Box; +pub struct BoxedMessage(pub BoxedEventDrivenMessage); + +impl From> for BoxedMessage { + fn from(m: Box) -> Self { + Self(m as Box) + } +} + +impl Deref for BoxedMessage { + type Target = BoxedEventDrivenMessage; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for BoxedMessage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + #[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash)] pub struct MessageHandle { pub channel_id: u64, @@ -61,7 +83,7 @@ impl<'a> RegisterRichInteractions for ClientBuilder<'a> { /// Registers the rich interactions with a custom rich event handler fn register_rich_interactions_with(self, rich_handler: RichEventHandler) -> Self { - self.type_map_insert::(Arc::new(Mutex::new(HashMap::new()))) + self.type_map_insert::(Arc::new(DashMap::new())) .raw_event_handler(rich_handler) } } diff --git a/src/error.rs b/src/error.rs index 27c460c..3f099d8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,9 @@ pub enum Error { #[error("Serenity Rich Interaction is not fully initialized")] Uninitialized, + #[error("the cache is not available, therefore some required data is missing")] + NoCache, + #[error("{0}")] Msg(String), } diff --git a/src/events/event_callbacks.rs b/src/events/event_callbacks.rs index 60fec82..2a5d8f0 100644 --- a/src/events/event_callbacks.rs +++ b/src/events/event_callbacks.rs @@ -1,5 +1,5 @@ use crate::core::MessageHandle; -use crate::menu::{get_listeners_from_context, MessageRef}; +use crate::menu::get_listeners_from_context; use crate::Result; use serenity::client::Context; use serenity::model::channel::Reaction; @@ -21,30 +21,19 @@ pub async fn start_update_loop(ctx: &Context) -> Result<()> { loop { { tracing::debug!("Updating messages..."); - let messages = { - let msgs_lock = event_messages.lock().await; - - msgs_lock - .iter() - .map(|(k, v)| (*k, v.clone())) - .collect::>() - }; let mut frozen_messages = Vec::new(); - for (key, msg) in messages { - let mut msg = msg.lock().await; + for entry in event_messages.iter() { + let mut msg = entry.value().lock().await; if let Err(e) = msg.update(&http).await { tracing::error!("Failed to update message: {:?}", e); } if msg.is_frozen() { - frozen_messages.push(key); + frozen_messages.push(*entry.key()); } } - { - let mut msgs_lock = event_messages.lock().await; - for key in frozen_messages { - msgs_lock.remove(&key); - } + for key in frozen_messages { + event_messages.remove(&key); } tracing::debug!("Messages updated"); } @@ -63,16 +52,14 @@ pub async fn handle_message_delete( message_id: MessageId, ) -> Result<()> { let mut affected_messages = Vec::new(); - { - let listeners = get_listeners_from_context(ctx).await?; - let mut listeners_lock = listeners.lock().await; - - let handle = MessageHandle::new(channel_id, message_id); - if let Some(msg) = listeners_lock.get(&handle) { - affected_messages.push(Arc::clone(msg)); - listeners_lock.remove(&handle); - } + let listeners = get_listeners_from_context(ctx).await?; + let handle = MessageHandle::new(channel_id, message_id); + + if let Some(msg) = listeners.get(&handle) { + affected_messages.push(msg.value().clone()); + listeners.remove(&handle); } + for msg in affected_messages { let mut msg = msg.lock().await; msg.on_deleted(ctx).await?; @@ -89,18 +76,17 @@ pub async fn handle_message_delete_bulk( message_ids: &Vec, ) -> Result<()> { let mut affected_messages = Vec::new(); - { - let listeners = get_listeners_from_context(ctx).await?; - let mut listeners_lock = listeners.lock().await; - - for message_id in message_ids { - let handle = MessageHandle::new(channel_id, *message_id); - if let Some(msg) = listeners_lock.get_mut(&handle) { - affected_messages.push(Arc::clone(msg)); - listeners_lock.remove(&handle); - } + + let listeners = get_listeners_from_context(ctx).await?; + + for message_id in message_ids { + let handle = MessageHandle::new(channel_id, *message_id); + if let Some(msg) = listeners.get(&handle) { + affected_messages.push(msg.value().clone()); + listeners.remove(&handle); } } + for msg in affected_messages { let mut msg = msg.lock().await; msg.on_deleted(ctx).await?; @@ -112,17 +98,14 @@ pub async fn handle_message_delete_bulk( /// Fired when a reaction was added to a message #[tracing::instrument(level = "debug", skip(ctx))] pub async fn handle_reaction_add(ctx: &Context, reaction: &Reaction) -> Result<()> { - let mut affected_messages = Vec::new(); - { - let listeners = get_listeners_from_context(ctx).await?; - let mut listeners_lock = listeners.lock().await; + let listeners = get_listeners_from_context(ctx).await?; + let handle = MessageHandle::new(reaction.channel_id, reaction.message_id); - let handle = MessageHandle::new(reaction.channel_id, reaction.message_id); - - if let Some(msg) = listeners_lock.get_mut(&handle) { - affected_messages.push(Arc::clone(&msg)); - } + let mut affected_messages = Vec::new(); + if let Some(msg) = listeners.get(&handle) { + affected_messages.push(msg.value().clone()); } + for msg in affected_messages { let mut msg = msg.lock().await; msg.on_reaction_add(ctx, reaction.clone()).await?; @@ -134,17 +117,14 @@ pub async fn handle_reaction_add(ctx: &Context, reaction: &Reaction) -> Result<( /// Fired when a reaction was added to a message #[tracing::instrument(level = "debug", skip(ctx))] pub async fn handle_reaction_remove(ctx: &Context, reaction: &Reaction) -> Result<()> { - let mut affected_messages = Vec::new(); - { - let listeners = get_listeners_from_context(ctx).await?; - let mut listeners_lock = listeners.lock().await; + let listeners = get_listeners_from_context(ctx).await?; + let handle = MessageHandle::new(reaction.channel_id, reaction.message_id); - let handle = MessageHandle::new(reaction.channel_id, reaction.message_id); - - if let Some(msg) = listeners_lock.get_mut(&handle) { - affected_messages.push(Arc::clone(&msg)); - } + let mut affected_messages = Vec::new(); + if let Some(msg) = listeners.get(&handle) { + affected_messages.push(msg.value().clone()); } + for msg in affected_messages { let mut msg = msg.lock().await; msg.on_reaction_remove(ctx, reaction.clone()).await?; diff --git a/src/events/handler.rs b/src/events/handler.rs index 40634b2..d21f880 100644 --- a/src/events/handler.rs +++ b/src/events/handler.rs @@ -1,5 +1,6 @@ use crate::events::event_callbacks; use crate::Result; +use futures::future; use serenity::async_trait; use serenity::client::{Context, RawEventHandler}; use serenity::model::event; @@ -51,20 +52,18 @@ impl RichEventHandler { /// Handles a generic event #[tracing::instrument(level = "debug", skip_all)] async fn handle_event(&self, ctx: Context, value: T) { - let callbacks = self.callbacks.clone(); - - tokio::spawn(async move { - let value = value; - if let Some(callbacks) = callbacks.get(&TypeId::of::()) { - for callback in callbacks { - if let Some(cb) = callback.downcast_ref::>() { - if let Err(e) = cb.run(&ctx, &value).await { - tracing::error!("Error in event callback: {:?}", e); - } - } - } - } - }); + let value = value; + if let Some(callbacks) = self.callbacks.get(&TypeId::of::()) { + let futures = callbacks + .iter() + .filter_map(|cb| cb.downcast_ref::>()) + .map(|cb| cb.run(&ctx, &value)); + future::join_all(futures) + .await + .into_iter() + .filter_map(Result::err) + .for_each(|e| tracing::error!("Error in event callback: {:?}", e)); + } } pub fn add_event(&mut self, cb: F) -> &mut Self diff --git a/src/menu/container.rs b/src/menu/container.rs index 9cd473d..5d59384 100644 --- a/src/menu/container.rs +++ b/src/menu/container.rs @@ -1,16 +1,16 @@ -use crate::core::{BoxedEventDrivenMessage, MessageHandle}; +use crate::core::{BoxedMessage, MessageHandle}; use crate::Error; use crate::Result; +use dashmap::DashMap; use serenity::client::Context; use serenity::prelude::TypeMapKey; -use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; /// Container to store event driven messages in the serenity context data pub struct EventDrivenMessageContainer; -pub type MessageRef = Arc>; -pub type EventDrivenMessagesRef = Arc>>; +pub type MessageRef = Arc>; +pub type EventDrivenMessagesRef = Arc>; impl TypeMapKey for EventDrivenMessageContainer { type Value = EventDrivenMessagesRef; diff --git a/src/menu/controls.rs b/src/menu/controls.rs index 84508ea..57d9f8d 100644 --- a/src/menu/controls.rs +++ b/src/menu/controls.rs @@ -35,10 +35,9 @@ pub async fn previous_page(ctx: &Context, menu: &mut Menu<'_>, _: Reaction) -> R #[tracing::instrument(level = "debug", skip_all)] pub async fn close_menu(ctx: &Context, menu: &mut Menu<'_>, _: Reaction) -> Result<()> { menu.close(ctx.http()).await?; - let listeners = get_listeners_from_context(&ctx).await?; - let mut listeners_lock = listeners.lock().await; let message = menu.message.read().await; - listeners_lock.remove(&*message); + let listeners = get_listeners_from_context(&ctx).await?; + listeners.remove(&*message); Ok(()) } diff --git a/src/menu/menu.rs b/src/menu/menu.rs index 95b468a..38947f2 100644 --- a/src/menu/menu.rs +++ b/src/menu/menu.rs @@ -66,20 +66,27 @@ impl ActionContainer { /// A menu message pub struct Menu<'a> { - pub message: Arc>, + pub(crate) message: Arc>, pub pages: Vec>, pub current_page: usize, - pub controls: HashMap, + pub(crate) controls: HashMap, pub timeout: Instant, pub sticky: bool, pub data: TypeMap, - pub help_entries: HashMap, + pub(crate) help_entries: HashMap, owner: Option, closed: bool, listeners: EventDrivenMessagesRef, } -impl Menu<'_> { +impl<'a> Menu<'a> { + /// Returns the current page of the menu + pub fn get_current_page(&self) -> Result<&Page<'a>> { + self.pages + .get(self.current_page) + .ok_or(Error::PageNotFound(self.current_page)) + } + /// Removes all reactions from the menu #[tracing::instrument(level = "debug", skip_all)] pub(crate) async fn close(&mut self, http: &Http) -> Result<()> { @@ -107,13 +114,7 @@ impl Menu<'_> { let handle = self.message.read().await; (*handle).clone() }; - let current_page = self - .pages - .get(self.current_page) - .cloned() - .ok_or(Error::PageNotFound(self.current_page))? - .get() - .await?; + let current_page = self.get_current_page()?.get().await?; let message = http .send_message( @@ -121,18 +122,12 @@ impl Menu<'_> { &serde_json::to_value(current_page.0).unwrap(), ) .await?; - let mut controls = self - .controls - .clone() - .into_iter() - .collect::>(); - controls.sort_by_key(|(_, a)| a.position); - for emoji in controls.into_iter().map(|(e, _)| e) { + for control in &self.controls { http.create_reaction( message.channel_id.0, message.id.0, - &ReactionType::Unicode(emoji.clone()), + &ReactionType::Unicode(control.0.clone()), ) .await?; } @@ -144,9 +139,8 @@ impl Menu<'_> { }; { tracing::debug!("Changing key of message"); - let mut listeners_lock = self.listeners.lock().await; - let menu = listeners_lock.remove(&old_handle).unwrap(); - listeners_lock.insert(new_handle, menu); + let menu = self.listeners.remove(&old_handle).unwrap(); + self.listeners.insert(new_handle, menu.1); } tracing::debug!("Deleting original message"); http.delete_message(old_handle.channel_id, old_handle.message_id) @@ -165,16 +159,17 @@ impl<'a> EventDrivenMessage for Menu<'a> { #[tracing::instrument(level = "debug", skip_all)] async fn update(&mut self, http: &Http) -> Result<()> { tracing::trace!("Checking for menu timeout"); + if Instant::now() >= self.timeout { tracing::debug!("Menu timout reached. Closing menu."); self.close(http).await?; } else if self.sticky { tracing::debug!("Message is sticky. Checking for new messages in channel..."); + let handle = { let handle = self.message.read().await; (*handle).clone() }; - let channel_id = ChannelId(handle.channel_id); let messages = channel_id .messages(http, |p| p.after(handle.message_id).limit(1)) @@ -191,8 +186,9 @@ impl<'a> EventDrivenMessage for Menu<'a> { #[tracing::instrument(level = "debug", skip_all)] async fn on_reaction_add(&mut self, ctx: &Context, reaction: Reaction) -> Result<()> { let current_user = ctx.http.get_current_user().await?; + let reaction_user_id = reaction.user_id.ok_or_else(|| Error::NoCache)?; - if reaction.user_id.unwrap().0 == current_user.id.0 { + if reaction_user_id.0 == current_user.id.0 { tracing::debug!("Reaction is from current user."); return Ok(()); } @@ -200,8 +196,9 @@ impl<'a> EventDrivenMessage for Menu<'a> { tracing::debug!("Deleting user reaction."); reaction.delete(ctx).await?; + if let Some(owner) = self.owner { - if owner != reaction.user_id.unwrap() { + if owner != reaction_user_id { tracing::debug!( "Menu has an owner and the reaction is not from the owner of the menu" ); @@ -403,7 +400,7 @@ impl MenuBuilder { .await?; let message = channel_id.send_message(ctx, |_| &mut current_page).await?; - let listeners = get_listeners_from_context(ctx).await?; + tracing::debug!("Sorting controls..."); let mut controls = self .controls @@ -415,6 +412,7 @@ impl MenuBuilder { tracing::debug!("Creating menu..."); let message_handle = MessageHandle::new(message.channel_id, message.id); let handle_lock = Arc::new(RwLock::new(message_handle)); + let listeners = get_listeners_from_context(ctx).await?; let menu = Menu { message: Arc::clone(&handle_lock), @@ -431,11 +429,7 @@ impl MenuBuilder { }; tracing::debug!("Storing menu to listeners..."); - { - let mut listeners_lock = listeners.lock().await; - tracing::trace!("Listeners locked."); - listeners_lock.insert(message_handle, Arc::new(Mutex::new(Box::new(menu)))); - } + listeners.insert(message_handle, Arc::new(Mutex::new(Box::new(menu).into()))); tracing::debug!("Adding controls..."); for (emoji, _) in controls {