added in review todos, and fixed the search bug

This commit is contained in:
dal 2025-04-16 08:29:25 -06:00
parent c96f33bed7
commit ff0001139e
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
8 changed files with 245 additions and 44 deletions

View File

@ -20,6 +20,7 @@ use crate::{
response_tools::{Done, MessageUserClarifyingQuestion},
utility_tools::no_search_needed::NoSearchNeededTool, // <-- Fixed import path
},
planning_tools::ReviewPlan,
IntoToolCallExecutor, ToolExecutor,
},
Agent, AgentError, AgentExt, AgentThread,
@ -27,7 +28,12 @@ use crate::{
use litellm::AgentMessage;
use super::{analysis_prompt::ANALYSIS_PROMPT, create_plan_prompt::CREATE_PLAN_PROMPT, data_catalog_search_prompt::DATA_CATALOG_SEARCH_PROMPT, initialization_follow_up_prompt::FOLLOW_UP_INTIALIZATION_PROMPT, initialization_prompt::INTIALIZATION_PROMPT};
use super::{
analysis_prompt::ANALYSIS_PROMPT, create_plan_prompt::CREATE_PLAN_PROMPT,
data_catalog_search_prompt::DATA_CATALOG_SEARCH_PROMPT,
initialization_follow_up_prompt::FOLLOW_UP_INTIALIZATION_PROMPT,
initialization_prompt::INTIALIZATION_PROMPT, review_prompt::REVIEW_PROMPT,
};
#[derive(Debug, Serialize, Deserialize)]
pub struct BusterSuperAgentOutput {
@ -68,6 +74,7 @@ impl BusterMultiAgent {
let message_user_clarifying_question_tool = MessageUserClarifyingQuestion::new();
let done_tool = Done::new();
let no_search_needed_tool = NoSearchNeededTool::new(Arc::clone(&self.agent));
let review_tool = ReviewPlan::new(Arc::clone(&self.agent));
// Get names before moving tools
let done_tool_name = done_tool.get_name();
@ -91,21 +98,8 @@ impl BusterMultiAgent {
.unwrap_or(false)
});
let response_tools_condition = Some(|state: &HashMap<String, Value>| -> bool {
// Check the state map for the follow-up indicator
let is_follow_up = state
.get("is_follow_up")
.and_then(Value::as_bool)
.unwrap_or(false);
if is_follow_up {
// For follow-ups, enable if neither data context nor plan is available
!state.contains_key("data_context") && !state.contains_key("plan_available")
} else {
// For initial requests, enable only if data context is not yet available
!state.contains_key("data_context")
}
});
let review_condition =
Some(|state: &HashMap<String, Value>| -> bool { state.contains_key("review_needed") });
let planning_tools_condition = Some(|state: &HashMap<String, Value>| -> bool {
let searched_catalog = state
@ -200,18 +194,25 @@ impl BusterMultiAgent {
.await;
self.agent
.add_tool(
msg_clarifying_q_tool_name.clone(),
message_user_clarifying_question_tool.get_name(),
message_user_clarifying_question_tool.into_tool_call_executor(),
after_search_condition.clone(), // Use after_search_condition
)
.await;
self.agent
.add_tool(
done_tool_name.clone(),
done_tool.get_name(),
done_tool.into_tool_call_executor(),
after_search_condition.clone(), // Use after_search_condition instead of None
)
.await;
self.agent
.add_tool(
review_tool.get_name(),
review_tool.into_tool_call_executor(),
review_condition.clone(),
)
.await;
// Register terminating tools by name using the stored names
self.agent.register_terminating_tool(done_tool_name).await;
@ -239,10 +240,12 @@ impl BusterMultiAgent {
// Select initial default prompt based on whether it's a follow-up
let initial_default_prompt = if is_follow_up {
FOLLOW_UP_INTIALIZATION_PROMPT.replace("{DATASETS}", &dataset_names.join(", "))
FOLLOW_UP_INTIALIZATION_PROMPT
.replace("{DATASETS}", &dataset_names.join(", "))
.replace("{TODAYS_DATE}", &todays_date)
} else {
INTIALIZATION_PROMPT.replace("{DATASETS}", &dataset_names.join(", "))
INTIALIZATION_PROMPT
.replace("{DATASETS}", &dataset_names.join(", "))
.replace("{TODAYS_DATE}", &todays_date)
};
@ -277,9 +280,15 @@ impl BusterMultiAgent {
!state.contains_key("searched_data_catalog")
};
let needs_review_condition =
|state: &HashMap<String, Value>| -> bool { state.contains_key("review_needed") };
// Add prompt rules (order matters)
// The agent will use the prompt associated with the first condition that evaluates to true.
// If none match, it uses the default (INITIALIZATION_PROMPT).
agent
.add_dynamic_prompt_rule(needs_review_condition, REVIEW_PROMPT.to_string())
.await;
agent
.add_dynamic_prompt_rule(
needs_data_catalog_search_condition,
@ -306,6 +315,12 @@ impl BusterMultiAgent {
.await;
// Add dynamic model rule: Use gpt-4.1 when searching the data catalog
agent
.add_dynamic_model_rule(
needs_review_condition, // Reuse the same condition
"gemini-2.0-flash-001".to_string(),
)
.await;
agent
.add_dynamic_model_rule(
needs_data_catalog_search_condition, // Reuse the same condition
@ -382,8 +397,3 @@ impl BusterMultiAgent {
None
}
}

View File

@ -3,3 +3,4 @@ pub mod initialization_follow_up_prompt;
pub mod analysis_prompt;
pub mod data_catalog_search_prompt;
pub mod create_plan_prompt;
pub mod review_prompt;

View File

@ -0,0 +1,53 @@
pub const REVIEW_PROMPT: &str = r##"
Role & Task
You are Buster, an expert analytics and data engineer. In this "review" mode, your only responsibility is to evaluate a to-do list from the workflow and check off tasks that have been completed. You do not create or analyze anythingjust assess and track progress.
Workflow Summary
Review the to-do list to see the tasks that need to be checked.
Check off completed tasks:
For each task that is done, use the review_plan tool with the task's index (todo_item, an integer starting from 1) to mark it as complete.
If a task isn't done, leave it unchecked.
Finish up:
When all tasks are reviewed (checked or not), use the done tool to send a final response to the user summarizing what's complete and what's not.
Tool Calling
You have two tools to do your job:
review_plan: Marks a task as complete. Needs todo_item (an integer) to specify which task (starts at 1).
done: Sends the final response to the user and ends the workflow.
Follow these rules:
Use tools for everythingno direct replies allowed.
Stick to the exact tool format with all required details.
Only use these two tools, nothing else.
Don't mention tool names in your explanations (e.g., say "I marked the task as done" instead of naming the tool).
Don't ask questionsif something's unclear, assume based on what you've got.
Guidelines
Keep it simple: Just check what's done and move on.
Be accurate: Only mark tasks that are actually complete.
Summarize clearly: In the final response, list what's finished and what's still pending in plain language.
Final Response Guidelines
When using the done tool:
Use simple, friendly language anyone can understand.
Say what's done and what's not, keeping it short and clear.
Use "I" (e.g., "I marked three tasks as done").
Use markdown for lists if it helps.
Don't use technical terms or mention tools.
Keep going until you've reviewed every task on the list. Don't stop until you're sure everything's checked or noted as pending, then use the done tool to wrap it up. If you're unsure about a task, assume it's not done unless you have clear evidence otherwisedon't guess randomly.
"##;

View File

@ -3,19 +3,20 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use std::time::Instant;
use crate::{agent::Agent, tools::ToolExecutor};
#[derive(Debug, Serialize, Deserialize)]
pub struct CreatePlanInvestigativeOutput {
pub success: bool,
pub todos: String,
}
#[derive(Debug, Deserialize)]
pub struct CreatePlanInvestigativeInput {
plan: String,
plan_todos: Vec<String>,
#[serde(rename = "plan")]
_plan: String,
todos: Vec<String>,
}
pub struct CreatePlanInvestigative {
@ -42,7 +43,24 @@ impl ToolExecutor for CreatePlanInvestigative {
.set_state_value(String::from("plan_available"), Value::Bool(true))
.await;
Ok(CreatePlanInvestigativeOutput { success: true })
let todos_state_objects: Vec<Value> = params
.todos
.iter()
.map(|item| {
let mut map = serde_json::Map::new();
map.insert("completed".to_string(), Value::Bool(false));
map.insert("todo".to_string(), Value::String(item.clone()));
Value::Object(map)
})
.collect();
self.agent
.set_state_value(String::from("todos"), Value::Array(todos_state_objects))
.await;
let todos_string = params.todos.iter().map(|item| format!("[ ] {}", item)).collect::<Vec<_>>().join("\n");
Ok(CreatePlanInvestigativeOutput { success: true, todos: todos_string })
}
async fn get_schema(&self) -> Value {
@ -54,16 +72,16 @@ impl ToolExecutor for CreatePlanInvestigative {
"type": "object",
"required": [
"plan",
"plan_todos"
"todos"
],
"properties": {
"plan": {
"type": "string",
"description": get_plan_investigative_description().await
},
"plan_todos": {
"todos": {
"type": "array",
"description": "Ordered todo points summarizing the plan. There should be one todo for each step in the plan, in order. For example, if the plan has two steps, plan_todos should have two items, each summarizing a step.",
"description": "Ordered todo points summarizing the plan. There should be one todo for each step in the plan, in order. For example, if the plan has two steps, plan_todos should have two items, each summarizing a step. Do not include review or response steps—these will be handled by a separate agent.",
"items": { "type": "string" },
},
},

View File

@ -3,19 +3,20 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use std::time::Instant;
use crate::{agent::Agent, tools::ToolExecutor};
#[derive(Debug, Serialize, Deserialize)]
pub struct CreatePlanStraightforwardOutput {
pub success: bool,
pub todos: String,
}
#[derive(Debug, Deserialize)]
pub struct CreatePlanStraightforwardInput {
plan: String,
plan_todos: Vec<String>,
#[serde(rename = "plan")]
_plan: String,
todos: Vec<String>,
}
pub struct CreatePlanStraightforward {
@ -38,12 +39,28 @@ impl ToolExecutor for CreatePlanStraightforward {
}
async fn execute(&self, params: Self::Params, _tool_call_id: String) -> Result<Self::Output> {
let start_time = Instant::now();
self.agent
.set_state_value(String::from("plan_available"), Value::Bool(true))
.await;
Ok(CreatePlanStraightforwardOutput { success: true })
let todos_state_objects: Vec<Value> = params
.todos
.iter()
.map(|item| {
let mut map = serde_json::Map::new();
map.insert("completed".to_string(), Value::Bool(false));
map.insert("todo".to_string(), Value::String(item.clone()));
Value::Object(map)
})
.collect();
self.agent
.set_state_value(String::from("todos"), Value::Array(todos_state_objects))
.await;
let todos_string = params.todos.iter().map(|item| format!("[ ] {}", item)).collect::<Vec<_>>().join("\n");
Ok(CreatePlanStraightforwardOutput { success: true, todos: todos_string })
}
async fn get_schema(&self) -> Value {
@ -55,16 +72,16 @@ impl ToolExecutor for CreatePlanStraightforward {
"type": "object",
"required": [
"plan",
"plan_todos"
"todos"
],
"properties": {
"plan": {
"type": "string",
"description": get_plan_straightforward_description().await
},
"plan_todos": {
"todos": {
"type": "array",
"description": "Ordered todo points summarizing the plan. There should be one todo for each step in the plan, in order. For example, if the plan has two steps, plan_todos should have two items, each summarizing a step.",
"description": "Ordered todo points summarizing the plan. There should be one todo for each step in the plan, in order. For example, if the plan has two steps, plan_todos should have two items, each summarizing a step. Do not include review or response steps—these will be handled by a separate agent.",
"items": { "type": "string" },
},
},

View File

@ -1,5 +1,7 @@
pub mod create_plan_investigative;
pub mod create_plan_straightforward;
pub mod review_plan;
pub use create_plan_investigative::*;
pub use create_plan_straightforward::*;
pub use create_plan_straightforward::*;
pub use review_plan::*;

View File

@ -0,0 +1,97 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use crate::{agent::Agent, tools::ToolExecutor};
#[derive(Debug, Serialize, Deserialize)]
pub struct ReviewPlanOutput {
pub success: bool,
pub todos: String,
}
#[derive(Debug, Deserialize)]
pub struct ReviewPlanInput {
pub todo_item: usize, // 0-based index
}
pub struct ReviewPlan {
agent: Arc<Agent>,
}
impl ReviewPlan {
pub fn new(agent: Arc<Agent>) -> Self {
Self { agent }
}
}
#[async_trait]
impl ToolExecutor for ReviewPlan {
type Output = ReviewPlanOutput;
type Params = ReviewPlanInput;
fn get_name(&self) -> String {
"review_plan".to_string()
}
async fn execute(&self, params: Self::Params, _tool_call_id: String) -> Result<Self::Output> {
// Get the current todos from state
let mut todos = match self.agent.get_state_value("todos").await {
Some(Value::Array(arr)) => arr,
_ => {
return Err(anyhow::anyhow!("Could not find 'todos' in agent state or it's not an array."));
}
};
let idx = params.todo_item;
if idx >= todos.len() {
return Err(anyhow::anyhow!("todo_item index {} out of range ({} todos)", idx, todos.len()));
}
// Mark the todo at the given index as complete
if let Some(Value::Object(map)) = todos.get_mut(idx) {
map.insert("completed".to_string(), Value::Bool(true));
} else {
return Err(anyhow::anyhow!("Todo item at index {} is not a valid object.", idx));
}
// Save the updated todos back to state
self.agent.set_state_value("todos".to_string(), Value::Array(todos.clone())).await; // Clone needed for iteration below
// Format the output string
let todos_string = todos.iter()
.map(|todo_val| {
if let Value::Object(map) = todo_val {
let completed = map.get("completed").and_then(Value::as_bool).unwrap_or(false);
let todo_text = map.get("todo").and_then(Value::as_str).unwrap_or("Invalid todo text");
format!("[{}] {}", if completed { "x" } else { " " }, todo_text)
} else {
"Invalid todo item format".to_string()
}
})
.collect::<Vec<_>>()
.join("\n");
Ok(ReviewPlanOutput { success: true, todos: todos_string })
}
async fn get_schema(&self) -> Value {
serde_json::json!({
"name": self.get_name(),
"description": "Marks a task as complete by its index in the to-do list.",
"parameters": {
"type": "object",
"properties": {
"todo_item": {
"type": "integer",
"description": "The 0-based index of the task to mark as complete (0 is the first item).",
"minimum": 0
}
},
"required": ["todo_item"]
}
})
}
}

View File

@ -166,8 +166,11 @@ pub async fn search(
};
// Compare highlights count (descending), then score (descending)
highlights_b.cmp(&highlights_a)
.then_with(|| score_b.partial_cmp(&score_a).unwrap_or(std::cmp::Ordering::Equal))
highlights_b.cmp(&highlights_a).then_with(|| {
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
// Only filter when we have a query
@ -242,7 +245,7 @@ pub async fn list_recent_assets(
);
info!("Generated SQL for list_recent_assets: {}", sql_query);
let mut results = sqlx::query(&sql_query).fetch(&mut *conn);
let mut results = sqlx::raw_sql(&sql_query).fetch(&mut *conn);
let mut results_vec = Vec::new();
while let Some(row) = results.try_next().await? {