mirror of https://github.com/buster-so/buster.git
ok just need to tie up the last few things
This commit is contained in:
parent
18413f2f24
commit
d8ee830c6a
|
@ -10,6 +10,7 @@ use handlers::threads::types::ThreadWithMessages;
|
|||
use litellm::Message as AgentMessage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::routes::rest::ApiResponse;
|
||||
|
@ -98,10 +99,12 @@ async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ThreadWi
|
|||
// Get the receiver and collect all messages
|
||||
let mut rx = agent.run(&mut thread).await?;
|
||||
let mut messages = Vec::new();
|
||||
while let Some(msg_result) = rx.recv().await {
|
||||
match msg_result {
|
||||
Ok(msg) => messages.push(msg),
|
||||
Err(e) => return Err(e.into()),
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(Ok(msg)) => messages.push(msg),
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
Err(e) => return Err(anyhow!(e)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,13 +5,26 @@ use litellm::{
|
|||
};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::tools::ToolExecutor;
|
||||
|
||||
use super::types::AgentThread;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentError(pub String);
|
||||
|
||||
impl std::error::Error for AgentError {}
|
||||
|
||||
impl std::fmt::Display for AgentError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
type MessageResult = Result<Message, AgentError>;
|
||||
|
||||
/// A wrapper type that converts ToolCall parameters to Value before executing
|
||||
struct ToolCallExecutor<T: ToolExecutor> {
|
||||
inner: Box<T>,
|
||||
|
@ -96,11 +109,13 @@ pub struct Agent {
|
|||
/// The current thread being processed, if any
|
||||
current_thread: Arc<RwLock<Option<AgentThread>>>,
|
||||
/// Sender for streaming messages from this agent and sub-agents
|
||||
stream_tx: Arc<RwLock<mpsc::Sender<Result<Message>>>>,
|
||||
stream_tx: Arc<RwLock<broadcast::Sender<MessageResult>>>,
|
||||
/// The user ID for the current thread
|
||||
user_id: Uuid,
|
||||
/// The session ID for the current thread
|
||||
session_id: Uuid,
|
||||
/// Shutdown signal sender
|
||||
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
|
@ -116,8 +131,10 @@ impl Agent {
|
|||
|
||||
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
|
||||
|
||||
// Create a default channel that just drops messages
|
||||
let (tx, _rx) = mpsc::channel(1);
|
||||
// Create a broadcast channel with buffer size 1000
|
||||
let (tx, _rx) = broadcast::channel(1000);
|
||||
// Create shutdown channel with buffer size 1
|
||||
let (shutdown_tx, _) = broadcast::channel(1);
|
||||
|
||||
Self {
|
||||
llm_client,
|
||||
|
@ -128,6 +145,7 @@ impl Agent {
|
|||
stream_tx: Arc::new(RwLock::new(tx)),
|
||||
user_id,
|
||||
session_id,
|
||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,13 +158,14 @@ impl Agent {
|
|||
|
||||
Self {
|
||||
llm_client,
|
||||
tools: Arc::new(RwLock::new(HashMap::new())), // Start with empty tools
|
||||
tools: Arc::new(RwLock::new(HashMap::new())),
|
||||
model: existing_agent.model.clone(),
|
||||
state: Arc::clone(&existing_agent.state),
|
||||
current_thread: Arc::clone(&existing_agent.current_thread),
|
||||
stream_tx: Arc::clone(&existing_agent.stream_tx),
|
||||
user_id: existing_agent.user_id,
|
||||
session_id: existing_agent.session_id,
|
||||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -168,13 +187,13 @@ impl Agent {
|
|||
enabled_tools
|
||||
}
|
||||
|
||||
/// Update the stream sender for this agent
|
||||
pub async fn set_stream_sender(&self, tx: mpsc::Sender<Result<Message>>) {
|
||||
*self.stream_tx.write().await = tx;
|
||||
/// Get a new receiver for the broadcast channel
|
||||
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> {
|
||||
self.stream_tx.read().await.subscribe()
|
||||
}
|
||||
|
||||
/// Get a clone of the current stream sender
|
||||
pub async fn get_stream_sender(&self) -> mpsc::Sender<Result<Message>> {
|
||||
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> {
|
||||
self.stream_tx.read().await.clone()
|
||||
}
|
||||
|
||||
|
@ -276,7 +295,7 @@ impl Agent {
|
|||
let mut rx = self.process_thread_streaming(thread).await?;
|
||||
|
||||
let mut final_message = None;
|
||||
while let Some(msg) = rx.recv().await {
|
||||
while let Ok(msg) = rx.recv().await {
|
||||
final_message = Some(msg?);
|
||||
}
|
||||
|
||||
|
@ -294,26 +313,37 @@ impl Agent {
|
|||
pub async fn process_thread_streaming(
|
||||
&self,
|
||||
thread: &AgentThread,
|
||||
) -> Result<mpsc::Receiver<Result<Message>>> {
|
||||
// Create new channel for this processing session
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
self.set_stream_sender(tx).await;
|
||||
|
||||
) -> Result<broadcast::Receiver<MessageResult>> {
|
||||
// Spawn the processing task
|
||||
let agent_clone = self.clone();
|
||||
let thread_clone = thread.clone();
|
||||
|
||||
// Get shutdown receiver
|
||||
let mut shutdown_rx = self.get_shutdown_receiver().await;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = agent_clone
|
||||
.process_thread_with_depth(&thread_clone, 0)
|
||||
.await
|
||||
{
|
||||
let err_msg = format!("Error processing thread: {:?}", e);
|
||||
let _ = agent_clone.get_stream_sender().await.send(Err(e)).await;
|
||||
tokio::select! {
|
||||
result = agent_clone.process_thread_with_depth(&thread_clone, 0) => {
|
||||
if let Err(e) = result {
|
||||
let err_msg = format!("Error processing thread: {:?}", e);
|
||||
let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
let _ = agent_clone.get_stream_sender().await.send(
|
||||
Ok(Message::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Processing interrupted due to shutdown signal".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
Ok(self.get_stream_receiver().await)
|
||||
}
|
||||
|
||||
async fn process_thread_with_depth(
|
||||
|
@ -335,7 +365,7 @@ impl Agent {
|
|||
None,
|
||||
None,
|
||||
);
|
||||
self.get_stream_sender().await.send(Ok(message)).await?;
|
||||
self.get_stream_sender().await.send(Ok(message))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -425,6 +455,23 @@ impl Agent {
|
|||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a receiver for the shutdown signal
|
||||
pub async fn get_shutdown_receiver(&self) -> broadcast::Receiver<()> {
|
||||
self.shutdown_tx.read().await.subscribe()
|
||||
}
|
||||
|
||||
/// Signal shutdown to all receivers
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
// Send shutdown signal
|
||||
self.shutdown_tx.read().await.send(())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a reference to the tools map
|
||||
pub async fn get_tools(&self) -> tokio::sync::RwLockReadGuard<'_, HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>> {
|
||||
self.tools.read().await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
@ -488,7 +535,7 @@ pub trait AgentExt {
|
|||
async fn stream_process_thread(
|
||||
&self,
|
||||
thread: &AgentThread,
|
||||
) -> Result<mpsc::Receiver<Result<Message>>> {
|
||||
) -> Result<broadcast::Receiver<MessageResult>> {
|
||||
(*self.get_agent()).process_thread_streaming(thread).await
|
||||
}
|
||||
|
||||
|
@ -524,24 +571,27 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
impl WeatherTool {
|
||||
async fn send_progress(&self, content: String, tool_id: String, progress: MessageProgress) -> Result<()> {
|
||||
let message = Message::tool(
|
||||
None,
|
||||
content,
|
||||
tool_id,
|
||||
Some(self.get_name()),
|
||||
Some(progress),
|
||||
);
|
||||
self.agent.get_stream_sender().await.send(Ok(message))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for WeatherTool {
|
||||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
|
||||
// Send progress using agent's stream sender
|
||||
self.agent
|
||||
.get_stream_sender()
|
||||
.await
|
||||
.send(Ok(Message::tool(
|
||||
None,
|
||||
"Fetching weather data...".to_string(),
|
||||
"123".to_string(),
|
||||
Some(self.get_name()),
|
||||
Some(MessageProgress::InProgress),
|
||||
)))
|
||||
.await?;
|
||||
self.send_progress("Fetching weather data...".to_string(), "123".to_string(), MessageProgress::InProgress).await?;
|
||||
|
||||
// Simulate a delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
@ -551,18 +601,7 @@ mod tests {
|
|||
"unit": "fahrenheit"
|
||||
});
|
||||
|
||||
// Send completion message using agent's stream sender
|
||||
self.agent
|
||||
.get_stream_sender()
|
||||
.await
|
||||
.send(Ok(Message::tool(
|
||||
None,
|
||||
serde_json::to_string(&result)?,
|
||||
"123".to_string(),
|
||||
Some(self.get_name()),
|
||||
Some(MessageProgress::Complete),
|
||||
)))
|
||||
.await?;
|
||||
self.send_progress(serde_json::to_string(&result)?, "123".to_string(), MessageProgress::Complete).await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::collections::HashMap;
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{
|
||||
agent::{Agent, AgentExt, AgentThread},
|
||||
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
|
||||
tools::{
|
||||
agents_as_tools::dashboard_agent_tool::DashboardAgentOutput, file_tools::{
|
||||
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
|
||||
|
@ -15,6 +15,7 @@ use crate::utils::{
|
|||
};
|
||||
|
||||
use litellm::Message as AgentMessage;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
pub struct DashboardAgent {
|
||||
agent: Arc<Agent>,
|
||||
|
@ -84,71 +85,75 @@ impl DashboardAgent {
|
|||
if content == "AGENT_COMPLETE")
|
||||
}
|
||||
|
||||
pub async fn run(&self, thread: &mut AgentThread) -> Result<DashboardAgentOutput> {
|
||||
println!("Running dashboard agent");
|
||||
println!("Setting developer message");
|
||||
pub async fn run(&self, thread: &mut AgentThread) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(DASHBOARD_AGENT_PROMPT.to_string());
|
||||
|
||||
println!("Starting stream_process_thread");
|
||||
// Get shutdown receiver
|
||||
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
|
||||
let mut rx = self.stream_process_thread(thread).await?;
|
||||
println!("Got receiver from stream_process_thread");
|
||||
|
||||
println!("Starting message processing loop");
|
||||
let rx_return = rx.resubscribe();
|
||||
|
||||
// Process messages internally until we determine we're done
|
||||
while let Some(msg_result) = rx.recv().await {
|
||||
println!("Received message from channel");
|
||||
match msg_result {
|
||||
Ok(msg) => {
|
||||
println!("Message content: {:?}", msg.get_content());
|
||||
println!("Message has tool calls: {:?}", msg.get_tool_calls());
|
||||
|
||||
println!("Forwarding message to stream sender");
|
||||
if let Err(e) = self.get_agent().get_stream_sender().await.send(Ok(msg.clone())).await {
|
||||
println!("Error forwarding message: {:?}", e);
|
||||
// Continue processing even if we fail to forward
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(content) = msg.get_content() {
|
||||
println!("Message has content: {}", content);
|
||||
if content == "AGENT_COMPLETE" {
|
||||
println!("Found completion signal, breaking loop");
|
||||
break;
|
||||
loop {
|
||||
tokio::select! {
|
||||
recv_result = rx.recv() => {
|
||||
match recv_result {
|
||||
Ok(msg_result) => {
|
||||
match msg_result {
|
||||
Ok(msg) => {
|
||||
// Forward message to stream sender
|
||||
let sender = self.get_agent().get_stream_sender().await;
|
||||
if let Err(e) = sender.send(Ok(msg.clone())) {
|
||||
let err_msg = format!("Error forwarding message: {:?}", e);
|
||||
let _ = sender.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(content) = msg.get_content() {
|
||||
if content == "AGENT_COMPLETE" {
|
||||
return Ok(rx_return);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error processing message: {:?}", e);
|
||||
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error receiving message: {:?}", e);
|
||||
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Error receiving message: {:?}", e);
|
||||
println!("Error details: {:?}", e.to_string());
|
||||
// Log error but continue processing instead of returning error
|
||||
continue;
|
||||
_ = shutdown_rx.recv() => {
|
||||
// Handle shutdown gracefully
|
||||
let tools = self.get_agent().get_tools().await;
|
||||
for (_, tool) in tools.iter() {
|
||||
if let Err(e) = tool.handle_shutdown().await {
|
||||
let err_msg = format!("Error shutting down tool: {:?}", e);
|
||||
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
}
|
||||
|
||||
let _ = self.get_agent().get_stream_sender().await.send(
|
||||
Ok(AgentMessage::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Dashboard agent shutting down gracefully".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
);
|
||||
|
||||
return Ok(rx_return);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("Exited message processing loop");
|
||||
|
||||
println!("Creating completion signal");
|
||||
let completion_msg = AgentMessage::assistant(
|
||||
None,
|
||||
Some("AGENT_COMPLETE".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
println!("Sending completion signal");
|
||||
self.get_agent()
|
||||
.get_stream_sender()
|
||||
.await
|
||||
.send(Ok(completion_msg))
|
||||
.await?;
|
||||
|
||||
println!("Sent completion signal, returning output");
|
||||
Ok(DashboardAgentOutput {
|
||||
message: "Dashboard processing complete".to_string(),
|
||||
duration: 0,
|
||||
files: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,17 +2,15 @@ use std::sync::Arc;
|
|||
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{
|
||||
agent::{Agent, AgentExt, AgentThread},
|
||||
tools::{
|
||||
IntoValueTool, ToolExecutor,
|
||||
},
|
||||
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
|
||||
tools::{IntoValueTool, ToolExecutor},
|
||||
};
|
||||
|
||||
use litellm::Message as AgentMessage;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
pub struct ExploratoryAgent {
|
||||
agent: Arc<Agent>,
|
||||
|
@ -46,11 +44,83 @@ impl ExploratoryAgent {
|
|||
Ok(exploratory)
|
||||
}
|
||||
|
||||
pub async fn run(&self, thread: &mut AgentThread) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
|
||||
// Process using agent's streaming functionality
|
||||
pub async fn run(
|
||||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(EXPLORATORY_AGENT_PROMPT.to_string());
|
||||
|
||||
self.stream_process_thread(thread).await
|
||||
// Get shutdown receiver
|
||||
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
|
||||
let mut rx = self.stream_process_thread(thread).await?;
|
||||
|
||||
// Clone what we need for the processing task
|
||||
let agent = Arc::clone(self.get_agent());
|
||||
|
||||
let rx_return = rx.resubscribe();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
recv_result = rx.recv() => {
|
||||
match recv_result {
|
||||
Ok(msg_result) => {
|
||||
match msg_result {
|
||||
Ok(msg) => {
|
||||
// Forward message to stream sender
|
||||
let sender = agent.get_stream_sender().await;
|
||||
if let Err(e) = sender.send(Ok(msg.clone())) {
|
||||
let err_msg = format!("Error forwarding message: {:?}", e);
|
||||
let _ = sender.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(content) = msg.get_content() {
|
||||
if content == "AGENT_COMPLETE" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error processing message: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error receiving message: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
// Handle shutdown gracefully
|
||||
let tools = agent.get_tools().await;
|
||||
for (_, tool) in tools.iter() {
|
||||
if let Err(e) = tool.handle_shutdown().await {
|
||||
let err_msg = format!("Error shutting down tool: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
}
|
||||
|
||||
let _ = agent.get_stream_sender().await.send(
|
||||
Ok(AgentMessage::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Exploratory agent shutting down gracefully".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx_return)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::tools::agents_as_tools::{DashboardAgentTool, MetricAgentTool};
|
||||
use crate::utils::tools::file_tools::SendAssetsToUserTool;
|
||||
use crate::utils::{
|
||||
agent::{Agent, AgentExt, AgentThread},
|
||||
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
|
||||
tools::{
|
||||
agents_as_tools::ExploratoryAgentTool,
|
||||
file_tools::{SearchDataCatalogTool, SearchFilesTool},
|
||||
|
@ -123,13 +123,57 @@ impl ManagerAgent {
|
|||
pub async fn run(
|
||||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
|
||||
|
||||
// Use existing channel - important for sub-agents
|
||||
let rx = self.get_agent().get_stream_receiver().await;
|
||||
|
||||
// Get shutdown receiver
|
||||
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
|
||||
|
||||
// Clone only what we need
|
||||
let agent = Arc::clone(self.get_agent());
|
||||
let thread = thread.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
result = agent.process_thread(&thread) => {
|
||||
if let Err(e) = result {
|
||||
let err_msg = format!("Manager agent processing failed: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
// Shutdown all tools
|
||||
let tools = agent.get_tools().await;
|
||||
for (_, tool) in tools.iter() {
|
||||
if let Err(e) = tool.handle_shutdown().await {
|
||||
let err_msg = format!("Error shutting down tool: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut rx = self.stream_process_thread(thread).await?;
|
||||
let _ = agent.get_stream_sender().await.send(
|
||||
Ok(AgentMessage::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Manager agent shutting down gracefully".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
/// Shutdown the manager agent and all its tools
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
self.get_agent().shutdown().await
|
||||
}
|
||||
}
|
||||
|
||||
const MANAGER_AGENT_PROMPT: &str = r##"
|
||||
|
|
|
@ -2,11 +2,10 @@ use std::sync::Arc;
|
|||
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{
|
||||
agent::{Agent, AgentExt, AgentThread},
|
||||
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
|
||||
tools::{
|
||||
file_tools::{CreateMetricFilesTool, ModifyMetricFilesTool},
|
||||
IntoValueTool, ToolExecutor,
|
||||
|
@ -14,6 +13,7 @@ use crate::utils::{
|
|||
};
|
||||
|
||||
use litellm::Message as AgentMessage;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
pub struct MetricAgent {
|
||||
agent: Arc<Agent>,
|
||||
|
@ -67,10 +67,80 @@ impl MetricAgent {
|
|||
pub async fn run(
|
||||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(METRIC_AGENT_PROMPT.to_string());
|
||||
// Process using agent's streaming functionality
|
||||
self.stream_process_thread(thread).await
|
||||
|
||||
// Get shutdown receiver
|
||||
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
|
||||
let mut rx = self.stream_process_thread(thread).await?;
|
||||
|
||||
// Clone what we need for the processing task
|
||||
let agent = Arc::clone(self.get_agent());
|
||||
|
||||
let rx_return = rx.resubscribe();
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
recv_result = rx.recv() => {
|
||||
match recv_result {
|
||||
Ok(msg_result) => {
|
||||
match msg_result {
|
||||
Ok(msg) => {
|
||||
// Forward message to stream sender
|
||||
let sender = agent.get_stream_sender().await;
|
||||
if let Err(e) = sender.send(Ok(msg.clone())) {
|
||||
let err_msg = format!("Error forwarding message: {:?}", e);
|
||||
let _ = sender.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(content) = msg.get_content() {
|
||||
if content == "AGENT_COMPLETE" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error processing message: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!("Error receiving message: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.recv() => {
|
||||
// Handle shutdown gracefully
|
||||
let tools = agent.get_tools().await;
|
||||
for (_, tool) in tools.iter() {
|
||||
if let Err(e) = tool.handle_shutdown().await {
|
||||
let err_msg = format!("Error shutting down tool: {:?}", e);
|
||||
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
}
|
||||
|
||||
let _ = agent.get_stream_sender().await.send(
|
||||
Ok(AgentMessage::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Metric agent shutting down gracefully".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx_return)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ impl ToolExecutor for DashboardAgentTool {
|
|||
|
||||
println!("DashboardAgentTool: Starting dashboard agent run");
|
||||
// Run the dashboard agent and get the output
|
||||
let output = dashboard_agent.run(&mut current_thread).await?;
|
||||
let _receiver = dashboard_agent.run(&mut current_thread).await?;
|
||||
println!("DashboardAgentTool: Dashboard agent run completed");
|
||||
|
||||
println!("DashboardAgentTool: Preparing success response");
|
||||
|
@ -83,12 +83,12 @@ impl ToolExecutor for DashboardAgentTool {
|
|||
.set_state_value(String::from("files_available"), Value::Bool(false))
|
||||
.await;
|
||||
|
||||
// Return success with the output
|
||||
// Return dummy data for testing
|
||||
Ok(serde_json::json!({
|
||||
"status": "success",
|
||||
"message": output.message,
|
||||
"duration": output.duration,
|
||||
"files": output.files
|
||||
"message": "Test dashboard creation",
|
||||
"duration": 0,
|
||||
"files": []
|
||||
}))
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ pub mod interaction_tools;
|
|||
/// A trait that defines how tools should be implemented.
|
||||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
/// Tools are constructed with a reference to their agent and can access its capabilities.
|
||||
#[async_trait]
|
||||
#[async_trait::async_trait]
|
||||
pub trait ToolExecutor: Send + Sync {
|
||||
/// The type of the output of the tool
|
||||
type Output: Serialize + Send;
|
||||
|
@ -34,21 +34,27 @@ pub trait ToolExecutor: Send + Sync {
|
|||
|
||||
/// Check if this tool is currently enabled
|
||||
async fn is_enabled(&self) -> bool;
|
||||
|
||||
/// Handle shutdown signal. Default implementation does nothing.
|
||||
/// Tools should override this if they need to perform cleanup on shutdown.
|
||||
async fn handle_shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper type that converts any ToolExecutor to one that outputs Value
|
||||
pub struct ValueToolExecutor<T: ToolExecutor> {
|
||||
pub struct ValueToolExecutor<T: ToolExecutor + Send + Sync> {
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T: ToolExecutor> ValueToolExecutor<T> {
|
||||
impl<T: ToolExecutor + Send + Sync> ValueToolExecutor<T> {
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
|
||||
#[async_trait::async_trait]
|
||||
impl<T: ToolExecutor + Send + Sync> ToolExecutor for ValueToolExecutor<T> {
|
||||
type Output = Value;
|
||||
type Params = T::Params;
|
||||
|
||||
|
|
Loading…
Reference in New Issue