feat: implement CORS support with configuration options and middleware integration
This commit is contained in:
15
Cargo.lock
generated
15
Cargo.lock
generated
@@ -3975,6 +3975,20 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.6.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"bytes",
|
||||
"http",
|
||||
"pin-project-lite",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.3"
|
||||
@@ -4713,6 +4727,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"utoipa",
|
||||
|
||||
@@ -27,4 +27,5 @@ once_cell = { version = "1.21.3" }
|
||||
argon2 = { version = "0.5.3", features = ["std"] }
|
||||
jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] }
|
||||
uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] }
|
||||
tower-http = { version = "0.6.8", features = ["cors"] }
|
||||
|
||||
|
||||
@@ -88,8 +88,10 @@ pub async fn start_server() {
|
||||
|
||||
// build the axum app and run the server...
|
||||
info!("Starting application...");
|
||||
let mut app: Router =
|
||||
routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings)));
|
||||
let mut app: Router = routes::get_root_router(
|
||||
Arc::new(get_app_state(&db_connection, &settings)),
|
||||
Arc::new(settings.server.cors.clone()),
|
||||
);
|
||||
|
||||
if settings.server.serve_openapi {
|
||||
info!("Enabling OpenAPI documentation endpoint at /openapi.json");
|
||||
|
||||
@@ -4,6 +4,7 @@ pub(crate) const LOGGING_UTC_KEY: &str = "LOGGING.UTC";
|
||||
pub(crate) const SERVER_ADDRESS_KEY: &str = "SERVER.ADDRESS";
|
||||
pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT";
|
||||
pub(crate) const SERVER_SERVE_OPENAPI_KEY: &str = "SERVER.SERVE_OPENAPI";
|
||||
pub(crate) const SERVER_CORS_ALLOWED_ORIGINS_KEY: &str = "SERVER.CORS.ALLOWED_ORIGINS";
|
||||
//
|
||||
pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL";
|
||||
pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS";
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::net::IpAddr;
|
||||
use config::{Config, ConfigError};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::configs::key::SERVER_SERVE_OPENAPI_KEY;
|
||||
use crate::configs::key::{SERVER_CORS_ALLOWED_ORIGINS_KEY, SERVER_SERVE_OPENAPI_KEY};
|
||||
|
||||
use super::{
|
||||
FromConfig,
|
||||
@@ -15,6 +15,12 @@ pub struct ServerSettings {
|
||||
pub address: IpAddr,
|
||||
pub port: u16,
|
||||
pub serve_openapi: bool,
|
||||
pub cors: CORSSettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CORSSettings {
|
||||
pub allowed_origins: Vec<String>,
|
||||
}
|
||||
|
||||
impl FromConfig for ServerSettings {
|
||||
@@ -57,6 +63,24 @@ impl FromConfig for ServerSettings {
|
||||
);
|
||||
DEFAULT_SERVE_OPENAPI
|
||||
}),
|
||||
|
||||
cors: CORSSettings {
|
||||
allowed_origins: _config
|
||||
.get_array(SERVER_CORS_ALLOWED_ORIGINS_KEY)
|
||||
.unwrap_or_else(|_| vec![])
|
||||
.into_iter()
|
||||
.filter_map(|val| match val.into_string() {
|
||||
Ok(s) => Some(s),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Invalid origin in {} configuration: {}",
|
||||
SERVER_CORS_ALLOWED_ORIGINS_KEY, e
|
||||
);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,25 +6,55 @@ use std::{sync::Arc, time::Duration};
|
||||
use axum::{
|
||||
BoxError, Router,
|
||||
error_handling::HandleErrorLayer,
|
||||
http::{Method, StatusCode, Uri},
|
||||
http::{HeaderValue, Method, StatusCode, Uri},
|
||||
};
|
||||
use tower::{ServiceBuilder, timeout::TimeoutLayer};
|
||||
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::routes::AppState;
|
||||
use crate::{configs::server::CORSSettings, routes::AppState};
|
||||
|
||||
pub const TIMEOUT_DURATION_SECS: u64 = 30;
|
||||
|
||||
pub fn apply_root_middleware(router: Router, _state: Arc<AppState>) -> Router {
|
||||
pub fn apply_root_middleware(
|
||||
router: Router,
|
||||
_state: Arc<AppState>,
|
||||
cors_settings: Arc<CORSSettings>,
|
||||
) -> Router {
|
||||
let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS));
|
||||
|
||||
let service_builder = ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(handle_timeout_error))
|
||||
.layer(timeout_layer);
|
||||
.layer(timeout_layer)
|
||||
.layer(get_cors_layer(cors_settings));
|
||||
|
||||
router.layer(service_builder)
|
||||
}
|
||||
|
||||
pub fn get_cors_layer(cors_settings: Arc<CORSSettings>) -> CorsLayer {
|
||||
let mut cors_layer = CorsLayer::new()
|
||||
.allow_credentials(true)
|
||||
.allow_headers(AllowHeaders::mirror_request());
|
||||
|
||||
let allowed_origins = &cors_settings.allowed_origins;
|
||||
if allowed_origins.contains(&"*".to_string()) {
|
||||
cors_layer = cors_layer.allow_origin(AllowOrigin::mirror_request());
|
||||
warn!(
|
||||
"Wildcard origin is found in allowed origins. CORS is configured to allow requests from any origin. Only use this setting in development or if you understand the security implications."
|
||||
);
|
||||
} else {
|
||||
for origin in allowed_origins {
|
||||
if let Ok(header_value) = HeaderValue::from_str(origin) {
|
||||
cors_layer = cors_layer.allow_origin(AllowOrigin::exact(header_value));
|
||||
} else {
|
||||
warn!("Invalid CORS origin: {}", origin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cors_layer
|
||||
}
|
||||
|
||||
pub async fn handle_timeout_error(
|
||||
method: Method,
|
||||
uri: Uri,
|
||||
|
||||
@@ -9,6 +9,7 @@ use axum::{Extension, Router};
|
||||
use migration::sea_orm::DatabaseConnection;
|
||||
|
||||
use crate::{
|
||||
configs::server::CORSSettings,
|
||||
middlewares,
|
||||
services::{
|
||||
auth::{
|
||||
@@ -46,7 +47,10 @@ pub struct AppService {
|
||||
pub server_state: ServiceState<dyn ServerStateStore>,
|
||||
}
|
||||
|
||||
pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
|
||||
pub fn get_root_router(
|
||||
state: impl Into<Arc<AppState>>,
|
||||
cors_settings: Arc<CORSSettings>,
|
||||
) -> Router {
|
||||
let mut router = Router::new();
|
||||
let state = state.into();
|
||||
|
||||
@@ -54,7 +58,7 @@ pub fn get_root_router(state: impl Into<Arc<AppState>>) -> Router {
|
||||
.nest("/api", api::get_api_router(state.clone()))
|
||||
.merge(view::get_view_router());
|
||||
|
||||
router = middlewares::apply_root_middleware(router, state.clone());
|
||||
router = middlewares::apply_root_middleware(router, state.clone(), cors_settings);
|
||||
|
||||
router = router.layer(Extension(state.clone()));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user