simplify agents pulling off shutdown function

This commit is contained in:
dal 2025-02-24 07:01:25 -07:00
parent 01a3915a4f
commit 84bada0dd1
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 6 additions and 178 deletions

View File

@ -51,76 +51,9 @@ impl ExploratoryAgent {
thread.set_developer_message(EXPLORATORY_AGENT_PROMPT.to_string());
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
let mut rx = self.stream_process_thread(thread).await?;
let rx = self.stream_process_thread(thread).await?;
// Clone what we need for the processing task
let agent = Arc::clone(self.get_agent());
let rx_return = rx.resubscribe();
tokio::spawn(async move {
loop {
tokio::select! {
recv_result = rx.recv() => {
match recv_result {
Ok(msg_result) => {
match msg_result {
Ok(msg) => {
// Forward message to stream sender
let sender = agent.get_stream_sender().await;
if let Err(e) = sender.send(Ok(msg.clone())) {
let err_msg = format!("Error forwarding message: {:?}", e);
let _ = sender.send(Err(AgentError(err_msg)));
continue;
}
if let Some(content) = msg.get_content() {
if content == "AGENT_COMPLETE" {
break;
}
}
}
Err(e) => {
let err_msg = format!("Error processing message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
let err_msg = format!("Error receiving message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
_ = shutdown_rx.recv() => {
// Handle shutdown gracefully
let tools = agent.get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = agent.get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Exploratory agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
break;
}
}
}
});
Ok(rx_return)
Ok(rx)
}
}

View File

@ -125,47 +125,9 @@ impl ManagerAgent {
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
// Use existing channel - important for sub-agents
let rx = self.get_agent().get_stream_receiver().await;
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
// Clone only what we need
let agent = Arc::clone(self.get_agent());
let thread = thread.clone();
tokio::spawn(async move {
tokio::select! {
result = agent.process_thread(&thread) => {
if let Err(e) = result {
let err_msg = format!("Manager agent processing failed: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
_ = shutdown_rx.recv() => {
// Shutdown all tools
let tools = agent.get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = agent.get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Manager agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
}
}
});
// Get shutdown receiver
let rx = self.stream_process_thread(thread).await?;
Ok(rx)
}

View File

@ -71,76 +71,9 @@ impl MetricAgent {
thread.set_developer_message(METRIC_AGENT_PROMPT.to_string());
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
let mut rx = self.stream_process_thread(thread).await?;
let rx = self.stream_process_thread(thread).await?;
// Clone what we need for the processing task
let agent = Arc::clone(self.get_agent());
let rx_return = rx.resubscribe();
tokio::spawn(async move {
loop {
tokio::select! {
recv_result = rx.recv() => {
match recv_result {
Ok(msg_result) => {
match msg_result {
Ok(msg) => {
// Forward message to stream sender
let sender = agent.get_stream_sender().await;
if let Err(e) = sender.send(Ok(msg.clone())) {
let err_msg = format!("Error forwarding message: {:?}", e);
let _ = sender.send(Err(AgentError(err_msg)));
continue;
}
if let Some(content) = msg.get_content() {
if content == "AGENT_COMPLETE" {
break;
}
}
}
Err(e) => {
let err_msg = format!("Error processing message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
let err_msg = format!("Error receiving message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
_ = shutdown_rx.recv() => {
// Handle shutdown gracefully
let tools = agent.get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = agent.get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Metric agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
break;
}
}
}
});
Ok(rx_return)
Ok(rx)
}
}