feat: implement authentication module with JWT support and user management

This commit is contained in:
GW_MC
2025-12-15 15:51:45 +08:00
parent 1233f3b736
commit 3354154b87
13 changed files with 1232 additions and 5 deletions

View File

@@ -24,5 +24,5 @@ utoipa = { version = "5.4.0", features = ["macros", "axum_extras", "chrono", "de
clap = { version = "4.5.53" }
once_cell = { version = "1.21.3" }
argon2 = { version = "0.5.3", features = ["std"] }
jsonwebtoken = { version = "10.2.0" }
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] }

View File

@@ -12,7 +12,10 @@ use crate::{
configs::{ProgramSettings, get_program_settings, logging::LoggingSettings},
log,
routes::{self, AppService, AppState},
services::settings::SettingsService,
services::{
auth::{authentication::AuthenticationServiceImpl, user::UserServiceImpl},
settings::SettingsService,
},
tasks,
};
@@ -58,6 +61,9 @@ pub async fn start_server() {
tasks::startup::run_startup_tasks(&settings)
.await
.inspect_err(|err| {
tracing::error!("Failed to run startup tasks: {}", err);
})
.expect("Failed to run startup tasks");
// setup database connection pool
@@ -78,7 +84,7 @@ pub async fn start_server() {
// build the axum app and run the server...
info!("Starting application...");
let app: Router = routes::get_root_router(Arc::new(get_app_state(&db_connection)));
let app: Router = routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings)));
let address = format!("{}:{}", settings.server.address, settings.server.port);
info!("Starting server at http://{}", address);
@@ -115,11 +121,18 @@ fn get_global_tracing_subscriber_builder(
}
}
fn get_app_state(db_connection: &Arc<sea_orm::DatabaseConnection>) -> AppState {
fn get_app_state(
db_connection: &Arc<sea_orm::DatabaseConnection>,
settings: &ProgramSettings,
) -> AppState {
AppState {
database_connection: db_connection.clone(),
service: Arc::new(AppService {
settings: Arc::new(SettingsService::new(db_connection.clone())),
authentication: Arc::new(AuthenticationServiceImpl::new(
settings.auth.jwt_secret.clone(),
)),
user: Arc::new(UserServiceImpl::new(db_connection.clone())),
}),
}
}

View File

@@ -1,3 +1,4 @@
pub mod auth;
pub mod database;
pub mod logging;
pub mod server;
@@ -17,6 +18,7 @@ pub struct ProgramSettings {
pub logging: logging::LoggingSettings,
pub database: database::DatabaseSettings,
pub server: server::ServerSettings,
pub auth: auth::AuthSettings,
}
impl FromConfig for ProgramSettings {
@@ -25,6 +27,7 @@ impl FromConfig for ProgramSettings {
logging: logging::LoggingSettings::from_config(_config)?,
database: database::DatabaseSettings::from_config(_config)?,
server: server::ServerSettings::from_config(_config)?,
auth: auth::AuthSettings::from_config(_config)?,
};
config.validate()?;
Ok(config)
@@ -34,6 +37,7 @@ impl FromConfig for ProgramSettings {
self.logging.validate()?;
self.database.validate()?;
self.server.validate()?;
self.auth.validate()?;
Ok(())
}
}

View File

@@ -0,0 +1,51 @@
use config::{Config, ConfigError};
use tracing::warn;
use crate::configs::key::{
AUTH_DEFAULT_ADMIN_PASSWORD_KEY, AUTH_DEFAULT_ADMIN_USERNAME_KEY, AUTH_JWT_SECRET_KEY,
};
use super::FromConfig;
#[derive(Debug, Clone)]
pub struct AuthSettings {
pub jwt_secret: Option<String>,
pub default_admin_username: Option<String>,
pub default_admin_password: Option<String>,
}
impl FromConfig for AuthSettings {
fn from_config(_config: &Config) -> Result<Self, String> {
Ok(AuthSettings {
jwt_secret: _config
.get_string(AUTH_JWT_SECRET_KEY)
.inspect_err(|err| {
match err {
ConfigError::NotFound(_) => {
warn!(
"{} not found in configuration, A random secret will be generated at runtime.",
AUTH_JWT_SECRET_KEY
);
}
_ => {
warn!(
"Failed to read {} from configuration, A random secret will be generated at runtime: {}",
AUTH_JWT_SECRET_KEY, err
);
}
};
})
.ok(),
default_admin_username: _config
.get_string(AUTH_DEFAULT_ADMIN_USERNAME_KEY)
.ok(),
default_admin_password: _config
.get_string(AUTH_DEFAULT_ADMIN_PASSWORD_KEY)
.ok(),
})
}
fn validate(&self) -> Result<(), String> {
Ok(())
}
}

View File

@@ -7,3 +7,7 @@ pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT";
pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL";
pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS";
pub(crate) const DATABASE_MIGRATE_ON_STARTUP_KEY: &str = "DATABASE.MIGRATION.MIGRATE_ON_STARTUP";
//
pub(crate) const AUTH_JWT_SECRET_KEY: &str = "AUTH.JWT_SECRET";
pub(crate) const AUTH_DEFAULT_ADMIN_USERNAME_KEY: &str = "AUTH.DEFAULT_ADMIN_USERNAME";
pub(crate) const AUTH_DEFAULT_ADMIN_PASSWORD_KEY: &str = "AUTH.DEFAULT_ADMIN_PASSWORD";

View File

@@ -8,7 +8,13 @@ use std::sync::Arc;
use axum::{Extension, Router};
use migration::sea_orm::DatabaseConnection;
use crate::{middlewares, services::settings::SettingsStore};
use crate::{
middlewares,
services::{
auth::{authentication::AuthenticationService, user::UserService},
settings::SettingsStore,
},
};
#[derive(Clone)]
pub struct AppState {
@@ -25,6 +31,10 @@ pub type ServiceState<T> = Arc<T>;
pub struct AppService {
#[allow(dead_code)] // TODO: remove when used
pub settings: ServiceState<dyn SettingsStore>,
#[allow(dead_code)] // TODO: remove when used
pub authentication: ServiceState<dyn AuthenticationService>,
#[allow(dead_code)] // TODO: remove when used
pub user: ServiceState<dyn UserService>,
}
pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {

View File

@@ -1 +1,2 @@
pub mod auth;
pub mod settings;

View File

@@ -0,0 +1,2 @@
pub mod authentication;
pub mod user;

View File

@@ -0,0 +1,269 @@
pub mod strategies;
use std::{collections::HashSet, sync::Arc};
use argon2::password_hash::{SaltString, rand_core::OsRng};
use jsonwebtoken::{
DecodingKey, EncodingKey, Header, Validation, decode, encode,
errors::ErrorKind::{ExpiredSignature, InvalidSubject, InvalidToken},
};
use sea_orm::prelude::Uuid;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::errors::service_error::ServiceError;
// Number of requests between invalidation cache cleanups
const INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS: usize = 100; // Cleanup every 100 for invalidation checks
#[derive(Serialize, Deserialize, Clone)]
pub struct Claims {
// Subject - user ID
pub sub: String,
// Issued at as UNIX timestamp
pub iat: u64,
// Expiration time as UNIX timestamp
pub exp: u64,
}
#[async_trait::async_trait]
pub trait AuthenticationService: Send + Sync {
async fn generate_jwt(&self, user_id: Uuid, duration_secs: u64)
-> Result<String, ServiceError>;
async fn is_valid_jwt(
&self,
token: &str,
target_sub: Option<String>,
) -> Result<bool, ServiceError>;
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError>;
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>;
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError>;
async fn logout(&self, token: &str) -> Result<(), ServiceError>;
async fn cleanup_invalidation_cache(&self);
}
#[derive(Eq, Hash, PartialEq)]
struct InvalidationEntry {
token: String,
invalidated_at: u64,
valid_until: u64,
}
pub struct AuthenticationServiceImpl {
secret: String,
invalidation_cache: Arc<RwLock<HashSet<InvalidationEntry>>>,
cache_cleanup_counter: Arc<RwLock<usize>>,
}
impl AuthenticationServiceImpl {
pub fn new(secret: Option<String>) -> Self {
let secret = secret.unwrap_or_else(|| {
// generate a random secret if none is provided
SaltString::generate(&mut OsRng).as_str().to_owned()
});
Self {
secret,
invalidation_cache: Arc::new(RwLock::new(HashSet::new())),
cache_cleanup_counter: Arc::new(RwLock::new(0)),
}
}
}
#[async_trait::async_trait]
impl AuthenticationService for AuthenticationServiceImpl {
async fn generate_jwt(
&self,
user_id: Uuid,
duration_secs: u64,
) -> Result<String, ServiceError> {
let header = Header::default();
let expiration = chrono::Utc::now()
.checked_add_signed(chrono::Duration::seconds(duration_secs as i64))
.ok_or(ServiceError::InternalError(
"Invalid expiration time".into(),
))?
.timestamp() as u64;
let claims = Claims {
sub: user_id.to_string(),
iat: chrono::Utc::now().timestamp() as u64,
exp: expiration,
};
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(self.secret.as_ref()),
)
.map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?;
Ok(token)
}
async fn is_valid_jwt(
&self,
token: &str,
target_sub: Option<String>,
) -> Result<bool, ServiceError> {
let mut validation = Validation::default();
if let Some(expected_sub) = target_sub {
validation.sub = Some(expected_sub);
}
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
match decode::<Claims>(token, &decoding_key, &validation) {
Ok(_) => Ok(true),
Err(err) => match *err.kind() {
InvalidToken | InvalidSubject | ExpiredSignature => Ok(false),
_ => Err(ServiceError::InternalError(format!(
"JWT validation error: {}",
err
))),
},
}
}
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError> {
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
let token_data = decode::<Claims>(token, &decoding_key, &Validation::default())
.map_err(|e| ServiceError::InternalError(format!("JWT parsing error: {}", e)))?;
Ok(token_data.claims)
}
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError> {
let claims = self.parse_jwt(token).await?;
let valid_until = claims.exp;
let invalidated_at = chrono::Utc::now().timestamp() as u64;
let entry = InvalidationEntry {
token: token.to_string(),
invalidated_at,
valid_until,
};
{
self.invalidation_cache.write().await.insert(entry);
}
//
if self.cache_cleanup_counter.read().await.wrapping_add(1)
% INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS
== 0
{
self.cleanup_invalidation_cache().await;
}
//
Ok(())
}
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError> {
let claims = self.parse_jwt(token).await?;
let user_id = Uuid::parse_str(&claims.sub).map_err(|e| {
ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e))
})?;
let new_token = self.generate_jwt(user_id, duration_secs).await?;
Ok(new_token)
}
async fn logout(&self, token: &str) -> Result<(), ServiceError> {
self.invalidate_jwt(token).await
}
async fn cleanup_invalidation_cache(&self) {
let now = chrono::Utc::now().timestamp() as u64;
let mut cache = self.invalidation_cache.write().await;
cache.retain(|entry| entry.valid_until > now);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_jwt_generation_and_validation() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service
.generate_jwt(user_id, 60)
.await
.expect("generate jwt");
let valid = service
.is_valid_jwt(&token, None)
.await
.expect("validate jwt");
assert!(valid, "Generated token should be valid");
let claims = service.parse_jwt(&token).await.expect("parse jwt");
assert_eq!(claims.sub, user_id.to_string());
}
#[tokio::test]
async fn test_jwt_validation_with_wrong_subject() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 60).await.unwrap();
let other_sub = Uuid::new_v4().to_string();
let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap();
assert!(!valid, "Token should be invalid for a different subject");
}
#[tokio::test]
async fn test_parse_jwt_invalid_token() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let res = service.parse_jwt("not_a_token").await;
assert!(matches!(res, Err(ServiceError::InternalError(_))));
}
#[tokio::test]
async fn test_refresh_jwt() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 60).await.unwrap();
let new_token = service.refresh_jwt(&token, 120).await.unwrap();
let claims = service.parse_jwt(&new_token).await.unwrap();
assert_eq!(claims.sub, user_id.to_string());
assert_eq!(claims.exp - claims.iat, 120);
}
#[tokio::test]
async fn test_is_valid_jwt_expired() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 1).await.unwrap();
sleep(Duration::from_secs(2)).await;
let valid = service.is_valid_jwt(&token, None).await.unwrap();
assert!(!valid, "Token should be expired and thus invalid");
}
#[tokio::test]
async fn test_invalidate_and_cleanup() {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 1).await.unwrap();
service.invalidate_jwt(&token).await.unwrap();
// ensure entry is present
{
let cache = service.invalidation_cache.read().await;
assert!(cache.iter().any(|e| e.token == token));
}
// wait until token validity ends and cleanup
sleep(Duration::from_secs(2)).await;
service.cleanup_invalidation_cache().await;
let cache = service.invalidation_cache.read().await;
assert!(
cache.is_empty(),
"Cleanup should remove expired invalidation entries"
);
}
}

View File

@@ -0,0 +1 @@
pub mod password;

View File

@@ -0,0 +1,453 @@
use std::sync::Arc;
use crate::{errors::service_error::ServiceError, with_conn};
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng},
};
use database::generated::entities::{user, user_identity};
use sea_orm::{
ColumnTrait, DatabaseConnection, DatabaseTransaction, EntityTrait, IntoActiveModel,
QueryFilter, prelude::Uuid,
};
pub struct PasswordStrategy {
connection: Arc<DatabaseConnection>,
}
const MAX_PASSWORD_LENGTH: usize = 32;
const PASSWORD_PROVIDER: &str = "password";
impl PasswordStrategy {
pub fn new(connection: Arc<DatabaseConnection>) -> Self {
Self { connection }
}
pub async fn authenticate(
&self,
username: &str,
password: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<Uuid, ServiceError> {
// Find user by username
let user = with_conn!(&*self.connection, tx, conn, {
user::Entity::find()
.filter(user::Column::Name.eq(username))
.one(*conn)
.await?
.ok_or_else(|| {
ServiceError::Unauthorized("Invalid username or password".to_string())
})?
});
// Get user's identity
let identity = with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::find()
.filter(user_identity::Column::UserId.eq(user.id))
.one(*conn)
.await?
.ok_or_else(|| {
ServiceError::Unauthorized("Invalid username or password".to_string())
})?
});
// Check if revoked
if identity.is_revoked {
return Err(ServiceError::Unauthorized("Account is revoked".to_string()));
}
// Verify password
let password_hash = identity
.password_hash
.ok_or_else(|| ServiceError::InternalError("Invalid password hash".to_string()))?;
let parsed_hash = PasswordHash::new(&password_hash)
.map_err(|_| ServiceError::InternalError("Invalid password hash".to_string()))?;
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.map_err(|_| ServiceError::Unauthorized("Invalid username or password".to_string()))?;
Ok(user.id)
}
pub async fn revoke_identity(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
let mut identity = with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::find()
.filter(user_identity::Column::UserId.eq(user_id))
.one(*conn)
.await?
.ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))?
});
identity.is_revoked = true;
with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::update(identity.into_active_model())
.exec(*conn)
.await
.map_err(ServiceError::from)
})?;
Ok(())
}
pub async fn create_identity(
&self,
user_id: Uuid,
password: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
Self::is_valid_password(password).map_err(ServiceError::BadRequest)?;
let password_hash = Argon2::default()
.hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))
.map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))?
.to_string();
let new_identity = user_identity::ActiveModel {
user_id: sea_orm::ActiveValue::Set(user_id),
provider: sea_orm::ActiveValue::Set(PASSWORD_PROVIDER.to_string()),
password_hash: sea_orm::ActiveValue::Set(Some(password_hash)),
metadata: sea_orm::ActiveValue::Set(None),
is_revoked: sea_orm::ActiveValue::Set(false),
..Default::default()
};
with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::insert(new_identity)
.exec(*conn)
.await
.map_err(ServiceError::from)
})?;
Ok(())
}
pub async fn update_password(
&self,
user_id: Uuid,
new_password: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
Self::is_valid_password(new_password).map_err(ServiceError::BadRequest)?;
let password_hash = Argon2::default()
.hash_password(new_password.as_bytes(), &SaltString::generate(&mut OsRng))
.map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))?
.to_string();
let mut identity = with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::find()
.filter(user_identity::Column::UserId.eq(user_id))
.one(*conn)
.await?
.ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))?
});
identity.password_hash = Some(password_hash);
identity.password_changed_at = Some(chrono::Utc::now());
with_conn!(&*self.connection, tx, conn, {
user_identity::Entity::update(identity.into_active_model())
.exec(*conn)
.await
.map_err(ServiceError::from)
})?;
Ok(())
}
fn is_valid_password(password: &str) -> Result<(), String> {
if password.is_empty() {
return Err("Password cannot be empty".to_string());
}
if password.len() > MAX_PASSWORD_LENGTH {
return Err(format!(
"Password cannot be longer than {} characters",
MAX_PASSWORD_LENGTH
));
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
use database::generated::entities::{user, user_identity};
use sea_orm::MockDatabase;
#[test]
fn ensure_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<PasswordStrategy>();
}
#[test]
fn password_validation() {
let valid_password = "ValidPassword123!";
let long_password = "a".repeat(129);
assert!(PasswordStrategy::is_valid_password(valid_password).is_ok());
assert!(PasswordStrategy::is_valid_password(long_password.as_str()).is_err());
}
#[tokio::test]
async fn authenticate_user_not_found() {
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.authenticate("nonexistent_user", "password", None)
.await;
assert!(matches!(result, Err(ServiceError::Unauthorized(_))));
}
#[tokio::test]
async fn authenticate_invalid_password() {
let user_id = Uuid::new_v4();
let password_hash = Argon2::default()
.hash_password(
"CorrectPassword".as_bytes(),
&SaltString::generate(&mut OsRng),
)
.unwrap()
.to_string();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![vec![user::Model {
id: user_id,
name: "test_user".to_string(),
is_active: true,
is_admin: false,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
deleted_at: None,
last_login_at: None,
}]])
.append_query_results(vec![vec![user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some(password_hash),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}]])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.authenticate("test_user", "InvalidPassword", None)
.await;
assert!(matches!(result, Err(ServiceError::Unauthorized(_))));
}
#[tokio::test]
async fn authenticate_success() {
let user_id = Uuid::new_v4();
let password_hash = Argon2::default()
.hash_password(
"CorrectPassword".as_bytes(),
&SaltString::generate(&mut OsRng),
)
.unwrap()
.to_string();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![vec![user::Model {
id: user_id,
name: "test_user".to_string(),
is_active: true,
is_admin: false,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
deleted_at: None,
last_login_at: None,
}]])
.append_query_results(vec![vec![user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some(password_hash),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}]])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.authenticate("test_user", "CorrectPassword", None)
.await;
assert!(matches!(result, Ok(id) if id == user_id));
}
#[tokio::test]
async fn revoke_identity_not_found() {
let user_id = Uuid::new_v4();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.revoke_identity(user_id, None).await;
assert!(matches!(result, Err(ServiceError::NotFound(_))));
}
#[tokio::test]
async fn revoke_identity_success() {
let user_id = Uuid::new_v4();
let identity = user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: None,
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
};
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![
vec![identity.clone()],
vec![user_identity::Model {
is_revoked: true,
..identity
}],
])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.revoke_identity(user_id, None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn create_identity_invalid_password() {
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite).into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.create_identity(Uuid::new_v4(), "", None).await;
assert!(matches!(result, Err(ServiceError::BadRequest(_))));
}
#[tokio::test]
async fn create_identity_success() {
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![vec![user_identity::Model {
id: Uuid::new_v4(),
user_id: Uuid::new_v4(),
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some("somehash".to_string()),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}]])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy
.create_identity(Uuid::new_v4(), "ValidPass1!", None)
.await;
assert!(
result.is_ok(),
"Failed to create identity, error: {:?}",
result.err()
);
}
#[tokio::test]
async fn update_password_not_found() {
let user_id = Uuid::new_v4();
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![Vec::<sea_orm::MockRow>::new()])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.update_password(user_id, "NewPass1!", None).await;
assert!(matches!(result, Err(ServiceError::NotFound(_))));
}
#[tokio::test]
async fn update_password_success() {
let user_id = Uuid::new_v4();
let identity = user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some("oldhash".to_string()),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
};
let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite)
.append_query_results(vec![
vec![identity],
vec![user_identity::Model {
id: Uuid::new_v4(),
user_id,
email: None,
provider: PASSWORD_PROVIDER.to_string(),
password_hash: Some("newhash".to_string()),
metadata: None,
is_revoked: false,
revoked_at: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
password_changed_at: None,
}],
])
.into_connection();
let strategy = PasswordStrategy::new(Arc::new(db));
let result = strategy.update_password(user_id, "NewPass1!", None).await;
assert!(
result.is_ok(),
"Failed to update password, error: {:?}",
result.err()
);
}
}

View File

@@ -0,0 +1,208 @@
use std::sync::Arc;
use database::generated::entities::user::{
self, ActiveModel as UserActiveModel, Model as UserModel,
};
use sea_orm::{
ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, DatabaseTransaction, DbErr,
EntityTrait, IntoActiveModel, QueryFilter, prelude::Uuid,
};
use crate::{errors::service_error::ServiceError, with_conn};
#[async_trait::async_trait]
pub trait UserService: Send + Sync {
async fn get_user_by_id(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError>;
async fn is_admin(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError>;
async fn user_exists(
&self,
username: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError>;
async fn create_user(
&self,
user: NewUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError>;
async fn update_user(
&self,
user_id: Uuid,
user: UpdateUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError>;
async fn delete_user(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError>;
}
pub struct User {
pub id: Uuid,
pub username: String,
pub is_admin: bool,
}
impl From<UserModel> for User {
fn from(model: UserModel) -> Self {
Self {
id: model.id,
username: model.name,
is_admin: model.is_admin,
}
}
}
pub struct NewUser {
pub username: String,
pub is_admin: bool,
}
pub struct UpdateUser {
pub username: Option<String>,
pub is_admin: Option<bool>,
pub is_active: Option<bool>,
}
impl UpdateUser {
fn apply_to_active_model(&self, model: &mut UserActiveModel) {
if let Some(username) = &self.username {
model.name = ActiveValue::Set(username.clone());
}
if let Some(is_admin) = self.is_admin {
model.is_admin = ActiveValue::Set(is_admin);
}
if let Some(is_active) = self.is_active {
model.is_active = ActiveValue::Set(is_active);
}
}
}
pub struct UserServiceImpl {
connection: Arc<DatabaseConnection>,
}
impl UserServiceImpl {
pub fn new(connection: Arc<DatabaseConnection>) -> Self {
Self { connection }
}
async fn get_user_by_id_from_db(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<UserModel, ServiceError> {
let user = with_conn!(&*self.connection, tx, conn, {
user::Entity::find_by_id(user_id).one(*conn).await
});
match user {
Err(err) => Err(ServiceError::from(err)),
Ok(None) => Err(ServiceError::NotFound(format!(
"User with id '{}' not found",
user_id
))),
Ok(Some(record)) => Ok(record),
}
}
}
#[async_trait::async_trait]
impl UserService for UserServiceImpl {
async fn get_user_by_id(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError> {
let user = self.get_user_by_id_from_db(user_id, tx).await?;
Ok(User::from(user))
}
async fn is_admin(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError> {
let user = self.get_user_by_id(user_id, tx).await?;
Ok(user.is_admin)
}
async fn user_exists(
&self,
username: &str,
tx: Option<&mut DatabaseTransaction>,
) -> Result<bool, ServiceError> {
let user = with_conn!(&*self.connection, tx, conn, {
user::Entity::find()
.filter(user::Column::Name.eq(username))
.one(*conn)
.await
});
match user {
Err(err) => match err {
DbErr::RecordNotFound(_) => Ok(false),
_ => Err(ServiceError::from(err)),
},
Ok(None) => Ok(false),
Ok(Some(_)) => Ok(true),
}
}
async fn create_user(
&self,
user: NewUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError> {
let user_active_model = UserActiveModel {
id: ActiveValue::NotSet,
name: ActiveValue::Set(user.username),
is_admin: ActiveValue::Set(user.is_admin),
is_active: ActiveValue::Set(true),
..Default::default()
};
let user_model = with_conn!(&*self.connection, tx, conn, {
user_active_model.insert(*conn).await
})?;
Ok(User::from(user_model))
}
async fn update_user(
&self,
user_id: Uuid,
update_user: UpdateUser,
tx: Option<&mut DatabaseTransaction>,
) -> Result<User, ServiceError> {
let existing_user = self.get_user_by_id_from_db(user_id, tx).await?;
let mut user_active_model = existing_user.into_active_model();
update_user.apply_to_active_model(&mut user_active_model);
let user_model = user_active_model.update(&*self.connection).await?;
Ok(User::from(user_model))
}
async fn delete_user(
&self,
user_id: Uuid,
tx: Option<&mut DatabaseTransaction>,
) -> Result<(), ServiceError> {
let user = self.get_user_by_id_from_db(user_id, tx).await?;
let user_active_model = user.into_active_model();
user_active_model.delete(&*self.connection).await?;
Ok(())
}
}