diff --git a/Cargo.lock b/Cargo.lock index 523ee09..7bcc2ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,6 +379,7 @@ dependencies = [ "log 0.4.14", "r2d2", "thiserror", + "tokio-diesel", ] [[package]] @@ -2161,6 +2162,18 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "tokio-diesel" +version = "0.3.0" +source = "git+https://github.com/Trivernis/tokio-diesel#f4af42558246ab323600622ba8d08803d3c18842" +dependencies = [ + "async-trait", + "diesel", + "futures", + "r2d2", + "tokio", +] + [[package]] name = "tokio-macros" version = "1.1.0" diff --git a/database/Cargo.lock b/database/Cargo.lock index 9e4766e..9a8d748 100644 --- a/database/Cargo.lock +++ b/database/Cargo.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +[[package]] +name = "async-trait" +version = "0.1.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36ea56748e10732c49404c153638a15ec3d6211ec5ff35d9bb20e13b93576adf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.0.1" @@ -48,6 +59,7 @@ dependencies = [ "log", "r2d2", "thiserror", + "tokio-diesel", ] [[package]] @@ -91,6 +103,76 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +[[package]] +name = "futures" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d5813545e459ad3ca1bff9915e9ad7f1a47dc6a91b627ce321d5863b7dd253" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce79c6a52a299137a6013061e0cf0e688fce5d7f1bc60125f520912fdb29ec25" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "098cd1c6dda6ca01650f1a37a794245eb73181d0d4d4e955e2f3c37db7af1815" + +[[package]] +name = "futures-io" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "365a1a1fb30ea1c03a830fdb2158f5236833ac81fa0ad12fe35b29cddc35cb04" + +[[package]] +name = "futures-sink" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5629433c555de3d82861a7a4e3794a4c40040390907cfbfd7143a92a426c23" + +[[package]] +name = "futures-task" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba7aa51095076f3ba6d9a1f702f74bd05ec65f555d70d2033d55ba8d69f581bc" + +[[package]] +name = "futures-util" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c144ad54d60f23927f0a6b6d816e4271278b64f005ad65e4e35291d2de9c025" +dependencies = [ + "futures-core", + "futures-sink", + "futures-task", + "pin-project-lite", + "pin-utils", +] + +[[package]] +name = "hermit-abi" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c" +dependencies = [ + "libc", +] + [[package]] name = "instant" version = "0.1.9" @@ -164,6 +246,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "parking_lot" version = "0.11.1" @@ -189,6 +281,18 @@ dependencies = [ "winapi", ] +[[package]] +name = "pin-project-lite" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0e1f259c92177c30a4c9d177246edd0a3568b25756a977d0632cf8fa37e905" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pq-sys" version = "0.4.6" @@ -299,6 +403,29 @@ dependencies = [ "winapi", ] +[[package]] +name = "tokio" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134af885d758d645f0f0505c9a8b3f9bf8a348fd822e112ab5248138348f1722" +dependencies = [ + "autocfg", + "num_cpus", + "pin-project-lite", +] + +[[package]] +name = "tokio-diesel" +version = "0.3.0" +source = "git+https://github.com/Trivernis/tokio-diesel#f4af42558246ab323600622ba8d08803d3c18842" +dependencies = [ + "async-trait", + "diesel", + "futures", + "r2d2", + "tokio", +] + [[package]] name = "unicode-xid" version = "0.2.1" diff --git a/database/Cargo.toml b/database/Cargo.toml index 6f34ac2..095772d 100644 --- a/database/Cargo.toml +++ b/database/Cargo.toml @@ -13,4 +13,5 @@ thiserror = "1.0.24" diesel = {version="1.4.6", features=["postgres", "r2d2", "chrono"]} log = "0.4.14" diesel_migrations = "1.4.0" -r2d2 = "0.8.9" \ No newline at end of file +r2d2 = "0.8.9" +tokio-diesel = {git = "https://github.com/Trivernis/tokio-diesel"} \ No newline at end of file diff --git a/database/src/database.rs b/database/src/database.rs index 850ce63..27580c6 100644 --- a/database/src/database.rs +++ b/database/src/database.rs @@ -7,6 +7,7 @@ use diesel::{delete, insert_into}; use std::any; use std::fmt::Debug; use std::str::FromStr; +use tokio_diesel::*; #[derive(Clone)] pub struct Database { @@ -22,22 +23,22 @@ impl Database { } /// Returns a guild setting from the database - pub fn get_guild_setting( + pub async fn get_guild_setting( &self, guild_id: u64, - key: &str, + key: String, ) -> DatabaseResult> where T: FromStr, { use guild_settings::dsl; log::debug!("Retrieving setting '{}' for guild {}", key, guild_id); - let connection = self.pool.get()?; let entries: Vec = dsl::guild_settings .filter(dsl::guild_id.eq(guild_id as i64)) .filter(dsl::key.eq(key)) - .load::(&connection)?; + .load_async::(&self.pool) + .await?; log::trace!("Result is {:?}", entries); if let Some(first) = entries.first() { @@ -57,16 +58,20 @@ impl Database { } /// Upserting a guild setting - pub fn set_guild_setting(&self, guild_id: u64, key: &str, value: T) -> DatabaseResult<()> + pub async fn set_guild_setting( + &self, + guild_id: u64, + key: String, + value: T, + ) -> DatabaseResult<()> where T: ToString + Debug, { use guild_settings::dsl; log::debug!("Setting '{}' to '{:?}' for guild {}", key, value, guild_id); - let connection = self.pool.get()?; insert_into(dsl::guild_settings) - .values(&GuildSettingInsert { + .values(GuildSettingInsert { guild_id: guild_id as i64, key: key.to_string(), value: value.to_string(), @@ -74,69 +79,76 @@ impl Database { .on_conflict((dsl::guild_id, dsl::key)) .do_update() .set(dsl::value.eq(value.to_string())) - .execute(&connection)?; + .execute_async(&self.pool) + .await?; Ok(()) } /// Deletes a guild setting - pub fn delete_guild_setting(&self, guild_id: u64, key: &str) -> DatabaseResult<()> { + pub async fn delete_guild_setting(&self, guild_id: u64, key: String) -> DatabaseResult<()> { use guild_settings::dsl; - log::debug!("Deleting '{}' for guild {}", key, guild_id); - let connection = self.pool.get()?; delete(dsl::guild_settings) .filter(dsl::guild_id.eq(guild_id as i64)) .filter(dsl::key.eq(key)) - .execute(&connection)?; + .execute_async(&self.pool) + .await?; Ok(()) } /// Returns a list of all guild playlists - pub fn get_guild_playlists(&self, guild_id: u64) -> DatabaseResult> { + pub async fn get_guild_playlists(&self, guild_id: u64) -> DatabaseResult> { use guild_playlists::dsl; log::debug!("Retrieving guild playlists for guild {}", guild_id); - let connection = self.pool.get()?; + let playlists: Vec = dsl::guild_playlists .filter(dsl::guild_id.eq(guild_id as i64)) - .load::(&connection)?; + .load_async::(&self.pool) + .await?; Ok(playlists) } /// Returns a guild playlist by name - pub fn get_guild_playlist( + pub async fn get_guild_playlist( &self, guild_id: u64, - name: &str, + name: String, ) -> DatabaseResult> { use guild_playlists::dsl; log::debug!("Retriving guild playlist '{}' for guild {}", name, guild_id); - let connection = self.pool.get()?; let playlists: Vec = dsl::guild_playlists .filter(dsl::guild_id.eq(guild_id as i64)) .filter(dsl::name.eq(name)) - .load::(&connection)?; + .load_async::(&self.pool) + .await?; Ok(playlists.into_iter().next()) } /// Adds a new playlist to the database overwriting the old one - pub fn add_guild_playlist(&self, guild_id: u64, name: &str, url: &str) -> DatabaseResult<()> { + pub async fn add_guild_playlist( + &self, + guild_id: u64, + name: String, + url: String, + ) -> DatabaseResult<()> { use guild_playlists::dsl; log::debug!("Inserting guild playlist '{}' for guild {}", name, guild_id); - let connection = self.pool.get()?; + insert_into(dsl::guild_playlists) .values(GuildPlaylistInsert { guild_id: guild_id as i64, - name: name.to_string(), - url: url.to_string(), + name: name.clone(), + url: url.clone(), }) .on_conflict((dsl::guild_id, dsl::name)) .do_update() - .set(dsl::url.eq(url.to_string())) - .execute(&connection)?; + .set(dsl::url.eq(url)) + .execute_async(&self.pool) + .await?; Ok(()) } diff --git a/database/src/error.rs b/database/src/error.rs index a431aab..5ec1112 100644 --- a/database/src/error.rs +++ b/database/src/error.rs @@ -18,4 +18,7 @@ pub enum DatabaseError { #[error("Result Error: {0}")] ResultError(#[from] diesel::result::Error), + + #[error("AsyncError: {0}")] + AsyncError(#[from] tokio_diesel::AsyncError), } diff --git a/src/commands/music/mod.rs b/src/commands/music/mod.rs index b730534..cd09ad4 100644 --- a/src/commands/music/mod.rs +++ b/src/commands/music/mod.rs @@ -286,7 +286,9 @@ async fn get_songs_for_query(ctx: &Context, msg: &Message, query: &str) -> BotRe log::debug!("Query is a saved playlist"); let pl_name: &str = captures.get(1).unwrap().as_str(); log::trace!("Playlist name is {}", pl_name); - let playlist_opt = database.get_guild_playlist(guild_id.0, pl_name)?; + let playlist_opt = database + .get_guild_playlist(guild_id.0, pl_name.to_string()) + .await?; log::trace!("Playlist is {:?}", playlist_opt); if let Some(playlist) = playlist_opt { diff --git a/src/commands/music/playlists.rs b/src/commands/music/playlists.rs index ab7761b..d922438 100644 --- a/src/commands/music/playlists.rs +++ b/src/commands/music/playlists.rs @@ -14,7 +14,7 @@ async fn playlists(ctx: &Context, msg: &Message) -> CommandResult { log::debug!("Displaying playlists for guild {}", guild.id); let database = get_database_from_context(ctx).await; - let playlists = database.get_guild_playlists(guild.id.0)?; + let playlists = database.get_guild_playlists(guild.id.0).await?; msg.channel_id .send_message(ctx, |m| { m.embed(|e| { diff --git a/src/commands/music/save_playlist.rs b/src/commands/music/save_playlist.rs index 9da0566..19c142a 100644 --- a/src/commands/music/save_playlist.rs +++ b/src/commands/music/save_playlist.rs @@ -29,7 +29,9 @@ async fn save_playlist(ctx: &Context, msg: &Message, mut args: Args) -> CommandR ); let database = get_database_from_context(ctx).await; - database.add_guild_playlist(guild.id.0, &*name, url)?; + database + .add_guild_playlist(guild.id.0, name.clone(), url.to_string()) + .await?; msg.channel_id .say(ctx, format!("Playlist **{}** saved", name)) diff --git a/src/commands/settings/get.rs b/src/commands/settings/get.rs index 1143751..2effdec 100644 --- a/src/commands/settings/get.rs +++ b/src/commands/settings/get.rs @@ -21,7 +21,9 @@ async fn get(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { if let Some(key) = args.single::().ok() { log::debug!("Displaying guild setting of '{}'", key); - let setting = database.get_guild_setting::(guild.id.0, &key)?; + let setting = database + .get_guild_setting::(guild.id.0, key.clone()) + .await?; match setting { Some(value) => { @@ -43,7 +45,10 @@ async fn get(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let key = key.to_string(); { - match database.get_guild_setting::(guild.id.0, &key)? { + match database + .get_guild_setting::(guild.id.0, key.clone()) + .await? + { Some(value) => kv_pairs.push(format!("`{}` = `{}`", key, value)), None => kv_pairs.push(format!("`{}` not set", key)), } diff --git a/src/commands/settings/set.rs b/src/commands/settings/set.rs index f9fd2e8..f4b6f55 100644 --- a/src/commands/settings/set.rs +++ b/src/commands/settings/set.rs @@ -30,12 +30,16 @@ async fn set(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { let guild = msg.guild(&ctx.cache).await.unwrap(); if let Ok(value) = args.single::() { - database.set_guild_setting(guild.id.0, &key, value.clone())?; + database + .set_guild_setting(guild.id.0, key.clone(), value.clone()) + .await?; msg.channel_id .say(ctx, format!("Set `{}` to `{}`", key, value)) .await?; } else { - database.delete_guild_setting(guild.id.0, &key)?; + database + .delete_guild_setting(guild.id.0, key.clone()) + .await?; msg.channel_id .say(ctx, format!("Setting `{}` reset to default", key)) .await?; diff --git a/src/providers/settings.rs b/src/providers/settings.rs index 77369b4..7345610 100644 --- a/src/providers/settings.rs +++ b/src/providers/settings.rs @@ -36,6 +36,7 @@ pub async fn get_setting( let data = ctx.data.read().await; let database = data.get::().unwrap(); database - .get_guild_setting::(guild_id.0, &setting.to_string()) + .get_guild_setting::(guild_id.0, setting.to_string()) + .await .map_err(BotError::from) }