diff --git a/src/database/tokens.rs b/src/database/tokens.rs index 137ea0f..d613ab7 100644 --- a/src/database/tokens.rs +++ b/src/database/tokens.rs @@ -109,9 +109,9 @@ impl TokenStoreEntry { /// If the token is expired -1 is returned. pub fn request_ttl(&self) -> i32 { max( - (self.request_ttl - self.ttl_start.elapsed().as_secs() as u32) as i32, + self.request_ttl as i64 - self.ttl_start.elapsed().as_secs() as i64, -1, - ) + ) as i32 } /// Returns the ttl for the refresh token @@ -119,9 +119,9 @@ impl TokenStoreEntry { /// If the token is expired -1 is returned. pub fn refresh_ttl(&self) -> i32 { max( - (self.refresh_ttl - self.ttl_start.elapsed().as_secs() as u32) as i32, + self.refresh_ttl as i64 - self.ttl_start.elapsed().as_secs() as i64, -1, - ) + ) as i32 } /// Returns the request token if it hasn't expired @@ -162,6 +162,13 @@ impl TokenStoreEntry { self.refresh_ttl = min(self.refresh_ttl(), 0) as u32; self.ttl_start = Instant::now(); } + + /// Invalidates the token entry which causes it to be deleted with the next + /// clearing of expired tokens by the token store. The + pub(crate) fn invalidate(&mut self) { + self.request_ttl = 0; + self.refresh_ttl = 0; + } } #[derive(Debug)] @@ -177,10 +184,10 @@ impl TokenStore { } /// Returns the token store entry for a given request token - pub fn get_by_request_token(&self, request_token: &String) -> Option<&TokenStoreEntry> { + pub fn get_by_request_token(&mut self, request_token: &String) -> Option<&mut TokenStoreEntry> { let user_id = get_user_id_from_token(&request_token)?; - if let Some(user_tokens) = self.tokens.get(&user_id) { - user_tokens.iter().find(|e| { + if let Some(user_tokens) = self.tokens.get_mut(&user_id) { + user_tokens.iter_mut().find(|e| { if let Some(token) = e.request_token() { &token == request_token } else { @@ -193,10 +200,10 @@ impl TokenStore { } /// Returns the token store entry by the given refresh token - pub fn get_by_refresh_token(&self, refresh_token: &String) -> Option<&TokenStoreEntry> { + pub fn get_by_refresh_token(&mut self, refresh_token: &String) -> Option<&mut TokenStoreEntry> { let user_id = get_user_id_from_token(&refresh_token)?; - if let Some(user_tokens) = self.tokens.get(&user_id) { - user_tokens.iter().find(|e| { + if let Some(user_tokens) = self.tokens.get_mut(&user_id) { + user_tokens.iter_mut().find(|e| { if let Some(token) = e.refresh_token() { &token == refresh_token } else { diff --git a/src/database/users.rs b/src/database/users.rs index fcb2f3d..69014d8 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -97,7 +97,7 @@ impl Users { /// Validates a request token and returns if it's valid and the /// ttl of the token pub fn validate_request_token(&self, token: &String) -> DatabaseResult<(bool, i32)> { - let store = self.token_store.lock(); + let mut store = self.token_store.lock(); let entry = store.get_by_request_token(&token); if let Some(entry) = entry { @@ -109,7 +109,7 @@ impl Users { /// Validates a refresh token and returns if it's valid and the ttl pub fn validate_refresh_token(&self, token: &String) -> DatabaseResult<(bool, i32)> { - let store = self.token_store.lock(); + let mut store = self.token_store.lock(); let entry = store.get_by_refresh_token(&token); if let Some(entry) = entry { @@ -134,6 +134,18 @@ impl Users { } } + pub fn delete_tokens(&self, request_token: &String) -> DatabaseResult { + let mut token_store = self.token_store.lock(); + let tokens = token_store.get_by_request_token(request_token); + if let Some(tokens) = tokens { + tokens.invalidate(); + + Ok(true) + } else { + Err(DBError::GenericError("Invalid request token!".to_string())) + } + } + /// Validates the login data of the user by creating the hash for the given password /// and comparing it with the database entry fn validate_login(&self, email: &String, password: &String) -> DatabaseResult { diff --git a/src/server/http_server.rs b/src/server/http_server.rs index 9d8ab71..be7e795 100644 --- a/src/server/http_server.rs +++ b/src/server/http_server.rs @@ -1,8 +1,9 @@ use crate::database::Database; -use crate::server::messages::{LoginMessage, RefreshMessage}; +use crate::server::messages::{LoginMessage, LogoutConfirmation, LogoutMessage, RefreshMessage}; use crate::utils::error::DBError; use rouille::{Request, Response, Server}; use serde::export::Formatter; +use serde::Serialize; use std::error::Error; use std::fmt::{self, Display}; use std::io::Read; @@ -16,10 +17,10 @@ pub struct UserHttpServer { database: Database, } -#[derive(Debug)] +#[derive(Debug, Serialize)] pub struct HTTPError { message: String, - code: u16, + error_code: u16, } impl Display for HTTPError { @@ -33,20 +34,23 @@ impl From for HTTPError { fn from(other: DBError) -> Self { Self { message: other.to_string(), - code: 400, + error_code: 400, } } } impl Into for HTTPError { fn into(self) -> Response { - Response::text(self.message).with_status_code(self.code) + Response::json(&self).with_status_code(self.error_code) } } impl HTTPError { pub fn new(message: String, code: u16) -> Self { - Self { message, code } + Self { + message, + error_code: code, + } } } @@ -69,10 +73,13 @@ impl UserHttpServer { let server = Server::new(&listen_address, move |request| { router!(request, (POST) (/login) => { - Self::login(&database, request).unwrap_or_else(|e|e.into()) + Self::login(&database, request).unwrap_or_else(HTTPError::into) }, (POST) (/new-token) => { - Self::new_token(&database, request).unwrap_or_else(|e|e.into()) + Self::new_token(&database, request).unwrap_or_else(HTTPError::into) + }, + (POST) (/logout) => { + Self::logout(&database, request).unwrap_or_else(HTTPError::into) }, _ => Response::empty_404() ) @@ -84,36 +91,44 @@ impl UserHttpServer { /// Handles the login part of the REST api fn login(database: &Database, request: &Request) -> HTTPResult { - if let Some(mut data) = request.data() { - let mut data_string = String::new(); - data.read_to_string(&mut data_string) - .map_err(|_| HTTPError::new("Failed to read request data".to_string(), 500))?; - let login_request: LoginMessage = serde_json::from_str(data_string.as_str()) + let login_request: LoginMessage = + serde_json::from_str(parse_string_body(request)?.as_str()) .map_err(|e| HTTPError::new(e.to_string(), 400))?; - let tokens = database - .users - .create_tokens(&login_request.email, &login_request.password)?; - Ok(Response::json(&tokens)) - } else { - Err(HTTPError::new("Missing Request Data".to_string(), 400)) - } + let tokens = database + .users + .create_tokens(&login_request.email, &login_request.password)?; + + Ok(Response::json(&tokens).with_status_code(201)) } /// Handles the new token part of the rest api fn new_token(database: &Database, request: &Request) -> HTTPResult { - if let Some(mut data) = request.data() { - let mut data_string = String::new(); - data.read_to_string(&mut data_string) - .map_err(|_| HTTPError::new("Failed to read request data".to_string(), 500))?; - let message: RefreshMessage = serde_json::from_str(data_string.as_str()) - .map_err(|e| HTTPError::new(e.to_string(), 400))?; + let message: RefreshMessage = serde_json::from_str(parse_string_body(request)?.as_str()) + .map_err(|e| HTTPError::new(e.to_string(), 400))?; - let tokens = database.users.refresh_tokens(&message.refresh_token)?; + let tokens = database.users.refresh_tokens(&message.refresh_token)?; - Ok(Response::json(&tokens)) - } else { - Err(HTTPError::new("Missing Request Data".to_string(), 400)) - } + Ok(Response::json(&tokens)) } + + fn logout(database: &Database, request: &Request) -> HTTPResult { + let message: LogoutMessage = serde_json::from_str(parse_string_body(request)?.as_str()) + .map_err(|e| HTTPError::new(e.to_string(), 400))?; + let success = database.users.delete_tokens(&message.request_token)?; + + Ok(Response::json(&LogoutConfirmation { success }).with_status_code(205)) + } +} + +/// Parses the body of a http request into a string representation +fn parse_string_body(request: &Request) -> HTTPResult { + let mut body = request + .data() + .ok_or(HTTPError::new("Missing request data!".to_string(), 400))?; + let mut string_body = String::new(); + body.read_to_string(&mut string_body) + .map_err(|e| HTTPError::new(format!("Failed to parse request data {}", e), 400))?; + + Ok(string_body) } diff --git a/src/server/messages.rs b/src/server/messages.rs index 91adea8..bc43962 100644 --- a/src/server/messages.rs +++ b/src/server/messages.rs @@ -88,3 +88,14 @@ pub struct LoginMessage { pub struct RefreshMessage { pub refresh_token: String, } + +#[derive(Deserialize, Zeroize)] +#[zeroize(drop)] +pub struct LogoutMessage { + pub request_token: String, +} + +#[derive(Serialize)] +pub struct LogoutConfirmation { + pub success: bool, +}