feat: implement CORS support with configuration options and middleware integration

This commit is contained in:
GW_MC
2025-12-19 21:34:12 +08:00
parent d861e0cd7d
commit b0b765b8fa
7 changed files with 86 additions and 9 deletions

15
Cargo.lock generated
View File

@@ -3975,6 +3975,20 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-http"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
dependencies = [
"bitflags 2.10.0",
"bytes",
"http",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
@@ -4713,6 +4727,7 @@ dependencies = [
"serde_json",
"tokio",
"tower",
"tower-http",
"tracing",
"tracing-subscriber",
"utoipa",

View File

@@ -27,4 +27,5 @@ once_cell = { version = "1.21.3" }
argon2 = { version = "0.5.3", features = ["std"] }
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] }
tower-http = { version = "0.6.8", features = ["cors"] }

View File

@@ -88,8 +88,10 @@ pub async fn start_server() {
// build the axum app and run the server...
info!("Starting application...");
let mut app: Router =
routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings)));
let mut app: Router = routes::get_root_router(
Arc::new(get_app_state(&db_connection, &settings)),
Arc::new(settings.server.cors.clone()),
);
if settings.server.serve_openapi {
info!("Enabling OpenAPI documentation endpoint at /openapi.json");

View File

@@ -4,6 +4,7 @@ pub(crate) const LOGGING_UTC_KEY: &str = "LOGGING.UTC";
pub(crate) const SERVER_ADDRESS_KEY: &str = "SERVER.ADDRESS";
pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT";
pub(crate) const SERVER_SERVE_OPENAPI_KEY: &str = "SERVER.SERVE_OPENAPI";
pub(crate) const SERVER_CORS_ALLOWED_ORIGINS_KEY: &str = "SERVER.CORS.ALLOWED_ORIGINS";
//
pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL";
pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS";

View File

@@ -3,7 +3,7 @@ use std::net::IpAddr;
use config::{Config, ConfigError};
use tracing::warn;
use crate::configs::key::SERVER_SERVE_OPENAPI_KEY;
use crate::configs::key::{SERVER_CORS_ALLOWED_ORIGINS_KEY, SERVER_SERVE_OPENAPI_KEY};
use super::{
FromConfig,
@@ -15,6 +15,12 @@ pub struct ServerSettings {
pub address: IpAddr,
pub port: u16,
pub serve_openapi: bool,
pub cors: CORSSettings,
}
#[derive(Debug, Clone)]
pub struct CORSSettings {
pub allowed_origins: Vec<String>,
}
impl FromConfig for ServerSettings {
@@ -57,6 +63,24 @@ impl FromConfig for ServerSettings {
);
DEFAULT_SERVE_OPENAPI
}),
cors: CORSSettings {
allowed_origins: _config
.get_array(SERVER_CORS_ALLOWED_ORIGINS_KEY)
.unwrap_or_else(|_| vec![])
.into_iter()
.filter_map(|val| match val.into_string() {
Ok(s) => Some(s),
Err(e) => {
warn!(
"Invalid origin in {} configuration: {}",
SERVER_CORS_ALLOWED_ORIGINS_KEY, e
);
None
}
})
.collect(),
},
})
}

View File

@@ -6,25 +6,55 @@ use std::{sync::Arc, time::Duration};
use axum::{
BoxError, Router,
error_handling::HandleErrorLayer,
http::{Method, StatusCode, Uri},
http::{HeaderValue, Method, StatusCode, Uri},
};
use tower::{ServiceBuilder, timeout::TimeoutLayer};
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
use tracing::warn;
use crate::routes::AppState;
use crate::{configs::server::CORSSettings, routes::AppState};
pub const TIMEOUT_DURATION_SECS: u64 = 30;
pub fn apply_root_middleware(router: Router, _state: Arc<AppState>) -> Router {
pub fn apply_root_middleware(
router: Router,
_state: Arc<AppState>,
cors_settings: Arc<CORSSettings>,
) -> Router {
let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS));
let service_builder = ServiceBuilder::new()
.layer(HandleErrorLayer::new(handle_timeout_error))
.layer(timeout_layer);
.layer(timeout_layer)
.layer(get_cors_layer(cors_settings));
router.layer(service_builder)
}
pub fn get_cors_layer(cors_settings: Arc<CORSSettings>) -> CorsLayer {
let mut cors_layer = CorsLayer::new()
.allow_credentials(true)
.allow_headers(AllowHeaders::mirror_request());
let allowed_origins = &cors_settings.allowed_origins;
if allowed_origins.contains(&"*".to_string()) {
cors_layer = cors_layer.allow_origin(AllowOrigin::mirror_request());
warn!(
"Wildcard origin is found in allowed origins. CORS is configured to allow requests from any origin. Only use this setting in development or if you understand the security implications."
);
} else {
for origin in allowed_origins {
if let Ok(header_value) = HeaderValue::from_str(origin) {
cors_layer = cors_layer.allow_origin(AllowOrigin::exact(header_value));
} else {
warn!("Invalid CORS origin: {}", origin);
}
}
}
cors_layer
}
pub async fn handle_timeout_error(
method: Method,
uri: Uri,

View File

@@ -9,6 +9,7 @@ use axum::{Extension, Router};
use migration::sea_orm::DatabaseConnection;
use crate::{
configs::server::CORSSettings,
middlewares,
services::{
auth::{
@@ -46,7 +47,10 @@ pub struct AppService {
pub server_state: ServiceState<dyn ServerStateStore>,
}
pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
pub fn get_root_router(
state: impl Into<Arc<AppState>>,
cors_settings: Arc<CORSSettings>,
) -> Router {
let mut router = Router::new();
let state = state.into();
@@ -54,7 +58,7 @@ pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
.nest("/api", api::get_api_router(state.clone()))
.merge(view::get_view_router());
router = middlewares::apply_root_middleware(router, state.clone());
router = middlewares::apply_root_middleware(router, state.clone(), cors_settings);
router = router.layer(Extension(state.clone()));