mirror of https://github.com/buster-so/buster.git
ids and initial message repeat handled
This commit is contained in:
parent
973e9b41ce
commit
2376153459
|
@ -8,12 +8,12 @@ use uuid::Uuid;
|
||||||
|
|
||||||
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
|
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
|
||||||
|
|
||||||
|
use crate::utils::tools::file_tools::create_files::CreateFilesOutput;
|
||||||
use crate::utils::tools::file_tools::file_types::file::FileEnum;
|
use crate::utils::tools::file_tools::file_types::file::FileEnum;
|
||||||
use crate::utils::tools::file_tools::modify_files::ModifyFilesParams;
|
use crate::utils::tools::file_tools::modify_files::ModifyFilesParams;
|
||||||
use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
|
use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
|
||||||
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
|
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
|
||||||
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
|
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
|
||||||
use crate::utils::tools::file_tools::create_files::CreateFilesOutput;
|
|
||||||
|
|
||||||
struct StreamingParser {
|
struct StreamingParser {
|
||||||
buffer: String,
|
buffer: String,
|
||||||
|
@ -310,12 +310,14 @@ fn transform_text_message(
|
||||||
message_chunk: Some(content),
|
message_chunk: Some(content),
|
||||||
})])
|
})])
|
||||||
}
|
}
|
||||||
MessageProgress::Complete => Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
MessageProgress::Complete => {
|
||||||
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
|
||||||
message_type: "text".to_string(),
|
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||||
message: Some(content),
|
message_type: "text".to_string(),
|
||||||
message_chunk: None,
|
message: Some(content),
|
||||||
})]),
|
message_chunk: None,
|
||||||
|
})])
|
||||||
|
}
|
||||||
_ => Err(anyhow::anyhow!("Unsupported message progress")),
|
_ => Err(anyhow::anyhow!("Unsupported message progress")),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -352,12 +354,12 @@ fn transform_assistant_tool_message(
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterThreadMessage>> {
|
||||||
if let Some(tool_call) = tool_calls.first() {
|
if let Some(tool_call) = tool_calls.first() {
|
||||||
match tool_call.function.name.as_str() {
|
match tool_call.function.name.as_str() {
|
||||||
"search_data_catalog" => assistant_data_catalog_search(id, progress),
|
"search_data_catalog" => assistant_data_catalog_search(id, tool_calls, progress),
|
||||||
"stored_values_search" => assistant_stored_values_search(id, progress),
|
"stored_values_search" => assistant_stored_values_search(id, tool_calls, progress),
|
||||||
"search_files" => assistant_file_search(id, progress),
|
"search_files" => assistant_file_search(id, tool_calls, progress),
|
||||||
"create_files" => assistant_create_file(id, tool_calls, progress),
|
"create_files" => assistant_create_file(id, tool_calls, progress),
|
||||||
"modify_files" => assistant_modify_file(id, tool_calls, progress),
|
"modify_files" => assistant_modify_file(id, tool_calls, progress),
|
||||||
"open_files" => assistant_open_files(id, progress),
|
"open_files" => assistant_open_files(id, tool_calls, progress),
|
||||||
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
_ => Err(anyhow::anyhow!("Unsupported tool name")),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -367,6 +369,7 @@ fn transform_assistant_tool_message(
|
||||||
|
|
||||||
fn assistant_data_catalog_search(
|
fn assistant_data_catalog_search(
|
||||||
id: Option<String>,
|
id: Option<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterThreadMessage>> {
|
||||||
if let Some(progress) = progress {
|
if let Some(progress) = progress {
|
||||||
|
@ -510,6 +513,7 @@ fn proccess_data_catalog_search_results(
|
||||||
|
|
||||||
fn assistant_stored_values_search(
|
fn assistant_stored_values_search(
|
||||||
id: Option<String>,
|
id: Option<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterThreadMessage>> {
|
||||||
if let Some(progress) = progress {
|
if let Some(progress) = progress {
|
||||||
|
@ -562,6 +566,7 @@ fn tool_stored_values_search(
|
||||||
|
|
||||||
fn assistant_file_search(
|
fn assistant_file_search(
|
||||||
id: Option<String>,
|
id: Option<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterThreadMessage>> {
|
||||||
if let Some(progress) = progress {
|
if let Some(progress) = progress {
|
||||||
|
@ -685,6 +690,7 @@ fn process_file_search_results(
|
||||||
|
|
||||||
fn assistant_open_files(
|
fn assistant_open_files(
|
||||||
id: Option<String>,
|
id: Option<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
progress: Option<MessageProgress>,
|
progress: Option<MessageProgress>,
|
||||||
) -> Result<Vec<BusterThreadMessage>> {
|
) -> Result<Vec<BusterThreadMessage>> {
|
||||||
if let Some(progress) = progress {
|
if let Some(progress) = progress {
|
||||||
|
|
|
@ -253,95 +253,91 @@ impl Agent {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// First, make request with tool_choice set to none
|
let mut tool_thread = thread.clone();
|
||||||
let initial_request = ChatCompletionRequest {
|
|
||||||
model: model.to_string(),
|
|
||||||
messages: thread.messages.clone(),
|
|
||||||
tools: if tools.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(tools.clone())
|
|
||||||
},
|
|
||||||
tool_choice: Some(ToolChoice::None("none".to_string())),
|
|
||||||
stream: Some(true),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get streaming response for initial thoughts
|
// Only do initial message phase if this is the first call (recursion_depth = 0)
|
||||||
let mut initial_stream = llm_client.stream_chat_completion(initial_request).await?;
|
if recursion_depth == 0 {
|
||||||
let mut initial_message = Message::assistant(
|
// First, make request with tool_choice set to none
|
||||||
None,
|
let initial_request = ChatCompletionRequest {
|
||||||
Some(String::new()),
|
model: model.to_string(),
|
||||||
None,
|
messages: thread.messages.clone(),
|
||||||
None,
|
tools: if tools.is_empty() {
|
||||||
);
|
None
|
||||||
|
} else {
|
||||||
|
Some(tools.clone())
|
||||||
|
},
|
||||||
|
tool_choice: Some(ToolChoice::None("none".to_string())),
|
||||||
|
stream: Some(true),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
// Process initial stream chunks
|
// Get streaming response for initial thoughts
|
||||||
while let Some(chunk_result) = initial_stream.recv().await {
|
let mut initial_stream = llm_client.stream_chat_completion(initial_request).await?;
|
||||||
match chunk_result {
|
let mut initial_message = Message::assistant(
|
||||||
Ok(chunk) => {
|
None,
|
||||||
initial_message.set_id(chunk.id.clone());
|
Some(String::new()),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let delta = &chunk.choices[0].delta;
|
// Process initial stream chunks
|
||||||
|
while let Some(chunk_result) = initial_stream.recv().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
initial_message.set_id(chunk.id.clone());
|
||||||
|
|
||||||
// Handle content updates - send delta directly
|
let delta = &chunk.choices[0].delta;
|
||||||
if let Some(content) = &delta.content {
|
|
||||||
// Send the delta chunk immediately with InProgress
|
|
||||||
let _ = tx
|
|
||||||
.send(Ok(Message::assistant(
|
|
||||||
Some("initial_message".to_string()),
|
|
||||||
Some(content.clone()),
|
|
||||||
None,
|
|
||||||
Some(MessageProgress::InProgress),
|
|
||||||
)))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Also accumulate for our thread history
|
// Handle content updates - send delta directly
|
||||||
if let Message::Assistant {
|
if let Some(content) = &delta.content {
|
||||||
content: msg_content,
|
// Send the delta chunk immediately with InProgress
|
||||||
..
|
let _ = tx
|
||||||
} = &mut initial_message
|
.send(Ok(Message::assistant(
|
||||||
{
|
Some("initial_message".to_string()),
|
||||||
if let Some(existing) = msg_content {
|
Some(content.clone()),
|
||||||
existing.push_str(content);
|
None,
|
||||||
|
Some(MessageProgress::InProgress),
|
||||||
|
)))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Also accumulate for our thread history
|
||||||
|
if let Message::Assistant {
|
||||||
|
content: msg_content,
|
||||||
|
..
|
||||||
|
} = &mut initial_message
|
||||||
|
{
|
||||||
|
if let Some(existing) = msg_content {
|
||||||
|
existing.push_str(content);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
Err(e) => {
|
||||||
Err(e) => {
|
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
||||||
let _ = tx.send(Err(anyhow::Error::from(e))).await;
|
return Ok(());
|
||||||
return Ok(());
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure we have content in the initial message
|
||||||
|
let initial_content = match &initial_message {
|
||||||
|
Message::Assistant { content, .. } => content.clone().unwrap_or_default(),
|
||||||
|
_ => String::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Send the complete message with accumulated content
|
||||||
|
if !initial_content.trim().is_empty() {
|
||||||
|
let _ = tx
|
||||||
|
.send(Ok(Message::assistant(
|
||||||
|
Some("initial_message".to_string()),
|
||||||
|
Some(initial_content.clone()),
|
||||||
|
None,
|
||||||
|
Some(MessageProgress::Complete),
|
||||||
|
)))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have content in the initial message
|
|
||||||
let initial_content = match &initial_message {
|
|
||||||
Message::Assistant { content, .. } => content.clone().unwrap_or_default(),
|
|
||||||
_ => String::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Send the complete message with accumulated content
|
|
||||||
if !initial_content.trim().is_empty() {
|
|
||||||
let _ = tx
|
|
||||||
.send(Ok(Message::assistant(
|
|
||||||
Some("initial_message".to_string()),
|
|
||||||
Some(initial_content.clone()),
|
|
||||||
None,
|
|
||||||
Some(MessageProgress::Complete),
|
|
||||||
)))
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new thread with the initial response
|
|
||||||
let mut tool_thread = thread.clone();
|
|
||||||
tool_thread.messages.push(Message::assistant(
|
|
||||||
Some("initial_message".to_string()),
|
|
||||||
Some(initial_content.clone()),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
|
|
||||||
// Create the tool-enabled request
|
// Create the tool-enabled request
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
|
|
Loading…
Reference in New Issue