mirror of https://github.com/buster-so/buster.git
fix on redo message and the chat context
This commit is contained in:
parent
c19a5512fe
commit
96553aa2e0
|
@ -1,6 +1,7 @@
|
||||||
use std::sync::Arc;
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use agents::{Agent, AgentMessage};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use database::{
|
use database::{
|
||||||
|
@ -9,7 +10,6 @@ use database::{
|
||||||
};
|
};
|
||||||
use diesel::prelude::*;
|
use diesel::prelude::*;
|
||||||
use diesel_async::RunQueryDsl;
|
use diesel_async::RunQueryDsl;
|
||||||
use agents::{Agent, AgentMessage};
|
|
||||||
use middleware::AuthenticatedUser;
|
use middleware::AuthenticatedUser;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
@ -28,36 +28,56 @@ impl ChatContextLoader {
|
||||||
// Helper function to check for tool usage and set appropriate context
|
// Helper function to check for tool usage and set appropriate context
|
||||||
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
|
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
|
||||||
// Handle tool calls from assistant messages
|
// Handle tool calls from assistant messages
|
||||||
if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
|
if let AgentMessage::Assistant {
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
..
|
||||||
|
} = message
|
||||||
|
{
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
match tool_call.function.name.as_str() {
|
match tool_call.function.name.as_str() {
|
||||||
"search_data_catalog" => {
|
"search_data_catalog" => {
|
||||||
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||||
.await;
|
.await;
|
||||||
},
|
}
|
||||||
"create_metrics" | "update_metrics" => {
|
"create_metrics" | "update_metrics" => {
|
||||||
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||||
.await;
|
.await;
|
||||||
},
|
}
|
||||||
"create_dashboards" | "update_dashboards" => {
|
"create_dashboards" | "update_dashboards" => {
|
||||||
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(
|
||||||
|
String::from("dashboards_available"),
|
||||||
|
Value::Bool(true),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
},
|
}
|
||||||
"import_assets" => {
|
"import_assets" => {
|
||||||
// When we see import_assets, we need to check the content in the corresponding tool response
|
// When we see import_assets, we need to check the content in the corresponding tool response
|
||||||
// This will be handled separately when processing tool messages
|
// This will be handled separately when processing tool messages
|
||||||
},
|
}
|
||||||
name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => {
|
name if name.contains("file")
|
||||||
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
|| name.contains("read")
|
||||||
|
|| name.contains("write")
|
||||||
|
|| name.contains("edit") =>
|
||||||
|
{
|
||||||
|
agent
|
||||||
|
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||||
.await;
|
.await;
|
||||||
},
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle tool responses - important for import_assets
|
// Handle tool responses - important for import_assets
|
||||||
if let AgentMessage::Tool { name: Some(tool_name), content, .. } = message {
|
if let AgentMessage::Tool {
|
||||||
|
name: Some(tool_name),
|
||||||
|
content,
|
||||||
|
..
|
||||||
|
} = message
|
||||||
|
{
|
||||||
if tool_name == "import_assets" {
|
if tool_name == "import_assets" {
|
||||||
// Parse the tool response to see what was imported
|
// Parse the tool response to see what was imported
|
||||||
if let Ok(import_result) = serde_json::from_str::<serde_json::Value>(content) {
|
if let Ok(import_result) = serde_json::from_str::<serde_json::Value>(content) {
|
||||||
|
@ -65,65 +85,94 @@ impl ChatContextLoader {
|
||||||
if let Some(files) = import_result.get("files").and_then(|f| f.as_array()) {
|
if let Some(files) = import_result.get("files").and_then(|f| f.as_array()) {
|
||||||
if !files.is_empty() {
|
if !files.is_empty() {
|
||||||
// Set files_available for any imported files
|
// Set files_available for any imported files
|
||||||
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Check each file to determine its type
|
// Check each file to determine its type
|
||||||
let mut has_metrics = false;
|
let mut has_metrics = false;
|
||||||
let mut has_dashboards = false;
|
let mut has_dashboards = false;
|
||||||
let mut has_datasets = false;
|
let mut has_datasets = false;
|
||||||
|
|
||||||
for file in files {
|
for file in files {
|
||||||
// Check file_type/asset_type to determine what kind of asset this is
|
// Check file_type/asset_type to determine what kind of asset this is
|
||||||
let file_type = file.get("file_type").and_then(|ft| ft.as_str())
|
let file_type = file
|
||||||
|
.get("file_type")
|
||||||
|
.and_then(|ft| ft.as_str())
|
||||||
.or_else(|| file.get("asset_type").and_then(|at| at.as_str()));
|
.or_else(|| file.get("asset_type").and_then(|at| at.as_str()));
|
||||||
|
|
||||||
tracing::debug!("Processing imported file with type: {:?}", file_type);
|
tracing::debug!(
|
||||||
|
"Processing imported file with type: {:?}",
|
||||||
|
file_type
|
||||||
|
);
|
||||||
|
|
||||||
match file_type {
|
match file_type {
|
||||||
Some("metric") => {
|
Some("metric") => {
|
||||||
has_metrics = true;
|
has_metrics = true;
|
||||||
|
|
||||||
// Check if the metric has dataset references
|
// Check if the metric has dataset references
|
||||||
if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) {
|
if let Some(yml_content) =
|
||||||
if yml_content.contains("dataset") || yml_content.contains("datasetIds") {
|
file.get("yml_content").and_then(|y| y.as_str())
|
||||||
|
{
|
||||||
|
if yml_content.contains("dataset")
|
||||||
|
|| yml_content.contains("datasetIds")
|
||||||
|
{
|
||||||
has_datasets = true;
|
has_datasets = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Some("dashboard") => {
|
Some("dashboard") => {
|
||||||
has_dashboards = true;
|
has_dashboards = true;
|
||||||
|
|
||||||
// Dashboards often reference metrics too
|
// Dashboards often reference metrics too
|
||||||
has_metrics = true;
|
has_metrics = true;
|
||||||
|
|
||||||
// Check if the dashboard has dataset references via metrics
|
// Check if the dashboard has dataset references via metrics
|
||||||
if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) {
|
if let Some(yml_content) =
|
||||||
if yml_content.contains("dataset") || yml_content.contains("datasetIds") {
|
file.get("yml_content").and_then(|y| y.as_str())
|
||||||
|
{
|
||||||
|
if yml_content.contains("dataset")
|
||||||
|
|| yml_content.contains("datasetIds")
|
||||||
|
{
|
||||||
has_datasets = true;
|
has_datasets = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
_ => {
|
_ => {
|
||||||
tracing::debug!("Unknown file type in import_assets: {:?}", file_type);
|
tracing::debug!(
|
||||||
|
"Unknown file type in import_assets: {:?}",
|
||||||
|
file_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set appropriate state values based on what we found
|
// Set appropriate state values based on what we found
|
||||||
if has_metrics {
|
if has_metrics {
|
||||||
tracing::debug!("Setting metrics_available state to true");
|
tracing::debug!("Setting metrics_available state to true");
|
||||||
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(
|
||||||
|
String::from("metrics_available"),
|
||||||
|
Value::Bool(true),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
if has_dashboards {
|
if has_dashboards {
|
||||||
tracing::debug!("Setting dashboards_available state to true");
|
tracing::debug!("Setting dashboards_available state to true");
|
||||||
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(
|
||||||
|
String::from("dashboards_available"),
|
||||||
|
Value::Bool(true),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
if has_datasets {
|
if has_datasets {
|
||||||
tracing::debug!("Setting data_context state to true");
|
tracing::debug!("Setting data_context state to true");
|
||||||
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
agent
|
||||||
|
.set_state_value(
|
||||||
|
String::from("data_context"),
|
||||||
|
Value::Bool(true),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -136,34 +185,47 @@ impl ChatContextLoader {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl ContextLoader for ChatContextLoader {
|
impl ContextLoader for ChatContextLoader {
|
||||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
async fn load_context(
|
||||||
|
&self,
|
||||||
|
user: &AuthenticatedUser,
|
||||||
|
agent: &Arc<Agent>,
|
||||||
|
) -> Result<Vec<AgentMessage>> {
|
||||||
let mut conn = get_pg_pool().get().await?;
|
let mut conn = get_pg_pool().get().await?;
|
||||||
|
|
||||||
// First verify the chat exists and user has access
|
// First verify the chat exists and user has access
|
||||||
let chat = chats::table
|
let chat = chats::table
|
||||||
.filter(chats::id.eq(self.chat_id))
|
.filter(chats::id.eq(self.chat_id))
|
||||||
.filter(chats::created_by.eq(&user.id))
|
.filter(chats::created_by.eq(&user.id))
|
||||||
|
.filter(chats::deleted_at.is_null())
|
||||||
.first::<database::models::Chat>(&mut conn)
|
.first::<database::models::Chat>(&mut conn)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Get only the most recent message for the chat
|
// Get only the most recent message for the chat
|
||||||
let message = messages::table
|
let message = match messages::table
|
||||||
.filter(messages::chat_id.eq(chat.id))
|
.filter(messages::chat_id.eq(chat.id))
|
||||||
|
.filter(messages::deleted_at.is_null())
|
||||||
.order_by(messages::created_at.desc())
|
.order_by(messages::created_at.desc())
|
||||||
.first::<database::models::Message>(&mut conn)
|
.first::<database::models::Message>(&mut conn)
|
||||||
.await?;
|
.await
|
||||||
|
{
|
||||||
|
Ok(message) => message,
|
||||||
|
Err(diesel::NotFound) => return Ok(vec![]),
|
||||||
|
Err(e) => return Err(anyhow::anyhow!("Failed to get message: {}", e)),
|
||||||
|
};
|
||||||
|
|
||||||
// Track seen message IDs
|
// Track seen message IDs
|
||||||
let mut seen_ids = HashSet::new();
|
let mut seen_ids = HashSet::new();
|
||||||
// Convert messages to AgentMessages
|
// Convert messages to AgentMessages
|
||||||
let mut agent_messages = Vec::new();
|
let mut agent_messages = Vec::new();
|
||||||
|
|
||||||
// Process only the most recent message's raw LLM messages
|
// Process only the most recent message's raw LLM messages
|
||||||
if let Ok(raw_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages) {
|
if let Ok(raw_messages) =
|
||||||
|
serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages)
|
||||||
|
{
|
||||||
// Check each message for tool calls and update context
|
// Check each message for tool calls and update context
|
||||||
for agent_message in &raw_messages {
|
for agent_message in &raw_messages {
|
||||||
Self::update_context_from_tool_calls(agent, agent_message).await;
|
Self::update_context_from_tool_calls(agent, agent_message).await;
|
||||||
|
|
||||||
// Only add messages with new IDs
|
// Only add messages with new IDs
|
||||||
if let Some(id) = agent_message.get_id() {
|
if let Some(id) = agent_message.get_id() {
|
||||||
if seen_ids.insert(id.to_string()) {
|
if seen_ids.insert(id.to_string()) {
|
||||||
|
@ -178,4 +240,4 @@ impl ContextLoader for ChatContextLoader {
|
||||||
|
|
||||||
Ok(agent_messages)
|
Ok(agent_messages)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -228,8 +228,6 @@ pub async fn post_chat_handler(
|
||||||
|
|
||||||
let messages = generate_asset_messages(asset_id_value, asset_type_value, &user).await?;
|
let messages = generate_asset_messages(asset_id_value, asset_type_value, &user).await?;
|
||||||
|
|
||||||
println!("messages: {:?}", messages);
|
|
||||||
|
|
||||||
// Add messages to chat and associate with chat_id
|
// Add messages to chat and associate with chat_id
|
||||||
let mut updated_messages = Vec::new();
|
let mut updated_messages = Vec::new();
|
||||||
for mut message in messages {
|
for mut message in messages {
|
||||||
|
|
Loading…
Reference in New Issue