refactor: enhance handler traits to support dynamic sizing and improve cancellation handling
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -2589,6 +2589,7 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tokio-test",
|
"tokio-test",
|
||||||
|
"tokio-util",
|
||||||
"toml",
|
"toml",
|
||||||
"tonic",
|
"tonic",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -2613,6 +2614,7 @@ dependencies = [
|
|||||||
name = "nxmesh-master"
|
name = "nxmesh-master"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
"argon2",
|
"argon2",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -2641,6 +2643,7 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
"tokio-test",
|
"tokio-test",
|
||||||
"toml",
|
"toml",
|
||||||
"tonic",
|
"tonic",
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ tonic.workspace = true
|
|||||||
async-trait.workspace = true
|
async-trait.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
tokio-stream.workspace = true
|
tokio-stream.workspace = true
|
||||||
|
tokio-util = "0.7"
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
config.workspace = true
|
config.workspace = true
|
||||||
|
|||||||
@@ -17,14 +17,14 @@ pub trait OnConfigUpdateHandler: Send + Sync + 'static {
|
|||||||
|
|
||||||
pub struct HandlerImpl<OCH>
|
pub struct HandlerImpl<OCH>
|
||||||
where
|
where
|
||||||
OCH: OnConfigUpdateHandler,
|
OCH: OnConfigUpdateHandler + ?Sized,
|
||||||
{
|
{
|
||||||
on_config_update_handler: Arc<OCH>,
|
on_config_update_handler: Arc<OCH>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<OCH> HandlerImpl<OCH>
|
impl<OCH> HandlerImpl<OCH>
|
||||||
where
|
where
|
||||||
OCH: OnConfigUpdateHandler,
|
OCH: OnConfigUpdateHandler + ?Sized,
|
||||||
{
|
{
|
||||||
pub fn new(on_config_update_handler: Arc<OCH>) -> Self {
|
pub fn new(on_config_update_handler: Arc<OCH>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -36,7 +36,7 @@ where
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl<OCH> MasterMessageHandler for HandlerImpl<OCH>
|
impl<OCH> MasterMessageHandler for HandlerImpl<OCH>
|
||||||
where
|
where
|
||||||
OCH: OnConfigUpdateHandler,
|
OCH: OnConfigUpdateHandler + ?Sized,
|
||||||
{
|
{
|
||||||
async fn handle_master_message(&self, message: MasterMessage) -> MessageResult<()> {
|
async fn handle_master_message(&self, message: MasterMessage) -> MessageResult<()> {
|
||||||
match message.payload {
|
match message.payload {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||||||
use nxmesh_proto::AgentMessage;
|
use nxmesh_proto::AgentMessage;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
use tokio_stream::wrappers::ReceiverStream;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -35,13 +36,14 @@ pub trait MasterHandler: Send + Sync + 'static {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct MessageHandleInfo {
|
struct MessageHandleInfo {
|
||||||
handle: tokio::task::JoinHandle<()>,
|
|
||||||
tx: mpsc::Sender<AgentMessage>,
|
tx: mpsc::Sender<AgentMessage>,
|
||||||
|
// used to signal the running handler/connection to stop
|
||||||
|
cancel: CancellationToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MasterHandlerImpl<MMH>
|
pub struct MasterHandlerImpl<MMH>
|
||||||
where
|
where
|
||||||
MMH: MasterMessageHandler,
|
MMH: MasterMessageHandler + ?Sized,
|
||||||
{
|
{
|
||||||
connector: Arc<MasterConnector>,
|
connector: Arc<MasterConnector>,
|
||||||
message_handler: Arc<MMH>,
|
message_handler: Arc<MMH>,
|
||||||
@@ -50,7 +52,7 @@ where
|
|||||||
|
|
||||||
impl<MMH> MasterHandlerImpl<MMH>
|
impl<MMH> MasterHandlerImpl<MMH>
|
||||||
where
|
where
|
||||||
MMH: MasterMessageHandler,
|
MMH: MasterMessageHandler + ?Sized,
|
||||||
{
|
{
|
||||||
pub fn new(connector: Arc<MasterConnector>, message_handler: Arc<MMH>) -> Self {
|
pub fn new(connector: Arc<MasterConnector>, message_handler: Arc<MMH>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -64,29 +66,35 @@ where
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl<MMH> MasterHandler for MasterHandlerImpl<MMH>
|
impl<MMH> MasterHandler for MasterHandlerImpl<MMH>
|
||||||
where
|
where
|
||||||
MMH: MasterMessageHandler,
|
MMH: MasterMessageHandler + ?Sized,
|
||||||
{
|
{
|
||||||
async fn start_handle_master_message(&self) -> MessageResult<()> {
|
async fn start_handle_master_message(&self) -> MessageResult<()> {
|
||||||
|
info!("Starting master message handler...");
|
||||||
let mut client = self.connector.get_client();
|
let mut client = self.connector.get_client();
|
||||||
|
|
||||||
'connection_loop: loop {
|
// ensure only one caller can start the handler
|
||||||
let guard_result = self.message_handle_lock.try_write();
|
// create the cancel token for the lifetime of this handler invocation
|
||||||
let mut guard = match guard_result {
|
let cancel_token = CancellationToken::new();
|
||||||
Ok(g) if g.is_none() => g,
|
{
|
||||||
Ok(_) => {
|
let mut guard = self.message_handle_lock.write().await;
|
||||||
|
if guard.is_some() {
|
||||||
warn!("Master message handler is already running");
|
warn!("Master message handler is already running");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
Err(e) => {
|
// placeholder tx; will be replaced per-connection
|
||||||
return Err(MasterHandlerError::MessageHandlingError(format!(
|
let (tx, _rx) = mpsc::channel(1);
|
||||||
"Failed to acquire lock for message handler: {}",
|
*guard = Some(MessageHandleInfo {
|
||||||
e
|
tx,
|
||||||
)));
|
cancel: cancel_token.clone(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
'connection_loop: loop {
|
||||||
|
// fresh outbound channel per connection
|
||||||
let (tx, rx) = mpsc::channel(32);
|
let (tx, rx) = mpsc::channel(32);
|
||||||
let outbound_stream = ReceiverStream::new(rx);
|
let outbound_stream = ReceiverStream::new(rx);
|
||||||
// 2. Connect to the master and start the bi-directional streaming RPC
|
|
||||||
|
// try to connect
|
||||||
let mut stream = match client.stream(outbound_stream).await {
|
let mut stream = match client.stream(outbound_stream).await {
|
||||||
Ok(s) => s.into_inner(),
|
Ok(s) => s.into_inner(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -94,50 +102,101 @@ where
|
|||||||
"Failed to connect to master: {}. Retrying in 5 seconds...",
|
"Failed to connect to master: {}. Retrying in 5 seconds...",
|
||||||
e
|
e
|
||||||
);
|
);
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
// update stored sender so any callers see the current tx
|
||||||
continue 'connection_loop;
|
{
|
||||||
|
let mut guard = self.message_handle_lock.write().await;
|
||||||
|
if let Some(info) = guard.as_mut() {
|
||||||
|
info.tx = tx.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let conn_token = cancel_token.child_token();
|
||||||
|
tokio::select! {
|
||||||
|
_ = conn_token.cancelled() => break 'connection_loop,
|
||||||
|
_ = tokio::time::sleep(std::time::Duration::from_secs(5)) => continue 'connection_loop,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// store current tx so senders can use it
|
||||||
// 3. Spawn a task to handle incoming messages from the master
|
{
|
||||||
let message_handler = self.message_handler.clone();
|
let mut guard = self.message_handle_lock.write().await;
|
||||||
let handle = tokio::spawn(async move {
|
if let Some(info) = guard.as_mut() {
|
||||||
'message_loop: loop {
|
info.tx = tx.clone();
|
||||||
let message = stream.message().await;
|
}
|
||||||
|
}
|
||||||
|
// connection-level token to observe stop requests
|
||||||
|
let conn_token = cancel_token.child_token();
|
||||||
|
info!("Connected to master, starting to receive messages...");
|
||||||
|
// process messages inline so we can clear the slot on exit
|
||||||
|
'message_processing: loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = conn_token.cancelled() => {
|
||||||
|
info!("Stop requested for master handler");
|
||||||
|
break 'connection_loop;
|
||||||
|
}
|
||||||
|
message = stream.message() => {
|
||||||
match message {
|
match message {
|
||||||
Ok(Some(msg)) => {
|
Ok(Some(msg)) => {
|
||||||
if let Err(e) = message_handler.handle_master_message(msg).await {
|
if let Err(e) = self.message_handler.handle_master_message(msg).await {
|
||||||
error!("Failed to handle master message: {:?}", e);
|
error!("Failed to handle master message: {:?}", e);
|
||||||
}
|
}
|
||||||
continue 'message_loop;
|
continue;
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
warn!("Master closed the connection");
|
warn!("Master closed the connection");
|
||||||
return;
|
break 'message_processing;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Error receiving message from master: {:?}", e);
|
error!("Error receiving message from master: {:?}", e);
|
||||||
return;
|
break 'message_processing;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
|
||||||
*guard = Some(MessageHandleInfo { handle, tx });
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// connection ended — clear stored info
|
||||||
|
{
|
||||||
|
let mut guard = self.message_handle_lock.write().await;
|
||||||
|
guard.take();
|
||||||
|
}
|
||||||
|
|
||||||
|
// if stop requested, exit
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
break 'connection_loop;
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise reconnect after backoff
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// final cleanup
|
||||||
|
let mut guard = self.message_handle_lock.write().await;
|
||||||
|
guard.take();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn stop_handle_master_message(&self) -> MessageResult<()> {
|
async fn stop_handle_master_message(&self) -> MessageResult<()> {
|
||||||
// 1. signal the task to stop (e.g. using a cancellation token or channel)
|
// Signal the running handler to stop and wait for it to clear
|
||||||
// 2. wait for the task to finish and handle any errors
|
let mut maybe_cancel = None;
|
||||||
// 3. set handle_master_message_task to None
|
{
|
||||||
|
|
||||||
let mut guard = self.message_handle_lock.write().await;
|
let mut guard = self.message_handle_lock.write().await;
|
||||||
if let Some(handle_info) = guard.take() {
|
if let Some(info) = guard.take() {
|
||||||
handle_info.handle.abort();
|
maybe_cancel = Some(info.cancel);
|
||||||
match handle_info.handle.await {
|
|
||||||
Ok(_) => info!("Master message handler task stopped successfully"),
|
|
||||||
Err(e) => error!("Failed to stop master message handler task: {:?}", e),
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(cancel) = maybe_cancel {
|
||||||
|
cancel.cancel();
|
||||||
|
|
||||||
|
// wait for the handler to clear (with timeout)
|
||||||
|
for _ in 0..50 {
|
||||||
|
if self.message_handle_lock.read().await.is_none() {
|
||||||
|
info!("Master message handler task stopped successfully");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
warn!("Timed out waiting for master message handler to stop");
|
||||||
} else {
|
} else {
|
||||||
warn!("Master message handler is not running");
|
warn!("Master message handler is not running");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user