consistent message id for text stream

This commit is contained in:
dal 2025-02-11 08:10:58 -07:00
parent bcf1ac1a65
commit 8de08323fa
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
3 changed files with 52 additions and 32 deletions

View File

@ -10,8 +10,8 @@ use crate::{
database::models::User,
routes::ws::{
threads_and_messages::{
threads_router::{ThreadEvent, ThreadRoute},
post_thread::agent_message_transformer::transform_message,
threads_router::{ThreadEvent, ThreadRoute},
},
ws::{WsEvent, WsResponseMessage, WsSendMethod},
ws_router::WsRoutes,
@ -115,8 +115,11 @@ impl AgentThreadHandler {
) {
let subscription = user_id.to_string();
let message_id = Uuid::new_v4().to_string();
while let Some(msg_result) = rx.recv().await {
if let Ok(msg) = msg_result {
if let Ok(mut msg) = msg_result {
msg.set_id(message_id.clone());
match transform_message(msg) {
Ok(transformed_messages) => {
for transformed in transformed_messages {

View File

@ -218,6 +218,15 @@ impl Message {
_ => None,
}
}
pub fn set_id(&mut self, new_id: String) {
match self {
Self::Assistant { id, .. } => *id = Some(new_id.clone()),
Self::Tool { id, .. } => *id = Some(new_id.clone()),
Self::Developer { id, .. } => *id = Some(new_id.clone()),
Self::User { id, .. } => *id = Some(new_id),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]

View File

@ -4,10 +4,14 @@ use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::{HashMap, HashSet}, time::Instant};
use std::{
collections::{HashMap, HashSet},
time::Instant,
};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use super::FileModificationTool;
use crate::{
database::{
lib::get_pg_pool,
@ -22,7 +26,6 @@ use crate::{
tools::ToolExecutor,
},
};
use super::FileModificationTool;
#[derive(Debug, Serialize, Deserialize)]
struct FileRequest {
@ -64,12 +67,11 @@ impl ToolExecutor for OpenFilesTool {
let start_time = Instant::now();
debug!("Starting file open operation");
let params: OpenFilesParams =
serde_json::from_str(&tool_call.function.arguments.clone())
.map_err(|e| {
error!(error = %e, "Failed to parse tool parameters");
anyhow::anyhow!("Failed to parse tool parameters: {}", e)
})?;
let params: OpenFilesParams = serde_json::from_str(&tool_call.function.arguments.clone())
.map_err(|e| {
error!(error = %e, "Failed to parse tool parameters");
anyhow::anyhow!("Failed to parse tool parameters: {}", e)
})?;
let mut results = Vec::new();
let mut error_messages = Vec::new();
@ -174,7 +176,11 @@ impl ToolExecutor for OpenFilesTool {
let duration = start_time.elapsed().as_millis();
Ok(OpenFilesOutput { message, duration: duration as i64, results })
Ok(OpenFilesOutput {
message,
duration: duration as i64,
results,
})
}
fn get_schema(&self) -> Value {
@ -212,13 +218,10 @@ impl ToolExecutor for OpenFilesTool {
async fn get_dashboard_files(ids: &[Uuid]) -> Result<Vec<(DashboardYml, Uuid, String)>> {
debug!(dashboard_ids = ?ids, "Fetching dashboard files");
let mut conn = get_pg_pool()
.get()
.await
.map_err(|e| {
error!(error = %e, "Failed to get database connection");
anyhow::anyhow!("Failed to get database connection: {}", e)
})?;
let mut conn = get_pg_pool().get().await.map_err(|e| {
error!(error = %e, "Failed to get database connection");
anyhow::anyhow!("Failed to get database connection: {}", e)
})?;
let files = match dashboard_files::table
.filter(dashboard_files::id.eq_any(ids))
@ -227,7 +230,10 @@ async fn get_dashboard_files(ids: &[Uuid]) -> Result<Vec<(DashboardYml, Uuid, St
.await
{
Ok(files) => {
debug!(count = files.len(), "Successfully loaded dashboard files from database");
debug!(
count = files.len(),
"Successfully loaded dashboard files from database"
);
files
}
Err(e) => {
@ -266,13 +272,10 @@ async fn get_dashboard_files(ids: &[Uuid]) -> Result<Vec<(DashboardYml, Uuid, St
async fn get_metric_files(ids: &[Uuid]) -> Result<Vec<(MetricYml, Uuid, String)>> {
debug!(metric_ids = ?ids, "Fetching metric files");
let mut conn = get_pg_pool()
.get()
.await
.map_err(|e| {
error!(error = %e, "Failed to get database connection");
anyhow::anyhow!("Failed to get database connection: {}", e)
})?;
let mut conn = get_pg_pool().get().await.map_err(|e| {
error!(error = %e, "Failed to get database connection");
anyhow::anyhow!("Failed to get database connection: {}", e)
})?;
let files = match metric_files::table
.filter(metric_files::id.eq_any(ids))
@ -281,7 +284,10 @@ async fn get_metric_files(ids: &[Uuid]) -> Result<Vec<(MetricYml, Uuid, String)>
.await
{
Ok(files) => {
debug!(count = files.len(), "Successfully loaded metric files from database");
debug!(
count = files.len(),
"Successfully loaded metric files from database"
);
files
}
Err(e) => {
@ -356,7 +362,9 @@ fn build_status_message(
#[cfg(test)]
mod tests {
use crate::utils::tools::file_tools::file_types::metric_yml::{BarLineChartConfig, BaseChartConfig, BarAndLineAxis, ChartConfig, DataMetadata};
use crate::utils::tools::file_tools::file_types::metric_yml::{
BarAndLineAxis, BarLineChartConfig, BaseChartConfig, ChartConfig, DataMetadata,
};
use super::*;
use chrono::Utc;
@ -366,7 +374,7 @@ mod tests {
DashboardYml {
id: Some(Uuid::new_v4()),
updated_at: Some(Utc::now()),
name: Some("Test Dashboard".to_string()),
name: "Test Dashboard".to_string(),
rows: vec![],
}
}
@ -414,7 +422,7 @@ mod tests {
DataMetadata {
name: "value".to_string(),
data_type: "string".to_string(),
}
},
]),
}
}
@ -533,7 +541,7 @@ mod tests {
let dashboard = create_test_dashboard();
let test_files = vec![DashboardFile {
id: test_id,
name: dashboard.name.clone().unwrap_or_default(),
name: dashboard.name.clone(),
file_name: "test.yml".to_string(),
content: serde_json::to_value(&dashboard).unwrap(),
filter: None,