feat: implement CORS support with configuration options and middleware integration
This commit is contained in:
@@ -27,4 +27,5 @@ once_cell = { version = "1.21.3" }
|
||||
argon2 = { version = "0.5.3", features = ["std"] }
|
||||
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
|
||||
uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] }
|
||||
tower-http = { version = "0.6.8", features = ["cors"] }
|
||||
|
||||
|
||||
@@ -88,8 +88,10 @@ pub async fn start_server() {
|
||||
|
||||
// build the axum app and run the server...
|
||||
info!("Starting application...");
|
||||
let mut app: Router =
|
||||
routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings)));
|
||||
let mut app: Router = routes::get_root_router(
|
||||
Arc::new(get_app_state(&db_connection, &settings)),
|
||||
Arc::new(settings.server.cors.clone()),
|
||||
);
|
||||
|
||||
if settings.server.serve_openapi {
|
||||
info!("Enabling OpenAPI documentation endpoint at /openapi.json");
|
||||
|
||||
@@ -4,6 +4,7 @@ pub(crate) const LOGGING_UTC_KEY: &str = "LOGGING.UTC";
|
||||
pub(crate) const SERVER_ADDRESS_KEY: &str = "SERVER.ADDRESS";
|
||||
pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT";
|
||||
pub(crate) const SERVER_SERVE_OPENAPI_KEY: &str = "SERVER.SERVE_OPENAPI";
|
||||
pub(crate) const SERVER_CORS_ALLOWED_ORIGINS_KEY: &str = "SERVER.CORS.ALLOWED_ORIGINS";
|
||||
//
|
||||
pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL";
|
||||
pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS";
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::net::IpAddr;
|
||||
use config::{Config, ConfigError};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::configs::key::SERVER_SERVE_OPENAPI_KEY;
|
||||
use crate::configs::key::{SERVER_CORS_ALLOWED_ORIGINS_KEY, SERVER_SERVE_OPENAPI_KEY};
|
||||
|
||||
use super::{
|
||||
FromConfig,
|
||||
@@ -15,6 +15,12 @@ pub struct ServerSettings {
|
||||
pub address: IpAddr,
|
||||
pub port: u16,
|
||||
pub serve_openapi: bool,
|
||||
pub cors: CORSSettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CORSSettings {
|
||||
pub allowed_origins: Vec<String>,
|
||||
}
|
||||
|
||||
impl FromConfig for ServerSettings {
|
||||
@@ -57,6 +63,24 @@ impl FromConfig for ServerSettings {
|
||||
);
|
||||
DEFAULT_SERVE_OPENAPI
|
||||
}),
|
||||
|
||||
cors: CORSSettings {
|
||||
allowed_origins: _config
|
||||
.get_array(SERVER_CORS_ALLOWED_ORIGINS_KEY)
|
||||
.unwrap_or_else(|_| vec![])
|
||||
.into_iter()
|
||||
.filter_map(|val| match val.into_string() {
|
||||
Ok(s) => Some(s),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Invalid origin in {} configuration: {}",
|
||||
SERVER_CORS_ALLOWED_ORIGINS_KEY, e
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,25 +6,55 @@ use std::{sync::Arc, time::Duration};
|
||||
use axum::{
|
||||
BoxError, Router,
|
||||
error_handling::HandleErrorLayer,
|
||||
http::{Method, StatusCode, Uri},
|
||||
http::{HeaderValue, Method, StatusCode, Uri},
|
||||
};
|
||||
use tower::{ServiceBuilder, timeout::TimeoutLayer};
|
||||
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::routes::AppState;
|
||||
use crate::{configs::server::CORSSettings, routes::AppState};
|
||||
|
||||
pub const TIMEOUT_DURATION_SECS: u64 = 30;
|
||||
|
||||
pub fn apply_root_middleware(router: Router, _state: Arc<AppState>) -> Router {
|
||||
pub fn apply_root_middleware(
|
||||
router: Router,
|
||||
_state: Arc<AppState>,
|
||||
cors_settings: Arc<CORSSettings>,
|
||||
) -> Router {
|
||||
let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS));
|
||||
|
||||
let service_builder = ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(handle_timeout_error))
|
||||
.layer(timeout_layer);
|
||||
.layer(timeout_layer)
|
||||
.layer(get_cors_layer(cors_settings));
|
||||
|
||||
router.layer(service_builder)
|
||||
}
|
||||
|
||||
pub fn get_cors_layer(cors_settings: Arc<CORSSettings>) -> CorsLayer {
|
||||
let mut cors_layer = CorsLayer::new()
|
||||
.allow_credentials(true)
|
||||
.allow_headers(AllowHeaders::mirror_request());
|
||||
|
||||
let allowed_origins = &cors_settings.allowed_origins;
|
||||
if allowed_origins.contains(&"*".to_string()) {
|
||||
cors_layer = cors_layer.allow_origin(AllowOrigin::mirror_request());
|
||||
warn!(
|
||||
"Wildcard origin is found in allowed origins. CORS is configured to allow requests from any origin. Only use this setting in development or if you understand the security implications."
|
||||
);
|
||||
} else {
|
||||
for origin in allowed_origins {
|
||||
if let Ok(header_value) = HeaderValue::from_str(origin) {
|
||||
cors_layer = cors_layer.allow_origin(AllowOrigin::exact(header_value));
|
||||
} else {
|
||||
warn!("Invalid CORS origin: {}", origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cors_layer
|
||||
}
|
||||
|
||||
pub async fn handle_timeout_error(
|
||||
method: Method,
|
||||
uri: Uri,
|
||||
|
||||
@@ -9,6 +9,7 @@ use axum::{Extension, Router};
|
||||
use migration::sea_orm::DatabaseConnection;
|
||||
|
||||
use crate::{
|
||||
configs::server::CORSSettings,
|
||||
middlewares,
|
||||
services::{
|
||||
auth::{
|
||||
@@ -46,7 +47,10 @@ pub struct AppService {
|
||||
pub server_state: ServiceState<dyn ServerStateStore>,
|
||||
}
|
||||
|
||||
pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
|
||||
pub fn get_root_router(
|
||||
state: impl Into<Arc<AppState>>,
|
||||
cors_settings: Arc<CORSSettings>,
|
||||
) -> Router {
|
||||
let mut router = Router::new();
|
||||
let state = state.into();
|
||||
|
||||
@@ -54,7 +58,7 @@ pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
|
||||
.nest("/api", api::get_api_router(state.clone()))
|
||||
.merge(view::get_view_router());
|
||||
|
||||
router = middlewares::apply_root_middleware(router, state.clone());
|
||||
router = middlewares::apply_root_middleware(router, state.clone(), cors_settings);
|
||||
|
||||
router = router.layer(Extension(state.clone()));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user