70 lines
2.1 KiB
Rust
70 lines
2.1 KiB
Rust
pub mod request_info;
|
|
pub mod require_auth;
|
|
|
|
use std::{sync::Arc, time::Duration};
|
|
|
|
use axum::{
|
|
BoxError, Router,
|
|
error_handling::HandleErrorLayer,
|
|
http::{HeaderValue, Method, StatusCode, Uri},
|
|
};
|
|
use tower::{ServiceBuilder, timeout::TimeoutLayer};
|
|
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
|
|
use tracing::warn;
|
|
|
|
use crate::{configs::server::CORSSettings, routes::AppState};
|
|
|
|
pub const TIMEOUT_DURATION_SECS: u64 = 30;
|
|
|
|
pub fn apply_root_middleware(
|
|
router: Router,
|
|
_state: Arc<AppState>,
|
|
cors_settings: Arc<CORSSettings>,
|
|
) -> Router {
|
|
let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS));
|
|
|
|
let service_builder = ServiceBuilder::new()
|
|
.layer(HandleErrorLayer::new(handle_timeout_error))
|
|
.layer(timeout_layer)
|
|
.layer(get_cors_layer(cors_settings));
|
|
|
|
router.layer(service_builder)
|
|
}
|
|
|
|
pub fn get_cors_layer(cors_settings: Arc<CORSSettings>) -> CorsLayer {
|
|
let mut cors_layer = CorsLayer::new()
|
|
.allow_credentials(true)
|
|
.allow_headers(AllowHeaders::mirror_request());
|
|
|
|
let allowed_origins = &cors_settings.allowed_origins;
|
|
if allowed_origins.contains(&"*".to_string()) {
|
|
cors_layer = cors_layer.allow_origin(AllowOrigin::mirror_request());
|
|
warn!(
|
|
"Wildcard origin is found in allowed origins. CORS is configured to allow requests from any origin. Only use this setting in development or if you understand the security implications."
|
|
);
|
|
} else {
|
|
for origin in allowed_origins {
|
|
if let Ok(header_value) = HeaderValue::from_str(origin) {
|
|
cors_layer = cors_layer.allow_origin(AllowOrigin::exact(header_value));
|
|
} else {
|
|
warn!("Invalid CORS origin: {}", origin);
|
|
}
|
|
}
|
|
}
|
|
|
|
cors_layer
|
|
}
|
|
|
|
pub async fn handle_timeout_error(
|
|
method: Method,
|
|
uri: Uri,
|
|
//
|
|
err: BoxError,
|
|
) -> (StatusCode, String) {
|
|
warn!("`{method} {uri}` failed with {err}");
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
"Internal server error".to_string(),
|
|
)
|
|
}
|