better use of user throughout agents and tools

This commit is contained in:
dal 2025-03-12 08:27:59 -06:00
parent 9577fe99ba
commit 4432574086
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
14 changed files with 109 additions and 304 deletions

View File

@ -15,6 +15,7 @@ uuid = { workspace = true }
litellm = { path = "../litellm" }
database = { path = "../database" }
query_engine = { path = "../query_engine" }
middleware = { path = "../middleware" }
serde_json = { workspace = true }
futures = { workspace = true }
futures-util = { workspace = true }

View File

@ -4,11 +4,12 @@ use litellm::{
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
};
use middleware::AuthenticatedUser;
use serde_json::Value;
use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use uuid::Uuid;
use std::time::{Duration, Instant};
use crate::models::AgentThread;
@ -34,7 +35,6 @@ struct MessageBuffer {
first_message_sent: bool,
}
impl MessageBuffer {
fn new() -> Self {
Self {
@ -80,7 +80,11 @@ impl MessageBuffer {
// Create and send the message
let message = AgentMessage::assistant(
self.message_id.clone(),
if self.content.is_empty() { None } else { Some(self.content.clone()) },
if self.content.is_empty() {
None
} else {
Some(self.content.clone())
},
tool_calls,
MessageProgress::InProgress,
Some(!self.first_message_sent),
@ -98,8 +102,6 @@ impl MessageBuffer {
}
}
#[derive(Clone)]
/// The Agent struct is responsible for managing conversations with the LLM
/// and coordinating tool executions. It maintains a registry of available tools
@ -122,7 +124,7 @@ pub struct Agent {
/// Sender for streaming messages from this agent and sub-agents
stream_tx: Arc<RwLock<Option<broadcast::Sender<MessageResult>>>>,
/// The user ID for the current thread
user_id: Uuid,
user: AuthenticatedUser,
/// The session ID for the current thread
session_id: Uuid,
/// Agent name
@ -136,7 +138,7 @@ impl Agent {
pub fn new(
model: String,
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
user_id: Uuid,
user: AuthenticatedUser,
session_id: Uuid,
name: String,
) -> Self {
@ -155,7 +157,7 @@ impl Agent {
state: Arc::new(RwLock::new(HashMap::new())),
current_thread: Arc::new(RwLock::new(None)),
stream_tx: Arc::new(RwLock::new(Some(tx))),
user_id,
user,
session_id,
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
name,
@ -176,7 +178,7 @@ impl Agent {
state: Arc::clone(&existing_agent.state),
current_thread: Arc::clone(&existing_agent.current_thread),
stream_tx: Arc::clone(&existing_agent.stream_tx),
user_id: existing_agent.user_id,
user: existing_agent.user.clone(),
session_id: existing_agent.session_id,
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx),
name,
@ -241,7 +243,11 @@ impl Agent {
}
pub fn get_user_id(&self) -> Uuid {
self.user_id
self.user.id
}
pub fn get_user(&self) -> AuthenticatedUser {
self.user.clone()
}
pub fn get_session_id(&self) -> Uuid {
@ -450,7 +456,8 @@ impl Agent {
if let Some(tool_calls) = &delta.tool_calls {
for tool_call in tool_calls {
let id = tool_call.id.clone().unwrap_or_else(|| {
buffer.tool_calls
buffer
.tool_calls
.keys()
.next()
.map(|s| s.clone())
@ -458,7 +465,8 @@ impl Agent {
});
// Get or create the pending tool call
let pending_call = buffer.tool_calls
let pending_call = buffer
.tool_calls
.entry(id.clone())
.or_insert_with(PendingToolCall::new);
@ -484,7 +492,8 @@ impl Agent {
// Create and send the final message
let final_tool_calls: Option<Vec<ToolCall>> = if !buffer.tool_calls.is_empty() {
Some(
buffer.tool_calls
buffer
.tool_calls
.values()
.map(|p| p.clone().into_tool_call())
.collect(),
@ -495,7 +504,11 @@ impl Agent {
let final_message = AgentMessage::assistant(
buffer.message_id,
if buffer.content.is_empty() { None } else { Some(buffer.content) },
if buffer.content.is_empty() {
None
} else {
Some(buffer.content)
},
final_tool_calls.clone(),
MessageProgress::Complete,
Some(false),
@ -527,7 +540,7 @@ impl Agent {
for tool_call in tool_calls {
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
let result = tool.execute(params, tool_call.id.clone()).await?;
let result = tool.execute(params, tool_call.id.clone(), self.get_user()).await?;
let result_str = serde_json::to_string(&result)?;
let tool_message = AgentMessage::tool(
None,
@ -671,12 +684,32 @@ mod tests {
use super::*;
use crate::tools::ToolExecutor;
use async_trait::async_trait;
use chrono::{Utc};
use litellm::MessageProgress;
use serde_json::{json, Value};
use uuid::Uuid;
use middleware::types::AuthenticatedUser;
fn setup() {
dotenv::dotenv().ok();
std::env::set_var("LLM_API_KEY", "test_key");
std::env::set_var("LLM_BASE_URL", "http://localhost:8000");
}
// Create a mock AuthenticatedUser for testing
fn create_test_user() -> AuthenticatedUser {
AuthenticatedUser {
id: Uuid::new_v4(),
email: "test@example.com".to_string(),
name: Some("Test User".to_string()),
config: json!({}),
created_at: Utc::now(),
updated_at: Utc::now(),
attributes: json!({}),
avatar_url: None,
organizations: vec![],
teams: vec![],
}
}
struct WeatherTool {
@ -696,13 +729,8 @@ mod tests {
tool_id: String,
progress: MessageProgress,
) -> Result<()> {
let message = AgentMessage::tool(
None,
content,
tool_id,
Some(self.get_name()),
progress,
);
let message =
AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress);
self.agent.get_stream_sender().await.send(Ok(message))?;
Ok(())
}
@ -713,7 +741,12 @@ mod tests {
type Output = Value;
type Params = Value;
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(
&self,
params: Self::Params,
tool_call_id: String,
user: AuthenticatedUser,
) -> Result<Self::Output> {
self.send_progress(
"Fetching weather data...".to_string(),
"123".to_string(),
@ -778,14 +811,14 @@ mod tests {
let agent = Agent::new(
"o1".to_string(),
HashMap::new(),
Uuid::new_v4(),
create_test_user(),
Uuid::new_v4(),
"test_agent".to_string(),
);
let thread = AgentThread::new(
None,
Uuid::new_v4(),
create_test_user().id,
vec![AgentMessage::user("Hello, world!".to_string())],
);
@ -803,7 +836,7 @@ mod tests {
let mut agent = Agent::new(
"o1".to_string(),
HashMap::new(),
Uuid::new_v4(),
create_test_user(),
Uuid::new_v4(),
"test_agent".to_string(),
);
@ -816,7 +849,7 @@ mod tests {
let thread = AgentThread::new(
None,
Uuid::new_v4(),
create_test_user().id,
vec![AgentMessage::user(
"What is the weather in vineyard ut?".to_string(),
)],
@ -836,7 +869,7 @@ mod tests {
let mut agent = Agent::new(
"o1".to_string(),
HashMap::new(),
Uuid::new_v4(),
create_test_user(),
Uuid::new_v4(),
"test_agent".to_string(),
);
@ -847,7 +880,7 @@ mod tests {
let thread = AgentThread::new(
None,
Uuid::new_v4(),
create_test_user().id,
vec![AgentMessage::user(
"What is the weather in vineyard ut and san francisco?".to_string(),
)],
@ -867,7 +900,7 @@ mod tests {
let agent = Agent::new(
"o1".to_string(),
HashMap::new(),
Uuid::new_v4(),
create_test_user(),
Uuid::new_v4(),
"test_agent".to_string(),
);

View File

@ -1,4 +1,5 @@
use anyhow::Result;
use middleware::AuthenticatedUser;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
@ -97,12 +98,12 @@ impl BusterSuperAgent {
Ok(())
}
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
pub async fn new(user: AuthenticatedUser, session_id: Uuid) -> Result<Self> {
// Create agent with empty tools map
let agent = Arc::new(Agent::new(
"o3-mini".to_string(),
HashMap::new(),
user_id,
user,
session_id,
"buster_super_agent".to_string(),
));

View File

@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{self, json, Value};
use tracing::debug;
use uuid::Uuid;
use middleware::AuthenticatedUser;
use crate::{
agent::Agent,
@ -131,7 +132,7 @@ impl ToolExecutor for CreateDashboardFilesTool {
}
}
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
let start_time = Instant::now();
let files = params.files;

View File

@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::debug;
use uuid::Uuid;
use middleware::AuthenticatedUser;
use crate::{
agent::Agent,
@ -81,7 +82,7 @@ impl ToolExecutor for CreateMetricFilesTool {
}
}
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
let start_time = Instant::now();
let files = params.files;

View File

@ -12,6 +12,7 @@ use indexmap::IndexMap;
use query_engine::data_types::DataType;
use serde_json::Value;
use tracing::{debug, error, info};
use middleware::AuthenticatedUser;
use super::{
common::{
@ -67,7 +68,7 @@ impl ToolExecutor for ModifyDashboardFilesTool {
}
}
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
let start_time = Instant::now();
debug!("Starting file modification execution");
@ -216,127 +217,3 @@ impl ToolExecutor for ModifyDashboardFilesTool {
})
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::tools::categories::file_tools::common::{
apply_modifications_to_content, Modification, ModificationResult,
};
use chrono::Utc;
use serde_json::json;
use uuid::Uuid;
#[test]
fn test_apply_modifications_to_content() {
let original_content =
"name: test_dashboard\ntype: dashboard\ndescription: A test dashboard";
// Test single modification
let mods1 = vec![Modification {
content_to_replace: "type: dashboard".to_string(),
new_content: "type: custom_dashboard".to_string(),
}];
let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap();
assert_eq!(
result1,
"name: test_dashboard\ntype: custom_dashboard\ndescription: A test dashboard"
);
// Test multiple non-overlapping modifications
let mods2 = vec![
Modification {
content_to_replace: "test_dashboard".to_string(),
new_content: "new_dashboard".to_string(),
},
Modification {
content_to_replace: "A test dashboard".to_string(),
new_content: "An updated dashboard".to_string(),
},
];
let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap();
assert_eq!(
result2,
"name: new_dashboard\ntype: dashboard\ndescription: An updated dashboard"
);
// Test content not found
let mods3 = vec![Modification {
content_to_replace: "nonexistent content".to_string(),
new_content: "new content".to_string(),
}];
let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml");
assert!(result3.is_err());
assert!(result3
.unwrap_err()
.to_string()
.contains("Content to replace not found"));
}
#[test]
fn test_modification_result_tracking() {
let result = ModificationResult {
file_id: Uuid::new_v4(),
file_name: "test.yml".to_string(),
success: true,
error: None,
modification_type: "content".to_string(),
timestamp: Utc::now(),
duration: 0,
};
assert!(result.success);
assert!(result.error.is_none());
let error_result = ModificationResult {
success: false,
error: Some("Failed to parse YAML".to_string()),
..result
};
assert!(!error_result.success);
assert!(error_result.error.is_some());
assert_eq!(error_result.error.unwrap(), "Failed to parse YAML");
}
#[test]
fn test_tool_parameter_validation() {
let tool = ModifyDashboardFilesTool {
agent: Arc::new(Agent::new(
"o3-mini".to_string(),
HashMap::new(),
Uuid::new_v4(),
Uuid::new_v4(),
"test_agent".to_string(),
)),
};
// Test valid parameters
let valid_params = json!({
"files": [{
"id": Uuid::new_v4().to_string(),
"file_name": "test.yml",
"modifications": [{
"content_to_replace": "old content",
"new_content": "new content"
}]
}]
});
let valid_args = serde_json::to_string(&valid_params).unwrap();
let result = serde_json::from_str::<ModifyFilesParams>(&valid_args);
assert!(result.is_ok());
// Test missing required fields
let missing_fields_params = json!({
"files": [{
"id": Uuid::new_v4().to_string(),
"file_name": "test.yml"
// missing modifications
}]
});
let missing_args = serde_json::to_string(&missing_fields_params).unwrap();
let result = serde_json::from_str::<ModifyFilesParams>(&missing_args);
assert!(result.is_err());
}
}

View File

@ -8,6 +8,7 @@ use database::{enums::Verification, models::MetricFile, pool::get_pg_pool, schem
use diesel::{upsert::excluded, ExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl;
use indexmap::IndexMap;
use middleware::AuthenticatedUser;
use query_engine::data_types::DataType;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@ -65,7 +66,7 @@ impl ToolExecutor for ModifyMetricFilesTool {
}
}
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
let start_time = Instant::now();
debug!("Starting file modification execution");
@ -274,122 +275,3 @@ impl ToolExecutor for ModifyMetricFilesTool {
})
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use chrono::Utc;
use serde_json::json;
#[test]
fn test_apply_modifications_to_content() {
let original_content = "name: test_metric\ntype: counter\ndescription: A test metric";
// Test single modification
let mods1 = vec![Modification {
content_to_replace: "type: counter".to_string(),
new_content: "type: gauge".to_string(),
}];
let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap();
assert_eq!(
result1,
"name: test_metric\ntype: gauge\ndescription: A test metric"
);
// Test multiple non-overlapping modifications
let mods2 = vec![
Modification {
content_to_replace: "test_metric".to_string(),
new_content: "new_metric".to_string(),
},
Modification {
content_to_replace: "A test metric".to_string(),
new_content: "An updated metric".to_string(),
},
];
let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap();
assert_eq!(
result2,
"name: new_metric\ntype: counter\ndescription: An updated metric"
);
// Test content not found
let mods3 = vec![Modification {
content_to_replace: "nonexistent content".to_string(),
new_content: "new content".to_string(),
}];
let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml");
assert!(result3.is_err());
assert!(result3
.unwrap_err()
.to_string()
.contains("Content to replace not found"));
}
#[test]
fn test_modification_result_tracking() {
let result = ModificationResult {
file_id: Uuid::new_v4(),
file_name: "test.yml".to_string(),
success: true,
error: None,
modification_type: "content".to_string(),
timestamp: Utc::now(),
duration: 0,
};
assert!(result.success);
assert!(result.error.is_none());
let error_result = ModificationResult {
success: false,
error: Some("Failed to parse YAML".to_string()),
..result
};
assert!(!error_result.success);
assert!(error_result.error.is_some());
assert_eq!(error_result.error.unwrap(), "Failed to parse YAML");
}
#[test]
fn test_tool_parameter_validation() {
let tool = ModifyMetricFilesTool {
agent: Arc::new(Agent::new(
"o3-mini".to_string(),
HashMap::new(),
Uuid::new_v4(),
Uuid::new_v4(),
"test_agent".to_string(),
)),
};
// Test valid parameters
let valid_params = json!({
"files": [{
"id": Uuid::new_v4().to_string(),
"file_name": "test.yml",
"modifications": [{
"content_to_replace": "old content",
"new_content": "new content"
}]
}]
});
let valid_args = serde_json::to_string(&valid_params).unwrap();
let result = serde_json::from_str::<ModifyFilesParams>(&valid_args);
assert!(result.is_ok());
// Test missing required fields
let missing_fields_params = json!({
"files": [{
"id": Uuid::new_v4().to_string(),
"file_name": "test.yml"
// missing modifications
}]
});
let missing_args = serde_json::to_string(&missing_fields_params).unwrap();
let result = serde_json::from_str::<ModifyFilesParams>(&missing_args);
assert!(result.is_err());
}
}

View File

@ -7,6 +7,7 @@ use chrono::{DateTime, Utc};
use database::{pool::get_pg_pool, schema::datasets};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use middleware::AuthenticatedUser;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use tracing::{debug, error, warn};
@ -260,7 +261,7 @@ impl ToolExecutor for SearchDataCatalogTool {
type Output = SearchDataCatalogOutput;
type Params = SearchDataCatalogParams;
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
let start_time = Instant::now();
// Fetch all non-deleted datasets

View File

@ -7,6 +7,5 @@
//! - interaction_tools: Tools for user interaction and UI manipulation
//! - planning_tools: Tools for planning and scheduling
pub mod agents_as_tools;
pub mod file_tools;
pub mod planning_tools;

View File

@ -1,5 +1,6 @@
use anyhow::Result;
use async_trait::async_trait;
use middleware::AuthenticatedUser;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
@ -36,7 +37,7 @@ impl ToolExecutor for CreatePlan {
"create_plan".to_string()
}
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
self.agent
.set_state_value(String::from("plan_available"), Value::Bool(true))
.await;

View File

@ -1,6 +1,8 @@
use anyhow::Result;
use middleware::AuthenticatedUser;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use uuid::Uuid;
/// A trait that defines how tools should be implemented.
/// Any struct that wants to be used as a tool must implement this trait.
@ -13,7 +15,7 @@ pub trait ToolExecutor: Send + Sync {
type Params: DeserializeOwned + Send;
/// Execute the tool with the given parameters and tool call ID.
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output>;
async fn execute(&self, params: Self::Params, tool_call_id: String, user_id: AuthenticatedUser) -> Result<Self::Output>;
/// Get the JSON schema for this tool
fn get_schema(&self) -> Value;
@ -53,9 +55,9 @@ where
type Output = Value;
type Params = Value;
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
let params = serde_json::from_value(params)?;
let result = self.inner.execute(params, tool_call_id).await?;
let result = self.inner.execute(params, tool_call_id, user).await?;
Ok(serde_json::to_value(result)?)
}
@ -78,8 +80,8 @@ impl<T: ToolExecutor<Output = Value, Params = Value> + Send + Sync> ToolExecutor
type Output = Value;
type Params = Value;
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
(**self).execute(params, tool_call_id).await
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
(**self).execute(params, tool_call_id, user).await
}
fn get_schema(&self) -> Value {

View File

@ -11,4 +11,3 @@ pub use executor::{ToolExecutor, ToolCallExecutor, IntoToolCallExecutor};
// Re-export commonly used tool categories
pub use categories::file_tools;
pub use categories::planning_tools;
pub use categories::agents_as_tools;

View File

@ -3,10 +3,14 @@ use once_cell::sync::OnceCell;
use std::{collections::HashMap, sync::Mutex, time::Instant};
use agents::{
tools::{file_tools::{
common::ModifyFilesOutput, create_dashboard_files::CreateDashboardFilesOutput,
create_metric_files::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
}, planning_tools::CreatePlanOutput},
tools::{
file_tools::{
common::ModifyFilesOutput, create_dashboard_files::CreateDashboardFilesOutput,
create_metric_files::CreateMetricFilesOutput,
search_data_catalog::SearchDataCatalogOutput,
},
planning_tools::CreatePlanOutput,
},
AgentExt, AgentMessage, AgentThread, BusterSuperAgent,
};
@ -175,7 +179,7 @@ pub async fn post_chat_handler(
let mut initial_messages = vec![];
// Initialize agent to add context
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
let agent = BusterSuperAgent::new(user.clone(), chat_id).await?;
// Load context if provided
if let Some(existing_chat_id) = request.chat_id {
@ -416,7 +420,8 @@ pub async fn post_chat_handler(
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
let mut response_messages = Vec::new();
// Use a Vec to maintain order, with a HashMap to track latest version of each message
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> = std::collections::HashMap::new();
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> =
std::collections::HashMap::new();
let mut reasoning_order = Vec::new();
for container in containers {
@ -1429,7 +1434,8 @@ fn transform_assistant_tool_message(
let mut updated_files = std::collections::HashMap::new();
for (file_id, file_content) in file.files.iter() {
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
let chunk_id =
format!("{}_{}", file.id, file_content.file_name);
let complete_text = tracker
.get_complete_text(chunk_id.clone())
.unwrap_or_else(|| {
@ -1456,10 +1462,12 @@ fn transform_assistant_tool_message(
let mut updated_files = std::collections::HashMap::new();
for (file_id, file_content) in file.files.iter() {
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
let chunk_id =
format!("{}_{}", file.id, file_content.file_name);
if let Some(chunk) = &file_content.file.text_chunk {
let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone());
let delta =
tracker.add_chunk(chunk_id.clone(), chunk.clone());
if !delta.is_empty() {
let mut updated_content = file_content.clone();