opt out working now

This commit is contained in:
dal 2025-02-24 06:59:02 -07:00
parent d8ee830c6a
commit 01a3915a4f
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
6 changed files with 119 additions and 95 deletions

View File

@ -404,6 +404,9 @@ impl Agent {
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
};
// Broadcast the assistant message as soon as we receive it
self.get_stream_sender().await.send(Ok(message.clone()))?;
// Update thread with assistant message
self.update_current_thread(message.clone()).await?;
@ -435,10 +438,12 @@ impl Agent {
result_str,
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
// TODO: need the progress for streaming
None,
);
// Broadcast the tool message as soon as we receive it
self.get_stream_sender().await.send(Ok(tool_message.clone()))?;
// Update thread with tool response
self.update_current_thread(tool_message.clone()).await?;
results.push(tool_message);

View File

@ -7,10 +7,12 @@ use uuid::Uuid;
use crate::utils::{
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
tools::{
agents_as_tools::dashboard_agent_tool::DashboardAgentOutput, file_tools::{
agents_as_tools::dashboard_agent_tool::DashboardAgentOutput,
file_tools::{
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
ModifyMetricFilesTool,
}, IntoValueTool, ToolExecutor
},
IntoValueTool, ToolExecutor,
},
};
@ -80,80 +82,16 @@ impl DashboardAgent {
Ok(dashboard)
}
fn is_completion_signal(msg: &AgentMessage) -> bool {
matches!(msg, AgentMessage::Assistant { content: Some(content), tool_calls: None, .. }
if content == "AGENT_COMPLETE")
}
pub async fn run(&self, thread: &mut AgentThread) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
pub async fn run(
&self,
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(DASHBOARD_AGENT_PROMPT.to_string());
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
let mut rx = self.stream_process_thread(thread).await?;
let rx = self.stream_process_thread(thread).await?;
let rx_return = rx.resubscribe();
// Process messages internally until we determine we're done
loop {
tokio::select! {
recv_result = rx.recv() => {
match recv_result {
Ok(msg_result) => {
match msg_result {
Ok(msg) => {
// Forward message to stream sender
let sender = self.get_agent().get_stream_sender().await;
if let Err(e) = sender.send(Ok(msg.clone())) {
let err_msg = format!("Error forwarding message: {:?}", e);
let _ = sender.send(Err(AgentError(err_msg)));
continue;
}
if let Some(content) = msg.get_content() {
if content == "AGENT_COMPLETE" {
return Ok(rx_return);
}
}
}
Err(e) => {
let err_msg = format!("Error processing message: {:?}", e);
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
let err_msg = format!("Error receiving message: {:?}", e);
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
_ = shutdown_rx.recv() => {
// Handle shutdown gracefully
let tools = self.get_agent().get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = self.get_agent().get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Dashboard agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
return Ok(rx_return);
}
}
}
Ok(rx)
}
}

View File

@ -2,12 +2,8 @@ mod agent;
mod agents;
mod types;
pub use agent::AgentError;
pub use agent::Agent;
pub use agent::AgentExt;
pub use agents::*;
pub use types::*;
use anyhow::Result;
use litellm::Message;
use std::sync::Arc;
use tokio::sync::mpsc::Receiver;
pub use types::*;

View File

@ -4,9 +4,10 @@ use litellm::Message as AgentMessage;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::utils::{
agent::{Agent, DashboardAgent},
agent::{Agent, AgentError, DashboardAgent},
tools::{file_tools::file_types::file::FileEnum, ToolExecutor},
};
@ -74,19 +75,19 @@ impl ToolExecutor for DashboardAgentTool {
println!("DashboardAgentTool: Starting dashboard agent run");
// Run the dashboard agent and get the output
let _receiver = dashboard_agent.run(&mut current_thread).await?;
let rx = dashboard_agent.run(&mut current_thread).await?;
println!("DashboardAgentTool: Dashboard agent run completed");
println!("DashboardAgentTool: Preparing success response");
process_agent_output(rx).await?;
self.agent
.set_state_value(String::from("files_available"), Value::Bool(false))
.await;
// Return dummy data for testing
// Return success response
Ok(serde_json::json!({
"status": "success",
"message": "Test dashboard creation",
"message": "Dashboard agent completed successfully",
"duration": 0,
"files": []
}))
@ -113,3 +114,30 @@ impl ToolExecutor for DashboardAgentTool {
})
}
}
async fn process_agent_output(
mut rx: broadcast::Receiver<Result<AgentMessage, AgentError>>,
) -> Result<()> {
while let Ok(msg_result) = rx.recv().await {
match msg_result {
Ok(msg) => {
println!("Agent message: {:?}", msg);
if let AgentMessage::Assistant {
content: Some(_), ..
} = msg
{
return Ok(());
}
}
Err(e) => {
println!("Agent error: {:?}", e);
return Err(e.into());
}
}
}
// If we get here without finding a completion message, return an error
Err(anyhow::anyhow!(
"Agent communication ended without completion message"
))
}

View File

@ -1,11 +1,13 @@
use anyhow::Result;
use async_trait::async_trait;
use litellm::Message as AgentMessage;
use serde::Deserialize;
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::utils::{
agent::{Agent, ExploratoryAgent},
agent::{Agent, AgentError, ExploratoryAgent},
tools::ToolExecutor,
};
@ -54,12 +56,18 @@ impl ToolExecutor for ExploratoryAgentTool {
current_thread.add_user_message(params.ticket_description);
// Run the exploratory agent and get the receiver
let _rx = exploratory_agent.run(&mut current_thread).await?;
let rx = exploratory_agent.run(&mut current_thread).await?;
// Return immediately with status
process_agent_output(rx).await?;
self.agent
.set_state_value(String::from("files_available"), Value::Bool(false))
.await;
// Return success response
Ok(serde_json::json!({
"status": "running",
"message": "Exploratory agent started successfully"
"status": "success",
"message": "Exploratory agent completed successfully"
}))
}
@ -84,3 +92,27 @@ impl ToolExecutor for ExploratoryAgentTool {
})
}
}
async fn process_agent_output(
mut rx: broadcast::Receiver<Result<AgentMessage, AgentError>>,
) -> Result<()> {
while let Ok(msg_result) = rx.recv().await {
match msg_result {
Ok(msg) => {
println!("Agent message: {:?}", msg);
if let AgentMessage::Assistant { content: Some(_), .. } = msg {
return Ok(());
}
}
Err(e) => {
println!("Agent error: {:?}", e);
return Err(e.into());
}
}
}
// If we get here without finding a completion message, return an error
Err(anyhow::anyhow!(
"Agent communication ended without completion message"
))
}

View File

@ -3,11 +3,12 @@ use async_trait::async_trait;
use litellm::Message as AgentMessage;
use serde::Deserialize;
use serde_json::Value;
use tokio::sync::broadcast;
use std::sync::Arc;
use uuid::Uuid;
use crate::utils::{
agent::{Agent, MetricAgent},
agent::{Agent, AgentError, MetricAgent},
tools::ToolExecutor,
};
@ -54,20 +55,21 @@ impl ToolExecutor for MetricAgentTool {
.ok_or_else(|| anyhow::anyhow!("No current thread"))?;
current_thread.remove_last_assistant_message();
current_thread.add_user_message(params.ticket_description);
// Run the metric agent and get the receiver
let _rx = metric_agent.run(&mut current_thread).await?;
let rx = metric_agent.run(&mut current_thread).await?;
// Wait for completion message
let result = process_agent_output(rx).await?;
self.agent
.set_state_value(String::from("files_available"), Value::Bool(false))
.await;
// Return immediately with status
Ok(serde_json::json!({
"status": "running",
"message": "Metric agent started successfully"
"status": "complete",
"message": "Metric agent completed successfully"
}))
}
@ -92,3 +94,26 @@ impl ToolExecutor for MetricAgentTool {
})
}
}
async fn process_agent_output(
mut rx: broadcast::Receiver<Result<AgentMessage, AgentError>>,
) -> Result<()> {
while let Ok(msg_result) = rx.recv().await {
match msg_result {
Ok(msg) => {
println!("Agent message: {:?}", msg);
if let AgentMessage::Assistant { content: Some(_), .. } = msg {
return Ok(());
}
}
Err(e) => {
println!("Agent error: {:?}", e);
return Err(e.into());
}
}
}
// If we get here without finding a completion message, return an error
Err(anyhow::anyhow!(
"Agent communication ended without completion message"
))
}