ok just need to tie up the last few things

This commit is contained in:
dal 2025-02-24 06:20:16 -07:00
parent 18413f2f24
commit d8ee830c6a
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
8 changed files with 371 additions and 134 deletions

View File

@ -10,6 +10,7 @@ use handlers::threads::types::ThreadWithMessages;
use litellm::Message as AgentMessage;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::broadcast;
use uuid::Uuid;
use crate::routes::rest::ApiResponse;
@ -98,10 +99,12 @@ async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ThreadWi
// Get the receiver and collect all messages
let mut rx = agent.run(&mut thread).await?;
let mut messages = Vec::new();
while let Some(msg_result) = rx.recv().await {
match msg_result {
Ok(msg) => messages.push(msg),
Err(e) => return Err(e.into()),
loop {
match rx.recv().await {
Ok(Ok(msg)) => messages.push(msg),
Ok(Err(e)) => return Err(e.into()),
Err(broadcast::error::RecvError::Closed) => break,
Err(e) => return Err(anyhow!(e)),
}
}

View File

@ -5,13 +5,26 @@ use litellm::{
};
use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{mpsc, RwLock};
use tokio::sync::{broadcast, RwLock};
use uuid::Uuid;
use crate::utils::tools::ToolExecutor;
use super::types::AgentThread;
#[derive(Debug, Clone)]
pub struct AgentError(pub String);
impl std::error::Error for AgentError {}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
type MessageResult = Result<Message, AgentError>;
/// A wrapper type that converts ToolCall parameters to Value before executing
struct ToolCallExecutor<T: ToolExecutor> {
inner: Box<T>,
@ -96,11 +109,13 @@ pub struct Agent {
/// The current thread being processed, if any
current_thread: Arc<RwLock<Option<AgentThread>>>,
/// Sender for streaming messages from this agent and sub-agents
stream_tx: Arc<RwLock<mpsc::Sender<Result<Message>>>>,
stream_tx: Arc<RwLock<broadcast::Sender<MessageResult>>>,
/// The user ID for the current thread
user_id: Uuid,
/// The session ID for the current thread
session_id: Uuid,
/// Shutdown signal sender
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
}
impl Agent {
@ -116,8 +131,10 @@ impl Agent {
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
// Create a default channel that just drops messages
let (tx, _rx) = mpsc::channel(1);
// Create a broadcast channel with buffer size 1000
let (tx, _rx) = broadcast::channel(1000);
// Create shutdown channel with buffer size 1
let (shutdown_tx, _) = broadcast::channel(1);
Self {
llm_client,
@ -128,6 +145,7 @@ impl Agent {
stream_tx: Arc::new(RwLock::new(tx)),
user_id,
session_id,
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
}
}
@ -140,13 +158,14 @@ impl Agent {
Self {
llm_client,
tools: Arc::new(RwLock::new(HashMap::new())), // Start with empty tools
tools: Arc::new(RwLock::new(HashMap::new())),
model: existing_agent.model.clone(),
state: Arc::clone(&existing_agent.state),
current_thread: Arc::clone(&existing_agent.current_thread),
stream_tx: Arc::clone(&existing_agent.stream_tx),
user_id: existing_agent.user_id,
session_id: existing_agent.session_id,
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx),
}
}
@ -168,13 +187,13 @@ impl Agent {
enabled_tools
}
/// Update the stream sender for this agent
pub async fn set_stream_sender(&self, tx: mpsc::Sender<Result<Message>>) {
*self.stream_tx.write().await = tx;
/// Get a new receiver for the broadcast channel
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> {
self.stream_tx.read().await.subscribe()
}
/// Get a clone of the current stream sender
pub async fn get_stream_sender(&self) -> mpsc::Sender<Result<Message>> {
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> {
self.stream_tx.read().await.clone()
}
@ -276,7 +295,7 @@ impl Agent {
let mut rx = self.process_thread_streaming(thread).await?;
let mut final_message = None;
while let Some(msg) = rx.recv().await {
while let Ok(msg) = rx.recv().await {
final_message = Some(msg?);
}
@ -294,26 +313,37 @@ impl Agent {
pub async fn process_thread_streaming(
&self,
thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> {
// Create new channel for this processing session
let (tx, rx) = mpsc::channel(100);
self.set_stream_sender(tx).await;
) -> Result<broadcast::Receiver<MessageResult>> {
// Spawn the processing task
let agent_clone = self.clone();
let thread_clone = thread.clone();
// Get shutdown receiver
let mut shutdown_rx = self.get_shutdown_receiver().await;
tokio::spawn(async move {
if let Err(e) = agent_clone
.process_thread_with_depth(&thread_clone, 0)
.await
{
let err_msg = format!("Error processing thread: {:?}", e);
let _ = agent_clone.get_stream_sender().await.send(Err(e)).await;
tokio::select! {
result = agent_clone.process_thread_with_depth(&thread_clone, 0) => {
if let Err(e) = result {
let err_msg = format!("Error processing thread: {:?}", e);
let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
_ = shutdown_rx.recv() => {
let _ = agent_clone.get_stream_sender().await.send(
Ok(Message::assistant(
Some("shutdown_message".to_string()),
Some("Processing interrupted due to shutdown signal".to_string()),
None,
None,
None,
))
);
}
}
});
Ok(rx)
Ok(self.get_stream_receiver().await)
}
async fn process_thread_with_depth(
@ -335,7 +365,7 @@ impl Agent {
None,
None,
);
self.get_stream_sender().await.send(Ok(message)).await?;
self.get_stream_sender().await.send(Ok(message))?;
return Ok(());
}
@ -425,6 +455,23 @@ impl Agent {
Ok(())
}
}
/// Get a receiver for the shutdown signal
pub async fn get_shutdown_receiver(&self) -> broadcast::Receiver<()> {
self.shutdown_tx.read().await.subscribe()
}
/// Signal shutdown to all receivers
pub async fn shutdown(&self) -> Result<()> {
// Send shutdown signal
self.shutdown_tx.read().await.send(())?;
Ok(())
}
/// Get a reference to the tools map
pub async fn get_tools(&self) -> tokio::sync::RwLockReadGuard<'_, HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>> {
self.tools.read().await
}
}
#[derive(Debug, Default)]
@ -488,7 +535,7 @@ pub trait AgentExt {
async fn stream_process_thread(
&self,
thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> {
) -> Result<broadcast::Receiver<MessageResult>> {
(*self.get_agent()).process_thread_streaming(thread).await
}
@ -524,24 +571,27 @@ mod tests {
}
}
impl WeatherTool {
async fn send_progress(&self, content: String, tool_id: String, progress: MessageProgress) -> Result<()> {
let message = Message::tool(
None,
content,
tool_id,
Some(self.get_name()),
Some(progress),
);
self.agent.get_stream_sender().await.send(Ok(message))?;
Ok(())
}
}
#[async_trait]
impl ToolExecutor for WeatherTool {
type Output = Value;
type Params = Value;
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
// Send progress using agent's stream sender
self.agent
.get_stream_sender()
.await
.send(Ok(Message::tool(
None,
"Fetching weather data...".to_string(),
"123".to_string(),
Some(self.get_name()),
Some(MessageProgress::InProgress),
)))
.await?;
self.send_progress("Fetching weather data...".to_string(), "123".to_string(), MessageProgress::InProgress).await?;
// Simulate a delay
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
@ -551,18 +601,7 @@ mod tests {
"unit": "fahrenheit"
});
// Send completion message using agent's stream sender
self.agent
.get_stream_sender()
.await
.send(Ok(Message::tool(
None,
serde_json::to_string(&result)?,
"123".to_string(),
Some(self.get_name()),
Some(MessageProgress::Complete),
)))
.await?;
self.send_progress(serde_json::to_string(&result)?, "123".to_string(), MessageProgress::Complete).await?;
Ok(result)
}

View File

@ -5,7 +5,7 @@ use std::collections::HashMap;
use uuid::Uuid;
use crate::utils::{
agent::{Agent, AgentExt, AgentThread},
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
tools::{
agents_as_tools::dashboard_agent_tool::DashboardAgentOutput, file_tools::{
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
@ -15,6 +15,7 @@ use crate::utils::{
};
use litellm::Message as AgentMessage;
use tokio::sync::broadcast;
pub struct DashboardAgent {
agent: Arc<Agent>,
@ -84,71 +85,75 @@ impl DashboardAgent {
if content == "AGENT_COMPLETE")
}
pub async fn run(&self, thread: &mut AgentThread) -> Result<DashboardAgentOutput> {
println!("Running dashboard agent");
println!("Setting developer message");
pub async fn run(&self, thread: &mut AgentThread) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(DASHBOARD_AGENT_PROMPT.to_string());
println!("Starting stream_process_thread");
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
let mut rx = self.stream_process_thread(thread).await?;
println!("Got receiver from stream_process_thread");
println!("Starting message processing loop");
let rx_return = rx.resubscribe();
// Process messages internally until we determine we're done
while let Some(msg_result) = rx.recv().await {
println!("Received message from channel");
match msg_result {
Ok(msg) => {
println!("Message content: {:?}", msg.get_content());
println!("Message has tool calls: {:?}", msg.get_tool_calls());
println!("Forwarding message to stream sender");
if let Err(e) = self.get_agent().get_stream_sender().await.send(Ok(msg.clone())).await {
println!("Error forwarding message: {:?}", e);
// Continue processing even if we fail to forward
continue;
}
if let Some(content) = msg.get_content() {
println!("Message has content: {}", content);
if content == "AGENT_COMPLETE" {
println!("Found completion signal, breaking loop");
break;
loop {
tokio::select! {
recv_result = rx.recv() => {
match recv_result {
Ok(msg_result) => {
match msg_result {
Ok(msg) => {
// Forward message to stream sender
let sender = self.get_agent().get_stream_sender().await;
if let Err(e) = sender.send(Ok(msg.clone())) {
let err_msg = format!("Error forwarding message: {:?}", e);
let _ = sender.send(Err(AgentError(err_msg)));
continue;
}
if let Some(content) = msg.get_content() {
if content == "AGENT_COMPLETE" {
return Ok(rx_return);
}
}
}
Err(e) => {
let err_msg = format!("Error processing message: {:?}", e);
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
let err_msg = format!("Error receiving message: {:?}", e);
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
println!("Error receiving message: {:?}", e);
println!("Error details: {:?}", e.to_string());
// Log error but continue processing instead of returning error
continue;
_ = shutdown_rx.recv() => {
// Handle shutdown gracefully
let tools = self.get_agent().get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = self.get_agent().get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = self.get_agent().get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Dashboard agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
return Ok(rx_return);
}
}
}
println!("Exited message processing loop");
println!("Creating completion signal");
let completion_msg = AgentMessage::assistant(
None,
Some("AGENT_COMPLETE".to_string()),
None,
None,
None,
);
println!("Sending completion signal");
self.get_agent()
.get_stream_sender()
.await
.send(Ok(completion_msg))
.await?;
println!("Sent completion signal, returning output");
Ok(DashboardAgentOutput {
message: "Dashboard processing complete".to_string(),
duration: 0,
files: vec![],
})
}
}

View File

@ -2,17 +2,15 @@ use std::sync::Arc;
use anyhow::Result;
use std::collections::HashMap;
use tokio::sync::mpsc::Receiver;
use uuid::Uuid;
use crate::utils::{
agent::{Agent, AgentExt, AgentThread},
tools::{
IntoValueTool, ToolExecutor,
},
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
tools::{IntoValueTool, ToolExecutor},
};
use litellm::Message as AgentMessage;
use tokio::sync::broadcast;
pub struct ExploratoryAgent {
agent: Arc<Agent>,
@ -46,11 +44,83 @@ impl ExploratoryAgent {
Ok(exploratory)
}
pub async fn run(&self, thread: &mut AgentThread) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
// Process using agent's streaming functionality
pub async fn run(
&self,
thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(EXPLORATORY_AGENT_PROMPT.to_string());
self.stream_process_thread(thread).await
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
let mut rx = self.stream_process_thread(thread).await?;
// Clone what we need for the processing task
let agent = Arc::clone(self.get_agent());
let rx_return = rx.resubscribe();
tokio::spawn(async move {
loop {
tokio::select! {
recv_result = rx.recv() => {
match recv_result {
Ok(msg_result) => {
match msg_result {
Ok(msg) => {
// Forward message to stream sender
let sender = agent.get_stream_sender().await;
if let Err(e) = sender.send(Ok(msg.clone())) {
let err_msg = format!("Error forwarding message: {:?}", e);
let _ = sender.send(Err(AgentError(err_msg)));
continue;
}
if let Some(content) = msg.get_content() {
if content == "AGENT_COMPLETE" {
break;
}
}
}
Err(e) => {
let err_msg = format!("Error processing message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
let err_msg = format!("Error receiving message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
_ = shutdown_rx.recv() => {
// Handle shutdown gracefully
let tools = agent.get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = agent.get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Exploratory agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
break;
}
}
}
});
Ok(rx_return)
}
}

View File

@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc::Receiver;
use tokio::sync::broadcast;
use uuid::Uuid;
use crate::utils::tools::agents_as_tools::{DashboardAgentTool, MetricAgentTool};
use crate::utils::tools::file_tools::SendAssetsToUserTool;
use crate::utils::{
agent::{Agent, AgentExt, AgentThread},
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
tools::{
agents_as_tools::ExploratoryAgentTool,
file_tools::{SearchDataCatalogTool, SearchFilesTool},
@ -123,13 +123,57 @@ impl ManagerAgent {
pub async fn run(
&self,
thread: &mut AgentThread,
) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
// Use existing channel - important for sub-agents
let rx = self.get_agent().get_stream_receiver().await;
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
// Clone only what we need
let agent = Arc::clone(self.get_agent());
let thread = thread.clone();
tokio::spawn(async move {
tokio::select! {
result = agent.process_thread(&thread) => {
if let Err(e) = result {
let err_msg = format!("Manager agent processing failed: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
_ = shutdown_rx.recv() => {
// Shutdown all tools
let tools = agent.get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let mut rx = self.stream_process_thread(thread).await?;
let _ = agent.get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Manager agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
}
}
});
Ok(rx)
}
/// Shutdown the manager agent and all its tools
pub async fn shutdown(&self) -> Result<()> {
self.get_agent().shutdown().await
}
}
const MANAGER_AGENT_PROMPT: &str = r##"

View File

@ -2,11 +2,10 @@ use std::sync::Arc;
use anyhow::Result;
use std::collections::HashMap;
use tokio::sync::mpsc::Receiver;
use uuid::Uuid;
use crate::utils::{
agent::{Agent, AgentExt, AgentThread},
agent::{agent::AgentError, Agent, AgentExt, AgentThread},
tools::{
file_tools::{CreateMetricFilesTool, ModifyMetricFilesTool},
IntoValueTool, ToolExecutor,
@ -14,6 +13,7 @@ use crate::utils::{
};
use litellm::Message as AgentMessage;
use tokio::sync::broadcast;
pub struct MetricAgent {
agent: Arc<Agent>,
@ -67,10 +67,80 @@ impl MetricAgent {
pub async fn run(
&self,
thread: &mut AgentThread,
) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(METRIC_AGENT_PROMPT.to_string());
// Process using agent's streaming functionality
self.stream_process_thread(thread).await
// Get shutdown receiver
let mut shutdown_rx = self.get_agent().get_shutdown_receiver().await;
let mut rx = self.stream_process_thread(thread).await?;
// Clone what we need for the processing task
let agent = Arc::clone(self.get_agent());
let rx_return = rx.resubscribe();
tokio::spawn(async move {
loop {
tokio::select! {
recv_result = rx.recv() => {
match recv_result {
Ok(msg_result) => {
match msg_result {
Ok(msg) => {
// Forward message to stream sender
let sender = agent.get_stream_sender().await;
if let Err(e) = sender.send(Ok(msg.clone())) {
let err_msg = format!("Error forwarding message: {:?}", e);
let _ = sender.send(Err(AgentError(err_msg)));
continue;
}
if let Some(content) = msg.get_content() {
if content == "AGENT_COMPLETE" {
break;
}
}
}
Err(e) => {
let err_msg = format!("Error processing message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
Err(e) => {
let err_msg = format!("Error receiving message: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
continue;
}
}
}
_ = shutdown_rx.recv() => {
// Handle shutdown gracefully
let tools = agent.get_tools().await;
for (_, tool) in tools.iter() {
if let Err(e) = tool.handle_shutdown().await {
let err_msg = format!("Error shutting down tool: {:?}", e);
let _ = agent.get_stream_sender().await.send(Err(AgentError(err_msg)));
}
}
let _ = agent.get_stream_sender().await.send(
Ok(AgentMessage::assistant(
Some("shutdown_message".to_string()),
Some("Metric agent shutting down gracefully".to_string()),
None,
None,
None,
))
);
break;
}
}
}
});
Ok(rx_return)
}
}

View File

@ -74,7 +74,7 @@ impl ToolExecutor for DashboardAgentTool {
println!("DashboardAgentTool: Starting dashboard agent run");
// Run the dashboard agent and get the output
let output = dashboard_agent.run(&mut current_thread).await?;
let _receiver = dashboard_agent.run(&mut current_thread).await?;
println!("DashboardAgentTool: Dashboard agent run completed");
println!("DashboardAgentTool: Preparing success response");
@ -83,12 +83,12 @@ impl ToolExecutor for DashboardAgentTool {
.set_state_value(String::from("files_available"), Value::Bool(false))
.await;
// Return success with the output
// Return dummy data for testing
Ok(serde_json::json!({
"status": "success",
"message": output.message,
"duration": output.duration,
"files": output.files
"message": "Test dashboard creation",
"duration": 0,
"files": []
}))
}

View File

@ -15,7 +15,7 @@ pub mod interaction_tools;
/// A trait that defines how tools should be implemented.
/// Any struct that wants to be used as a tool must implement this trait.
/// Tools are constructed with a reference to their agent and can access its capabilities.
#[async_trait]
#[async_trait::async_trait]
pub trait ToolExecutor: Send + Sync {
/// The type of the output of the tool
type Output: Serialize + Send;
@ -34,21 +34,27 @@ pub trait ToolExecutor: Send + Sync {
/// Check if this tool is currently enabled
async fn is_enabled(&self) -> bool;
/// Handle shutdown signal. Default implementation does nothing.
/// Tools should override this if they need to perform cleanup on shutdown.
async fn handle_shutdown(&self) -> Result<()> {
Ok(())
}
}
/// A wrapper type that converts any ToolExecutor to one that outputs Value
pub struct ValueToolExecutor<T: ToolExecutor> {
pub struct ValueToolExecutor<T: ToolExecutor + Send + Sync> {
inner: T,
}
impl<T: ToolExecutor> ValueToolExecutor<T> {
impl<T: ToolExecutor + Send + Sync> ValueToolExecutor<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
#[async_trait]
impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
#[async_trait::async_trait]
impl<T: ToolExecutor + Send + Sync> ToolExecutor for ValueToolExecutor<T> {
type Output = Value;
type Params = T::Params;