From 8de08323fa39d3325912facb049fc34b792b1d5b Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 11 Feb 2025 08:10:58 -0700 Subject: [PATCH] consistent message id for text stream --- .../post_thread/agent_thread.rs | 7 +- api/src/utils/clients/ai/litellm/types.rs | 9 +++ api/src/utils/tools/file_tools/open_files.rs | 68 +++++++++++-------- 3 files changed, 52 insertions(+), 32 deletions(-) diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs index 0909d73ab..343b1c928 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs @@ -10,8 +10,8 @@ use crate::{ database::models::User, routes::ws::{ threads_and_messages::{ - threads_router::{ThreadEvent, ThreadRoute}, post_thread::agent_message_transformer::transform_message, + threads_router::{ThreadEvent, ThreadRoute}, }, ws::{WsEvent, WsResponseMessage, WsSendMethod}, ws_router::WsRoutes, @@ -115,8 +115,11 @@ impl AgentThreadHandler { ) { let subscription = user_id.to_string(); + let message_id = Uuid::new_v4().to_string(); + while let Some(msg_result) = rx.recv().await { - if let Ok(msg) = msg_result { + if let Ok(mut msg) = msg_result { + msg.set_id(message_id.clone()); match transform_message(msg) { Ok(transformed_messages) => { for transformed in transformed_messages { diff --git a/api/src/utils/clients/ai/litellm/types.rs b/api/src/utils/clients/ai/litellm/types.rs index ce678d576..86a17ca43 100644 --- a/api/src/utils/clients/ai/litellm/types.rs +++ b/api/src/utils/clients/ai/litellm/types.rs @@ -218,6 +218,15 @@ impl Message { _ => None, } } + + pub fn set_id(&mut self, new_id: String) { + match self { + Self::Assistant { id, .. } => *id = Some(new_id.clone()), + Self::Tool { id, .. } => *id = Some(new_id.clone()), + Self::Developer { id, .. } => *id = Some(new_id.clone()), + Self::User { id, .. } => *id = Some(new_id), + } + } } #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/api/src/utils/tools/file_tools/open_files.rs b/api/src/utils/tools/file_tools/open_files.rs index b23a92215..ef6f600bd 100644 --- a/api/src/utils/tools/file_tools/open_files.rs +++ b/api/src/utils/tools/file_tools/open_files.rs @@ -4,10 +4,14 @@ use diesel::prelude::*; use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::{collections::{HashMap, HashSet}, time::Instant}; +use std::{ + collections::{HashMap, HashSet}, + time::Instant, +}; use tracing::{debug, error, info, warn}; use uuid::Uuid; +use super::FileModificationTool; use crate::{ database::{ lib::get_pg_pool, @@ -22,7 +26,6 @@ use crate::{ tools::ToolExecutor, }, }; -use super::FileModificationTool; #[derive(Debug, Serialize, Deserialize)] struct FileRequest { @@ -64,12 +67,11 @@ impl ToolExecutor for OpenFilesTool { let start_time = Instant::now(); debug!("Starting file open operation"); - let params: OpenFilesParams = - serde_json::from_str(&tool_call.function.arguments.clone()) - .map_err(|e| { - error!(error = %e, "Failed to parse tool parameters"); - anyhow::anyhow!("Failed to parse tool parameters: {}", e) - })?; + let params: OpenFilesParams = serde_json::from_str(&tool_call.function.arguments.clone()) + .map_err(|e| { + error!(error = %e, "Failed to parse tool parameters"); + anyhow::anyhow!("Failed to parse tool parameters: {}", e) + })?; let mut results = Vec::new(); let mut error_messages = Vec::new(); @@ -174,7 +176,11 @@ impl ToolExecutor for OpenFilesTool { let duration = start_time.elapsed().as_millis(); - Ok(OpenFilesOutput { message, duration: duration as i64, results }) + Ok(OpenFilesOutput { + message, + duration: duration as i64, + results, + }) } fn get_schema(&self) -> Value { @@ -212,13 +218,10 @@ impl ToolExecutor for OpenFilesTool { async fn get_dashboard_files(ids: &[Uuid]) -> Result> { debug!(dashboard_ids = ?ids, "Fetching dashboard files"); - let mut conn = get_pg_pool() - .get() - .await - .map_err(|e| { - error!(error = %e, "Failed to get database connection"); - anyhow::anyhow!("Failed to get database connection: {}", e) - })?; + let mut conn = get_pg_pool().get().await.map_err(|e| { + error!(error = %e, "Failed to get database connection"); + anyhow::anyhow!("Failed to get database connection: {}", e) + })?; let files = match dashboard_files::table .filter(dashboard_files::id.eq_any(ids)) @@ -227,7 +230,10 @@ async fn get_dashboard_files(ids: &[Uuid]) -> Result { - debug!(count = files.len(), "Successfully loaded dashboard files from database"); + debug!( + count = files.len(), + "Successfully loaded dashboard files from database" + ); files } Err(e) => { @@ -266,13 +272,10 @@ async fn get_dashboard_files(ids: &[Uuid]) -> Result Result> { debug!(metric_ids = ?ids, "Fetching metric files"); - let mut conn = get_pg_pool() - .get() - .await - .map_err(|e| { - error!(error = %e, "Failed to get database connection"); - anyhow::anyhow!("Failed to get database connection: {}", e) - })?; + let mut conn = get_pg_pool().get().await.map_err(|e| { + error!(error = %e, "Failed to get database connection"); + anyhow::anyhow!("Failed to get database connection: {}", e) + })?; let files = match metric_files::table .filter(metric_files::id.eq_any(ids)) @@ -281,7 +284,10 @@ async fn get_metric_files(ids: &[Uuid]) -> Result .await { Ok(files) => { - debug!(count = files.len(), "Successfully loaded metric files from database"); + debug!( + count = files.len(), + "Successfully loaded metric files from database" + ); files } Err(e) => { @@ -356,7 +362,9 @@ fn build_status_message( #[cfg(test)] mod tests { - use crate::utils::tools::file_tools::file_types::metric_yml::{BarLineChartConfig, BaseChartConfig, BarAndLineAxis, ChartConfig, DataMetadata}; + use crate::utils::tools::file_tools::file_types::metric_yml::{ + BarAndLineAxis, BarLineChartConfig, BaseChartConfig, ChartConfig, DataMetadata, + }; use super::*; use chrono::Utc; @@ -366,7 +374,7 @@ mod tests { DashboardYml { id: Some(Uuid::new_v4()), updated_at: Some(Utc::now()), - name: Some("Test Dashboard".to_string()), + name: "Test Dashboard".to_string(), rows: vec![], } } @@ -412,9 +420,9 @@ mod tests { data_type: "number".to_string(), }, DataMetadata { - name: "value".to_string(), + name: "value".to_string(), data_type: "string".to_string(), - } + }, ]), } } @@ -533,7 +541,7 @@ mod tests { let dashboard = create_test_dashboard(); let test_files = vec![DashboardFile { id: test_id, - name: dashboard.name.clone().unwrap_or_default(), + name: dashboard.name.clone(), file_name: "test.yml".to_string(), content: serde_json::to_value(&dashboard).unwrap(), filter: None,