mirror of https://github.com/buster-so/buster.git
conditional prompt switching
This commit is contained in:
parent
e147711a23
commit
f46376eac0
|
@ -132,6 +132,12 @@ struct RegisteredTool {
|
|||
enablement_condition: Option<Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>>,
|
||||
}
|
||||
|
||||
// Helper struct for dynamic prompt rules
|
||||
struct DynamicPromptRule {
|
||||
condition: Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>,
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
// Update the ToolRegistry type alias is no longer needed, but we need the new type for the map
|
||||
type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
|
||||
|
||||
|
@ -161,18 +167,22 @@ pub struct Agent {
|
|||
name: String,
|
||||
/// Shutdown signal sender
|
||||
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
|
||||
/// Default system prompt if no dynamic rules match
|
||||
default_prompt: String,
|
||||
/// Ordered rules for dynamically selecting system prompts
|
||||
dynamic_prompt_rules: Arc<RwLock<Vec<DynamicPromptRule>>>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Create a new Agent instance with a specific LLM client and model
|
||||
pub fn new(
|
||||
model: String,
|
||||
// Note: tools argument is removed as they are added via add_tool now
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
name: String,
|
||||
api_key: Option<String>,
|
||||
base_url: Option<String>,
|
||||
default_prompt: String,
|
||||
) -> Self {
|
||||
let llm_client = LiteLLMClient::new(api_key, base_url);
|
||||
|
||||
|
@ -192,6 +202,8 @@ impl Agent {
|
|||
session_id,
|
||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
name,
|
||||
default_prompt,
|
||||
dynamic_prompt_rules: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -213,6 +225,8 @@ impl Agent {
|
|||
session_id: existing_agent.session_id,
|
||||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown
|
||||
name,
|
||||
default_prompt: existing_agent.default_prompt.clone(),
|
||||
dynamic_prompt_rules: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -538,6 +552,19 @@ impl Agent {
|
|||
return Ok(()); // Don't return error, just stop processing
|
||||
}
|
||||
|
||||
// --- Dynamic Prompt Selection ---
|
||||
let current_system_prompt = self.get_current_prompt().await;
|
||||
let system_message = AgentMessage::developer(current_system_prompt);
|
||||
|
||||
// Prepare messages for LLM: Inject current system prompt and filter out old ones
|
||||
let mut llm_messages = vec![system_message];
|
||||
llm_messages.extend(
|
||||
thread.messages.iter()
|
||||
.filter(|msg| !matches!(msg, AgentMessage::Developer { .. }))
|
||||
.cloned()
|
||||
);
|
||||
// --- End Dynamic Prompt Selection ---
|
||||
|
||||
// Collect all enabled tools and their schemas
|
||||
let tools = self.get_enabled_tools().await; // Now uses the new logic
|
||||
|
||||
|
@ -549,7 +576,7 @@ impl Agent {
|
|||
// Create the tool-enabled request
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: thread.messages.clone(),
|
||||
messages: llm_messages, // Use the dynamically prepared messages list
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream: Some(true), // Enable streaming
|
||||
|
@ -982,6 +1009,37 @@ impl Agent {
|
|||
let mut tx = self.stream_tx.write().await;
|
||||
*tx = None;
|
||||
}
|
||||
|
||||
/// Add a rule for dynamically selecting a system prompt.
|
||||
/// Rules are checked in the order they are added. The first matching rule's prompt is used.
|
||||
pub async fn add_dynamic_prompt_rule<F>(
|
||||
&self,
|
||||
condition: F,
|
||||
prompt: String,
|
||||
)
|
||||
where
|
||||
F: Fn(&HashMap<String, Value>) -> bool + Send + Sync + 'static,
|
||||
{
|
||||
let rule = DynamicPromptRule {
|
||||
condition: Box::new(condition),
|
||||
prompt,
|
||||
};
|
||||
self.dynamic_prompt_rules.write().await.push(rule);
|
||||
}
|
||||
|
||||
/// Gets the system prompt based on the current agent state and dynamic rules.
|
||||
async fn get_current_prompt(&self) -> String {
|
||||
let rules = self.dynamic_prompt_rules.read().await;
|
||||
let state = self.state.read().await;
|
||||
|
||||
for rule in rules.iter() {
|
||||
if (rule.condition)(&state) {
|
||||
return rule.prompt.clone(); // Return the first matching rule's prompt
|
||||
}
|
||||
}
|
||||
|
||||
self.default_prompt.clone() // Fallback to default prompt if no rules match
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
|
@ -1178,6 +1236,7 @@ mod tests {
|
|||
"test_agent_no_tools".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
"".to_string(),
|
||||
));
|
||||
|
||||
let thread = AgentThread::new(
|
||||
|
@ -1207,6 +1266,7 @@ mod tests {
|
|||
"test_agent_with_tools".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
"".to_string(),
|
||||
));
|
||||
|
||||
// Create weather tool with reference to agent
|
||||
|
@ -1246,6 +1306,7 @@ mod tests {
|
|||
"test_agent_multi_step".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
"".to_string(),
|
||||
));
|
||||
|
||||
let weather_tool = WeatherTool::new(Arc::clone(&agent));
|
||||
|
@ -1284,6 +1345,7 @@ mod tests {
|
|||
"test_agent_disabled".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
"".to_string(),
|
||||
));
|
||||
|
||||
// Create weather tool
|
||||
|
@ -1348,6 +1410,7 @@ mod tests {
|
|||
"test_agent_state".to_string(),
|
||||
env::var("LLM_API_KEY").ok(),
|
||||
env::var("LLM_BASE_URL").ok(),
|
||||
"".to_string(),
|
||||
));
|
||||
|
||||
// Test setting single values
|
||||
|
|
|
@ -22,6 +22,9 @@ use crate::{
|
|||
|
||||
use litellm::AgentMessage;
|
||||
|
||||
// Type alias for the enablement condition closure for tools
|
||||
type ToolEnablementCondition = Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BusterSuperAgentOutput {
|
||||
pub message: String,
|
||||
|
@ -128,7 +131,7 @@ impl BusterMultiAgent {
|
|||
}
|
||||
|
||||
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
|
||||
// Create agent (Agent::new no longer takes tools directly)
|
||||
// Create agent, passing the initialization prompt as default
|
||||
let agent = Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
user_id,
|
||||
|
@ -136,8 +139,26 @@ impl BusterMultiAgent {
|
|||
"buster_super_agent".to_string(),
|
||||
None,
|
||||
None,
|
||||
INTIALIZATION_PROMPT.to_string(), // Default prompt
|
||||
));
|
||||
|
||||
// Define prompt switching conditions
|
||||
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && !state.contains_key("plan_available")
|
||||
};
|
||||
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
|
||||
// Example: Trigger analysis prompt once plan is available and metrics/dashboards are not yet available
|
||||
state.contains_key("plan_available")
|
||||
&& !state.contains_key("metrics_available")
|
||||
&& !state.contains_key("dashboards_available")
|
||||
};
|
||||
|
||||
// Add prompt rules (order matters)
|
||||
// The agent will use the prompt associated with the first condition that evaluates to true.
|
||||
// If none match, it uses the default (INITIALIZATION_PROMPT).
|
||||
agent.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()).await;
|
||||
agent.add_dynamic_prompt_rule(needs_analysis_condition, ANALYSIS_PROMPT.to_string()).await;
|
||||
|
||||
let manager = Self { agent };
|
||||
manager.load_tools().await?; // Load tools with conditions
|
||||
Ok(manager)
|
||||
|
@ -149,6 +170,20 @@ impl BusterMultiAgent {
|
|||
existing_agent,
|
||||
"buster_super_agent".to_string(),
|
||||
));
|
||||
|
||||
// Re-apply prompt rules for the new agent instance if necessary
|
||||
// (Currently Agent::from_existing copies the default prompt but not rules)
|
||||
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("data_context") && !state.contains_key("plan_available")
|
||||
};
|
||||
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
|
||||
state.contains_key("plan_available")
|
||||
&& !state.contains_key("metrics_available")
|
||||
&& !state.contains_key("dashboards_available")
|
||||
};
|
||||
agent.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()).await;
|
||||
agent.add_dynamic_prompt_rule(needs_analysis_condition, ANALYSIS_PROMPT.to_string()).await;
|
||||
|
||||
let manager = Self { agent };
|
||||
manager.load_tools().await?; // Load tools with conditions for the new agent instance
|
||||
Ok(manager)
|
||||
|
@ -158,9 +193,10 @@ impl BusterMultiAgent {
|
|||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(INTIALIZATION_PROMPT.to_string());
|
||||
// Remove the explicit setting of the developer message here
|
||||
// thread.set_developer_message(INTIALIZATION_PROMPT.to_string());
|
||||
|
||||
// Get shutdown receiver
|
||||
// Start processing (prompt is handled dynamically within process_thread_with_depth)
|
||||
let rx = self.stream_process_thread(thread).await?;
|
||||
|
||||
Ok(rx)
|
||||
|
|
Loading…
Reference in New Issue