pub mod request_info; pub mod require_auth; use std::{sync::Arc, time::Duration}; use axum::{ BoxError, Router, error_handling::HandleErrorLayer, http::{HeaderValue, Method, StatusCode, Uri}, }; use tower::{ServiceBuilder, timeout::TimeoutLayer}; use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer}; use tracing::warn; use crate::{configs::server::CORSSettings, routes::AppState}; pub const TIMEOUT_DURATION_SECS: u64 = 30; pub fn apply_root_middleware( router: Router, _state: Arc, cors_settings: Arc, ) -> Router { let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS)); let service_builder = ServiceBuilder::new() .layer(HandleErrorLayer::new(handle_timeout_error)) .layer(timeout_layer) .layer(get_cors_layer(cors_settings)); router.layer(service_builder) } pub fn get_cors_layer(cors_settings: Arc) -> CorsLayer { let mut cors_layer = CorsLayer::new() .allow_credentials(true) .allow_headers(AllowHeaders::mirror_request()); let allowed_origins = &cors_settings.allowed_origins; if allowed_origins.contains(&"*".to_string()) { cors_layer = cors_layer.allow_origin(AllowOrigin::mirror_request()); warn!( "Wildcard origin is found in allowed origins. CORS is configured to allow requests from any origin. Only use this setting in development or if you understand the security implications." ); } else { for origin in allowed_origins { if let Ok(header_value) = HeaderValue::from_str(origin) { cors_layer = cors_layer.allow_origin(AllowOrigin::exact(header_value)); } else { warn!("Invalid CORS origin: {}", origin); } } } cors_layer } pub async fn handle_timeout_error( method: Method, uri: Uri, // err: BoxError, ) -> (StatusCode, String) { warn!("`{method} {uri}` failed with {err}"); ( StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), ) }