mirror of https://github.com/buster-so/buster.git
moved agent into own lib
This commit is contained in:
parent
6604b9789e
commit
ce1fb87b19
|
@ -1,6 +1,7 @@
|
|||
---
|
||||
description: This is helpful for building libs for our web server to interact with.
|
||||
globs: libs/*
|
||||
globs: */libs/*
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
# Library Construction Guide
|
||||
|
@ -23,37 +24,39 @@ libs/
|
|||
[package]
|
||||
name = "my_lib"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
edition = "2021"
|
||||
|
||||
# Inherit workspace dependencies
|
||||
# This ensures consistent versions across the project
|
||||
# Dependencies should be inherited from workspace
|
||||
[dependencies]
|
||||
serde.workspace = true # If defined in workspace
|
||||
tokio.workspace = true # If defined in workspace
|
||||
thiserror.workspace = true # If defined in workspace
|
||||
|
||||
# Library-specific dependencies (not in workspace)
|
||||
some-specific-dep = "1.0"
|
||||
# Use workspace dependencies
|
||||
anyhow = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
diesel = { workspace = true }
|
||||
diesel-async = { workspace = true }
|
||||
# Add other workspace dependencies as needed
|
||||
|
||||
# Development dependencies
|
||||
[dev-dependencies]
|
||||
tokio-test.workspace = true # If defined in workspace
|
||||
assert_matches.workspace = true # If defined in workspace
|
||||
tokio-test = { workspace = true }
|
||||
# Add other workspace dev dependencies as needed
|
||||
|
||||
# Feature flags - can inherit from workspace or be lib-specific
|
||||
# Feature flags
|
||||
[features]
|
||||
default = []
|
||||
async = ["tokio"] # Example of a library-specific feature
|
||||
# Define library-specific features here
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Workspace Integration
|
||||
- Use `.workspace = true` for common fields and dependencies
|
||||
- Only specify library-specific versions for unique dependencies
|
||||
- Inherit common development dependencies from workspace
|
||||
- Use `{ workspace = true }` for common dependencies
|
||||
- Never specify library-specific versions for dependencies that exist in the workspace
|
||||
- All dependencies should be managed by the workspace
|
||||
- Keep feature flags modular and specific to the library's needs
|
||||
|
||||
### 2. Library Structure
|
||||
|
@ -125,7 +128,7 @@ pub enum Error {
|
|||
- Use workspace-level CI/CD pipelines
|
||||
|
||||
### 8. Dependencies
|
||||
- Prefer workspace-level dependencies
|
||||
- Only add library-specific dependencies when necessary
|
||||
- All dependencies should be inherited from the workspace
|
||||
- Never add library-specific dependency versions
|
||||
- Keep dependencies minimal and focused
|
||||
- Document any deviations from workspace versions
|
||||
- The workspace will manage all dependency versions
|
|
@ -3,7 +3,8 @@ members = [
|
|||
".",
|
||||
"libs/handlers",
|
||||
"libs/litellm",
|
||||
"libs/database"
|
||||
"libs/database",
|
||||
"libs/agents"
|
||||
]
|
||||
|
||||
# Define shared dependencies for all workspace members
|
||||
|
@ -19,6 +20,22 @@ uuid = { version = "1.8", features = ["serde", "v4"] }
|
|||
diesel = { version = "2", features = ["uuid", "chrono", "serde_json", "postgres"] }
|
||||
diesel-async = { version = "0.5.2", features = ["postgres", "bb8"] }
|
||||
futures = "0.3.30"
|
||||
async-trait = "0.1.85"
|
||||
thiserror = "1.0.58"
|
||||
tokio-test = "0.4.3"
|
||||
futures-util = "0.3"
|
||||
reqwest = { version = "0.12.4", features = ["json", "stream"] }
|
||||
dotenv = "0.15.0"
|
||||
mockito = "1.2.0"
|
||||
bb8-redis = "0.18.0"
|
||||
indexmap = { version = "2.2.6", features = ["serde"] }
|
||||
once_cell = "1.20.2"
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
rustls-native-certs = "0.8"
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "chrono", "json"] }
|
||||
tokio-postgres = "0.7"
|
||||
tokio-postgres-rustls = "0.13"
|
||||
regex = "1.10.6"
|
||||
|
||||
[package]
|
||||
name = "bi_api"
|
||||
|
@ -39,50 +56,49 @@ tracing = { workspace = true }
|
|||
uuid = { workspace = true }
|
||||
diesel = { workspace = true }
|
||||
diesel-async = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
bb8-redis = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
rustls-native-certs = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio-postgres = { workspace = true }
|
||||
tokio-postgres-rustls = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
|
||||
# Local dependencies
|
||||
handlers = { path = "libs/handlers" }
|
||||
litellm = { path = "libs/litellm" }
|
||||
database = { path = "libs/database" }
|
||||
agents = { path = "libs/agents" }
|
||||
|
||||
# Other dependencies specific to the main app
|
||||
arrow = { version = "54.0.0", features = ["json"] }
|
||||
async-compression = { version = "0.4.11", features = ["tokio"] }
|
||||
axum = { version = "0.7.5", features = ["ws"] }
|
||||
base64 = "0.21"
|
||||
bb8-redis = "0.18.0"
|
||||
cohere-rust = "0.6.0"
|
||||
dotenv = "0.15.0"
|
||||
futures = "0.3.30"
|
||||
gcp-bigquery-client = "0.24.1"
|
||||
indexmap = { version = "2.2.6", features = ["serde"] }
|
||||
jsonwebtoken = "9.3.0"
|
||||
lazy_static = "1.4.0"
|
||||
num-traits = "0.2.19"
|
||||
once_cell = "1.20.2"
|
||||
rand = "0.8.5"
|
||||
redis = { version = "0.27.5", features = [
|
||||
"tokio-comp",
|
||||
"tokio-rustls-comp",
|
||||
"tls-rustls-webpki-roots",
|
||||
] }
|
||||
regex = "1.10.6"
|
||||
reqwest = { version = "0.12.4", features = ["json", "stream"] }
|
||||
resend-rs = "0.10.0"
|
||||
sentry = { version = "0.35.0", features = ["tokio", "sentry-tracing"] }
|
||||
serde_urlencoded = "0.7.1"
|
||||
snowflake-api = "0.11.0"
|
||||
sqlparser = { version = "0.53.0", features = ["visitor"] }
|
||||
sqlx = { version = "0.8", features = [
|
||||
"runtime-tokio",
|
||||
"tls-rustls",
|
||||
"postgres",
|
||||
"mysql",
|
||||
"macros",
|
||||
"uuid",
|
||||
"chrono",
|
||||
"bigdecimal",
|
||||
] }
|
||||
tempfile = "3.10.1"
|
||||
tiberius = { version = "0.12.2", default-features = false, features = [
|
||||
"chrono",
|
||||
|
@ -102,21 +118,14 @@ tower-http = { version = "0.6.2", features = [
|
|||
] }
|
||||
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
|
||||
url = "2.5.1"
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
rustls-native-certs = "0.8"
|
||||
tokio-postgres-rustls = "0.13"
|
||||
tokio-postgres = "0.7"
|
||||
futures-util = "0.3"
|
||||
rayon = "1.10.0"
|
||||
diesel_migrations = "2.0.0"
|
||||
serde_yaml = "0.9.34"
|
||||
html-escape = "0.2.13"
|
||||
async-trait = "0.1.85"
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.2.0"
|
||||
async-trait = "0.1.77"
|
||||
tokio = { version = "1.0", features = ["full", "test-util"] }
|
||||
mockito = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
[profile.release]
|
||||
debug = false
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "agents"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# Inherit workspace dependencies
|
||||
[dependencies]
|
||||
# Use workspace dependencies
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
litellm = { path = "../litellm" }
|
||||
serde_json = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
|
||||
# Development dependencies
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
mockito = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
|
@ -0,0 +1,809 @@
|
|||
use anyhow::Result;
|
||||
use litellm::{
|
||||
ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient, Message, Metadata, Tool,
|
||||
ToolCall, ToolChoice,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::tools::ToolExecutor;
|
||||
use crate::models::AgentThread;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentError(pub String);
|
||||
|
||||
impl std::error::Error for AgentError {}
|
||||
|
||||
impl std::fmt::Display for AgentError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
type MessageResult = Result<Message, AgentError>;
|
||||
|
||||
/// A wrapper type that converts ToolCall parameters to Value before executing
|
||||
struct ToolCallExecutor<T: ToolExecutor> {
|
||||
inner: Box<T>,
|
||||
}
|
||||
|
||||
impl<T: ToolExecutor> ToolCallExecutor<T> {
|
||||
fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner: Box::new(inner),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T: ToolExecutor + Send + Sync> ToolExecutor for ToolCallExecutor<T>
|
||||
where
|
||||
T::Params: serde::de::DeserializeOwned,
|
||||
T::Output: serde::Serialize,
|
||||
{
|
||||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
|
||||
let params = serde_json::from_value(params)?;
|
||||
let result = self.inner.execute(params).await?;
|
||||
Ok(serde_json::to_value(result)?)
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
self.inner.get_schema()
|
||||
}
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
self.inner.get_name()
|
||||
}
|
||||
|
||||
async fn is_enabled(&self) -> bool {
|
||||
self.inner.is_enabled().await
|
||||
}
|
||||
}
|
||||
|
||||
// Add this near the top of the file, with other trait implementations
|
||||
#[async_trait::async_trait]
|
||||
impl<T: ToolExecutor<Output = Value, Params = Value> + Send + Sync> ToolExecutor for Box<T> {
|
||||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
|
||||
(**self).execute(params).await
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
(**self).get_schema()
|
||||
}
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
(**self).get_name()
|
||||
}
|
||||
|
||||
async fn is_enabled(&self) -> bool {
|
||||
(**self).is_enabled().await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
/// and coordinating tool executions. It maintains a registry of available tools
|
||||
/// and handles the recursive nature of tool calls.
|
||||
pub struct Agent {
|
||||
/// Client for communicating with the LLM provider
|
||||
llm_client: LiteLLMClient,
|
||||
/// Registry of available tools, mapped by their names
|
||||
tools: Arc<
|
||||
RwLock<
|
||||
HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
|
||||
>,
|
||||
>,
|
||||
/// The model identifier to use (e.g., "gpt-4")
|
||||
model: String,
|
||||
/// Flexible state storage for maintaining memory across interactions
|
||||
state: Arc<RwLock<HashMap<String, Value>>>,
|
||||
/// The current thread being processed, if any
|
||||
current_thread: Arc<RwLock<Option<AgentThread>>>,
|
||||
/// Sender for streaming messages from this agent and sub-agents
|
||||
stream_tx: Arc<RwLock<broadcast::Sender<MessageResult>>>,
|
||||
/// The user ID for the current thread
|
||||
user_id: Uuid,
|
||||
/// The session ID for the current thread
|
||||
session_id: Uuid,
|
||||
/// Agent name
|
||||
name: String,
|
||||
/// Shutdown signal sender
|
||||
shutdown_tx: Arc<RwLock<broadcast::Sender<()>>>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
/// Create a new Agent instance with a specific LLM client and model
|
||||
pub fn new(
|
||||
model: String,
|
||||
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
name: String,
|
||||
) -> Self {
|
||||
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||
|
||||
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
|
||||
|
||||
// Create a broadcast channel with buffer size 1000
|
||||
let (tx, _rx) = broadcast::channel(1000);
|
||||
// Create shutdown channel with buffer size 1
|
||||
let (shutdown_tx, _) = broadcast::channel(1);
|
||||
|
||||
Self {
|
||||
llm_client,
|
||||
tools: Arc::new(RwLock::new(tools)),
|
||||
model,
|
||||
state: Arc::new(RwLock::new(HashMap::new())),
|
||||
current_thread: Arc::new(RwLock::new(None)),
|
||||
stream_tx: Arc::new(RwLock::new(tx)),
|
||||
user_id,
|
||||
session_id,
|
||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Agent that shares state and stream with an existing agent
|
||||
pub fn from_existing(existing_agent: &Agent, name: String) -> Self {
|
||||
let llm_api_key = env::var("LLM_API_KEY").expect("LLM_API_KEY must be set");
|
||||
let llm_base_url = env::var("LLM_BASE_URL").expect("LLM_API_BASE must be set");
|
||||
|
||||
let llm_client = LiteLLMClient::new(Some(llm_api_key), Some(llm_base_url));
|
||||
|
||||
Self {
|
||||
llm_client,
|
||||
tools: Arc::new(RwLock::new(HashMap::new())),
|
||||
model: existing_agent.model.clone(),
|
||||
state: Arc::clone(&existing_agent.state),
|
||||
current_thread: Arc::clone(&existing_agent.current_thread),
|
||||
stream_tx: Arc::clone(&existing_agent.stream_tx),
|
||||
user_id: existing_agent.user_id,
|
||||
session_id: existing_agent.session_id,
|
||||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_enabled_tools(&self) -> Vec<Tool> {
|
||||
// Collect all registered tools and their schemas
|
||||
let tools = self.tools.read().await;
|
||||
|
||||
let mut enabled_tools = Vec::new();
|
||||
|
||||
for (_, tool) in tools.iter() {
|
||||
if tool.is_enabled().await {
|
||||
enabled_tools.push(Tool {
|
||||
tool_type: "function".to_string(),
|
||||
function: tool.get_schema(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
enabled_tools
|
||||
}
|
||||
|
||||
/// Get a new receiver for the broadcast channel
|
||||
pub async fn get_stream_receiver(&self) -> broadcast::Receiver<MessageResult> {
|
||||
self.stream_tx.read().await.subscribe()
|
||||
}
|
||||
|
||||
/// Get a clone of the current stream sender
|
||||
pub async fn get_stream_sender(&self) -> broadcast::Sender<MessageResult> {
|
||||
self.stream_tx.read().await.clone()
|
||||
}
|
||||
|
||||
/// Get a value from the agent's state by key
|
||||
pub async fn get_state_value(&self, key: &str) -> Option<Value> {
|
||||
self.state.read().await.get(key).cloned()
|
||||
}
|
||||
|
||||
/// Set a value in the agent's state
|
||||
pub async fn set_state_value(&self, key: String, value: Value) {
|
||||
self.state.write().await.insert(key, value);
|
||||
}
|
||||
|
||||
/// Update multiple state values at once using a closure
|
||||
pub async fn update_state<F>(&self, f: F)
|
||||
where
|
||||
F: FnOnce(&mut HashMap<String, Value>),
|
||||
{
|
||||
let mut state = self.state.write().await;
|
||||
f(&mut state);
|
||||
}
|
||||
|
||||
/// Clear all state values
|
||||
pub async fn clear_state(&self) {
|
||||
self.state.write().await.clear();
|
||||
}
|
||||
|
||||
/// Get the current thread being processed, if any
|
||||
pub async fn get_current_thread(&self) -> Option<AgentThread> {
|
||||
self.current_thread.read().await.clone()
|
||||
}
|
||||
|
||||
pub fn get_user_id(&self) -> Uuid {
|
||||
self.user_id
|
||||
}
|
||||
|
||||
pub fn get_session_id(&self) -> Uuid {
|
||||
self.session_id
|
||||
}
|
||||
|
||||
pub fn get_model_name(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
/// Get the complete conversation history of the current thread
|
||||
pub async fn get_conversation_history(&self) -> Option<Vec<Message>> {
|
||||
self.current_thread
|
||||
.read()
|
||||
.await
|
||||
.as_ref()
|
||||
.map(|thread| thread.messages.clone())
|
||||
}
|
||||
|
||||
/// Update the current thread with a new message
|
||||
async fn update_current_thread(&self, message: Message) -> Result<()> {
|
||||
let mut thread_lock = self.current_thread.write().await;
|
||||
if let Some(thread) = thread_lock.as_mut() {
|
||||
thread.messages.push(message);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Add a new tool with the agent
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the tool, used to identify it in tool calls
|
||||
/// * `tool` - The tool implementation that will be executed
|
||||
pub async fn add_tool(&self, name: String, tool: impl ToolExecutor<Output = Value> + 'static) {
|
||||
let mut tools = self.tools.write().await;
|
||||
tools.insert(name, Box::new(ToolCallExecutor::new(tool)));
|
||||
}
|
||||
|
||||
/// Add multiple tools to the agent at once
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tools` - HashMap of tool names and their implementations
|
||||
pub async fn add_tools<E: ToolExecutor<Output = Value> + 'static>(
|
||||
&self,
|
||||
tools: HashMap<String, E>,
|
||||
) {
|
||||
let mut tools_map = self.tools.write().await;
|
||||
for (name, tool) in tools {
|
||||
tools_map.insert(name, Box::new(ToolCallExecutor::new(tool)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a thread of conversation, potentially executing tools and continuing
|
||||
/// the conversation recursively until a final response is reached.
|
||||
///
|
||||
/// This is a convenience wrapper around process_thread_streaming that collects
|
||||
/// all streamed messages into a final response.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `thread` - The conversation thread to process
|
||||
///
|
||||
/// # Returns
|
||||
/// * A Result containing the final Message from the assistant
|
||||
pub async fn process_thread(&self, thread: &AgentThread) -> Result<Message> {
|
||||
let mut rx = self.process_thread_streaming(thread).await?;
|
||||
|
||||
let mut final_message = None;
|
||||
while let Ok(msg) = rx.recv().await {
|
||||
final_message = Some(msg?);
|
||||
}
|
||||
|
||||
final_message.ok_or_else(|| anyhow::anyhow!("No messages received from processing"))
|
||||
}
|
||||
|
||||
/// Process a thread of conversation with streaming responses. This is the primary
|
||||
/// interface for processing conversations.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `thread` - The conversation thread to process
|
||||
///
|
||||
/// # Returns
|
||||
/// * A Result containing a receiver for streamed messages
|
||||
pub async fn process_thread_streaming(
|
||||
&self,
|
||||
thread: &AgentThread,
|
||||
) -> Result<broadcast::Receiver<MessageResult>> {
|
||||
// Spawn the processing task
|
||||
let agent_clone = self.clone();
|
||||
let thread_clone = thread.clone();
|
||||
|
||||
// Get shutdown receiver
|
||||
let mut shutdown_rx = self.get_shutdown_receiver().await;
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
result = agent_clone.process_thread_with_depth(&thread_clone, 0) => {
|
||||
if let Err(e) = result {
|
||||
let err_msg = format!("Error processing thread: {:?}", e);
|
||||
let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
}
|
||||
},
|
||||
_ = shutdown_rx.recv() => {
|
||||
let _ = agent_clone.get_stream_sender().await.send(
|
||||
Ok(Message::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Processing interrupted due to shutdown signal".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(agent_clone.name.clone()),
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(self.get_stream_receiver().await)
|
||||
}
|
||||
|
||||
async fn process_thread_with_depth(
|
||||
&self,
|
||||
thread: &AgentThread,
|
||||
recursion_depth: u32,
|
||||
) -> Result<()> {
|
||||
// Set the initial thread
|
||||
{
|
||||
let mut current = self.current_thread.write().await;
|
||||
*current = Some(thread.clone());
|
||||
}
|
||||
|
||||
if recursion_depth >= 30 {
|
||||
let message = Message::assistant(
|
||||
Some("max_recursion_depth_message".to_string()),
|
||||
Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(self.name.clone()),
|
||||
);
|
||||
self.get_stream_sender().await.send(Ok(message))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Collect all registered tools and their schemas
|
||||
let tools = self.get_enabled_tools().await;
|
||||
|
||||
// Create the tool-enabled request
|
||||
let request = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: thread.messages.clone(),
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
tool_choice: Some(ToolChoice::Required),
|
||||
metadata: Some(Metadata {
|
||||
generation_name: "agent".to_string(),
|
||||
user_id: thread.user_id.to_string(),
|
||||
session_id: thread.id.to_string(),
|
||||
trace_id: thread.id.to_string(),
|
||||
}),
|
||||
store: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Get the response from the LLM
|
||||
let response = match self.llm_client.chat_completion(request).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow::anyhow!("Error processing thread: {:?}", e)),
|
||||
};
|
||||
|
||||
let llm_message = &response.choices[0].message;
|
||||
|
||||
// Create the assistant message
|
||||
let message = match llm_message {
|
||||
Message::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
} => Message::assistant(None, content.clone(), tool_calls.clone(), None, None, Some(self.name.clone())),
|
||||
_ => return Err(anyhow::anyhow!("Expected assistant message from LLM")),
|
||||
};
|
||||
|
||||
// Broadcast the assistant message as soon as we receive it
|
||||
self.get_stream_sender().await.send(Ok(message.clone()))?;
|
||||
|
||||
// Update thread with assistant message
|
||||
self.update_current_thread(message.clone()).await?;
|
||||
|
||||
// If this is an auto response without tool calls, it means we're done
|
||||
if let Message::Assistant {
|
||||
tool_calls: None, ..
|
||||
} = &llm_message
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If the LLM wants to use tools, execute them and continue
|
||||
if let Message::Assistant {
|
||||
tool_calls: Some(tool_calls),
|
||||
..
|
||||
} = &llm_message
|
||||
{
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Execute each requested tool
|
||||
for tool_call in tool_calls {
|
||||
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
|
||||
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
||||
let result = tool.execute(params).await?;
|
||||
println!("Tool Call result: {:?}", result);
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_message = Message::tool(
|
||||
None,
|
||||
result_str,
|
||||
tool_call.id.clone(),
|
||||
Some(tool_call.function.name.clone()),
|
||||
None,
|
||||
);
|
||||
|
||||
// Broadcast the tool message as soon as we receive it
|
||||
self.get_stream_sender()
|
||||
.await
|
||||
.send(Ok(tool_message.clone()))?;
|
||||
|
||||
// Update thread with tool response
|
||||
self.update_current_thread(tool_message.clone()).await?;
|
||||
results.push(tool_message);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new thread with the tool results and continue recursively
|
||||
let mut new_thread = thread.clone();
|
||||
new_thread.messages.push(message);
|
||||
new_thread.messages.extend(results);
|
||||
|
||||
Box::pin(self.process_thread_with_depth(&new_thread, recursion_depth + 1)).await
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a receiver for the shutdown signal
|
||||
pub async fn get_shutdown_receiver(&self) -> broadcast::Receiver<()> {
|
||||
self.shutdown_tx.read().await.subscribe()
|
||||
}
|
||||
|
||||
/// Signal shutdown to all receivers
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
// Send shutdown signal
|
||||
self.shutdown_tx.read().await.send(())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a reference to the tools map
|
||||
pub async fn get_tools(
|
||||
&self,
|
||||
) -> tokio::sync::RwLockReadGuard<
|
||||
'_,
|
||||
HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
|
||||
> {
|
||||
self.tools.read().await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct PendingToolCall {
|
||||
id: Option<String>,
|
||||
call_type: Option<String>,
|
||||
function_name: Option<String>,
|
||||
arguments: String,
|
||||
code_interpreter: Option<Value>,
|
||||
retrieval: Option<Value>,
|
||||
}
|
||||
|
||||
impl PendingToolCall {
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn update_from_delta(&mut self, tool_call: &DeltaToolCall) {
|
||||
if let Some(id) = &tool_call.id {
|
||||
self.id = Some(id.clone());
|
||||
}
|
||||
if let Some(call_type) = &tool_call.call_type {
|
||||
self.call_type = Some(call_type.clone());
|
||||
}
|
||||
if let Some(function) = &tool_call.function {
|
||||
if let Some(name) = &function.name {
|
||||
self.function_name = Some(name.clone());
|
||||
}
|
||||
if let Some(args) = &function.arguments {
|
||||
self.arguments.push_str(args);
|
||||
}
|
||||
}
|
||||
if let Some(_) = &tool_call.code_interpreter {
|
||||
self.code_interpreter = None;
|
||||
}
|
||||
if let Some(_) = &tool_call.retrieval {
|
||||
self.retrieval = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn into_tool_call(self) -> ToolCall {
|
||||
ToolCall {
|
||||
id: self.id.unwrap_or_default(),
|
||||
function: FunctionCall {
|
||||
name: self.function_name.unwrap_or_default(),
|
||||
arguments: self.arguments,
|
||||
},
|
||||
call_type: self.call_type.unwrap_or_default(),
|
||||
code_interpreter: None,
|
||||
retrieval: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait that provides convenient access to Agent functionality
|
||||
/// when the agent is stored behind an Arc
|
||||
#[async_trait::async_trait]
|
||||
pub trait AgentExt {
|
||||
fn get_agent(&self) -> &Arc<Agent>;
|
||||
|
||||
async fn stream_process_thread(
|
||||
&self,
|
||||
thread: &AgentThread,
|
||||
) -> Result<broadcast::Receiver<MessageResult>> {
|
||||
(*self.get_agent()).process_thread_streaming(thread).await
|
||||
}
|
||||
|
||||
async fn process_thread(&self, thread: &AgentThread) -> Result<Message> {
|
||||
(*self.get_agent()).process_thread(thread).await
|
||||
}
|
||||
|
||||
async fn get_current_thread(&self) -> Option<AgentThread> {
|
||||
(*self.get_agent()).get_current_thread().await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use litellm::MessageProgress;
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn setup() {
|
||||
dotenv::dotenv().ok();
|
||||
}
|
||||
|
||||
struct WeatherTool {
|
||||
agent: Arc<Agent>,
|
||||
}
|
||||
|
||||
impl WeatherTool {
|
||||
fn new(agent: Arc<Agent>) -> Self {
|
||||
Self { agent }
|
||||
}
|
||||
}
|
||||
|
||||
impl WeatherTool {
|
||||
async fn send_progress(
|
||||
&self,
|
||||
content: String,
|
||||
tool_id: String,
|
||||
progress: MessageProgress,
|
||||
) -> Result<()> {
|
||||
let message = Message::tool(
|
||||
None,
|
||||
content,
|
||||
tool_id,
|
||||
Some(self.get_name()),
|
||||
Some(progress),
|
||||
);
|
||||
self.agent.get_stream_sender().await.send(Ok(message))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for WeatherTool {
|
||||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
|
||||
self.send_progress(
|
||||
"Fetching weather data...".to_string(),
|
||||
"123".to_string(),
|
||||
MessageProgress::InProgress,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Simulate a delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let result = json!({
|
||||
"temperature": 20,
|
||||
"unit": "fahrenheit"
|
||||
});
|
||||
|
||||
self.send_progress(
|
||||
serde_json::to_string(&result)?,
|
||||
"123".to_string(),
|
||||
MessageProgress::Complete,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn is_enabled(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
json!({
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather information for a specific location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g., San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"get_weather".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_convo_no_tools() {
|
||||
setup();
|
||||
|
||||
// Create LLM client and agent
|
||||
let agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
Uuid::new_v4(),
|
||||
vec![Message::user("Hello, world!".to_string())],
|
||||
);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_convo_with_tools() {
|
||||
setup();
|
||||
|
||||
// Create agent first
|
||||
let mut agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
||||
// Create weather tool with reference to agent
|
||||
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
|
||||
|
||||
// Add tool to agent
|
||||
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
Uuid::new_v4(),
|
||||
vec![Message::user(
|
||||
"What is the weather in vineyard ut?".to_string(),
|
||||
)],
|
||||
);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_with_multiple_steps() {
|
||||
setup();
|
||||
|
||||
// Create LLM client and agent
|
||||
let mut agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
||||
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
|
||||
|
||||
agent.add_tool(weather_tool.get_name(), weather_tool);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
Uuid::new_v4(),
|
||||
vec![Message::user(
|
||||
"What is the weather in vineyard ut and san francisco?".to_string(),
|
||||
)],
|
||||
);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => panic!("Error processing thread: {:?}", e),
|
||||
};
|
||||
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_state_management() {
|
||||
setup();
|
||||
|
||||
// Create agent
|
||||
let agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
||||
// Test setting single values
|
||||
agent
|
||||
.set_state_value("test_key".to_string(), json!("test_value"))
|
||||
.await;
|
||||
let value = agent.get_state_value("test_key").await;
|
||||
assert_eq!(value, Some(json!("test_value")));
|
||||
|
||||
// Test updating multiple values
|
||||
agent
|
||||
.update_state(|state| {
|
||||
state.insert("key1".to_string(), json!(1));
|
||||
state.insert("key2".to_string(), json!({"nested": "value"}));
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(agent.get_state_value("key1").await, Some(json!(1)));
|
||||
assert_eq!(
|
||||
agent.get_state_value("key2").await,
|
||||
Some(json!({"nested": "value"}))
|
||||
);
|
||||
|
||||
// Test clearing state
|
||||
agent.clear_state().await;
|
||||
assert_eq!(agent.get_state_value("test_key").await, None);
|
||||
assert_eq!(agent.get_state_value("key1").await, None);
|
||||
assert_eq!(agent.get_state_value("key2").await, None);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,273 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
use uuid::Uuid;
|
||||
|
||||
// These imports need to be updated once we migrate the tools
|
||||
// For now, we'll keep them as is and they'll need to be updated later
|
||||
use crate::utils::tools::agents_as_tools::{DashboardAgentTool, MetricAgentTool};
|
||||
use crate::utils::tools::file_tools::{
|
||||
CreateDashboardFilesTool, CreateMetricFilesTool, ModifyDashboardFilesTool,
|
||||
ModifyMetricFilesTool, SendAssetsToUserTool,
|
||||
};
|
||||
use crate::utils::tools::planning_tools::{CreatePlan, ReviewPlan};
|
||||
use crate::{
|
||||
Agent, AgentError, AgentExt, AgentThread,
|
||||
utils::tools::{
|
||||
file_tools::{SearchDataCatalogTool, SearchFilesTool},
|
||||
IntoValueTool, ToolExecutor,
|
||||
},
|
||||
};
|
||||
|
||||
use litellm::Message as AgentMessage;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BusterSuperAgentOutput {
|
||||
pub message: String,
|
||||
pub duration: i64,
|
||||
pub thread_id: Uuid,
|
||||
pub messages: Vec<AgentMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct BusterSuperAgentInput {
|
||||
pub prompt: String,
|
||||
pub thread_id: Option<Uuid>,
|
||||
pub message_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
pub struct BusterSuperAgent {
|
||||
agent: Arc<Agent>,
|
||||
}
|
||||
|
||||
impl AgentExt for BusterSuperAgent {
|
||||
fn get_agent(&self) -> &Arc<Agent> {
|
||||
&self.agent
|
||||
}
|
||||
}
|
||||
|
||||
impl BusterSuperAgent {
|
||||
async fn load_tools(&self) -> Result<()> {
|
||||
// Create tools using the shared Arc
|
||||
let search_data_catalog_tool = SearchDataCatalogTool::new(Arc::clone(&self.agent));
|
||||
let create_plan_tool = CreatePlan::new(Arc::clone(&self.agent));
|
||||
let create_metric_files_tool = CreateMetricFilesTool::new(Arc::clone(&self.agent));
|
||||
let modify_metric_files_tool = ModifyMetricFilesTool::new(Arc::clone(&self.agent));
|
||||
let create_dashboard_files_tool = CreateDashboardFilesTool::new(Arc::clone(&self.agent));
|
||||
let modify_dashboard_files_tool = ModifyDashboardFilesTool::new(Arc::clone(&self.agent));
|
||||
|
||||
// Add tools to the agent
|
||||
self.agent
|
||||
.add_tool(
|
||||
search_data_catalog_tool.get_name(),
|
||||
search_data_catalog_tool.into_value_tool(),
|
||||
)
|
||||
.await;
|
||||
self.agent
|
||||
.add_tool(
|
||||
create_metric_files_tool.get_name(),
|
||||
create_metric_files_tool.into_value_tool(),
|
||||
)
|
||||
.await;
|
||||
self.agent
|
||||
.add_tool(
|
||||
modify_metric_files_tool.get_name(),
|
||||
modify_metric_files_tool.into_value_tool(),
|
||||
)
|
||||
.await;
|
||||
self.agent
|
||||
.add_tool(
|
||||
create_dashboard_files_tool.get_name(),
|
||||
create_dashboard_files_tool.into_value_tool(),
|
||||
)
|
||||
.await;
|
||||
self.agent
|
||||
.add_tool(
|
||||
modify_dashboard_files_tool.get_name(),
|
||||
modify_dashboard_files_tool.into_value_tool(),
|
||||
)
|
||||
.await;
|
||||
self.agent
|
||||
.add_tool(
|
||||
create_plan_tool.get_name(),
|
||||
create_plan_tool.into_value_tool(),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
|
||||
// Create agent with empty tools map
|
||||
let agent = Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
user_id,
|
||||
session_id,
|
||||
"manager_agent".to_string(),
|
||||
));
|
||||
|
||||
let manager = Self { agent };
|
||||
manager.load_tools().await?;
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
pub async fn from_existing(existing_agent: &Arc<Agent>) -> Result<Self> {
|
||||
// Create a new agent with the same core properties and shared state/stream
|
||||
let agent = Arc::new(Agent::from_existing(
|
||||
existing_agent,
|
||||
"manager_agent".to_string(),
|
||||
));
|
||||
let manager = Self { agent };
|
||||
manager.load_tools().await?;
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(MANAGER_AGENT_PROMPT.to_string());
|
||||
|
||||
// Get shutdown receiver
|
||||
let rx = self.stream_process_thread(thread).await?;
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
|
||||
/// Shutdown the manager agent and all its tools
|
||||
pub async fn shutdown(&self) -> Result<()> {
|
||||
self.get_agent().shutdown().await
|
||||
}
|
||||
}
|
||||
|
||||
const MANAGER_AGENT_PROMPT: &str = r##"### Role & Task
|
||||
You are Buster, an expert analytics and data engineer. Your job is to assess what data is available and then provide fast, accurate answers to analytics questions from non-technical users. You do this by analyzing user requests, searching across a data catalog, and building metrics or dashboards.
|
||||
---
|
||||
### Actions Available (Tools)
|
||||
*All actions will become available once the environment is ready and dependencies are met.*
|
||||
- **search_data_catalog**
|
||||
- *Purpose:* Find what data is available for analysis (returns metadata, relevant datasets, documentation, and column details).
|
||||
- *When to use:* Before any analysis is performed or whenever you need context about the available data.
|
||||
- *Dependencies:* This action is always available.
|
||||
- **create_plan**
|
||||
- *Purpose:* Define the goal and outline a plan for analysis.
|
||||
- *When to use:* Before starting any analysis.
|
||||
- *Dependencies:* This action will only be available after the `search_data_catalog` action has been called at least once.
|
||||
- **create_metrics**
|
||||
- *Purpose:* Create new metrics.
|
||||
- *When to use:* For creating individual visualizations. These visualizations can either be returned to the user directly, or added to a dashboard that gets returned to the user. This tool is capable of writing SQL statements and building visualizations.
|
||||
- *Dependencies:* This action will only be available after the `search_data_catalog` and `create_plan` actions have been called.
|
||||
- **update_metrics**
|
||||
- *Purpose:* Update or modify existing metrics/visualizations.
|
||||
- *When to use:* For updating or modifying visualizations. This tool is capable of editing SQL statements and modifying visualization configurations.
|
||||
- *Dependencies:* This action will only be available after the `search_data_catalog` and `create_plan` actions have been called, and at least one metric has been created (i.e., after `create_metrics` has been called at least once).
|
||||
- **create_dashboards**
|
||||
- *Purpose:* Create dashboards and display multiple metrics in one cohesive view.
|
||||
- *When to use:* For creating new dashboards and adding multiple visualizations to it. For organizing several metrics together. Dashboards are sent directly to the user upon completion. You need to use `create_metrics` before you can save metrics to a dashboard.
|
||||
- *Dependencies:* This action will only be available after the `search_data_catalog` and `create_plan` actions have been called, and at least one metric has been created (i.e., after `create_metrics` has been called at least once).
|
||||
- **update_dashboards**
|
||||
- *Purpose:* Update or modify existing dashboards.
|
||||
- *When to use:* For updating or modifying a dashboard. For rearranging the visualizations, editing the display, or adding/removing visualizations from the dashboard. This is not capable of updating the SQL or styling characteristics of individual metrics (even if they are saved to the dashboard).
|
||||
- *Dependencies:* This action will only be available after the `search_data_catalog` and `create_plan` actions have been called, and at least one dashboard has been created (i.e., after `create_dashboards` has been called at least once).
|
||||
---
|
||||
### Key Workflow Reminders
|
||||
1. **Checking the data catalog first**
|
||||
- You cannot assume that any form or type of data exists prior to searching the data catalog.
|
||||
- Prior to creating a plan or doing any kind of task/workflow, you must search the catalog to have sufficient context about the datasets you can query.
|
||||
- If you have sufficient context (i.e. you searched the data catalog in a previous workflow) you do not need to search the data catalog again.
|
||||
2. **Answering questions about available data**
|
||||
- Sometimes users will ask things like "What kinds of reports can you build me?" or "What metrics can you get me about {topic_or_item}?" or "What data do you have access to?" or "How can you help me understand {topic_or_item}?. In these types of scenarios, you should search the data catalog, assess the available data, and then respond to the user.
|
||||
- Your response should be simple, clear, and offer the user an suggestion for how you can help them or proceed.
|
||||
3. **Assessing search results from the data catalog**
|
||||
- Before creating a plan, you should always assess the search results from the data catalog. If the data catalog doesn't contain relevant or adequate data to answer the user request, you should respond and inform the user.
|
||||
4. **Explaining if something is impossible or not supported**
|
||||
- If a user requests any of the following, briefly address it and let them know that you cannot:
|
||||
- *Write Operations:* You can only perform read operations on the database or warehouse. You cannot perform write operations. You are only able to query existing models/tables/datasets/views.
|
||||
- *Forecasting & Python Analysis:* You are not currently capable of using Python or R (i.e. analyses like modeling, what-if analysis, hypothetical scenario analysis, predictive forecasting, etc). You are only capable of querying historical data using SQL. These capabilities are currently in a beta state and will be generally available in the coming months.
|
||||
- *Unsupported Chart Types:* You are only capable of building the following visualizaitons - are table, line, multi-axis combo, bar, histogram, pie/donut, number cards, scatter plot. Other chart types are not currently supported.
|
||||
- *Unspecified Actions:* You cannot perform any actions outside your specified capabilities (e.g. you are unable to send emails, schedule reports, integrate with other applicaitons, update data pipelines, etc).
|
||||
- *Web App Actions:* You are operating as a feature within a web app. You cannot control other features or aspects of the web application (i.e. adding users to the workspace, sharing things, exporting things, creating or adding metrics/dashboards to collections or folders, searching across previously built metrics/dashboards/chats/etc). These user will need to do these kind of actions manually through the UI. Inform them of this and let them know that they can contact our team, contact their system admin, or read our docs for additional help.
|
||||
- *Non-data related requests:* You should not answer requests that aren't specifically related to data analysis. Do not address requests that are non-data related.
|
||||
- You should finish your response to these types of requests with an open-ended offer of something that you can do to help them.
|
||||
- If part of a request is doable, but another part is not (i.e. build a dashboard and send it to another user) you should perform the analysis/workflow, then address the aspects of the user request that you weren't able to perform in your final response (after the analysis is completed).
|
||||
5. **Starting tasks right away**
|
||||
- If you're going to take any action (searching the data catalog, creating a plan, building metrics or dashboards, or modifying metrics/dashboards), begin immediately without messaging the user first.
|
||||
- Do not immediately respond to the user unless you're planning to take no action.. You should never preface your workflow with a response or sending a message to the user.
|
||||
- Oftentimes, you must begin your workflow by searching the data catalog to have sufficient context. Once this is accomplished, you will have access to other actions (like creating a plan).
|
||||
6. **Handling vague, nuanced, or broad requests**
|
||||
- The user may send requests that are extremely broad, vague, or nuanced. These are some examples of vague or broad requests you might get from users...
|
||||
- who are our top customers
|
||||
- how does our perfomance look lately
|
||||
- what kind of things should we be monitoring
|
||||
- build a report of important stuff
|
||||
- etc
|
||||
- In these types of vague or nuanced scenarios, you should attempt to build a dashboard of available data. You should not respond to the user immediately. Instead, your workflow should be: search the data catalog, assess the available data, and then create a plan for your analysis.
|
||||
- You should **never ask the user to clarify** things before doing your analysis.
|
||||
7. **Handling goal, KPI or initiative focused requests**
|
||||
- The user may send requests that want you to help them accomplish a goal, hit a KPI, or improve in some sort of initiative. These are some examples of initiative focused requests you might get from users...
|
||||
- how can we improve our business
|
||||
- i want to improve X, how do I do it?
|
||||
- what can I do to hit X goal
|
||||
- we are trying to hit this KPI, how do we do it?
|
||||
- i want to increase Y, how do we do it?
|
||||
- etc
|
||||
- In these types of initiative focused scenarios, you should attempt to build a dashboard of available data. You should not respond to the user immediately. Instead, your workflow should be: search the data catalog, assess the available data, and then create a plan for your analysis..
|
||||
- You should **never ask the user to clarify** things before doing your analysis.
|
||||
---
|
||||
### Understanding What Gets Sent to the User
|
||||
- **Real-Time Visibility**: The user can observe your actions as they happen, such as searching the data catalog or creating a plan.
|
||||
- **Final Output**: When you complete your task, the user will receive the metrics or dashboards you create, presented based on the following rules:
|
||||
#### For Metrics Not Added to a Dashboard
|
||||
- **Single Metric**: If you create or update just one metric and do not add it to a dashboard, the user will see that metric as a standalone chart.
|
||||
- **Multiple Metrics**: If you create or update multiple metrics without adding them to a dashboard, each metric will be returned as an individual chart. The user can view these charts one at a time (e.g., by navigating through a list), with the most recently created or updated chart displayed first by default.
|
||||
#### For Dashboards
|
||||
- **New or Updated Dashboard**: If you create or update a dashboard, the user will see the entire dashboard, which displays all the metrics you've added to it in a unified view.
|
||||
- **Updates to Dashboard Metrics**: If you update metrics that are already part of a dashboard, the user will see the dashboard with those metrics automatically reflecting the updates.
|
||||
---
|
||||
### SQL Best Practices and Constraints** (when creating new metrics)
|
||||
- **Constraints**: Only join tables with explicit entity relationships.
|
||||
- **SQL Requirements**:
|
||||
- Use schema-qualified table names (`<SCHEMA_NAME>.<TABLE_NAME>`).
|
||||
- Select specific columns (avoid `SELECT *` or `COUNT(*)`).
|
||||
- Use CTEs instead of subqueries, and use snake_case for naming them.
|
||||
- Use `DISTINCT` (not `DISTINCT ON`) with matching `GROUP BY`/`SORT BY` clauses.
|
||||
- Show entity names rather than just IDs.
|
||||
- Handle date conversions appropriately.
|
||||
- Order dates in ascending order.
|
||||
- Reference database identifiers for cross-database queries.
|
||||
- Format output for the specified visualization type.
|
||||
- Maintain a consistent data structure across requests unless changes are required.
|
||||
- Use explicit ordering for custom buckets or categories.
|
||||
---
|
||||
### Response Guidelines and Format
|
||||
- Answer in simple, clear language for non-technical users, avoiding tech terms.
|
||||
- Don't mention tools, actions, or technical details in responses.
|
||||
- Explain how you completed the task after finishing.
|
||||
- Your responses should be very simple.
|
||||
- Your tone should not be formal.
|
||||
- Do not include yml or reference file names directly.
|
||||
- Do not include any SQL, Python, or other code in your final responses.
|
||||
- Never ask the user to clarify anything.
|
||||
- Your response should be in markdown and can use bullets or number lists whenever necessary (but you should never use headers or sub-headers)
|
||||
- Respond in the first person.
|
||||
- As an expert analytics and data engineer, you are capable of giving direct advice based on the analysis you perform.
|
||||
### Example of a Good Response
|
||||
[A single metric was created]
|
||||
This line chart displays the monthly sales for each sales rep. Here's a breakdown of how this is being calculated:
|
||||
1. I searched through your data catalog and found a dataset that has a log of orders. It also includes a column for the sales rep that closed the order.
|
||||
2. I took the sum of revenue generated by all of your orders from the last 12 months.
|
||||
3. I filtered the revenue by sales rep.
|
||||
It looks like Nate Kelley is one of your standout sales reps. He is consistently closing more revenue than other sales reps in most months of the year.
|
||||
---
|
||||
### Summary & Additional Info
|
||||
- If you're going to take action, begin immediately. Never respond to the user until you have completed your workflow
|
||||
- Search the data catalog first, unless you have context
|
||||
- **Never ask clarifying questions**
|
||||
- Any assets created, modified, or referenced will automatically be shown to the user
|
||||
- Under the hood, you use state of the art encryption and have rigorous security protocols and policies in place.
|
||||
- Currently, you are not able to do things that require Python. You are only capable of querying historical data using SQL statements.
|
||||
- Keep final responses clear, simple and concise, focusing on what was accomplished.
|
||||
- You cannot assume that any form of data exists prior to searching the data catalog."##;
|
|
@ -0,0 +1,3 @@
|
|||
pub mod buster_super_agent;
|
||||
|
||||
pub use buster_super_agent::BusterSuperAgent;
|
|
@ -0,0 +1,15 @@
|
|||
//! Agents Library
|
||||
//!
|
||||
//! This library provides agent functionality for interacting with LLMs.
|
||||
|
||||
mod agent;
|
||||
mod agents;
|
||||
mod models;
|
||||
|
||||
// Re-export public API
|
||||
pub use agent::{Agent, AgentError, AgentExt};
|
||||
pub use agents::*;
|
||||
pub use models::*;
|
||||
|
||||
// Re-export types from dependencies that are part of our public API
|
||||
pub use litellm::Message;
|
|
@ -0,0 +1,3 @@
|
|||
mod types;
|
||||
|
||||
pub use types::*;
|
|
@ -0,0 +1,57 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use litellm::Message;
|
||||
|
||||
/// A Thread represents a conversation between a user and the AI agent.
|
||||
/// It contains a sequence of messages in chronological order.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentThread {
|
||||
/// Unique identifier for the thread
|
||||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
/// Ordered sequence of messages in the conversation
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl AgentThread {
|
||||
pub fn new(id: Option<Uuid>, user_id: Uuid, messages: Vec<Message>) -> Self {
|
||||
Self {
|
||||
id: id.unwrap_or(Uuid::new_v4()),
|
||||
user_id,
|
||||
messages,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the developer message in the thread
|
||||
pub fn set_developer_message(&mut self, message: String) {
|
||||
// Look for an existing developer message
|
||||
if let Some(pos) = self
|
||||
.messages
|
||||
.iter()
|
||||
.position(|msg| matches!(msg, Message::Developer { .. }))
|
||||
{
|
||||
// Update existing developer message
|
||||
self.messages[pos] = Message::developer(message);
|
||||
} else {
|
||||
// Insert new developer message at the start
|
||||
self.messages.insert(0, Message::developer(message));
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove the most recent assistant message from the thread
|
||||
pub fn remove_last_assistant_message(&mut self) {
|
||||
if let Some(pos) = self
|
||||
.messages
|
||||
.iter()
|
||||
.rposition(|msg| matches!(msg, Message::Assistant { .. }))
|
||||
{
|
||||
self.messages.remove(pos);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a user message to the thread
|
||||
pub fn add_user_message(&mut self, content: String) {
|
||||
self.messages.push(Message::user(content));
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
pub mod tools;
|
|
@ -0,0 +1,33 @@
|
|||
use anyhow::Result;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// A trait that defines how tools should be implemented.
|
||||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
/// Tools are constructed with a reference to their agent and can access its capabilities.
|
||||
#[async_trait::async_trait]
|
||||
pub trait ToolExecutor: Send + Sync {
|
||||
/// The type of the output of the tool
|
||||
type Output: Serialize + Send;
|
||||
|
||||
/// The type of the parameters for this tool
|
||||
type Params: DeserializeOwned + Send;
|
||||
|
||||
/// Execute the tool with the given parameters.
|
||||
async fn execute(&self, params: Self::Params) -> Result<Self::Output>;
|
||||
|
||||
/// Get the JSON schema for this tool
|
||||
fn get_schema(&self) -> Value;
|
||||
|
||||
/// Get the name of this tool
|
||||
fn get_name(&self) -> String;
|
||||
|
||||
/// Check if this tool is currently enabled
|
||||
async fn is_enabled(&self) -> bool;
|
||||
|
||||
/// Handle shutdown signal. Default implementation does nothing.
|
||||
/// Tools should override this if they need to perform cleanup on shutdown.
|
||||
async fn handle_shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -14,15 +14,16 @@ serde_json = { workspace = true }
|
|||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
rustls-native-certs = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio-postgres = { workspace = true }
|
||||
tokio-postgres-rustls = { workspace = true }
|
||||
bb8-redis = { workspace = true }
|
||||
|
||||
# Additional dependencies
|
||||
async-trait = "0.1"
|
||||
bb8-redis = "0.18.0"
|
||||
futures = "0.3"
|
||||
indexmap = { version = "2.2", features = ["serde"] }
|
||||
once_cell = "1.20"
|
||||
rustls = "0.23"
|
||||
rustls-native-certs = "0.8"
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid", "chrono", "json"] }
|
||||
tokio-postgres = "0.7"
|
||||
tokio-postgres-rustls = "0.13"
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
|
@ -20,4 +20,7 @@ futures = { workspace = true }
|
|||
# Local dependencies
|
||||
database = { path = "../database" }
|
||||
|
||||
# Add any handler-specific dependencies here
|
||||
# Add any handler-specific dependencies here
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
|
@ -6,12 +6,11 @@ use diesel_async::RunQueryDsl;
|
|||
use serde_json::Value;
|
||||
use tokio;
|
||||
use uuid::Uuid;
|
||||
use serde_yaml;
|
||||
|
||||
use crate::messages::types::ThreadMessage;
|
||||
use crate::chats::types::ThreadWithMessages;
|
||||
use crate::chats::types::ChatWithMessages;
|
||||
use crate::messages::types::ChatMessage;
|
||||
use database::pool::get_pg_pool;
|
||||
use database::schema::{messages, threads, users};
|
||||
use database::schema::{chats, messages, users};
|
||||
|
||||
#[derive(Queryable)]
|
||||
pub struct ThreadWithUser {
|
||||
|
@ -36,7 +35,7 @@ pub struct MessageWithUser {
|
|||
pub user_attributes: Value,
|
||||
}
|
||||
|
||||
pub async fn get_thread(thread_id: &Uuid, user_id: &Uuid) -> Result<ThreadWithMessages> {
|
||||
pub async fn get_chat(chat_id: &Uuid, user_id: &Uuid) -> Result<ChatWithMessages> {
|
||||
// Run thread and messages queries concurrently
|
||||
let thread_future = {
|
||||
let mut conn = match get_pg_pool().get().await {
|
||||
|
@ -44,20 +43,20 @@ pub async fn get_thread(thread_id: &Uuid, user_id: &Uuid) -> Result<ThreadWithMe
|
|||
Err(e) => return Err(anyhow!("Failed to get database connection: {}", e)),
|
||||
};
|
||||
|
||||
let thread_id = thread_id.clone();
|
||||
let chat_id = chat_id.clone();
|
||||
let user_id = user_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
threads::table
|
||||
.inner_join(users::table.on(threads::created_by.eq(users::id)))
|
||||
.filter(threads::id.eq(thread_id))
|
||||
.filter(threads::created_by.eq(user_id))
|
||||
.filter(threads::deleted_at.is_null())
|
||||
chats::table
|
||||
.inner_join(users::table.on(chats::created_by.eq(users::id)))
|
||||
.filter(chats::id.eq(chat_id))
|
||||
.filter(chats::created_by.eq(user_id))
|
||||
.filter(chats::deleted_at.is_null())
|
||||
.select((
|
||||
threads::id,
|
||||
threads::title,
|
||||
threads::created_at,
|
||||
threads::updated_at,
|
||||
chats::id,
|
||||
chats::title,
|
||||
chats::created_at,
|
||||
chats::updated_at,
|
||||
users::id,
|
||||
users::name.nullable(),
|
||||
users::email,
|
||||
|
@ -74,12 +73,12 @@ pub async fn get_thread(thread_id: &Uuid, user_id: &Uuid) -> Result<ThreadWithMe
|
|||
Err(e) => return Err(anyhow!("Failed to get database connection: {}", e)),
|
||||
};
|
||||
|
||||
let thread_id = thread_id.clone();
|
||||
let chat_id = chat_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
messages::table
|
||||
.inner_join(users::table.on(messages::created_by.eq(users::id)))
|
||||
.filter(messages::thread_id.eq(thread_id))
|
||||
.filter(messages::chat_id.eq(chat_id))
|
||||
.filter(messages::deleted_at.is_null())
|
||||
.order_by(messages::created_at.desc())
|
||||
.select((
|
||||
|
@ -144,7 +143,7 @@ pub async fn get_thread(thread_id: &Uuid, user_id: &Uuid) -> Result<ThreadWithMe
|
|||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
ThreadMessage {
|
||||
ChatMessage {
|
||||
id: msg.id,
|
||||
request_message: crate::messages::types::ThreadUserMessage {
|
||||
request: msg.request,
|
||||
|
@ -167,7 +166,7 @@ pub async fn get_thread(thread_id: &Uuid, user_id: &Uuid) -> Result<ThreadWithMe
|
|||
.map(String::from);
|
||||
|
||||
// Construct and return the ThreadWithMessages
|
||||
Ok(ThreadWithMessages {
|
||||
Ok(ChatWithMessages {
|
||||
id: thread.id,
|
||||
title: thread.title,
|
||||
is_favorited: false, // Not implemented in current schema
|
||||
|
|
|
@ -1 +1 @@
|
|||
pub mod get_thread;
|
||||
pub mod get_chat;
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::messages::types::ThreadMessage;
|
||||
use crate::messages::types::ChatMessage;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ThreadWithMessages {
|
||||
pub struct ChatWithMessages {
|
||||
pub id: Uuid,
|
||||
pub title: String,
|
||||
pub is_favorited: bool,
|
||||
pub messages: Vec<ThreadMessage>,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
pub created_by: String,
|
||||
|
|
|
@ -3,7 +3,7 @@ use serde_json::Value;
|
|||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ThreadMessage {
|
||||
pub struct ChatMessage {
|
||||
pub id: Uuid,
|
||||
pub request_message: ThreadUserMessage,
|
||||
pub response_messages: Vec<Value>,
|
||||
|
|
|
@ -10,14 +10,12 @@ serde = { workspace = true }
|
|||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
# Library specific dependencies
|
||||
reqwest = { version = "0.12.4", features = ["json", "stream"] }
|
||||
async-trait = "0.1.85"
|
||||
futures = "0.3.30"
|
||||
futures-util = "0.3"
|
||||
dotenv = "0.15.0"
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
dotenv = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mockito = "1.2.0"
|
||||
tokio = { version = "1.38.0", features = ["full", "test-util"] }
|
||||
mockito = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
|
@ -9,7 +9,6 @@ use crate::routes::ws::threads_and_messages::threads_router::ThreadEvent;
|
|||
use litellm::{Message, MessageProgress, ToolCall};
|
||||
|
||||
use crate::utils::tools::file_tools::file_types::file::FileEnum;
|
||||
use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
|
||||
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
|
||||
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
|
||||
use crate::utils::tools::interaction_tools::send_message_to_user::{
|
||||
|
@ -256,9 +255,8 @@ fn transform_tool_message(
|
|||
"search_data_catalog" => tool_data_catalog_search(id, content, progress),
|
||||
"stored_values_search" => tool_stored_values_search(id, content, progress),
|
||||
"search_files" => tool_file_search(id, content, progress),
|
||||
// "create_files" => tool_create_file(id, content, progress),
|
||||
"create_files" => tool_create_file(id, content, progress),
|
||||
"modify_files" => tool_modify_file(id, content, progress),
|
||||
"open_files" => tool_open_files(id, content, progress),
|
||||
"send_message_to_user" => tool_send_message_to_user(id, content, progress),
|
||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||
}?;
|
||||
|
@ -291,8 +289,7 @@ fn transform_assistant_tool_message(
|
|||
"stored_values_search" => assistant_stored_values_search(id, progress, initial),
|
||||
"search_files" => assistant_file_search(id, progress, initial),
|
||||
"create_files" => assistant_create_file(id, tool_calls, progress),
|
||||
// "modify_files" => assistant_modify_file(id, tool_calls, progress),
|
||||
"open_files" => assistant_open_files(id, progress, initial),
|
||||
"modify_files" => assistant_modify_file(id, tool_calls, progress),
|
||||
"send_message_to_user" => assistant_send_message_to_user(id, tool_calls, progress),
|
||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||
}?;
|
||||
|
@ -637,97 +634,6 @@ fn process_file_search_results(
|
|||
Ok(buster_thought_pill_containers)
|
||||
}
|
||||
|
||||
fn assistant_open_files(
|
||||
id: Option<String>,
|
||||
progress: Option<MessageProgress>,
|
||||
initial: bool,
|
||||
) -> Result<Vec<BusterThreadMessage>> {
|
||||
if let Some(progress) = progress {
|
||||
if initial {
|
||||
match progress {
|
||||
MessageProgress::InProgress => {
|
||||
Ok(vec![BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: "Looking through assets...".to_string(),
|
||||
thought_secondary_title: "".to_string(),
|
||||
thoughts: None,
|
||||
status: "loading".to_string(),
|
||||
})])
|
||||
}
|
||||
_ => Err(anyhow::anyhow!(
|
||||
"Assistant file search only supports in progress."
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"Assistant file search only supports initial."
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Assistant file search requires progress."))
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_open_files(
|
||||
id: Option<String>,
|
||||
content: String,
|
||||
progress: Option<MessageProgress>,
|
||||
) -> Result<Vec<BusterThreadMessage>> {
|
||||
if let Some(progress) = progress {
|
||||
let open_files_result = match serde_json::from_str::<OpenFilesOutput>(&content) {
|
||||
Ok(result) => result,
|
||||
Err(_) => return Ok(vec![]), // Silently ignore parsing errors
|
||||
};
|
||||
|
||||
let duration = (open_files_result.duration as f64 / 1000.0 * 10.0).round() / 10.0;
|
||||
let result_count = open_files_result.results.len();
|
||||
|
||||
let mut file_results: HashMap<String, Vec<BusterThoughtPill>> = HashMap::new();
|
||||
|
||||
for result in open_files_result.results {
|
||||
let file_type = match result {
|
||||
FileEnum::Dashboard(_) => "dashboard",
|
||||
FileEnum::Metric(_) => "metric",
|
||||
}
|
||||
.to_string();
|
||||
|
||||
file_results
|
||||
.entry(file_type.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(BusterThoughtPill {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
text: open_files_result.message.clone(),
|
||||
thought_file_type: file_type,
|
||||
});
|
||||
}
|
||||
|
||||
let thought_pill_containers = file_results
|
||||
.into_iter()
|
||||
.map(|(title, thought_pills)| BusterThoughtPillContainer {
|
||||
title: title.chars().next().unwrap().to_uppercase().to_string() + &title[1..],
|
||||
thought_pills,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let buster_thought = BusterThreadMessage::Thought(BusterThought {
|
||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
thought_type: "thought".to_string(),
|
||||
thought_title: format!("Looked through {} assets", result_count),
|
||||
thought_secondary_title: format!("{} seconds", duration),
|
||||
thoughts: Some(thought_pill_containers),
|
||||
status: "completed".to_string(),
|
||||
});
|
||||
|
||||
match progress {
|
||||
MessageProgress::Complete => Ok(vec![buster_thought]),
|
||||
_ => Err(anyhow::anyhow!("Tool open file only supports complete.")),
|
||||
}
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Tool open file requires progress."))
|
||||
}
|
||||
}
|
||||
|
||||
fn assistant_create_file(
|
||||
id: Option<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
|
|
|
@ -3,15 +3,15 @@ use crate::routes::rest::ApiResponse;
|
|||
use axum::extract::Path;
|
||||
use axum::http::StatusCode;
|
||||
use axum::Extension;
|
||||
use handlers::thread_types::ThreadWithMessages;
|
||||
use handlers::chats::helpers::get_thread::get_thread;
|
||||
use handlers::chats::helpers::get_chat::get_chat;
|
||||
use handlers::chats::types::ChatWithMessages;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn get_chat_rest_handler(
|
||||
Extension(user): Extension<User>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<ApiResponse<ThreadWithMessages>, (StatusCode, &'static str)> {
|
||||
let thread_with_messages = match get_thread(&id, &user.id).await {
|
||||
) -> Result<ApiResponse<ChatWithMessages>, (StatusCode, &'static str)> {
|
||||
let thread_with_messages = match get_chat(&id, &user.id).await {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
tracing::error!("Error getting chat: {}", e);
|
||||
|
|
|
@ -5,8 +5,8 @@ use axum::{response::IntoResponse, Json};
|
|||
use chrono::Utc;
|
||||
use diesel::{insert_into, ExpressionMethods, QueryDsl};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use handlers::messages::types::{ThreadMessage, ThreadUserMessage};
|
||||
use handlers::chats::types::ThreadWithMessages;
|
||||
use handlers::messages::types::{ChatMessage, ThreadUserMessage};
|
||||
use handlers::chats::types::ChatWithMessages;
|
||||
use litellm::Message as AgentMessage;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
@ -14,7 +14,7 @@ use tokio::sync::broadcast;
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::routes::rest::ApiResponse;
|
||||
use crate::utils::agent::AgentThread;
|
||||
use agents::AgentThread;
|
||||
use crate::{
|
||||
database_dep::{
|
||||
enums::Verification,
|
||||
|
@ -22,8 +22,8 @@ use crate::{
|
|||
models::{DashboardFile, Message, MessageToFile, MetricFile, Thread, User},
|
||||
schema::{dashboard_files, messages, messages_to_files, metric_files, threads},
|
||||
},
|
||||
utils::agent::manager_agent::{ManagerAgent, ManagerAgentInput},
|
||||
};
|
||||
use agents::agents::buster_super_agent::{BusterSuperAgent, BusterSuperAgentInput};
|
||||
|
||||
use super::agent_message_transformer::{transform_message, BusterContainer, ReasoningMessage};
|
||||
|
||||
|
@ -34,7 +34,7 @@ pub struct ChatCreateNewChat {
|
|||
pub message_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ThreadWithMessages> {
|
||||
async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ChatWithMessages> {
|
||||
let chat_id = request.chat_id.unwrap_or_else(|| Uuid::new_v4());
|
||||
let message_id = request.message_id.unwrap_or_else(|| Uuid::new_v4());
|
||||
|
||||
|
@ -57,11 +57,11 @@ async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ThreadWi
|
|||
deleted_at: None,
|
||||
};
|
||||
|
||||
let mut thread_with_messages = ThreadWithMessages {
|
||||
let mut thread_with_messages = ChatWithMessages {
|
||||
id: chat_id,
|
||||
title: request.prompt.clone(),
|
||||
is_favorited: false,
|
||||
messages: vec![ThreadMessage {
|
||||
messages: vec![ChatMessage {
|
||||
id: message_id,
|
||||
request_message: ThreadUserMessage {
|
||||
request: request.prompt.clone(),
|
||||
|
@ -89,7 +89,7 @@ async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ThreadWi
|
|||
.await?;
|
||||
|
||||
// Initialize agent and process request
|
||||
let agent = ManagerAgent::new(user.id, chat_id).await?;
|
||||
let agent = BusterSuperAgent::new(user.id, chat_id).await?;
|
||||
let mut thread = AgentThread::new(
|
||||
Some(chat_id),
|
||||
user.id,
|
||||
|
@ -203,7 +203,7 @@ async fn process_chat(request: ChatCreateNewChat, user: User) -> Result<ThreadWi
|
|||
pub async fn create_chat(
|
||||
Extension(user): Extension<User>,
|
||||
Json(request): Json<ChatCreateNewChat>,
|
||||
) -> Result<ApiResponse<ThreadWithMessages>, (StatusCode, &'static str)> {
|
||||
) -> Result<ApiResponse<ChatWithMessages>, (StatusCode, &'static str)> {
|
||||
match process_chat(request, user).await {
|
||||
Ok(response) => Ok(ApiResponse::JsonData(response)),
|
||||
Err(e) => {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use handlers::chats::helpers::get_thread;
|
||||
use handlers::chats::helpers::get_chat;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -60,7 +60,7 @@ pub async fn get_thread_ws(
|
|||
Err(e) => return Err(anyhow!("Error subscribing to thread: {}", e)),
|
||||
};
|
||||
|
||||
let thread = get_thread::get_thread(&req.id, &user.id).await?;
|
||||
let thread = get_chat::get_chat(&req.id, &user.id).await?;
|
||||
|
||||
let get_thread_ws_message = WsResponseMessage::new(
|
||||
WsRoutes::Threads(ThreadRoute::Get),
|
||||
|
|
|
@ -31,10 +31,10 @@ use crate::{
|
|||
ws_router::WsRoutes,
|
||||
ws_utils::send_ws_message,
|
||||
},
|
||||
utils::{
|
||||
agent::agents::manager_agent::{ManagerAgent, ManagerAgentInput},
|
||||
agent::{AgentExt, AgentThread},
|
||||
},
|
||||
};
|
||||
use agents::{
|
||||
agents::manager_agent::{ManagerAgent, ManagerAgentInput},
|
||||
AgentExt, AgentThread,
|
||||
};
|
||||
|
||||
use super::agent_message_transformer::transform_message;
|
||||
|
|
|
@ -28,8 +28,8 @@ use crate::{
|
|||
ws_router::WsRoutes,
|
||||
ws_utils::send_ws_message,
|
||||
},
|
||||
utils::agent::manager_agent::{ManagerAgent, ManagerAgentInput},
|
||||
};
|
||||
use agents::agents::manager_agent::{ManagerAgent, ManagerAgentInput};
|
||||
|
||||
/// This creates a new thread for a user. It follows these steps:
|
||||
///
|
||||
|
|
|
@ -5,7 +5,6 @@ pub mod file_types;
|
|||
// pub mod filter_dashboards;
|
||||
pub mod modify_dashboard_files;
|
||||
pub mod modify_metric_files;
|
||||
pub mod open_files;
|
||||
pub mod search_data_catalog;
|
||||
pub mod search_files;
|
||||
pub mod send_assets_to_user;
|
||||
|
@ -15,7 +14,6 @@ pub use create_metric_files::CreateMetricFilesTool;
|
|||
// pub use filter_dashboards::FilterDashboardsTool;
|
||||
pub use modify_dashboard_files::ModifyDashboardFilesTool;
|
||||
pub use modify_metric_files::ModifyMetricFilesTool;
|
||||
pub use open_files::OpenFilesTool;
|
||||
pub use search_data_catalog::SearchDataCatalogTool;
|
||||
pub use search_files::SearchFilesTool;
|
||||
pub use send_assets_to_user::SendAssetsToUserTool;
|
||||
|
|
|
@ -1,629 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::FileModificationTool;
|
||||
use crate::{
|
||||
database_dep::{
|
||||
lib::get_pg_pool,
|
||||
models::{DashboardFile, MetricFile},
|
||||
schema::{dashboard_files, metric_files},
|
||||
},
|
||||
utils::{
|
||||
agent::Agent,
|
||||
tools::{
|
||||
file_tools::file_types::{
|
||||
dashboard_yml::DashboardYml, file::FileEnum, metric_yml::MetricYml,
|
||||
},
|
||||
ToolExecutor,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use litellm::ToolCall;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FileRequest {
|
||||
id: String,
|
||||
file_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OpenFilesParams {
|
||||
files: Vec<FileRequest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OpenFilesOutput {
|
||||
pub message: String,
|
||||
pub duration: i64,
|
||||
pub results: Vec<FileEnum>,
|
||||
}
|
||||
|
||||
pub struct OpenFilesTool {
|
||||
agent: Arc<Agent>,
|
||||
}
|
||||
|
||||
impl OpenFilesTool {
|
||||
pub fn new(agent: Arc<Agent>) -> Self {
|
||||
Self { agent }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolExecutor for OpenFilesTool {
|
||||
type Output = OpenFilesOutput;
|
||||
type Params = OpenFilesParams;
|
||||
|
||||
fn get_name(&self) -> String {
|
||||
"open_files".to_string()
|
||||
}
|
||||
|
||||
async fn is_enabled(&self) -> bool {
|
||||
match self.agent.get_state_value("files_available").await {
|
||||
Some(_) => true,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// No need for agent/thread context as this is just opening files
|
||||
let mut results = Vec::new();
|
||||
let mut error_messages = Vec::new();
|
||||
|
||||
// Track requested IDs by type for later comparison
|
||||
let mut requested_ids: HashMap<String, HashSet<Uuid>> = HashMap::new();
|
||||
let mut found_ids: HashMap<String, HashSet<Uuid>> = HashMap::new();
|
||||
|
||||
// Group requests by file type and track requested IDs
|
||||
let grouped_requests = params
|
||||
.files
|
||||
.into_iter()
|
||||
.filter_map(|req| match Uuid::parse_str(&req.id) {
|
||||
Ok(id) => {
|
||||
requested_ids
|
||||
.entry(req.file_type.clone())
|
||||
.or_default()
|
||||
.insert(id);
|
||||
Some((req.file_type, id))
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(invalid_id = %req.id, "Invalid UUID format");
|
||||
error_messages.push(format!("Invalid UUID format for id: {}", req.id));
|
||||
None
|
||||
}
|
||||
})
|
||||
.fold(HashMap::new(), |mut acc, (file_type, id)| {
|
||||
acc.entry(file_type).or_insert_with(Vec::new).push(id);
|
||||
acc
|
||||
});
|
||||
|
||||
// Get database connection
|
||||
let mut conn = match get_pg_pool().get().await {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
let duration = start_time.elapsed().as_millis() as i64;
|
||||
return Ok(OpenFilesOutput {
|
||||
message: format!("Failed to connect to database: {}", e),
|
||||
results: Vec::new(),
|
||||
duration,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Process metric files
|
||||
if let Some(metric_ids) = grouped_requests.get("metric") {
|
||||
use crate::database_dep::schema::metric_files::dsl::*;
|
||||
match metric_files
|
||||
.filter(id.eq_any(metric_ids))
|
||||
.filter(deleted_at.is_null())
|
||||
.load::<MetricFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(files) => {
|
||||
for file in files {
|
||||
found_ids
|
||||
.entry("metric".to_string())
|
||||
.or_default()
|
||||
.insert(file.id);
|
||||
|
||||
match serde_json::from_value::<MetricYml>(file.content.clone()) {
|
||||
Ok(yml) => {
|
||||
results.push(FileEnum::Metric(yml));
|
||||
}
|
||||
Err(e) => {
|
||||
error_messages.push(format!(
|
||||
"Failed to parse metric file {}: {}",
|
||||
file.id, e
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error_messages.push(format!("Failed to fetch metric files: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process dashboard files
|
||||
if let Some(dashboard_ids) = grouped_requests.get("dashboard") {
|
||||
use crate::database_dep::schema::dashboard_files::dsl::*;
|
||||
match dashboard_files
|
||||
.filter(id.eq_any(dashboard_ids))
|
||||
.filter(deleted_at.is_null())
|
||||
.load::<DashboardFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(files) => {
|
||||
for file in files {
|
||||
found_ids
|
||||
.entry("dashboard".to_string())
|
||||
.or_default()
|
||||
.insert(file.id);
|
||||
|
||||
match serde_json::from_value::<DashboardYml>(file.content.clone()) {
|
||||
Ok(yml) => {
|
||||
results.push(FileEnum::Dashboard(yml));
|
||||
}
|
||||
Err(e) => {
|
||||
error_messages.push(format!(
|
||||
"Failed to parse dashboard file {}: {}",
|
||||
file.id, e
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error_messages.push(format!("Failed to fetch dashboard files: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for missing files
|
||||
for (file_type, ids) in requested_ids.iter() {
|
||||
let found_set = found_ids.get(file_type).unwrap();
|
||||
let missing: Vec<_> = ids.difference(found_set).collect();
|
||||
if !missing.is_empty() {
|
||||
error_messages.push(format!(
|
||||
"Could not find {} files with IDs: {:?}",
|
||||
file_type, missing
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let duration = start_time.elapsed().as_millis() as i64;
|
||||
let message = if error_messages.is_empty() {
|
||||
format!("Successfully opened {} files", results.len())
|
||||
} else {
|
||||
let success_msg = if !results.is_empty() {
|
||||
format!("Successfully opened {} files. ", results.len())
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
format!(
|
||||
"{}Errors occurred: {}",
|
||||
success_msg,
|
||||
error_messages.join("; ")
|
||||
)
|
||||
};
|
||||
|
||||
Ok(OpenFilesOutput {
|
||||
message,
|
||||
duration,
|
||||
results,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"name": "open_files",
|
||||
"strict": true,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["files"],
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["id", "file_type"],
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The UUID of the file from search results"
|
||||
},
|
||||
"file_type": {
|
||||
"type": "string",
|
||||
"enum": ["metric", "dashboard"],
|
||||
"description": "The type of file identified in search results"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": "Array of files from search results to open"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": "Opens files that were found in search results. Use this to view the content of files after finding them through search. Each file requires its UUID and type from the search results."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_dashboard_files(ids: &[Uuid]) -> Result<Vec<(DashboardYml, Uuid, String)>> {
|
||||
debug!(dashboard_ids = ?ids, "Fetching dashboard files");
|
||||
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||
error!(error = %e, "Failed to get database connection");
|
||||
anyhow::anyhow!("Failed to get database connection: {}", e)
|
||||
})?;
|
||||
|
||||
let files = match dashboard_files::table
|
||||
.filter(dashboard_files::id.eq_any(ids))
|
||||
.filter(dashboard_files::deleted_at.is_null())
|
||||
.load::<DashboardFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(files) => {
|
||||
debug!(
|
||||
count = files.len(),
|
||||
"Successfully loaded dashboard files from database"
|
||||
);
|
||||
files
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to load dashboard files from database");
|
||||
return Err(anyhow::anyhow!(
|
||||
"Error loading dashboard files from database: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
for file in files {
|
||||
match serde_json::from_value(file.content.clone()) {
|
||||
Ok(dashboard_yml) => {
|
||||
debug!(dashboard_id = %file.id, "Successfully parsed dashboard YAML");
|
||||
results.push((dashboard_yml, file.id, file.updated_at.to_string()));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
error = %e,
|
||||
dashboard_id = %file.id,
|
||||
"Failed to parse dashboard YAML"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
requested_count = ids.len(),
|
||||
found_count = results.len(),
|
||||
"Completed dashboard files retrieval"
|
||||
);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn get_metric_files(ids: &[Uuid]) -> Result<Vec<(MetricYml, Uuid, String)>> {
|
||||
debug!(metric_ids = ?ids, "Fetching metric files");
|
||||
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||
error!(error = %e, "Failed to get database connection");
|
||||
anyhow::anyhow!("Failed to get database connection: {}", e)
|
||||
})?;
|
||||
|
||||
let files = match metric_files::table
|
||||
.filter(metric_files::id.eq_any(ids))
|
||||
.filter(metric_files::deleted_at.is_null())
|
||||
.load::<MetricFile>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(files) => {
|
||||
debug!(
|
||||
count = files.len(),
|
||||
"Successfully loaded metric files from database"
|
||||
);
|
||||
files
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to load metric files from database");
|
||||
return Err(anyhow::anyhow!(
|
||||
"Error loading metric files from database: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
for file in files {
|
||||
match serde_json::from_value(file.content.clone()) {
|
||||
Ok(metric_yml) => {
|
||||
debug!(metric_id = %file.id, "Successfully parsed metric YAML");
|
||||
results.push((metric_yml, file.id, file.updated_at.to_string()));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
error = %e,
|
||||
metric_id = %file.id,
|
||||
"Failed to parse metric YAML"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
requested_count = ids.len(),
|
||||
found_count = results.len(),
|
||||
"Completed metric files retrieval"
|
||||
);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn build_status_message(
|
||||
results: &[FileEnum],
|
||||
missing_files: &[String],
|
||||
error_messages: &[String],
|
||||
) -> String {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Add success message if any files were found
|
||||
if !results.is_empty() {
|
||||
parts.push(format!("Successfully opened {} files", results.len()));
|
||||
}
|
||||
|
||||
// Add missing files information
|
||||
if !missing_files.is_empty() {
|
||||
parts.push(format!(
|
||||
"Could not find the following files: {}",
|
||||
missing_files.join("; ")
|
||||
));
|
||||
}
|
||||
|
||||
// Add any error messages
|
||||
if !error_messages.is_empty() {
|
||||
parts.push(format!(
|
||||
"Encountered the following issues: {}",
|
||||
error_messages.join("; ")
|
||||
));
|
||||
}
|
||||
|
||||
// If everything is empty, provide a clear message
|
||||
if parts.is_empty() {
|
||||
"No files were processed due to invalid input".to_string()
|
||||
} else {
|
||||
parts.join(". ")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::utils::tools::file_tools::file_types::metric_yml::{
|
||||
BarAndLineAxis, BarLineChartConfig, BaseChartConfig, ChartConfig, DataMetadata,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
|
||||
fn create_test_dashboard() -> DashboardYml {
|
||||
DashboardYml {
|
||||
id: Some(Uuid::new_v4()),
|
||||
updated_at: Some(Utc::now()),
|
||||
name: "Test Dashboard".to_string(),
|
||||
rows: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_metric() -> MetricYml {
|
||||
MetricYml {
|
||||
id: Some(Uuid::new_v4()),
|
||||
updated_at: Some(Utc::now()),
|
||||
title: "Test Metric".to_string(),
|
||||
description: Some("Test Description".to_string()),
|
||||
sql: "SELECT * FROM test_table".to_string(),
|
||||
chart_config: ChartConfig::Bar(BarLineChartConfig {
|
||||
base: BaseChartConfig {
|
||||
column_label_formats: HashMap::new(),
|
||||
column_settings: None,
|
||||
colors: None,
|
||||
show_legend: None,
|
||||
grid_lines: None,
|
||||
show_legend_headline: None,
|
||||
goal_lines: None,
|
||||
trendlines: None,
|
||||
disable_tooltip: None,
|
||||
y_axis_config: None,
|
||||
x_axis_config: None,
|
||||
category_axis_style_config: None,
|
||||
y2_axis_config: None,
|
||||
},
|
||||
bar_and_line_axis: BarAndLineAxis {
|
||||
x: vec![],
|
||||
y: vec![],
|
||||
category: None,
|
||||
tooltip: None,
|
||||
},
|
||||
bar_layout: None,
|
||||
bar_sort_by: None,
|
||||
bar_group_type: None,
|
||||
bar_show_total_at_top: None,
|
||||
line_group_type: None,
|
||||
}),
|
||||
data_metadata: Some(vec![
|
||||
DataMetadata {
|
||||
name: "id".to_string(),
|
||||
data_type: "number".to_string(),
|
||||
},
|
||||
DataMetadata {
|
||||
name: "value".to_string(),
|
||||
data_type: "string".to_string(),
|
||||
},
|
||||
]),
|
||||
dataset_ids: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_all_success() {
|
||||
let results = vec![
|
||||
FileEnum::Dashboard(create_test_dashboard()),
|
||||
FileEnum::Metric(create_test_metric()),
|
||||
];
|
||||
let missing_files = vec![];
|
||||
let error_messages = vec![];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(message, "Successfully opened 2 files");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_with_missing() {
|
||||
let results = vec![
|
||||
FileEnum::Dashboard(create_test_dashboard()),
|
||||
FileEnum::Metric(create_test_metric()),
|
||||
];
|
||||
let missing_files = vec![
|
||||
"1 dashboard: abc-123".to_string(),
|
||||
"2 metrics: def-456, ghi-789".to_string(),
|
||||
];
|
||||
let error_messages = vec![];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(
|
||||
message,
|
||||
"Successfully opened 2 files. Could not find the following files: 1 dashboard: abc-123; 2 metrics: def-456, ghi-789"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_with_errors() {
|
||||
let results = vec![];
|
||||
let missing_files = vec![];
|
||||
let error_messages = vec![
|
||||
"Invalid UUID format for id: xyz".to_string(),
|
||||
"Error processing metric files: connection failed".to_string(),
|
||||
];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(
|
||||
message,
|
||||
"Encountered the following issues: Invalid UUID format for id: xyz; Error processing metric files: connection failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_status_message_mixed_results() {
|
||||
let results = vec![FileEnum::Metric(create_test_metric())];
|
||||
let missing_files = vec!["1 dashboard: abc-123".to_string()];
|
||||
let error_messages = vec!["Invalid UUID format for id: xyz".to_string()];
|
||||
|
||||
let message = build_status_message(&results, &missing_files, &error_messages);
|
||||
assert_eq!(
|
||||
message,
|
||||
"Successfully opened 1 files. Could not find the following files: 1 dashboard: abc-123. Encountered the following issues: Invalid UUID format for id: xyz"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_params() {
|
||||
let params_json = json!({
|
||||
"files": [
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "dashboard"},
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "metric"}
|
||||
]
|
||||
});
|
||||
|
||||
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
|
||||
assert_eq!(params.files.len(), 2);
|
||||
assert_eq!(params.files[0].file_type, "dashboard");
|
||||
assert_eq!(params.files[1].file_type, "metric");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_uuid() {
|
||||
let params_json = json!({
|
||||
"files": [
|
||||
{"id": "not-a-uuid", "file_type": "dashboard"},
|
||||
{"id": "also-not-a-uuid", "file_type": "metric"}
|
||||
]
|
||||
});
|
||||
|
||||
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
|
||||
for file in ¶ms.files {
|
||||
let uuid_result = Uuid::parse_str(&file.id);
|
||||
assert!(uuid_result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_file_type() {
|
||||
let params_json = json!({
|
||||
"files": [
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440000", "file_type": "invalid"},
|
||||
{"id": "550e8400-e29b-41d4-a716-446655440001", "file_type": "unknown"}
|
||||
]
|
||||
});
|
||||
|
||||
let params: OpenFilesParams = serde_json::from_value(params_json).unwrap();
|
||||
for file in ¶ms.files {
|
||||
assert!(file.file_type != "dashboard" && file.file_type != "metric");
|
||||
}
|
||||
}
|
||||
|
||||
// Mock tests for file retrieval
|
||||
#[tokio::test]
|
||||
async fn test_get_dashboard_files() {
|
||||
let test_id = Uuid::new_v4();
|
||||
let dashboard = create_test_dashboard();
|
||||
let test_files = vec![DashboardFile {
|
||||
id: test_id,
|
||||
name: dashboard.name.clone(),
|
||||
file_name: "test.yml".to_string(),
|
||||
content: serde_json::to_value(&dashboard).unwrap(),
|
||||
filter: None,
|
||||
organization_id: Uuid::new_v4(),
|
||||
created_by: Uuid::new_v4(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
}];
|
||||
|
||||
// TODO: Mock database connection and return test_files
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_metric_files() {
|
||||
let test_id = Uuid::new_v4();
|
||||
let metric = create_test_metric();
|
||||
let test_files = vec![MetricFile {
|
||||
id: test_id,
|
||||
name: metric.title.clone(),
|
||||
file_name: "test.yml".to_string(),
|
||||
content: serde_json::to_value(&metric).unwrap(),
|
||||
verification: crate::database_dep::enums::Verification::NotRequested,
|
||||
evaluation_obj: None,
|
||||
evaluation_summary: None,
|
||||
evaluation_score: None,
|
||||
organization_id: Uuid::new_v4(),
|
||||
created_by: Uuid::new_v4(),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
deleted_at: None,
|
||||
}];
|
||||
|
||||
// TODO: Mock database connection and return test_files
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue