conditional prompt switching

This commit is contained in:
dal 2025-04-10 13:28:03 -06:00
parent e147711a23
commit f46376eac0
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 104 additions and 5 deletions

View File

@ -132,6 +132,12 @@ struct RegisteredTool {
enablement_condition: Option<Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>>, enablement_condition: Option<Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>>,
} }
// Helper struct for dynamic prompt rules
struct DynamicPromptRule {
condition: Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>,
prompt: String,
}
// Update the ToolRegistry type alias is no longer needed, but we need the new type for the map // Update the ToolRegistry type alias is no longer needed, but we need the new type for the map
type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>; type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
@ -161,18 +167,22 @@ pub struct Agent {
name: String, name: String,
/// Shutdown signal sender /// Shutdown signal sender
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>, shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
/// Default system prompt if no dynamic rules match
default_prompt: String,
/// Ordered rules for dynamically selecting system prompts
dynamic_prompt_rules: Arc<RwLock<Vec<DynamicPromptRule>>>,
} }
impl Agent { impl Agent {
/// Create a new Agent instance with a specific LLM client and model /// Create a new Agent instance with a specific LLM client and model
pub fn new( pub fn new(
model: String, model: String,
// Note: tools argument is removed as they are added via add_tool now
user_id: Uuid, user_id: Uuid,
session_id: Uuid, session_id: Uuid,
name: String, name: String,
api_key: Option<String>, api_key: Option<String>,
base_url: Option<String>, base_url: Option<String>,
default_prompt: String,
) -> Self { ) -> Self {
let llm_client = LiteLLMClient::new(api_key, base_url); let llm_client = LiteLLMClient::new(api_key, base_url);
@ -192,6 +202,8 @@ impl Agent {
session_id, session_id,
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)), shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
name, name,
default_prompt,
dynamic_prompt_rules: Arc::new(RwLock::new(Vec::new())),
} }
} }
@ -213,6 +225,8 @@ impl Agent {
session_id: existing_agent.session_id, session_id: existing_agent.session_id,
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown
name, name,
default_prompt: existing_agent.default_prompt.clone(),
dynamic_prompt_rules: Arc::new(RwLock::new(Vec::new())),
} }
} }
@ -538,6 +552,19 @@ impl Agent {
return Ok(()); // Don't return error, just stop processing return Ok(()); // Don't return error, just stop processing
} }
// --- Dynamic Prompt Selection ---
let current_system_prompt = self.get_current_prompt().await;
let system_message = AgentMessage::developer(current_system_prompt);
// Prepare messages for LLM: Inject current system prompt and filter out old ones
let mut llm_messages = vec![system_message];
llm_messages.extend(
thread.messages.iter()
.filter(|msg| !matches!(msg, AgentMessage::Developer { .. }))
.cloned()
);
// --- End Dynamic Prompt Selection ---
// Collect all enabled tools and their schemas // Collect all enabled tools and their schemas
let tools = self.get_enabled_tools().await; // Now uses the new logic let tools = self.get_enabled_tools().await; // Now uses the new logic
@ -549,7 +576,7 @@ impl Agent {
// Create the tool-enabled request // Create the tool-enabled request
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: self.model.clone(), model: self.model.clone(),
messages: thread.messages.clone(), messages: llm_messages, // Use the dynamically prepared messages list
tools: if tools.is_empty() { None } else { Some(tools) }, tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
stream: Some(true), // Enable streaming stream: Some(true), // Enable streaming
@ -982,6 +1009,37 @@ impl Agent {
let mut tx = self.stream_tx.write().await; let mut tx = self.stream_tx.write().await;
*tx = None; *tx = None;
} }
/// Add a rule for dynamically selecting a system prompt.
/// Rules are checked in the order they are added. The first matching rule's prompt is used.
pub async fn add_dynamic_prompt_rule<F>(
&self,
condition: F,
prompt: String,
)
where
F: Fn(&HashMap<String, Value>) -> bool + Send + Sync + 'static,
{
let rule = DynamicPromptRule {
condition: Box::new(condition),
prompt,
};
self.dynamic_prompt_rules.write().await.push(rule);
}
/// Gets the system prompt based on the current agent state and dynamic rules.
async fn get_current_prompt(&self) -> String {
let rules = self.dynamic_prompt_rules.read().await;
let state = self.state.read().await;
for rule in rules.iter() {
if (rule.condition)(&state) {
return rule.prompt.clone(); // Return the first matching rule's prompt
}
}
self.default_prompt.clone() // Fallback to default prompt if no rules match
}
} }
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
@ -1178,6 +1236,7 @@ mod tests {
"test_agent_no_tools".to_string(), "test_agent_no_tools".to_string(),
env::var("LLM_API_KEY").ok(), env::var("LLM_API_KEY").ok(),
env::var("LLM_BASE_URL").ok(), env::var("LLM_BASE_URL").ok(),
"".to_string(),
)); ));
let thread = AgentThread::new( let thread = AgentThread::new(
@ -1207,6 +1266,7 @@ mod tests {
"test_agent_with_tools".to_string(), "test_agent_with_tools".to_string(),
env::var("LLM_API_KEY").ok(), env::var("LLM_API_KEY").ok(),
env::var("LLM_BASE_URL").ok(), env::var("LLM_BASE_URL").ok(),
"".to_string(),
)); ));
// Create weather tool with reference to agent // Create weather tool with reference to agent
@ -1246,6 +1306,7 @@ mod tests {
"test_agent_multi_step".to_string(), "test_agent_multi_step".to_string(),
env::var("LLM_API_KEY").ok(), env::var("LLM_API_KEY").ok(),
env::var("LLM_BASE_URL").ok(), env::var("LLM_BASE_URL").ok(),
"".to_string(),
)); ));
let weather_tool = WeatherTool::new(Arc::clone(&agent)); let weather_tool = WeatherTool::new(Arc::clone(&agent));
@ -1284,6 +1345,7 @@ mod tests {
"test_agent_disabled".to_string(), "test_agent_disabled".to_string(),
env::var("LLM_API_KEY").ok(), env::var("LLM_API_KEY").ok(),
env::var("LLM_BASE_URL").ok(), env::var("LLM_BASE_URL").ok(),
"".to_string(),
)); ));
// Create weather tool // Create weather tool
@ -1348,6 +1410,7 @@ mod tests {
"test_agent_state".to_string(), "test_agent_state".to_string(),
env::var("LLM_API_KEY").ok(), env::var("LLM_API_KEY").ok(),
env::var("LLM_BASE_URL").ok(), env::var("LLM_BASE_URL").ok(),
"".to_string(),
)); ));
// Test setting single values // Test setting single values

View File

@ -22,6 +22,9 @@ use crate::{
use litellm::AgentMessage; use litellm::AgentMessage;
// Type alias for the enablement condition closure for tools
type ToolEnablementCondition = Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct BusterSuperAgentOutput { pub struct BusterSuperAgentOutput {
pub message: String, pub message: String,
@ -128,7 +131,7 @@ impl BusterMultiAgent {
} }
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> { pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
// Create agent (Agent::new no longer takes tools directly) // Create agent, passing the initialization prompt as default
let agent = Arc::new(Agent::new( let agent = Arc::new(Agent::new(
"o3-mini".to_string(), "o3-mini".to_string(),
user_id, user_id,
@ -136,8 +139,26 @@ impl BusterMultiAgent {
"buster_super_agent".to_string(), "buster_super_agent".to_string(),
None, None,
None, None,
INTIALIZATION_PROMPT.to_string(), // Default prompt
)); ));
// Define prompt switching conditions
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
state.contains_key("data_context") && !state.contains_key("plan_available")
};
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
// Example: Trigger analysis prompt once plan is available and metrics/dashboards are not yet available
state.contains_key("plan_available")
&& !state.contains_key("metrics_available")
&& !state.contains_key("dashboards_available")
};
// Add prompt rules (order matters)
// The agent will use the prompt associated with the first condition that evaluates to true.
// If none match, it uses the default (INITIALIZATION_PROMPT).
agent.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()).await;
agent.add_dynamic_prompt_rule(needs_analysis_condition, ANALYSIS_PROMPT.to_string()).await;
let manager = Self { agent }; let manager = Self { agent };
manager.load_tools().await?; // Load tools with conditions manager.load_tools().await?; // Load tools with conditions
Ok(manager) Ok(manager)
@ -149,6 +170,20 @@ impl BusterMultiAgent {
existing_agent, existing_agent,
"buster_super_agent".to_string(), "buster_super_agent".to_string(),
)); ));
// Re-apply prompt rules for the new agent instance if necessary
// (Currently Agent::from_existing copies the default prompt but not rules)
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
state.contains_key("data_context") && !state.contains_key("plan_available")
};
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
state.contains_key("plan_available")
&& !state.contains_key("metrics_available")
&& !state.contains_key("dashboards_available")
};
agent.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()).await;
agent.add_dynamic_prompt_rule(needs_analysis_condition, ANALYSIS_PROMPT.to_string()).await;
let manager = Self { agent }; let manager = Self { agent };
manager.load_tools().await?; // Load tools with conditions for the new agent instance manager.load_tools().await?; // Load tools with conditions for the new agent instance
Ok(manager) Ok(manager)
@ -158,9 +193,10 @@ impl BusterMultiAgent {
&self, &self,
thread: &mut AgentThread, thread: &mut AgentThread,
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> { ) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
thread.set_developer_message(INTIALIZATION_PROMPT.to_string()); // Remove the explicit setting of the developer message here
// thread.set_developer_message(INTIALIZATION_PROMPT.to_string());
// Get shutdown receiver // Start processing (prompt is handled dynamically within process_thread_with_depth)
let rx = self.stream_process_thread(thread).await?; let rx = self.stream_process_thread(thread).await?;
Ok(rx) Ok(rx)