diff --git a/src/database/roles.rs b/src/database/roles.rs index bc39ebc..e1b33be 100644 --- a/src/database/roles.rs +++ b/src/database/roles.rs @@ -1,3 +1,4 @@ +use crate::database::models::Role; use crate::database::role_permissions::RolePermissions; use crate::database::{DatabaseResult, RedisConnection, Table}; use crate::utils::error::DBError; @@ -41,3 +42,45 @@ impl Table for Roles { .map_err(DBError::from) } } + +impl Roles { + pub fn create_role( + &self, + name: String, + description: Option, + permissions: Vec, + ) -> DatabaseResult { + let mut connection = self.database_connection.lock().unwrap(); + let exists = connection.query_opt("SELECT id FROM roles WHERE name = $1", &[&name])?; + + if exists.is_some() { + return Err(DBError::RecordExists); + } + log::trace!("Preparing transaction"); + let mut transaction = connection.transaction()?; + let result: DatabaseResult = { + let row = transaction.query_one( + "INSERT INTO roles (name, description) VALUES ($1, $2) RETURNING *", + &[&name, &description], + )?; + let role: Role = serde_postgres::from_row(&row)?; + for permission in permissions { + transaction.execute( + "INSERT INTO role_permissions (role_id, permission_id) VALUES ($1, $2);", + &[&role.id, &permission], + )?; + } + + Ok(role) + }; + if let Err(_) = result { + log::trace!("Rollback"); + transaction.rollback()?; + } else { + log::trace!("Commit"); + transaction.commit()?; + } + + result + } +} diff --git a/src/server/messages.rs b/src/server/messages.rs index ff5ffd3..7263a8d 100644 --- a/src/server/messages.rs +++ b/src/server/messages.rs @@ -61,3 +61,10 @@ impl InfoEntry { pub struct GetPermissionsRequest { pub role_ids: Vec, } + +#[derive(Deserialize)] +pub struct CreateRoleRequest { + pub name: String, + pub description: Option, + pub permission: Vec, +} diff --git a/src/server/rpc_methods.rs b/src/server/rpc_methods.rs index 71687be..c1f49e8 100644 --- a/src/server/rpc_methods.rs +++ b/src/server/rpc_methods.rs @@ -5,3 +5,4 @@ pub(crate) const INFO: [u8; 4] = [0x49, 0x4e, 0x46, 0x4f]; pub(crate) const VALIDATE_TOKEN: [u8; 4] = [0x56, 0x41, 0x4c, 0x49]; pub(crate) const GET_ROLES: [u8; 4] = [0x52, 0x4f, 0x4c, 0x45]; pub(crate) const GET_ROLE_PERMISSIONS: [u8; 4] = [0x50, 0x45, 0x52, 0x4d]; +pub(crate) const CREATE_ROLE: [u8; 4] = [0x43, 0x52, 0x4f, 0x4c]; diff --git a/src/server/user_rpc.rs b/src/server/user_rpc.rs index 217b02e..07df352 100644 --- a/src/server/user_rpc.rs +++ b/src/server/user_rpc.rs @@ -1,6 +1,8 @@ use super::rpc_methods::*; use crate::database::Database; -use crate::server::messages::{ErrorMessage, GetPermissionsRequest, InfoEntry, TokenRequest}; +use crate::server::messages::{ + CreateRoleRequest, ErrorMessage, GetPermissionsRequest, InfoEntry, TokenRequest, +}; use crate::utils::get_user_id_from_token; use msgrpc::message::Message; use msgrpc::server::RpcServer; @@ -41,6 +43,7 @@ impl UserRpcServer { GET_ROLES => self.handle_get_roles(&handler.message.data), VALIDATE_TOKEN => self.handle_validate_token(&handler.message.data), GET_ROLE_PERMISSIONS => self.handle_get_permissions(&handler.message.data), + CREATE_ROLE => self.handle_create_role(&handler.message.data), _ => Err(ErrorMessage::new("Invalid Method".to_string())), } .unwrap_or_else(|e| Message::new_with_serialize(ERROR, e)); @@ -88,6 +91,12 @@ impl UserRpcServer { "Returns all permissions the given roles are assigned to", "{role_ids: [i32]}", ), + InfoEntry::new( + "create role", + CREATE_ROLE, + "Creates a new role with the given permissions", + "{name: String, description: String, permissions: [i32]}", + ), ], )) } @@ -127,4 +136,17 @@ impl UserRpcServer { Ok(Message::new_with_serialize(GET_ROLES, response_data)) } + + fn handle_create_role(&self, data: &Vec) -> RpcResult { + log::trace!("Create Role"); + let message = CreateRoleRequest::deserialize(&mut Deserializer::new(&mut data.as_slice())) + .map_err(|e| ErrorMessage::new(e.to_string()))?; + let role = self.database.roles.create_role( + message.name, + message.description, + message.permission, + )?; + + Ok(Message::new_with_serialize(CREATE_ROLE, role)) + } } diff --git a/src/utils/error.rs b/src/utils/error.rs index 83d3e5e..71590d2 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -21,6 +21,19 @@ impl Display for DBError { impl error::Error for DBError {} +impl DBError { + pub fn to_string(&self) -> String { + match self { + DBError::GenericError(g) => g.clone(), + DBError::RecordExists => "Record Exists".to_string(), + DBError::Postgres(p) => p.to_string(), + DBError::Redis(r) => r.to_string(), + DBError::DeserializeError(de) => de.to_string(), + DBError::ScryptError => "sCrypt Hash creation error".to_string(), + } + } +} + pub type DatabaseResult = Result; impl From for DBError {