diff --git a/Cargo.lock b/Cargo.lock index 06944b1..1f96fe9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -224,6 +224,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbfe9f610fe4e99cf0cfcd03ccf8c63c28c616fe714d80475ef731f3b13dd21b" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-macros" version = "0.5.0" @@ -607,6 +629,17 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -4665,6 +4698,7 @@ dependencies = [ "argon2", "async-trait", "axum", + "axum-extra", "chrono", "clap", "config", diff --git a/apps/api/Cargo.toml b/apps/api/Cargo.toml index 3ec1a3d..8e400da 100644 --- a/apps/api/Cargo.toml +++ b/apps/api/Cargo.toml @@ -7,7 +7,8 @@ edition = "2024" database = { path = "../../public/database" } migration = { path = "../../public/migration" } -axum = { version = "0.8.7", features = ["form", "http1", "http2", "json", "matched-path", "original-uri", "query", "tokio", "tower-log", "tracing", "macros"]} +axum = { version = "0.8.7", features = ["form", "http1", "http2", "json", "matched-path", "original-uri", "query", "tokio", "tower-log", "tracing", "macros"] } +axum-extra = { version = "0.12.2", features = ["cookie"] } async-trait = { version = "0.1.89" } chrono = { version = "0.4.42", features = ["clock", "std", "oldtime", "wasmbind", "serde"] } config = { version = "0.15.19", features = ["toml", "json", "yaml", "ini", "ron", "json5", "convert-case", "async"] } @@ -26,3 +27,4 @@ once_cell = { version = "1.21.3" } argon2 = { version = "0.5.3", features = ["std"] } jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] } uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] } + diff --git a/apps/api/src/cmd/start_server.rs b/apps/api/src/cmd/start_server.rs index d1d1ce8..e602258 100644 --- a/apps/api/src/cmd/start_server.rs +++ b/apps/api/src/cmd/start_server.rs @@ -13,7 +13,10 @@ use crate::{ log, routes::{self, AppService, AppState}, services::{ - auth::{authentication::AuthenticationServiceImpl, user::UserServiceImpl}, + auth::{ + authentication::{AuthenticationServiceImpl, strategies::password::PasswordStrategy}, + user::UserServiceImpl, + }, settings::SettingsService, }, tasks, @@ -129,9 +132,15 @@ fn get_app_state( database_connection: db_connection.clone(), service: Arc::new(AppService { settings: Arc::new(SettingsService::new(db_connection.clone())), - authentication: Arc::new(AuthenticationServiceImpl::new( - settings.auth.jwt_secret.clone(), - )), + auth_state: routes::AuthState { + strategy: routes::AuthStrategy { + password: Arc::new(PasswordStrategy::new(db_connection.clone())), + }, + authentication: Arc::new(AuthenticationServiceImpl::new( + settings.auth.jwt_secret.clone(), + )), + user: Arc::new(UserServiceImpl::new(db_connection.clone())), + }, user: Arc::new(UserServiceImpl::new(db_connection.clone())), }), } diff --git a/apps/api/src/helpers/constants.rs b/apps/api/src/helpers/constants.rs index a36bf1c..05b6e2f 100644 --- a/apps/api/src/helpers/constants.rs +++ b/apps/api/src/helpers/constants.rs @@ -1 +1,3 @@ pub const ADMIN_INIT_SECRET_KEY: &str = "admin_init_secret"; +// +pub const JWT_COOKIE_NAME: &str = "session_jwt"; diff --git a/apps/api/src/main.rs b/apps/api/src/main.rs index 3eef8e1..4d6bbdb 100644 --- a/apps/api/src/main.rs +++ b/apps/api/src/main.rs @@ -1,3 +1,5 @@ +#![forbid(unsafe_code)] + mod cmd; mod configs; mod errors; diff --git a/apps/api/src/middlewares.rs b/apps/api/src/middlewares.rs index f47e3de..69935b1 100644 --- a/apps/api/src/middlewares.rs +++ b/apps/api/src/middlewares.rs @@ -1,16 +1,21 @@ +pub mod request_info; +pub mod require_auth; + +use std::{sync::Arc, time::Duration}; + use axum::{ BoxError, Router, error_handling::HandleErrorLayer, http::{Method, StatusCode, Uri}, }; -use std::time::Duration; use tower::{ServiceBuilder, timeout::TimeoutLayer}; - use tracing::warn; +use crate::routes::AppState; + pub const TIMEOUT_DURATION_SECS: u64 = 30; -pub fn apply_root_middleware(router: Router) -> Router { +pub fn apply_root_middleware(router: Router, _state: Arc) -> Router { let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS)); let service_builder = ServiceBuilder::new() diff --git a/apps/api/src/middlewares/request_info.rs b/apps/api/src/middlewares/request_info.rs new file mode 100644 index 0000000..fb44b20 --- /dev/null +++ b/apps/api/src/middlewares/request_info.rs @@ -0,0 +1,6 @@ +use uuid::Uuid; + +#[derive(Clone, Debug)] +pub struct RequestInfo { + pub user_id: Option, +} diff --git a/apps/api/src/middlewares/require_auth.rs b/apps/api/src/middlewares/require_auth.rs new file mode 100644 index 0000000..b504210 --- /dev/null +++ b/apps/api/src/middlewares/require_auth.rs @@ -0,0 +1,68 @@ +use std::sync::Arc; + +use axum::{ + extract::State, + http::{Request, StatusCode}, + middleware::Next, + response::Response, +}; +use axum_extra::extract::cookie::CookieJar; +use uuid::Uuid; + +use crate::{ + errors::service_error::ServiceError, helpers::constants::JWT_COOKIE_NAME, + middlewares::request_info::RequestInfo, routes::AppState, +}; + +pub async fn require_auth( + cookies: CookieJar, + State(state): State>, + req: Request, + next: Next, +) -> Result { + // get jwt from cookies + let auth_service = &state.service.auth_state.authentication; + let token = if let Some(cookie) = cookies.get(JWT_COOKIE_NAME) { + cookie.value().to_string() + } else { + return handle_unauthenticated().await; + }; + + // validate jwt + let is_valid = auth_service.is_valid_jwt(&token, None).await; + let user_id = match is_valid { + Ok(Some(claims)) => claims + .sub + .parse::() + .map_err(|_| StatusCode::UNAUTHORIZED)?, + Ok(None) => return handle_unauthenticated().await, + Err(err) => { + tracing::error!("Error validating JWT: {}", err); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + }; + + // ensure user exists + if let Err(err) = state.service.user.get_user_by_id(user_id, None).await { + match err { + ServiceError::NotFound(_) => return handle_unauthenticated().await, + _ => { + tracing::error!("Error fetching user by ID: {}", err); + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + } + } + + let mut req = req; + let user = req + .extensions_mut() + .get_or_insert_with(|| RequestInfo { user_id: None }); + user.user_id = Some(user_id); + + Ok(next.run(req).await) +} + +async fn handle_unauthenticated() -> Result { + // TODO: log unauthenticated access attempts + Err(StatusCode::UNAUTHORIZED) +} diff --git a/apps/api/src/routes.rs b/apps/api/src/routes.rs index 59e181e..7abbf97 100644 --- a/apps/api/src/routes.rs +++ b/apps/api/src/routes.rs @@ -11,7 +11,10 @@ use migration::sea_orm::DatabaseConnection; use crate::{ middlewares, services::{ - auth::{authentication::AuthenticationService, user::UserService}, + auth::{ + authentication::{AuthenticationService, strategies::password::PasswordStrategy}, + user::UserService, + }, settings::SettingsStore, }, }; @@ -28,25 +31,35 @@ pub struct AppState { pub type ServiceState = Arc; -pub struct AppService { - #[allow(dead_code)] // TODO: remove when used - pub settings: ServiceState, - #[allow(dead_code)] // TODO: remove when used +pub struct AuthStrategy { + pub password: ServiceState, +} + +pub struct AuthState { + pub strategy: AuthStrategy, pub authentication: ServiceState, - #[allow(dead_code)] // TODO: remove when used + pub user: ServiceState, +} + +pub struct AppService { + // #[allow(dead_code)] // TODO: remove when used + pub settings: ServiceState, + pub auth_state: AuthState, + // #[allow(dead_code)] // TODO: remove when used pub user: ServiceState, } pub fn get_root_router(state: impl Into>) -> Router { let mut router = Router::new(); + let state = state.into(); router = router - .nest("/api", api::get_api_router()) + .nest("/api", api::get_api_router(state.clone())) .merge(view::get_view_router()); - router = middlewares::apply_root_middleware(router); + router = middlewares::apply_root_middleware(router, state.clone()); - router = router.layer(Extension(state.into())); + router = router.layer(Extension(state.clone())); router } diff --git a/apps/api/src/routes/api.rs b/apps/api/src/routes/api.rs index 68bbdcc..8fbf5d0 100644 --- a/apps/api/src/routes/api.rs +++ b/apps/api/src/routes/api.rs @@ -1,13 +1,21 @@ +mod auth; mod health; mod openapi; +mod restricted; + +use std::sync::Arc; + +use crate::routes::AppState; pub use self::openapi::ApiDoc; use axum::{Router, response::IntoResponse, routing::any}; -pub fn get_api_router() -> Router { +pub fn get_api_router(state: Arc) -> Router { Router::new() .nest("/health", health::get_health_router()) + .merge(auth::get_basic_auth_router(state.clone())) + .merge(restricted::get_restricted_router(state.clone())) // explicit fallback for unmatched API routes .route("/{*wildcard}", any(api_fallback_handler)) } diff --git a/apps/api/src/routes/api/auth.rs b/apps/api/src/routes/api/auth.rs new file mode 100644 index 0000000..98d87d9 --- /dev/null +++ b/apps/api/src/routes/api/auth.rs @@ -0,0 +1,16 @@ +pub mod login; + +use std::sync::Arc; + +use axum::{ + Router, + routing::{get, post}, +}; + +use crate::routes::AppState; + +pub fn get_basic_auth_router(state: Arc) -> Router { + Router::new() + .route("/login", post(login::login)) + .with_state(state) +} diff --git a/apps/api/src/routes/api/auth/login.rs b/apps/api/src/routes/api/auth/login.rs new file mode 100644 index 0000000..a2abf66 --- /dev/null +++ b/apps/api/src/routes/api/auth/login.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; + +use axum::{ + Json, + body::Body, + extract::State, + http::{StatusCode, header::SET_COOKIE}, + response::{IntoResponse, Response}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, from_value}; +use tracing::{error, warn}; + +use crate::routes::{AppState, api::openapi::tag::AUTH_TAG}; + +/// Login request payload +#[derive(Serialize, Deserialize, utoipa::ToSchema)] +pub struct LoginRequest { + username: String, + password: String, +} + +/// Login endpoint +/// +/// Authenticates a user and returns a JWT in an HttpOnly cookie. +#[utoipa::path( + post, + path = "/api/auth/login", + request_body = LoginRequest, + responses( + (status = 200, description = "User authenticated successfully", body = ()), + (status = 401, description = "Authentication failed"), + (status = 500, description = "Internal server error"), + ), + tag = AUTH_TAG, + )] +pub async fn login(State(state): State>, Json(payload): Json) -> Response { + let login_request: LoginRequest = match from_value(payload) { + Ok(req) => req, + Err(e) => { + warn!("Invalid login request: {}", e); + return (StatusCode::BAD_REQUEST).into_response(); + } + }; + + let user_id = match state + .service + .auth_state + .strategy + .password + .authenticate(&login_request.username, &login_request.password, None) + .await + { + Ok(user_id) => user_id, + Err(e) => { + warn!( + "Authentication failed for user {}: {}", + login_request.username, e + ); + return (StatusCode::UNAUTHORIZED).into_response(); + } + }; + + let (jwt, claims) = match state + .service + .auth_state + .authentication + .generate_jwt(user_id, 3600) + .await + { + Ok(token) => token, + Err(e) => { + error!("Error generating JWT for user {}: {}", user_id, e); + return (StatusCode::INTERNAL_SERVER_ERROR).into_response(); + } + }; + + let response_builder = Response::builder() + .status(StatusCode::OK) + // add jwt as cookie + .header( + SET_COOKIE, + format!( + "token={}; HttpOnly; Path=/; Max-Age={}; SameSite=Strict;", + jwt, + claims.exp - claims.iat + ), + ) + .body(Body::from(())); + + match response_builder { + Ok(resp) => resp, + Err(e) => { + error!("Error building response: {}", e); + return (StatusCode::INTERNAL_SERVER_ERROR).into_response(); + } + } +} diff --git a/apps/api/src/routes/api/openapi.rs b/apps/api/src/routes/api/openapi.rs index 6b39737..b8ae85b 100644 --- a/apps/api/src/routes/api/openapi.rs +++ b/apps/api/src/routes/api/openapi.rs @@ -1,18 +1,22 @@ pub mod tag { /// Health tag constant pub const HEALTH_TAG: &str = "Health"; + pub const AUTH_TAG: &str = "Authentication"; } #[derive(utoipa::OpenApi)] #[openapi( paths( - crate::routes::api::health::info::get_health_info + crate::routes::api::health::info::get_health_info, + crate::routes::api::auth::login::login ), components( - schemas(crate::routes::api::health::info::HealthInfo) // Register any schemas used in your paths + schemas(crate::routes::api::health::info::HealthInfo), // Register any schemas used in your paths + schemas(crate::routes::api::auth::login::LoginRequest) ), tags( - (name = tag::HEALTH_TAG, description = "Health information API") + (name = tag::HEALTH_TAG, description = "Health information API"), + (name = tag::AUTH_TAG, description = "Authentication API") ) )] pub struct ApiDoc; diff --git a/apps/api/src/routes/api/restricted.rs b/apps/api/src/routes/api/restricted.rs new file mode 100644 index 0000000..2372184 --- /dev/null +++ b/apps/api/src/routes/api/restricted.rs @@ -0,0 +1,15 @@ +use std::sync::Arc; + +use axum::{Router, routing::get}; + +use crate::{middlewares::require_auth::require_auth, routes::AppState}; + +pub fn get_restricted_router(state: Arc) -> Router { + Router::new() + // + // + .layer(axum::middleware::from_fn_with_state( + state.clone(), + require_auth, + )) +} diff --git a/apps/api/src/services/auth/authentication.rs b/apps/api/src/services/auth/authentication.rs index 2bbaae1..a7c1096 100644 --- a/apps/api/src/services/auth/authentication.rs +++ b/apps/api/src/services/auth/authentication.rs @@ -28,13 +28,16 @@ pub struct Claims { #[async_trait::async_trait] pub trait AuthenticationService: Send + Sync { - async fn generate_jwt(&self, user_id: Uuid, duration_secs: u64) - -> Result; + async fn generate_jwt( + &self, + user_id: Uuid, + duration_secs: u64, + ) -> Result<(String, Claims), ServiceError>; async fn is_valid_jwt( &self, token: &str, target_sub: Option, - ) -> Result; + ) -> Result, ServiceError>; async fn parse_jwt(&self, token: &str) -> Result; async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>; async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result; @@ -76,7 +79,7 @@ impl AuthenticationService for AuthenticationServiceImpl { &self, user_id: Uuid, duration_secs: u64, - ) -> Result { + ) -> Result<(String, Claims), ServiceError> { let header = Header::default(); let expiration = chrono::Utc::now() .checked_add_signed(chrono::Duration::seconds(duration_secs as i64)) @@ -95,23 +98,23 @@ impl AuthenticationService for AuthenticationServiceImpl { &EncodingKey::from_secret(self.secret.as_ref()), ) .map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?; - Ok(token) + Ok((token, claims)) } async fn is_valid_jwt( &self, token: &str, target_sub: Option, - ) -> Result { + ) -> Result, ServiceError> { let mut validation = Validation::default(); if let Some(expected_sub) = target_sub { validation.sub = Some(expected_sub); } let decoding_key = DecodingKey::from_secret(self.secret.as_ref()); match decode::(token, &decoding_key, &validation) { - Ok(_) => Ok(true), + Ok(data) => Ok(Some(data.claims)), Err(err) => match *err.kind() { - InvalidToken | InvalidSubject | ExpiredSignature => Ok(false), + InvalidToken | InvalidSubject | ExpiredSignature => Ok(None), _ => Err(ServiceError::InternalError(format!( "JWT validation error: {}", err @@ -156,7 +159,7 @@ impl AuthenticationService for AuthenticationServiceImpl { let user_id = Uuid::parse_str(&claims.sub).map_err(|e| { ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e)) })?; - let new_token = self.generate_jwt(user_id, duration_secs).await?; + let (new_token, _) = self.generate_jwt(user_id, duration_secs).await?; Ok(new_token) } @@ -181,7 +184,7 @@ mod tests { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); - let token = service + let (token, _) = service .generate_jwt(user_id, 60) .await .expect("generate jwt"); @@ -190,8 +193,7 @@ mod tests { .is_valid_jwt(&token, None) .await .expect("validate jwt"); - assert!(valid, "Generated token should be valid"); - + assert!(valid.is_some(), "Generated token should be valid"); let claims = service.parse_jwt(&token).await.expect("parse jwt"); assert_eq!(claims.sub, user_id.to_string()); } @@ -201,11 +203,14 @@ mod tests { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); - let token = service.generate_jwt(user_id, 60).await.unwrap(); + let (token, _) = service.generate_jwt(user_id, 60).await.unwrap(); let other_sub = Uuid::new_v4().to_string(); let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap(); - assert!(!valid, "Token should be invalid for a different subject"); + assert!( + valid.is_none(), + "Token should be invalid for a different subject" + ); } #[tokio::test] @@ -221,7 +226,7 @@ mod tests { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); - let token = service.generate_jwt(user_id, 60).await.unwrap(); + let (token, _) = service.generate_jwt(user_id, 60).await.unwrap(); let new_token = service.refresh_jwt(&token, 120).await.unwrap(); let claims = service.parse_jwt(&new_token).await.unwrap(); @@ -234,11 +239,11 @@ mod tests { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); - let token = service.generate_jwt(user_id, 1).await.unwrap(); + let (token, _) = service.generate_jwt(user_id, 1).await.unwrap(); sleep(Duration::from_secs(2)).await; let valid = service.is_valid_jwt(&token, None).await.unwrap(); - assert!(!valid, "Token should be expired and thus invalid"); + assert!(valid.is_none(), "Token should be expired and thus invalid"); } #[tokio::test] @@ -246,7 +251,7 @@ mod tests { let service = AuthenticationServiceImpl::new(Some("secret".to_string())); let user_id = Uuid::new_v4(); - let token = service.generate_jwt(user_id, 1).await.unwrap(); + let (token, _) = service.generate_jwt(user_id, 1).await.unwrap(); service.invalidate_jwt(&token).await.unwrap();