its working, but not sure if it's how I want it...

This commit is contained in:
dal 2025-02-10 17:09:01 -07:00
parent 61153020ba
commit 43e2cf44f4
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
1 changed files with 367 additions and 29 deletions

View File

@ -2,6 +2,7 @@ use std::collections::HashMap;
use anyhow::Result;
use serde::Serialize;
use serde_json::Value;
use uuid::Uuid;
use crate::utils::clients::ai::litellm::{Message, MessageProgress, ToolCall};
@ -13,11 +14,344 @@ use crate::utils::tools::file_tools::open_files::OpenFilesOutput;
use crate::utils::tools::file_tools::search_data_catalog::SearchDataCatalogOutput;
use crate::utils::tools::file_tools::search_files::SearchFilesOutput;
#[derive(Clone)]
struct StreamingFileState {
id: String,
file_type: String,
file_name: String,
version_id: String,
current_lines: Vec<FileContent>,
line_buffer: String,
next_line_number: usize,
has_metadata: bool,
status: String,
}
enum ParsingState {
WaitingForMetadata,
StreamingFiles {
files: Vec<StreamingFileState>,
},
Complete,
}
struct StreamingParser {
state: ParsingState,
buffer: String,
}
impl StreamingParser {
fn new() -> Self {
Self {
state: ParsingState::WaitingForMetadata,
buffer: String::new(),
}
}
fn complete_json(&self, partial: &str) -> String {
let mut json = partial.to_string();
let mut result = String::with_capacity(json.len() * 2);
let mut brace_count = 0;
let mut bracket_count = 0;
let mut in_string = false;
let mut escape_next = false;
let mut in_yml_content = false;
let mut yml_content_start = 0;
// First pass: track state and identify yml_content
let chars: Vec<char> = json.chars().collect();
let mut i = 0;
while i < chars.len() {
let c = chars[i];
if escape_next {
result.push(c);
escape_next = false;
i += 1;
continue;
}
match c {
'{' if !in_string => {
brace_count += 1;
result.push(c);
}
'}' if !in_string => {
brace_count -= 1;
result.push(c);
}
'[' if !in_string => {
bracket_count += 1;
result.push(c);
}
']' if !in_string => {
bracket_count -= 1;
result.push(c);
}
'"' => {
if !escape_next {
// Check if we're entering yml_content
if !in_string && i >= 13 {
let prev = &chars[i-13..i];
let prev_str: String = prev.iter().collect();
if prev_str == "\"yml_content\":" {
in_yml_content = true;
yml_content_start = result.len() + 1;
}
}
// Check if we're exiting yml_content
if in_string && in_yml_content {
// Look ahead to see if this is really the end
if i + 1 < chars.len() {
match chars[i + 1] {
',' | '}' => in_yml_content = false,
_ => {} // Not the end, keep going
}
}
}
in_string = !in_string;
}
result.push(c);
}
'\\' => {
escape_next = true;
result.push(c);
}
_ => {
result.push(c);
}
}
i += 1;
}
// Second pass: complete any unclosed structures
if in_string {
// If we're in yml_content, we need to be careful about how we close it
if in_yml_content {
// Only add closing quote if we don't have an odd number of unescaped quotes
let yml_part = &result[yml_content_start..];
let mut quote_count = 0;
let mut was_escape = false;
for c in yml_part.chars() {
match c {
'"' if !was_escape => quote_count += 1,
'\\' => was_escape = !was_escape,
_ => was_escape = false
}
}
if quote_count % 2 == 0 {
result.push('"');
}
} else {
result.push('"');
}
}
// Close any unclosed arrays/objects
while bracket_count > 0 {
result.push(']');
bracket_count -= 1;
}
while brace_count > 0 {
result.push('}');
brace_count -= 1;
}
result
}
fn process_chunk(&mut self, chunk: &str) -> Result<Option<BusterThreadMessage>> {
self.buffer.push_str(chunk);
let completed_json = self.complete_json(&self.buffer);
println!("completed_json: {:?}", completed_json);
match &mut self.state {
ParsingState::WaitingForMetadata => {
if let Ok(partial) = serde_json::from_str::<Value>(&completed_json) {
if let Some(files_array) = partial.get("files").and_then(|f| f.as_array()) {
let mut streaming_files = Vec::new();
for file in files_array {
if let (Some(name), Some(file_type)) = (
file.get("name").and_then(|n| n.as_str()),
file.get("file_type").and_then(|t| t.as_str()),
) {
let mut file_state = StreamingFileState {
id: Uuid::new_v4().to_string(),
file_type: file_type.to_string(),
file_name: name.to_string(),
version_id: Uuid::new_v4().to_string(),
current_lines: Vec::new(),
line_buffer: String::new(),
next_line_number: 1,
has_metadata: true,
status: "loading".to_string(),
};
// Process any initial content that's available
if let Some(yml_content) = file.get("yml_content").and_then(|c| c.as_str()) {
if !yml_content.is_empty() {
file_state.line_buffer.push_str(yml_content);
// Process complete lines
let mut new_lines = Vec::new();
for line in file_state.line_buffer.lines() {
new_lines.push(FileContent {
text: line.to_string(),
line_number: file_state.next_line_number,
modified: true,
});
file_state.next_line_number += 1;
}
// Handle partial lines
if !file_state.line_buffer.ends_with('\n') {
if let Some(last_newline) = file_state.line_buffer.rfind('\n') {
file_state.line_buffer = file_state.line_buffer[last_newline + 1..].to_string();
}
} else {
file_state.line_buffer.clear();
}
file_state.current_lines.extend(new_lines);
}
}
streaming_files.push(file_state);
// Transition to StreamingFiles state as soon as we have metadata
if file == files_array.last().unwrap() {
let last_file = streaming_files.last().unwrap().clone();
self.state = ParsingState::StreamingFiles {
files: streaming_files,
};
return Ok(Some(BusterThreadMessage::File(BusterFileMessage {
id: last_file.id,
message_type: "file".to_string(),
file_type: last_file.file_type,
file_name: last_file.file_name,
version_number: 1,
version_id: last_file.version_id,
status: "loading".to_string(),
file: Some(last_file.current_lines),
})));
}
}
}
}
}
Ok(None)
}
ParsingState::StreamingFiles { files } => {
if let Ok(partial) = serde_json::from_str::<Value>(&completed_json) {
if let Some(files_array) = partial.get("files").and_then(|f| f.as_array()) {
// Process content for current file
if let Some(current_file) = files.last_mut() {
if let Some(file_data) = files_array.last() {
if let Some(yml_content) = file_data.get("yml_content").and_then(|c| c.as_str()) {
if yml_content.len() > current_file.line_buffer.len() {
let new_content = &yml_content[current_file.line_buffer.len()..];
current_file.line_buffer.push_str(new_content);
// Process complete lines
let mut new_lines = Vec::new();
for line in current_file.line_buffer.lines() {
new_lines.push(FileContent {
text: line.to_string(),
line_number: current_file.next_line_number,
modified: true,
});
current_file.next_line_number += 1;
}
// Handle partial lines
if !current_file.line_buffer.ends_with('\n') {
if let Some(last_newline) = current_file.line_buffer.rfind('\n') {
current_file.line_buffer = current_file.line_buffer[last_newline + 1..].to_string();
}
} else {
current_file.line_buffer.clear();
}
current_file.current_lines.extend(new_lines);
return Ok(Some(BusterThreadMessage::File(BusterFileMessage {
id: current_file.id.clone(),
message_type: "file".to_string(),
file_type: current_file.file_type.clone(),
file_name: current_file.file_name.clone(),
version_number: 1,
version_id: current_file.version_id.clone(),
status: "loading".to_string(),
file: Some(current_file.current_lines.clone()),
})));
}
}
}
}
// Check for new files
if files_array.len() > files.len() {
// Complete the current file if it exists
if let Some(current_file) = files.last_mut() {
current_file.status = "completed".to_string();
// Emit completion message for current file
let completion_message = BusterThreadMessage::File(BusterFileMessage {
id: current_file.id.clone(),
message_type: "file".to_string(),
file_type: current_file.file_type.clone(),
file_name: current_file.file_name.clone(),
version_number: 1,
version_id: current_file.version_id.clone(),
status: "completed".to_string(),
file: Some(current_file.current_lines.clone()),
});
// Add new file to state if we have its metadata
if let Some(new_file) = files_array.last() {
if let (Some(name), Some(file_type)) = (
new_file.get("name").and_then(|n| n.as_str()),
new_file.get("file_type").and_then(|t| t.as_str()),
) {
files.push(StreamingFileState {
id: Uuid::new_v4().to_string(),
file_type: file_type.to_string(),
file_name: name.to_string(),
version_id: Uuid::new_v4().to_string(),
current_lines: Vec::new(),
line_buffer: String::new(),
next_line_number: 1,
has_metadata: true,
status: "loading".to_string(),
});
}
}
return Ok(Some(completion_message));
}
}
}
}
Ok(None)
}
ParsingState::Complete => Ok(None),
}
}
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum BusterThreadMessage {
ChatMessage(BusterChatMessage),
Thought(BusterThought),
File(BusterFileMessage),
}
#[derive(Debug, Serialize)]
@ -54,6 +388,26 @@ pub struct BusterThoughtPill {
pub thought_file_type: String,
}
#[derive(Debug, Serialize)]
pub struct BusterFileMessage {
pub id: String,
#[serde(rename = "type")]
pub message_type: String,
pub file_type: String,
pub file_name: String,
pub version_number: i32,
pub version_id: String,
pub status: String,
pub file: Option<Vec<FileContent>>,
}
#[derive(Debug, Serialize, Clone)]
pub struct FileContent {
pub text: String,
pub line_number: usize,
pub modified: bool,
}
pub fn transform_message(message: Message) -> Result<BusterThreadMessage> {
println!("transform_message: {:?}", message);
@ -576,35 +930,9 @@ fn assistant_create_file(
MessageProgress::InProgress => {
// Try to parse the tool call arguments to get file metadata
if let Some(tool_call) = tool_calls.first() {
if let Ok(params) =
serde_json::from_str::<CreateFilesParams>(&tool_call.function.arguments)
{
if let Some(file) = params.files.first() {
return Ok(BusterThreadMessage::Thought(BusterThought {
id: id.unwrap_or_else(|| Uuid::new_v4().to_string()),
thought_type: "thought".to_string(),
thought_title: format!(
"Creating {} file '{}'...",
file.file_type, file.name
),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "loading".to_string(),
}));
}
}
return process_assistant_create_file(tool_call);
}
// Fall back to generic message if we can't parse the metadata
let id = id.unwrap_or_else(|| Uuid::new_v4().to_string());
Ok(BusterThreadMessage::Thought(BusterThought {
id,
thought_type: "thought".to_string(),
thought_title: "Creating a new file...".to_string(),
thought_secondary_title: "".to_string(),
thought_pills: None,
status: "loading".to_string(),
}))
Err(anyhow::anyhow!("No tool call found"))
}
_ => Err(anyhow::anyhow!(
"Assistant create file only supports in progress."
@ -615,7 +943,17 @@ fn assistant_create_file(
}
}
fn process_assistant_create_file() -> Result<BusterThreadMessage> {}
fn process_assistant_create_file(tool_call: &ToolCall) -> Result<BusterThreadMessage> {
let mut parser = StreamingParser::new();
// Process the arguments from the tool call
if let Some(message) = parser.process_chunk(&tool_call.function.arguments)? {
return Ok(message);
}
// Return None by returning Ok(None) wrapped in a Result
Err(anyhow::anyhow!("Still waiting for file data"))
}
fn assistant_modify_file(
id: Option<String>,