Replace hashmaps with dashmaps

Signed-off-by: trivernis <trivernis@protonmail.com>
main
trivernis 2 years ago
parent 9d75afc8c1
commit 5818aa3133
Signed by: Trivernis
GPG Key ID: DFFFCC2C7A02DB45

119
Cargo.lock generated

@ -153,6 +153,17 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "dashmap"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c8858831f7781322e539ea39e72449c46b059638250c14344fec8d0aa6e539c"
dependencies = [
"cfg-if",
"num_cpus",
"parking_lot",
]
[[package]]
name = "digest"
version = "0.9.0"
@ -358,6 +369,15 @@ version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "http"
version = "0.2.6"
@ -514,6 +534,16 @@ version = "0.2.122"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259"
[[package]]
name = "lock_api"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.16"
@ -630,6 +660,16 @@ dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "once_cell"
version = "1.10.0"
@ -675,6 +715,29 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "parking_lot"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "995f667a6c822200b0433ac218e05582f0e2efa1b922a3fd2fbaadc5f87bab37"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-sys",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@ -910,6 +973,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "sct"
version = "0.6.1"
@ -1024,6 +1093,7 @@ dependencies = [
name = "serenity-rich-interaction"
version = "0.2.6"
dependencies = [
"dashmap",
"futures",
"serde_json",
"serenity",
@ -1051,6 +1121,12 @@ version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32"
[[package]]
name = "smallvec"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
[[package]]
name = "socket2"
version = "0.4.4"
@ -1515,6 +1591,49 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5acdd78cb4ba54c0045ac14f62d8f94a03d10047904ae2a40afa1e99d8f70825"
dependencies = [
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_msvc"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17cffbe740121affb56fad0fc0e421804adf0ae00891205213b5cecd30db881d"
[[package]]
name = "windows_i686_gnu"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2564fde759adb79129d9b4f54be42b32c89970c18ebf93124ca8870a498688ed"
[[package]]
name = "windows_i686_msvc"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cd9d32ba70453522332c14d38814bceeb747d80b3958676007acadd7e166956"
[[package]]
name = "windows_x86_64_gnu"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfce6deae227ee8d356d19effc141a509cc503dfd1f850622ec4b0f84428e1f4"
[[package]]
name = "windows_x86_64_msvc"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d19538ccc21819d01deaf88d6a17eae6596a12e9aafdbb97916fb49896d89de9"
[[package]]
name = "winreg"
version = "0.10.1"

@ -21,6 +21,7 @@ thiserror = "1.0.30"
tracing= "0.1.33"
futures = "0.3.21"
serde_json = "1.0.79"
dashmap = "5.2.0"
[dependencies.serenity]
version = "0.10.10"

@ -2,14 +2,14 @@ use crate::error::Result;
use crate::events::RichEventHandler;
use crate::menu::traits::EventDrivenMessage;
use crate::menu::EventDrivenMessageContainer;
use dashmap::DashMap;
use serenity::client::ClientBuilder;
use serenity::http::Http;
use serenity::model::channel::Message;
use serenity::model::id::{ChannelId, MessageId};
use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
pub static SHORT_TIMEOUT: Duration = Duration::from_secs(5);
pub static MEDIUM_TIMEOUT: Duration = Duration::from_secs(20);
@ -18,6 +18,28 @@ pub static EXTRA_LONG_TIMEOUT: Duration = Duration::from_secs(600);
pub type BoxedEventDrivenMessage = Box<dyn EventDrivenMessage>;
pub struct BoxedMessage(pub BoxedEventDrivenMessage);
impl<T: EventDrivenMessage + 'static> From<Box<T>> for BoxedMessage {
fn from(m: Box<T>) -> Self {
Self(m as Box<dyn EventDrivenMessage>)
}
}
impl Deref for BoxedMessage {
type Target = BoxedEventDrivenMessage;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for BoxedMessage {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash)]
pub struct MessageHandle {
pub channel_id: u64,
@ -61,7 +83,7 @@ impl<'a> RegisterRichInteractions for ClientBuilder<'a> {
/// Registers the rich interactions with a custom rich event handler
fn register_rich_interactions_with(self, rich_handler: RichEventHandler) -> Self {
self.type_map_insert::<EventDrivenMessageContainer>(Arc::new(Mutex::new(HashMap::new())))
self.type_map_insert::<EventDrivenMessageContainer>(Arc::new(DashMap::new()))
.raw_event_handler(rich_handler)
}
}

@ -11,6 +11,9 @@ pub enum Error {
#[error("Serenity Rich Interaction is not fully initialized")]
Uninitialized,
#[error("the cache is not available, therefore some required data is missing")]
NoCache,
#[error("{0}")]
Msg(String),
}

@ -1,5 +1,5 @@
use crate::core::MessageHandle;
use crate::menu::{get_listeners_from_context, MessageRef};
use crate::menu::get_listeners_from_context;
use crate::Result;
use serenity::client::Context;
use serenity::model::channel::Reaction;
@ -21,30 +21,19 @@ pub async fn start_update_loop(ctx: &Context) -> Result<()> {
loop {
{
tracing::debug!("Updating messages...");
let messages = {
let msgs_lock = event_messages.lock().await;
msgs_lock
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect::<Vec<(MessageHandle, MessageRef)>>()
};
let mut frozen_messages = Vec::new();
for (key, msg) in messages {
let mut msg = msg.lock().await;
for entry in event_messages.iter() {
let mut msg = entry.value().lock().await;
if let Err(e) = msg.update(&http).await {
tracing::error!("Failed to update message: {:?}", e);
}
if msg.is_frozen() {
frozen_messages.push(key);
frozen_messages.push(*entry.key());
}
}
{
let mut msgs_lock = event_messages.lock().await;
for key in frozen_messages {
msgs_lock.remove(&key);
}
for key in frozen_messages {
event_messages.remove(&key);
}
tracing::debug!("Messages updated");
}
@ -63,16 +52,14 @@ pub async fn handle_message_delete(
message_id: MessageId,
) -> Result<()> {
let mut affected_messages = Vec::new();
{
let listeners = get_listeners_from_context(ctx).await?;
let mut listeners_lock = listeners.lock().await;
let handle = MessageHandle::new(channel_id, message_id);
if let Some(msg) = listeners_lock.get(&handle) {
affected_messages.push(Arc::clone(msg));
listeners_lock.remove(&handle);
}
let listeners = get_listeners_from_context(ctx).await?;
let handle = MessageHandle::new(channel_id, message_id);
if let Some(msg) = listeners.get(&handle) {
affected_messages.push(msg.value().clone());
listeners.remove(&handle);
}
for msg in affected_messages {
let mut msg = msg.lock().await;
msg.on_deleted(ctx).await?;
@ -89,18 +76,17 @@ pub async fn handle_message_delete_bulk(
message_ids: &Vec<MessageId>,
) -> Result<()> {
let mut affected_messages = Vec::new();
{
let listeners = get_listeners_from_context(ctx).await?;
let mut listeners_lock = listeners.lock().await;
for message_id in message_ids {
let handle = MessageHandle::new(channel_id, *message_id);
if let Some(msg) = listeners_lock.get_mut(&handle) {
affected_messages.push(Arc::clone(msg));
listeners_lock.remove(&handle);
}
let listeners = get_listeners_from_context(ctx).await?;
for message_id in message_ids {
let handle = MessageHandle::new(channel_id, *message_id);
if let Some(msg) = listeners.get(&handle) {
affected_messages.push(msg.value().clone());
listeners.remove(&handle);
}
}
for msg in affected_messages {
let mut msg = msg.lock().await;
msg.on_deleted(ctx).await?;
@ -112,17 +98,14 @@ pub async fn handle_message_delete_bulk(
/// Fired when a reaction was added to a message
#[tracing::instrument(level = "debug", skip(ctx))]
pub async fn handle_reaction_add(ctx: &Context, reaction: &Reaction) -> Result<()> {
let mut affected_messages = Vec::new();
{
let listeners = get_listeners_from_context(ctx).await?;
let mut listeners_lock = listeners.lock().await;
let listeners = get_listeners_from_context(ctx).await?;
let handle = MessageHandle::new(reaction.channel_id, reaction.message_id);
let handle = MessageHandle::new(reaction.channel_id, reaction.message_id);
if let Some(msg) = listeners_lock.get_mut(&handle) {
affected_messages.push(Arc::clone(&msg));
}
let mut affected_messages = Vec::new();
if let Some(msg) = listeners.get(&handle) {
affected_messages.push(msg.value().clone());
}
for msg in affected_messages {
let mut msg = msg.lock().await;
msg.on_reaction_add(ctx, reaction.clone()).await?;
@ -134,17 +117,14 @@ pub async fn handle_reaction_add(ctx: &Context, reaction: &Reaction) -> Result<(
/// Fired when a reaction was added to a message
#[tracing::instrument(level = "debug", skip(ctx))]
pub async fn handle_reaction_remove(ctx: &Context, reaction: &Reaction) -> Result<()> {
let mut affected_messages = Vec::new();
{
let listeners = get_listeners_from_context(ctx).await?;
let mut listeners_lock = listeners.lock().await;
let listeners = get_listeners_from_context(ctx).await?;
let handle = MessageHandle::new(reaction.channel_id, reaction.message_id);
let handle = MessageHandle::new(reaction.channel_id, reaction.message_id);
if let Some(msg) = listeners_lock.get_mut(&handle) {
affected_messages.push(Arc::clone(&msg));
}
let mut affected_messages = Vec::new();
if let Some(msg) = listeners.get(&handle) {
affected_messages.push(msg.value().clone());
}
for msg in affected_messages {
let mut msg = msg.lock().await;
msg.on_reaction_remove(ctx, reaction.clone()).await?;

@ -1,5 +1,6 @@
use crate::events::event_callbacks;
use crate::Result;
use futures::future;
use serenity::async_trait;
use serenity::client::{Context, RawEventHandler};
use serenity::model::event;
@ -51,20 +52,18 @@ impl RichEventHandler {
/// Handles a generic event
#[tracing::instrument(level = "debug", skip_all)]
async fn handle_event<T: 'static + Send + Sync>(&self, ctx: Context, value: T) {
let callbacks = self.callbacks.clone();
tokio::spawn(async move {
let value = value;
if let Some(callbacks) = callbacks.get(&TypeId::of::<T>()) {
for callback in callbacks {
if let Some(cb) = callback.downcast_ref::<EventCallback<T>>() {
if let Err(e) = cb.run(&ctx, &value).await {
tracing::error!("Error in event callback: {:?}", e);
}
}
}
}
});
let value = value;
if let Some(callbacks) = self.callbacks.get(&TypeId::of::<T>()) {
let futures = callbacks
.iter()
.filter_map(|cb| cb.downcast_ref::<EventCallback<T>>())
.map(|cb| cb.run(&ctx, &value));
future::join_all(futures)
.await
.into_iter()
.filter_map(Result::err)
.for_each(|e| tracing::error!("Error in event callback: {:?}", e));
}
}
pub fn add_event<T: 'static, F: 'static>(&mut self, cb: F) -> &mut Self

@ -1,16 +1,16 @@
use crate::core::{BoxedEventDrivenMessage, MessageHandle};
use crate::core::{BoxedMessage, MessageHandle};
use crate::Error;
use crate::Result;
use dashmap::DashMap;
use serenity::client::Context;
use serenity::prelude::TypeMapKey;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Container to store event driven messages in the serenity context data
pub struct EventDrivenMessageContainer;
pub type MessageRef = Arc<Mutex<BoxedEventDrivenMessage>>;
pub type EventDrivenMessagesRef = Arc<Mutex<HashMap<MessageHandle, MessageRef>>>;
pub type MessageRef = Arc<Mutex<BoxedMessage>>;
pub type EventDrivenMessagesRef = Arc<DashMap<MessageHandle, MessageRef>>;
impl TypeMapKey for EventDrivenMessageContainer {
type Value = EventDrivenMessagesRef;

@ -35,10 +35,9 @@ pub async fn previous_page(ctx: &Context, menu: &mut Menu<'_>, _: Reaction) -> R
#[tracing::instrument(level = "debug", skip_all)]
pub async fn close_menu(ctx: &Context, menu: &mut Menu<'_>, _: Reaction) -> Result<()> {
menu.close(ctx.http()).await?;
let listeners = get_listeners_from_context(&ctx).await?;
let mut listeners_lock = listeners.lock().await;
let message = menu.message.read().await;
listeners_lock.remove(&*message);
let listeners = get_listeners_from_context(&ctx).await?;
listeners.remove(&*message);
Ok(())
}

@ -66,20 +66,27 @@ impl ActionContainer {
/// A menu message
pub struct Menu<'a> {
pub message: Arc<RwLock<MessageHandle>>,
pub(crate) message: Arc<RwLock<MessageHandle>>,
pub pages: Vec<Page<'a>>,
pub current_page: usize,
pub controls: HashMap<String, ActionContainer>,
pub(crate) controls: HashMap<String, ActionContainer>,
pub timeout: Instant,
pub sticky: bool,
pub data: TypeMap,
pub help_entries: HashMap<String, String>,
pub(crate) help_entries: HashMap<String, String>,
owner: Option<UserId>,
closed: bool,
listeners: EventDrivenMessagesRef,
}
impl Menu<'_> {
impl<'a> Menu<'a> {
/// Returns the current page of the menu
pub fn get_current_page(&self) -> Result<&Page<'a>> {
self.pages
.get(self.current_page)
.ok_or(Error::PageNotFound(self.current_page))
}
/// Removes all reactions from the menu
#[tracing::instrument(level = "debug", skip_all)]
pub(crate) async fn close(&mut self, http: &Http) -> Result<()> {
@ -107,13 +114,7 @@ impl Menu<'_> {
let handle = self.message.read().await;
(*handle).clone()
};
let current_page = self
.pages
.get(self.current_page)
.cloned()
.ok_or(Error::PageNotFound(self.current_page))?
.get()
.await?;
let current_page = self.get_current_page()?.get().await?;
let message = http
.send_message(
@ -121,18 +122,12 @@ impl Menu<'_> {
&serde_json::to_value(current_page.0).unwrap(),
)
.await?;
let mut controls = self
.controls
.clone()
.into_iter()
.collect::<Vec<(String, ActionContainer)>>();
controls.sort_by_key(|(_, a)| a.position);
for emoji in controls.into_iter().map(|(e, _)| e) {
for control in &self.controls {
http.create_reaction(
message.channel_id.0,
message.id.0,
&ReactionType::Unicode(emoji.clone()),
&ReactionType::Unicode(control.0.clone()),
)
.await?;
}
@ -144,9 +139,8 @@ impl Menu<'_> {
};
{
tracing::debug!("Changing key of message");
let mut listeners_lock = self.listeners.lock().await;
let menu = listeners_lock.remove(&old_handle).unwrap();
listeners_lock.insert(new_handle, menu);
let menu = self.listeners.remove(&old_handle).unwrap();
self.listeners.insert(new_handle, menu.1);
}
tracing::debug!("Deleting original message");
http.delete_message(old_handle.channel_id, old_handle.message_id)
@ -165,16 +159,17 @@ impl<'a> EventDrivenMessage for Menu<'a> {
#[tracing::instrument(level = "debug", skip_all)]
async fn update(&mut self, http: &Http) -> Result<()> {
tracing::trace!("Checking for menu timeout");
if Instant::now() >= self.timeout {
tracing::debug!("Menu timout reached. Closing menu.");
self.close(http).await?;
} else if self.sticky {
tracing::debug!("Message is sticky. Checking for new messages in channel...");
let handle = {
let handle = self.message.read().await;
(*handle).clone()
};
let channel_id = ChannelId(handle.channel_id);
let messages = channel_id
.messages(http, |p| p.after(handle.message_id).limit(1))
@ -191,8 +186,9 @@ impl<'a> EventDrivenMessage for Menu<'a> {
#[tracing::instrument(level = "debug", skip_all)]
async fn on_reaction_add(&mut self, ctx: &Context, reaction: Reaction) -> Result<()> {
let current_user = ctx.http.get_current_user().await?;
let reaction_user_id = reaction.user_id.ok_or_else(|| Error::NoCache)?;
if reaction.user_id.unwrap().0 == current_user.id.0 {
if reaction_user_id.0 == current_user.id.0 {
tracing::debug!("Reaction is from current user.");
return Ok(());
}
@ -200,8 +196,9 @@ impl<'a> EventDrivenMessage for Menu<'a> {
tracing::debug!("Deleting user reaction.");
reaction.delete(ctx).await?;
if let Some(owner) = self.owner {
if owner != reaction.user_id.unwrap() {
if owner != reaction_user_id {
tracing::debug!(
"Menu has an owner and the reaction is not from the owner of the menu"
);
@ -403,7 +400,7 @@ impl MenuBuilder {
.await?;
let message = channel_id.send_message(ctx, |_| &mut current_page).await?;
let listeners = get_listeners_from_context(ctx).await?;
tracing::debug!("Sorting controls...");
let mut controls = self
.controls
@ -415,6 +412,7 @@ impl MenuBuilder {
tracing::debug!("Creating menu...");
let message_handle = MessageHandle::new(message.channel_id, message.id);
let handle_lock = Arc::new(RwLock::new(message_handle));
let listeners = get_listeners_from_context(ctx).await?;
let menu = Menu {
message: Arc::clone(&handle_lock),
@ -431,11 +429,7 @@ impl MenuBuilder {
};
tracing::debug!("Storing menu to listeners...");
{
let mut listeners_lock = listeners.lock().await;
tracing::trace!("Listeners locked.");
listeners_lock.insert(message_handle, Arc::new(Mutex::new(Box::new(menu))));
}
listeners.insert(message_handle, Arc::new(Mutex::new(Box::new(menu).into())));
tracing::debug!("Adding controls...");
for (emoji, _) in controls {

Loading…
Cancel
Save