diff --git a/Cargo.lock b/Cargo.lock index 63168a3..06944b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -235,6 +235,12 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.21.7" @@ -672,6 +678,18 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -682,6 +700,33 @@ dependencies = [ "typenum", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "digest", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.110", +] + [[package]] name = "darling" version = "0.20.11" @@ -862,6 +907,44 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + +[[package]] +name = "ed25519" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" +dependencies = [ + "pkcs8", + "signature", +] + +[[package]] +name = "ed25519-dalek" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" +dependencies = [ + "curve25519-dalek", + "ed25519", + "serde", + "sha2", + "subtle", + "zeroize", +] + [[package]] name = "either" version = "1.15.0" @@ -871,6 +954,27 @@ dependencies = [ "serde", ] +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest", + "ff", + "generic-array", + "group", + "hkdf", + "pem-rfc7468", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + [[package]] name = "encoding_rs" version = "0.8.35" @@ -946,6 +1050,22 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "filetime" version = "0.2.26" @@ -1125,6 +1245,7 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", ] [[package]] @@ -1156,6 +1277,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "h2" version = "0.4.12" @@ -1640,11 +1772,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e" dependencies = [ "base64 0.22.1", + "ed25519-dalek", "getrandom 0.2.16", + "hmac", "js-sys", + "p256", + "p384", "pem", + "rand 0.8.5", + "rsa", "serde", "serde_json", + "sha2", "signature", "simple_asn1", ] @@ -1974,6 +2113,30 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "p256" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + +[[package]] +name = "p384" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" +dependencies = [ + "ecdsa", + "elliptic-curve", + "primeorder", + "sha2", +] + [[package]] name = "parking" version = "2.2.1" @@ -2211,6 +2374,15 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve", +] + [[package]] name = "proc-macro-crate" version = "3.4.0" @@ -2440,6 +2612,16 @@ dependencies = [ "bytecheck", ] +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + [[package]] name = "ring" version = "0.17.14" @@ -2543,6 +2725,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.1.2" @@ -2852,6 +3043,20 @@ version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + [[package]] name = "security-framework" version = "2.11.1" @@ -2888,6 +3093,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" diff --git a/apps/api/Cargo.toml b/apps/api/Cargo.toml index a86ee30..3ec1a3d 100644 --- a/apps/api/Cargo.toml +++ b/apps/api/Cargo.toml @@ -24,5 +24,5 @@ utoipa = { version = "5.4.0", features = ["macros", "axum_extras", "chrono", "de clap = { version = "4.5.53" } once_cell = { version = "1.21.3" } argon2 = { version = "0.5.3", features = ["std"] } -jsonwebtoken = { version = "10.2.0" } +jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] } uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] } diff --git a/apps/api/src/cmd/start_server.rs b/apps/api/src/cmd/start_server.rs index 461c0b3..d1d1ce8 100644 --- a/apps/api/src/cmd/start_server.rs +++ b/apps/api/src/cmd/start_server.rs @@ -12,7 +12,10 @@ use crate::{ configs::{ProgramSettings, get_program_settings, logging::LoggingSettings}, log, routes::{self, AppService, AppState}, - services::settings::SettingsService, + services::{ + auth::{authentication::AuthenticationServiceImpl, user::UserServiceImpl}, + settings::SettingsService, + }, tasks, }; @@ -58,6 +61,9 @@ pub async fn start_server() { tasks::startup::run_startup_tasks(&settings) .await + .inspect_err(|err| { + tracing::error!("Failed to run startup tasks: {}", err); + }) .expect("Failed to run startup tasks"); // setup database connection pool @@ -78,7 +84,7 @@ pub async fn start_server() { // build the axum app and run the server... info!("Starting application..."); - let app: Router = routes::get_root_router(Arc::new(get_app_state(&db_connection))); + let app: Router = routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings))); let address = format!("{}:{}", settings.server.address, settings.server.port); info!("Starting server at http://{}", address); @@ -115,11 +121,18 @@ fn get_global_tracing_subscriber_builder( } } -fn get_app_state(db_connection: &Arc) -> AppState { +fn get_app_state( + db_connection: &Arc, + settings: &ProgramSettings, +) -> AppState { AppState { database_connection: db_connection.clone(), service: Arc::new(AppService { settings: Arc::new(SettingsService::new(db_connection.clone())), + authentication: Arc::new(AuthenticationServiceImpl::new( + settings.auth.jwt_secret.clone(), + )), + user: Arc::new(UserServiceImpl::new(db_connection.clone())), }), } } diff --git a/apps/api/src/configs.rs b/apps/api/src/configs.rs index cae85de..6c274d6 100644 --- a/apps/api/src/configs.rs +++ b/apps/api/src/configs.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod database; pub mod logging; pub mod server; @@ -17,6 +18,7 @@ pub struct ProgramSettings { pub logging: logging::LoggingSettings, pub database: database::DatabaseSettings, pub server: server::ServerSettings, + pub auth: auth::AuthSettings, } impl FromConfig for ProgramSettings { @@ -25,6 +27,7 @@ impl FromConfig for ProgramSettings { logging: logging::LoggingSettings::from_config(_config)?, database: database::DatabaseSettings::from_config(_config)?, server: server::ServerSettings::from_config(_config)?, + auth: auth::AuthSettings::from_config(_config)?, }; config.validate()?; Ok(config) @@ -34,6 +37,7 @@ impl FromConfig for ProgramSettings { self.logging.validate()?; self.database.validate()?; self.server.validate()?; + self.auth.validate()?; Ok(()) } } diff --git a/apps/api/src/configs/auth.rs b/apps/api/src/configs/auth.rs new file mode 100644 index 0000000..4041092 --- /dev/null +++ b/apps/api/src/configs/auth.rs @@ -0,0 +1,51 @@ +use config::{Config, ConfigError}; +use tracing::warn; + +use crate::configs::key::{ + AUTH_DEFAULT_ADMIN_PASSWORD_KEY, AUTH_DEFAULT_ADMIN_USERNAME_KEY, AUTH_JWT_SECRET_KEY, +}; + +use super::FromConfig; + +#[derive(Debug, Clone)] +pub struct AuthSettings { + pub jwt_secret: Option, + pub default_admin_username: Option, + pub default_admin_password: Option, +} + +impl FromConfig for AuthSettings { + fn from_config(_config: &Config) -> Result { + Ok(AuthSettings { + jwt_secret: _config + .get_string(AUTH_JWT_SECRET_KEY) + .inspect_err(|err| { + match err { + ConfigError::NotFound(_) => { + warn!( + "{} not found in configuration, A random secret will be generated at runtime.", + AUTH_JWT_SECRET_KEY + ); + } + _ => { + warn!( + "Failed to read {} from configuration, A random secret will be generated at runtime: {}", + AUTH_JWT_SECRET_KEY, err + ); + } + }; + }) + .ok(), + default_admin_username: _config + .get_string(AUTH_DEFAULT_ADMIN_USERNAME_KEY) + .ok(), + default_admin_password: _config + .get_string(AUTH_DEFAULT_ADMIN_PASSWORD_KEY) + .ok(), + }) + } + + fn validate(&self) -> Result<(), String> { + Ok(()) + } +} diff --git a/apps/api/src/configs/key.rs b/apps/api/src/configs/key.rs index dd31902..c4a0ee9 100644 --- a/apps/api/src/configs/key.rs +++ b/apps/api/src/configs/key.rs @@ -7,3 +7,7 @@ pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT"; pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL"; pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS"; pub(crate) const DATABASE_MIGRATE_ON_STARTUP_KEY: &str = "DATABASE.MIGRATION.MIGRATE_ON_STARTUP"; +// +pub(crate) const AUTH_JWT_SECRET_KEY: &str = "AUTH.JWT_SECRET"; +pub(crate) const AUTH_DEFAULT_ADMIN_USERNAME_KEY: &str = "AUTH.DEFAULT_ADMIN_USERNAME"; +pub(crate) const AUTH_DEFAULT_ADMIN_PASSWORD_KEY: &str = "AUTH.DEFAULT_ADMIN_PASSWORD"; diff --git a/apps/api/src/routes.rs b/apps/api/src/routes.rs index f257f1b..59e181e 100644 --- a/apps/api/src/routes.rs +++ b/apps/api/src/routes.rs @@ -8,7 +8,13 @@ use std::sync::Arc; use axum::{Extension, Router}; use migration::sea_orm::DatabaseConnection; -use crate::{middlewares, services::settings::SettingsStore}; +use crate::{ + middlewares, + services::{ + auth::{authentication::AuthenticationService, user::UserService}, + settings::SettingsStore, + }, +}; #[derive(Clone)] pub struct AppState { @@ -25,6 +31,10 @@ pub type ServiceState = Arc; pub struct AppService { #[allow(dead_code)] // TODO: remove when used pub settings: ServiceState, + #[allow(dead_code)] // TODO: remove when used + pub authentication: ServiceState, + #[allow(dead_code)] // TODO: remove when used + pub user: ServiceState, } pub fn get_root_router(state: impl Into>) -> Router { diff --git a/apps/api/src/services.rs b/apps/api/src/services.rs index 6e98cef..f7da917 100644 --- a/apps/api/src/services.rs +++ b/apps/api/src/services.rs @@ -1 +1,2 @@ +pub mod auth; pub mod settings; diff --git a/apps/api/src/services/auth.rs b/apps/api/src/services/auth.rs new file mode 100644 index 0000000..0e84b15 --- /dev/null +++ b/apps/api/src/services/auth.rs @@ -0,0 +1,2 @@ +pub mod authentication; +pub mod user; diff --git a/apps/api/src/services/auth/authentication.rs b/apps/api/src/services/auth/authentication.rs new file mode 100644 index 0000000..2bbaae1 --- /dev/null +++ b/apps/api/src/services/auth/authentication.rs @@ -0,0 +1,269 @@ +pub mod strategies; + +use std::{collections::HashSet, sync::Arc}; + +use argon2::password_hash::{SaltString, rand_core::OsRng}; +use jsonwebtoken::{ + DecodingKey, EncodingKey, Header, Validation, decode, encode, + errors::ErrorKind::{ExpiredSignature, InvalidSubject, InvalidToken}, +}; +use sea_orm::prelude::Uuid; +use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; + +use crate::errors::service_error::ServiceError; + +// Number of requests between invalidation cache cleanups +const INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS: usize = 100; // Cleanup every 100 for invalidation checks + +#[derive(Serialize, Deserialize, Clone)] +pub struct Claims { + // Subject - user ID + pub sub: String, + // Issued at as UNIX timestamp + pub iat: u64, + // Expiration time as UNIX timestamp + pub exp: u64, +} + +#[async_trait::async_trait] +pub trait AuthenticationService: Send + Sync { + async fn generate_jwt(&self, user_id: Uuid, duration_secs: u64) + -> Result; + async fn is_valid_jwt( + &self, + token: &str, + target_sub: Option, + ) -> Result; + async fn parse_jwt(&self, token: &str) -> Result; + async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>; + async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result; + async fn logout(&self, token: &str) -> Result<(), ServiceError>; + async fn cleanup_invalidation_cache(&self); +} + +#[derive(Eq, Hash, PartialEq)] +struct InvalidationEntry { + token: String, + invalidated_at: u64, + valid_until: u64, +} + +pub struct AuthenticationServiceImpl { + secret: String, + invalidation_cache: Arc>>, + cache_cleanup_counter: Arc>, +} + +impl AuthenticationServiceImpl { + pub fn new(secret: Option) -> Self { + let secret = secret.unwrap_or_else(|| { + // generate a random secret if none is provided + SaltString::generate(&mut OsRng).as_str().to_owned() + }); + + Self { + secret, + invalidation_cache: Arc::new(RwLock::new(HashSet::new())), + cache_cleanup_counter: Arc::new(RwLock::new(0)), + } + } +} + +#[async_trait::async_trait] +impl AuthenticationService for AuthenticationServiceImpl { + async fn generate_jwt( + &self, + user_id: Uuid, + duration_secs: u64, + ) -> Result { + let header = Header::default(); + let expiration = chrono::Utc::now() + .checked_add_signed(chrono::Duration::seconds(duration_secs as i64)) + .ok_or(ServiceError::InternalError( + "Invalid expiration time".into(), + ))? + .timestamp() as u64; + let claims = Claims { + sub: user_id.to_string(), + iat: chrono::Utc::now().timestamp() as u64, + exp: expiration, + }; + let token = encode( + &header, + &claims, + &EncodingKey::from_secret(self.secret.as_ref()), + ) + .map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?; + Ok(token) + } + + async fn is_valid_jwt( + &self, + token: &str, + target_sub: Option, + ) -> Result { + let mut validation = Validation::default(); + if let Some(expected_sub) = target_sub { + validation.sub = Some(expected_sub); + } + let decoding_key = DecodingKey::from_secret(self.secret.as_ref()); + match decode::(token, &decoding_key, &validation) { + Ok(_) => Ok(true), + Err(err) => match *err.kind() { + InvalidToken | InvalidSubject | ExpiredSignature => Ok(false), + _ => Err(ServiceError::InternalError(format!( + "JWT validation error: {}", + err + ))), + }, + } + } + + async fn parse_jwt(&self, token: &str) -> Result { + let decoding_key = DecodingKey::from_secret(self.secret.as_ref()); + let token_data = decode::(token, &decoding_key, &Validation::default()) + .map_err(|e| ServiceError::InternalError(format!("JWT parsing error: {}", e)))?; + Ok(token_data.claims) + } + + async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError> { + let claims = self.parse_jwt(token).await?; + let valid_until = claims.exp; + let invalidated_at = chrono::Utc::now().timestamp() as u64; + let entry = InvalidationEntry { + token: token.to_string(), + invalidated_at, + valid_until, + }; + + { + self.invalidation_cache.write().await.insert(entry); + } + // + if self.cache_cleanup_counter.read().await.wrapping_add(1) + % INVALIDATE_CACHE_CLEANUP_INTERVAL_REQUESTS + == 0 + { + self.cleanup_invalidation_cache().await; + } + // + Ok(()) + } + + async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result { + let claims = self.parse_jwt(token).await?; + let user_id = Uuid::parse_str(&claims.sub).map_err(|e| { + ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e)) + })?; + let new_token = self.generate_jwt(user_id, duration_secs).await?; + Ok(new_token) + } + + async fn logout(&self, token: &str) -> Result<(), ServiceError> { + self.invalidate_jwt(token).await + } + + async fn cleanup_invalidation_cache(&self) { + let now = chrono::Utc::now().timestamp() as u64; + let mut cache = self.invalidation_cache.write().await; + cache.retain(|entry| entry.valid_until > now); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::{Duration, sleep}; + + #[tokio::test] + async fn test_jwt_generation_and_validation() { + let service = AuthenticationServiceImpl::new(Some("secret".to_string())); + + let user_id = Uuid::new_v4(); + let token = service + .generate_jwt(user_id, 60) + .await + .expect("generate jwt"); + + let valid = service + .is_valid_jwt(&token, None) + .await + .expect("validate jwt"); + assert!(valid, "Generated token should be valid"); + + let claims = service.parse_jwt(&token).await.expect("parse jwt"); + assert_eq!(claims.sub, user_id.to_string()); + } + + #[tokio::test] + async fn test_jwt_validation_with_wrong_subject() { + let service = AuthenticationServiceImpl::new(Some("secret".to_string())); + + let user_id = Uuid::new_v4(); + let token = service.generate_jwt(user_id, 60).await.unwrap(); + + let other_sub = Uuid::new_v4().to_string(); + let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap(); + assert!(!valid, "Token should be invalid for a different subject"); + } + + #[tokio::test] + async fn test_parse_jwt_invalid_token() { + let service = AuthenticationServiceImpl::new(Some("secret".to_string())); + + let res = service.parse_jwt("not_a_token").await; + assert!(matches!(res, Err(ServiceError::InternalError(_)))); + } + + #[tokio::test] + async fn test_refresh_jwt() { + let service = AuthenticationServiceImpl::new(Some("secret".to_string())); + + let user_id = Uuid::new_v4(); + let token = service.generate_jwt(user_id, 60).await.unwrap(); + let new_token = service.refresh_jwt(&token, 120).await.unwrap(); + + let claims = service.parse_jwt(&new_token).await.unwrap(); + assert_eq!(claims.sub, user_id.to_string()); + assert_eq!(claims.exp - claims.iat, 120); + } + + #[tokio::test] + async fn test_is_valid_jwt_expired() { + let service = AuthenticationServiceImpl::new(Some("secret".to_string())); + + let user_id = Uuid::new_v4(); + let token = service.generate_jwt(user_id, 1).await.unwrap(); + sleep(Duration::from_secs(2)).await; + + let valid = service.is_valid_jwt(&token, None).await.unwrap(); + assert!(!valid, "Token should be expired and thus invalid"); + } + + #[tokio::test] + async fn test_invalidate_and_cleanup() { + let service = AuthenticationServiceImpl::new(Some("secret".to_string())); + + let user_id = Uuid::new_v4(); + let token = service.generate_jwt(user_id, 1).await.unwrap(); + + service.invalidate_jwt(&token).await.unwrap(); + + // ensure entry is present + { + let cache = service.invalidation_cache.read().await; + assert!(cache.iter().any(|e| e.token == token)); + } + + // wait until token validity ends and cleanup + sleep(Duration::from_secs(2)).await; + service.cleanup_invalidation_cache().await; + + let cache = service.invalidation_cache.read().await; + assert!( + cache.is_empty(), + "Cleanup should remove expired invalidation entries" + ); + } +} diff --git a/apps/api/src/services/auth/authentication/strategies.rs b/apps/api/src/services/auth/authentication/strategies.rs new file mode 100644 index 0000000..c72e4b9 --- /dev/null +++ b/apps/api/src/services/auth/authentication/strategies.rs @@ -0,0 +1 @@ +pub mod password; diff --git a/apps/api/src/services/auth/authentication/strategies/password.rs b/apps/api/src/services/auth/authentication/strategies/password.rs new file mode 100644 index 0000000..396fbbb --- /dev/null +++ b/apps/api/src/services/auth/authentication/strategies/password.rs @@ -0,0 +1,453 @@ +use std::sync::Arc; + +use crate::{errors::service_error::ServiceError, with_conn}; +use argon2::{ + Argon2, + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString, rand_core::OsRng}, +}; +use database::generated::entities::{user, user_identity}; +use sea_orm::{ + ColumnTrait, DatabaseConnection, DatabaseTransaction, EntityTrait, IntoActiveModel, + QueryFilter, prelude::Uuid, +}; + +pub struct PasswordStrategy { + connection: Arc, +} + +const MAX_PASSWORD_LENGTH: usize = 32; +const PASSWORD_PROVIDER: &str = "password"; + +impl PasswordStrategy { + pub fn new(connection: Arc) -> Self { + Self { connection } + } + + pub async fn authenticate( + &self, + username: &str, + password: &str, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + // Find user by username + let user = with_conn!(&*self.connection, tx, conn, { + user::Entity::find() + .filter(user::Column::Name.eq(username)) + .one(*conn) + .await? + .ok_or_else(|| { + ServiceError::Unauthorized("Invalid username or password".to_string()) + })? + }); + // Get user's identity + let identity = with_conn!(&*self.connection, tx, conn, { + user_identity::Entity::find() + .filter(user_identity::Column::UserId.eq(user.id)) + .one(*conn) + .await? + .ok_or_else(|| { + ServiceError::Unauthorized("Invalid username or password".to_string()) + })? + }); + + // Check if revoked + if identity.is_revoked { + return Err(ServiceError::Unauthorized("Account is revoked".to_string())); + } + + // Verify password + let password_hash = identity + .password_hash + .ok_or_else(|| ServiceError::InternalError("Invalid password hash".to_string()))?; + let parsed_hash = PasswordHash::new(&password_hash) + .map_err(|_| ServiceError::InternalError("Invalid password hash".to_string()))?; + + Argon2::default() + .verify_password(password.as_bytes(), &parsed_hash) + .map_err(|_| ServiceError::Unauthorized("Invalid username or password".to_string()))?; + + Ok(user.id) + } + + pub async fn revoke_identity( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result<(), ServiceError> { + let mut identity = with_conn!(&*self.connection, tx, conn, { + user_identity::Entity::find() + .filter(user_identity::Column::UserId.eq(user_id)) + .one(*conn) + .await? + .ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))? + }); + + identity.is_revoked = true; + + with_conn!(&*self.connection, tx, conn, { + user_identity::Entity::update(identity.into_active_model()) + .exec(*conn) + .await + .map_err(ServiceError::from) + })?; + + Ok(()) + } + + pub async fn create_identity( + &self, + user_id: Uuid, + password: &str, + tx: Option<&mut DatabaseTransaction>, + ) -> Result<(), ServiceError> { + Self::is_valid_password(password).map_err(ServiceError::BadRequest)?; + + let password_hash = Argon2::default() + .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng)) + .map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))? + .to_string(); + + let new_identity = user_identity::ActiveModel { + user_id: sea_orm::ActiveValue::Set(user_id), + provider: sea_orm::ActiveValue::Set(PASSWORD_PROVIDER.to_string()), + password_hash: sea_orm::ActiveValue::Set(Some(password_hash)), + metadata: sea_orm::ActiveValue::Set(None), + is_revoked: sea_orm::ActiveValue::Set(false), + ..Default::default() + }; + + with_conn!(&*self.connection, tx, conn, { + user_identity::Entity::insert(new_identity) + .exec(*conn) + .await + .map_err(ServiceError::from) + })?; + + Ok(()) + } + + pub async fn update_password( + &self, + user_id: Uuid, + new_password: &str, + tx: Option<&mut DatabaseTransaction>, + ) -> Result<(), ServiceError> { + Self::is_valid_password(new_password).map_err(ServiceError::BadRequest)?; + + let password_hash = Argon2::default() + .hash_password(new_password.as_bytes(), &SaltString::generate(&mut OsRng)) + .map_err(|_| ServiceError::InternalError("Failed to hash password".to_string()))? + .to_string(); + + let mut identity = with_conn!(&*self.connection, tx, conn, { + user_identity::Entity::find() + .filter(user_identity::Column::UserId.eq(user_id)) + .one(*conn) + .await? + .ok_or_else(|| ServiceError::NotFound("User identity not found".to_string()))? + }); + + identity.password_hash = Some(password_hash); + identity.password_changed_at = Some(chrono::Utc::now()); + + with_conn!(&*self.connection, tx, conn, { + user_identity::Entity::update(identity.into_active_model()) + .exec(*conn) + .await + .map_err(ServiceError::from) + })?; + + Ok(()) + } + + fn is_valid_password(password: &str) -> Result<(), String> { + if password.is_empty() { + return Err("Password cannot be empty".to_string()); + } + if password.len() > MAX_PASSWORD_LENGTH { + return Err(format!( + "Password cannot be longer than {} characters", + MAX_PASSWORD_LENGTH + )); + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use database::generated::entities::{user, user_identity}; + use sea_orm::MockDatabase; + + #[test] + fn ensure_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } + + #[test] + fn password_validation() { + let valid_password = "ValidPassword123!"; + let long_password = "a".repeat(129); + + assert!(PasswordStrategy::is_valid_password(valid_password).is_ok()); + assert!(PasswordStrategy::is_valid_password(long_password.as_str()).is_err()); + } + + #[tokio::test] + async fn authenticate_user_not_found() { + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![Vec::::new()]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy + .authenticate("nonexistent_user", "password", None) + .await; + + assert!(matches!(result, Err(ServiceError::Unauthorized(_)))); + } + + #[tokio::test] + async fn authenticate_invalid_password() { + let user_id = Uuid::new_v4(); + let password_hash = Argon2::default() + .hash_password( + "CorrectPassword".as_bytes(), + &SaltString::generate(&mut OsRng), + ) + .unwrap() + .to_string(); + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![vec![user::Model { + id: user_id, + name: "test_user".to_string(), + is_active: true, + is_admin: false, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + deleted_at: None, + last_login_at: None, + }]]) + .append_query_results(vec![vec![user_identity::Model { + id: Uuid::new_v4(), + user_id, + email: None, + provider: PASSWORD_PROVIDER.to_string(), + password_hash: Some(password_hash), + metadata: None, + is_revoked: false, + revoked_at: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + password_changed_at: None, + }]]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy + .authenticate("test_user", "InvalidPassword", None) + .await; + + assert!(matches!(result, Err(ServiceError::Unauthorized(_)))); + } + + #[tokio::test] + async fn authenticate_success() { + let user_id = Uuid::new_v4(); + let password_hash = Argon2::default() + .hash_password( + "CorrectPassword".as_bytes(), + &SaltString::generate(&mut OsRng), + ) + .unwrap() + .to_string(); + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![vec![user::Model { + id: user_id, + name: "test_user".to_string(), + is_active: true, + is_admin: false, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + deleted_at: None, + last_login_at: None, + }]]) + .append_query_results(vec![vec![user_identity::Model { + id: Uuid::new_v4(), + user_id, + email: None, + provider: PASSWORD_PROVIDER.to_string(), + password_hash: Some(password_hash), + metadata: None, + is_revoked: false, + revoked_at: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + password_changed_at: None, + }]]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy + .authenticate("test_user", "CorrectPassword", None) + .await; + + assert!(matches!(result, Ok(id) if id == user_id)); + } + + #[tokio::test] + async fn revoke_identity_not_found() { + let user_id = Uuid::new_v4(); + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![Vec::::new()]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy.revoke_identity(user_id, None).await; + + assert!(matches!(result, Err(ServiceError::NotFound(_)))); + } + + #[tokio::test] + async fn revoke_identity_success() { + let user_id = Uuid::new_v4(); + let identity = user_identity::Model { + id: Uuid::new_v4(), + user_id, + email: None, + provider: PASSWORD_PROVIDER.to_string(), + password_hash: None, + metadata: None, + is_revoked: false, + revoked_at: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + password_changed_at: None, + }; + + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![ + vec![identity.clone()], + vec![user_identity::Model { + is_revoked: true, + ..identity + }], + ]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy.revoke_identity(user_id, None).await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn create_identity_invalid_password() { + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite).into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy.create_identity(Uuid::new_v4(), "", None).await; + + assert!(matches!(result, Err(ServiceError::BadRequest(_)))); + } + + #[tokio::test] + async fn create_identity_success() { + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![vec![user_identity::Model { + id: Uuid::new_v4(), + user_id: Uuid::new_v4(), + email: None, + provider: PASSWORD_PROVIDER.to_string(), + password_hash: Some("somehash".to_string()), + metadata: None, + is_revoked: false, + revoked_at: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + password_changed_at: None, + }]]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy + .create_identity(Uuid::new_v4(), "ValidPass1!", None) + .await; + + assert!( + result.is_ok(), + "Failed to create identity, error: {:?}", + result.err() + ); + } + + #[tokio::test] + async fn update_password_not_found() { + let user_id = Uuid::new_v4(); + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![Vec::::new()]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy.update_password(user_id, "NewPass1!", None).await; + + assert!(matches!(result, Err(ServiceError::NotFound(_)))); + } + + #[tokio::test] + async fn update_password_success() { + let user_id = Uuid::new_v4(); + let identity = user_identity::Model { + id: Uuid::new_v4(), + user_id, + email: None, + provider: PASSWORD_PROVIDER.to_string(), + password_hash: Some("oldhash".to_string()), + metadata: None, + is_revoked: false, + revoked_at: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + password_changed_at: None, + }; + + let db = MockDatabase::new(sea_orm::DatabaseBackend::Sqlite) + .append_query_results(vec![ + vec![identity], + vec![user_identity::Model { + id: Uuid::new_v4(), + user_id, + email: None, + provider: PASSWORD_PROVIDER.to_string(), + password_hash: Some("newhash".to_string()), + metadata: None, + is_revoked: false, + revoked_at: None, + created_at: chrono::Utc::now(), + updated_at: chrono::Utc::now(), + password_changed_at: None, + }], + ]) + .into_connection(); + + let strategy = PasswordStrategy::new(Arc::new(db)); + + let result = strategy.update_password(user_id, "NewPass1!", None).await; + + assert!( + result.is_ok(), + "Failed to update password, error: {:?}", + result.err() + ); + } +} diff --git a/apps/api/src/services/auth/user.rs b/apps/api/src/services/auth/user.rs new file mode 100644 index 0000000..58abbfa --- /dev/null +++ b/apps/api/src/services/auth/user.rs @@ -0,0 +1,208 @@ +use std::sync::Arc; + +use database::generated::entities::user::{ + self, ActiveModel as UserActiveModel, Model as UserModel, +}; +use sea_orm::{ + ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, DatabaseTransaction, DbErr, + EntityTrait, IntoActiveModel, QueryFilter, prelude::Uuid, +}; + +use crate::{errors::service_error::ServiceError, with_conn}; + +#[async_trait::async_trait] +pub trait UserService: Send + Sync { + async fn get_user_by_id( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result; + async fn is_admin( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result; + async fn user_exists( + &self, + username: &str, + tx: Option<&mut DatabaseTransaction>, + ) -> Result; + async fn create_user( + &self, + user: NewUser, + tx: Option<&mut DatabaseTransaction>, + ) -> Result; + async fn update_user( + &self, + user_id: Uuid, + user: UpdateUser, + tx: Option<&mut DatabaseTransaction>, + ) -> Result; + async fn delete_user( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result<(), ServiceError>; +} + +pub struct User { + pub id: Uuid, + pub username: String, + pub is_admin: bool, +} + +impl From for User { + fn from(model: UserModel) -> Self { + Self { + id: model.id, + username: model.name, + is_admin: model.is_admin, + } + } +} + +pub struct NewUser { + pub username: String, + pub is_admin: bool, +} + +pub struct UpdateUser { + pub username: Option, + pub is_admin: Option, + pub is_active: Option, +} + +impl UpdateUser { + fn apply_to_active_model(&self, model: &mut UserActiveModel) { + if let Some(username) = &self.username { + model.name = ActiveValue::Set(username.clone()); + } + if let Some(is_admin) = self.is_admin { + model.is_admin = ActiveValue::Set(is_admin); + } + if let Some(is_active) = self.is_active { + model.is_active = ActiveValue::Set(is_active); + } + } +} + +pub struct UserServiceImpl { + connection: Arc, +} + +impl UserServiceImpl { + pub fn new(connection: Arc) -> Self { + Self { connection } + } + + async fn get_user_by_id_from_db( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + let user = with_conn!(&*self.connection, tx, conn, { + user::Entity::find_by_id(user_id).one(*conn).await + }); + + match user { + Err(err) => Err(ServiceError::from(err)), + Ok(None) => Err(ServiceError::NotFound(format!( + "User with id '{}' not found", + user_id + ))), + Ok(Some(record)) => Ok(record), + } + } +} + +#[async_trait::async_trait] +impl UserService for UserServiceImpl { + async fn get_user_by_id( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + let user = self.get_user_by_id_from_db(user_id, tx).await?; + Ok(User::from(user)) + } + + async fn is_admin( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + let user = self.get_user_by_id(user_id, tx).await?; + Ok(user.is_admin) + } + + async fn user_exists( + &self, + username: &str, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + let user = with_conn!(&*self.connection, tx, conn, { + user::Entity::find() + .filter(user::Column::Name.eq(username)) + .one(*conn) + .await + }); + + match user { + Err(err) => match err { + DbErr::RecordNotFound(_) => Ok(false), + _ => Err(ServiceError::from(err)), + }, + Ok(None) => Ok(false), + Ok(Some(_)) => Ok(true), + } + } + + async fn create_user( + &self, + user: NewUser, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + let user_active_model = UserActiveModel { + id: ActiveValue::NotSet, + name: ActiveValue::Set(user.username), + is_admin: ActiveValue::Set(user.is_admin), + is_active: ActiveValue::Set(true), + ..Default::default() + }; + + let user_model = with_conn!(&*self.connection, tx, conn, { + user_active_model.insert(*conn).await + })?; + + Ok(User::from(user_model)) + } + + async fn update_user( + &self, + user_id: Uuid, + update_user: UpdateUser, + tx: Option<&mut DatabaseTransaction>, + ) -> Result { + let existing_user = self.get_user_by_id_from_db(user_id, tx).await?; + + let mut user_active_model = existing_user.into_active_model(); + update_user.apply_to_active_model(&mut user_active_model); + + let user_model = user_active_model.update(&*self.connection).await?; + + Ok(User::from(user_model)) + } + + async fn delete_user( + &self, + user_id: Uuid, + tx: Option<&mut DatabaseTransaction>, + ) -> Result<(), ServiceError> { + let user = self.get_user_by_id_from_db(user_id, tx).await?; + + let user_active_model = user.into_active_model(); + user_active_model.delete(&*self.connection).await?; + + Ok(()) + } +}