mirror of https://github.com/buster-so/buster.git
better use of user throughout agents and tools
This commit is contained in:
parent
9577fe99ba
commit
4432574086
|
@ -15,6 +15,7 @@ uuid = { workspace = true }
|
|||
litellm = { path = "../litellm" }
|
||||
database = { path = "../database" }
|
||||
query_engine = { path = "../query_engine" }
|
||||
middleware = { path = "../middleware" }
|
||||
serde_json = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
|
|
|
@ -4,11 +4,12 @@ use litellm::{
|
|||
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||
};
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde_json::Value;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use uuid::Uuid;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::models::AgentThread;
|
||||
|
||||
|
@ -34,7 +35,6 @@ struct MessageBuffer {
|
|||
first_message_sent: bool,
|
||||
}
|
||||
|
||||
|
||||
impl MessageBuffer {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
|
@ -80,7 +80,11 @@ impl MessageBuffer {
|
|||
// Create and send the message
|
||||
let message = AgentMessage::assistant(
|
||||
self.message_id.clone(),
|
||||
if self.content.is_empty() { None } else { Some(self.content.clone()) },
|
||||
if self.content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.content.clone())
|
||||
},
|
||||
tool_calls,
|
||||
MessageProgress::InProgress,
|
||||
Some(!self.first_message_sent),
|
||||
|
@ -98,8 +102,6 @@ impl MessageBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
/// and coordinating tool executions. It maintains a registry of available tools
|
||||
|
@ -122,7 +124,7 @@ pub struct Agent {
|
|||
/// Sender for streaming messages from this agent and sub-agents
|
||||
stream_tx: Arc<RwLock<Option<broadcast::Sender<MessageResult>>>>,
|
||||
/// The user ID for the current thread
|
||||
user_id: Uuid,
|
||||
user: AuthenticatedUser,
|
||||
/// The session ID for the current thread
|
||||
session_id: Uuid,
|
||||
/// Agent name
|
||||
|
@ -136,7 +138,7 @@ impl Agent {
|
|||
pub fn new(
|
||||
model: String,
|
||||
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
|
||||
user_id: Uuid,
|
||||
user: AuthenticatedUser,
|
||||
session_id: Uuid,
|
||||
name: String,
|
||||
) -> Self {
|
||||
|
@ -155,7 +157,7 @@ impl Agent {
|
|||
state: Arc::new(RwLock::new(HashMap::new())),
|
||||
current_thread: Arc::new(RwLock::new(None)),
|
||||
stream_tx: Arc::new(RwLock::new(Some(tx))),
|
||||
user_id,
|
||||
user,
|
||||
session_id,
|
||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
name,
|
||||
|
@ -176,7 +178,7 @@ impl Agent {
|
|||
state: Arc::clone(&existing_agent.state),
|
||||
current_thread: Arc::clone(&existing_agent.current_thread),
|
||||
stream_tx: Arc::clone(&existing_agent.stream_tx),
|
||||
user_id: existing_agent.user_id,
|
||||
user: existing_agent.user.clone(),
|
||||
session_id: existing_agent.session_id,
|
||||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx),
|
||||
name,
|
||||
|
@ -241,7 +243,11 @@ impl Agent {
|
|||
}
|
||||
|
||||
pub fn get_user_id(&self) -> Uuid {
|
||||
self.user_id
|
||||
self.user.id
|
||||
}
|
||||
|
||||
pub fn get_user(&self) -> AuthenticatedUser {
|
||||
self.user.clone()
|
||||
}
|
||||
|
||||
pub fn get_session_id(&self) -> Uuid {
|
||||
|
@ -450,7 +456,8 @@ impl Agent {
|
|||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
buffer.tool_calls
|
||||
buffer
|
||||
.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.map(|s| s.clone())
|
||||
|
@ -458,7 +465,8 @@ impl Agent {
|
|||
});
|
||||
|
||||
// Get or create the pending tool call
|
||||
let pending_call = buffer.tool_calls
|
||||
let pending_call = buffer
|
||||
.tool_calls
|
||||
.entry(id.clone())
|
||||
.or_insert_with(PendingToolCall::new);
|
||||
|
||||
|
@ -484,7 +492,8 @@ impl Agent {
|
|||
// Create and send the final message
|
||||
let final_tool_calls: Option<Vec<ToolCall>> = if !buffer.tool_calls.is_empty() {
|
||||
Some(
|
||||
buffer.tool_calls
|
||||
buffer
|
||||
.tool_calls
|
||||
.values()
|
||||
.map(|p| p.clone().into_tool_call())
|
||||
.collect(),
|
||||
|
@ -495,7 +504,11 @@ impl Agent {
|
|||
|
||||
let final_message = AgentMessage::assistant(
|
||||
buffer.message_id,
|
||||
if buffer.content.is_empty() { None } else { Some(buffer.content) },
|
||||
if buffer.content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.content)
|
||||
},
|
||||
final_tool_calls.clone(),
|
||||
MessageProgress::Complete,
|
||||
Some(false),
|
||||
|
@ -527,7 +540,7 @@ impl Agent {
|
|||
for tool_call in tool_calls {
|
||||
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
|
||||
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
||||
let result = tool.execute(params, tool_call.id.clone()).await?;
|
||||
let result = tool.execute(params, tool_call.id.clone(), self.get_user()).await?;
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_message = AgentMessage::tool(
|
||||
None,
|
||||
|
@ -671,12 +684,32 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::tools::ToolExecutor;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Utc};
|
||||
use litellm::MessageProgress;
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
use middleware::types::AuthenticatedUser;
|
||||
|
||||
fn setup() {
|
||||
dotenv::dotenv().ok();
|
||||
std::env::set_var("LLM_API_KEY", "test_key");
|
||||
std::env::set_var("LLM_BASE_URL", "http://localhost:8000");
|
||||
}
|
||||
|
||||
// Create a mock AuthenticatedUser for testing
|
||||
fn create_test_user() -> AuthenticatedUser {
|
||||
AuthenticatedUser {
|
||||
id: Uuid::new_v4(),
|
||||
email: "test@example.com".to_string(),
|
||||
name: Some("Test User".to_string()),
|
||||
config: json!({}),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
attributes: json!({}),
|
||||
avatar_url: None,
|
||||
organizations: vec![],
|
||||
teams: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
struct WeatherTool {
|
||||
|
@ -696,13 +729,8 @@ mod tests {
|
|||
tool_id: String,
|
||||
progress: MessageProgress,
|
||||
) -> Result<()> {
|
||||
let message = AgentMessage::tool(
|
||||
None,
|
||||
content,
|
||||
tool_id,
|
||||
Some(self.get_name()),
|
||||
progress,
|
||||
);
|
||||
let message =
|
||||
AgentMessage::tool(None, content, tool_id, Some(self.get_name()), progress);
|
||||
self.agent.get_stream_sender().await.send(Ok(message))?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -713,7 +741,12 @@ mod tests {
|
|||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(
|
||||
&self,
|
||||
params: Self::Params,
|
||||
tool_call_id: String,
|
||||
user: AuthenticatedUser,
|
||||
) -> Result<Self::Output> {
|
||||
self.send_progress(
|
||||
"Fetching weather data...".to_string(),
|
||||
"123".to_string(),
|
||||
|
@ -778,14 +811,14 @@ mod tests {
|
|||
let agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
Uuid::new_v4(),
|
||||
create_test_user().id,
|
||||
vec![AgentMessage::user("Hello, world!".to_string())],
|
||||
);
|
||||
|
||||
|
@ -803,7 +836,7 @@ mod tests {
|
|||
let mut agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
@ -816,7 +849,7 @@ mod tests {
|
|||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
Uuid::new_v4(),
|
||||
create_test_user().id,
|
||||
vec![AgentMessage::user(
|
||||
"What is the weather in vineyard ut?".to_string(),
|
||||
)],
|
||||
|
@ -836,7 +869,7 @@ mod tests {
|
|||
let mut agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
@ -847,7 +880,7 @@ mod tests {
|
|||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
Uuid::new_v4(),
|
||||
create_test_user().id,
|
||||
vec![AgentMessage::user(
|
||||
"What is the weather in vineyard ut and san francisco?".to_string(),
|
||||
)],
|
||||
|
@ -867,7 +900,7 @@ mod tests {
|
|||
let agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use anyhow::Result;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
@ -97,12 +98,12 @@ impl BusterSuperAgent {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
|
||||
pub async fn new(user: AuthenticatedUser, session_id: Uuid) -> Result<Self> {
|
||||
// Create agent with empty tools map
|
||||
let agent = Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
user_id,
|
||||
user,
|
||||
session_id,
|
||||
"buster_super_agent".to_string(),
|
||||
));
|
||||
|
|
|
@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{self, json, Value};
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
use middleware::AuthenticatedUser;
|
||||
|
||||
use crate::{
|
||||
agent::Agent,
|
||||
|
@ -131,7 +132,7 @@ impl ToolExecutor for CreateDashboardFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
let files = params.files;
|
||||
|
|
|
@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
use middleware::AuthenticatedUser;
|
||||
|
||||
use crate::{
|
||||
agent::Agent,
|
||||
|
@ -81,7 +82,7 @@ impl ToolExecutor for CreateMetricFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
let files = params.files;
|
||||
|
|
|
@ -12,6 +12,7 @@ use indexmap::IndexMap;
|
|||
use query_engine::data_types::DataType;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, error, info};
|
||||
use middleware::AuthenticatedUser;
|
||||
|
||||
use super::{
|
||||
common::{
|
||||
|
@ -67,7 +68,7 @@ impl ToolExecutor for ModifyDashboardFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
debug!("Starting file modification execution");
|
||||
|
@ -216,127 +217,3 @@ impl ToolExecutor for ModifyDashboardFilesTool {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use crate::tools::categories::file_tools::common::{
|
||||
apply_modifications_to_content, Modification, ModificationResult,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn test_apply_modifications_to_content() {
|
||||
let original_content =
|
||||
"name: test_dashboard\ntype: dashboard\ndescription: A test dashboard";
|
||||
|
||||
// Test single modification
|
||||
let mods1 = vec![Modification {
|
||||
content_to_replace: "type: dashboard".to_string(),
|
||||
new_content: "type: custom_dashboard".to_string(),
|
||||
}];
|
||||
let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result1,
|
||||
"name: test_dashboard\ntype: custom_dashboard\ndescription: A test dashboard"
|
||||
);
|
||||
|
||||
// Test multiple non-overlapping modifications
|
||||
let mods2 = vec![
|
||||
Modification {
|
||||
content_to_replace: "test_dashboard".to_string(),
|
||||
new_content: "new_dashboard".to_string(),
|
||||
},
|
||||
Modification {
|
||||
content_to_replace: "A test dashboard".to_string(),
|
||||
new_content: "An updated dashboard".to_string(),
|
||||
},
|
||||
];
|
||||
let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result2,
|
||||
"name: new_dashboard\ntype: dashboard\ndescription: An updated dashboard"
|
||||
);
|
||||
|
||||
// Test content not found
|
||||
let mods3 = vec![Modification {
|
||||
content_to_replace: "nonexistent content".to_string(),
|
||||
new_content: "new content".to_string(),
|
||||
}];
|
||||
let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml");
|
||||
assert!(result3.is_err());
|
||||
assert!(result3
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Content to replace not found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modification_result_tracking() {
|
||||
let result = ModificationResult {
|
||||
file_id: Uuid::new_v4(),
|
||||
file_name: "test.yml".to_string(),
|
||||
success: true,
|
||||
error: None,
|
||||
modification_type: "content".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
duration: 0,
|
||||
};
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.error.is_none());
|
||||
|
||||
let error_result = ModificationResult {
|
||||
success: false,
|
||||
error: Some("Failed to parse YAML".to_string()),
|
||||
..result
|
||||
};
|
||||
assert!(!error_result.success);
|
||||
assert!(error_result.error.is_some());
|
||||
assert_eq!(error_result.error.unwrap(), "Failed to parse YAML");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_parameter_validation() {
|
||||
let tool = ModifyDashboardFilesTool {
|
||||
agent: Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
)),
|
||||
};
|
||||
|
||||
// Test valid parameters
|
||||
let valid_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml",
|
||||
"modifications": [{
|
||||
"content_to_replace": "old content",
|
||||
"new_content": "new content"
|
||||
}]
|
||||
}]
|
||||
});
|
||||
let valid_args = serde_json::to_string(&valid_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&valid_args);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test missing required fields
|
||||
let missing_fields_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml"
|
||||
// missing modifications
|
||||
}]
|
||||
});
|
||||
let missing_args = serde_json::to_string(&missing_fields_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&missing_args);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use database::{enums::Verification, models::MetricFile, pool::get_pg_pool, schem
|
|||
use diesel::{upsert::excluded, ExpressionMethods, QueryDsl};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use indexmap::IndexMap;
|
||||
use middleware::AuthenticatedUser;
|
||||
use query_engine::data_types::DataType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
@ -65,7 +66,7 @@ impl ToolExecutor for ModifyMetricFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
debug!("Starting file modification execution");
|
||||
|
@ -274,122 +275,3 @@ impl ToolExecutor for ModifyMetricFilesTool {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_apply_modifications_to_content() {
|
||||
let original_content = "name: test_metric\ntype: counter\ndescription: A test metric";
|
||||
|
||||
// Test single modification
|
||||
let mods1 = vec![Modification {
|
||||
content_to_replace: "type: counter".to_string(),
|
||||
new_content: "type: gauge".to_string(),
|
||||
}];
|
||||
let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result1,
|
||||
"name: test_metric\ntype: gauge\ndescription: A test metric"
|
||||
);
|
||||
|
||||
// Test multiple non-overlapping modifications
|
||||
let mods2 = vec![
|
||||
Modification {
|
||||
content_to_replace: "test_metric".to_string(),
|
||||
new_content: "new_metric".to_string(),
|
||||
},
|
||||
Modification {
|
||||
content_to_replace: "A test metric".to_string(),
|
||||
new_content: "An updated metric".to_string(),
|
||||
},
|
||||
];
|
||||
let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result2,
|
||||
"name: new_metric\ntype: counter\ndescription: An updated metric"
|
||||
);
|
||||
|
||||
// Test content not found
|
||||
let mods3 = vec![Modification {
|
||||
content_to_replace: "nonexistent content".to_string(),
|
||||
new_content: "new content".to_string(),
|
||||
}];
|
||||
let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml");
|
||||
assert!(result3.is_err());
|
||||
assert!(result3
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Content to replace not found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modification_result_tracking() {
|
||||
let result = ModificationResult {
|
||||
file_id: Uuid::new_v4(),
|
||||
file_name: "test.yml".to_string(),
|
||||
success: true,
|
||||
error: None,
|
||||
modification_type: "content".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
duration: 0,
|
||||
};
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.error.is_none());
|
||||
|
||||
let error_result = ModificationResult {
|
||||
success: false,
|
||||
error: Some("Failed to parse YAML".to_string()),
|
||||
..result
|
||||
};
|
||||
assert!(!error_result.success);
|
||||
assert!(error_result.error.is_some());
|
||||
assert_eq!(error_result.error.unwrap(), "Failed to parse YAML");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_parameter_validation() {
|
||||
let tool = ModifyMetricFilesTool {
|
||||
agent: Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
)),
|
||||
};
|
||||
|
||||
// Test valid parameters
|
||||
let valid_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml",
|
||||
"modifications": [{
|
||||
"content_to_replace": "old content",
|
||||
"new_content": "new content"
|
||||
}]
|
||||
}]
|
||||
});
|
||||
let valid_args = serde_json::to_string(&valid_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&valid_args);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test missing required fields
|
||||
let missing_fields_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml"
|
||||
// missing modifications
|
||||
}]
|
||||
});
|
||||
let missing_args = serde_json::to_string(&missing_fields_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&missing_args);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use chrono::{DateTime, Utc};
|
|||
use database::{pool::get_pg_pool, schema::datasets};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{debug, error, warn};
|
||||
|
@ -260,7 +261,7 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
type Output = SearchDataCatalogOutput;
|
||||
type Params = SearchDataCatalogParams;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Fetch all non-deleted datasets
|
||||
|
|
|
@ -7,6 +7,5 @@
|
|||
//! - interaction_tools: Tools for user interaction and UI manipulation
|
||||
//! - planning_tools: Tools for planning and scheduling
|
||||
|
||||
pub mod agents_as_tools;
|
||||
pub mod file_tools;
|
||||
pub mod planning_tools;
|
|
@ -1,5 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
@ -36,7 +37,7 @@ impl ToolExecutor for CreatePlan {
|
|||
"create_plan".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
self.agent
|
||||
.set_state_value(String::from("plan_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use anyhow::Result;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// A trait that defines how tools should be implemented.
|
||||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
|
@ -13,7 +15,7 @@ pub trait ToolExecutor: Send + Sync {
|
|||
type Params: DeserializeOwned + Send;
|
||||
|
||||
/// Execute the tool with the given parameters and tool call ID.
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output>;
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user_id: AuthenticatedUser) -> Result<Self::Output>;
|
||||
|
||||
/// Get the JSON schema for this tool
|
||||
fn get_schema(&self) -> Value;
|
||||
|
@ -53,9 +55,9 @@ where
|
|||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
let params = serde_json::from_value(params)?;
|
||||
let result = self.inner.execute(params, tool_call_id).await?;
|
||||
let result = self.inner.execute(params, tool_call_id, user).await?;
|
||||
Ok(serde_json::to_value(result)?)
|
||||
}
|
||||
|
||||
|
@ -78,8 +80,8 @@ impl<T: ToolExecutor<Output = Value, Params = Value> + Send + Sync> ToolExecutor
|
|||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
(**self).execute(params, tool_call_id).await
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
(**self).execute(params, tool_call_id, user).await
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
|
|
@ -11,4 +11,3 @@ pub use executor::{ToolExecutor, ToolCallExecutor, IntoToolCallExecutor};
|
|||
// Re-export commonly used tool categories
|
||||
pub use categories::file_tools;
|
||||
pub use categories::planning_tools;
|
||||
pub use categories::agents_as_tools;
|
|
@ -3,10 +3,14 @@ use once_cell::sync::OnceCell;
|
|||
use std::{collections::HashMap, sync::Mutex, time::Instant};
|
||||
|
||||
use agents::{
|
||||
tools::{file_tools::{
|
||||
common::ModifyFilesOutput, create_dashboard_files::CreateDashboardFilesOutput,
|
||||
create_metric_files::CreateMetricFilesOutput, search_data_catalog::SearchDataCatalogOutput,
|
||||
}, planning_tools::CreatePlanOutput},
|
||||
tools::{
|
||||
file_tools::{
|
||||
common::ModifyFilesOutput, create_dashboard_files::CreateDashboardFilesOutput,
|
||||
create_metric_files::CreateMetricFilesOutput,
|
||||
search_data_catalog::SearchDataCatalogOutput,
|
||||
},
|
||||
planning_tools::CreatePlanOutput,
|
||||
},
|
||||
AgentExt, AgentMessage, AgentThread, BusterSuperAgent,
|
||||
};
|
||||
|
||||
|
@ -175,7 +179,7 @@ pub async fn post_chat_handler(
|
|||
let mut initial_messages = vec![];
|
||||
|
||||
// Initialize agent to add context
|
||||
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
|
||||
let agent = BusterSuperAgent::new(user.clone(), chat_id).await?;
|
||||
|
||||
// Load context if provided
|
||||
if let Some(existing_chat_id) = request.chat_id {
|
||||
|
@ -416,7 +420,8 @@ pub async fn post_chat_handler(
|
|||
fn prepare_final_message_state(containers: &[BusterContainer]) -> Result<(Vec<Value>, Vec<Value>)> {
|
||||
let mut response_messages = Vec::new();
|
||||
// Use a Vec to maintain order, with a HashMap to track latest version of each message
|
||||
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> = std::collections::HashMap::new();
|
||||
let mut reasoning_map: std::collections::HashMap<String, (usize, Value)> =
|
||||
std::collections::HashMap::new();
|
||||
let mut reasoning_order = Vec::new();
|
||||
|
||||
for container in containers {
|
||||
|
@ -1429,7 +1434,8 @@ fn transform_assistant_tool_message(
|
|||
let mut updated_files = std::collections::HashMap::new();
|
||||
|
||||
for (file_id, file_content) in file.files.iter() {
|
||||
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
|
||||
let chunk_id =
|
||||
format!("{}_{}", file.id, file_content.file_name);
|
||||
let complete_text = tracker
|
||||
.get_complete_text(chunk_id.clone())
|
||||
.unwrap_or_else(|| {
|
||||
|
@ -1456,10 +1462,12 @@ fn transform_assistant_tool_message(
|
|||
let mut updated_files = std::collections::HashMap::new();
|
||||
|
||||
for (file_id, file_content) in file.files.iter() {
|
||||
let chunk_id = format!("{}_{}", file.id, file_content.file_name);
|
||||
let chunk_id =
|
||||
format!("{}_{}", file.id, file_content.file_name);
|
||||
|
||||
if let Some(chunk) = &file_content.file.text_chunk {
|
||||
let delta = tracker.add_chunk(chunk_id.clone(), chunk.clone());
|
||||
let delta =
|
||||
tracker.add_chunk(chunk_id.clone(), chunk.clone());
|
||||
|
||||
if !delta.is_empty() {
|
||||
let mut updated_content = file_content.clone();
|
||||
|
|
Loading…
Reference in New Issue