Add new thread events for chat generation progress tracking

This commit is contained in:
dal 2025-02-11 11:21:57 -07:00
parent e7b96d9bd5
commit bf05c7f06b
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
4 changed files with 45 additions and 13 deletions

View File

@ -6,6 +6,7 @@ use serde::Serialize;
use serde_json::Value;
use uuid::Uuid;
use crate::routes::ws::threads_and_messages::threads_router::ThreadEvent;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
use crate::utils::tools::file_tools::create_files::CreateFilesOutput;
@ -257,7 +258,7 @@ pub struct BusterFileLine {
pub text: String,
}
pub fn transform_message(message: Message) -> Result<Vec<BusterThreadMessage>> {
pub fn transform_message(message: Message) -> Result<(Vec<BusterThreadMessage>, ThreadEvent)> {
println!("transform_message: {:?}", message);
match message {
@ -270,11 +271,31 @@ pub fn transform_message(message: Message) -> Result<Vec<BusterThreadMessage>> {
initial,
} => {
if let Some(content) = content {
return transform_text_message(id, content, progress);
let messages = match transform_text_message(id, content, progress) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingResponseMessage,
));
}
if let Some(tool_calls) = tool_calls {
return transform_assistant_tool_message(id, tool_calls, progress, initial);
let messages = match transform_assistant_tool_message(id, tool_calls, progress, initial) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingReasoningMessage,
));
}
Err(anyhow::anyhow!("Assistant message missing required fields"))
@ -287,7 +308,17 @@ pub fn transform_message(message: Message) -> Result<Vec<BusterThreadMessage>> {
progress,
} => {
if let Some(name) = name {
return transform_tool_message(id, name, content, progress);
let messages = match transform_tool_message(id, name, content, progress) {
Ok(messages) => messages,
Err(e) => {
return Err(e);
}
};
return Ok((
messages,
ThreadEvent::GeneratingReasoningMessage,
));
}
Err(anyhow::anyhow!("Tool message missing name field"))
@ -902,11 +933,7 @@ fn tool_create_file(
let mut messages = Vec::new();
for file in create_files_result.files {
let (name, file_type, content) = (
file.name,
file.file_type,
file.yml_content
);
let (name, file_type, content) = (file.name, file.file_type, file.yml_content);
let mut current_lines = Vec::new();
for (i, line) in content.lines().enumerate() {

View File

@ -134,7 +134,7 @@ impl AgentThreadHandler {
let response = WsResponseMessage::new_no_user(
WsRoutes::Threads(ThreadRoute::Post),
WsEvent::Threads(ThreadEvent::PostThread),
WsEvent::Threads(ThreadEvent::InitializeChat),
init_response,
None,
WsSendMethod::All,
@ -175,11 +175,11 @@ impl AgentThreadHandler {
while let Some(msg_result) = rx.recv().await {
if let Ok(msg) = msg_result {
match transform_message(msg) {
Ok(transformed_messages) => {
Ok((transformed_messages, event)) => {
for transformed in transformed_messages {
let response = WsResponseMessage::new_no_user(
WsRoutes::Threads(ThreadRoute::Post),
WsEvent::Threads(ThreadEvent::PostThread),
WsEvent::Threads(event.clone()),
transformed,
None,
WsSendMethod::All,

View File

@ -70,6 +70,11 @@ pub enum ThreadEvent {
Unsubscribed,
DuplicateThread,
SqlEvaluation,
InitializeChat,
GeneratingTitle,
GeneratingResponseMessage,
GeneratingReasoningMessage,
Complete,
}
pub async fn threads_router(

2
cli/.gitignore vendored
View File

@ -16,4 +16,4 @@ Cargo.lock
Makefile
.vscode/
/prd
/prds