pub mod strategies; use std::{collections::HashSet, sync::Arc}; use argon2::password_hash::{SaltString, rand_core::OsRng}; use jsonwebtoken::{ DecodingKey, EncodingKey, Header, Validation, decode, encode, errors::ErrorKind::{ExpiredSignature, InvalidSignature, InvalidSubject, InvalidToken}, }; use sea_orm::prelude::Uuid; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; use crate::errors::service_error::ServiceError; // Number of requests between invalidation cache cleanups #[allow(dead_code)] // TODO: remove when used const INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS: usize = 100; // Cleanup every 100 for invalidation checks #[derive(Serialize, Deserialize, Clone)] pub struct Claims { // Subject - user ID pub sub: String, // Issued at as UNIX timestamp pub iat: u64, // Expiration time as UNIX timestamp pub exp: u64, } #[async_trait::async_trait] pub trait AuthenticationService: Send + Sync { async fn generate_jwt( &self, user_id: Uuid, duration_secs: u64, ) -> Result<(String, Claims), ServiceError>; async fn is_valid_jwt( &self, token: &str, target_sub: Option, ) -> Result, ServiceError>; #[allow(dead_code)] // TODO: remove when used async fn parse_jwt(&self, token: &str) -> Result; #[allow(dead_code)] // TODO: remove when used async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>; #[allow(dead_code)] // TODO: remove when used async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result; #[allow(dead_code)] // TODO: remove when used async fn logout(&self, token: &str) -> Result<(), ServiceError>; #[allow(dead_code)] // TODO: remove when used async fn cleanup_invalidation_cache(&self); } #[derive(Eq, Hash, PartialEq)] struct InvalidationEntry { token: String, invalidated_at: u64, valid_until: u64, } pub struct AuthenticationServiceImpl { secret: String, #[allow(dead_code)] // TODO: remove when used invalidation_cache: Arc>>, #[allow(dead_code)] // TODO: remove when used cache_cleanup_counter: Arc>, } impl AuthenticationServiceImpl { pub fn new(secret: Option) -> Self { let secret = secret.unwrap_or_else(|| { // generate a random secret if none is provided SaltString::generate(&mut OsRng).as_str().to_owned() }); Self { secret, invalidation_cache: Arc::new(RwLock::new(HashSet::new())), cache_cleanup_counter: Arc::new(RwLock::new(0)), } } } #[async_trait::async_trait] impl AuthenticationService for AuthenticationServiceImpl { async fn generate_jwt( &self, user_id: Uuid, duration_secs: u64, ) -> Result<(String, Claims), ServiceError> { let header = Header::default(); let expiration = chrono::Utc::now() .checked_add_signed(chrono::Duration::seconds(duration_secs as i64)) .ok_or(ServiceError::InternalError( "Invalid expiration time".into(), ))? .timestamp() as u64; let claims = Claims { sub: user_id.to_string(), iat: chrono::Utc::now().timestamp() as u64, exp: expiration, }; let token = encode( &header, &claims, &EncodingKey::from_secret(self.secret.as_ref()), ) .map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?; Ok((token, claims)) } async fn is_valid_jwt( &self, token: &str, target_sub: Option, ) -> Result, ServiceError> { let mut validation = Validation::default(); // disable leeway for strict expiration checking validation.leeway = 0; if let Some(expected_sub) = target_sub { validation.sub = Some(expected_sub); } let decoding_key = DecodingKey::from_secret(self.secret.as_ref()); match decode::(token, &decoding_key, &validation) { Ok(data) => Ok(Some(data.claims)), Err(err) => match *err.kind() { InvalidToken | InvalidSubject | ExpiredSignature | InvalidSignature => Ok(None), _ => Err(ServiceError::InternalError(format!( "JWT validation error: {}", err ))), }, } } async fn parse_jwt(&self, token: &str) -> Result { let decoding_key = DecodingKey::from_secret(self.secret.as_ref()); let token_data = decode::(token, &decoding_key, &Validation::default()) .map_err(|e| ServiceError::InternalError(format!("JWT parsing error: {}", e)))?; Ok(token_data.claims) } async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError> { let claims = self.parse_jwt(token).await?; let valid_until = claims.exp; let invalidated_at = chrono::Utc::now().timestamp() as u64; let entry = InvalidationEntry { token: token.to_string(), invalidated_at, valid_until, }; { self.invalidation_cache.write().await.insert(entry); } // if self.cache_cleanup_counter.read().await.wrapping_add(1) % INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS == 0 { self.cleanup_invalidation_cache().await; } // Ok(()) } async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result { let claims = self.parse_jwt(token).await?; let user_id = Uuid::parse_str(&claims.sub).map_err(|e| { ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e)) })?; let (new_token, _) = self.generate_jwt(user_id, duration_secs).await?; Ok(new_token) } async fn logout(&self, token: &str) -> Result<(), ServiceError> { self.invalidate_jwt(token).await } async fn cleanup_invalidation_cache(&self) { let now = chrono::Utc::now().timestamp() as u64; let mut cache = self.invalidation_cache.write().await; cache.retain(|entry| entry.valid_until > now); } } #[cfg(test)] mod tests { use super::*; use tokio::time::{Duration, sleep}; #[tokio::test] async fn test_jwt_generation_and_validation() { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); let (token, _) = service .generate_jwt(user_id, 60) .await .expect("generate jwt"); let valid = service .is_valid_jwt(&token, None) .await .expect("validate jwt"); assert!(valid.is_some(), "Generated token should be valid"); let claims = service.parse_jwt(&token).await.expect("parse jwt"); assert_eq!(claims.sub, user_id.to_string()); } #[tokio::test] async fn test_jwt_validation_with_wrong_subject() { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); let (token, _) = service.generate_jwt(user_id, 60).await.unwrap(); let other_sub = Uuid::new_v4().to_string(); let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap(); assert!( valid.is_none(), "Token should be invalid for a different subject" ); } #[tokio::test] async fn test_parse_jwt_invalid_token() { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let res = service.parse_jwt("not_a_token").await; assert!(matches!(res, Err(ServiceError::InternalError(_)))); } #[tokio::test] async fn test_refresh_jwt() { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); let (token, _) = service.generate_jwt(user_id, 60).await.unwrap(); let new_token = service.refresh_jwt(&token, 120).await.unwrap(); let claims = service.parse_jwt(&new_token).await.unwrap(); assert_eq!(claims.sub, user_id.to_string()); assert_eq!(claims.exp - claims.iat, 120); } #[tokio::test] async fn test_is_valid_jwt_expired() { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); let (token, claims) = service.generate_jwt(user_id, 1).await.unwrap(); sleep(Duration::from_secs(2)).await; let valid = service.is_valid_jwt(&token, None).await.unwrap(); assert!( valid.is_none(), "Token should be expired and thus invalid. Current time: {:?}. Diff: {}", chrono::Utc::now(), chrono::Utc::now().timestamp() - claims.exp as i64 ); } #[tokio::test] async fn test_invalidate_and_cleanup() { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); let (token, _) = service.generate_jwt(user_id, 1).await.unwrap(); service.invalidate_jwt(&token).await.unwrap(); // ensure entry is present { let cache = service.invalidation_cache.read().await; assert!(cache.iter().any(|e| e.token == token)); } // wait until token validity ends and cleanup sleep(Duration::from_secs(2)).await; service.cleanup_invalidation_cache().await; let cache = service.invalidation_cache.read().await; assert!( cache.is_empty(), "Cleanup should remove expired invalidation entries" ); } }