feature/authentication service #9

Merged
GW_MC merged 19 commits from feature/authentication into master 2025-12-19 12:24:49 +08:00
15 changed files with 326 additions and 39 deletions
Showing only changes of commit ccd8bc7aa1 - Show all commits

34
Cargo.lock generated
View File

@@ -224,6 +224,28 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-extra"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbfe9f610fe4e99cf0cfcd03ccf8c63c28c616fe714d80475ef731f3b13dd21b"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-core",
"futures-util",
"http",
"http-body",
"http-body-util",
"mime",
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-macros"
version = "0.5.0"
@@ -607,6 +629,17 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"percent-encoding",
"time",
"version_check",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@@ -4665,6 +4698,7 @@ dependencies = [
"argon2",
"async-trait",
"axum",
"axum-extra",
"chrono",
"clap",
"config",

View File

@@ -7,7 +7,8 @@ edition = "2024"
database = { path = "../../public/database" }
migration = { path = "../../public/migration" }
axum = { version = "0.8.7", features = ["form", "http1", "http2", "json", "matched-path", "original-uri", "query", "tokio", "tower-log", "tracing", "macros"]}
axum = { version = "0.8.7", features = ["form", "http1", "http2", "json", "matched-path", "original-uri", "query", "tokio", "tower-log", "tracing", "macros"] }
axum-extra = { version = "0.12.2", features = ["cookie"] }
async-trait = { version = "0.1.89" }
chrono = { version = "0.4.42", features = ["clock", "std", "oldtime", "wasmbind", "serde"] }
config = { version = "0.15.19", features = ["toml", "json", "yaml", "ini", "ron", "json5", "convert-case", "async"] }
@@ -26,3 +27,4 @@ once_cell = { version = "1.21.3" }
argon2 = { version = "0.5.3", features = ["std"] }
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] }

View File

@@ -13,7 +13,10 @@ use crate::{
log,
routes::{self, AppService, AppState},
services::{
auth::{authentication::AuthenticationServiceImpl, user::UserServiceImpl},
auth::{
authentication::{AuthenticationServiceImpl, strategies::password::PasswordStrategy},
user::UserServiceImpl,
},
settings::SettingsService,
},
tasks,
@@ -129,10 +132,16 @@ fn get_app_state(
database_connection: db_connection.clone(),
service: Arc::new(AppService {
settings: Arc::new(SettingsService::new(db_connection.clone())),
auth_state: routes::AuthState {
strategy: routes::AuthStrategy {
password: Arc::new(PasswordStrategy::new(db_connection.clone())),
},
authentication: Arc::new(AuthenticationServiceImpl::new(
settings.auth.jwt_secret.clone(),
)),
user: Arc::new(UserServiceImpl::new(db_connection.clone())),
},
user: Arc::new(UserServiceImpl::new(db_connection.clone())),
}),
}
}

View File

@@ -1 +1,3 @@
pub const ADMIN_INIT_SECRET_KEY: &str = "admin_init_secret";
//
pub const JWT_COOKIE_NAME: &str = "session_jwt";

View File

@@ -1,3 +1,5 @@
#![forbid(unsafe_code)]
mod cmd;
mod configs;
mod errors;

View File

@@ -1,16 +1,21 @@
pub mod request_info;
pub mod require_auth;
use std::{sync::Arc, time::Duration};
use axum::{
BoxError, Router,
error_handling::HandleErrorLayer,
http::{Method, StatusCode, Uri},
};
use std::time::Duration;
use tower::{ServiceBuilder, timeout::TimeoutLayer};
use tracing::warn;
use crate::routes::AppState;
pub const TIMEOUT_DURATION_SECS: u64 = 30;
pub fn apply_root_middleware(router: Router) -> Router {
pub fn apply_root_middleware(router: Router, _state: Arc<AppState>) -> Router {
let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS));
let service_builder = ServiceBuilder::new()

View File

@@ -0,0 +1,6 @@
use uuid::Uuid;
#[derive(Clone, Debug)]
pub struct RequestInfo {
pub user_id: Option<Uuid>,
}

View File

@@ -0,0 +1,68 @@
use std::sync::Arc;
use axum::{
extract::State,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
use axum_extra::extract::cookie::CookieJar;
use uuid::Uuid;
use crate::{
errors::service_error::ServiceError, helpers::constants::JWT_COOKIE_NAME,
middlewares::request_info::RequestInfo, routes::AppState,
};
pub async fn require_auth(
cookies: CookieJar,
State(state): State<Arc<AppState>>,
req: Request<axum::body::Body>,
next: Next,
) -> Result<Response, StatusCode> {
// get jwt from cookies
let auth_service = &state.service.auth_state.authentication;
let token = if let Some(cookie) = cookies.get(JWT_COOKIE_NAME) {
cookie.value().to_string()
} else {
return handle_unauthenticated().await;
};
// validate jwt
let is_valid = auth_service.is_valid_jwt(&token, None).await;
let user_id = match is_valid {
Ok(Some(claims)) => claims
.sub
.parse::<Uuid>()
.map_err(|_| StatusCode::UNAUTHORIZED)?,
Ok(None) => return handle_unauthenticated().await,
Err(err) => {
tracing::error!("Error validating JWT: {}", err);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
};
// ensure user exists
if let Err(err) = state.service.user.get_user_by_id(user_id, None).await {
match err {
ServiceError::NotFound(_) => return handle_unauthenticated().await,
_ => {
tracing::error!("Error fetching user by ID: {}", err);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
let mut req = req;
let user = req
.extensions_mut()
.get_or_insert_with(|| RequestInfo { user_id: None });
user.user_id = Some(user_id);
Ok(next.run(req).await)
}
async fn handle_unauthenticated() -> Result<Response, StatusCode> {
// TODO: log unauthenticated access attempts
Err(StatusCode::UNAUTHORIZED)
}

View File

@@ -11,7 +11,10 @@ use migration::sea_orm::DatabaseConnection;
use crate::{
middlewares,
services::{
auth::{authentication::AuthenticationService, user::UserService},
auth::{
authentication::{AuthenticationService, strategies::password::PasswordStrategy},
user::UserService,
},
settings::SettingsStore,
},
};
@@ -28,25 +31,35 @@ pub struct AppState {
pub type ServiceState<T> = Arc<T>;
pub struct AppService {
#[allow(dead_code)] // TODO: remove when used
pub settings: ServiceState<dyn SettingsStore>,
#[allow(dead_code)] // TODO: remove when used
pub struct AuthStrategy {
pub password: ServiceState<PasswordStrategy>,
}
pub struct AuthState {
pub strategy: AuthStrategy,
pub authentication: ServiceState<dyn AuthenticationService>,
#[allow(dead_code)] // TODO: remove when used
pub user: ServiceState<dyn UserService>,
}
pub struct AppService {
// #[allow(dead_code)] // TODO: remove when used
pub settings: ServiceState<dyn SettingsStore>,
pub auth_state: AuthState,
// #[allow(dead_code)] // TODO: remove when used
pub user: ServiceState<dyn UserService>,
}
pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
let mut router = Router::new();
let state = state.into();
router = router
.nest("/api", api::get_api_router())
.nest("/api", api::get_api_router(state.clone()))
.merge(view::get_view_router());
router = middlewares::apply_root_middleware(router);
router = middlewares::apply_root_middleware(router, state.clone());
router = router.layer(Extension(state.into()));
router = router.layer(Extension(state.clone()));
router
}

View File

@@ -1,13 +1,21 @@
mod auth;
mod health;
mod openapi;
mod restricted;
use std::sync::Arc;
use crate::routes::AppState;
pub use self::openapi::ApiDoc;
use axum::{Router, response::IntoResponse, routing::any};
pub fn get_api_router() -> Router {
pub fn get_api_router(state: Arc<AppState>) -> Router {
Router::new()
.nest("/health", health::get_health_router())
.merge(auth::get_basic_auth_router(state.clone()))
.merge(restricted::get_restricted_router(state.clone()))
// explicit fallback for unmatched API routes
.route("/{*wildcard}", any(api_fallback_handler))
}

View File

@@ -0,0 +1,16 @@
pub mod login;
use std::sync::Arc;
use axum::{
Router,
routing::{get, post},
};
use crate::routes::AppState;
pub fn get_basic_auth_router(state: Arc<AppState>) -> Router {
Router::new()
.route("/login", post(login::login))
.with_state(state)
}

View File

@@ -0,0 +1,98 @@
use std::sync::Arc;
use axum::{
Json,
body::Body,
extract::State,
http::{StatusCode, header::SET_COOKIE},
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use serde_json::{Value, from_value};
use tracing::{error, warn};
use crate::routes::{AppState, api::openapi::tag::AUTH_TAG};
/// Login request payload
#[derive(Serialize, Deserialize, utoipa::ToSchema)]
pub struct LoginRequest {
username: String,
password: String,
}
/// Login endpoint
///
/// Authenticates a user and returns a JWT in an HttpOnly cookie.
#[utoipa::path(
post,
path = "/api/auth/login",
request_body = LoginRequest,
responses(
(status = 200, description = "User authenticated successfully", body = ()),
(status = 401, description = "Authentication failed"),
(status = 500, description = "Internal server error"),
),
tag = AUTH_TAG,
)]
pub async fn login(State(state): State<Arc<AppState>>, Json(payload): Json<Value>) -> Response {
let login_request: LoginRequest = match from_value(payload) {
Ok(req) => req,
Err(e) => {
warn!("Invalid login request: {}", e);
return (StatusCode::BAD_REQUEST).into_response();
}
};
let user_id = match state
.service
.auth_state
.strategy
.password
.authenticate(&login_request.username, &login_request.password, None)
.await
{
Ok(user_id) => user_id,
Err(e) => {
warn!(
"Authentication failed for user {}: {}",
login_request.username, e
);
return (StatusCode::UNAUTHORIZED).into_response();
}
};
let (jwt, claims) = match state
.service
.auth_state
.authentication
.generate_jwt(user_id, 3600)
.await
{
Ok(token) => token,
Err(e) => {
error!("Error generating JWT for user {}: {}", user_id, e);
return (StatusCode::INTERNAL_SERVER_ERROR).into_response();
}
};
let response_builder = Response::builder()
.status(StatusCode::OK)
// add jwt as cookie
.header(
SET_COOKIE,
format!(
"token={}; HttpOnly; Path=/; Max-Age={}; SameSite=Strict;",
jwt,
claims.exp - claims.iat
),
)
.body(Body::from(()));
match response_builder {
Ok(resp) => resp,
Err(e) => {
error!("Error building response: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR).into_response();
}
}
}

View File

@@ -1,18 +1,22 @@
pub mod tag {
/// Health tag constant
pub const HEALTH_TAG: &str = "Health";
pub const AUTH_TAG: &str = "Authentication";
}
#[derive(utoipa::OpenApi)]
#[openapi(
paths(
crate::routes::api::health::info::get_health_info
crate::routes::api::health::info::get_health_info,
crate::routes::api::auth::login::login
),
components(
schemas(crate::routes::api::health::info::HealthInfo) // Register any schemas used in your paths
schemas(crate::routes::api::health::info::HealthInfo), // Register any schemas used in your paths
schemas(crate::routes::api::auth::login::LoginRequest)
),
tags(
(name = tag::HEALTH_TAG, description = "Health information API")
(name = tag::HEALTH_TAG, description = "Health information API"),
(name = tag::AUTH_TAG, description = "Authentication API")
)
)]
pub struct ApiDoc;

View File

@@ -0,0 +1,15 @@
use std::sync::Arc;
use axum::{Router, routing::get};
use crate::{middlewares::require_auth::require_auth, routes::AppState};
pub fn get_restricted_router(state: Arc<AppState>) -> Router {
Router::new()
//
//
.layer(axum::middleware::from_fn_with_state(
state.clone(),
require_auth,
))
}

View File

@@ -28,13 +28,16 @@ pub struct Claims {
#[async_trait::async_trait]
pub trait AuthenticationService: Send + Sync {
async fn generate_jwt(&self, user_id: Uuid, duration_secs: u64)
-> Result<String, ServiceError>;
async fn generate_jwt(
&self,
user_id: Uuid,
duration_secs: u64,
) -> Result<(String, Claims), ServiceError>;
async fn is_valid_jwt(
&self,
token: &str,
target_sub: Option<String>,
) -> Result<bool, ServiceError>;
) -> Result<Option<Claims>, ServiceError>;
async fn parse_jwt(&self, token: &str) -> Result<Claims, ServiceError>;
async fn invalidate_jwt(&self, token: &str) -> Result<(), ServiceError>;
async fn refresh_jwt(&self, token: &str, duration_secs: u64) -> Result<String, ServiceError>;
@@ -76,7 +79,7 @@ impl AuthenticationService for AuthenticationServiceImpl {
&self,
user_id: Uuid,
duration_secs: u64,
) -> Result<String, ServiceError> {
) -> Result<(String, Claims), ServiceError> {
let header = Header::default();
let expiration = chrono::Utc::now()
.checked_add_signed(chrono::Duration::seconds(duration_secs as i64))
@@ -95,23 +98,23 @@ impl AuthenticationService for AuthenticationServiceImpl {
&EncodingKey::from_secret(self.secret.as_ref()),
)
.map_err(|e| ServiceError::InternalError(format!("JWT generation error: {}", e)))?;
Ok(token)
Ok((token, claims))
}
async fn is_valid_jwt(
&self,
token: &str,
target_sub: Option<String>,
) -> Result<bool, ServiceError> {
) -> Result<Option<Claims>, ServiceError> {
let mut validation = Validation::default();
if let Some(expected_sub) = target_sub {
validation.sub = Some(expected_sub);
}
let decoding_key = DecodingKey::from_secret(self.secret.as_ref());
match decode::<Claims>(token, &decoding_key, &validation) {
Ok(_) => Ok(true),
Ok(data) => Ok(Some(data.claims)),
Err(err) => match *err.kind() {
InvalidToken | InvalidSubject | ExpiredSignature => Ok(false),
InvalidToken | InvalidSubject | ExpiredSignature => Ok(None),
_ => Err(ServiceError::InternalError(format!(
"JWT validation error: {}",
err
@@ -156,7 +159,7 @@ impl AuthenticationService for AuthenticationServiceImpl {
let user_id = Uuid::parse_str(&claims.sub).map_err(|e| {
ServiceError::InternalError(format!("Invalid user ID in JWT claims: {}", e))
})?;
let new_token = self.generate_jwt(user_id, duration_secs).await?;
let (new_token, _) = self.generate_jwt(user_id, duration_secs).await?;
Ok(new_token)
}
@@ -181,7 +184,7 @@ mod tests {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service
let (token, _) = service
.generate_jwt(user_id, 60)
.await
.expect("generate jwt");
@@ -190,8 +193,7 @@ mod tests {
.is_valid_jwt(&token, None)
.await
.expect("validate jwt");
assert!(valid, "Generated token should be valid");
assert!(valid.is_some(), "Generated token should be valid");
let claims = service.parse_jwt(&token).await.expect("parse jwt");
assert_eq!(claims.sub, user_id.to_string());
}
@@ -201,11 +203,14 @@ mod tests {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 60).await.unwrap();
let (token, _) = service.generate_jwt(user_id, 60).await.unwrap();
let other_sub = Uuid::new_v4().to_string();
let valid = service.is_valid_jwt(&token, Some(other_sub)).await.unwrap();
assert!(!valid, "Token should be invalid for a different subject");
assert!(
valid.is_none(),
"Token should be invalid for a different subject"
);
}
#[tokio::test]
@@ -221,7 +226,7 @@ mod tests {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 60).await.unwrap();
let (token, _) = service.generate_jwt(user_id, 60).await.unwrap();
let new_token = service.refresh_jwt(&token, 120).await.unwrap();
let claims = service.parse_jwt(&new_token).await.unwrap();
@@ -234,11 +239,11 @@ mod tests {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 1).await.unwrap();
let (token, _) = service.generate_jwt(user_id, 1).await.unwrap();
sleep(Duration::from_secs(2)).await;
let valid = service.is_valid_jwt(&token, None).await.unwrap();
assert!(!valid, "Token should be expired and thus invalid");
assert!(valid.is_none(), "Token should be expired and thus invalid");
}
#[tokio::test]
@@ -246,7 +251,7 @@ mod tests {
let service = AuthenticationServiceImpl::new(Some("secret".to_string()));
let user_id = Uuid::new_v4();
let token = service.generate_jwt(user_id, 1).await.unwrap();
let (token, _) = service.generate_jwt(user_id, 1).await.unwrap();
service.invalidate_jwt(&token).await.unwrap();