diff --git a/src/database/user_roles.rs b/src/database/user_roles.rs index e8916a7..07ad3ef 100644 --- a/src/database/user_roles.rs +++ b/src/database/user_roles.rs @@ -5,6 +5,8 @@ use crate::database::models::Role; use crate::database::{DatabaseResult, PostgresPool, Table}; use crate::utils::error::DBError; +use std::collections::HashSet; +use std::iter::FromIterator; /// A table that stores the relation between users and roles #[derive(Clone)] @@ -43,4 +45,37 @@ impl UserRoles { serde_postgres::from_rows(&rows).map_err(DBError::from) } + + pub fn update_roles(&self, user_id: i32, roles: Vec) -> DatabaseResult> { + let mut connection = self.pool.get()?; + let mut transaction = connection.transaction()?; + let role_ids_result = transaction.query( + "SELECT roles.id FROM roles WHERE roles.name = ANY ($1)", + &[&roles], + )?; + let role_ids: Vec = serde_postgres::from_rows(role_ids_result.iter())?; + let role_ids: HashSet = HashSet::from_iter(role_ids.into_iter()); + let role_result = transaction.query("SELECT roles.id FROM roles, user_roles WHERE roles.id = user_roles.role_id AND user_roles.user_id = $1", &[&user_id])?; + let current_roles: Vec = serde_postgres::from_rows(role_result.iter())?; + + let current_roles = HashSet::from_iter(current_roles.into_iter()); + let added_roles: HashSet<&i32> = role_ids.difference(¤t_roles).collect(); + let removed_roles: HashSet<&i32> = current_roles.difference(&role_ids).collect(); + + for role in removed_roles { + transaction.query( + "DELETE FROM user_roles WHERE role_id = $1 AND user_id = $2", + &[role, &user_id], + )?; + } + for role in added_roles { + transaction.query( + "INSERT INTO user_roles (user_id, role_id) VALUES ($1, $2)", + &[&user_id, role], + )?; + } + transaction.commit()?; + + Ok(self.by_user(user_id)?) + } } diff --git a/src/server/http_server.rs b/src/server/http_server.rs index 0d61eaf..a8850f3 100644 --- a/src/server/http_server.rs +++ b/src/server/http_server.rs @@ -439,8 +439,19 @@ impl UserHttpServer { &message.email.clone().unwrap_or(user_record.email), &message.password, )?; + let roles = if let Some(roles) = &message.roles { + require_permission!(database, request, USER_UPDATE_PERM); + database.user_roles.update_roles(record.id, roles.clone())? + } else { + database.user_roles.by_user(record.id)? + }; - Ok(Response::json(&record)) + Ok(Response::json(&UserFullInformation { + id: record.id, + email: record.email, + name: record.name, + roles, + })) } /// Deletes a user completely diff --git a/src/server/messages.rs b/src/server/messages.rs index 85bf895..c23d590 100644 --- a/src/server/messages.rs +++ b/src/server/messages.rs @@ -126,6 +126,7 @@ pub struct UpdateUserRequest { pub name: Option, pub email: Option, pub password: Option, + pub roles: Option>, pub own_password: String, }