added in conditional logic

This commit is contained in:
dal 2025-02-21 16:02:28 -07:00
parent 9f1d8eff7d
commit 242f648d85
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
21 changed files with 236 additions and 90 deletions

View File

@ -19,12 +19,14 @@ struct ToolCallExecutor<T: ToolExecutor> {
impl<T: ToolExecutor> ToolCallExecutor<T> {
fn new(inner: T) -> Self {
Self { inner: Box::new(inner) }
Self {
inner: Box::new(inner),
}
}
}
#[async_trait::async_trait]
impl<T: ToolExecutor + Send + Sync> ToolExecutor for ToolCallExecutor<T>
impl<T: ToolExecutor + Send + Sync> ToolExecutor for ToolCallExecutor<T>
where
T::Params: serde::de::DeserializeOwned,
T::Output: serde::Serialize,
@ -45,6 +47,10 @@ where
fn get_name(&self) -> String {
self.inner.get_name()
}
async fn is_enabled(&self) -> bool {
self.inner.is_enabled().await
}
}
// Add this near the top of the file, with other trait implementations
@ -64,6 +70,10 @@ impl<T: ToolExecutor<Output = Value, Params = Value> + Send + Sync> ToolExecutor
fn get_name(&self) -> String {
(**self).get_name()
}
async fn is_enabled(&self) -> bool {
(**self).is_enabled().await
}
}
#[derive(Clone)]
@ -74,7 +84,11 @@ pub struct Agent {
/// Client for communicating with the LLM provider
llm_client: LiteLLMClient,
/// Registry of available tools, mapped by their names
tools: Arc<RwLock<HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>>>,
tools: Arc<
RwLock<
HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
>,
>,
/// The model identifier to use (e.g., "gpt-4")
model: String,
/// Flexible state storage for maintaining memory across interactions
@ -136,6 +150,24 @@ impl Agent {
}
}
pub async fn get_enabled_tools(&self) -> Vec<Tool> {
// Collect all registered tools and their schemas
let tools = self.tools.read().await;
let mut enabled_tools = Vec::new();
for (_, tool) in tools.iter() {
if tool.is_enabled().await {
enabled_tools.push(Tool {
tool_type: "function".to_string(),
function: tool.get_schema(),
});
}
}
enabled_tools
}
/// Update the stream sender for this agent
pub async fn set_stream_sender(&self, tx: mpsc::Sender<Result<Message>>) {
*self.stream_tx.write().await = tx;
@ -308,16 +340,7 @@ impl Agent {
}
// Collect all registered tools and their schemas
let tools: Vec<Tool> = self
.tools
.read()
.await
.iter()
.map(|(name, tool)| Tool {
tool_type: "function".to_string(),
function: tool.get_schema(),
})
.collect();
let tools = self.get_enabled_tools().await;
// Create the tool-enabled request
let request = ChatCompletionRequest {
@ -382,6 +405,7 @@ impl Agent {
result_str,
tool_call.id.clone(),
Some(tool_call.function.name.clone()),
// TODO: need the progress for streaming
None,
);
@ -461,7 +485,10 @@ impl PendingToolCall {
pub trait AgentExt {
fn get_agent(&self) -> &Arc<Agent>;
async fn stream_process_thread(&self, thread: &AgentThread) -> Result<mpsc::Receiver<Result<Message>>> {
async fn stream_process_thread(
&self,
thread: &AgentThread,
) -> Result<mpsc::Receiver<Result<Message>>> {
(*self.get_agent()).process_thread_streaming(thread).await
}
@ -540,6 +567,10 @@ mod tests {
Ok(result)
}
async fn is_enabled(&self) -> bool {
true
}
fn get_schema(&self) -> Value {
json!({
"name": "get_weather",

View File

@ -1,27 +1,22 @@
use anyhow::{anyhow, Result};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::mpsc::Receiver;
use tracing::{debug, info};
use uuid::Uuid;
use crate::utils::tools::agents_as_tools::{DashboardAgentTool, MetricAgentTool};
use crate::utils::tools::file_tools::{send_assets_to_user, SendAssetsToUserTool};
use crate::utils::tools::file_tools::SendAssetsToUserTool;
use crate::utils::{
agent::{Agent, AgentExt, AgentThread},
tools::{
agents_as_tools::ExploratoryAgentTool,
file_tools::{
CreateFilesTool, ModifyFilesTool, OpenFilesTool, SearchDataCatalogTool, SearchFilesTool,
},
file_tools::{SearchDataCatalogTool, SearchFilesTool},
IntoValueTool, ToolExecutor,
},
};
use litellm::{Message as AgentMessage, ToolCall};
use super::MetricAgent;
use litellm::Message as AgentMessage;
#[derive(Debug, Serialize, Deserialize)]
pub struct ManagerAgentOutput {
@ -159,7 +154,21 @@ impl ManagerAgent {
) -> Result<Receiver<Result<AgentMessage, anyhow::Error>>> {
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
self.stream_process_thread(thread).await
let mut rx = self.stream_process_thread(thread).await?;
while let Some(message) = rx.recv().await {
let message = message?;
if let AgentMessage::Tool {
id,
content,
tool_call_id,
name,
progress,
} = message
{}
}
Ok(rx)
}
}

View File

@ -26,7 +26,11 @@ impl AgentThread {
/// Set the developer message in the thread
pub fn set_developer_message(&mut self, message: String) {
// Look for an existing developer message
if let Some(pos) = self.messages.iter().position(|msg| matches!(msg, Message::Developer { .. })) {
if let Some(pos) = self
.messages
.iter()
.position(|msg| matches!(msg, Message::Developer { .. }))
{
// Update existing developer message
self.messages[pos] = Message::developer(message);
} else {
@ -37,8 +41,17 @@ impl AgentThread {
/// Remove the most recent assistant message from the thread
pub fn remove_last_assistant_message(&mut self) {
if let Some(pos) = self.messages.iter().rposition(|msg| matches!(msg, Message::Assistant { .. })) {
if let Some(pos) = self
.messages
.iter()
.rposition(|msg| matches!(msg, Message::Assistant { .. }))
{
self.messages.remove(pos);
}
}
/// Add a user message to the thread
pub fn add_user_message(&mut self, content: String) {
self.messages.push(Message::user(content));
}
}

View File

@ -42,7 +42,14 @@ impl ToolExecutor for DashboardAgentTool {
type Params = DashboardAgentParams;
fn get_name(&self) -> String {
"create_or_modify_dashboard".to_string()
"create_or_modify_dashboards".to_string()
}
async fn is_enabled(&self) -> bool {
match self.agent.get_state_value("data_context").await {
Some(_) => true,
None => false,
}
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
@ -63,6 +70,8 @@ impl ToolExecutor for DashboardAgentTool {
current_thread.remove_last_assistant_message();
println!("DashboardAgentTool: Last assistant message removed");
current_thread.add_user_message(params.ticket_description);
println!("DashboardAgentTool: Starting dashboard agent run");
// Run the dashboard agent and get the output
let output = dashboard_agent.run(&mut current_thread).await?;

View File

@ -1,6 +1,5 @@
use anyhow::Result;
use async_trait::async_trait;
use litellm::Message as AgentMessage;
use serde::Deserialize;
use serde_json::Value;
use std::sync::Arc;
@ -33,6 +32,13 @@ impl ToolExecutor for ExploratoryAgentTool {
"explore_data".to_string()
}
async fn is_enabled(&self) -> bool {
match self.agent.get_state_value("data_context").await {
Some(_) => true,
None => false,
}
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
// Create and initialize the agent
let exploratory_agent = ExploratoryAgent::from_existing(&self.agent).await?;
@ -45,6 +51,8 @@ impl ToolExecutor for ExploratoryAgentTool {
current_thread.remove_last_assistant_message();
current_thread.add_user_message(params.ticket_description);
// Run the exploratory agent and get the receiver
let _rx = exploratory_agent.run(&mut current_thread).await?;

View File

@ -35,6 +35,13 @@ impl ToolExecutor for MetricAgentTool {
"create_or_modify_metrics".to_string()
}
async fn is_enabled(&self) -> bool {
match self.agent.get_state_value("data_context").await {
Some(_) => true,
None => false,
}
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
// Create and initialize the agent
let metric_agent = MetricAgent::from_existing(&self.agent).await?;
@ -46,13 +53,10 @@ impl ToolExecutor for MetricAgentTool {
.await
.ok_or_else(|| anyhow::anyhow!("No current thread"))?;
// Parse input parameters
let agent_input = MetricAgentInput {
ticket_description: params.ticket_description,
};
current_thread.remove_last_assistant_message();
current_thread.add_user_message(params.ticket_description);
// Run the metric agent and get the receiver
let _rx = metric_agent.run(&mut current_thread).await?;

View File

@ -2,13 +2,10 @@ use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use std::sync::Arc;
use uuid::Uuid;
use crate::utils::{
tools::ToolExecutor,
agent::Agent,
};
use crate::utils::{agent::Agent, tools::ToolExecutor};
use litellm::ToolCall;
#[derive(Debug, Serialize, Deserialize)]
@ -25,7 +22,7 @@ pub struct PlanInput {
}
pub struct CreatePlan {
agent: Arc<Agent>
agent: Arc<Agent>,
}
impl CreatePlan {
@ -45,7 +42,7 @@ impl ToolExecutor for CreatePlan {
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let input = params;
// TODO: Implement actual plan creation logic here
// This would typically involve:
// 1. Validating the markdown content
@ -59,6 +56,10 @@ impl ToolExecutor for CreatePlan {
})
}
async fn is_enabled(&self) -> bool {
true
}
fn get_schema(&self) -> Value {
serde_json::json!({
"name": "create_plan",
@ -79,4 +80,4 @@ impl ToolExecutor for CreatePlan {
}
})
}
}
}

View File

@ -47,6 +47,10 @@ impl ToolExecutor for ReviewPlan {
"review_plan".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let input = params;

View File

@ -57,6 +57,10 @@ impl ToolExecutor for SqlQuery {
"run_sql".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let input = params;
let mut results = Vec::new();

View File

@ -118,6 +118,10 @@ impl ToolExecutor for CreateDashboardFilesTool {
"create_dashboard_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();

View File

@ -167,6 +167,14 @@ impl ToolExecutor for CreateFilesTool {
type Output = CreateFilesOutput;
type Params = CreateFilesParams;
fn get_name(&self) -> String {
"create_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();
@ -314,10 +322,6 @@ impl ToolExecutor for CreateFilesTool {
})
}
fn get_name(&self) -> String {
"create_files".to_string()
}
fn get_schema(&self) -> Value {
serde_json::json!({
"name": "create_files",

View File

@ -113,6 +113,10 @@ impl ToolExecutor for CreateMetricFilesTool {
"create_metric_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();

View File

@ -236,6 +236,10 @@ impl ToolExecutor for ModifyDashboardFilesTool {
"modify_dashboard_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();

View File

@ -249,6 +249,14 @@ impl ToolExecutor for ModifyFilesTool {
type Output = ModifyFilesOutput;
type Params = ModifyFilesParams;
fn get_name(&self) -> String {
"modify_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();
@ -467,10 +475,6 @@ impl ToolExecutor for ModifyFilesTool {
Ok(output)
}
fn get_name(&self) -> String {
"modify_files".to_string()
}
fn get_schema(&self) -> Value {
serde_json::json!({
"name": "modify_files",

View File

@ -249,6 +249,10 @@ impl ToolExecutor for ModifyMetricFilesTool {
"modify_metric_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();

View File

@ -67,6 +67,14 @@ impl ToolExecutor for OpenFilesTool {
type Output = OpenFilesOutput;
type Params = OpenFilesParams;
fn get_name(&self) -> String {
"open_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();
@ -219,10 +227,6 @@ impl ToolExecutor for OpenFilesTool {
})
}
fn get_name(&self) -> String {
"open_files".to_string()
}
fn get_schema(&self) -> Value {
serde_json::json!({
"name": "open_files",

View File

@ -85,6 +85,10 @@ impl SearchDataCatalogTool {
Self { agent }
}
async fn is_enabled(&self) -> bool {
true
}
fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result<String> {
let datasets_json = datasets
.iter()
@ -125,6 +129,7 @@ impl SearchDataCatalogTool {
user_id: user_id.to_string(),
session_id: session_id.to_string(),
}),
reasoning_effort: Some("low".to_string()),
..Default::default()
};
@ -259,6 +264,10 @@ impl ToolExecutor for SearchDataCatalogTool {
})
}
async fn is_enabled(&self) -> bool {
true
}
fn get_name(&self) -> String {
"search_data_catalog".to_string()
}

View File

@ -90,6 +90,10 @@ impl SearchFilesTool {
Self { agent }
}
async fn is_enabled(&self) -> bool {
true
}
fn format_search_prompt(query_params: &[String], files_array: &[Value]) -> Result<String> {
let queries_joined = query_params.join("\n");
let files_json = serde_json::to_string_pretty(&files_array)?;
@ -166,6 +170,10 @@ impl ToolExecutor for SearchFilesTool {
"search_files".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
let start_time = Instant::now();

View File

@ -43,47 +43,54 @@ impl ToolExecutor for SendAssetsToUserTool {
"decide_assets_to_return".to_string()
}
async fn is_enabled(&self) -> bool {
match self.agent.get_state_value("files_created").await {
Some(_) => true,
None => false,
}
}
fn get_schema(&self) -> Value {
serde_json::json!({
"name": "decide_assets_to_return",
"description": "Use after you have created or modified any assets (metrics or dashboards) to specify exactly which assets to present in the final response. If you have not created or modified any assets, do not call this action.",
"strict": true,
"parameters": {
"type": "object",
"required": [
"assets_to_return",
"ticket_description"
],
"properties": {
"assets_to_return": {
"type": "array",
"description": "List of assets to present in the final response, each with an ID and a name",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for the asset"
},
"name": {
"type": "string",
"description": "Name of the asset"
}
"name": "decide_assets_to_return",
"description": "Use after you have created or modified any assets (metrics or dashboards) to specify exactly which assets to present in the final response. If you have not created or modified any assets, do not call this action.",
"strict": true,
"parameters": {
"type": "object",
"required": [
"assets_to_return",
"ticket_description"
],
"properties": {
"assets_to_return": {
"type": "array",
"description": "List of assets to present in the final response, each with an ID and a name",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for the asset"
},
"required": [
"id",
"name"
],
"additionalProperties": false
}
},
"ticket_description": {
"type": "string",
"description": "Description of the ticket related to the assets"
"name": {
"type": "string",
"description": "Name of the asset"
}
},
"required": [
"id",
"name"
],
"additionalProperties": false
}
},
"additionalProperties": false
}
})
"ticket_description": {
"type": "string",
"description": "Description of the ticket related to the assets"
}
},
"additionalProperties": false
}
})
}
}

View File

@ -55,4 +55,8 @@ impl ToolExecutor for SendMessageToUser {
fn get_name(&self) -> String {
"send_message_to_user".to_string()
}
async fn is_enabled(&self) -> bool {
true
}
}

View File

@ -31,6 +31,9 @@ pub trait ToolExecutor: Send + Sync {
/// Get the name of this tool
fn get_name(&self) -> String;
/// Check if this tool is currently enabled
async fn is_enabled(&self) -> bool;
}
/// A wrapper type that converts any ToolExecutor to one that outputs Value
@ -61,6 +64,10 @@ impl<T: ToolExecutor> ToolExecutor for ValueToolExecutor<T> {
fn get_name(&self) -> String {
self.inner.get_name()
}
async fn is_enabled(&self) -> bool {
self.inner.is_enabled().await
}
}
/// Extension trait to add value conversion methods to ToolExecutor