110 lines
3.2 KiB
Rust
110 lines
3.2 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::{
|
|
extract::State,
|
|
http::{Request, StatusCode},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use axum_extra::extract::cookie::CookieJar;
|
|
use tracing::debug;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{
|
|
errors::service_error::ServiceError, helpers::constants::JWT_COOKIE_NAME,
|
|
middlewares::request_info::RequestInfo, routes::AppState,
|
|
};
|
|
|
|
pub async fn require_auth(
|
|
cookies: CookieJar,
|
|
State(state): State<Arc<AppState>>,
|
|
req: Request<axum::body::Body>,
|
|
next: Next,
|
|
) -> Result<Response, StatusCode> {
|
|
// get jwt from cookies
|
|
let auth_service = &state.service.auth_state.authentication;
|
|
let token = if let Some(cookie) = cookies.get(JWT_COOKIE_NAME) {
|
|
cookie.value().to_string()
|
|
} else {
|
|
debug!("No JWT cookie found. cookies: {:?}", cookies);
|
|
return handle_unauthenticated().await;
|
|
};
|
|
|
|
// validate jwt
|
|
let is_valid = auth_service.is_valid_jwt(&token, None).await;
|
|
let user_id = match is_valid {
|
|
Ok(Some(claims)) => claims
|
|
.sub
|
|
.parse::<Uuid>()
|
|
.map_err(|_| StatusCode::UNAUTHORIZED)?,
|
|
Ok(None) => return handle_unauthenticated().await,
|
|
Err(err) => {
|
|
tracing::error!("Error validating JWT: {}", err);
|
|
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
|
}
|
|
};
|
|
|
|
// ensure user exists
|
|
if let Err(err) = state.service.user.get_user_by_id(user_id, None).await {
|
|
match err {
|
|
ServiceError::NotFound(_) => return handle_unauthenticated().await,
|
|
_ => {
|
|
tracing::error!("Error fetching user by ID: {}", err);
|
|
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut req = req;
|
|
let user = req
|
|
.extensions_mut()
|
|
.get_or_insert_with(|| RequestInfo { user_id: None });
|
|
user.user_id = Some(user_id);
|
|
|
|
Ok(next.run(req).await)
|
|
}
|
|
|
|
async fn handle_unauthenticated() -> Result<Response, StatusCode> {
|
|
// TODO: log unauthenticated access attempts
|
|
Err(StatusCode::UNAUTHORIZED)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub mod mock {
|
|
|
|
use super::*;
|
|
|
|
pub const REQUEST_AUTH_USER_ID_HEADER: &str = "x-mock-authenticated-user-id";
|
|
pub const REQUEST_AUTH_USER_INVALID_HEADER: &str = "x-mock-authenticated-invalid";
|
|
|
|
pub async fn mock_require_auth(
|
|
req: Request<axum::body::Body>,
|
|
next: Next,
|
|
) -> Result<Response, StatusCode> {
|
|
let mut req = req;
|
|
let invalid_present = req
|
|
.headers()
|
|
.get(REQUEST_AUTH_USER_INVALID_HEADER)
|
|
.is_some();
|
|
let user_id_header = req.headers().get(REQUEST_AUTH_USER_ID_HEADER).cloned();
|
|
|
|
if invalid_present {
|
|
return handle_unauthenticated().await;
|
|
}
|
|
|
|
let user = req
|
|
.extensions_mut()
|
|
.get_or_insert_with(|| RequestInfo { user_id: None });
|
|
user.user_id = Some(if let Some(user_id_header) = user_id_header {
|
|
let user_id_str = user_id_header
|
|
.to_str()
|
|
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
|
Uuid::parse_str(user_id_str).map_err(|_| StatusCode::UNAUTHORIZED)?
|
|
} else {
|
|
Uuid::new_v4()
|
|
});
|
|
|
|
Ok(next.run(req).await)
|
|
}
|
|
}
|