From bf05c7f06b0902a7def442038638a1bedaae5804 Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 11 Feb 2025 11:21:57 -0700 Subject: [PATCH] Add new thread events for chat generation progress tracking --- .../post_thread/agent_message_transformer.rs | 45 +++++++++++++++---- .../post_thread/agent_thread.rs | 6 +-- .../ws/threads_and_messages/threads_router.rs | 5 +++ cli/.gitignore | 2 +- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs index 0a24e19b6..e73b91251 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_message_transformer.rs @@ -6,6 +6,7 @@ use serde::Serialize; use serde_json::Value; use uuid::Uuid; +use crate::routes::ws::threads_and_messages::threads_router::ThreadEvent; use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall}; use crate::utils::tools::file_tools::create_files::CreateFilesOutput; @@ -257,7 +258,7 @@ pub struct BusterFileLine { pub text: String, } -pub fn transform_message(message: Message) -> Result> { +pub fn transform_message(message: Message) -> Result<(Vec, ThreadEvent)> { println!("transform_message: {:?}", message); match message { @@ -270,11 +271,31 @@ pub fn transform_message(message: Message) -> Result> { initial, } => { if let Some(content) = content { - return transform_text_message(id, content, progress); + let messages = match transform_text_message(id, content, progress) { + Ok(messages) => messages, + Err(e) => { + return Err(e); + } + }; + + return Ok(( + messages, + ThreadEvent::GeneratingResponseMessage, + )); } if let Some(tool_calls) = tool_calls { - return transform_assistant_tool_message(id, tool_calls, progress, initial); + let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) { + Ok(messages) => messages, + Err(e) => { + return Err(e); + } + }; + + return Ok(( + messages, + ThreadEvent::GeneratingReasoningMessage, + )); } Err(anyhow::anyhow!("Assistant message missing required fields")) @@ -287,7 +308,17 @@ pub fn transform_message(message: Message) -> Result> { progress, } => { if let Some(name) = name { - return transform_tool_message(id, name, content, progress); + let messages = match transform_tool_message(id, name, content, progress) { + Ok(messages) => messages, + Err(e) => { + return Err(e); + } + }; + + return Ok(( + messages, + ThreadEvent::GeneratingReasoningMessage, + )); } Err(anyhow::anyhow!("Tool message missing name field")) @@ -902,11 +933,7 @@ fn tool_create_file( let mut messages = Vec::new(); for file in create_files_result.files { - let (name, file_type, content) = ( - file.name, - file.file_type, - file.yml_content - ); + let (name, file_type, content) = (file.name, file.file_type, file.yml_content); let mut current_lines = Vec::new(); for (i, line) in content.lines().enumerate() { diff --git a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs index 4a3068833..158422dab 100644 --- a/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs +++ b/api/src/routes/ws/threads_and_messages/post_thread/agent_thread.rs @@ -134,7 +134,7 @@ impl AgentThreadHandler { let response = WsResponseMessage::new_no_user( WsRoutes::Threads(ThreadRoute::Post), - WsEvent::Threads(ThreadEvent::PostThread), + WsEvent::Threads(ThreadEvent::InitializeChat), init_response, None, WsSendMethod::All, @@ -175,11 +175,11 @@ impl AgentThreadHandler { while let Some(msg_result) = rx.recv().await { if let Ok(msg) = msg_result { match transform_message(msg) { - Ok(transformed_messages) => { + Ok((transformed_messages, event)) => { for transformed in transformed_messages { let response = WsResponseMessage::new_no_user( WsRoutes::Threads(ThreadRoute::Post), - WsEvent::Threads(ThreadEvent::PostThread), + WsEvent::Threads(event.clone()), transformed, None, WsSendMethod::All, diff --git a/api/src/routes/ws/threads_and_messages/threads_router.rs b/api/src/routes/ws/threads_and_messages/threads_router.rs index 0fb700006..99f86e2a7 100644 --- a/api/src/routes/ws/threads_and_messages/threads_router.rs +++ b/api/src/routes/ws/threads_and_messages/threads_router.rs @@ -70,6 +70,11 @@ pub enum ThreadEvent { Unsubscribed, DuplicateThread, SqlEvaluation, + InitializeChat, + GeneratingTitle, + GeneratingResponseMessage, + GeneratingReasoningMessage, + Complete, } pub async fn threads_router( diff --git a/cli/.gitignore b/cli/.gitignore index 3be7153a3..6566ae633 100644 --- a/cli/.gitignore +++ b/cli/.gitignore @@ -16,4 +16,4 @@ Cargo.lock Makefile .vscode/ -/prd \ No newline at end of file +/prds \ No newline at end of file