mirror of https://github.com/buster-so/buster.git
Merge branch 'evals' of https://github.com/buster-so/buster into evals
This commit is contained in:
commit
2ea633e561
|
@ -10,12 +10,13 @@ use tokio::sync::broadcast;
|
|||
use uuid::Uuid;
|
||||
|
||||
// Import the modes and necessary types
|
||||
use crate::agents::modes::{ // Assuming modes/mod.rs is one level up
|
||||
use crate::agents::modes::{
|
||||
// Assuming modes/mod.rs is one level up
|
||||
self, // Import the module itself for functions like determine_agent_state
|
||||
determine_agent_state,
|
||||
AgentState,
|
||||
ModeAgentData,
|
||||
ModeConfiguration,
|
||||
determine_agent_state,
|
||||
};
|
||||
|
||||
// Import Agent related types
|
||||
|
@ -50,14 +51,6 @@ pub struct BusterSuperAgentInput {
|
|||
pub message_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
// --- REMOVE State Management (Moved to modes/mod.rs) ---
|
||||
// #[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
// enum AgentState { ... }
|
||||
// fn determine_agent_state(state: &HashMap<String, Value>) -> AgentState { ... }
|
||||
|
||||
|
||||
// --- Mode Provider Implementation ---
|
||||
|
||||
// Create a struct to hold the data needed by the provider
|
||||
#[derive(Clone)]
|
||||
struct BusterModeProvider {
|
||||
|
@ -66,29 +59,34 @@ struct BusterModeProvider {
|
|||
|
||||
#[async_trait::async_trait]
|
||||
impl ModeProvider for BusterModeProvider {
|
||||
async fn get_configuration_for_state(&self, state: &HashMap<String, Value>) -> Result<ModeConfiguration> {
|
||||
async fn get_configuration_for_state(
|
||||
&self,
|
||||
state: &HashMap<String, Value>,
|
||||
) -> Result<ModeConfiguration> {
|
||||
let current_mode = determine_agent_state(state);
|
||||
|
||||
|
||||
// Call the appropriate get_configuration function based on the mode
|
||||
let mode_config = match current_mode {
|
||||
AgentState::Initializing => modes::initialization::get_configuration(&self.agent_data),
|
||||
AgentState::FollowUpInitialization => modes::follow_up_initialization::get_configuration(&self.agent_data),
|
||||
AgentState::DataCatalogSearch => modes::data_catalog_search::get_configuration(&self.agent_data),
|
||||
AgentState::FollowUpInitialization => {
|
||||
modes::follow_up_initialization::get_configuration(&self.agent_data)
|
||||
}
|
||||
AgentState::DataCatalogSearch => {
|
||||
modes::data_catalog_search::get_configuration(&self.agent_data)
|
||||
}
|
||||
AgentState::Planning => modes::planning::get_configuration(&self.agent_data),
|
||||
AgentState::AnalysisExecution => modes::analysis::get_configuration(&self.agent_data),
|
||||
AgentState::Review => modes::review::get_configuration(&self.agent_data),
|
||||
};
|
||||
|
||||
|
||||
Ok(mode_config)
|
||||
}
|
||||
}
|
||||
|
||||
// --- BusterMultiAgent ---
|
||||
// --- BusterMultiAgent ---
|
||||
|
||||
pub struct BusterMultiAgent {
|
||||
agent: Arc<Agent>,
|
||||
// REMOVED dataset_names: Vec<String>,
|
||||
// REMOVED todays_date: String,
|
||||
}
|
||||
|
||||
impl AgentExt for BusterMultiAgent {
|
||||
|
@ -99,11 +97,7 @@ impl AgentExt for BusterMultiAgent {
|
|||
}
|
||||
|
||||
impl BusterMultiAgent {
|
||||
pub async fn new(
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
is_follow_up: bool,
|
||||
) -> Result<Self> {
|
||||
pub async fn new(user_id: Uuid, session_id: Uuid, is_follow_up: bool) -> Result<Self> {
|
||||
let organization_id = match get_user_organization_id(&user_id).await {
|
||||
Ok(Some(org_id)) => org_id,
|
||||
Ok(None) => return Err(anyhow::anyhow!("User does not belong to any organization")),
|
||||
|
@ -113,7 +107,7 @@ impl BusterMultiAgent {
|
|||
// Prepare data for modes
|
||||
let todays_date = Arc::new(Local::now().format("%Y-%m-%d").to_string());
|
||||
let dataset_names = Arc::new(get_dataset_names_for_organization(organization_id).await?);
|
||||
|
||||
|
||||
let agent_data = ModeAgentData {
|
||||
dataset_names,
|
||||
todays_date,
|
||||
|
@ -122,21 +116,15 @@ impl BusterMultiAgent {
|
|||
// Create the mode provider
|
||||
let mode_provider = Arc::new(BusterModeProvider { agent_data });
|
||||
|
||||
// REMOVE old hook logic
|
||||
// let agent_arc_for_hook = self.agent.clone(); // This was the error source
|
||||
// let hook_generator = || -> ... { ... };
|
||||
|
||||
// Create agent, passing the provider
|
||||
let agent = Arc::new(Agent::new(
|
||||
"o3-mini".to_string(), // Initial model (can be overridden by first mode)
|
||||
"o4-mini".to_string(), // Initial model (can be overridden by first mode)
|
||||
user_id,
|
||||
session_id,
|
||||
"buster_multi_agent".to_string(),
|
||||
None, // api_key
|
||||
None, // api_key
|
||||
None, // base_url
|
||||
mode_provider, // Pass the provider
|
||||
// REMOVED initial_default_prompt
|
||||
// REMOVED hook_generator()
|
||||
));
|
||||
|
||||
// Set the initial is_follow_up flag in state
|
||||
|
@ -146,33 +134,25 @@ impl BusterMultiAgent {
|
|||
|
||||
let buster_agent = Self {
|
||||
agent,
|
||||
// REMOVED dataset_names,
|
||||
// REMOVED todays_date,
|
||||
};
|
||||
|
||||
// REMOVE dynamic rules registration
|
||||
// buster_agent.register_dynamic_rules().await?;
|
||||
|
||||
Ok(buster_agent)
|
||||
}
|
||||
|
||||
// REMOVE register_dynamic_rules function
|
||||
// async fn register_dynamic_rules(&self) -> Result<()> { ... }
|
||||
|
||||
pub async fn run(
|
||||
self: &Arc<Self>, // Take Arc<Self> if AgentExt requires it for process_thread
|
||||
self: &Arc<Self>,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
|
||||
// Ensure the initial user prompt is in the state if available
|
||||
// This is important for the first call to determine_agent_state
|
||||
if let Some(user_prompt) = self.get_latest_user_message(thread) {
|
||||
self.agent // Use self.agent directly
|
||||
.set_state_value("user_prompt".to_string(), Value::String(user_prompt))
|
||||
.await;
|
||||
} else {
|
||||
// Handle case where there might not be a user message yet (e.g., agent starts convo?)
|
||||
self.agent.set_state_value("user_prompt".to_string(), Value::Null).await;
|
||||
self.agent
|
||||
.set_state_value("user_prompt".to_string(), Value::Null)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Mode configuration now happens inside Agent::process_thread_with_depth via the provider
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::tools::ToolExecutor;
|
||||
use crate::Agent; // For get_name()
|
||||
|
@ -15,15 +15,10 @@ use super::{ModeAgentData, ModeConfiguration};
|
|||
use crate::tools::{
|
||||
categories::{
|
||||
file_tools::{
|
||||
CreateDashboardFilesTool,
|
||||
CreateMetricFilesTool,
|
||||
ModifyDashboardFilesTool,
|
||||
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
|
||||
ModifyMetricFilesTool,
|
||||
},
|
||||
response_tools::{
|
||||
Done,
|
||||
MessageUserClarifyingQuestion,
|
||||
},
|
||||
response_tools::{Done, MessageUserClarifyingQuestion},
|
||||
},
|
||||
IntoToolCallExecutor,
|
||||
};
|
||||
|
@ -31,77 +26,110 @@ use crate::tools::{
|
|||
// Function to get the configuration for the AnalysisExecution mode
|
||||
pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
||||
// 1. Get the prompt, formatted with current data
|
||||
let prompt = PROMPT
|
||||
.replace("{TODAYS_DATE}", &agent_data.todays_date);
|
||||
// Note: This prompt doesn't use {DATASETS}
|
||||
let prompt = PROMPT.replace("{TODAYS_DATE}", &agent_data.todays_date);
|
||||
// Note: This prompt doesn't use {DATASETS}
|
||||
|
||||
// 2. Define the model for this mode (Using default based on original MODEL = None)
|
||||
let model = "o3-mini".to_string();
|
||||
let model = "o4-mini".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
let tool_loader: Box<
|
||||
dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
|
||||
> = Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
|
||||
// Instantiate tools for this mode
|
||||
let create_metric_files_tool = CreateMetricFilesTool::new(agent_clone.clone());
|
||||
let modify_metric_files_tool = ModifyMetricFilesTool::new(agent_clone.clone());
|
||||
let create_dashboard_files_tool = CreateDashboardFilesTool::new(agent_clone.clone());
|
||||
let modify_dashboard_files_tool = ModifyDashboardFilesTool::new(agent_clone.clone());
|
||||
let message_user_clarifying_question_tool = MessageUserClarifyingQuestion::new();
|
||||
let done_tool = Done::new(agent_clone.clone());
|
||||
// Instantiate tools for this mode
|
||||
let create_metric_files_tool = CreateMetricFilesTool::new(agent_clone.clone());
|
||||
let modify_metric_files_tool = ModifyMetricFilesTool::new(agent_clone.clone());
|
||||
let create_dashboard_files_tool = CreateDashboardFilesTool::new(agent_clone.clone());
|
||||
let modify_dashboard_files_tool = ModifyDashboardFilesTool::new(agent_clone.clone());
|
||||
let done_tool = Done::new(agent_clone.clone());
|
||||
|
||||
// --- Define Conditions based on Agent State (as per original load_tools) ---
|
||||
// Base condition: Plan and context must exist (implicitly true if we are in this mode)
|
||||
let base_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && state.contains_key("plan_available")
|
||||
});
|
||||
let modify_metric_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && state.contains_key("plan_available") && state.contains_key("metrics_available")
|
||||
});
|
||||
let create_dashboard_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && state.contains_key("plan_available") && state.contains_key("metrics_available")
|
||||
});
|
||||
let modify_dashboard_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && state.contains_key("plan_available") && state.contains_key("dashboards_available")
|
||||
});
|
||||
let done_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
let review_needed = state.get("review_needed").and_then(Value::as_bool).unwrap_or(false);
|
||||
let all_todos_complete = state
|
||||
.get("todos") // Assuming plan execution updates 'todos'
|
||||
.and_then(Value::as_array)
|
||||
.map(|todos| {
|
||||
todos.iter().all(|todo| {
|
||||
todo.get("completed")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
// --- Define Conditions based on Agent State (as per original load_tools) ---
|
||||
// Base condition: Plan and context must exist (implicitly true if we are in this mode)
|
||||
let base_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && state.contains_key("plan_available")
|
||||
});
|
||||
let modify_metric_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context")
|
||||
&& state.contains_key("plan_available")
|
||||
&& state.contains_key("metrics_available")
|
||||
});
|
||||
let create_dashboard_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context")
|
||||
&& state.contains_key("plan_available")
|
||||
&& state.contains_key("metrics_available")
|
||||
});
|
||||
let modify_dashboard_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context")
|
||||
&& state.contains_key("plan_available")
|
||||
&& state.contains_key("dashboards_available")
|
||||
});
|
||||
let done_condition = Some(|state: &HashMap<String, Value>| -> bool {
|
||||
let review_needed = state
|
||||
.get("review_needed")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(false);
|
||||
let all_todos_complete = state
|
||||
.get("todos") // Assuming plan execution updates 'todos'
|
||||
.and_then(Value::as_array)
|
||||
.map(|todos| {
|
||||
todos.iter().all(|todo| {
|
||||
todo.get("completed")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.unwrap_or(false);
|
||||
review_needed || all_todos_complete
|
||||
});
|
||||
let always_available = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
})
|
||||
.unwrap_or(false);
|
||||
review_needed || all_todos_complete
|
||||
});
|
||||
|
||||
// Add tools to the agent with conditions
|
||||
agent_clone.add_tool(create_metric_files_tool.get_name(), create_metric_files_tool.into_tool_call_executor(), base_condition.clone()).await;
|
||||
agent_clone.add_tool(modify_metric_files_tool.get_name(), modify_metric_files_tool.into_tool_call_executor(), modify_metric_condition).await;
|
||||
agent_clone.add_tool(create_dashboard_files_tool.get_name(), create_dashboard_files_tool.into_tool_call_executor(), create_dashboard_condition).await;
|
||||
agent_clone.add_tool(modify_dashboard_files_tool.get_name(), modify_dashboard_files_tool.into_tool_call_executor(), modify_dashboard_condition).await;
|
||||
agent_clone.add_tool(message_user_clarifying_question_tool.get_name(), message_user_clarifying_question_tool.into_tool_call_executor(), always_available).await;
|
||||
agent_clone.add_tool(done_tool.get_name(), done_tool.into_tool_call_executor(), done_condition).await;
|
||||
// Add tools to the agent with conditions
|
||||
agent_clone
|
||||
.add_tool(
|
||||
create_metric_files_tool.get_name(),
|
||||
create_metric_files_tool.into_tool_call_executor(),
|
||||
base_condition.clone(),
|
||||
)
|
||||
.await;
|
||||
agent_clone
|
||||
.add_tool(
|
||||
modify_metric_files_tool.get_name(),
|
||||
modify_metric_files_tool.into_tool_call_executor(),
|
||||
modify_metric_condition,
|
||||
)
|
||||
.await;
|
||||
agent_clone
|
||||
.add_tool(
|
||||
create_dashboard_files_tool.get_name(),
|
||||
create_dashboard_files_tool.into_tool_call_executor(),
|
||||
create_dashboard_condition,
|
||||
)
|
||||
.await;
|
||||
agent_clone
|
||||
.add_tool(
|
||||
modify_dashboard_files_tool.get_name(),
|
||||
modify_dashboard_files_tool.into_tool_call_executor(),
|
||||
modify_dashboard_condition,
|
||||
)
|
||||
.await;
|
||||
agent_clone
|
||||
.add_tool(
|
||||
done_tool.get_name(),
|
||||
done_tool.into_tool_call_executor(),
|
||||
done_condition,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
// 4. Define terminating tools for this mode (From original load_tools)
|
||||
let terminating_tools = vec![
|
||||
"message_user_clarifying_question".to_string(), // Hardcoded name
|
||||
"finish_and_respond".to_string(), // Hardcoded name for Done tool
|
||||
];
|
||||
let terminating_tools = vec![Done::get_name()];
|
||||
|
||||
// 5. Construct and return the ModeConfiguration
|
||||
ModeConfiguration {
|
||||
|
|
|
@ -42,7 +42,7 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
.replace("{TODAYS_DATE}", &agent_data.todays_date);
|
||||
|
||||
// 2. Define the model for this mode (Using a default, adjust if needed)
|
||||
let model = "o3-mini".to_string(); // Assuming default based on original MODEL = None
|
||||
let model = "o4-mini".to_string(); // Assuming default based on original MODEL = None
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
// Import necessary types from the parent module (modes/mod.rs)
|
||||
use super::{ModeAgentData, ModeConfiguration};
|
||||
|
@ -12,14 +12,11 @@ use crate::{Agent, ToolExecutor};
|
|||
// Import necessary tools for this mode
|
||||
use crate::tools::{
|
||||
categories::{
|
||||
file_tools::SearchDataCatalogTool,
|
||||
response_tools::MessageUserClarifyingQuestion,
|
||||
utility_tools::no_search_needed::NoSearchNeededTool,
|
||||
file_tools::SearchDataCatalogTool, response_tools::MessageUserClarifyingQuestion,
|
||||
},
|
||||
IntoToolCallExecutor,
|
||||
};
|
||||
|
||||
|
||||
// Function to get the configuration for the Initialization mode
|
||||
pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
||||
// 1. Get the prompt, formatted with current data
|
||||
|
@ -29,54 +26,48 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
|
||||
// 2. Define the model for this mode (Using a default, adjust if needed)
|
||||
// Since the original MODEL was None, we might use the agent's default
|
||||
// or specify a standard one like "o3-mini". Let's use "o3-mini".
|
||||
let model = "o3-mini".to_string();
|
||||
// or specify a standard one like "o4-mini". Let's use "o4-mini".
|
||||
let model = "o4-mini".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
let tool_loader: Box<
|
||||
dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
|
||||
> = Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
|
||||
// Instantiate tools for this mode
|
||||
let search_data_catalog_tool = SearchDataCatalogTool::new(agent_clone.clone());
|
||||
let no_search_needed_tool = NoSearchNeededTool::new(agent_clone.clone());
|
||||
let message_user_clarifying_question_tool = MessageUserClarifyingQuestion::new(); // No agent state needed
|
||||
// Instantiate tools for this mode
|
||||
let search_data_catalog_tool = SearchDataCatalogTool::new(agent_clone.clone());
|
||||
let message_user_clarifying_question_tool = MessageUserClarifyingQuestion::new(); // No agent state needed
|
||||
|
||||
// Condition (always true for this mode's tools)
|
||||
let condition = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
// Condition (always true for this mode's tools)
|
||||
let condition = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
|
||||
// Add tools to the agent
|
||||
agent_clone.add_tool(
|
||||
// Add tools to the agent
|
||||
agent_clone
|
||||
.add_tool(
|
||||
search_data_catalog_tool.get_name(),
|
||||
search_data_catalog_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
no_search_needed_tool.get_name(),
|
||||
no_search_needed_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
agent_clone
|
||||
.add_tool(
|
||||
message_user_clarifying_question_tool.get_name(),
|
||||
message_user_clarifying_question_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
// 4. Define terminating tools for this mode
|
||||
let terminating_tools = vec![
|
||||
// From original load_tools: only MessageUserClarifyingQuestion was registered
|
||||
MessageUserClarifyingQuestion::new().get_name(),
|
||||
// Add other terminating tools if needed for this mode
|
||||
];
|
||||
let terminating_tools = vec![MessageUserClarifyingQuestion::get_name()];
|
||||
|
||||
// 5. Construct and return the ModeConfiguration
|
||||
ModeConfiguration {
|
||||
|
@ -87,7 +78,6 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
// Keep the prompt constant, but it's no longer pub
|
||||
const INTIALIZATION_PROMPT: &str = r##"### Role & Task
|
||||
You are Buster, an AI assistant and expert in **data analytics, data science, and data engineering**. You operate within the **Buster platform**, the world's best BI tool, assisting non-technical users with their analytics tasks. Your capabilities include:
|
||||
|
|
|
@ -31,7 +31,7 @@ pub struct ModeAgentData {
|
|||
pub struct ModeConfiguration {
|
||||
/// The system prompt to use for the LLM call in this mode.
|
||||
pub prompt: String,
|
||||
/// The specific LLM model identifier (e.g., "o3-mini") to use for this mode.
|
||||
/// The specific LLM model identifier (e.g., "o4-mini") to use for this mode.
|
||||
pub model: String,
|
||||
/// An async function/closure responsible for clearing existing tools
|
||||
/// and loading the specific tools required for this mode onto the agent.
|
||||
|
@ -77,31 +77,47 @@ pub fn determine_agent_state(state: &HashMap<String, Value>) -> AgentState {
|
|||
let has_user_prompt = state.contains_key("user_prompt"); // Check if latest user prompt is stored
|
||||
|
||||
|
||||
// Initial state determination is tricky - depends when this is called relative to receiving the first user message.
|
||||
// Assuming this is called *after* the first user message is added to the state.
|
||||
if is_follow_up {
|
||||
// Follow-up specific flow
|
||||
if !has_user_prompt { // If this is the very start of a follow-up *before* user speaks
|
||||
return AgentState::FollowUpInitialization;
|
||||
}
|
||||
// Now assume user has spoken in this follow-up turn
|
||||
if needs_review { AgentState::Review }
|
||||
else if !searched_catalog && !has_data_context { AgentState::DataCatalogSearch } // Need to search if no context yet
|
||||
else if has_data_context && !has_plan { AgentState::Planning }
|
||||
else if has_data_context && has_plan { AgentState::AnalysisExecution }
|
||||
else { AgentState::FollowUpInitialization } // Fallback for follow-up if state is unclear
|
||||
// 1. Handle states before the user provides their first prompt in this turn/session
|
||||
if !has_user_prompt {
|
||||
return if is_follow_up {
|
||||
AgentState::FollowUpInitialization
|
||||
} else {
|
||||
AgentState::Initializing
|
||||
};
|
||||
}
|
||||
|
||||
// 2. Review always takes precedence after user speaks
|
||||
if needs_review {
|
||||
return AgentState::Review;
|
||||
}
|
||||
|
||||
// 3. If we haven't searched the catalog yet, do that now (initial or follow-up)
|
||||
// This is the key change: check this condition before others like has_data_context
|
||||
if !searched_catalog {
|
||||
return AgentState::DataCatalogSearch;
|
||||
}
|
||||
|
||||
// 4. If we have context but no plan, plan
|
||||
if has_data_context && !has_plan {
|
||||
return AgentState::Planning;
|
||||
}
|
||||
|
||||
// 5. If we have context and a plan, execute analysis
|
||||
if has_data_context && has_plan {
|
||||
return AgentState::AnalysisExecution;
|
||||
}
|
||||
|
||||
// 6. Fallback: If the state is ambiguous after searching and without needing review
|
||||
// (e.g., search happened but no context was added, or no plan needed).
|
||||
// Revert to an earlier appropriate state.
|
||||
if is_follow_up {
|
||||
// If it was a follow-up, perhaps return to follow-up init or planning?
|
||||
// Let's choose FollowUpInitialization as a safe default if planning/analysis aren't ready.
|
||||
AgentState::FollowUpInitialization
|
||||
} else {
|
||||
// Initial conversation flow
|
||||
if !has_user_prompt { // If this is the very start *before* user speaks
|
||||
return AgentState::Initializing;
|
||||
}
|
||||
// Now assume user has spoken
|
||||
if needs_review { AgentState::Review }
|
||||
else if !searched_catalog { AgentState::DataCatalogSearch }
|
||||
else if has_data_context && !has_plan { AgentState::Planning }
|
||||
else if has_data_context && has_plan { AgentState::AnalysisExecution }
|
||||
else { AgentState::Initializing } // Fallback for initial if state is unclear
|
||||
// If it was initial, perhaps return to init or planning?
|
||||
// Let's choose Initializing as a safe default if planning/analysis aren't ready.
|
||||
AgentState::Initializing
|
||||
}
|
||||
|
||||
// Original logic kept for reference:
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::Agent;
|
||||
use crate::tools::ToolExecutor; // For get_name()
|
||||
use crate::tools::ToolExecutor;
|
||||
use crate::Agent; // For get_name()
|
||||
|
||||
// Import necessary types from the parent module (modes/mod.rs)
|
||||
use super::{ModeAgentData, ModeConfiguration};
|
||||
|
@ -20,7 +20,6 @@ use crate::tools::{
|
|||
IntoToolCallExecutor,
|
||||
};
|
||||
|
||||
|
||||
// Function to get the configuration for the Planning mode
|
||||
pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
||||
// 1. Get the prompt, formatted with current data
|
||||
|
@ -29,70 +28,64 @@ pub fn get_configuration(agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
.replace("{DATASETS}", &agent_data.dataset_names.join(", "));
|
||||
|
||||
// 2. Define the model for this mode (Using default based on original MODEL = None)
|
||||
let model = "o3-mini".to_string();
|
||||
let model = "o4-mini".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
let tool_loader: Box<
|
||||
dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
|
||||
> = Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
|
||||
// Instantiate tools for this mode
|
||||
let create_plan_straightforward_tool = CreatePlanStraightforward::new(agent_clone.clone());
|
||||
let create_plan_investigative_tool = CreatePlanInvestigative::new(agent_clone.clone());
|
||||
let message_user_clarifying_question_tool = MessageUserClarifyingQuestion::new();
|
||||
let done_tool = Done::new(agent_clone.clone());
|
||||
// Instantiate tools for this mode
|
||||
let create_plan_straightforward_tool =
|
||||
CreatePlanStraightforward::new(agent_clone.clone());
|
||||
let create_plan_investigative_tool = CreatePlanInvestigative::new(agent_clone.clone());
|
||||
let done_tool = Done::new(agent_clone.clone());
|
||||
|
||||
// Condition (always true for this mode's tools)
|
||||
let condition = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
// Condition (always true for this mode's tools)
|
||||
let condition = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
|
||||
// Add tools to the agent
|
||||
agent_clone.add_tool(
|
||||
// Add tools to the agent
|
||||
agent_clone
|
||||
.add_tool(
|
||||
create_plan_straightforward_tool.get_name(),
|
||||
create_plan_straightforward_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
agent_clone
|
||||
.add_tool(
|
||||
create_plan_investigative_tool.get_name(),
|
||||
create_plan_investigative_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
message_user_clarifying_question_tool.get_name(),
|
||||
message_user_clarifying_question_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
agent_clone
|
||||
.add_tool(
|
||||
done_tool.get_name(),
|
||||
done_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
// 4. Define terminating tools for this mode (From original load_tools)
|
||||
let terminating_tools = vec![
|
||||
"message_user_clarifying_question".to_string(), // Hardcoded name
|
||||
"finish_and_respond".to_string(), // Hardcoded name for Done tool
|
||||
];
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
// 5. Construct and return the ModeConfiguration
|
||||
ModeConfiguration {
|
||||
prompt,
|
||||
model,
|
||||
tool_loader,
|
||||
terminating_tools,
|
||||
terminating_tools: vec![Done::get_name()],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Keep the prompt constant, but it's no longer pub
|
||||
const PLANNING_PROMPT: &str = r##"## Overview
|
||||
|
||||
|
@ -394,4 +387,3 @@ By following these guidelines, you can ensure that the visualizations you create
|
|||
##Available Datasets:
|
||||
{DATASETS}
|
||||
"##;
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::tools::ToolExecutor;
|
||||
use crate::Agent; // For get_name()
|
||||
|
@ -27,49 +27,44 @@ pub fn get_configuration(_agent_data: &ModeAgentData) -> ModeConfiguration {
|
|||
let model = "gemini-2.0-flash-001".to_string();
|
||||
|
||||
// 3. Define the tool loader closure
|
||||
let tool_loader: Box<dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync> =
|
||||
Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
let tool_loader: Box<
|
||||
dyn Fn(&Arc<Agent>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync,
|
||||
> = Box::new(|agent_arc: &Arc<Agent>| {
|
||||
let agent_clone = Arc::clone(agent_arc); // Clone Arc for the async block
|
||||
Box::pin(async move {
|
||||
// Clear existing tools before loading mode-specific ones
|
||||
agent_clone.clear_tools().await;
|
||||
|
||||
// Instantiate tools for this mode
|
||||
let review_tool = ReviewPlan::new(agent_clone.clone());
|
||||
let message_user_clarifying_question_tool = MessageUserClarifyingQuestion::new();
|
||||
let done_tool = Done::new(agent_clone.clone());
|
||||
// Instantiate tools for this mode
|
||||
let review_tool = ReviewPlan::new(agent_clone.clone());
|
||||
let done_tool = Done::new(agent_clone.clone());
|
||||
|
||||
// Condition (always true for this mode's tools)
|
||||
let condition = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
// Condition (always true for this mode's tools)
|
||||
let condition = Some(|_state: &HashMap<String, Value>| -> bool { true });
|
||||
|
||||
// Add tools to the agent
|
||||
agent_clone.add_tool(
|
||||
// Add tools to the agent
|
||||
agent_clone
|
||||
.add_tool(
|
||||
review_tool.get_name(),
|
||||
review_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
message_user_clarifying_question_tool.get_name(),
|
||||
message_user_clarifying_question_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
|
||||
agent_clone.add_tool(
|
||||
agent_clone
|
||||
.add_tool(
|
||||
done_tool.get_name(),
|
||||
done_tool.into_tool_call_executor(),
|
||||
condition.clone(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
// 4. Define terminating tools for this mode (From original load_tools)
|
||||
let terminating_tools = vec![
|
||||
"message_user_clarifying_question".to_string(), // Hardcoded name
|
||||
"finish_and_respond".to_string(), // Hardcoded name for Done tool
|
||||
];
|
||||
let terminating_tools = vec![Done::get_name()];
|
||||
|
||||
// 5. Construct and return the ModeConfiguration
|
||||
ModeConfiguration {
|
||||
|
|
|
@ -304,13 +304,15 @@ impl ToolExecutor for CreateDashboardFilesTool {
|
|||
|
||||
let duration = start_time.elapsed().as_millis() as i64;
|
||||
|
||||
self.agent
|
||||
.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
self.agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
if !created_files.is_empty() {
|
||||
self.agent
|
||||
.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
self.agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
}
|
||||
|
||||
// Set review_needed flag if execution was successful
|
||||
if failed_files.is_empty() {
|
||||
|
|
|
@ -235,14 +235,16 @@ impl ToolExecutor for CreateMetricFilesTool {
|
|||
|
||||
let duration = start_time.elapsed().as_millis() as i64;
|
||||
|
||||
self.agent
|
||||
.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
.await;
|
||||
if !created_files.is_empty() {
|
||||
self.agent
|
||||
.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
self.agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
}
|
||||
|
||||
self.agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
// Set review_needed flag if execution was successful
|
||||
if failed_files.is_empty() {
|
||||
self.agent
|
||||
|
|
|
@ -673,7 +673,7 @@ mod tests {
|
|||
fn test_tool_parameter_validation() {
|
||||
let tool = FilterDashboardsTool {
|
||||
agent: Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
"o4-mini".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
|
|
|
@ -25,15 +25,38 @@ pub async fn generate_todos_from_plan(
|
|||
|
||||
let prompt = format!(
|
||||
r#"
|
||||
Given the following plan, extract the main actionable steps and return them as a JSON list of concise todo strings. Focus on the core actions described in each step. Do not include any introductory text, summary, or review steps. Only include the main tasks to be performed.
|
||||
Given the following plan, identify the main high-level objects (e.g., dashboards, visualizations) being created or modified. Return these as a JSON list of descriptive todo strings. Each todo item should summarize the primary creation or modification goal for one object.
|
||||
|
||||
**IMPORTANT**: Do not include granular implementation steps (like adding specific filters or fields), review steps, verification steps, summarization steps, or steps about responding to the user. Focus solely on the final artifact being built or changed.
|
||||
|
||||
Plan:
|
||||
"""
|
||||
{}
|
||||
"""
|
||||
|
||||
Return ONLY a valid JSON array of strings, where each string is a short todo item corresponding to a main step in the plan.
|
||||
Example format: `["Create 11 visualizations", "Create dashboard"]`
|
||||
Return ONLY a valid JSON array of strings, where each string is a descriptive todo item corresponding to a main object being created or modified in the plan.
|
||||
|
||||
Example Plan:
|
||||
**Thought**
|
||||
The user wants to see the daily transaction volume trend over the past month.
|
||||
I'll sum the transaction amounts per day from the `transactions` dataset, filtering for the last 30 days.
|
||||
I will present this as a line chart.
|
||||
|
||||
**Step-by-Step Plan**
|
||||
1. **Create 1 Visualization**:
|
||||
- **Title**: "Daily Transaction Volume (Last 30 Days)"
|
||||
- **Type**: Line Chart
|
||||
- **Datasets**: transactions
|
||||
- **X-Axis**: Day (from transaction timestamp)
|
||||
- **Y-Axis**: Sum of transaction amount
|
||||
- **Filter**: Transaction timestamp within the last 30 days.
|
||||
- **Expected Output**: A line chart showing the total transaction volume for each day over the past 30 days.
|
||||
|
||||
2. **Review & Finish**:
|
||||
- Verify the date filter correctly captures the last 30 days.
|
||||
- Ensure the axes are labeled clearly.
|
||||
|
||||
Example Output for the above plan: `["Create line chart visualization 'Daily Transaction Volume (Last 30 Days)'"]`
|
||||
"#,
|
||||
plan
|
||||
);
|
||||
|
|
|
@ -26,6 +26,10 @@ impl Done {
|
|||
pub fn new(agent: Arc<Agent>) -> Self {
|
||||
Self { agent }
|
||||
}
|
||||
|
||||
pub fn get_name() -> String {
|
||||
"done".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
@ -43,7 +47,10 @@ impl ToolExecutor for Done {
|
|||
Some(Value::Array(arr)) => arr,
|
||||
_ => {
|
||||
// If no todos exist, just return success without a list
|
||||
return Ok(DoneOutput { success: true, todos: "No to-do list found.".to_string() });
|
||||
return Ok(DoneOutput {
|
||||
success: true,
|
||||
todos: "No to-do list found.".to_string(),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -52,28 +59,44 @@ impl ToolExecutor for Done {
|
|||
// Mark all remaining unfinished todos as complete
|
||||
for (idx, todo_val) in todos.iter_mut().enumerate() {
|
||||
if let Value::Object(map) = todo_val {
|
||||
let is_completed = map.get("completed").and_then(Value::as_bool).unwrap_or(false);
|
||||
let is_completed = map
|
||||
.get("completed")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(false);
|
||||
if !is_completed {
|
||||
map.insert("completed".to_string(), Value::Bool(true));
|
||||
marked_by_done.push(idx); // Track 0-based index
|
||||
}
|
||||
} else {
|
||||
// Handle invalid item format if necessary, maybe log a warning?
|
||||
eprintln!("Warning: Invalid todo item format at index {}", idx);
|
||||
// Handle invalid item format if necessary, maybe log a warning?
|
||||
eprintln!("Warning: Invalid todo item format at index {}", idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Save the updated todos back to state
|
||||
self.agent.set_state_value("todos".to_string(), Value::Array(todos.clone())).await; // Clone needed for iteration below
|
||||
|
||||
self.agent
|
||||
.set_state_value("todos".to_string(), Value::Array(todos.clone()))
|
||||
.await; // Clone needed for iteration below
|
||||
|
||||
// Format the output string, potentially noting items marked by 'done'
|
||||
let todos_string = todos.iter().enumerate()
|
||||
let todos_string = todos
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, todo_val)| {
|
||||
if let Value::Object(map) = todo_val {
|
||||
let completed = map.get("completed").and_then(Value::as_bool).unwrap_or(false); // Should always be true now
|
||||
let todo_text = map.get("todo").and_then(Value::as_str).unwrap_or("Invalid todo text");
|
||||
let annotation = if marked_by_done.contains(&idx) { " *Marked complete by calling the done tool" } else { "" };
|
||||
let completed = map
|
||||
.get("completed")
|
||||
.and_then(Value::as_bool)
|
||||
.unwrap_or(false); // Should always be true now
|
||||
let todo_text = map
|
||||
.get("todo")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("Invalid todo text");
|
||||
let annotation = if marked_by_done.contains(&idx) {
|
||||
" *Marked complete by calling the done tool"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("[x] {}{}", todo_text, annotation)
|
||||
} else {
|
||||
"Invalid todo item format".to_string()
|
||||
|
@ -82,16 +105,18 @@ impl ToolExecutor for Done {
|
|||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
|
||||
// This tool signals the end of the workflow and provides the final response.
|
||||
// The actual agent termination logic resides elsewhere.
|
||||
Ok(DoneOutput { success: true, todos: todos_string }) // Include todos in output
|
||||
Ok(DoneOutput {
|
||||
success: true,
|
||||
todos: todos_string,
|
||||
}) // Include todos in output
|
||||
}
|
||||
|
||||
async fn get_schema(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"name": self.get_name(),
|
||||
"description": "Marks all remaining unfinished tasks as complete, sends a final response to the user, and ends the workflow. Use this when the workflow is finished.",
|
||||
"description": "Marks all remaining unfinished tasks as complete, sends a final response to the user, and ends the workflow. Use this when the workflow is finished. This must be in markdown format and not use the '•' bullet character.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -100,7 +125,7 @@ impl ToolExecutor for Done {
|
|||
"properties": {
|
||||
"final_response": {
|
||||
"type": "string",
|
||||
"description": "The final response message to the user. **MUST** be formatted in Markdown. Use bullet points or other appropriate Markdown formatting. Do not include headers."
|
||||
"description": "The final response message to the user. **MUST** be formatted in Markdown. Use bullet points or other appropriate Markdown formatting. Do not include headers. Do not use the '•' bullet character."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
|
|
|
@ -22,6 +22,10 @@ impl MessageUserClarifyingQuestion {
|
|||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn get_name() -> String {
|
||||
"message_user_clarifying_question".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
Loading…
Reference in New Issue