diff --git a/Cargo.lock b/Cargo.lock index c85cdfd..6a0439e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2517,6 +2517,7 @@ name = "nxmesh-proto" version = "0.1.0" dependencies = [ "prost", + "tokio-test", "tonic", "tonic-async-interceptor", "tonic-prost", diff --git a/crates/nxmesh-proto/Cargo.toml b/crates/nxmesh-proto/Cargo.toml index 2503d86..3dd8cf7 100644 --- a/crates/nxmesh-proto/Cargo.toml +++ b/crates/nxmesh-proto/Cargo.toml @@ -25,3 +25,6 @@ client = [] [build-dependencies] tonic-prost-build.workspace = true + +[dev-dependencies] +tokio-test.workspace = true diff --git a/crates/nxmesh-proto/src/auth/ssh_auth.rs b/crates/nxmesh-proto/src/auth/ssh_auth.rs index a3c74ac..5cefa08 100644 --- a/crates/nxmesh-proto/src/auth/ssh_auth.rs +++ b/crates/nxmesh-proto/src/auth/ssh_auth.rs @@ -38,12 +38,162 @@ impl SshAuthInterceptor { async fn authenticate(&self, req: Request<()>) -> Result, Status> { let certs = req.peer_certs().ok_or(Status::unauthenticated("No cert"))?; - let is_authorized = self.certificate_provider.is_authorized(&certs).await?; + self.validate_certs(&certs).await?; + Ok(req) + } + + async fn validate_certs(&self, certs: &Arc>>) -> Result<(), Status> { + let is_authorized = self.certificate_provider.is_authorized(certs).await?; if is_authorized { - Ok(req) + Ok(()) } else { Err(Status::permission_denied("Blocked")) } } } + +#[cfg(test)] +mod tests { + use std::sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }; + + use tonic::{Request, Status, transport::CertificateDer}; + use tonic_async_interceptor::AsyncInterceptor; + + use super::{CertificateValidationProvider, SshAuthInterceptor, create_ssh_auth_interceptor}; + + #[derive(Clone, Copy)] + enum ProviderMode { + Allow, + Deny, + Error, + } + + struct TestCertificateProvider { + mode: ProviderMode, + calls: Arc, + cert_count_seen: Arc>>, + } + + #[tonic::async_trait] + impl CertificateValidationProvider for TestCertificateProvider { + async fn is_authorized( + &self, + certs: &Arc>>, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + + let lock = self.cert_count_seen.lock(); + assert!(lock.is_ok()); + let mut lock = lock.unwrap_or_else(|_| unreachable!()); + *lock = Some(certs.len()); + + match self.mode { + ProviderMode::Allow => Ok(true), + ProviderMode::Deny => Ok(false), + ProviderMode::Error => Err(Status::internal("provider failed")), + } + } + } + + fn build_provider( + mode: ProviderMode, + ) -> ( + Arc, + Arc, + Arc>>, + ) { + let calls = Arc::new(AtomicUsize::new(0)); + let cert_count_seen = Arc::new(Mutex::new(None)); + let provider = Arc::new(TestCertificateProvider { + mode, + calls: calls.clone(), + cert_count_seen: cert_count_seen.clone(), + }); + (provider, calls, cert_count_seen) + } + + fn sample_certs() -> Arc>> { + Arc::new(vec![ + CertificateDer::from(vec![1, 2, 3]), + CertificateDer::from(vec![4, 5, 6]), + ]) + } + + #[test] + fn create_ssh_auth_interceptor_builds_layer() { + let (provider, _, _) = build_provider(ProviderMode::Allow); + let _ = create_ssh_auth_interceptor(provider); + } + + #[test] + fn authenticate_fails_when_no_peer_certificates_exist() { + let (provider, calls, _) = build_provider(ProviderMode::Allow); + let interceptor = SshAuthInterceptor::new(provider); + + let result = tokio_test::block_on(interceptor.authenticate(Request::new(()))); + assert!(result.is_err()); + let err = result.err().unwrap_or_else(|| unreachable!()); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[test] + fn validate_certs_succeeds_when_provider_allows() { + let (provider, calls, cert_count_seen) = build_provider(ProviderMode::Allow); + let interceptor = SshAuthInterceptor::new(provider); + let certs = sample_certs(); + + let result = tokio_test::block_on(interceptor.validate_certs(&certs)); + assert!(result.is_ok()); + assert_eq!(calls.load(Ordering::SeqCst), 1); + + let seen = cert_count_seen.lock(); + assert!(seen.is_ok()); + let seen = seen.unwrap_or_else(|_| unreachable!()); + assert_eq!(*seen, Some(2)); + } + + #[test] + fn validate_certs_returns_permission_denied_when_provider_denies() { + let (provider, calls, _) = build_provider(ProviderMode::Deny); + let interceptor = SshAuthInterceptor::new(provider); + let certs = sample_certs(); + + let result = tokio_test::block_on(interceptor.validate_certs(&certs)); + assert!(result.is_err()); + let err = result.err().unwrap_or_else(|| unreachable!()); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + assert_eq!(err.message(), "Blocked"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[test] + fn validate_certs_propagates_provider_errors() { + let (provider, calls, _) = build_provider(ProviderMode::Error); + let interceptor = SshAuthInterceptor::new(provider); + let certs = sample_certs(); + + let result = tokio_test::block_on(interceptor.validate_certs(&certs)); + assert!(result.is_err()); + let err = result.err().unwrap_or_else(|| unreachable!()); + assert_eq!(err.code(), tonic::Code::Internal); + assert_eq!(err.message(), "provider failed"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[test] + fn async_interceptor_call_delegates_to_authenticate() { + let (provider, calls, _) = build_provider(ProviderMode::Allow); + let mut interceptor = SshAuthInterceptor::new(provider); + + let result = tokio_test::block_on(interceptor.call(Request::new(()))); + assert!(result.is_err()); + let err = result.err().unwrap_or_else(|| unreachable!()); + assert_eq!(err.code(), tonic::Code::Unauthenticated); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } +} diff --git a/crates/nxmesh-proto/src/lib.rs b/crates/nxmesh-proto/src/lib.rs index 37c6dd2..c718240 100644 --- a/crates/nxmesh-proto/src/lib.rs +++ b/crates/nxmesh-proto/src/lib.rs @@ -11,3 +11,114 @@ pub mod agent { pub use agent::*; pub mod auth; pub use tonic_async_interceptor::*; + +#[cfg(test)] +mod tests { + use prost::Message; + + use crate::agent::{ + AgentMessage, ConfigApplyStatus, ConfigStatus, DeploymentMode, Error, MasterMessage, + MetricType, RegistrationRequest, agent_message, master_message, + }; + + #[test] + fn agent_message_round_trip_with_registration_payload() { + let msg = AgentMessage { + agent_id: "agent-1".to_string(), + timestamp: 123, + payload: Some(agent_message::Payload::Registration(RegistrationRequest { + hostname: "node-1".to_string(), + ip_address: "127.0.0.1".to_string(), + version: "1.0.0".to_string(), + capabilities: vec!["reload".to_string(), "metrics".to_string()], + labels: std::collections::HashMap::from([ + ("region".to_string(), "dev".to_string()), + ("tier".to_string(), "edge".to_string()), + ]), + deployment_mode: DeploymentMode::Standalone as i32, + })), + }; + + let encoded = msg.encode_to_vec(); + let decoded = AgentMessage::decode(encoded.as_slice()); + assert!(decoded.is_ok()); + let decoded = decoded.unwrap_or_else(|_| unreachable!()); + + assert_eq!(decoded.agent_id, "agent-1"); + assert_eq!(decoded.timestamp, 123); + + match decoded.payload { + Some(agent_message::Payload::Registration(payload)) => { + assert_eq!(payload.hostname, "node-1"); + assert_eq!(payload.ip_address, "127.0.0.1"); + assert_eq!(payload.version, "1.0.0"); + assert_eq!(payload.capabilities.len(), 2); + assert_eq!(payload.labels.get("region"), Some(&"dev".to_string())); + assert_eq!(payload.deployment_mode, DeploymentMode::Standalone as i32); + } + _ => unreachable!(), + } + } + + #[test] + fn master_message_round_trip_with_error_payload() { + let msg = MasterMessage { + timestamp: 999, + payload: Some(master_message::Payload::Error(Error { + code: "E_CONFIG_INVALID".to_string(), + message: "invalid config".to_string(), + details: std::collections::HashMap::from([ + ("file".to_string(), "site.conf".to_string()), + ("line".to_string(), "42".to_string()), + ]), + })), + }; + + let encoded = msg.encode_to_vec(); + let decoded = MasterMessage::decode(encoded.as_slice()); + assert!(decoded.is_ok()); + let decoded = decoded.unwrap_or_else(|_| unreachable!()); + + assert_eq!(decoded.timestamp, 999); + match decoded.payload { + Some(master_message::Payload::Error(err)) => { + assert_eq!(err.code, "E_CONFIG_INVALID"); + assert_eq!(err.message, "invalid config"); + assert_eq!(err.details.get("line"), Some(&"42".to_string())); + } + _ => unreachable!(), + } + } + + #[test] + fn enum_integer_mappings_are_stable() { + assert_eq!(DeploymentMode::Unspecified as i32, 0); + assert_eq!(DeploymentMode::DockerSidecar as i32, 1); + assert_eq!(DeploymentMode::KubernetesSidecar as i32, 2); + assert_eq!(DeploymentMode::Standalone as i32, 3); + + assert_eq!(ConfigApplyStatus::Unspecified as i32, 0); + assert_eq!(ConfigApplyStatus::Pending as i32, 1); + assert_eq!(ConfigApplyStatus::Validating as i32, 2); + assert_eq!(ConfigApplyStatus::Applying as i32, 3); + assert_eq!(ConfigApplyStatus::Success as i32, 4); + assert_eq!(ConfigApplyStatus::Failed as i32, 5); + assert_eq!(ConfigApplyStatus::RolledBack as i32, 6); + + assert_eq!(MetricType::Unspecified as i32, 0); + assert_eq!(MetricType::Gauge as i32, 1); + assert_eq!(MetricType::Counter as i32, 2); + assert_eq!(MetricType::Histogram as i32, 3); + } + + #[test] + fn config_status_defaults_are_proto3_zero_values() { + let status = ConfigStatus::default(); + + assert_eq!(status.config_id, ""); + assert_eq!(status.version, 0); + assert_eq!(status.status, ConfigApplyStatus::Unspecified as i32); + assert_eq!(status.error_message, ""); + assert_eq!(status.applied_at, 0); + } +}