feat: implement authentication module with JWT support and user management
This commit is contained in:
@@ -24,5 +24,5 @@ utoipa = { version = "5.4.0", features = ["macros", "axum_extras", "chrono", "de
|
||||
clap = { version = "4.5.53" }
|
||||
once_cell = { version = "1.21.3" }
|
||||
argon2 = { version = "0.5.3", features = ["std"] }
|
||||
jsonwebtoken = { version = "10.2.0" }
|
||||
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
|
||||
uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] }
|
||||
|
||||
@@ -12,7 +12,10 @@ use crate::{
|
||||
configs::{ProgramSettings, get_program_settings, logging::LoggingSettings},
|
||||
log,
|
||||
routes::{self, AppService, AppState},
|
||||
services::settings::SettingsService,
|
||||
services::{
|
||||
auth::{authentication::AuthenticationServiceImpl, user::UserServiceImpl},
|
||||
settings::SettingsService,
|
||||
},
|
||||
tasks,
|
||||
};
|
||||
|
||||
@@ -58,6 +61,9 @@ pub async fn start_server() {
|
||||
|
||||
tasks::startup::run_startup_tasks(&settings)
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
tracing::error!("Failed to run startup tasks: {}", err);
|
||||
})
|
||||
.expect("Failed to run startup tasks");
|
||||
|
||||
// setup database connection pool
|
||||
@@ -78,7 +84,7 @@ pub async fn start_server() {
|
||||
|
||||
// build the axum app and run the server...
|
||||
info!("Starting application...");
|
||||
let app: Router = routes::get_root_router(Arc::new(get_app_state(&db_connection)));
|
||||
let app: Router = routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings)));
|
||||
|
||||
let address = format!("{}:{}", settings.server.address, settings.server.port);
|
||||
info!("Starting server at http://{}", address);
|
||||
@@ -115,11 +121,18 @@ fn get_global_tracing_subscriber_builder(
|
||||
}
|
||||
}
|
||||
|
||||
fn get_app_state(db_connection: &Arc<sea_orm::DatabaseConnection>) -> AppState {
|
||||
fn get_app_state(
|
||||
db_connection: &Arc<sea_orm::DatabaseConnection>,
|
||||
settings: &ProgramSettings,
|
||||
) -> AppState {
|
||||
AppState {
|
||||
database_connection: db_connection.clone(),
|
||||
service: Arc::new(AppService {
|
||||
settings: Arc::new(SettingsService::new(db_connection.clone())),
|
||||
authentication: Arc::new(AuthenticationServiceImpl::new(
|
||||
settings.auth.jwt_secret.clone(),
|
||||
)),
|
||||
user: Arc::new(UserServiceImpl::new(db_connection.clone())),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod auth;
|
||||
pub mod database;
|
||||
pub mod logging;
|
||||
pub mod server;
|
||||
@@ -17,6 +18,7 @@ pub struct ProgramSettings {
|
||||
pub logging: logging::LoggingSettings,
|
||||
pub database: database::DatabaseSettings,
|
||||
pub server: server::ServerSettings,
|
||||
pub auth: auth::AuthSettings,
|
||||
}
|
||||
|
||||
impl FromConfig for ProgramSettings {
|
||||
@@ -25,6 +27,7 @@ impl FromConfig for ProgramSettings {
|
||||
logging: logging::LoggingSettings::from_config(_config)?,
|
||||
database: database::DatabaseSettings::from_config(_config)?,
|
||||
server: server::ServerSettings::from_config(_config)?,
|
||||
auth: auth::AuthSettings::from_config(_config)?,
|
||||
};
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
@@ -34,6 +37,7 @@ impl FromConfig for ProgramSettings {
|
||||
self.logging.validate()?;
|
||||
self.database.validate()?;
|
||||
self.server.validate()?;
|
||||
self.auth.validate()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
51
apps/api/src/configs/auth.rs
Normal file
51
apps/api/src/configs/auth.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use config::{Config, ConfigError};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::configs::key::{
|
||||
AUTH_DEFAULT_ADMIN_PASSWORD_KEY, AUTH_DEFAULT_ADMIN_USERNAME_KEY, AUTH_JWT_SECRET_KEY,
|
||||
};
|
||||
|
||||
use super::FromConfig;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthSettings {
|
||||
pub jwt_secret: Option<String>,
|
||||
pub default_admin_username: Option<String>,
|
||||
pub default_admin_password: Option<String>,
|
||||
}
|
||||
|
||||
impl FromConfig for AuthSettings {
|
||||
fn from_config(_config: &Config) -> Result<Self, String> {
|
||||
Ok(AuthSettings {
|
||||
jwt_secret: _config
|
||||
.get_string(AUTH_JWT_SECRET_KEY)
|
||||
.inspect_err(|err| {
|
||||
match err {
|
||||
ConfigError::NotFound(_) => {
|
||||
warn!(
|
||||
"{} not found in configuration, A random secret will be generated at runtime.",
|
||||
AUTH_JWT_SECRET_KEY
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
warn!(
|
||||
"Failed to read {} from configuration, A random secret will be generated at runtime: {}",
|
||||
AUTH_JWT_SECRET_KEY, err
|
||||
);
|
||||
}
|
||||
};
|
||||
})
|
||||
.ok(),
|
||||
default_admin_username: _config
|
||||
.get_string(AUTH_DEFAULT_ADMIN_USERNAME_KEY)
|
||||
.ok(),
|
||||
default_admin_password: _config
|
||||
.get_string(AUTH_DEFAULT_ADMIN_PASSWORD_KEY)
|
||||
.ok(),
|
||||
})
|
||||
}
|
||||
|
||||
fn validate(&self) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -7,3 +7,7 @@ pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT";
|
||||
pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL";
|
||||
pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS";
|
||||
pub(crate) const DATABASE_MIGRATE_ON_STARTUP_KEY: &str = "DATABASE.MIGRATION.MIGRATE_ON_STARTUP";
|
||||
//
|
||||
pub(crate) const AUTH_JWT_SECRET_KEY: &str = "AUTH.JWT_SECRET";
|
||||
pub(crate) const AUTH_DEFAULT_ADMIN_USERNAME_KEY: &str = "AUTH.DEFAULT_ADMIN_USERNAME";
|
||||
pub(crate) const AUTH_DEFAULT_ADMIN_PASSWORD_KEY: &str = "AUTH.DEFAULT_ADMIN_PASSWORD";
|
||||
|
||||
@@ -8,7 +8,13 @@ use std::sync::Arc;
|
||||
use axum::{Extension, Router};
|
||||
use migration::sea_orm::DatabaseConnection;
|
||||
|
||||
use crate::{middlewares, services::settings::SettingsStore};
|
||||
use crate::{
|
||||
middlewares,
|
||||
services::{
|
||||
auth::{authentication::AuthenticationService, user::UserService},
|
||||
settings::SettingsStore,
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@@ -25,6 +31,10 @@ pub type ServiceState<T> = Arc<T>;
|
||||
pub struct AppService {
|
||||
#[allow(dead_code)] // TODO: remove when used
|
||||
pub settings: ServiceState<dyn SettingsStore>,
|
||||
#[allow(dead_code)] // TODO: remove when used
|
||||
pub authentication: ServiceState<dyn AuthenticationService>,
|
||||
#[allow(dead_code)] // TODO: remove when used
|
||||
pub user: ServiceState<dyn UserService>,
|
||||
}
|
||||
|
||||
pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
pub mod auth;
|
||||
pub mod settings;
|
||||
|
||||
2
apps/api/src/services/auth.rs
Normal file
2
apps/api/src/services/auth.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod authentication;
|
||||
pub mod user;
|
||||
269
apps/api/src/services/auth/authentication.rs
Normal file
269
apps/api/src/services/auth/authentication.rs
Normal file
@@ -0,0 +1,269 @@
|
||||
pub mod strategies;
|
||||
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use argon2::password_hash::{SaltString, rand_core::OsRng};
|
||||
use jsonwebtoken::{
|
||||
DecodingKey, EncodingKey, Header, Validation, decode, encode,
|
||||
errors::ErrorKind::{ExpiredSignature, InvalidSubject, InvalidToken},
|
||||
};
|
||||
use sea_orm::prelude::Uuid;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::errors::service_error::ServiceError;
|
||||
|
||||
// Number of requests between invalidation cache cleanups
|
||||
const INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS: usize = 100; // Cleanup every 100 for invalidation checks
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
// Subject - user ID
|
||||
pub sub: String,
|
||||
// Issued at as UNIX timestamp
|
||||
pub iat: u64,
|
||||
// Expiration time as UNIX timestamp
|
||||
pub exp: u64,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait AuthenticationService: Send + Sync {
|
||||
async fn generate_jwt(&self, user_id: Uuid, duration_secs: u64)
|
||||
-> Result<String, ServiceError>;
|
||||
async fn is_valid_jwt(
|
||||
&self,
|
||||
token: &str,
|
||||
target_sub: Option<String>,
|
||||
) -> Result<bool, ServiceError>;
|
||||
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError>;
|
||||
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>;
|
||||
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError>;
|
||||
async fn logout(&self, token: &str) -> Result<(), ServiceError>;
|
||||
async fn cleanup_invalidation_cache(&self);
|
||||
}
|
||||
|
||||
#[derive(Eq, Hash, PartialEq)]
|
||||
struct InvalidationEntry {
|
||||
token: String,
|
||||
invalidated_at: u64,
|
||||
valid_until: u64,
|
||||
}
|
||||
|
||||
pub struct AuthenticationServiceImpl {
|
||||
secret: String,
|
||||
invalidation_cache: Arc<RwLock<HashSet<InvalidationEntry>>>,
|
||||
cache_cleanup_counter: Arc<RwLock<usize>>,
|
||||
}
|
||||
|
||||
impl AuthenticationServiceImpl {
|
||||
pub fn new(secret: Option<String>) -> Self {
|
||||
let secret = secret.unwrap_or_else(|| {
|
||||
// generate a random secret if none is provided
|
||||
SaltString::generate(&mut OsRng).as_str().to_owned()
|
||||
});
|
||||
|
||||
Self {
|
||||
secret,
|
||||
invalidation_cache: Arc::new(RwLock::new(HashSet::new())),
|
||||
cache_cleanup_counter: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AuthenticationService for AuthenticationServiceImpl {
|
||||
async fn generate_jwt(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
duration_secs: u64,
|
||||
) -> Result<String, ServiceError> {
|
||||
let header = Header::default();
|
||||
let expiration = chrono::Utc::now()
|
||||
.checked_add_signed(chrono::Duration::seconds(duration_secs as i64))
|
||||
.ok_or(ServiceError::InternalError(
|
||||
"Invalid expiration time".into(),
|
||||
))?
|
||||
.timestamp() as u64;
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
iat: chrono::Utc::now().timestamp() as u64,
|
||||
exp: expiration,
|
||||
};
|
||||
let token = encode(
|
||||
&header,
|
||||
&claims,
|
||||
&EncodingKey::from_secret(self.secret.as_ref()),
|
||||
)
|
||||
.map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
async fn is_valid_jwt(
|
||||
&self,
|
||||
token: &str,
|
||||
target_sub: Option<String>,
|
||||
) -> Result<bool, ServiceError> {
|
||||
let mut validation = Validation::default();
|
||||
if let Some(expected_sub) = target_sub {
|
||||
validation.sub = Some(expected_sub);
|
||||
}
|
||||
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
|
||||
match decode::<Claims>(token, &decoding_key, &validation) {
|
||||
Ok(_) => Ok(true),
|
||||
Err(err) => match *err.kind() {
|
||||
InvalidToken | InvalidSubject | ExpiredSignature => Ok(false),
|
||||
_ => Err(ServiceError::InternalError(format!(
|
||||
"JWT validation error: {}",
|
||||
err
|
||||
))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError> {
|
||||
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
|
||||
let token_data = decode::<Claims>(token, &decoding_key, &Validation::default())
|
||||
.map_err(|e| ServiceError::InternalError(format!("JWT parsing error: {}", e)))?;
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError> {
|
||||
let claims = self.parse_jwt(token).await?;
|
||||
let valid_until = claims.exp;
|
||||
let invalidated_at = chrono::Utc::now().timestamp() as u64;
|
||||
let entry = InvalidationEntry {
|
||||
token: token.to_string(),
|
||||
invalidated_at,
|
||||
valid_until,
|
||||
};
|
||||
|
||||
{
|
||||
self.invalidation_cache.write().await.insert(entry);
|
||||
}
|
||||
//
|
||||
if self.cache_cleanup_counter.read().await.wrapping_add(1)
|
||||
% INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS
|
||||
== 0
|
||||
{
|
||||
self.cleanup_invalidation_cache().await;
|
||||
}
|
||||
//
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError> {
|
||||
let claims = self.parse_jwt(token).await?;
|
||||
let user_id = Uuid::parse_str(&claims.sub).map_err(|e| {
|
||||
ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e))
|
||||
})?;
|
||||
let new_token = self.generate_jwt(user_id, duration_secs).await?;
|
||||
Ok(new_token)
|
||||
}
|
||||
|
||||
async fn logout(&self, token: &str) -> Result<(), ServiceError> {
|
||||
self.invalidate_jwt(token).await
|
||||
}
|
||||
|
||||
async fn cleanup_invalidation_cache(&self) {
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let mut cache = self.invalidation_cache.write().await;
|
||||
cache.retain(|entry| entry.valid_until > now);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::time::{Duration, sleep};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_jwt_generation_and_validation() {
|
||||
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
||||
|
||||
let user_id = Uuid::new_v4();
|
||||
let token = service
|
||||
.generate_jwt(user_id, 60)
|
||||
.await
|
||||
.expect("generate jwt");
|
||||
|
||||
let valid = service
|
||||
.is_valid_jwt(&token, None)
|
||||
.await
|
||||
.expect("validate jwt");
|
||||
assert!(valid, "Generated token should be valid");
|
||||
|
||||
let claims = service.parse_jwt(&token).await.expect("parse jwt");
|
||||
assert_eq!(claims.sub, user_id.to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_jwt_validation_with_wrong_subject() {
|
||||
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
||||
|
||||
let user_id = Uuid::new_v4();
|
||||
let token = service.generate_jwt(user_id, 60).await.unwrap();
|
||||
|
||||
let other_sub = Uuid::new_v4().to_string();
|
||||
let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap();
|
||||
assert!(!valid, "Token should be invalid for a different subject");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_jwt_invalid_token() {
|
||||
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
||||
|
||||
let res = service.parse_jwt("not_a_token").await;
|
||||
assert!(matches!(res, Err(ServiceError::InternalError(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_jwt() {
|
||||
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
||||
|
||||
let user_id = Uuid::new_v4();
|
||||
let token = service.generate_jwt(user_id, 60).await.unwrap();
|
||||
let new_token = service.refresh_jwt(&token, 120).await.unwrap();
|
||||
|
||||
let claims = service.parse_jwt(&new_token).await.unwrap();
|
||||
assert_eq!(claims.sub, user_id.to_string());
|
||||
assert_eq!(claims.exp - claims.iat, 120);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_is_valid_jwt_expired() {
|
||||
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
||||
|
||||
let user_id = Uuid::new_v4();
|
||||
let token = service.generate_jwt(user_id, 1).await.unwrap();
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let valid = service.is_valid_jwt(&token, None).await.unwrap();
|
||||
assert!(!valid, "Token should be expired and thus invalid");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalidate_and_cleanup() {
|
||||
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
|
||||
|
||||
let user_id = Uuid::new_v4();
|
||||
let token = service.generate_jwt(user_id, 1).await.unwrap();
|
||||
|
||||
service.invalidate_jwt(&token).await.unwrap();
|
||||
|
||||
// ensure entry is present
|
||||
{
|
||||
let cache = service.invalidation_cache.read().await;
|
||||
assert!(cache.iter().any(|e| e.token == token));
|
||||
}
|
||||
|
||||
// wait until token validity ends and cleanup
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
service.cleanup_invalidation_cache().await;
|
||||
|
||||
let cache = service.invalidation_cache.read().await;
|
||||
assert!(
|
||||
cache.is_empty(),
|
||||
"Cleanup should remove expired invalidation entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
1
apps/api/src/services/auth/authentication/strategies.rs
Normal file
1
apps/api/src/services/auth/authentication/strategies.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod password;
|
||||
453
apps/api/src/services/auth/authentication/strategies/password.rs
Normal file
453
apps/api/src/services/auth/authentication/strategies/password.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{errors::service_error::ServiceError, with_conn};
|
||||
use argon2::{
|
||||
Argon2,
|
||||
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
|
||||
};
|
||||
use database::generated::entities::{user, user_identity};
|
||||
use sea_orm::{
|
||||
ColumnTrait, DatabaseConnection, DatabaseTransaction, EntityTrait, IntoActiveModel,
|
||||
QueryFilter, prelude::Uuid,
|
||||
};
|
||||
|
||||
pub struct PasswordStrategy {
|
||||
connection: Arc<DatabaseConnection>,
|
||||
}
|
||||
|
||||
const MAX_PASSWORD_LENGTH: usize = 32;
|
||||
const PASSWORD_PROVIDER: &str = "password";
|
||||
|
||||
impl PasswordStrategy {
|
||||
pub fn new(connection: Arc<DatabaseConnection>) -> Self {
|
||||
Self { connection }
|
||||
}
|
||||
|
||||
pub async fn authenticate(
|
||||
&self,
|
||||
username: &str,
|
||||
password: &str,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<Uuid, ServiceError> {
|
||||
// Find user by username
|
||||
let user = with_conn!(&*self.connection, tx, conn, {
|
||||
user::Entity::find()
|
||||
.filter(user::Column::Name.eq(username))
|
||||
.one(*conn)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ServiceError::Unauthorized("Invalid username or password".to_string())
|
||||
})?
|
||||
});
|
||||
// Get user's identity
|
||||
let identity = with_conn!(&*self.connection, tx, conn, {
|
||||
user_identity::Entity::find()
|
||||
.filter(user_identity::Column::UserId.eq(user.id))
|
||||
.one(*conn)
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
ServiceError::Unauthorized("Invalid username or password".to_string())
|
||||
})?
|
||||
});
|
||||
|
||||
// Check if revoked
|
||||
if identity.is_revoked {
|
||||
return Err(ServiceError::Unauthorized("Account is revoked".to_string()));
|
||||
}
|
||||
|
||||
// Verify password
|
||||
let password_hash = identity
|
||||
.password_hash
|
||||
.ok_or_else(|| ServiceError::InternalError("Invalid password hash".to_string()))?;
|
||||
let parsed_hash = PasswordHash::new(&password_hash)
|
||||
.map_err(|_| ServiceError::InternalError("Invalid password hash".to_string()))?;
|
||||
|
||||
Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.map_err(|_| ServiceError::Unauthorized("Invalid username or password".to_string()))?;
|
||||
|
||||
Ok(user.id)
|
||||
}
|
||||
|
||||
pub async fn revoke_identity(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<(), ServiceError> {
|
||||
let mut identity = with_conn!(&*self.connection, tx, conn, {
|
||||
user_identity::Entity::find()
|
||||
.filter(user_identity::Column::UserId.eq(user_id))
|
||||
.one(*conn)
|
||||
.await?
|
||||
.ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))?
|
||||
});
|
||||
|
||||
identity.is_revoked = true;
|
||||
|
||||
with_conn!(&*self.connection, tx, conn, {
|
||||
user_identity::Entity::update(identity.into_active_model())
|
||||
.exec(*conn)
|
||||
.await
|
||||
.map_err(ServiceError::from)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_identity(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
password: &str,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<(), ServiceError> {
|
||||
Self::is_valid_password(password).map_err(ServiceError::BadRequest)?;
|
||||
|
||||
let password_hash = Argon2::default()
|
||||
.hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))
|
||||
.map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))?
|
||||
.to_string();
|
||||
|
||||
let new_identity = user_identity::ActiveModel {
|
||||
user_id: sea_orm::ActiveValue::Set(user_id),
|
||||
provider: sea_orm::ActiveValue::Set(PASSWORD_PROVIDER.to_string()),
|
||||
password_hash: sea_orm::ActiveValue::Set(Some(password_hash)),
|
||||
metadata: sea_orm::ActiveValue::Set(None),
|
||||
is_revoked: sea_orm::ActiveValue::Set(false),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
with_conn!(&*self.connection, tx, conn, {
|
||||
user_identity::Entity::insert(new_identity)
|
||||
.exec(*conn)
|
||||
.await
|
||||
.map_err(ServiceError::from)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_password(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
new_password: &str,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<(), ServiceError> {
|
||||
Self::is_valid_password(new_password).map_err(ServiceError::BadRequest)?;
|
||||
|
||||
let password_hash = Argon2::default()
|
||||
.hash_password(new_password.as_bytes(), &SaltString::generate(&mut OsRng))
|
||||
.map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))?
|
||||
.to_string();
|
||||
|
||||
let mut identity = with_conn!(&*self.connection, tx, conn, {
|
||||
user_identity::Entity::find()
|
||||
.filter(user_identity::Column::UserId.eq(user_id))
|
||||
.one(*conn)
|
||||
.await?
|
||||
.ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))?
|
||||
});
|
||||
|
||||
identity.password_hash = Some(password_hash);
|
||||
identity.password_changed_at = Some(chrono::Utc::now());
|
||||
|
||||
with_conn!(&*self.connection, tx, conn, {
|
||||
user_identity::Entity::update(identity.into_active_model())
|
||||
.exec(*conn)
|
||||
.await
|
||||
.map_err(ServiceError::from)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_valid_password(password: &str) -> Result<(), String> {
|
||||
if password.is_empty() {
|
||||
return Err("Password cannot be empty".to_string());
|
||||
}
|
||||
if password.len() > MAX_PASSWORD_LENGTH {
|
||||
return Err(format!(
|
||||
"Password cannot be longer than {} characters",
|
||||
MAX_PASSWORD_LENGTH
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use database::generated::entities::{user, user_identity};
|
||||
use sea_orm::MockDatabase;
|
||||
|
||||
#[test]
|
||||
fn ensure_send_sync() {
|
||||
fn assert_send_sync<T: Send + Sync>() {}
|
||||
assert_send_sync::<PasswordStrategy>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn password_validation() {
|
||||
let valid_password = "ValidPassword123!";
|
||||
let long_password = "a".repeat(129);
|
||||
|
||||
assert!(PasswordStrategy::is_valid_password(valid_password).is_ok());
|
||||
assert!(PasswordStrategy::is_valid_password(long_password.as_str()).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn authenticate_user_not_found() {
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy
|
||||
.authenticate("nonexistent_user", "password", None)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ServiceError::Unauthorized(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn authenticate_invalid_password() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let password_hash = Argon2::default()
|
||||
.hash_password(
|
||||
"CorrectPassword".as_bytes(),
|
||||
&SaltString::generate(&mut OsRng),
|
||||
)
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![vec![user::Model {
|
||||
id: user_id,
|
||||
name: "test_user".to_string(),
|
||||
is_active: true,
|
||||
is_admin: false,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
deleted_at: None,
|
||||
last_login_at: None,
|
||||
}]])
|
||||
.append_query_results(vec![vec![user_identity::Model {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
email: None,
|
||||
provider: PASSWORD_PROVIDER.to_string(),
|
||||
password_hash: Some(password_hash),
|
||||
metadata: None,
|
||||
is_revoked: false,
|
||||
revoked_at: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
password_changed_at: None,
|
||||
}]])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy
|
||||
.authenticate("test_user", "InvalidPassword", None)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ServiceError::Unauthorized(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn authenticate_success() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let password_hash = Argon2::default()
|
||||
.hash_password(
|
||||
"CorrectPassword".as_bytes(),
|
||||
&SaltString::generate(&mut OsRng),
|
||||
)
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![vec![user::Model {
|
||||
id: user_id,
|
||||
name: "test_user".to_string(),
|
||||
is_active: true,
|
||||
is_admin: false,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
deleted_at: None,
|
||||
last_login_at: None,
|
||||
}]])
|
||||
.append_query_results(vec![vec![user_identity::Model {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
email: None,
|
||||
provider: PASSWORD_PROVIDER.to_string(),
|
||||
password_hash: Some(password_hash),
|
||||
metadata: None,
|
||||
is_revoked: false,
|
||||
revoked_at: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
password_changed_at: None,
|
||||
}]])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy
|
||||
.authenticate("test_user", "CorrectPassword", None)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Ok(id) if id == user_id));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn revoke_identity_not_found() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy.revoke_identity(user_id, None).await;
|
||||
|
||||
assert!(matches!(result, Err(ServiceError::NotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn revoke_identity_success() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let identity = user_identity::Model {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
email: None,
|
||||
provider: PASSWORD_PROVIDER.to_string(),
|
||||
password_hash: None,
|
||||
metadata: None,
|
||||
is_revoked: false,
|
||||
revoked_at: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
password_changed_at: None,
|
||||
};
|
||||
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![
|
||||
vec![identity.clone()],
|
||||
vec![user_identity::Model {
|
||||
is_revoked: true,
|
||||
..identity
|
||||
}],
|
||||
])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy.revoke_identity(user_id, None).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_identity_invalid_password() {
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite).into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy.create_identity(Uuid::new_v4(), "", None).await;
|
||||
|
||||
assert!(matches!(result, Err(ServiceError::BadRequest(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_identity_success() {
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![vec![user_identity::Model {
|
||||
id: Uuid::new_v4(),
|
||||
user_id: Uuid::new_v4(),
|
||||
email: None,
|
||||
provider: PASSWORD_PROVIDER.to_string(),
|
||||
password_hash: Some("somehash".to_string()),
|
||||
metadata: None,
|
||||
is_revoked: false,
|
||||
revoked_at: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
password_changed_at: None,
|
||||
}]])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy
|
||||
.create_identity(Uuid::new_v4(), "ValidPass1!", None)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to create identity, error: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_password_not_found() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy.update_password(user_id, "NewPass1!", None).await;
|
||||
|
||||
assert!(matches!(result, Err(ServiceError::NotFound(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_password_success() {
|
||||
let user_id = Uuid::new_v4();
|
||||
let identity = user_identity::Model {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
email: None,
|
||||
provider: PASSWORD_PROVIDER.to_string(),
|
||||
password_hash: Some("oldhash".to_string()),
|
||||
metadata: None,
|
||||
is_revoked: false,
|
||||
revoked_at: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
password_changed_at: None,
|
||||
};
|
||||
|
||||
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
|
||||
.append_query_results(vec![
|
||||
vec![identity],
|
||||
vec![user_identity::Model {
|
||||
id: Uuid::new_v4(),
|
||||
user_id,
|
||||
email: None,
|
||||
provider: PASSWORD_PROVIDER.to_string(),
|
||||
password_hash: Some("newhash".to_string()),
|
||||
metadata: None,
|
||||
is_revoked: false,
|
||||
revoked_at: None,
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
password_changed_at: None,
|
||||
}],
|
||||
])
|
||||
.into_connection();
|
||||
|
||||
let strategy = PasswordStrategy::new(Arc::new(db));
|
||||
|
||||
let result = strategy.update_password(user_id, "NewPass1!", None).await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Failed to update password, error: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
208
apps/api/src/services/auth/user.rs
Normal file
208
apps/api/src/services/auth/user.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use database::generated::entities::user::{
|
||||
self, ActiveModel as UserActiveModel, Model as UserModel,
|
||||
};
|
||||
use sea_orm::{
|
||||
ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, DatabaseTransaction, DbErr,
|
||||
EntityTrait, IntoActiveModel, QueryFilter, prelude::Uuid,
|
||||
};
|
||||
|
||||
use crate::{errors::service_error::ServiceError, with_conn};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait UserService: Send + Sync {
|
||||
async fn get_user_by_id(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<User, ServiceError>;
|
||||
async fn is_admin(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<bool, ServiceError>;
|
||||
async fn user_exists(
|
||||
&self,
|
||||
username: &str,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<bool, ServiceError>;
|
||||
async fn create_user(
|
||||
&self,
|
||||
user: NewUser,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<User, ServiceError>;
|
||||
async fn update_user(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
user: UpdateUser,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<User, ServiceError>;
|
||||
async fn delete_user(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<(), ServiceError>;
|
||||
}
|
||||
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub is_admin: bool,
|
||||
}
|
||||
|
||||
impl From<UserModel> for User {
|
||||
fn from(model: UserModel) -> Self {
|
||||
Self {
|
||||
id: model.id,
|
||||
username: model.name,
|
||||
is_admin: model.is_admin,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NewUser {
|
||||
pub username: String,
|
||||
pub is_admin: bool,
|
||||
}
|
||||
|
||||
pub struct UpdateUser {
|
||||
pub username: Option<String>,
|
||||
pub is_admin: Option<bool>,
|
||||
pub is_active: Option<bool>,
|
||||
}
|
||||
|
||||
impl UpdateUser {
|
||||
fn apply_to_active_model(&self, model: &mut UserActiveModel) {
|
||||
if let Some(username) = &self.username {
|
||||
model.name = ActiveValue::Set(username.clone());
|
||||
}
|
||||
if let Some(is_admin) = self.is_admin {
|
||||
model.is_admin = ActiveValue::Set(is_admin);
|
||||
}
|
||||
if let Some(is_active) = self.is_active {
|
||||
model.is_active = ActiveValue::Set(is_active);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UserServiceImpl {
|
||||
connection: Arc<DatabaseConnection>,
|
||||
}
|
||||
|
||||
impl UserServiceImpl {
|
||||
pub fn new(connection: Arc<DatabaseConnection>) -> Self {
|
||||
Self { connection }
|
||||
}
|
||||
|
||||
async fn get_user_by_id_from_db(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<UserModel, ServiceError> {
|
||||
let user = with_conn!(&*self.connection, tx, conn, {
|
||||
user::Entity::find_by_id(user_id).one(*conn).await
|
||||
});
|
||||
|
||||
match user {
|
||||
Err(err) => Err(ServiceError::from(err)),
|
||||
Ok(None) => Err(ServiceError::NotFound(format!(
|
||||
"User with id '{}' not found",
|
||||
user_id
|
||||
))),
|
||||
Ok(Some(record)) => Ok(record),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UserService for UserServiceImpl {
|
||||
async fn get_user_by_id(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<User, ServiceError> {
|
||||
let user = self.get_user_by_id_from_db(user_id, tx).await?;
|
||||
Ok(User::from(user))
|
||||
}
|
||||
|
||||
async fn is_admin(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<bool, ServiceError> {
|
||||
let user = self.get_user_by_id(user_id, tx).await?;
|
||||
Ok(user.is_admin)
|
||||
}
|
||||
|
||||
async fn user_exists(
|
||||
&self,
|
||||
username: &str,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<bool, ServiceError> {
|
||||
let user = with_conn!(&*self.connection, tx, conn, {
|
||||
user::Entity::find()
|
||||
.filter(user::Column::Name.eq(username))
|
||||
.one(*conn)
|
||||
.await
|
||||
});
|
||||
|
||||
match user {
|
||||
Err(err) => match err {
|
||||
DbErr::RecordNotFound(_) => Ok(false),
|
||||
_ => Err(ServiceError::from(err)),
|
||||
},
|
||||
Ok(None) => Ok(false),
|
||||
Ok(Some(_)) => Ok(true),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_user(
|
||||
&self,
|
||||
user: NewUser,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<User, ServiceError> {
|
||||
let user_active_model = UserActiveModel {
|
||||
id: ActiveValue::NotSet,
|
||||
name: ActiveValue::Set(user.username),
|
||||
is_admin: ActiveValue::Set(user.is_admin),
|
||||
is_active: ActiveValue::Set(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let user_model = with_conn!(&*self.connection, tx, conn, {
|
||||
user_active_model.insert(*conn).await
|
||||
})?;
|
||||
|
||||
Ok(User::from(user_model))
|
||||
}
|
||||
|
||||
async fn update_user(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
update_user: UpdateUser,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<User, ServiceError> {
|
||||
let existing_user = self.get_user_by_id_from_db(user_id, tx).await?;
|
||||
|
||||
let mut user_active_model = existing_user.into_active_model();
|
||||
update_user.apply_to_active_model(&mut user_active_model);
|
||||
|
||||
let user_model = user_active_model.update(&*self.connection).await?;
|
||||
|
||||
Ok(User::from(user_model))
|
||||
}
|
||||
|
||||
async fn delete_user(
|
||||
&self,
|
||||
user_id: Uuid,
|
||||
tx: Option<&mut DatabaseTransaction>,
|
||||
) -> Result<(), ServiceError> {
|
||||
let user = self.get_user_by_id_from_db(user_id, tx).await?;
|
||||
|
||||
let user_active_model = user.into_active_model();
|
||||
user_active_model.delete(&*self.connection).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user