Add new thread events for chat generation progress tracking

This commit is contained in:
dal 2025-02-11 11:21:57 -07:00
parent e7b96d9bd5
commit bf05c7f06b
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 45 additions and 13 deletions

View File

@ -6,6 +6,7 @@ use serde::Serialize;
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use crate::routes::ws::threads_and_messages::threads_router::ThreadEvent;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall}; use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
use crate::utils::tools::file_tools::create_files::CreateFilesOutput; use crate::utils::tools::file_tools::create_files::CreateFilesOutput;
@ -257,7 +258,7 @@ pub struct BusterFileLine {
pub text: String, pub text: String,
} }
pub fn transform_message(message: Message) -> Result<Vec<BusterThreadMessage>> { pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>, ThreadEvent)> {
println!("transform_message: {:?}", message); println!("transform_message: {:?}", message);
match message { match message {
@ -270,11 +271,31 @@ pub fn transform_message(message: Message) -> Result<Vec<BusterThreadMessage>> {
initial, initial,
} => { } => {
if let Some(content) = content { if let Some(content) = content {
return transform_text_message(id, content, progress); let messages = match transform_text_message(id, content, progress) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingResponseMessage,
));
} }
if let Some(tool_calls) = tool_calls { if let Some(tool_calls) = tool_calls {
return transform_assistant_tool_message(id, tool_calls, progress, initial); let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingReasoningMessage,
));
} }
Err(anyhow::anyhow!("Assistant message missing required fields")) Err(anyhow::anyhow!("Assistant message missing required fields"))
@ -287,7 +308,17 @@ pub fn transform_message(message: Message) -> Result<Vec<BusterThreadMessage>> {
progress, progress,
} => { } => {
if let Some(name) = name { if let Some(name) = name {
return transform_tool_message(id, name, content, progress); let messages = match transform_tool_message(id, name, content, progress) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingReasoningMessage,
));
} }
Err(anyhow::anyhow!("Tool message missing name field")) Err(anyhow::anyhow!("Tool message missing name field"))
@ -902,11 +933,7 @@ fn tool_create_file(
let mut messages = Vec::new(); let mut messages = Vec::new();
for file in create_files_result.files { for file in create_files_result.files {
let (name, file_type, content) = ( let (name, file_type, content) = (file.name, file.file_type, file.yml_content);
file.name,
file.file_type,
file.yml_content
);
let mut current_lines = Vec::new(); let mut current_lines = Vec::new();
for (i, line) in content.lines().enumerate() { for (i, line) in content.lines().enumerate() {

View File

@ -134,7 +134,7 @@ impl AgentThreadHandler {
let response = WsResponseMessage::new_no_user( let response = WsResponseMessage::new_no_user(
WsRoutes::Threads(ThreadRoute::Post), WsRoutes::Threads(ThreadRoute::Post),
WsEvent::Threads(ThreadEvent::PostThread), WsEvent::Threads(ThreadEvent::InitializeChat),
init_response, init_response,
None, None,
WsSendMethod::All, WsSendMethod::All,
@ -175,11 +175,11 @@ impl AgentThreadHandler {
while let Some(msg_result) = rx.recv().await { while let Some(msg_result) = rx.recv().await {
if let Ok(msg) = msg_result { if let Ok(msg) = msg_result {
match transform_message(msg) { match transform_message(msg) {
Ok(transformed_messages) => { Ok((transformed_messages, event)) => {
for transformed in transformed_messages { for transformed in transformed_messages {
let response = WsResponseMessage::new_no_user( let response = WsResponseMessage::new_no_user(
WsRoutes::Threads(ThreadRoute::Post), WsRoutes::Threads(ThreadRoute::Post),
WsEvent::Threads(ThreadEvent::PostThread), WsEvent::Threads(event.clone()),
transformed, transformed,
None, None,
WsSendMethod::All, WsSendMethod::All,

View File

@ -70,6 +70,11 @@ pub enum ThreadEvent {
Unsubscribed, Unsubscribed,
DuplicateThread, DuplicateThread,
SqlEvaluation, SqlEvaluation,
InitializeChat,
GeneratingTitle,
GeneratingResponseMessage,
GeneratingReasoningMessage,
Complete,
} }
pub async fn threads_router( pub async fn threads_router(

2
cli/.gitignore vendored
View File

@ -16,4 +16,4 @@ Cargo.lock
Makefile Makefile
.vscode/ .vscode/
/prd /prds