From 8899fb8549e340e614ee31eff91f1f3ae7e96870 Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 1 Apr 2025 11:26:29 -0600 Subject: [PATCH] timeout and file message --- .../handlers/src/chats/post_chat_handler.rs | 66 ++++++++++--------- api/src/routes/ws/ws.rs | 2 +- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index ff8633d38..4e0df3cf9 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -34,8 +34,9 @@ use uuid::Uuid; use crate::chats::{ asset_messages::{create_message_file_association, generate_asset_messages}, context_loaders::{ - chat_context::ChatContextLoader, create_asset_context_loader, dashboard_context::DashboardContextLoader, - fetch_asset_details, metric_context::MetricContextLoader, validate_context_request, ContextLoader, + chat_context::ChatContextLoader, create_asset_context_loader, + dashboard_context::DashboardContextLoader, fetch_asset_details, + metric_context::MetricContextLoader, validate_context_request, ContextLoader, }, get_chat_handler, streaming_parser::StreamingParser, @@ -156,12 +157,18 @@ pub async fn post_chat_handler( // Create a request-local chunk tracker instance instead of using global static let chunk_tracker = ChunkTracker::new(); let reasoning_duration = Instant::now(); - + // Normalize request to use asset_id/asset_type if legacy fields are provided let (asset_id, asset_type) = normalize_asset_fields(&request); - + // Validate that only one context type is provided - validate_context_request(request.chat_id, asset_id, asset_type, request.metric_id, request.dashboard_id)?; + validate_context_request( + request.chat_id, + asset_id, + asset_type, + request.metric_id, + request.dashboard_id, + )?; let user_org_id = match user.attributes.get("organization_id") { Some(Value::String(org_id)) => Uuid::parse_str(org_id).unwrap_or_default(), @@ -193,18 +200,14 @@ pub async fn post_chat_handler( if request.prompt.is_none() && asset_id.is_some() && asset_type.is_some() { let asset_id_value = asset_id.unwrap(); let asset_type_value = asset_type.unwrap(); - - let messages = generate_asset_messages( - asset_id_value, - asset_type_value, - &user, - ).await?; - + + let messages = generate_asset_messages(asset_id_value, asset_type_value, &user).await?; + // Add messages to chat and associate with chat_id let mut updated_messages = Vec::new(); for mut message in messages { message.chat_id = chat_id; - + // If this is a file message, create file association if message.response_messages.is_array() { let response_arr = message.response_messages.as_array().unwrap(); @@ -216,23 +219,24 @@ pub async fn post_chat_handler( message.id, asset_id_value, asset_type_value, - ).await; + ) + .await; } } } } - + // Insert message into database let mut conn = get_pg_pool().get().await?; insert_into(database::schema::messages::table) .values(&message) .execute(&mut conn) .await?; - + // Add to updated messages for the response updated_messages.push(message); } - + // Transform DB messages to ChatMessage format for response for message in updated_messages { let chat_message = ChatMessage::new_with_messages( @@ -249,10 +253,10 @@ pub async fn post_chat_handler( None, message.created_at, ); - + chat_with_messages.add_message(chat_message); } - + // Return early with auto-generated messages - no need for agent processing return Ok(chat_with_messages); } @@ -296,7 +300,9 @@ pub async fn post_chat_handler( } // Add the new user message (now with unwrap_or_default for optional prompt) - initial_messages.push(AgentMessage::user(request.prompt.clone().unwrap_or_default())); + initial_messages.push(AgentMessage::user( + request.prompt.clone().unwrap_or_default(), + )); // Initialize raw_llm_messages with initial_messages let mut raw_llm_messages = initial_messages.clone(); @@ -945,8 +951,8 @@ pub async fn transform_message( metadata: Some(vec![BusterChatResponseFileMetadata { status: "completed".to_string(), message: format!( - "File {} completed", - file_content.file_name + "Created new {}", + file_content.file_type ), timestamp: Some(Utc::now().timestamp()), }]), @@ -1031,8 +1037,8 @@ pub async fn transform_message( metadata: Some(vec![BusterChatResponseFileMetadata { status: "completed".to_string(), message: format!( - "File {} completed", - file_content.file_name + "Created new {}", + file_content.file_type ), timestamp: Some(Utc::now().timestamp()), }]), @@ -1836,11 +1842,11 @@ pub struct BusterGeneratingTitle { } /// Helper function to normalize legacy and new asset fields -/// +/// /// This function converts legacy asset fields (metric_id, dashboard_id) to the new /// generic asset_id/asset_type format. It ensures backward compatibility while /// using a single code path for processing assets. -/// +/// /// Returns a tuple of (Option, Option) representing the normalized /// asset reference. pub fn normalize_asset_fields(request: &ChatCreateNewChat) -> (Option, Option) { @@ -1848,16 +1854,16 @@ pub fn normalize_asset_fields(request: &ChatCreateNewChat) -> (Option, Opt if request.asset_id.is_some() && request.asset_type.is_some() { return (request.asset_id, request.asset_type); } - + // If legacy fields are provided, convert them to the new format if let Some(metric_id) = request.metric_id { return (Some(metric_id), Some(AssetType::MetricFile)); } - + if let Some(dashboard_id) = request.dashboard_id { return (Some(dashboard_id), Some(AssetType::DashboardFile)); } - + // No asset references (None, None) } @@ -1978,7 +1984,7 @@ async fn initialize_chat( user_org_id: Uuid, ) -> Result<(Uuid, Uuid, ChatWithMessages)> { let message_id = request.message_id.unwrap_or_else(Uuid::new_v4); - + // Get a default title for chats with optional prompt let default_title = match request.prompt { Some(ref prompt) => prompt.clone(), diff --git a/api/src/routes/ws/ws.rs b/api/src/routes/ws/ws.rs index 40905ddc4..f218728a8 100644 --- a/api/src/routes/ws/ws.rs +++ b/api/src/routes/ws/ws.rs @@ -36,7 +36,7 @@ use super::{ collections::collections_router::CollectionEvent, dashboards::dashboards_router::DashboardEvent, data_sources::data_sources_router::DataSourceEvent, datasets::datasets_router::DatasetEvent, metrics::MetricEvent, organizations::organization_router::OrganizationEvent, permissions::permissions_router::PermissionEvent, search::search_router::SearchEvent, sql::sql_router::SqlEvent, teams::teams_routes::TeamEvent, terms::terms_router::TermEvent, threads_and_messages::threads_router::ThreadEvent, users::users_router::UserEvent, ws_router::{ws_router, WsRoutes}, ws_utils::{subscribe_to_stream, unsubscribe_from_stream} }; -const CLIENT_TIMEOUT: Duration = Duration::from_secs(300); +const CLIENT_TIMEOUT: Duration = Duration::from_secs(900); const PING_INTERVAL: Duration = Duration::from_secs(15); const PING_TIMEOUT: Duration = Duration::from_secs(5);