fix on redo message and the chat context

This commit is contained in:
dal 2025-04-16 17:41:40 -06:00
parent c19a5512fe
commit 96553aa2e0
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 105 additions and 45 deletions

View File

@ -1,6 +1,7 @@
use std::sync::Arc;
use std::collections::HashSet;
use std::sync::Arc;
use agents::{Agent, AgentMessage};
use anyhow::Result;
use async_trait::async_trait;
use database::{
@ -9,7 +10,6 @@ use database::{
};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use agents::{Agent, AgentMessage};
use middleware::AuthenticatedUser;
use serde_json::Value;
use uuid::Uuid;
@ -28,36 +28,56 @@ impl ChatContextLoader {
// Helper function to check for tool usage and set appropriate context
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
// Handle tool calls from assistant messages
if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
if let AgentMessage::Assistant {
tool_calls: Some(tool_calls),
..
} = message
{
for tool_call in tool_calls {
match tool_call.function.name.as_str() {
"search_data_catalog" => {
agent.set_state_value(String::from("data_context"), Value::Bool(true))
agent
.set_state_value(String::from("data_context"), Value::Bool(true))
.await;
},
}
"create_metrics" | "update_metrics" => {
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
agent
.set_state_value(String::from("metrics_available"), Value::Bool(true))
.await;
},
}
"create_dashboards" | "update_dashboards" => {
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
agent
.set_state_value(
String::from("dashboards_available"),
Value::Bool(true),
)
.await;
},
}
"import_assets" => {
// When we see import_assets, we need to check the content in the corresponding tool response
// This will be handled separately when processing tool messages
},
name if name.contains("file") || name.contains("read") || name.contains("write") || name.contains("edit") => {
agent.set_state_value(String::from("files_available"), Value::Bool(true))
}
name if name.contains("file")
|| name.contains("read")
|| name.contains("write")
|| name.contains("edit") =>
{
agent
.set_state_value(String::from("files_available"), Value::Bool(true))
.await;
},
}
_ => {}
}
}
}
// Handle tool responses - important for import_assets
if let AgentMessage::Tool { name: Some(tool_name), content, .. } = message {
if let AgentMessage::Tool {
name: Some(tool_name),
content,
..
} = message
{
if tool_name == "import_assets" {
// Parse the tool response to see what was imported
if let Ok(import_result) = serde_json::from_str::<serde_json::Value>(content) {
@ -65,7 +85,8 @@ impl ChatContextLoader {
if let Some(files) = import_result.get("files").and_then(|f| f.as_array()) {
if !files.is_empty() {
// Set files_available for any imported files
agent.set_state_value(String::from("files_available"), Value::Bool(true))
agent
.set_state_value(String::from("files_available"), Value::Bool(true))
.await;
// Check each file to determine its type
@ -75,22 +96,31 @@ impl ChatContextLoader {
for file in files {
// Check file_type/asset_type to determine what kind of asset this is
let file_type = file.get("file_type").and_then(|ft| ft.as_str())
let file_type = file
.get("file_type")
.and_then(|ft| ft.as_str())
.or_else(|| file.get("asset_type").and_then(|at| at.as_str()));
tracing::debug!("Processing imported file with type: {:?}", file_type);
tracing::debug!(
"Processing imported file with type: {:?}",
file_type
);
match file_type {
Some("metric") => {
has_metrics = true;
// Check if the metric has dataset references
if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) {
if yml_content.contains("dataset") || yml_content.contains("datasetIds") {
if let Some(yml_content) =
file.get("yml_content").and_then(|y| y.as_str())
{
if yml_content.contains("dataset")
|| yml_content.contains("datasetIds")
{
has_datasets = true;
}
}
},
}
Some("dashboard") => {
has_dashboards = true;
@ -98,14 +128,21 @@ impl ChatContextLoader {
has_metrics = true;
// Check if the dashboard has dataset references via metrics
if let Some(yml_content) = file.get("yml_content").and_then(|y| y.as_str()) {
if yml_content.contains("dataset") || yml_content.contains("datasetIds") {
if let Some(yml_content) =
file.get("yml_content").and_then(|y| y.as_str())
{
if yml_content.contains("dataset")
|| yml_content.contains("datasetIds")
{
has_datasets = true;
}
}
},
}
_ => {
tracing::debug!("Unknown file type in import_assets: {:?}", file_type);
tracing::debug!(
"Unknown file type in import_assets: {:?}",
file_type
);
}
}
}
@ -113,17 +150,29 @@ impl ChatContextLoader {
// Set appropriate state values based on what we found
if has_metrics {
tracing::debug!("Setting metrics_available state to true");
agent.set_state_value(String::from("metrics_available"), Value::Bool(true))
agent
.set_state_value(
String::from("metrics_available"),
Value::Bool(true),
)
.await;
}
if has_dashboards {
tracing::debug!("Setting dashboards_available state to true");
agent.set_state_value(String::from("dashboards_available"), Value::Bool(true))
agent
.set_state_value(
String::from("dashboards_available"),
Value::Bool(true),
)
.await;
}
if has_datasets {
tracing::debug!("Setting data_context state to true");
agent.set_state_value(String::from("data_context"), Value::Bool(true))
agent
.set_state_value(
String::from("data_context"),
Value::Bool(true),
)
.await;
}
}
@ -136,22 +185,33 @@ impl ChatContextLoader {
#[async_trait]
impl ContextLoader for ChatContextLoader {
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
async fn load_context(
&self,
user: &AuthenticatedUser,
agent: &Arc<Agent>,
) -> Result<Vec<AgentMessage>> {
let mut conn = get_pg_pool().get().await?;
// First verify the chat exists and user has access
let chat = chats::table
.filter(chats::id.eq(self.chat_id))
.filter(chats::created_by.eq(&user.id))
.filter(chats::deleted_at.is_null())
.first::<database::models::Chat>(&mut conn)
.await?;
// Get only the most recent message for the chat
let message = messages::table
let message = match messages::table
.filter(messages::chat_id.eq(chat.id))
.filter(messages::deleted_at.is_null())
.order_by(messages::created_at.desc())
.first::<database::models::Message>(&mut conn)
.await?;
.await
{
Ok(message) => message,
Err(diesel::NotFound) => return Ok(vec![]),
Err(e) => return Err(anyhow::anyhow!("Failed to get message: {}", e)),
};
// Track seen message IDs
let mut seen_ids = HashSet::new();
@ -159,7 +219,9 @@ impl ContextLoader for ChatContextLoader {
let mut agent_messages = Vec::new();
// Process only the most recent message's raw LLM messages
if let Ok(raw_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages) {
if let Ok(raw_messages) =
serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages)
{
// Check each message for tool calls and update context
for agent_message in &raw_messages {
Self::update_context_from_tool_calls(agent, agent_message).await;

View File

@ -228,8 +228,6 @@ pub async fn post_chat_handler(
let messages = generate_asset_messages(asset_id_value, asset_type_value, &user).await?;
println!("messages: {:?}", messages);
// Add messages to chat and associate with chat_id
let mut updated_messages = Vec::new();
for mut message in messages {