feat: Add validation for SSH certificates and implement unit tests for SshAuthInterceptor
This commit is contained in:
@@ -38,12 +38,162 @@ impl SshAuthInterceptor {
|
||||
async fn authenticate(&self, req: Request<()>) -> Result<Request<()>, Status> {
|
||||
let certs = req.peer_certs().ok_or(Status::unauthenticated("No cert"))?;
|
||||
|
||||
let is_authorized = self.certificate_provider.is_authorized(&certs).await?;
|
||||
self.validate_certs(&certs).await?;
|
||||
Ok(req)
|
||||
}
|
||||
|
||||
async fn validate_certs(&self, certs: &Arc<Vec<CertificateDer<'_>>>) -> Result<(), Status> {
|
||||
let is_authorized = self.certificate_provider.is_authorized(certs).await?;
|
||||
|
||||
if is_authorized {
|
||||
Ok(req)
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Status::permission_denied("Blocked"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{
|
||||
Arc, Mutex,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
use tonic::{Request, Status, transport::CertificateDer};
|
||||
use tonic_async_interceptor::AsyncInterceptor;
|
||||
|
||||
use super::{CertificateValidationProvider, SshAuthInterceptor, create_ssh_auth_interceptor};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum ProviderMode {
|
||||
Allow,
|
||||
Deny,
|
||||
Error,
|
||||
}
|
||||
|
||||
struct TestCertificateProvider {
|
||||
mode: ProviderMode,
|
||||
calls: Arc<AtomicUsize>,
|
||||
cert_count_seen: Arc<Mutex<Option<usize>>>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl CertificateValidationProvider for TestCertificateProvider {
|
||||
async fn is_authorized(
|
||||
&self,
|
||||
certs: &Arc<Vec<CertificateDer<'_>>>,
|
||||
) -> Result<bool, Status> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
let lock = self.cert_count_seen.lock();
|
||||
assert!(lock.is_ok());
|
||||
let mut lock = lock.unwrap_or_else(|_| unreachable!());
|
||||
*lock = Some(certs.len());
|
||||
|
||||
match self.mode {
|
||||
ProviderMode::Allow => Ok(true),
|
||||
ProviderMode::Deny => Ok(false),
|
||||
ProviderMode::Error => Err(Status::internal("provider failed")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_provider(
|
||||
mode: ProviderMode,
|
||||
) -> (
|
||||
Arc<TestCertificateProvider>,
|
||||
Arc<AtomicUsize>,
|
||||
Arc<Mutex<Option<usize>>>,
|
||||
) {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let cert_count_seen = Arc::new(Mutex::new(None));
|
||||
let provider = Arc::new(TestCertificateProvider {
|
||||
mode,
|
||||
calls: calls.clone(),
|
||||
cert_count_seen: cert_count_seen.clone(),
|
||||
});
|
||||
(provider, calls, cert_count_seen)
|
||||
}
|
||||
|
||||
fn sample_certs() -> Arc<Vec<CertificateDer<'static>>> {
|
||||
Arc::new(vec![
|
||||
CertificateDer::from(vec![1, 2, 3]),
|
||||
CertificateDer::from(vec![4, 5, 6]),
|
||||
])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_ssh_auth_interceptor_builds_layer() {
|
||||
let (provider, _, _) = build_provider(ProviderMode::Allow);
|
||||
let _ = create_ssh_auth_interceptor(provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authenticate_fails_when_no_peer_certificates_exist() {
|
||||
let (provider, calls, _) = build_provider(ProviderMode::Allow);
|
||||
let interceptor = SshAuthInterceptor::new(provider);
|
||||
|
||||
let result = tokio_test::block_on(interceptor.authenticate(Request::new(())));
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap_or_else(|| unreachable!());
|
||||
assert_eq!(err.code(), tonic::Code::Unauthenticated);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_certs_succeeds_when_provider_allows() {
|
||||
let (provider, calls, cert_count_seen) = build_provider(ProviderMode::Allow);
|
||||
let interceptor = SshAuthInterceptor::new(provider);
|
||||
let certs = sample_certs();
|
||||
|
||||
let result = tokio_test::block_on(interceptor.validate_certs(&certs));
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
|
||||
let seen = cert_count_seen.lock();
|
||||
assert!(seen.is_ok());
|
||||
let seen = seen.unwrap_or_else(|_| unreachable!());
|
||||
assert_eq!(*seen, Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_certs_returns_permission_denied_when_provider_denies() {
|
||||
let (provider, calls, _) = build_provider(ProviderMode::Deny);
|
||||
let interceptor = SshAuthInterceptor::new(provider);
|
||||
let certs = sample_certs();
|
||||
|
||||
let result = tokio_test::block_on(interceptor.validate_certs(&certs));
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap_or_else(|| unreachable!());
|
||||
assert_eq!(err.code(), tonic::Code::PermissionDenied);
|
||||
assert_eq!(err.message(), "Blocked");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_certs_propagates_provider_errors() {
|
||||
let (provider, calls, _) = build_provider(ProviderMode::Error);
|
||||
let interceptor = SshAuthInterceptor::new(provider);
|
||||
let certs = sample_certs();
|
||||
|
||||
let result = tokio_test::block_on(interceptor.validate_certs(&certs));
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap_or_else(|| unreachable!());
|
||||
assert_eq!(err.code(), tonic::Code::Internal);
|
||||
assert_eq!(err.message(), "provider failed");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn async_interceptor_call_delegates_to_authenticate() {
|
||||
let (provider, calls, _) = build_provider(ProviderMode::Allow);
|
||||
let mut interceptor = SshAuthInterceptor::new(provider);
|
||||
|
||||
let result = tokio_test::block_on(interceptor.call(Request::new(())));
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap_or_else(|| unreachable!());
|
||||
assert_eq!(err.code(), tonic::Code::Unauthenticated);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user