mirror of https://github.com/buster-so/buster.git
conditional prompt switching
This commit is contained in:
parent
e147711a23
commit
f46376eac0
|
@ -132,6 +132,12 @@ struct RegisteredTool {
|
||||||
enablement_condition: Option<Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>>,
|
enablement_condition: Option<Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper struct for dynamic prompt rules
|
||||||
|
struct DynamicPromptRule {
|
||||||
|
condition: Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>,
|
||||||
|
prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
// Update the ToolRegistry type alias is no longer needed, but we need the new type for the map
|
// Update the ToolRegistry type alias is no longer needed, but we need the new type for the map
|
||||||
type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
|
type ToolsMap = Arc<RwLock<HashMap<String, RegisteredTool>>>;
|
||||||
|
|
||||||
|
@ -161,18 +167,22 @@ pub struct Agent {
|
||||||
name: String,
|
name: String,
|
||||||
/// Shutdown signal sender
|
/// Shutdown signal sender
|
||||||
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
|
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
|
||||||
|
/// Default system prompt if no dynamic rules match
|
||||||
|
default_prompt: String,
|
||||||
|
/// Ordered rules for dynamically selecting system prompts
|
||||||
|
dynamic_prompt_rules: Arc<RwLock<Vec<DynamicPromptRule>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
/// Create a new Agent instance with a specific LLM client and model
|
/// Create a new Agent instance with a specific LLM client and model
|
||||||
pub fn new(
|
pub fn new(
|
||||||
model: String,
|
model: String,
|
||||||
// Note: tools argument is removed as they are added via add_tool now
|
|
||||||
user_id: Uuid,
|
user_id: Uuid,
|
||||||
session_id: Uuid,
|
session_id: Uuid,
|
||||||
name: String,
|
name: String,
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
base_url: Option<String>,
|
base_url: Option<String>,
|
||||||
|
default_prompt: String,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let llm_client = LiteLLMClient::new(api_key, base_url);
|
let llm_client = LiteLLMClient::new(api_key, base_url);
|
||||||
|
|
||||||
|
@ -192,6 +202,8 @@ impl Agent {
|
||||||
session_id,
|
session_id,
|
||||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||||
name,
|
name,
|
||||||
|
default_prompt,
|
||||||
|
dynamic_prompt_rules: Arc::new(RwLock::new(Vec::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,6 +225,8 @@ impl Agent {
|
||||||
session_id: existing_agent.session_id,
|
session_id: existing_agent.session_id,
|
||||||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown
|
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx), // Shared shutdown
|
||||||
name,
|
name,
|
||||||
|
default_prompt: existing_agent.default_prompt.clone(),
|
||||||
|
dynamic_prompt_rules: Arc::new(RwLock::new(Vec::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -538,6 +552,19 @@ impl Agent {
|
||||||
return Ok(()); // Don't return error, just stop processing
|
return Ok(()); // Don't return error, just stop processing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Dynamic Prompt Selection ---
|
||||||
|
let current_system_prompt = self.get_current_prompt().await;
|
||||||
|
let system_message = AgentMessage::developer(current_system_prompt);
|
||||||
|
|
||||||
|
// Prepare messages for LLM: Inject current system prompt and filter out old ones
|
||||||
|
let mut llm_messages = vec![system_message];
|
||||||
|
llm_messages.extend(
|
||||||
|
thread.messages.iter()
|
||||||
|
.filter(|msg| !matches!(msg, AgentMessage::Developer { .. }))
|
||||||
|
.cloned()
|
||||||
|
);
|
||||||
|
// --- End Dynamic Prompt Selection ---
|
||||||
|
|
||||||
// Collect all enabled tools and their schemas
|
// Collect all enabled tools and their schemas
|
||||||
let tools = self.get_enabled_tools().await; // Now uses the new logic
|
let tools = self.get_enabled_tools().await; // Now uses the new logic
|
||||||
|
|
||||||
|
@ -549,7 +576,7 @@ impl Agent {
|
||||||
// Create the tool-enabled request
|
// Create the tool-enabled request
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: self.model.clone(),
|
model: self.model.clone(),
|
||||||
messages: thread.messages.clone(),
|
messages: llm_messages, // Use the dynamically prepared messages list
|
||||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
stream: Some(true), // Enable streaming
|
stream: Some(true), // Enable streaming
|
||||||
|
@ -982,6 +1009,37 @@ impl Agent {
|
||||||
let mut tx = self.stream_tx.write().await;
|
let mut tx = self.stream_tx.write().await;
|
||||||
*tx = None;
|
*tx = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a rule for dynamically selecting a system prompt.
|
||||||
|
/// Rules are checked in the order they are added. The first matching rule's prompt is used.
|
||||||
|
pub async fn add_dynamic_prompt_rule<F>(
|
||||||
|
&self,
|
||||||
|
condition: F,
|
||||||
|
prompt: String,
|
||||||
|
)
|
||||||
|
where
|
||||||
|
F: Fn(&HashMap<String, Value>) -> bool + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let rule = DynamicPromptRule {
|
||||||
|
condition: Box::new(condition),
|
||||||
|
prompt,
|
||||||
|
};
|
||||||
|
self.dynamic_prompt_rules.write().await.push(rule);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gets the system prompt based on the current agent state and dynamic rules.
|
||||||
|
async fn get_current_prompt(&self) -> String {
|
||||||
|
let rules = self.dynamic_prompt_rules.read().await;
|
||||||
|
let state = self.state.read().await;
|
||||||
|
|
||||||
|
for rule in rules.iter() {
|
||||||
|
if (rule.condition)(&state) {
|
||||||
|
return rule.prompt.clone(); // Return the first matching rule's prompt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.default_prompt.clone() // Fallback to default prompt if no rules match
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Default, Clone)]
|
||||||
|
@ -1178,6 +1236,7 @@ mod tests {
|
||||||
"test_agent_no_tools".to_string(),
|
"test_agent_no_tools".to_string(),
|
||||||
env::var("LLM_API_KEY").ok(),
|
env::var("LLM_API_KEY").ok(),
|
||||||
env::var("LLM_BASE_URL").ok(),
|
env::var("LLM_BASE_URL").ok(),
|
||||||
|
"".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let thread = AgentThread::new(
|
let thread = AgentThread::new(
|
||||||
|
@ -1207,6 +1266,7 @@ mod tests {
|
||||||
"test_agent_with_tools".to_string(),
|
"test_agent_with_tools".to_string(),
|
||||||
env::var("LLM_API_KEY").ok(),
|
env::var("LLM_API_KEY").ok(),
|
||||||
env::var("LLM_BASE_URL").ok(),
|
env::var("LLM_BASE_URL").ok(),
|
||||||
|
"".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Create weather tool with reference to agent
|
// Create weather tool with reference to agent
|
||||||
|
@ -1246,6 +1306,7 @@ mod tests {
|
||||||
"test_agent_multi_step".to_string(),
|
"test_agent_multi_step".to_string(),
|
||||||
env::var("LLM_API_KEY").ok(),
|
env::var("LLM_API_KEY").ok(),
|
||||||
env::var("LLM_BASE_URL").ok(),
|
env::var("LLM_BASE_URL").ok(),
|
||||||
|
"".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let weather_tool = WeatherTool::new(Arc::clone(&agent));
|
let weather_tool = WeatherTool::new(Arc::clone(&agent));
|
||||||
|
@ -1284,6 +1345,7 @@ mod tests {
|
||||||
"test_agent_disabled".to_string(),
|
"test_agent_disabled".to_string(),
|
||||||
env::var("LLM_API_KEY").ok(),
|
env::var("LLM_API_KEY").ok(),
|
||||||
env::var("LLM_BASE_URL").ok(),
|
env::var("LLM_BASE_URL").ok(),
|
||||||
|
"".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Create weather tool
|
// Create weather tool
|
||||||
|
@ -1348,6 +1410,7 @@ mod tests {
|
||||||
"test_agent_state".to_string(),
|
"test_agent_state".to_string(),
|
||||||
env::var("LLM_API_KEY").ok(),
|
env::var("LLM_API_KEY").ok(),
|
||||||
env::var("LLM_BASE_URL").ok(),
|
env::var("LLM_BASE_URL").ok(),
|
||||||
|
"".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Test setting single values
|
// Test setting single values
|
||||||
|
|
|
@ -22,6 +22,9 @@ use crate::{
|
||||||
|
|
||||||
use litellm::AgentMessage;
|
use litellm::AgentMessage;
|
||||||
|
|
||||||
|
// Type alias for the enablement condition closure for tools
|
||||||
|
type ToolEnablementCondition = Box<dyn Fn(&HashMap<String, Value>) -> bool + Send + Sync>;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct BusterSuperAgentOutput {
|
pub struct BusterSuperAgentOutput {
|
||||||
pub message: String,
|
pub message: String,
|
||||||
|
@ -128,7 +131,7 @@ impl BusterMultiAgent {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
|
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
|
||||||
// Create agent (Agent::new no longer takes tools directly)
|
// Create agent, passing the initialization prompt as default
|
||||||
let agent = Arc::new(Agent::new(
|
let agent = Arc::new(Agent::new(
|
||||||
"o3-mini".to_string(),
|
"o3-mini".to_string(),
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -136,8 +139,26 @@ impl BusterMultiAgent {
|
||||||
"buster_super_agent".to_string(),
|
"buster_super_agent".to_string(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
INTIALIZATION_PROMPT.to_string(), // Default prompt
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Define prompt switching conditions
|
||||||
|
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
|
||||||
|
state.contains_key("data_context") && !state.contains_key("plan_available")
|
||||||
|
};
|
||||||
|
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
|
||||||
|
// Example: Trigger analysis prompt once plan is available and metrics/dashboards are not yet available
|
||||||
|
state.contains_key("plan_available")
|
||||||
|
&& !state.contains_key("metrics_available")
|
||||||
|
&& !state.contains_key("dashboards_available")
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add prompt rules (order matters)
|
||||||
|
// The agent will use the prompt associated with the first condition that evaluates to true.
|
||||||
|
// If none match, it uses the default (INITIALIZATION_PROMPT).
|
||||||
|
agent.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()).await;
|
||||||
|
agent.add_dynamic_prompt_rule(needs_analysis_condition, ANALYSIS_PROMPT.to_string()).await;
|
||||||
|
|
||||||
let manager = Self { agent };
|
let manager = Self { agent };
|
||||||
manager.load_tools().await?; // Load tools with conditions
|
manager.load_tools().await?; // Load tools with conditions
|
||||||
Ok(manager)
|
Ok(manager)
|
||||||
|
@ -149,6 +170,20 @@ impl BusterMultiAgent {
|
||||||
existing_agent,
|
existing_agent,
|
||||||
"buster_super_agent".to_string(),
|
"buster_super_agent".to_string(),
|
||||||
));
|
));
|
||||||
|
|
||||||
|
// Re-apply prompt rules for the new agent instance if necessary
|
||||||
|
// (Currently Agent::from_existing copies the default prompt but not rules)
|
||||||
|
let needs_plan_condition = |state: &HashMap<String, Value>| -> bool {
|
||||||
|
state.contains_key("data_context") && !state.contains_key("plan_available")
|
||||||
|
};
|
||||||
|
let needs_analysis_condition = |state: &HashMap<String, Value>| -> bool {
|
||||||
|
state.contains_key("plan_available")
|
||||||
|
&& !state.contains_key("metrics_available")
|
||||||
|
&& !state.contains_key("dashboards_available")
|
||||||
|
};
|
||||||
|
agent.add_dynamic_prompt_rule(needs_plan_condition, CREATE_PLAN_PROMPT.to_string()).await;
|
||||||
|
agent.add_dynamic_prompt_rule(needs_analysis_condition, ANALYSIS_PROMPT.to_string()).await;
|
||||||
|
|
||||||
let manager = Self { agent };
|
let manager = Self { agent };
|
||||||
manager.load_tools().await?; // Load tools with conditions for the new agent instance
|
manager.load_tools().await?; // Load tools with conditions for the new agent instance
|
||||||
Ok(manager)
|
Ok(manager)
|
||||||
|
@ -158,9 +193,10 @@ impl BusterMultiAgent {
|
||||||
&self,
|
&self,
|
||||||
thread: &mut AgentThread,
|
thread: &mut AgentThread,
|
||||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||||
thread.set_developer_message(INTIALIZATION_PROMPT.to_string());
|
// Remove the explicit setting of the developer message here
|
||||||
|
// thread.set_developer_message(INTIALIZATION_PROMPT.to_string());
|
||||||
|
|
||||||
// Get shutdown receiver
|
// Start processing (prompt is handled dynamically within process_thread_with_depth)
|
||||||
let rx = self.stream_process_thread(thread).await?;
|
let rx = self.stream_process_thread(thread).await?;
|
||||||
|
|
||||||
Ok(rx)
|
Ok(rx)
|
||||||
|
|
Loading…
Reference in New Issue