diff --git a/Cargo.lock b/Cargo.lock index 1f96fe9..6966840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3975,6 +3975,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags 2.10.0", + "bytes", + "http", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -4713,6 +4727,7 @@ dependencies = [ "serde_json", "tokio", "tower", + "tower-http", "tracing", "tracing-subscriber", "utoipa", diff --git a/apps/api/Cargo.toml b/apps/api/Cargo.toml index 8e400da..593f5ec 100644 --- a/apps/api/Cargo.toml +++ b/apps/api/Cargo.toml @@ -27,4 +27,5 @@ once_cell = { version = "1.21.3" } argon2 = { version = "0.5.3", features = ["std"] } jsonwebtoken = { version = "10.2.0", features = ["rust_crypto"] } uuid = { version = "1.19.0", features = ["v4", "serde", "fast-rng"] } +tower-http = { version = "0.6.8", features = ["cors"] } diff --git a/apps/api/src/cmd/start_server.rs b/apps/api/src/cmd/start_server.rs index e49f5a8..7545156 100644 --- a/apps/api/src/cmd/start_server.rs +++ b/apps/api/src/cmd/start_server.rs @@ -88,8 +88,10 @@ pub async fn start_server() { // build the axum app and run the server... info!("Starting application..."); - let mut app: Router = - routes::get_root_router(Arc::new(get_app_state(&db_connection, &settings))); + let mut app: Router = routes::get_root_router( + Arc::new(get_app_state(&db_connection, &settings)), + Arc::new(settings.server.cors.clone()), + ); if settings.server.serve_openapi { info!("Enabling OpenAPI documentation endpoint at /openapi.json"); diff --git a/apps/api/src/configs/key.rs b/apps/api/src/configs/key.rs index 0bc3baf..204be32 100644 --- a/apps/api/src/configs/key.rs +++ b/apps/api/src/configs/key.rs @@ -4,6 +4,7 @@ pub(crate) const LOGGING_UTC_KEY: &str = "LOGGING.UTC"; pub(crate) const SERVER_ADDRESS_KEY: &str = "SERVER.ADDRESS"; pub(crate) const SERVER_PORT_KEY: &str = "SERVER.PORT"; pub(crate) const SERVER_SERVE_OPENAPI_KEY: &str = "SERVER.SERVE_OPENAPI"; +pub(crate) const SERVER_CORS_ALLOWED_ORIGINS_KEY: &str = "SERVER.CORS.ALLOWED_ORIGINS"; // pub(crate) const DATABASE_URL_KEY: &str = "DATABASE.URL"; pub(crate) const DATABASE_MAX_CONNECTIONS_KEY: &str = "DATABASE.MAX_CONNECTIONS"; diff --git a/apps/api/src/configs/server.rs b/apps/api/src/configs/server.rs index a79ce14..0ff7bae 100644 --- a/apps/api/src/configs/server.rs +++ b/apps/api/src/configs/server.rs @@ -3,7 +3,7 @@ use std::net::IpAddr; use config::{Config, ConfigError}; use tracing::warn; -use crate::configs::key::SERVER_SERVE_OPENAPI_KEY; +use crate::configs::key::{SERVER_CORS_ALLOWED_ORIGINS_KEY, SERVER_SERVE_OPENAPI_KEY}; use super::{ FromConfig, @@ -15,6 +15,12 @@ pub struct ServerSettings { pub address: IpAddr, pub port: u16, pub serve_openapi: bool, + pub cors: CORSSettings, +} + +#[derive(Debug, Clone)] +pub struct CORSSettings { + pub allowed_origins: Vec, } impl FromConfig for ServerSettings { @@ -57,6 +63,24 @@ impl FromConfig for ServerSettings { ); DEFAULT_SERVE_OPENAPI }), + + cors: CORSSettings { + allowed_origins: _config + .get_array(SERVER_CORS_ALLOWED_ORIGINS_KEY) + .unwrap_or_else(|_| vec![]) + .into_iter() + .filter_map(|val| match val.into_string() { + Ok(s) => Some(s), + Err(e) => { + warn!( + "Invalid origin in {} configuration: {}", + SERVER_CORS_ALLOWED_ORIGINS_KEY, e + ); + None + } + }) + .collect(), + }, }) } diff --git a/apps/api/src/middlewares.rs b/apps/api/src/middlewares.rs index 69935b1..8799c1d 100644 --- a/apps/api/src/middlewares.rs +++ b/apps/api/src/middlewares.rs @@ -6,25 +6,55 @@ use std::{sync::Arc, time::Duration}; use axum::{ BoxError, Router, error_handling::HandleErrorLayer, - http::{Method, StatusCode, Uri}, + http::{HeaderValue, Method, StatusCode, Uri}, }; use tower::{ServiceBuilder, timeout::TimeoutLayer}; +use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer}; use tracing::warn; -use crate::routes::AppState; +use crate::{configs::server::CORSSettings, routes::AppState}; pub const TIMEOUT_DURATION_SECS: u64 = 30; -pub fn apply_root_middleware(router: Router, _state: Arc) -> Router { +pub fn apply_root_middleware( + router: Router, + _state: Arc, + cors_settings: Arc, +) -> Router { let timeout_layer = TimeoutLayer::new(Duration::from_secs(TIMEOUT_DURATION_SECS)); let service_builder = ServiceBuilder::new() .layer(HandleErrorLayer::new(handle_timeout_error)) - .layer(timeout_layer); + .layer(timeout_layer) + .layer(get_cors_layer(cors_settings)); router.layer(service_builder) } +pub fn get_cors_layer(cors_settings: Arc) -> CorsLayer { + let mut cors_layer = CorsLayer::new() + .allow_credentials(true) + .allow_headers(AllowHeaders::mirror_request()); + + let allowed_origins = &cors_settings.allowed_origins; + if allowed_origins.contains(&"*".to_string()) { + cors_layer = cors_layer.allow_origin(AllowOrigin::mirror_request()); + warn!( + "Wildcard origin is found in allowed origins. CORS is configured to allow requests from any origin. Only use this setting in development or if you understand the security implications." + ); + } else { + for origin in allowed_origins { + if let Ok(header_value) = HeaderValue::from_str(origin) { + cors_layer = cors_layer.allow_origin(AllowOrigin::exact(header_value)); + } else { + warn!("Invalid CORS origin: {}", origin); + } + } + } + + cors_layer +} + pub async fn handle_timeout_error( method: Method, uri: Uri, diff --git a/apps/api/src/routes.rs b/apps/api/src/routes.rs index 41cc73b..37d1cdd 100644 --- a/apps/api/src/routes.rs +++ b/apps/api/src/routes.rs @@ -9,6 +9,7 @@ use axum::{Extension, Router}; use migration::sea_orm::DatabaseConnection; use crate::{ + configs::server::CORSSettings, middlewares, services::{ auth::{ @@ -46,7 +47,10 @@ pub struct AppService { pub server_state: ServiceState, } -pub fn get_root_router(state: impl Into>) -> Router { +pub fn get_root_router( + state: impl Into>, + cors_settings: Arc, +) -> Router { let mut router = Router::new(); let state = state.into(); @@ -54,7 +58,7 @@ pub fn get_root_router(state: impl Into>) -> Router { .nest("/api", api::get_api_router(state.clone())) .merge(view::get_view_router()); - router = middlewares::apply_root_middleware(router, state.clone()); + router = middlewares::apply_root_middleware(router, state.clone(), cors_settings); router = router.layer(Extension(state.clone()));