mirror of https://github.com/buster-so/buster.git
fix on redo message and the chat context
This commit is contained in:
parent
c19a5512fe
commit
96553aa2e0
|
@ -1,6 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use agents::{Agent, AgentMessage};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use database::{
|
||||
|
@ -9,7 +10,6 @@ use database::{
|
|||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use agents::{Agent, AgentMessage};
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
@ -28,36 +28,56 @@ impl ChatContextLoader {
|
|||
// Helper function to check for tool usage and set appropriate context
|
||||
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
|
||||
// Handle tool calls from assistant messages
|
||||
if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
|
||||
if let AgentMessage::Assistant {
|
||||
tool_calls: Some(tool_calls),
|
||||
..
|
||||
} = message
|
||||
{
|
||||
for tool_call in tool_calls {
|
||||
match tool_call.function.name.as_str() {
|
||||
"search_data_catalog" => {
|
||||
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
}
|
||||
"create_metrics" | "update_metrics" => {
|
||||
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
}
|
||||
"create_dashboards" | "update_dashboards" => {
|
||||
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(
|
||||
String::from("dashboards_available"),
|
||||
Value::Bool(true),
|
||||
)
|
||||
.await;
|
||||
},
|
||||
}
|
||||
"import_assets" => {
|
||||
// When we see import_assets, we need to check the content in the corresponding tool response
|
||||
// This will be handled separately when processing tool messages
|
||||
},
|
||||
name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => {
|
||||
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
}
|
||||
name if name.contains("file")
|
||||
|| name.contains("read")
|
||||
|| name.contains("write")
|
||||
|| name.contains("edit") =>
|
||||
{
|
||||
agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
},
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool responses - important for import_assets
|
||||
if let AgentMessage::Tool { name: Some(tool_name), content, .. } = message {
|
||||
if let AgentMessage::Tool {
|
||||
name: Some(tool_name),
|
||||
content,
|
||||
..
|
||||
} = message
|
||||
{
|
||||
if tool_name == "import_assets" {
|
||||
// Parse the tool response to see what was imported
|
||||
if let Ok(import_result) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
|
@ -65,7 +85,8 @@ impl ChatContextLoader {
|
|||
if let Some(files) = import_result.get("files").and_then(|f| f.as_array()) {
|
||||
if !files.is_empty() {
|
||||
// Set files_available for any imported files
|
||||
agent.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(String::from("files_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
||||
// Check each file to determine its type
|
||||
|
@ -75,22 +96,31 @@ impl ChatContextLoader {
|
|||
|
||||
for file in files {
|
||||
// Check file_type/asset_type to determine what kind of asset this is
|
||||
let file_type = file.get("file_type").and_then(|ft| ft.as_str())
|
||||
let file_type = file
|
||||
.get("file_type")
|
||||
.and_then(|ft| ft.as_str())
|
||||
.or_else(|| file.get("asset_type").and_then(|at| at.as_str()));
|
||||
|
||||
tracing::debug!("Processing imported file with type: {:?}", file_type);
|
||||
tracing::debug!(
|
||||
"Processing imported file with type: {:?}",
|
||||
file_type
|
||||
);
|
||||
|
||||
match file_type {
|
||||
Some("metric") => {
|
||||
has_metrics = true;
|
||||
|
||||
// Check if the metric has dataset references
|
||||
if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) {
|
||||
if yml_content.contains("dataset") || yml_content.contains("datasetIds") {
|
||||
if let Some(yml_content) =
|
||||
file.get("yml_content").and_then(|y| y.as_str())
|
||||
{
|
||||
if yml_content.contains("dataset")
|
||||
|| yml_content.contains("datasetIds")
|
||||
{
|
||||
has_datasets = true;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Some("dashboard") => {
|
||||
has_dashboards = true;
|
||||
|
||||
|
@ -98,14 +128,21 @@ impl ChatContextLoader {
|
|||
has_metrics = true;
|
||||
|
||||
// Check if the dashboard has dataset references via metrics
|
||||
if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) {
|
||||
if yml_content.contains("dataset") || yml_content.contains("datasetIds") {
|
||||
if let Some(yml_content) =
|
||||
file.get("yml_content").and_then(|y| y.as_str())
|
||||
{
|
||||
if yml_content.contains("dataset")
|
||||
|| yml_content.contains("datasetIds")
|
||||
{
|
||||
has_datasets = true;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!("Unknown file type in import_assets: {:?}", file_type);
|
||||
tracing::debug!(
|
||||
"Unknown file type in import_assets: {:?}",
|
||||
file_type
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -113,17 +150,29 @@ impl ChatContextLoader {
|
|||
// Set appropriate state values based on what we found
|
||||
if has_metrics {
|
||||
tracing::debug!("Setting metrics_available state to true");
|
||||
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(
|
||||
String::from("metrics_available"),
|
||||
Value::Bool(true),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
if has_dashboards {
|
||||
tracing::debug!("Setting dashboards_available state to true");
|
||||
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(
|
||||
String::from("dashboards_available"),
|
||||
Value::Bool(true),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
if has_datasets {
|
||||
tracing::debug!("Setting data_context state to true");
|
||||
agent.set_state_value(String::from("data_context"), Value::Bool(true))
|
||||
agent
|
||||
.set_state_value(
|
||||
String::from("data_context"),
|
||||
Value::Bool(true),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
@ -136,22 +185,33 @@ impl ChatContextLoader {
|
|||
|
||||
#[async_trait]
|
||||
impl ContextLoader for ChatContextLoader {
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||
async fn load_context(
|
||||
&self,
|
||||
user: &AuthenticatedUser,
|
||||
agent: &Arc<Agent>,
|
||||
) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
// First verify the chat exists and user has access
|
||||
let chat = chats::table
|
||||
.filter(chats::id.eq(self.chat_id))
|
||||
.filter(chats::created_by.eq(&user.id))
|
||||
.filter(chats::deleted_at.is_null())
|
||||
.first::<database::models::Chat>(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Get only the most recent message for the chat
|
||||
let message = messages::table
|
||||
let message = match messages::table
|
||||
.filter(messages::chat_id.eq(chat.id))
|
||||
.filter(messages::deleted_at.is_null())
|
||||
.order_by(messages::created_at.desc())
|
||||
.first::<database::models::Message>(&mut conn)
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
Ok(message) => message,
|
||||
Err(diesel::NotFound) => return Ok(vec![]),
|
||||
Err(e) => return Err(anyhow::anyhow!("Failed to get message: {}", e)),
|
||||
};
|
||||
|
||||
// Track seen message IDs
|
||||
let mut seen_ids = HashSet::new();
|
||||
|
@ -159,7 +219,9 @@ impl ContextLoader for ChatContextLoader {
|
|||
let mut agent_messages = Vec::new();
|
||||
|
||||
// Process only the most recent message's raw LLM messages
|
||||
if let Ok(raw_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages) {
|
||||
if let Ok(raw_messages) =
|
||||
serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages)
|
||||
{
|
||||
// Check each message for tool calls and update context
|
||||
for agent_message in &raw_messages {
|
||||
Self::update_context_from_tool_calls(agent, agent_message).await;
|
||||
|
|
|
@ -228,8 +228,6 @@ pub async fn post_chat_handler(
|
|||
|
||||
let messages = generate_asset_messages(asset_id_value, asset_type_value, &user).await?;
|
||||
|
||||
println!("messages: {:?}", messages);
|
||||
|
||||
// Add messages to chat and associate with chat_id
|
||||
let mut updated_messages = Vec::new();
|
||||
for mut message in messages {
|
||||
|
|
Loading…
Reference in New Issue