119 lines
4.3 KiB
Rust
119 lines
4.3 KiB
Rust
use std::sync::Arc;
|
|
|
|
use nxmesh_proto::{
|
|
agent_service_server::AgentServiceServer,
|
|
auth::ssh_auth::{CertificateValidationProvider, create_ssh_auth_interceptor},
|
|
};
|
|
use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter};
|
|
use tonic::transport::Server;
|
|
use tracing::info;
|
|
|
|
use crate::{db::entities::public_key_revocations, service::agent::AgentServerService};
|
|
|
|
use super::AgentConnectorTrait;
|
|
|
|
const MAX_CERTS_TO_CHECK: usize = 50;
|
|
|
|
pub struct SshAgentConnector {
|
|
// router: Router<Stack<AsyncInterceptorLayer<SshAuthInterceptor>, Identity>>,
|
|
settings: Arc<crate::config::settings::Settings>,
|
|
}
|
|
|
|
impl SshAgentConnector {
|
|
pub fn new(
|
|
settings: impl Into<Arc<crate::config::settings::Settings>>,
|
|
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
|
Ok(Self {
|
|
settings: settings.into(),
|
|
})
|
|
}
|
|
|
|
async fn get_tls_config(
|
|
cert_service: Arc<dyn crate::service::certificate::CertificateService>,
|
|
) -> Result<tonic::transport::ServerTlsConfig, Box<dyn std::error::Error + Send + Sync>> {
|
|
let (san_ips, san_dns) =
|
|
cert_service.get_sans(crate::service::certificate::ConnectionType::GRPC);
|
|
let (cert_pem, key_pem) = cert_service
|
|
.generate_pub_cert_pair(san_ips, san_dns)
|
|
.await?;
|
|
let (ca_cert_path, _) = cert_service.get_ca_cert().await?;
|
|
let ca_cert_pem = std::fs::read_to_string(&ca_cert_path)?;
|
|
|
|
let tls_config = tonic::transport::ServerTlsConfig::new()
|
|
.identity(tonic::transport::Identity::from_pem(cert_pem, key_pem))
|
|
.client_ca_root(tonic::transport::Certificate::from_pem(ca_cert_pem));
|
|
Ok(tls_config)
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl AgentConnectorTrait for SshAgentConnector {
|
|
async fn start_server(
|
|
&mut self,
|
|
settings: &crate::config::settings::Settings,
|
|
cert_service: Arc<dyn crate::service::certificate::CertificateService>,
|
|
connection: DatabaseConnection,
|
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|
let addr = settings.grpc.bind_address.clone().parse()?;
|
|
let port = settings.grpc.port;
|
|
let addr = std::net::SocketAddr::new(addr, port);
|
|
|
|
// Create the gRPC server
|
|
let cert_validation_provider = Arc::new(CertificateValidationProviderImpl::new(connection));
|
|
let ssh_interceptor = create_ssh_auth_interceptor(cert_validation_provider);
|
|
let agent_server_service = AgentServiceServer::new(AgentServerService::default());
|
|
|
|
let tls_config = Self::get_tls_config(cert_service.clone()).await?;
|
|
|
|
let router = Server::builder()
|
|
.tls_config(tls_config)?
|
|
.layer(ssh_interceptor)
|
|
.add_service(agent_server_service);
|
|
|
|
info!("SSH Agent gRPC server is listening on {}", addr);
|
|
router
|
|
.serve(addr)
|
|
.await
|
|
.inspect(|_| info!("SSH Agent gRPC server stopped gracefully."))
|
|
.inspect_err(|e| {
|
|
tracing::error!("SSH Agent gRPC server failed: {}", e);
|
|
})?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
struct CertificateValidationProviderImpl {
|
|
connection: DatabaseConnection,
|
|
}
|
|
|
|
impl CertificateValidationProviderImpl {
|
|
pub fn new(connection: DatabaseConnection) -> Self {
|
|
CertificateValidationProviderImpl { connection }
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl CertificateValidationProvider for CertificateValidationProviderImpl {
|
|
async fn is_authorized(
|
|
&self,
|
|
certs: &Arc<Vec<tonic::transport::CertificateDer<'_>>>,
|
|
) -> Result<bool, tonic::Status> {
|
|
// check if the certificate's public key matches any agent's public key in the database
|
|
let found = public_key_revocations::Entity::find()
|
|
.filter(public_key_revocations::Column::PublicKeyHash.is_in(
|
|
certs.iter().take(MAX_CERTS_TO_CHECK).map(|cert| {
|
|
use sha2::{Digest, Sha256};
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(cert.as_ref());
|
|
hex::encode(hasher.finalize())
|
|
}),
|
|
))
|
|
.one(&self.connection)
|
|
.await
|
|
.map_err(|e| tonic::Status::internal(format!("Database query failed: {}", e)))?
|
|
.is_some();
|
|
|
|
Ok(!found)
|
|
}
|
|
}
|