use std::sync::Arc; use sea_orm::{ ActiveModelTrait, ColumnTrait, DatabaseConnection, DatabaseTransaction, EntityTrait, ExprTrait, FromQueryResult, ModelTrait, QueryFilter, QuerySelect, QueryTrait, TransactionTrait, }; use database::generated::entities::{upstream, upstream_target}; use crate::{ errors::service_error::ServiceError, helpers::database::PaginationFilter, services::nginx::{ builder::NginxConfigBuilder, info::{ upstream::{UpdateUpstreamInfo, UpstreamCreateInfo, UpstreamInfo}, upstream_target::{ UpdateUpstreamTargetInfo, UpstreamTargetCreateInfo, UpstreamTargetInfo, }, }, }, with_conn, }; #[async_trait::async_trait] pub trait UpstreamService: Send + Sync { async fn create_upstream( &self, create_info: UpstreamCreateInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result; async fn get_total_upstreams( &self, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result; async fn get_upstream( &self, upstream_id: uuid::Uuid, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result; async fn get_upstreams( &self, pagination: Option, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result, ServiceError>; async fn update_upstream( &self, id: uuid::Uuid, upstream: UpdateUpstreamInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result; async fn delete_upstream( &self, upstream_id: uuid::Uuid, tx: Option<&mut DatabaseTransaction>, ) -> Result<(), ServiceError>; async fn create_upstream_target( &self, create_info: UpstreamTargetCreateInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result; async fn get_upstream_target( &self, target_id: uuid::Uuid, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result; #[allow(dead_code)] async fn get_upstream_targets_by_upstream( &self, upstream_id: uuid::Uuid, tx: Option<&mut DatabaseTransaction>, ) -> Result, ServiceError>; async fn update_upstream_target( &self, id: uuid::Uuid, target: UpdateUpstreamTargetInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result; async fn delete_upstream_target( &self, target_id: uuid::Uuid, tx: Option<&mut DatabaseTransaction>, ) -> Result<(), ServiceError>; async fn generate_config( &self, builder: &mut NginxConfigBuilder, tx: &Option<&mut DatabaseTransaction>, ) -> Result<(), ServiceError>; } pub struct UpstreamServiceImpl { connection: Arc, } #[derive(Default)] pub struct GetUpstreamOptions { pub include_targets: bool, pub filter_by_enabled: bool, } #[allow(dead_code)] pub struct UpstreamTotalCountOptions {} #[derive(Default)] pub struct GetUpstreamTargetOptions { pub include_upstream: bool, } impl UpstreamServiceImpl { pub fn new(connection: Arc) -> Self { Self { connection } } } #[async_trait::async_trait] impl UpstreamService for UpstreamServiceImpl { async fn create_upstream( &self, create_info: UpstreamCreateInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result { let (upstream_model, upstream_target_models): ( upstream::ActiveModel, Vec, ) = create_info.into(); // If a transaction was provided use it, otherwise create and own one here. let mut maybe_owned_tx: Option = None; let tx_ref: Option<&mut DatabaseTransaction> = if let Some(tx) = tx { Some(tx) } else { maybe_owned_tx = Some(self.connection.begin().await?); maybe_owned_tx.as_mut() }; let r = with_conn!(&*self.connection, tx_ref, conn, { let created_upstream = upstream_model.insert(*conn).await?; let created_targets = upstream_target::Entity::insert_many( upstream_target_models .into_iter() .map(|mut model| { model.upstream_id = sea_orm::ActiveValue::Set(created_upstream.id); model }) .collect::>(), ) .exec_with_returning(*conn) .await?; (created_upstream, created_targets) }); // Commit only if we created the transaction here (we own it). if let Some(t) = maybe_owned_tx.take() { t.commit().await?; } Ok(r.into()) } async fn get_total_upstreams( &self, _options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result { #[derive(Debug, FromQueryResult)] struct CountResult { // The field name must match the column alias in the query count: i64, } let count_info = with_conn!(&*self.connection, tx, conn, { upstream::Entity::find() .select_only() .column_as(upstream::Column::Id.count(), "count") .into_model::() .one(*conn) .await? }); Ok(count_info.map_or(0, |c| c.count) as u64) } async fn get_upstream( &self, upstream_id: uuid::Uuid, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result { let concrete_options = options.unwrap_or_default(); let info: UpstreamInfo = if concrete_options.include_targets { let (up_model, targets) = with_conn!(&*self.connection, tx, conn, { let up = upstream::Entity::find_by_id(upstream_id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream with id {} not found", upstream_id )))?; let targets = upstream_target::Entity::find() .filter(upstream_target::Column::UpstreamId.eq(upstream_id)) .apply_if( concrete_options.filter_by_enabled.then_some(true), |query, _v| query.filter(upstream_target::Column::Enabled.eq(true)), ) .all(*conn) .await?; (up, targets) }); (up_model, targets).into() } else { with_conn!(&*self.connection, tx, conn, { upstream::Entity::find_by_id(upstream_id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream with id {} not found", upstream_id )))? }) .into() }; Ok(info) } async fn get_upstreams( &self, pagination: Option, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result, ServiceError> { let r = with_conn!(&*self.connection, tx, conn, { let find_query = upstream::Entity::find(); let find_query = if let Some(pagination) = pagination { let (offset, limit) = pagination.get_offset_limit(); find_query.offset(offset).limit(limit) } else { find_query }; let find_query = match options { Some(opts) => { if opts.include_targets && opts.filter_by_enabled { find_query.filter( upstream_target::Column::Enabled .eq(true) .or(upstream_target::Column::Id.is_null()), ) } else { find_query } } _ => find_query, }; find_query .find_with_related(upstream_target::Entity) .all(*conn) .await? }); Ok(r.into_iter().map(|m| m.into()).collect()) } async fn update_upstream( &self, id: uuid::Uuid, upstream: UpdateUpstreamInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result { // If a transaction was provided use it, otherwise create and own one here. let mut maybe_owned_tx: Option = None; let tx_ref: Option<&mut DatabaseTransaction> = if let Some(tx) = tx { Some(tx) } else { maybe_owned_tx = Some(self.connection.begin().await?); maybe_owned_tx.as_mut() }; let current_model = with_conn!(&*self.connection, tx_ref, conn, { upstream::Entity::find_by_id(id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream with id {} not found", id )))? }); let upstream_active_model = upstream.clone().apply_to_model(current_model); let r = with_conn!(&*self.connection, tx_ref, conn, { let updated_upstream_model = upstream_active_model.update(*conn).await?; // update upstream targets if any if let Some(targets) = upstream.upstream_targets { for (target_id, enabled) in targets.into_iter() { let target_model = upstream_target::Entity::find_by_id(target_id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream target with id {} not found", target_id )))?; let mut target_active_model: upstream_target::ActiveModel = target_model.into(); target_active_model.enabled = sea_orm::ActiveValue::Set(enabled); target_active_model.update(*conn).await?; Ok::<(), ServiceError>(())?; } } updated_upstream_model }); // Commit if let Some(t) = maybe_owned_tx.take() { t.commit().await?; } Ok(r.into()) } async fn delete_upstream( &self, upstream_id: uuid::Uuid, tx: Option<&mut DatabaseTransaction>, ) -> Result<(), ServiceError> { let model = with_conn!(&*self.connection, tx, conn, { upstream::Entity::find_by_id(upstream_id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream with id {} not found", upstream_id )))? }); with_conn!(&*self.connection, tx, conn, { // delete all targets belonging to the upstream upstream_target::Entity::delete_many() .filter(upstream_target::Column::UpstreamId.eq(upstream_id)) .exec(*conn) .await?; model.delete(*conn).await?; Ok(()) }) } // // async fn create_upstream_target( &self, create_info: UpstreamTargetCreateInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result { let model: upstream_target::ActiveModel = create_info.into(); let r = with_conn!(&*self.connection, tx, conn, { model.insert(*conn).await? }); Ok(r.into()) } async fn get_upstream_target( &self, target_id: uuid::Uuid, options: Option, tx: Option<&mut DatabaseTransaction>, ) -> Result { let concrete_options = options.unwrap_or_default(); let info: UpstreamTargetInfo = if concrete_options.include_upstream { match with_conn!(&*self.connection, tx, conn, { upstream_target::Entity::find_by_id(target_id) .find_also_related(upstream::Entity) .one(*conn) .await? }) { Some((target_model, Some(upstream_model))) => (target_model, upstream_model).into(), Some((_target_model, None)) => { return Err(ServiceError::InternalError(format!( "Inconsistent data: Upstream target with id {} has no associated upstream", target_id ))); } None => { return Err(ServiceError::NotFound(format!( "Upstream target with id {} not found", target_id ))); } } } else { with_conn!(&*self.connection, tx, conn, { upstream_target::Entity::find_by_id(target_id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream target with id {} not found", target_id )))? }) .into() }; Ok(info) } async fn get_upstream_targets_by_upstream( &self, upstream_id: uuid::Uuid, tx: Option<&mut DatabaseTransaction>, ) -> Result, ServiceError> { let r = with_conn!(&*self.connection, tx, conn, { upstream_target::Entity::find() .filter(upstream_target::Column::UpstreamId.eq(upstream_id)) .all(*conn) .await? }); Ok(r.into_iter().map(|m| m.into()).collect()) } async fn update_upstream_target( &self, id: uuid::Uuid, target: UpdateUpstreamTargetInfo, tx: Option<&mut DatabaseTransaction>, ) -> Result { let current_model = with_conn!(&*self.connection, tx, conn, { upstream_target::Entity::find_by_id(id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream target with id {} not found", id )))? }); let active_model = target.apply_to_model(current_model); let r = with_conn!(&*self.connection, tx, conn, { active_model.update(*conn).await? }); Ok(r.into()) } async fn delete_upstream_target( &self, target_id: uuid::Uuid, tx: Option<&mut DatabaseTransaction>, ) -> Result<(), ServiceError> { let model = with_conn!(&*self.connection, tx, conn, { upstream_target::Entity::find_by_id(target_id) .one(*conn) .await? .ok_or(ServiceError::NotFound(format!( "Upstream target with id {} not found", target_id )))? }); with_conn!(&*self.connection, tx, conn, { model.delete(*conn).await?; Ok(()) }) } async fn generate_config( &self, builder: &mut NginxConfigBuilder, tx: &Option<&mut DatabaseTransaction>, ) -> Result<(), ServiceError> { // get all upstreams and their targets let upstreams = with_conn!(&*self.connection, tx, conn, { upstream::Entity::find() .find_with_related(upstream_target::Entity) .all(*conn) .await? }); let upstreams_info = upstreams .into_iter() .map(|(up_model, target_models)| (up_model, target_models).into()) .collect::>(); builder.add_upstreams(upstreams_info); Ok(()) } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use sea_orm::MockExecResult; use sea_orm::{DatabaseBackend, MockDatabase}; use database::generated::entities::{upstream, upstream_target}; #[tokio::test] async fn create_upstream_returns_info() { let up_model = upstream::Model { id: uuid::Uuid::new_v4(), name: "test_upstream".to_string(), protocol: "http".to_string(), algorithm: "round_robin".to_string(), sticky_session: false, created_by: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![up_model.clone()]]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let create_info = crate::services::nginx::info::upstream::UpstreamCreateInfo { name: "test_upstream".to_string(), protocol: "http".to_string(), algorithm: "round_robin".to_string(), sticky_session: false, created_by: None, upstream_targets: Vec::new(), }; let res = svc.create_upstream(create_info, None).await; assert!(res.is_ok()); let info = res.expect("Failed to create upstream"); assert_eq!(info.name, "test_upstream"); } #[tokio::test] async fn get_upstream_with_targets_returns_targets() { let up_id = uuid::Uuid::new_v4(); let up_model = upstream::Model { id: up_id, name: "with_targets".to_string(), protocol: "http".to_string(), algorithm: "least_conn".to_string(), sticky_session: true, created_by: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let target_model = upstream_target::Model { id: uuid::Uuid::new_v4(), upstream_id: up_id, target_host: "127.0.0.1".to_string(), target_port: 8080, weight: 1, is_backup: false, enabled: true, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) // find_by_id -> returns upstream model .append_query_results(vec![vec![up_model.clone()]]) // find targets -> returns the target(s) .append_query_results(vec![vec![target_model.clone()]]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc .get_upstream( up_id, Some(GetUpstreamOptions { include_targets: true, filter_by_enabled: false, }), None, ) .await; assert!(res.is_ok()); let info = res.expect("Failed to get upstream with targets"); assert_eq!(info.id, up_id); assert_eq!(info.upstream_targets.len(), 1); assert_eq!(info.upstream_targets[0].target_host, "127.0.0.1"); } #[tokio::test] async fn get_upstream_not_found_returns_not_found() { let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![Vec::::new()]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc.get_upstream(uuid::Uuid::new_v4(), None, None).await; assert!(matches!(res, Err(ServiceError::NotFound(_)))); } #[tokio::test] async fn get_upstreams_returns_list() { let u1 = upstream::Model { id: uuid::Uuid::new_v4(), name: "u1".to_string(), protocol: "http".to_string(), algorithm: "rr".to_string(), sticky_session: false, created_by: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let u2 = upstream::Model { id: uuid::Uuid::new_v4(), name: "u2".to_string(), protocol: "http".to_string(), algorithm: "rr".to_string(), sticky_session: false, created_by: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![ (u1.clone(), None::), (u2.clone(), None::), ]]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc.get_upstreams(None, None, None).await; assert!(res.is_ok()); let list = res.expect("Failed to get upstreams"); assert_eq!(list.len(), 2); } #[tokio::test] async fn get_upstream_targets_by_upstream_returns_targets() { let up_id = uuid::Uuid::new_v4(); let t = upstream_target::Model { id: uuid::Uuid::new_v4(), upstream_id: up_id, target_host: "10.0.0.1".to_string(), target_port: 80, weight: 10, is_backup: false, enabled: true, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![t.clone()]]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc.get_upstream_targets_by_upstream(up_id, None).await; assert!(res.is_ok()); let targets = res.expect("Failed to get upstream targets"); assert_eq!(targets.len(), 1); assert_eq!(targets[0].target_host, "10.0.0.1"); } #[tokio::test] async fn update_upstream_success() { let id = uuid::Uuid::new_v4(); let existing = upstream::Model { id, name: "old".to_string(), protocol: "http".to_string(), algorithm: "rr".to_string(), sticky_session: false, created_by: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let updated = upstream::Model { id, name: "new".to_string(), protocol: "http".to_string(), algorithm: "rr".to_string(), sticky_session: false, created_by: None, created_at: existing.created_at, updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![existing.clone()]]) // find_by_id .append_query_results(vec![vec![updated.clone()]]) // update result .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let update_info = crate::services::nginx::info::upstream::UpdateUpstreamInfo { name: None, protocol: None, algorithm: None, sticky_session: None, upstream_targets: None, }; let res = svc.update_upstream(id, update_info, None).await; assert!(res.is_ok()); let got = res.expect("Failed to update upstream"); assert_eq!(got.name, "new"); } #[tokio::test] async fn update_upstream_not_found() { let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![Vec::::new()]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc .update_upstream( uuid::Uuid::new_v4(), crate::services::nginx::info::upstream::UpdateUpstreamInfo { name: None, protocol: None, algorithm: None, sticky_session: None, upstream_targets: None, }, None, ) .await; assert!(matches!(res, Err(ServiceError::NotFound(_)))); } #[tokio::test] async fn delete_upstream_success() { let id = uuid::Uuid::new_v4(); let existing = upstream::Model { id, name: "todelete".to_string(), protocol: "http".to_string(), algorithm: "rr".to_string(), sticky_session: false, created_by: None, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![existing.clone()]]) .append_exec_results(vec![ MockExecResult { rows_affected: 1, last_insert_id: 0, }, MockExecResult { rows_affected: 1, last_insert_id: 0, }, ]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc.delete_upstream(id, None).await; assert!(res.is_ok()); } #[tokio::test] async fn delete_upstream_not_found() { let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![Vec::::new()]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc.delete_upstream(uuid::Uuid::new_v4(), None).await; assert!(matches!(res, Err(ServiceError::NotFound(_)))); } #[tokio::test] async fn create_upstream_target_success() { let id = uuid::Uuid::new_v4(); let upstream_id = uuid::Uuid::new_v4(); let created = upstream_target::Model { id, upstream_id, target_host: "1.2.3.4".to_string(), target_port: 8080, weight: 5, is_backup: false, enabled: true, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![created.clone()]]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let create_info = crate::services::nginx::info::upstream_target::UpstreamTargetCreateInfo { target_host: "1.2.3.4".to_string(), target_port: 8080, weight: 5, is_backup: false, enabled: true, upstream_id, }; let res = svc.create_upstream_target(create_info, None).await; assert!(res.is_ok()); let t = res.expect("Failed to create target"); assert_eq!(t.target_host, "1.2.3.4"); } #[tokio::test] async fn update_upstream_target_success() { let id = uuid::Uuid::new_v4(); let existing = upstream_target::Model { id, upstream_id: uuid::Uuid::new_v4(), target_host: "old".to_string(), target_port: 80, weight: 1, is_backup: false, enabled: true, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let updated = upstream_target::Model { id, upstream_id: existing.upstream_id, target_host: "new".to_string(), target_port: 80, weight: 1, is_backup: false, enabled: true, created_at: existing.created_at, updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![existing.clone()]]) .append_query_results(vec![vec![updated.clone()]]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let update_info = crate::services::nginx::info::upstream_target::UpdateUpstreamTargetInfo { target_host: None, target_port: None, weight: None, is_backup: None, enabled: None, }; let res = svc.update_upstream_target(id, update_info, None).await; assert!(res.is_ok()); let got = res.expect("Failed to update target"); assert_eq!(got.target_host, "new"); } #[tokio::test] async fn delete_upstream_target_success() { let id = uuid::Uuid::new_v4(); let existing = upstream_target::Model { id, upstream_id: uuid::Uuid::new_v4(), target_host: "del".to_string(), target_port: 80, weight: 1, is_backup: false, enabled: true, created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), }; let db = MockDatabase::new(DatabaseBackend::Sqlite) .append_query_results(vec![vec![existing.clone()]]) .append_exec_results(vec![MockExecResult { rows_affected: 1, last_insert_id: 0, }]) .into_connection(); let svc = UpstreamServiceImpl::new(Arc::new(db)); let res = svc.delete_upstream_target(id, None).await; assert!(res.is_ok()); } }