ids and initial message repeat handled

This commit is contained in:
dal 2025-02-11 09:36:28 -07:00
parent 973e9b41ce
commit 2376153459
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 94 additions and 92 deletions

View File

@ -8,12 +8,12 @@ use uuid::Uuid;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
use crate::utils::tools::file_tools::create_files::CreateFilesOutput;
use crate::utils::tools::file_tools::file_types::file::FileEnum;
use crate::utils::tools::file_tools::modify_files::ModifyFilesParams;
use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
use crate::utils::tools::file_tools::create_files::CreateFilesOutput;
struct StreamingParser {
buffer: String,
@ -310,12 +310,14 @@ fn transform_text_message(
message_chunk: Some(content),
})])
}
MessageProgress::Complete => Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
MessageProgress::Complete => {
Ok(vec![BusterThreadMessage::ChatMessage(BusterChatMessage {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
message_type: "text".to_string(),
message: Some(content),
message_chunk: None,
})]),
})])
}
_ => Err(anyhow::anyhow!("Unsupported message progress")),
}
} else {
@ -352,12 +354,12 @@ fn transform_assistant_tool_message(
) -> Result<Vec<BusterThreadMessage>> {
if let Some(tool_call) = tool_calls.first() {
match tool_call.function.name.as_str() {
"search_data_catalog" => assistant_data_catalog_search(id, progress),
"stored_values_search" => assistant_stored_values_search(id, progress),
"search_files" => assistant_file_search(id, progress),
"search_data_catalog" => assistant_data_catalog_search(id, tool_calls, progress),
"stored_values_search" => assistant_stored_values_search(id, tool_calls, progress),
"search_files" => assistant_file_search(id, tool_calls, progress),
"create_files" => assistant_create_file(id, tool_calls, progress),
"modify_files" => assistant_modify_file(id, tool_calls, progress),
"open_files" => assistant_open_files(id, progress),
"open_files" => assistant_open_files(id, tool_calls, progress),
_ => Err(anyhow::anyhow!("Unsupported tool name")),
}
} else {
@ -367,6 +369,7 @@ fn transform_assistant_tool_message(
fn assistant_data_catalog_search(
id: Option<String>,
tool_calls: Vec<ToolCall>,
progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> {
if let Some(progress) = progress {
@ -510,6 +513,7 @@ fn proccess_data_catalog_search_results(
fn assistant_stored_values_search(
id: Option<String>,
tool_calls: Vec<ToolCall>,
progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> {
if let Some(progress) = progress {
@ -562,6 +566,7 @@ fn tool_stored_values_search(
fn assistant_file_search(
id: Option<String>,
tool_calls: Vec<ToolCall>,
progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> {
if let Some(progress) = progress {
@ -685,6 +690,7 @@ fn process_file_search_results(
fn assistant_open_files(
id: Option<String>,
tool_calls: Vec<ToolCall>,
progress: Option<MessageProgress>,
) -> Result<Vec<BusterThreadMessage>> {
if let Some(progress) = progress {

View File

@ -253,6 +253,10 @@ impl Agent {
})
.collect();
let mut tool_thread = thread.clone();
// Only do initial message phase if this is the first call (recursion_depth = 0)
if recursion_depth == 0 {
// First, make request with tool_choice set to none
let initial_request = ChatCompletionRequest {
model: model.to_string(),
@ -332,15 +336,7 @@ impl Agent {
)))
.await;
}
// Create a new thread with the initial response
let mut tool_thread = thread.clone();
tool_thread.messages.push(Message::assistant(
Some("initial_message".to_string()),
Some(initial_content.clone()),
None,
None,
));
}
// Create the tool-enabled request
let request = ChatCompletionRequest {