290 lines
9.8 KiB
Rust
290 lines
9.8 KiB
Rust
pub mod strategies;
|
|
|
|
use std::{collections::HashSet, sync::Arc};
|
|
|
|
use argon2::password_hash::{SaltString, rand_core::OsRng};
|
|
use jsonwebtoken::{
|
|
DecodingKey, EncodingKey, Header, Validation, decode, encode,
|
|
errors::ErrorKind::{ExpiredSignature, InvalidSignature, InvalidSubject, InvalidToken},
|
|
};
|
|
use sea_orm::prelude::Uuid;
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::errors::service_error::ServiceError;
|
|
|
|
// Number of requests between invalidation cache cleanups
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
const INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS: usize = 100; // Cleanup every 100 for invalidation checks
|
|
|
|
#[derive(Serialize, Deserialize, Clone)]
|
|
pub struct Claims {
|
|
// Subject - user ID
|
|
pub sub: String,
|
|
// Issued at as UNIX timestamp
|
|
pub iat: u64,
|
|
// Expiration time as UNIX timestamp
|
|
pub exp: u64,
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
pub trait AuthenticationService: Send + Sync {
|
|
async fn generate_jwt(
|
|
&self,
|
|
user_id: Uuid,
|
|
duration_secs: u64,
|
|
) -> Result<(String, Claims), ServiceError>;
|
|
async fn is_valid_jwt(
|
|
&self,
|
|
token: &str,
|
|
target_sub: Option<String>,
|
|
) -> Result<Option<Claims>, ServiceError>;
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError>;
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>;
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError>;
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
async fn logout(&self, token: &str) -> Result<(), ServiceError>;
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
async fn cleanup_invalidation_cache(&self);
|
|
}
|
|
|
|
#[derive(Eq, Hash, PartialEq)]
|
|
struct InvalidationEntry {
|
|
token: String,
|
|
invalidated_at: u64,
|
|
valid_until: u64,
|
|
}
|
|
|
|
pub struct AuthenticationServiceImpl {
|
|
secret: String,
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
invalidation_cache: Arc<RwLock<HashSet<InvalidationEntry>>>,
|
|
#[allow(dead_code)] // TODO: remove when used
|
|
cache_cleanup_counter: Arc<RwLock<usize>>,
|
|
}
|
|
|
|
impl AuthenticationServiceImpl {
|
|
pub fn new(secret: Option<String>) -> Self {
|
|
let secret = secret.unwrap_or_else(|| {
|
|
// generate a random secret if none is provided
|
|
SaltString::generate(&mut OsRng).as_str().to_owned()
|
|
});
|
|
|
|
Self {
|
|
secret,
|
|
invalidation_cache: Arc::new(RwLock::new(HashSet::new())),
|
|
cache_cleanup_counter: Arc::new(RwLock::new(0)),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl AuthenticationService for AuthenticationServiceImpl {
|
|
async fn generate_jwt(
|
|
&self,
|
|
user_id: Uuid,
|
|
duration_secs: u64,
|
|
) -> Result<(String, Claims), ServiceError> {
|
|
let header = Header::default();
|
|
let expiration = chrono::Utc::now()
|
|
.checked_add_signed(chrono::Duration::seconds(duration_secs as i64))
|
|
.ok_or(ServiceError::InternalError(
|
|
"Invalid expiration time".into(),
|
|
))?
|
|
.timestamp() as u64;
|
|
let claims = Claims {
|
|
sub: user_id.to_string(),
|
|
iat: chrono::Utc::now().timestamp() as u64,
|
|
exp: expiration,
|
|
};
|
|
let token = encode(
|
|
&header,
|
|
&claims,
|
|
&EncodingKey::from_secret(self.secret.as_ref()),
|
|
)
|
|
.map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?;
|
|
Ok((token, claims))
|
|
}
|
|
|
|
async fn is_valid_jwt(
|
|
&self,
|
|
token: &str,
|
|
target_sub: Option<String>,
|
|
) -> Result<Option<Claims>, ServiceError> {
|
|
let mut validation = Validation::default();
|
|
// disable leeway for strict expiration checking
|
|
validation.leeway = 0;
|
|
if let Some(expected_sub) = target_sub {
|
|
validation.sub = Some(expected_sub);
|
|
}
|
|
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
|
|
match decode::<Claims>(token, &decoding_key, &validation) {
|
|
Ok(data) => Ok(Some(data.claims)),
|
|
Err(err) => match *err.kind() {
|
|
InvalidToken | InvalidSubject | ExpiredSignature | InvalidSignature => Ok(None),
|
|
_ => Err(ServiceError::InternalError(format!(
|
|
"JWT validation error: {}",
|
|
err
|
|
))),
|
|
},
|
|
}
|
|
}
|
|
|
|
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError> {
|
|
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
|
|
let token_data = decode::<Claims>(token, &decoding_key, &Validation::default())
|
|
.map_err(|e| ServiceError::InternalError(format!("JWT parsing error: {}", e)))?;
|
|
Ok(token_data.claims)
|
|
}
|
|
|
|
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError> {
|
|
let claims = self.parse_jwt(token).await?;
|
|
let valid_until = claims.exp;
|
|
let invalidated_at = chrono::Utc::now().timestamp() as u64;
|
|
let entry = InvalidationEntry {
|
|
token: token.to_string(),
|
|
invalidated_at,
|
|
valid_until,
|
|
};
|
|
|
|
{
|
|
self.invalidation_cache.write().await.insert(entry);
|
|
}
|
|
//
|
|
if self.cache_cleanup_counter.read().await.wrapping_add(1)
|
|
% INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS
|
|
== 0
|
|
{
|
|
self.cleanup_invalidation_cache().await;
|
|
}
|
|
//
|
|
Ok(())
|
|
}
|
|
|
|
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError> {
|
|
let claims = self.parse_jwt(token).await?;
|
|
let user_id = Uuid::parse_str(&claims.sub).map_err(|e| {
|
|
ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e))
|
|
})?;
|
|
let (new_token, _) = self.generate_jwt(user_id, duration_secs).await?;
|
|
Ok(new_token)
|
|
}
|
|
|
|
async fn logout(&self, token: &str) -> Result<(), ServiceError> {
|
|
self.invalidate_jwt(token).await
|
|
}
|
|
|
|
async fn cleanup_invalidation_cache(&self) {
|
|
let now = chrono::Utc::now().timestamp() as u64;
|
|
let mut cache = self.invalidation_cache.write().await;
|
|
cache.retain(|entry| entry.valid_until > now);
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use tokio::time::{Duration, sleep};
|
|
|
|
#[tokio::test]
|
|
async fn test_jwt_generation_and_validation() {
|
|
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
|
|
|
let user_id = Uuid::new_v4();
|
|
let (token, _) = service
|
|
.generate_jwt(user_id, 60)
|
|
.await
|
|
.expect("generate jwt");
|
|
|
|
let valid = service
|
|
.is_valid_jwt(&token, None)
|
|
.await
|
|
.expect("validate jwt");
|
|
assert!(valid.is_some(), "Generated token should be valid");
|
|
let claims = service.parse_jwt(&token).await.expect("parse jwt");
|
|
assert_eq!(claims.sub, user_id.to_string());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_jwt_validation_with_wrong_subject() {
|
|
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
|
|
|
let user_id = Uuid::new_v4();
|
|
let (token, _) = service.generate_jwt(user_id, 60).await.unwrap();
|
|
|
|
let other_sub = Uuid::new_v4().to_string();
|
|
let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap();
|
|
assert!(
|
|
valid.is_none(),
|
|
"Token should be invalid for a different subject"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_parse_jwt_invalid_token() {
|
|
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
|
|
|
let res = service.parse_jwt("not_a_token").await;
|
|
assert!(matches!(res, Err(ServiceError::InternalError(_))));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_refresh_jwt() {
|
|
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
|
|
|
let user_id = Uuid::new_v4();
|
|
let (token, _) = service.generate_jwt(user_id, 60).await.unwrap();
|
|
let new_token = service.refresh_jwt(&token, 120).await.unwrap();
|
|
|
|
let claims = service.parse_jwt(&new_token).await.unwrap();
|
|
assert_eq!(claims.sub, user_id.to_string());
|
|
assert_eq!(claims.exp - claims.iat, 120);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_is_valid_jwt_expired() {
|
|
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
|
|
|
let user_id = Uuid::new_v4();
|
|
let (token, claims) = service.generate_jwt(user_id, 1).await.unwrap();
|
|
sleep(Duration::from_secs(2)).await;
|
|
|
|
let valid = service.is_valid_jwt(&token, None).await.unwrap();
|
|
assert!(
|
|
valid.is_none(),
|
|
"Token should be expired and thus invalid. Current time: {:?}. Diff: {}",
|
|
chrono::Utc::now(),
|
|
chrono::Utc::now().timestamp() - claims.exp as i64
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_invalidate_and_cleanup() {
|
|
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
|
|
|
let user_id = Uuid::new_v4();
|
|
let (token, _) = service.generate_jwt(user_id, 1).await.unwrap();
|
|
|
|
service.invalidate_jwt(&token).await.unwrap();
|
|
|
|
// ensure entry is present
|
|
{
|
|
let cache = service.invalidation_cache.read().await;
|
|
assert!(cache.iter().any(|e| e.token == token));
|
|
}
|
|
|
|
// wait until token validity ends and cleanup
|
|
sleep(Duration::from_secs(2)).await;
|
|
service.cleanup_invalidation_cache().await;
|
|
|
|
let cache = service.invalidation_cache.read().await;
|
|
assert!(
|
|
cache.is_empty(),
|
|
"Cleanup should remove expired invalidation entries"
|
|
);
|
|
}
|
|
}
|