mirror of https://github.com/buster-so/buster.git
ok reverting to old commit will have to come back to streaming problems later.
This commit is contained in:
parent
bfdf06be16
commit
e917dacafe
|
@ -6,7 +6,6 @@ members = [
|
|||
"libs/database",
|
||||
"libs/agents",
|
||||
"libs/query_engine",
|
||||
"libs/streaming",
|
||||
]
|
||||
|
||||
# Define shared dependencies for all workspace members
|
||||
|
|
|
@ -15,7 +15,6 @@ uuid = { workspace = true }
|
|||
litellm = { path = "../litellm" }
|
||||
database = { path = "../database" }
|
||||
query_engine = { path = "../query_engine" }
|
||||
middleware = { path = "../middleware" }
|
||||
serde_json = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
use crate::tools::{IntoToolCallExecutor, ToolExecutor};
|
||||
use anyhow::Result;
|
||||
use litellm::{
|
||||
LiteLlmMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
AgentMessage, ChatCompletionRequest, DeltaToolCall, FunctionCall, LiteLLMClient,
|
||||
MessageProgress, Metadata, Tool, ToolCall, ToolChoice,
|
||||
};
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde_json::Value;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use uuid::Uuid;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::models::AgentThread;
|
||||
|
||||
|
@ -24,7 +23,7 @@ impl std::fmt::Display for AgentError {
|
|||
}
|
||||
}
|
||||
|
||||
type MessageResult = Result<LiteLlmMessage, AgentError>;
|
||||
type MessageResult = Result<AgentMessage, AgentError>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MessageBuffer {
|
||||
|
@ -35,6 +34,7 @@ struct MessageBuffer {
|
|||
first_message_sent: bool,
|
||||
}
|
||||
|
||||
|
||||
impl MessageBuffer {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
|
@ -78,13 +78,9 @@ impl MessageBuffer {
|
|||
};
|
||||
|
||||
// Create and send the message
|
||||
let message = LiteLlmMessage::assistant(
|
||||
let message = AgentMessage::assistant(
|
||||
self.message_id.clone(),
|
||||
if self.content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.content.clone())
|
||||
},
|
||||
if self.content.is_empty() { None } else { Some(self.content.clone()) },
|
||||
tool_calls,
|
||||
MessageProgress::InProgress,
|
||||
Some(!self.first_message_sent),
|
||||
|
@ -92,7 +88,7 @@ impl MessageBuffer {
|
|||
);
|
||||
|
||||
agent.get_stream_sender().await.send(Ok(message))?;
|
||||
|
||||
|
||||
// Update state
|
||||
self.first_message_sent = true;
|
||||
self.last_flush = Instant::now();
|
||||
|
@ -102,6 +98,8 @@ impl MessageBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
/// The Agent struct is responsible for managing conversations with the LLM
|
||||
/// and coordinating tool executions. It maintains a registry of available tools
|
||||
|
@ -124,7 +122,7 @@ pub struct Agent {
|
|||
/// Sender for streaming messages from this agent and sub-agents
|
||||
stream_tx: Arc<RwLock<Option<broadcast::Sender<MessageResult>>>>,
|
||||
/// The user ID for the current thread
|
||||
user: AuthenticatedUser,
|
||||
user_id: Uuid,
|
||||
/// The session ID for the current thread
|
||||
session_id: Uuid,
|
||||
/// Agent name
|
||||
|
@ -138,7 +136,7 @@ impl Agent {
|
|||
pub fn new(
|
||||
model: String,
|
||||
tools: HashMap<String, Box<dyn ToolExecutor<Output = Value, Params = Value> + Send + Sync>>,
|
||||
user: AuthenticatedUser,
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
name: String,
|
||||
) -> Self {
|
||||
|
@ -157,7 +155,7 @@ impl Agent {
|
|||
state: Arc::new(RwLock::new(HashMap::new())),
|
||||
current_thread: Arc::new(RwLock::new(None)),
|
||||
stream_tx: Arc::new(RwLock::new(Some(tx))),
|
||||
user,
|
||||
user_id,
|
||||
session_id,
|
||||
shutdown_tx: Arc::new(RwLock::new(shutdown_tx)),
|
||||
name,
|
||||
|
@ -178,7 +176,7 @@ impl Agent {
|
|||
state: Arc::clone(&existing_agent.state),
|
||||
current_thread: Arc::clone(&existing_agent.current_thread),
|
||||
stream_tx: Arc::clone(&existing_agent.stream_tx),
|
||||
user: existing_agent.user.clone(),
|
||||
user_id: existing_agent.user_id,
|
||||
session_id: existing_agent.session_id,
|
||||
shutdown_tx: Arc::clone(&existing_agent.shutdown_tx),
|
||||
name,
|
||||
|
@ -243,11 +241,7 @@ impl Agent {
|
|||
}
|
||||
|
||||
pub fn get_user_id(&self) -> Uuid {
|
||||
self.user.id
|
||||
}
|
||||
|
||||
pub fn get_user(&self) -> AuthenticatedUser {
|
||||
self.user.clone()
|
||||
self.user_id
|
||||
}
|
||||
|
||||
pub fn get_session_id(&self) -> Uuid {
|
||||
|
@ -259,7 +253,7 @@ impl Agent {
|
|||
}
|
||||
|
||||
/// Get the complete conversation history of the current thread
|
||||
pub async fn get_conversation_history(&self) -> Option<Vec<LiteLlmMessage>> {
|
||||
pub async fn get_conversation_history(&self) -> Option<Vec<AgentMessage>> {
|
||||
self.current_thread
|
||||
.read()
|
||||
.await
|
||||
|
@ -268,7 +262,7 @@ impl Agent {
|
|||
}
|
||||
|
||||
/// Update the current thread with a new message
|
||||
async fn update_current_thread(&self, message: LiteLlmMessage) -> Result<()> {
|
||||
async fn update_current_thread(&self, message: AgentMessage) -> Result<()> {
|
||||
let mut thread_lock = self.current_thread.write().await;
|
||||
if let Some(thread) = thread_lock.as_mut() {
|
||||
thread.messages.push(message);
|
||||
|
@ -322,7 +316,7 @@ impl Agent {
|
|||
///
|
||||
/// # Returns
|
||||
/// * A Result containing the final Message from the assistant
|
||||
pub async fn process_thread(&self, thread: &AgentThread) -> Result<LiteLlmMessage> {
|
||||
pub async fn process_thread(&self, thread: &AgentThread) -> Result<AgentMessage> {
|
||||
let mut rx = self.process_thread_streaming(thread).await?;
|
||||
|
||||
let mut final_message = None;
|
||||
|
@ -359,13 +353,13 @@ impl Agent {
|
|||
let err_msg = format!("Error processing thread: {:?}", e);
|
||||
let _ = agent_clone.get_stream_sender().await.send(Err(AgentError(err_msg)));
|
||||
// Send Done message after error
|
||||
let _ = agent_clone.get_stream_sender().await.send(Ok(LiteLlmMessage::Done));
|
||||
let _ = agent_clone.get_stream_sender().await.send(Ok(AgentMessage::Done));
|
||||
}
|
||||
},
|
||||
_ = shutdown_rx.recv() => {
|
||||
// Send shutdown notification
|
||||
let _ = agent_clone.get_stream_sender().await.send(
|
||||
Ok(LiteLlmMessage::assistant(
|
||||
Ok(AgentMessage::assistant(
|
||||
Some("shutdown_message".to_string()),
|
||||
Some("Processing interrupted due to shutdown signal".to_string()),
|
||||
None,
|
||||
|
@ -375,7 +369,7 @@ impl Agent {
|
|||
))
|
||||
);
|
||||
// Send Done message after shutdown
|
||||
let _ = agent_clone.get_stream_sender().await.send(Ok(LiteLlmMessage::Done));
|
||||
let _ = agent_clone.get_stream_sender().await.send(Ok(AgentMessage::Done));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -395,7 +389,7 @@ impl Agent {
|
|||
}
|
||||
|
||||
if recursion_depth >= 30 {
|
||||
let message = LiteLlmMessage::assistant(
|
||||
let message = AgentMessage::assistant(
|
||||
Some("max_recursion_depth_message".to_string()),
|
||||
Some("I apologize, but I've reached the maximum number of actions (30). Please try breaking your request into smaller parts.".to_string()),
|
||||
None,
|
||||
|
@ -456,8 +450,7 @@ impl Agent {
|
|||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone().unwrap_or_else(|| {
|
||||
buffer
|
||||
.tool_calls
|
||||
buffer.tool_calls
|
||||
.keys()
|
||||
.next()
|
||||
.map(|s| s.clone())
|
||||
|
@ -465,8 +458,7 @@ impl Agent {
|
|||
});
|
||||
|
||||
// Get or create the pending tool call
|
||||
let pending_call = buffer
|
||||
.tool_calls
|
||||
let pending_call = buffer.tool_calls
|
||||
.entry(id.clone())
|
||||
.or_insert_with(PendingToolCall::new);
|
||||
|
||||
|
@ -492,8 +484,7 @@ impl Agent {
|
|||
// Create and send the final message
|
||||
let final_tool_calls: Option<Vec<ToolCall>> = if !buffer.tool_calls.is_empty() {
|
||||
Some(
|
||||
buffer
|
||||
.tool_calls
|
||||
buffer.tool_calls
|
||||
.values()
|
||||
.map(|p| p.clone().into_tool_call())
|
||||
.collect(),
|
||||
|
@ -502,13 +493,9 @@ impl Agent {
|
|||
None
|
||||
};
|
||||
|
||||
let final_message = LiteLlmMessage::assistant(
|
||||
let final_message = AgentMessage::assistant(
|
||||
buffer.message_id,
|
||||
if buffer.content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.content)
|
||||
},
|
||||
if buffer.content.is_empty() { None } else { Some(buffer.content) },
|
||||
final_tool_calls.clone(),
|
||||
MessageProgress::Complete,
|
||||
Some(false),
|
||||
|
@ -528,7 +515,7 @@ impl Agent {
|
|||
// Send Done message and return
|
||||
self.get_stream_sender()
|
||||
.await
|
||||
.send(Ok(LiteLlmMessage::Done))?;
|
||||
.send(Ok(AgentMessage::Done))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -540,9 +527,9 @@ impl Agent {
|
|||
for tool_call in tool_calls {
|
||||
if let Some(tool) = self.tools.read().await.get(&tool_call.function.name) {
|
||||
let params: Value = serde_json::from_str(&tool_call.function.arguments)?;
|
||||
let result = tool.execute(params, tool_call.id.clone(), self.get_user()).await?;
|
||||
let result = tool.execute(params, tool_call.id.clone()).await?;
|
||||
let result_str = serde_json::to_string(&result)?;
|
||||
let tool_message = LiteLlmMessage::tool(
|
||||
let tool_message = AgentMessage::tool(
|
||||
None,
|
||||
result_str,
|
||||
tool_call.id.clone(),
|
||||
|
@ -571,7 +558,7 @@ impl Agent {
|
|||
// Send Done message and return
|
||||
self.get_stream_sender()
|
||||
.await
|
||||
.send(Ok(LiteLlmMessage::Done))?;
|
||||
.send(Ok(AgentMessage::Done))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -670,7 +657,7 @@ pub trait AgentExt {
|
|||
(*self.get_agent()).process_thread_streaming(thread).await
|
||||
}
|
||||
|
||||
async fn process_thread(&self, thread: &AgentThread) -> Result<LiteLlmMessage> {
|
||||
async fn process_thread(&self, thread: &AgentThread) -> Result<AgentMessage> {
|
||||
(*self.get_agent()).process_thread(thread).await
|
||||
}
|
||||
|
||||
|
@ -684,32 +671,12 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::tools::ToolExecutor;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Utc};
|
||||
use litellm::MessageProgress;
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
use middleware::types::AuthenticatedUser;
|
||||
|
||||
fn setup() {
|
||||
dotenv::dotenv().ok();
|
||||
std::env::set_var("LLM_API_KEY", "test_key");
|
||||
std::env::set_var("LLM_BASE_URL", "http://localhost:8000");
|
||||
}
|
||||
|
||||
// Create a mock AuthenticatedUser for testing
|
||||
fn create_test_user() -> AuthenticatedUser {
|
||||
AuthenticatedUser {
|
||||
id: Uuid::new_v4(),
|
||||
email: "test@example.com".to_string(),
|
||||
name: Some("Test User".to_string()),
|
||||
config: json!({}),
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
attributes: json!({}),
|
||||
avatar_url: None,
|
||||
organizations: vec![],
|
||||
teams: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
struct WeatherTool {
|
||||
|
@ -729,8 +696,13 @@ mod tests {
|
|||
tool_id: String,
|
||||
progress: MessageProgress,
|
||||
) -> Result<()> {
|
||||
let message =
|
||||
LiteLlmMessage::tool(None, content, tool_id, Some(self.get_name()), progress);
|
||||
let message = AgentMessage::tool(
|
||||
None,
|
||||
content,
|
||||
tool_id,
|
||||
Some(self.get_name()),
|
||||
progress,
|
||||
);
|
||||
self.agent.get_stream_sender().await.send(Ok(message))?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -741,12 +713,7 @@ mod tests {
|
|||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
params: Self::Params,
|
||||
tool_call_id: String,
|
||||
user: AuthenticatedUser,
|
||||
) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
self.send_progress(
|
||||
"Fetching weather data...".to_string(),
|
||||
"123".to_string(),
|
||||
|
@ -811,15 +778,15 @@ mod tests {
|
|||
let agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
create_test_user().id,
|
||||
vec![LiteLlmMessage::user("Hello, world!".to_string())],
|
||||
Uuid::new_v4(),
|
||||
vec![AgentMessage::user("Hello, world!".to_string())],
|
||||
);
|
||||
|
||||
let response = match agent.process_thread(&thread).await {
|
||||
|
@ -836,7 +803,7 @@ mod tests {
|
|||
let mut agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
@ -849,8 +816,8 @@ mod tests {
|
|||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
create_test_user().id,
|
||||
vec![LiteLlmMessage::user(
|
||||
Uuid::new_v4(),
|
||||
vec![AgentMessage::user(
|
||||
"What is the weather in vineyard ut?".to_string(),
|
||||
)],
|
||||
);
|
||||
|
@ -869,7 +836,7 @@ mod tests {
|
|||
let mut agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
@ -880,8 +847,8 @@ mod tests {
|
|||
|
||||
let thread = AgentThread::new(
|
||||
None,
|
||||
create_test_user().id,
|
||||
vec![LiteLlmMessage::user(
|
||||
Uuid::new_v4(),
|
||||
vec![AgentMessage::user(
|
||||
"What is the weather in vineyard ut and san francisco?".to_string(),
|
||||
)],
|
||||
);
|
||||
|
@ -900,7 +867,7 @@ mod tests {
|
|||
let agent = Agent::new(
|
||||
"o1".to_string(),
|
||||
HashMap::new(),
|
||||
create_test_user(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
);
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
use anyhow::Result;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
@ -20,14 +19,14 @@ use crate::{
|
|||
Agent, AgentError, AgentExt, AgentThread,
|
||||
};
|
||||
|
||||
use litellm::LiteLlmMessage;
|
||||
use litellm::AgentMessage as AgentMessage;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BusterSuperAgentOutput {
|
||||
pub message: String,
|
||||
pub duration: i64,
|
||||
pub thread_id: Uuid,
|
||||
pub messages: Vec<LiteLlmMessage>,
|
||||
pub messages: Vec<AgentMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
@ -98,12 +97,12 @@ impl BusterSuperAgent {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn new(user: AuthenticatedUser, session_id: Uuid) -> Result<Self> {
|
||||
pub async fn new(user_id: Uuid, session_id: Uuid) -> Result<Self> {
|
||||
// Create agent with empty tools map
|
||||
let agent = Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
user,
|
||||
user_id,
|
||||
session_id,
|
||||
"buster_super_agent".to_string(),
|
||||
));
|
||||
|
@ -127,7 +126,7 @@ impl BusterSuperAgent {
|
|||
pub async fn run(
|
||||
&self,
|
||||
thread: &mut AgentThread,
|
||||
) -> Result<broadcast::Receiver<Result<LiteLlmMessage, AgentError>>> {
|
||||
) -> Result<broadcast::Receiver<Result<AgentMessage, AgentError>>> {
|
||||
thread.set_developer_message(BUSTER_SUPER_AGENT_PROMPT.to_string());
|
||||
|
||||
// Get shutdown receiver
|
||||
|
|
|
@ -16,4 +16,4 @@ pub use models::*;
|
|||
pub use tools::ToolExecutor;
|
||||
|
||||
// Re-export types from dependencies that are part of our public API
|
||||
pub use litellm::LiteLlmMessage;
|
||||
pub use litellm::AgentMessage;
|
|
@ -1,7 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use litellm::LiteLlmMessage;
|
||||
use litellm::AgentMessage;
|
||||
|
||||
/// A Thread represents a conversation between a user and the AI agent.
|
||||
/// It contains a sequence of messages in chronological order.
|
||||
|
@ -11,11 +11,11 @@ pub struct AgentThread {
|
|||
pub id: Uuid,
|
||||
pub user_id: Uuid,
|
||||
/// Ordered sequence of messages in the conversation
|
||||
pub messages: Vec<LiteLlmMessage>,
|
||||
pub messages: Vec<AgentMessage>,
|
||||
}
|
||||
|
||||
impl AgentThread {
|
||||
pub fn new(id: Option<Uuid>, user_id: Uuid, messages: Vec<LiteLlmMessage>) -> Self {
|
||||
pub fn new(id: Option<Uuid>, user_id: Uuid, messages: Vec<AgentMessage>) -> Self {
|
||||
Self {
|
||||
id: id.unwrap_or(Uuid::new_v4()),
|
||||
user_id,
|
||||
|
@ -29,13 +29,13 @@ impl AgentThread {
|
|||
if let Some(pos) = self
|
||||
.messages
|
||||
.iter()
|
||||
.position(|msg| matches!(msg, LiteLlmMessage::Developer { .. }))
|
||||
.position(|msg| matches!(msg, AgentMessage::Developer { .. }))
|
||||
{
|
||||
// Update existing developer message
|
||||
self.messages[pos] = LiteLlmMessage::developer(message);
|
||||
self.messages[pos] = AgentMessage::developer(message);
|
||||
} else {
|
||||
// Insert new developer message at the start
|
||||
self.messages.insert(0, LiteLlmMessage::developer(message));
|
||||
self.messages.insert(0, AgentMessage::developer(message));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ impl AgentThread {
|
|||
if let Some(pos) = self
|
||||
.messages
|
||||
.iter()
|
||||
.rposition(|msg| matches!(msg, LiteLlmMessage::Assistant { .. }))
|
||||
.rposition(|msg| matches!(msg, AgentMessage::Assistant { .. }))
|
||||
{
|
||||
self.messages.remove(pos);
|
||||
}
|
||||
|
@ -52,6 +52,6 @@ impl AgentThread {
|
|||
|
||||
/// Add a user message to the thread
|
||||
pub fn add_user_message(&mut self, content: String) {
|
||||
self.messages.push(LiteLlmMessage::user(content));
|
||||
self.messages.push(AgentMessage::user(content));
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -345,6 +345,7 @@ definitions:
|
|||
required:
|
||||
- x
|
||||
- y
|
||||
- category
|
||||
bar_layout:
|
||||
type: string
|
||||
enum: ["horizontal", "vertical"]
|
||||
|
|
|
@ -11,7 +11,6 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::{self, json, Value};
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
use middleware::AuthenticatedUser;
|
||||
|
||||
use crate::{
|
||||
agent::Agent,
|
||||
|
@ -132,7 +131,7 @@ impl ToolExecutor for CreateDashboardFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
let files = params.files;
|
||||
|
|
|
@ -12,7 +12,6 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_json::Value;
|
||||
use tracing::debug;
|
||||
use uuid::Uuid;
|
||||
use middleware::AuthenticatedUser;
|
||||
|
||||
use crate::{
|
||||
agent::Agent,
|
||||
|
@ -82,7 +81,7 @@ impl ToolExecutor for CreateMetricFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
let files = params.files;
|
||||
|
|
|
@ -12,7 +12,6 @@ use indexmap::IndexMap;
|
|||
use query_engine::data_types::DataType;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, error, info};
|
||||
use middleware::AuthenticatedUser;
|
||||
|
||||
use super::{
|
||||
common::{
|
||||
|
@ -68,7 +67,7 @@ impl ToolExecutor for ModifyDashboardFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
debug!("Starting file modification execution");
|
||||
|
@ -216,4 +215,128 @@ impl ToolExecutor for ModifyDashboardFilesTool {
|
|||
"description": "Makes content-based modifications to one or more existing dashboard YAML files in a single call. Each modification specifies the exact content to replace and its replacement. If you need to update chart config or other sections within a file, use this. Guard Rail: Do not execute any file creation or modifications until a thorough data catalog search has been completed and reviewed."
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use crate::tools::categories::file_tools::common::{
|
||||
apply_modifications_to_content, Modification, ModificationResult,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn test_apply_modifications_to_content() {
|
||||
let original_content =
|
||||
"name: test_dashboard\ntype: dashboard\ndescription: A test dashboard";
|
||||
|
||||
// Test single modification
|
||||
let mods1 = vec![Modification {
|
||||
content_to_replace: "type: dashboard".to_string(),
|
||||
new_content: "type: custom_dashboard".to_string(),
|
||||
}];
|
||||
let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result1,
|
||||
"name: test_dashboard\ntype: custom_dashboard\ndescription: A test dashboard"
|
||||
);
|
||||
|
||||
// Test multiple non-overlapping modifications
|
||||
let mods2 = vec![
|
||||
Modification {
|
||||
content_to_replace: "test_dashboard".to_string(),
|
||||
new_content: "new_dashboard".to_string(),
|
||||
},
|
||||
Modification {
|
||||
content_to_replace: "A test dashboard".to_string(),
|
||||
new_content: "An updated dashboard".to_string(),
|
||||
},
|
||||
];
|
||||
let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result2,
|
||||
"name: new_dashboard\ntype: dashboard\ndescription: An updated dashboard"
|
||||
);
|
||||
|
||||
// Test content not found
|
||||
let mods3 = vec![Modification {
|
||||
content_to_replace: "nonexistent content".to_string(),
|
||||
new_content: "new content".to_string(),
|
||||
}];
|
||||
let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml");
|
||||
assert!(result3.is_err());
|
||||
assert!(result3
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Content to replace not found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modification_result_tracking() {
|
||||
let result = ModificationResult {
|
||||
file_id: Uuid::new_v4(),
|
||||
file_name: "test.yml".to_string(),
|
||||
success: true,
|
||||
error: None,
|
||||
modification_type: "content".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
duration: 0,
|
||||
};
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.error.is_none());
|
||||
|
||||
let error_result = ModificationResult {
|
||||
success: false,
|
||||
error: Some("Failed to parse YAML".to_string()),
|
||||
..result
|
||||
};
|
||||
assert!(!error_result.success);
|
||||
assert!(error_result.error.is_some());
|
||||
assert_eq!(error_result.error.unwrap(), "Failed to parse YAML");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_parameter_validation() {
|
||||
let tool = ModifyDashboardFilesTool {
|
||||
agent: Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
)),
|
||||
};
|
||||
|
||||
// Test valid parameters
|
||||
let valid_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml",
|
||||
"modifications": [{
|
||||
"content_to_replace": "old content",
|
||||
"new_content": "new content"
|
||||
}]
|
||||
}]
|
||||
});
|
||||
let valid_args = serde_json::to_string(&valid_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&valid_args);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test missing required fields
|
||||
let missing_fields_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml"
|
||||
// missing modifications
|
||||
}]
|
||||
});
|
||||
let missing_args = serde_json::to_string(&missing_fields_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&missing_args);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ use database::{enums::Verification, models::MetricFile, pool::get_pg_pool, schem
|
|||
use diesel::{upsert::excluded, ExpressionMethods, QueryDsl};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use indexmap::IndexMap;
|
||||
use middleware::AuthenticatedUser;
|
||||
use query_engine::data_types::DataType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
@ -66,7 +65,7 @@ impl ToolExecutor for ModifyMetricFilesTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
debug!("Starting file modification execution");
|
||||
|
@ -275,3 +274,122 @@ impl ToolExecutor for ModifyMetricFilesTool {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_apply_modifications_to_content() {
|
||||
let original_content = "name: test_metric\ntype: counter\ndescription: A test metric";
|
||||
|
||||
// Test single modification
|
||||
let mods1 = vec![Modification {
|
||||
content_to_replace: "type: counter".to_string(),
|
||||
new_content: "type: gauge".to_string(),
|
||||
}];
|
||||
let result1 = apply_modifications_to_content(original_content, &mods1, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result1,
|
||||
"name: test_metric\ntype: gauge\ndescription: A test metric"
|
||||
);
|
||||
|
||||
// Test multiple non-overlapping modifications
|
||||
let mods2 = vec![
|
||||
Modification {
|
||||
content_to_replace: "test_metric".to_string(),
|
||||
new_content: "new_metric".to_string(),
|
||||
},
|
||||
Modification {
|
||||
content_to_replace: "A test metric".to_string(),
|
||||
new_content: "An updated metric".to_string(),
|
||||
},
|
||||
];
|
||||
let result2 = apply_modifications_to_content(original_content, &mods2, "test.yml").unwrap();
|
||||
assert_eq!(
|
||||
result2,
|
||||
"name: new_metric\ntype: counter\ndescription: An updated metric"
|
||||
);
|
||||
|
||||
// Test content not found
|
||||
let mods3 = vec![Modification {
|
||||
content_to_replace: "nonexistent content".to_string(),
|
||||
new_content: "new content".to_string(),
|
||||
}];
|
||||
let result3 = apply_modifications_to_content(original_content, &mods3, "test.yml");
|
||||
assert!(result3.is_err());
|
||||
assert!(result3
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Content to replace not found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_modification_result_tracking() {
|
||||
let result = ModificationResult {
|
||||
file_id: Uuid::new_v4(),
|
||||
file_name: "test.yml".to_string(),
|
||||
success: true,
|
||||
error: None,
|
||||
modification_type: "content".to_string(),
|
||||
timestamp: Utc::now(),
|
||||
duration: 0,
|
||||
};
|
||||
|
||||
assert!(result.success);
|
||||
assert!(result.error.is_none());
|
||||
|
||||
let error_result = ModificationResult {
|
||||
success: false,
|
||||
error: Some("Failed to parse YAML".to_string()),
|
||||
..result
|
||||
};
|
||||
assert!(!error_result.success);
|
||||
assert!(error_result.error.is_some());
|
||||
assert_eq!(error_result.error.unwrap(), "Failed to parse YAML");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_parameter_validation() {
|
||||
let tool = ModifyMetricFilesTool {
|
||||
agent: Arc::new(Agent::new(
|
||||
"o3-mini".to_string(),
|
||||
HashMap::new(),
|
||||
Uuid::new_v4(),
|
||||
Uuid::new_v4(),
|
||||
"test_agent".to_string(),
|
||||
)),
|
||||
};
|
||||
|
||||
// Test valid parameters
|
||||
let valid_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml",
|
||||
"modifications": [{
|
||||
"content_to_replace": "old content",
|
||||
"new_content": "new content"
|
||||
}]
|
||||
}]
|
||||
});
|
||||
let valid_args = serde_json::to_string(&valid_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&valid_args);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test missing required fields
|
||||
let missing_fields_params = json!({
|
||||
"files": [{
|
||||
"id": Uuid::new_v4().to_string(),
|
||||
"file_name": "test.yml"
|
||||
// missing modifications
|
||||
}]
|
||||
});
|
||||
let missing_args = serde_json::to_string(&missing_fields_params).unwrap();
|
||||
let result = serde_json::from_str::<ModifyFilesParams>(&missing_args);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ use chrono::{DateTime, Utc};
|
|||
use database::{pool::get_pg_pool, schema::datasets};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tracing::{debug, error, warn};
|
||||
|
@ -15,7 +14,7 @@ use uuid::Uuid;
|
|||
|
||||
use crate::{agent::Agent, tools::ToolExecutor};
|
||||
|
||||
use litellm::{ChatCompletionRequest, LiteLLMClient, LiteLlmMessage, Metadata, ResponseFormat};
|
||||
use litellm::{ChatCompletionRequest, LiteLLMClient, AgentMessage, Metadata, ResponseFormat};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SearchDataCatalogParams {
|
||||
|
@ -128,7 +127,7 @@ impl SearchDataCatalogTool {
|
|||
while retry_count < MAX_RETRIES {
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o3-mini".to_string(),
|
||||
messages: vec![LiteLlmMessage::User {
|
||||
messages: vec![AgentMessage::User {
|
||||
id: None,
|
||||
content: current_prompt.clone(),
|
||||
name: None,
|
||||
|
@ -160,7 +159,7 @@ impl SearchDataCatalogTool {
|
|||
|
||||
// Parse LLM response
|
||||
let content = match &response.choices[0].message {
|
||||
LiteLlmMessage::Assistant {
|
||||
AgentMessage::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
} => content,
|
||||
|
@ -261,7 +260,7 @@ impl ToolExecutor for SearchDataCatalogTool {
|
|||
type Output = SearchDataCatalogOutput;
|
||||
type Params = SearchDataCatalogParams;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Fetch all non-deleted datasets
|
||||
|
|
|
@ -7,5 +7,6 @@
|
|||
//! - interaction_tools: Tools for user interaction and UI manipulation
|
||||
//! - planning_tools: Tools for planning and scheduling
|
||||
|
||||
pub mod agents_as_tools;
|
||||
pub mod file_tools;
|
||||
pub mod planning_tools;
|
|
@ -1,6 +1,5 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
@ -37,7 +36,7 @@ impl ToolExecutor for CreatePlan {
|
|||
"create_plan".to_string()
|
||||
}
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
self.agent
|
||||
.set_state_value(String::from("plan_available"), Value::Bool(true))
|
||||
.await;
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// A trait that defines how tools should be implemented.
|
||||
/// Any struct that wants to be used as a tool must implement this trait.
|
||||
|
@ -15,7 +13,7 @@ pub trait ToolExecutor: Send + Sync {
|
|||
type Params: DeserializeOwned + Send;
|
||||
|
||||
/// Execute the tool with the given parameters and tool call ID.
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user_id: AuthenticatedUser) -> Result<Self::Output>;
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output>;
|
||||
|
||||
/// Get the JSON schema for this tool
|
||||
fn get_schema(&self) -> Value;
|
||||
|
@ -55,9 +53,9 @@ where
|
|||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
let params = serde_json::from_value(params)?;
|
||||
let result = self.inner.execute(params, tool_call_id, user).await?;
|
||||
let result = self.inner.execute(params, tool_call_id).await?;
|
||||
Ok(serde_json::to_value(result)?)
|
||||
}
|
||||
|
||||
|
@ -80,8 +78,8 @@ impl<T: ToolExecutor<Output = Value, Params = Value> + Send + Sync> ToolExecutor
|
|||
type Output = Value;
|
||||
type Params = Value;
|
||||
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String, user: AuthenticatedUser) -> Result<Self::Output> {
|
||||
(**self).execute(params, tool_call_id, user).await
|
||||
async fn execute(&self, params: Self::Params, tool_call_id: String) -> Result<Self::Output> {
|
||||
(**self).execute(params, tool_call_id).await
|
||||
}
|
||||
|
||||
fn get_schema(&self) -> Value {
|
||||
|
|
|
@ -10,4 +10,5 @@ pub use executor::{ToolExecutor, ToolCallExecutor, IntoToolCallExecutor};
|
|||
|
||||
// Re-export commonly used tool categories
|
||||
pub use categories::file_tools;
|
||||
pub use categories::planning_tools;
|
||||
pub use categories::planning_tools;
|
||||
pub use categories::agents_as_tools;
|
|
@ -31,7 +31,6 @@ litellm = { path = "../litellm" }
|
|||
query_engine = { path = "../query_engine" }
|
||||
middleware = { path = "../middleware" }
|
||||
sharing = { path = "../sharing" }
|
||||
streaming = { path = "../streaming" }
|
||||
|
||||
# Add any handler-specific dependencies here
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ use database::{
|
|||
};
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use agents::{Agent, LiteLlmMessage};
|
||||
use agents::{Agent, AgentMessage};
|
||||
use middleware::AuthenticatedUser;
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
@ -26,8 +26,8 @@ impl ChatContextLoader {
|
|||
}
|
||||
|
||||
// Helper function to check for tool usage and set appropriate context
|
||||
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &LiteLlmMessage) {
|
||||
if let LiteLlmMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
|
||||
async fn update_context_from_tool_calls(agent: &Arc<Agent>, message: &AgentMessage) {
|
||||
if let AgentMessage::Assistant { tool_calls: Some(tool_calls), .. } = message {
|
||||
for tool_call in tool_calls {
|
||||
match tool_call.function.name.as_str() {
|
||||
"search_data_catalog" => {
|
||||
|
@ -55,7 +55,7 @@ impl ChatContextLoader {
|
|||
|
||||
#[async_trait]
|
||||
impl ContextLoader for ChatContextLoader {
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<LiteLlmMessage>> {
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await?;
|
||||
|
||||
// First verify the chat exists and user has access
|
||||
|
@ -78,7 +78,7 @@ impl ContextLoader for ChatContextLoader {
|
|||
let mut agent_messages = Vec::new();
|
||||
|
||||
// Process only the most recent message's raw LLM messages
|
||||
if let Ok(raw_messages) = serde_json::from_value::<Vec<LiteLlmMessage>>(message.raw_llm_messages) {
|
||||
if let Ok(raw_messages) = serde_json::from_value::<Vec<AgentMessage>>(message.raw_llm_messages) {
|
||||
// Check each message for tool calls and update context
|
||||
for agent_message in &raw_messages {
|
||||
Self::update_context_from_tool_calls(agent, agent_message).await;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use agents::{Agent, LiteLlmMessage};
|
||||
use agents::{Agent, AgentMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use database::{
|
||||
|
@ -33,7 +33,7 @@ impl ContextLoader for DashboardContextLoader {
|
|||
&self,
|
||||
user: &AuthenticatedUser,
|
||||
agent: &Arc<Agent>,
|
||||
) -> Result<Vec<LiteLlmMessage>> {
|
||||
) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to get database connection for dashboard context loading: {}",
|
||||
|
@ -172,7 +172,7 @@ impl ContextLoader for DashboardContextLoader {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(vec![LiteLlmMessage::Assistant {
|
||||
Ok(vec![AgentMessage::Assistant {
|
||||
id: None,
|
||||
content: Some(context_message),
|
||||
name: None,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use agents::{Agent, LiteLlmMessage};
|
||||
use agents::{Agent, AgentMessage};
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use database::{
|
||||
|
@ -27,7 +27,7 @@ impl MetricContextLoader {
|
|||
|
||||
#[async_trait]
|
||||
impl ContextLoader for MetricContextLoader {
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<LiteLlmMessage>> {
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>> {
|
||||
let mut conn = get_pg_pool().get().await.map_err(|e| {
|
||||
anyhow!(
|
||||
"Failed to get database connection for metric context loading: {}",
|
||||
|
@ -107,7 +107,7 @@ impl ContextLoader for MetricContextLoader {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(vec![LiteLlmMessage::Assistant {
|
||||
Ok(vec![AgentMessage::Assistant {
|
||||
id: None,
|
||||
content: Some(context_message),
|
||||
name: None,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use agents::LiteLlmMessage;
|
||||
use agents::AgentMessage;
|
||||
use middleware::AuthenticatedUser;
|
||||
use std::sync::Arc;
|
||||
use agents::Agent;
|
||||
|
@ -15,7 +15,7 @@ pub use dashboard_context::DashboardContextLoader;
|
|||
|
||||
#[async_trait]
|
||||
pub trait ContextLoader {
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<LiteLlmMessage>>;
|
||||
async fn load_context(&self, user: &AuthenticatedUser, agent: &Arc<Agent>) -> Result<Vec<AgentMessage>>;
|
||||
}
|
||||
|
||||
// Validate that only one context type is provided
|
||||
|
|
|
@ -3,9 +3,9 @@ pub mod post_chat_handler;
|
|||
pub mod update_chats_handler;
|
||||
pub mod delete_chats_handler;
|
||||
pub mod types;
|
||||
pub mod streaming_parser;
|
||||
pub mod context_loaders;
|
||||
pub mod list_chats_handler;
|
||||
pub mod helpers;
|
||||
|
||||
pub use get_chat_handler::get_chat_handler;
|
||||
pub use post_chat_handler::post_chat_handler;
|
||||
|
@ -13,3 +13,4 @@ pub use update_chats_handler::update_chats_handler;
|
|||
pub use delete_chats_handler::delete_chats_handler;
|
||||
pub use list_chats_handler::list_chats_handler;
|
||||
pub use types::*;
|
||||
pub use streaming_parser::StreamingParser;
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,307 @@
|
|||
use agents::tools::categories::file_tools::common::generate_deterministic_uuid;
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::post_chat_handler::{
|
||||
BusterFile, BusterFileContent, BusterReasoningFile, BusterReasoningMessage,
|
||||
BusterReasoningPill, BusterReasoningText, BusterThoughtPill, BusterThoughtPillContainer,
|
||||
};
|
||||
|
||||
pub struct StreamingParser {
|
||||
buffer: String,
|
||||
yml_content_regex: regex::Regex,
|
||||
}
|
||||
|
||||
impl StreamingParser {
|
||||
pub fn new() -> Self {
|
||||
StreamingParser {
|
||||
buffer: String::new(),
|
||||
yml_content_regex: regex::Regex::new(
|
||||
r#""yml_content":\s*"((?:[^"\\]|\\.|[\r\n])*?)(?:"|$)"#,
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the buffer - useful when reusing the parser for different content formats
|
||||
pub fn clear_buffer(&mut self) {
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
// Process chunks meant for plan creation
|
||||
pub fn process_plan_chunk(
|
||||
&mut self,
|
||||
id: String,
|
||||
chunk: &str,
|
||||
) -> Result<Option<BusterReasoningMessage>> {
|
||||
// Clear buffer and add new chunk
|
||||
self.clear_buffer();
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
// Complete any incomplete JSON structure
|
||||
let processed_json = self.complete_json_structure(self.buffer.clone());
|
||||
|
||||
// Try to parse the JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&processed_json) {
|
||||
// Check if it's a plan structure (has plan_markdown key)
|
||||
if let Some(plan_markdown) = value.get("plan_markdown").and_then(Value::as_str) {
|
||||
// Return the plan as a BusterReasoningText
|
||||
return Ok(Some(BusterReasoningMessage::Text(BusterReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Creating a plan...".to_string(),
|
||||
secondary_title: String::from(""),
|
||||
message: None,
|
||||
message_chunk: Some(plan_markdown.to_string()),
|
||||
status: Some("loading".to_string()),
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// Process chunks meant for search data catalog
|
||||
pub fn process_search_data_catalog_chunk(
|
||||
&mut self,
|
||||
id: String,
|
||||
chunk: &str,
|
||||
) -> Result<Option<BusterReasoningMessage>> {
|
||||
// Clear buffer and add new chunk
|
||||
self.clear_buffer();
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
// Complete any incomplete JSON structure
|
||||
let processed_json = self.complete_json_structure(self.buffer.clone());
|
||||
|
||||
// Try to parse the JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&processed_json) {
|
||||
// Check if it's a search requirements structure
|
||||
if let Some(search_requirements) =
|
||||
value.get("search_requirements").and_then(Value::as_str)
|
||||
{
|
||||
// Return the search requirements as a BusterReasoningText
|
||||
return Ok(Some(BusterReasoningMessage::Text(BusterReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Searching your data catalog...".to_string(),
|
||||
secondary_title: String::from(""),
|
||||
message: None,
|
||||
message_chunk: Some(search_requirements.to_string()),
|
||||
status: Some("loading".to_string()),
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// Process chunks meant for metric files
|
||||
pub fn process_metric_chunk(
|
||||
&mut self,
|
||||
id: String,
|
||||
chunk: &str,
|
||||
) -> Result<Option<BusterReasoningMessage>> {
|
||||
// Clear buffer and add new chunk
|
||||
self.clear_buffer();
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
// Process the buffer with metric file type
|
||||
self.process_file_data(id.clone(), "metric".to_string())
|
||||
}
|
||||
|
||||
// Process chunks meant for dashboard files
|
||||
pub fn process_dashboard_chunk(
|
||||
&mut self,
|
||||
id: String,
|
||||
chunk: &str,
|
||||
) -> Result<Option<BusterReasoningMessage>> {
|
||||
// Clear buffer and add new chunk
|
||||
self.clear_buffer();
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
// Process the buffer with dashboard file type
|
||||
self.process_file_data(id.clone(), "dashboard".to_string())
|
||||
}
|
||||
|
||||
// Internal function to process file data (shared by metric and dashboard processing)
|
||||
pub fn process_file_data(
|
||||
&mut self,
|
||||
id: String,
|
||||
file_type: String,
|
||||
) -> Result<Option<BusterReasoningMessage>> {
|
||||
// Extract and replace yml_content with placeholders
|
||||
let mut yml_contents = Vec::new();
|
||||
let mut positions = Vec::new();
|
||||
let mut processed_json = self.buffer.clone();
|
||||
|
||||
// Find all yml_content matches and store them with their positions
|
||||
for captures in self.yml_content_regex.captures_iter(&self.buffer) {
|
||||
if let Some(content_match) = captures.get(1) {
|
||||
yml_contents.push(content_match.as_str().to_string());
|
||||
positions.push((
|
||||
captures.get(0).unwrap().start(),
|
||||
captures.get(0).unwrap().end(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort positions from last to first to maintain correct indices when replacing
|
||||
let mut position_indices: Vec<usize> = (0..positions.len()).collect();
|
||||
position_indices.sort_by_key(|&i| std::cmp::Reverse(positions[i].0));
|
||||
|
||||
// Replace matches with placeholders in reverse order
|
||||
for i in position_indices {
|
||||
let (start, end) = positions[i];
|
||||
let placeholder = format!(r#""yml_content":"YML_CONTENT_{i}""#);
|
||||
processed_json.replace_range(start..end, &placeholder);
|
||||
}
|
||||
|
||||
// Complete any incomplete JSON structure
|
||||
processed_json = self.complete_json_structure(processed_json);
|
||||
|
||||
// Try to parse the completed JSON
|
||||
if let Ok(mut value) = serde_json::from_str::<Value>(&processed_json) {
|
||||
// Put back the yml_content and process escapes first
|
||||
if let Some(obj) = value.as_object_mut() {
|
||||
if let Some(files) = obj.get_mut("files").and_then(|v| v.as_array_mut()) {
|
||||
for (i, file) in files.iter_mut().enumerate() {
|
||||
if let Some(file_obj) = file.as_object_mut() {
|
||||
if let Some(yml_content) = yml_contents.get(i) {
|
||||
// Process escaped characters
|
||||
let processed_content =
|
||||
serde_json::from_str::<String>(&format!("\"{}\"", yml_content))
|
||||
.unwrap_or_else(|_| yml_content.clone());
|
||||
|
||||
file_obj.insert(
|
||||
"yml_content".to_string(),
|
||||
Value::String(processed_content),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now check the structure after modifications
|
||||
return self.convert_file_to_message(id, value, file_type);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// Helper method to complete JSON structure (shared functionality)
|
||||
fn complete_json_structure(&self, json: String) -> String {
|
||||
let mut processed = String::with_capacity(json.len());
|
||||
let mut nesting_stack = Vec::new();
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
// Process each character and track structure
|
||||
for c in json.chars() {
|
||||
processed.push(c);
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match c {
|
||||
'\\' => escape_next = true,
|
||||
'"' if !escape_next => in_string = !in_string,
|
||||
'{' | '[' if !in_string => nesting_stack.push(c),
|
||||
'}' if !in_string => {
|
||||
if nesting_stack.last() == Some(&'{') {
|
||||
nesting_stack.pop();
|
||||
}
|
||||
}
|
||||
']' if !in_string => {
|
||||
if nesting_stack.last() == Some(&'[') {
|
||||
nesting_stack.pop();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Close any unclosed strings
|
||||
if in_string {
|
||||
processed.push('"');
|
||||
}
|
||||
|
||||
// Close structures in reverse order of opening
|
||||
while let Some(c) = nesting_stack.pop() {
|
||||
match c {
|
||||
'{' => processed.push('}'),
|
||||
'[' => processed.push(']'),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
processed
|
||||
}
|
||||
|
||||
// Helper method to convert file JSON to message
|
||||
fn convert_file_to_message(
|
||||
&self,
|
||||
id: String,
|
||||
value: Value,
|
||||
file_type: String,
|
||||
) -> Result<Option<BusterReasoningMessage>> {
|
||||
if let Some(files) = value.get("files").and_then(Value::as_array) {
|
||||
let mut files_map = std::collections::HashMap::new();
|
||||
let mut file_ids = Vec::new();
|
||||
|
||||
for file in files {
|
||||
if let Some(file_obj) = file.as_object() {
|
||||
let has_name = file_obj.get("name").and_then(Value::as_str).is_some();
|
||||
let has_yml_content = file_obj.get("yml_content").is_some();
|
||||
|
||||
if has_name && has_yml_content {
|
||||
let name = file_obj.get("name").and_then(Value::as_str).unwrap_or("");
|
||||
let yml_content = file_obj
|
||||
.get("yml_content")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("");
|
||||
|
||||
// Generate deterministic UUID based on tool call ID, file name, and type
|
||||
let file_id = generate_deterministic_uuid(&id, name, &file_type)?;
|
||||
|
||||
let buster_file = BusterFile {
|
||||
id: file_id.to_string(),
|
||||
file_type: file_type.clone(),
|
||||
file_name: name.to_string(),
|
||||
version_number: 1,
|
||||
version_id: String::from("0203f597-5ec5-4fd8-86e2-8587fe1c23b6"),
|
||||
status: "loading".to_string(),
|
||||
file: BusterFileContent {
|
||||
text: None,
|
||||
text_chunk: Some(yml_content.to_string()),
|
||||
modifided: None,
|
||||
},
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
file_ids.push(file_id.to_string());
|
||||
files_map.insert(file_id.to_string(), buster_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !files_map.is_empty() {
|
||||
return Ok(Some(BusterReasoningMessage::File(BusterReasoningFile {
|
||||
id,
|
||||
message_type: "files".to_string(),
|
||||
title: format!("Creating {} files...", file_type),
|
||||
secondary_title: String::new(),
|
||||
status: "loading".to_string(),
|
||||
file_ids,
|
||||
files: files_map,
|
||||
})));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
|
@ -76,7 +76,7 @@ impl LiteLLMClient {
|
|||
let response: ChatCompletionResponse = serde_json::from_str(&response_text)?;
|
||||
|
||||
// Print tool calls if present
|
||||
if let Some(LiteLlmMessage::Assistant {
|
||||
if let Some(AgentMessage::Assistant {
|
||||
tool_calls: Some(tool_calls),
|
||||
..
|
||||
}) = response.choices.first().map(|c| &c.message)
|
||||
|
@ -221,8 +221,8 @@ mod tests {
|
|||
(api_key, base_url)
|
||||
}
|
||||
|
||||
fn create_test_message() -> LiteLlmMessage {
|
||||
LiteLlmMessage::user("Hello".to_string())
|
||||
fn create_test_message() -> AgentMessage {
|
||||
AgentMessage::user("Hello".to_string())
|
||||
}
|
||||
|
||||
fn create_test_request() -> ChatCompletionRequest {
|
||||
|
@ -281,7 +281,7 @@ mod tests {
|
|||
|
||||
let response = client.chat_completion(request).await.unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
if let LiteLlmMessage::Assistant { content, .. } = response.choices[0].message.clone() {
|
||||
if let AgentMessage::Assistant { content, .. } = response.choices[0].message.clone() {
|
||||
assert_eq!(content.unwrap(), "Hello there!");
|
||||
} else {
|
||||
panic!("Expected assistant message");
|
||||
|
@ -416,7 +416,7 @@ mod tests {
|
|||
|
||||
let response = client.chat_completion(request).await.unwrap();
|
||||
assert_eq!(response.id, "test-id");
|
||||
if let LiteLlmMessage::Assistant {
|
||||
if let AgentMessage::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
|
@ -471,7 +471,7 @@ mod tests {
|
|||
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o1".to_string(),
|
||||
messages: vec![LiteLlmMessage::user("Hello, world!".to_string())],
|
||||
messages: vec![AgentMessage::user("Hello, world!".to_string())],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::collections::HashMap;
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<LiteLlmMessage>,
|
||||
pub messages: Vec<AgentMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub store: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
@ -109,7 +109,7 @@ impl Default for MessageProgress {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[serde(tag = "role")]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum LiteLlmMessage {
|
||||
pub enum AgentMessage {
|
||||
#[serde(alias = "system")]
|
||||
Developer {
|
||||
#[serde(skip)]
|
||||
|
@ -154,7 +154,7 @@ pub enum LiteLlmMessage {
|
|||
|
||||
// Helper methods for Message
|
||||
// Intentionally leaving out name for now.
|
||||
impl LiteLlmMessage {
|
||||
impl AgentMessage {
|
||||
pub fn developer(content: impl Into<String>) -> Self {
|
||||
Self::Developer {
|
||||
id: None,
|
||||
|
@ -382,7 +382,7 @@ pub struct ChatCompletionResponse {
|
|||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Choice {
|
||||
pub index: i32,
|
||||
pub message: LiteLlmMessage,
|
||||
pub message: AgentMessage,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub delta: Option<Delta>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
|
@ -469,8 +469,8 @@ mod tests {
|
|||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
messages: vec![
|
||||
LiteLlmMessage::developer("You are a helpful assistant."),
|
||||
LiteLlmMessage::user("Hello!"),
|
||||
AgentMessage::developer("You are a helpful assistant."),
|
||||
AgentMessage::user("Hello!"),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
@ -492,7 +492,7 @@ mod tests {
|
|||
|
||||
// Check first message (developer)
|
||||
match &deserialized.messages[0] {
|
||||
LiteLlmMessage::Developer { content, .. } => {
|
||||
AgentMessage::Developer { content, .. } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("First message should be developer role"),
|
||||
|
@ -500,7 +500,7 @@ mod tests {
|
|||
|
||||
// Check second message (user)
|
||||
match &deserialized.messages[1] {
|
||||
LiteLlmMessage::User { content, .. } => {
|
||||
AgentMessage::User { content, .. } => {
|
||||
assert_eq!(content, "Hello!");
|
||||
}
|
||||
_ => panic!("Second message should be user role"),
|
||||
|
@ -517,7 +517,7 @@ mod tests {
|
|||
system_fingerprint: Some("fp_44709d6fcb".to_string()),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: LiteLlmMessage::assistant(
|
||||
message: AgentMessage::assistant(
|
||||
Some("\n\nHello there, how may I assist you today?".to_string()),
|
||||
None,
|
||||
None,
|
||||
|
@ -567,7 +567,7 @@ mod tests {
|
|||
|
||||
// Verify message
|
||||
match &choice.message {
|
||||
LiteLlmMessage::Assistant {
|
||||
AgentMessage::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
|
@ -598,7 +598,7 @@ mod tests {
|
|||
async fn test_chat_completion_request_with_tools() {
|
||||
let request = ChatCompletionRequest {
|
||||
model: "o1".to_string(),
|
||||
messages: vec![LiteLlmMessage::user(
|
||||
messages: vec![AgentMessage::user(
|
||||
"Hello whats the weather in vineyard ut!",
|
||||
)],
|
||||
max_completion_tokens: Some(100),
|
||||
|
@ -641,7 +641,7 @@ mod tests {
|
|||
// Verify message
|
||||
assert_eq!(deserialized.messages.len(), 1);
|
||||
match &deserialized.messages[0] {
|
||||
LiteLlmMessage::User { content, .. } => {
|
||||
AgentMessage::User { content, .. } => {
|
||||
assert_eq!(content, "Hello whats the weather in vineyard ut!");
|
||||
}
|
||||
_ => panic!("Expected user message"),
|
||||
|
@ -667,7 +667,7 @@ mod tests {
|
|||
choices: vec![Choice {
|
||||
finish_reason: Some("length".to_string()),
|
||||
index: 0,
|
||||
message: LiteLlmMessage::assistant(
|
||||
message: AgentMessage::assistant(
|
||||
Some("".to_string()),
|
||||
None,
|
||||
None,
|
||||
|
@ -714,7 +714,7 @@ mod tests {
|
|||
|
||||
// Verify message is empty
|
||||
match &choice.message {
|
||||
LiteLlmMessage::Assistant {
|
||||
AgentMessage::Assistant {
|
||||
content,
|
||||
tool_calls,
|
||||
..
|
||||
|
@ -742,8 +742,8 @@ mod tests {
|
|||
let request = ChatCompletionRequest {
|
||||
model: "o1".to_string(),
|
||||
messages: vec![
|
||||
LiteLlmMessage::developer("You are a helpful assistant."),
|
||||
LiteLlmMessage::user("Hello!"),
|
||||
AgentMessage::developer("You are a helpful assistant."),
|
||||
AgentMessage::user("Hello!"),
|
||||
],
|
||||
stream: Some(true),
|
||||
..Default::default()
|
||||
|
@ -763,13 +763,13 @@ mod tests {
|
|||
// Verify messages
|
||||
assert_eq!(deserialized.messages.len(), 2);
|
||||
match &deserialized.messages[0] {
|
||||
LiteLlmMessage::Developer { content, .. } => {
|
||||
AgentMessage::Developer { content, .. } => {
|
||||
assert_eq!(content, "You are a helpful assistant.");
|
||||
}
|
||||
_ => panic!("First message should be developer role"),
|
||||
}
|
||||
match &deserialized.messages[1] {
|
||||
LiteLlmMessage::User { content, .. } => {
|
||||
AgentMessage::User { content, .. } => {
|
||||
assert_eq!(content, "Hello!");
|
||||
}
|
||||
_ => panic!("Second message should be user role"),
|
||||
|
@ -893,7 +893,7 @@ mod tests {
|
|||
// Test request with function tool
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
messages: vec![LiteLlmMessage::user(
|
||||
messages: vec![AgentMessage::user(
|
||||
"What's the weather like in Boston today?",
|
||||
)],
|
||||
tools: Some(vec![Tool {
|
||||
|
@ -931,7 +931,7 @@ mod tests {
|
|||
// Verify request fields
|
||||
assert_eq!(deserialized_req.model, "gpt-4o");
|
||||
match &deserialized_req.messages[0] {
|
||||
LiteLlmMessage::User { content, .. } => {
|
||||
AgentMessage::User { content, .. } => {
|
||||
assert_eq!(content, "What's the weather like in Boston today?");
|
||||
}
|
||||
_ => panic!("Expected user message"),
|
||||
|
@ -952,7 +952,7 @@ mod tests {
|
|||
model: "gpt-4o-mini".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: LiteLlmMessage::assistant(
|
||||
message: AgentMessage::assistant(
|
||||
None,
|
||||
None,
|
||||
Some(vec![ToolCall {
|
||||
|
@ -1001,7 +1001,7 @@ mod tests {
|
|||
assert_eq!(choice.finish_reason, Some("tool_calls".to_string()));
|
||||
|
||||
match &choice.message {
|
||||
LiteLlmMessage::Assistant {
|
||||
AgentMessage::Assistant {
|
||||
id,
|
||||
content,
|
||||
tool_calls,
|
||||
|
|
|
@ -12,6 +12,8 @@ serde_json = { workspace = true }
|
|||
regex = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
litellm = { path = "../litellm" }
|
||||
|
||||
# Development dependencies
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -2,13 +2,18 @@
|
|||
//!
|
||||
//! This library provides functionality for parsing incomplete JSON streams
|
||||
//! and processing them through specialized processors.
|
||||
//!
|
||||
//! The library now supports ID-based processors, allowing multiple tool calls
|
||||
//! of the same type to be processed simultaneously without interference.
|
||||
//! It also handles caching and chunk tracking for each individual tool call.
|
||||
|
||||
pub mod parser;
|
||||
pub mod processor;
|
||||
pub mod types;
|
||||
pub mod processors;
|
||||
|
||||
// Re-exports for convenient access
|
||||
// Re-export the main types
|
||||
pub use parser::StreamingParser;
|
||||
pub use processor::{Processor, ProcessorRegistry};
|
||||
pub use types::ProcessedOutput;
|
||||
pub use types::{ProcessedOutput, ProcessorType, MessageType, ToolCallInfo, ToolCallState, ProcessedMessage};
|
||||
pub use processors::*;
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
use anyhow::Result;
|
||||
use anyhow::{anyhow, Result};
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
use chrono::Utc;
|
||||
use litellm::{LiteLlmMessage, ToolCall};
|
||||
|
||||
use crate::processor::ProcessorRegistry;
|
||||
use crate::types::{File, FileContent, ProcessedOutput, ReasoningFile};
|
||||
use crate::types::{
|
||||
File, FileContent, MessageType, ProcessedMessage, ProcessedOutput,
|
||||
ReasoningFile, ReasoningText, ToolCallInfo, ToolCallState
|
||||
};
|
||||
|
||||
/// StreamingParser handles parsing of incomplete JSON streams
|
||||
/// StreamingParser handles parsing of incomplete JSON streams and LiteLlmMessage processing
|
||||
pub struct StreamingParser {
|
||||
/// Buffer to accumulate chunks of data
|
||||
buffer: String,
|
||||
|
@ -15,6 +20,12 @@ pub struct StreamingParser {
|
|||
processors: ProcessorRegistry,
|
||||
/// Regex for extracting YAML content
|
||||
yml_content_regex: Regex,
|
||||
/// Map of tool call IDs to their information
|
||||
tool_calls: HashMap<String, ToolCallInfo>,
|
||||
/// List of reasoning messages (tool calls and outputs)
|
||||
reasoning_messages: Vec<ProcessedMessage>,
|
||||
/// List of response messages
|
||||
response_messages: Vec<String>,
|
||||
}
|
||||
|
||||
impl StreamingParser {
|
||||
|
@ -25,6 +36,9 @@ impl StreamingParser {
|
|||
processors: ProcessorRegistry::new(),
|
||||
yml_content_regex: Regex::new(r#""yml_content":\s*"((?:[^"\\]|\\.|[\r\n])*?)(?:"|$)"#)
|
||||
.unwrap(),
|
||||
tool_calls: HashMap::new(),
|
||||
reasoning_messages: Vec::new(),
|
||||
response_messages: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -35,6 +49,9 @@ impl StreamingParser {
|
|||
processors,
|
||||
yml_content_regex: Regex::new(r#""yml_content":\s*"((?:[^"\\]|\\.|[\r\n])*?)(?:"|$)"#)
|
||||
.unwrap(),
|
||||
tool_calls: HashMap::new(),
|
||||
reasoning_messages: Vec::new(),
|
||||
response_messages: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,6 +60,15 @@ impl StreamingParser {
|
|||
self.processors.register(processor);
|
||||
}
|
||||
|
||||
/// Registers a processor with a specific ID
|
||||
pub fn register_processor_with_id(
|
||||
&mut self,
|
||||
id: String,
|
||||
processor: Box<dyn crate::processor::Processor>,
|
||||
) {
|
||||
self.processors.register_with_id(id, processor);
|
||||
}
|
||||
|
||||
/// Clear the buffer - useful when reusing the parser for different content formats
|
||||
pub fn clear_buffer(&mut self) {
|
||||
self.buffer.clear();
|
||||
|
@ -55,14 +81,220 @@ impl StreamingParser {
|
|||
chunk: &str,
|
||||
processor_type: &str,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Add new chunk to buffer
|
||||
self.buffer.push_str(chunk);
|
||||
// Get or create buffer for this ID
|
||||
let buffer = self.processors.get_chunk_buffer(&id).unwrap_or_default();
|
||||
let updated_buffer = buffer + chunk;
|
||||
|
||||
println!("Updated buffer: {}", updated_buffer);
|
||||
|
||||
// Store updated buffer
|
||||
self.processors
|
||||
.update_chunk_buffer(id.clone(), updated_buffer.clone());
|
||||
|
||||
// Get the previously cached output, if any
|
||||
let previous_output = self.processors.get_cached_output(&id).cloned();
|
||||
|
||||
println!("Previous output: {:#?}", previous_output);
|
||||
|
||||
// Complete any incomplete JSON structure
|
||||
let processed_json = self.complete_json_structure(self.buffer.clone());
|
||||
let processed_json = self.complete_json_structure(updated_buffer.clone());
|
||||
|
||||
// If we don't have a processor registered with this ID yet, find one by type and register it
|
||||
if !self.processors.has_processor_with_id(&id) {
|
||||
for (_, (type_str, processor)) in self.processors.get_processors() {
|
||||
if type_str == processor_type && processor.can_process(&processed_json) {
|
||||
// Clone the processor and register it with this ID
|
||||
let processor_clone = processor.clone_box();
|
||||
self.processors.register_with_id(id.clone(), processor_clone);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process with the appropriate processor
|
||||
self.processors.process(id, &processed_json, processor_type)
|
||||
// Process with the appropriate processor, passing the previous output
|
||||
let result = self
|
||||
.processors
|
||||
.process_by_id_with_context(id.clone(), &processed_json, processor_type, previous_output);
|
||||
|
||||
println!("Result: {:#?}", result);
|
||||
|
||||
// If processing succeeded, cache the result
|
||||
if let Ok(Some(output)) = &result {
|
||||
self.processors.cache_output(id, output.clone());
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Process a LiteLlmMessage
|
||||
pub fn process_message(&mut self, message: &LiteLlmMessage) -> Result<Option<ProcessedOutput>> {
|
||||
match message {
|
||||
LiteLlmMessage::Assistant { tool_calls: Some(tool_calls), id, .. } => {
|
||||
self.process_assistant_tool_call(message, tool_calls, id.clone())
|
||||
},
|
||||
LiteLlmMessage::Assistant { content: Some(content), id, tool_calls: None, .. } => {
|
||||
self.process_assistant_response(message, content, id.clone())
|
||||
},
|
||||
LiteLlmMessage::Tool { content, tool_call_id, id, .. } => {
|
||||
self.process_tool_output(message, tool_call_id, content, id.clone())
|
||||
},
|
||||
_ => Ok(None), // Ignore other message types
|
||||
}
|
||||
}
|
||||
|
||||
/// Process an Assistant message with tool calls
|
||||
fn process_assistant_tool_call(
|
||||
&mut self,
|
||||
_message: &LiteLlmMessage,
|
||||
tool_calls: &[ToolCall],
|
||||
id: Option<String>
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
for tool_call in tool_calls {
|
||||
let tool_id = tool_call.id.clone();
|
||||
let name = tool_call.function.name.clone();
|
||||
let arguments = tool_call.function.arguments.clone();
|
||||
|
||||
// Parse arguments as JSON
|
||||
let input = serde_json::from_str::<Value>(&arguments)
|
||||
.unwrap_or_else(|_| serde_json::json!({"raw": arguments}));
|
||||
|
||||
// Register or update tool call
|
||||
if let Some(existing_tool_call) = self.tool_calls.get_mut(&tool_id) {
|
||||
// Update existing tool call with new chunks
|
||||
existing_tool_call.chunks.push(arguments.clone());
|
||||
existing_tool_call.input = input.clone();
|
||||
if existing_tool_call.state == ToolCallState::InProgress {
|
||||
existing_tool_call.state = ToolCallState::Complete;
|
||||
}
|
||||
} else {
|
||||
// Register new tool call
|
||||
self.tool_calls.insert(tool_id.clone(), ToolCallInfo {
|
||||
id: tool_id.clone(),
|
||||
name: name.clone(),
|
||||
input: input.clone(),
|
||||
output: None,
|
||||
timestamp: Utc::now(),
|
||||
state: ToolCallState::Complete,
|
||||
chunks: vec![arguments.clone()],
|
||||
});
|
||||
}
|
||||
|
||||
// Process with appropriate processor
|
||||
if let Some(processor) = self.processors.get_processor_for_tool(&name) {
|
||||
let processed = processor.process(tool_id.clone(), &serde_json::to_string(&input)?)?;
|
||||
|
||||
if let Some(output) = processed.clone() {
|
||||
// Store as reasoning message
|
||||
self.add_reasoning_message(
|
||||
tool_id.clone(),
|
||||
MessageType::AssistantToolCall,
|
||||
output.clone()
|
||||
);
|
||||
|
||||
return Ok(Some(output));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process an Assistant message with content (text response)
|
||||
fn process_assistant_response(
|
||||
&mut self,
|
||||
_message: &LiteLlmMessage,
|
||||
content: &str,
|
||||
id: Option<String>
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// For response messages, we just store the text
|
||||
self.response_messages.push(content.to_string());
|
||||
|
||||
// Create a simple processed output
|
||||
let processed = ProcessedOutput::Text(ReasoningText {
|
||||
id: id.clone().unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
reasoning_type: "response".to_string(),
|
||||
title: "Assistant Response".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some(content.to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("complete".to_string()),
|
||||
});
|
||||
|
||||
// Add to reasoning messages
|
||||
self.add_reasoning_message(
|
||||
id.unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
MessageType::AssistantResponse,
|
||||
processed.clone()
|
||||
);
|
||||
|
||||
Ok(Some(processed))
|
||||
}
|
||||
|
||||
/// Process a Tool message (output from executed tool call)
|
||||
fn process_tool_output(
|
||||
&mut self,
|
||||
_message: &LiteLlmMessage,
|
||||
tool_call_id: &str,
|
||||
content: &str,
|
||||
_id: Option<String>
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Parse content as JSON if possible
|
||||
let output = serde_json::from_str::<Value>(content)
|
||||
.unwrap_or_else(|_| serde_json::json!({"text": content}));
|
||||
|
||||
// Update tool call with output
|
||||
if let Some(tool_call) = self.tool_calls.get_mut(tool_call_id) {
|
||||
tool_call.output = Some(output.clone());
|
||||
tool_call.state = ToolCallState::HasOutput;
|
||||
|
||||
// Get the tool name
|
||||
let name = tool_call.name.clone();
|
||||
|
||||
// Process with appropriate processor
|
||||
if self.processors.has_processor_for_tool(&name) {
|
||||
if let Ok(Some(processed)) = self.processors.process_with_tool(
|
||||
&name,
|
||||
tool_call_id.to_string(),
|
||||
&serde_json::to_string(&output)?
|
||||
) {
|
||||
// Store as reasoning message
|
||||
self.add_reasoning_message(
|
||||
tool_call_id.to_string(),
|
||||
MessageType::ToolOutput,
|
||||
processed.clone()
|
||||
);
|
||||
|
||||
return Ok(Some(processed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Gets the cached output for the given ID
|
||||
pub fn get_cached_output(&self, id: &str) -> Option<&ProcessedOutput> {
|
||||
self.processors.get_cached_output(id)
|
||||
}
|
||||
|
||||
/// Caches the output for the given ID
|
||||
pub fn cache_output(&mut self, id: String, output: ProcessedOutput) {
|
||||
self.processors.cache_output(id, output);
|
||||
}
|
||||
|
||||
/// Clears the cache for the given ID
|
||||
pub fn clear_cache(&mut self, id: &str) {
|
||||
self.processors.clear_cache(id);
|
||||
}
|
||||
|
||||
/// Clears all caches
|
||||
pub fn clear_all_caches(&mut self) {
|
||||
self.processors.clear_all_caches();
|
||||
}
|
||||
|
||||
/// Clears the chunk buffer for the given ID
|
||||
pub fn clear_chunk_buffer(&mut self, id: &str) {
|
||||
self.processors.clear_chunk_buffer(id);
|
||||
}
|
||||
|
||||
/// Process YAML content in JSON
|
||||
|
@ -130,57 +362,6 @@ impl StreamingParser {
|
|||
processed_json
|
||||
}
|
||||
|
||||
/// Complete JSON structure by adding missing brackets and braces
|
||||
fn complete_json_structure(&self, json: String) -> String {
|
||||
let mut processed = String::with_capacity(json.len());
|
||||
let mut nesting_stack = Vec::new();
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
// Process each character and track structure
|
||||
for c in json.chars() {
|
||||
processed.push(c);
|
||||
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match c {
|
||||
'\\' => escape_next = true,
|
||||
'"' if !escape_next => in_string = !in_string,
|
||||
'{' | '[' if !in_string => nesting_stack.push(c),
|
||||
'}' if !in_string => {
|
||||
if nesting_stack.last() == Some(&'{') {
|
||||
nesting_stack.pop();
|
||||
}
|
||||
}
|
||||
']' if !in_string => {
|
||||
if nesting_stack.last() == Some(&'[') {
|
||||
nesting_stack.pop();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Close any unclosed strings
|
||||
if in_string {
|
||||
processed.push('"');
|
||||
}
|
||||
|
||||
// Close structures in reverse order of opening
|
||||
while let Some(c) = nesting_stack.pop() {
|
||||
match c {
|
||||
'{' => processed.push('}'),
|
||||
'[' => processed.push(']'),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
processed
|
||||
}
|
||||
|
||||
/// Process file data for metric and dashboard files
|
||||
pub fn process_file_data(
|
||||
&mut self,
|
||||
|
@ -252,7 +433,7 @@ impl StreamingParser {
|
|||
let file_content = FileContent {
|
||||
text: Some(yml_content),
|
||||
text_chunk: None,
|
||||
modifided: None,
|
||||
modified: None,
|
||||
};
|
||||
|
||||
// Create file
|
||||
|
@ -318,6 +499,335 @@ impl StreamingParser {
|
|||
|
||||
Ok(Uuid::from_bytes(bytes))
|
||||
}
|
||||
|
||||
/// Completes any incomplete JSON structure by adding missing closing brackets
|
||||
fn complete_json_structure(&self, json: String) -> String {
|
||||
let mut stack = Vec::new();
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for c in json.chars() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match c {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' => in_string = !in_string,
|
||||
'{' | '[' if !in_string => stack.push(c),
|
||||
'}' if !in_string => {
|
||||
if let Some('{') = stack.last() {
|
||||
stack.pop();
|
||||
}
|
||||
}
|
||||
']' if !in_string => {
|
||||
if let Some('[') = stack.last() {
|
||||
stack.pop();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// If we have an incomplete JSON, add the missing closing brackets
|
||||
let mut completed_json = json;
|
||||
while let Some(c) = stack.pop() {
|
||||
match c {
|
||||
'{' => completed_json.push('}'),
|
||||
'[' => completed_json.push(']'),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
completed_json
|
||||
}
|
||||
|
||||
/// Adds a reasoning message
|
||||
fn add_reasoning_message(&mut self, id: String, message_type: MessageType, content: ProcessedOutput) {
|
||||
self.reasoning_messages.push(ProcessedMessage {
|
||||
id,
|
||||
message_type,
|
||||
content,
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Gets all reasoning messages
|
||||
pub fn get_reasoning_messages(&self) -> &[ProcessedMessage] {
|
||||
&self.reasoning_messages
|
||||
}
|
||||
|
||||
/// Gets all response messages
|
||||
pub fn get_response_messages(&self) -> &[String] {
|
||||
&self.response_messages
|
||||
}
|
||||
|
||||
/// Gets all tool calls
|
||||
pub fn get_tool_calls(&self) -> &HashMap<String, ToolCallInfo> {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
/// Gets a specific tool call by ID
|
||||
pub fn get_tool_call(&self, id: &str) -> Option<&ToolCallInfo> {
|
||||
self.tool_calls.get(id)
|
||||
}
|
||||
|
||||
/// Registers a processor for a specific tool
|
||||
pub fn register_tool_processor(&mut self, name: &str, processor: Box<dyn crate::processor::Processor>) {
|
||||
self.processors.register_tool_processor(name, processor);
|
||||
}
|
||||
|
||||
/// Process a streaming chunk for a tool call
|
||||
pub fn process_tool_call_chunk(
|
||||
&mut self,
|
||||
tool_id: &str,
|
||||
tool_name: &str,
|
||||
chunk: &str
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Update or create tool call info
|
||||
if let Some(tool_call) = self.tool_calls.get_mut(tool_id) {
|
||||
// Update existing tool call
|
||||
tool_call.chunks.push(chunk.to_string());
|
||||
|
||||
// Update chunk buffer
|
||||
self.processors.update_tool_chunk_buffer(tool_id, chunk);
|
||||
|
||||
// Get the complete buffer - clone it to end the immutable borrow
|
||||
let buffer = match self.processors.get_tool_chunk_buffer(tool_id) {
|
||||
Some(buffer) => buffer.clone(),
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// Try to parse as JSON
|
||||
if let Ok(input) = serde_json::from_str::<Value>(&buffer) {
|
||||
// Update the tool call input
|
||||
tool_call.input = input.clone();
|
||||
|
||||
// Check if we have a processor for this tool - store result to end immutable borrow
|
||||
let has_processor = self.processors.has_processor_for_tool(tool_name);
|
||||
|
||||
// Process with appropriate processor
|
||||
if has_processor {
|
||||
if let Ok(Some(processed)) = self.processors.process_and_cache_tool_output(
|
||||
tool_name,
|
||||
tool_id.to_string(),
|
||||
&buffer
|
||||
) {
|
||||
// Update the tool call state
|
||||
tool_call.state = ToolCallState::Complete;
|
||||
|
||||
// Store as reasoning message
|
||||
self.add_reasoning_message(
|
||||
tool_id.to_string(),
|
||||
MessageType::AssistantToolCall,
|
||||
processed.clone()
|
||||
);
|
||||
|
||||
return Ok(Some(processed));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Create new tool call
|
||||
self.tool_calls.insert(tool_id.to_string(), ToolCallInfo {
|
||||
id: tool_id.to_string(),
|
||||
name: tool_name.to_string(),
|
||||
input: serde_json::json!({}),
|
||||
output: None,
|
||||
timestamp: Utc::now(),
|
||||
state: ToolCallState::InProgress,
|
||||
chunks: vec![chunk.to_string()],
|
||||
});
|
||||
|
||||
// Initialize chunk buffer
|
||||
self.processors.update_tool_chunk_buffer(tool_id, chunk);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process a streaming chunk for a tool output
|
||||
pub fn process_tool_output_chunk(
|
||||
&mut self,
|
||||
tool_id: &str,
|
||||
chunk: &str
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Update chunk buffer
|
||||
self.processors.update_tool_chunk_buffer(tool_id, chunk);
|
||||
|
||||
// Get the tool call
|
||||
if let Some(tool_call) = self.tool_calls.get_mut(tool_id) {
|
||||
// Get the tool name
|
||||
let tool_name = tool_call.name.clone();
|
||||
|
||||
// Get the complete buffer - clone it to end the immutable borrow
|
||||
let buffer = match self.processors.get_tool_chunk_buffer(tool_id) {
|
||||
Some(buffer_ref) => buffer_ref.clone(),
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// Try to parse as JSON
|
||||
let output = serde_json::from_str::<Value>(&buffer)
|
||||
.unwrap_or_else(|_| serde_json::json!({"text": buffer.clone()}));
|
||||
|
||||
// Update the tool call output
|
||||
tool_call.output = Some(output.clone());
|
||||
tool_call.state = ToolCallState::HasOutput;
|
||||
|
||||
// Process with appropriate processor
|
||||
if self.processors.has_processor_for_tool(&tool_name) {
|
||||
// Get previous output if available
|
||||
let previous_output = self.processors.get_cached_output(tool_id).cloned();
|
||||
|
||||
if let Ok(Some(processed)) = self.processors.process_and_cache_tool_output_with_context(
|
||||
&tool_name,
|
||||
tool_id.to_string(),
|
||||
&buffer,
|
||||
previous_output
|
||||
) {
|
||||
// Store as reasoning message
|
||||
self.add_reasoning_message(
|
||||
tool_id.to_string(),
|
||||
MessageType::ToolOutput,
|
||||
processed.clone()
|
||||
);
|
||||
|
||||
return Ok(Some(processed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process a streaming chunk for an assistant response
|
||||
pub fn process_response_chunk(
|
||||
&mut self,
|
||||
id: &str,
|
||||
chunk: &str
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Update chunk buffer
|
||||
self.processors.update_chunk_buffer(id.to_string(), chunk.to_string());
|
||||
|
||||
// Get the complete buffer
|
||||
if let Some(buffer) = self.processors.get_chunk_buffer(id) {
|
||||
// Add to response messages if not already present
|
||||
if !self.response_messages.contains(&buffer) {
|
||||
self.response_messages.push(buffer.clone());
|
||||
}
|
||||
|
||||
// Create a simple processed output
|
||||
let processed = ProcessedOutput::Text(ReasoningText {
|
||||
id: id.to_string(),
|
||||
reasoning_type: "response".to_string(),
|
||||
title: "Assistant Response".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some(buffer),
|
||||
message_chunk: Some(chunk.to_string()),
|
||||
status: Some("streaming".to_string()),
|
||||
});
|
||||
|
||||
// Add to reasoning messages
|
||||
self.add_reasoning_message(
|
||||
id.to_string(),
|
||||
MessageType::AssistantResponse,
|
||||
processed.clone()
|
||||
);
|
||||
|
||||
return Ok(Some(processed));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Clears all tool calls and their associated data
|
||||
pub fn clear_tool_calls(&mut self) {
|
||||
self.tool_calls.clear();
|
||||
self.reasoning_messages.retain(|msg| {
|
||||
msg.message_type != MessageType::AssistantToolCall &&
|
||||
msg.message_type != MessageType::ToolOutput
|
||||
});
|
||||
}
|
||||
|
||||
/// Clears a specific tool call and its associated data
|
||||
pub fn clear_tool_call(&mut self, tool_id: &str) {
|
||||
self.tool_calls.remove(tool_id);
|
||||
self.reasoning_messages.retain(|msg| msg.id != tool_id);
|
||||
self.processors.clear_tool_chunk_buffer(tool_id);
|
||||
self.processors.clear_cache(tool_id);
|
||||
}
|
||||
|
||||
/// Gets all tool calls that are in the specified state
|
||||
pub fn get_tool_calls_by_state(&self, state: ToolCallState) -> Vec<&ToolCallInfo> {
|
||||
self.tool_calls
|
||||
.values()
|
||||
.filter(|call| call.state == state)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets all completed tool calls (those with state Complete or HasOutput)
|
||||
pub fn get_completed_tool_calls(&self) -> Vec<&ToolCallInfo> {
|
||||
self.tool_calls
|
||||
.values()
|
||||
.filter(|call| call.state == ToolCallState::Complete || call.state == ToolCallState::HasOutput)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets all tool calls for a specific tool name
|
||||
pub fn get_tool_calls_by_name(&self, name: &str) -> Vec<&ToolCallInfo> {
|
||||
self.tool_calls
|
||||
.values()
|
||||
.filter(|call| call.name == name)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets the most recent tool call for a specific tool name
|
||||
pub fn get_latest_tool_call_by_name(&self, name: &str) -> Option<&ToolCallInfo> {
|
||||
self.tool_calls
|
||||
.values()
|
||||
.filter(|call| call.name == name)
|
||||
.max_by_key(|call| call.timestamp)
|
||||
}
|
||||
|
||||
/// Gets the reasoning messages for a specific tool call
|
||||
pub fn get_reasoning_messages_for_tool(&self, tool_id: &str) -> Vec<&ProcessedMessage> {
|
||||
self.reasoning_messages
|
||||
.iter()
|
||||
.filter(|msg| msg.id == tool_id)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Gets the combined input and output for a specific tool call
|
||||
pub fn get_tool_call_with_output(&self, tool_id: &str) -> Option<(Value, Option<Value>)> {
|
||||
self.tool_calls.get(tool_id).map(|call| (call.input.clone(), call.output.clone()))
|
||||
}
|
||||
|
||||
/// Exports all tool calls and their outputs as a JSON object
|
||||
pub fn export_tool_calls_as_json(&self) -> Value {
|
||||
let mut result = serde_json::json!({});
|
||||
|
||||
for (id, call) in &self.tool_calls {
|
||||
let mut call_data = serde_json::json!({
|
||||
"name": call.name,
|
||||
"input": call.input,
|
||||
"state": match call.state {
|
||||
ToolCallState::InProgress => "in_progress",
|
||||
ToolCallState::Complete => "complete",
|
||||
ToolCallState::HasOutput => "has_output",
|
||||
},
|
||||
"timestamp": call.timestamp.to_rfc3339(),
|
||||
});
|
||||
|
||||
if let Some(output) = &call.output {
|
||||
call_data["output"] = output.clone();
|
||||
}
|
||||
|
||||
result[id] = call_data;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -333,113 +843,198 @@ mod tests {
|
|||
ProcessorType::Custom("test".to_string())
|
||||
}
|
||||
|
||||
fn can_process(&self, json: &str) -> bool {
|
||||
json.contains("test_key")
|
||||
fn can_process(&self, _json: &str) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
if self.can_process(json) {
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Test".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some("Test message".to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("completed".to_string()),
|
||||
})))
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Test".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some(json.to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("completed".to_string()),
|
||||
})))
|
||||
}
|
||||
|
||||
fn process_with_context(&self, id: String, json: &str, previous_output: Option<ProcessedOutput>) -> Result<Option<ProcessedOutput>> {
|
||||
// Get the previously processed content
|
||||
let previous_content = if let Some(ProcessedOutput::Text(text)) = previous_output {
|
||||
text.message.clone().unwrap_or_default()
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Calculate the new content (what wasn't in the previous content)
|
||||
let new_content = if json.len() > previous_content.len() {
|
||||
json[previous_content.len()..].to_string()
|
||||
} else {
|
||||
// If for some reason the new content is shorter, just use the whole thing
|
||||
json.to_string()
|
||||
};
|
||||
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Test".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some(json.to_string()),
|
||||
message_chunk: if new_content.is_empty() { None } else { Some(new_content) },
|
||||
status: Some("loading".to_string()),
|
||||
})))
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(TestProcessor)
|
||||
}
|
||||
}
|
||||
|
||||
struct MockProcessor {
|
||||
processor_type: ProcessorType,
|
||||
}
|
||||
|
||||
impl Processor for MockProcessor {
|
||||
fn processor_type(&self) -> ProcessorType {
|
||||
self.processor_type.clone()
|
||||
}
|
||||
|
||||
fn can_process(&self, _json: &str) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn process(&self, id: String, _json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Mock".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some("Mock message".to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("completed".to_string()),
|
||||
})))
|
||||
}
|
||||
|
||||
fn process_with_context(
|
||||
&self,
|
||||
id: String,
|
||||
_json: &str,
|
||||
_previous_output: Option<ProcessedOutput>,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Mock".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some("Mock message".to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("completed".to_string()),
|
||||
})))
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(MockProcessor {
|
||||
processor_type: self.processor_type.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complete_json_structure() {
|
||||
let parser = StreamingParser::new();
|
||||
|
||||
// Test basic completion
|
||||
let incomplete = r#"{"key": "value"#;
|
||||
let completed = parser.complete_json_structure(incomplete.to_string());
|
||||
|
||||
// Parse the completed JSON to verify it's valid
|
||||
let parsed_json: serde_json::Value = serde_json::from_str(&completed).unwrap();
|
||||
assert_eq!(parsed_json["key"], "value");
|
||||
|
||||
// Test escaped quotes in string
|
||||
let incomplete = r#"{"key": "value with \"quotes\""#;
|
||||
let completed = parser.complete_json_structure(incomplete.to_string());
|
||||
|
||||
println!("Completed JSON: {}", completed);
|
||||
|
||||
// Parse the completed JSON to verify it's valid
|
||||
let parsed_json: serde_json::Value = serde_json::from_str(&completed).unwrap();
|
||||
assert_eq!(
|
||||
parsed_json["key"].as_str().unwrap(),
|
||||
r#"value with "quotes""#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_chunk() {
|
||||
fn test_streaming_parser_process_chunk() {
|
||||
let mut parser = StreamingParser::new();
|
||||
parser.register_processor(Box::new(TestProcessor));
|
||||
|
||||
// Test with valid data for the processor
|
||||
let result =
|
||||
parser.process_chunk("test_id".to_string(), r#"{"test_key": "value"}"#, "test");
|
||||
// Process a chunk
|
||||
let result = parser.process_chunk("id1".to_string(), r#"{"test": "value"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_some());
|
||||
|
||||
// Test with invalid data for the processor
|
||||
parser.clear_buffer();
|
||||
let result =
|
||||
parser.process_chunk("test_id".to_string(), r#"{"other_key": "value"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
|
||||
// Test with non-existent processor
|
||||
parser.clear_buffer();
|
||||
let result = parser.process_chunk(
|
||||
"test_id".to_string(),
|
||||
r#"{"test_key": "value"}"#,
|
||||
"non_existent",
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_yml_content() {
|
||||
let parser = StreamingParser::new();
|
||||
fn test_streaming_parser_caching() {
|
||||
let mut parser = StreamingParser::new();
|
||||
parser.register_processor(Box::new(TestProcessor));
|
||||
|
||||
// Test with yml_content
|
||||
let json = r#"{"files":[{"name":"test.yml","yml_content":"key: value\nlist:\n - item1\n - item2"}]}"#;
|
||||
let processed = parser.process_yml_content(json.to_string());
|
||||
// Process a chunk
|
||||
let id = "cache_test_id".to_string();
|
||||
let result = parser.process_chunk(id.clone(), r#"{"test": "value"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap().unwrap();
|
||||
|
||||
// Parse the processed JSON to verify it's valid
|
||||
let value: Value = serde_json::from_str(&processed).unwrap();
|
||||
// Cache the output
|
||||
parser.cache_output(id.clone(), output.clone());
|
||||
|
||||
// Check that the yml_content was properly processed
|
||||
let yml_content = value["files"][0]["yml_content"].as_str().unwrap();
|
||||
assert!(yml_content.contains("key: value"));
|
||||
assert!(yml_content.contains("list:"));
|
||||
// Check if cached
|
||||
let cached = parser.get_cached_output(&id);
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap(), &output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_yml_content_with_invalid_json() {
|
||||
let parser = StreamingParser::new();
|
||||
fn test_streaming_parser_multiple_ids() {
|
||||
let mut parser = StreamingParser::new();
|
||||
parser.register_processor(Box::new(TestProcessor));
|
||||
|
||||
// Test with invalid JSON
|
||||
let json = r#"{"files":[{"name":"test.yml","yml_content":"key: value\nlist:\n - item1\n - item2"}]"#;
|
||||
let processed = parser.process_yml_content(json.to_string());
|
||||
// Process with first ID
|
||||
let result1 = parser.process_chunk("id1".to_string(), r#"{"test": "value1"}"#, "test");
|
||||
assert!(result1.is_ok());
|
||||
let output1 = result1.unwrap().unwrap();
|
||||
|
||||
// Parse the processed JSON to verify it's valid
|
||||
let value: Value = serde_json::from_str(&processed).unwrap();
|
||||
// Process with second ID
|
||||
let result2 = parser.process_chunk("id2".to_string(), r#"{"test": "value2"}"#, "test");
|
||||
assert!(result2.is_ok());
|
||||
let output2 = result2.unwrap().unwrap();
|
||||
|
||||
// Check that the yml_content was properly processed
|
||||
let yml_content = value["files"][0]["yml_content"].as_str().unwrap();
|
||||
assert!(yml_content.contains("key: value"));
|
||||
assert!(yml_content.contains("list:"));
|
||||
// Verify they have different content
|
||||
match (output1, output2) {
|
||||
(ProcessedOutput::Text(text1), ProcessedOutput::Text(text2)) => {
|
||||
assert_eq!(text1.id, "id1");
|
||||
assert_eq!(text2.id, "id2");
|
||||
assert_ne!(text1.message, text2.message);
|
||||
}
|
||||
_ => panic!("Expected Text outputs"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_parser_incomplete_json() {
|
||||
let mut parser = StreamingParser::new();
|
||||
parser.register_processor(Box::new(TestProcessor));
|
||||
|
||||
// Process incomplete JSON
|
||||
let id = "incomplete_json_id".to_string();
|
||||
let result = parser.process_chunk(id.clone(), r#"{"test": "value"#, "test");
|
||||
|
||||
// The TestProcessor should process this and include the chunk
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
|
||||
// Verify the output has the correct chunk
|
||||
if let Some(ProcessedOutput::Text(text)) = output {
|
||||
assert_eq!(text.message, Some(r#"{"test": "value}"#.to_string()));
|
||||
assert_eq!(text.message_chunk, Some(r#"{"test": "value}"#.to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
|
||||
// Complete the JSON
|
||||
let result = parser.process_chunk(id.clone(), r#"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
|
||||
// Verify the output has the correct chunk - should only contain the new part
|
||||
if let Some(ProcessedOutput::Text(text)) = output {
|
||||
assert_eq!(text.message, Some(r#"{"test": "value}}"#.to_string()));
|
||||
assert_eq!(text.message_chunk, Some(r#"}"#.to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
|
||||
// Check the cached output
|
||||
let cached = parser.get_cached_output(&id);
|
||||
assert!(cached.is_some());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,20 +4,37 @@ use std::collections::HashMap;
|
|||
use crate::types::{ProcessedOutput, ProcessorType};
|
||||
|
||||
/// Trait defining the interface for streaming processors
|
||||
pub trait Processor {
|
||||
pub trait Processor: Send + Sync {
|
||||
/// Returns the type of processor
|
||||
fn processor_type(&self) -> ProcessorType;
|
||||
|
||||
|
||||
/// Checks if this processor can handle the given JSON data
|
||||
fn can_process(&self, json: &str) -> bool;
|
||||
|
||||
|
||||
/// Processes the JSON data and returns a processed output
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>>;
|
||||
|
||||
/// Processes the JSON data with context from a previous output and returns a processed output
|
||||
fn process_with_context(
|
||||
&self,
|
||||
id: String,
|
||||
json: &str,
|
||||
previous_output: Option<ProcessedOutput>,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Default implementation just calls process without using the context
|
||||
self.process(id, json)
|
||||
}
|
||||
|
||||
/// Creates a clone of this processor
|
||||
fn clone_box(&self) -> Box<dyn Processor>;
|
||||
}
|
||||
|
||||
/// Registry for managing processors
|
||||
pub struct ProcessorRegistry {
|
||||
processors: HashMap<String, Box<dyn Processor>>,
|
||||
processors: HashMap<String, (String, Box<dyn Processor>)>, // (id, (type, processor))
|
||||
tool_processors: HashMap<String, Box<dyn Processor>>, // (tool_name, processor)
|
||||
output_cache: HashMap<String, ProcessedOutput>, // (id, output)
|
||||
chunk_buffers: HashMap<String, String>, // (id, buffer)
|
||||
}
|
||||
|
||||
impl ProcessorRegistry {
|
||||
|
@ -25,110 +42,394 @@ impl ProcessorRegistry {
|
|||
pub fn new() -> Self {
|
||||
ProcessorRegistry {
|
||||
processors: HashMap::new(),
|
||||
tool_processors: HashMap::new(),
|
||||
output_cache: HashMap::new(),
|
||||
chunk_buffers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a processor with the registry
|
||||
|
||||
/// Registers a processor
|
||||
pub fn register(&mut self, processor: Box<dyn Processor>) {
|
||||
let processor_type = processor.processor_type().as_str().to_string();
|
||||
self.processors.insert(processor_type, processor);
|
||||
self.processors
|
||||
.insert(processor_type.clone(), (processor_type, processor));
|
||||
}
|
||||
|
||||
/// Processes data using the appropriate processor
|
||||
pub fn process(&self, id: String, json: &str, processor_type: &str) -> Result<Option<ProcessedOutput>> {
|
||||
if let Some(processor) = self.processors.get(processor_type) {
|
||||
|
||||
/// Registers a processor with a specific ID
|
||||
pub fn register_with_id(&mut self, id: String, processor: Box<dyn Processor>) {
|
||||
let processor_type = processor.processor_type().as_str().to_string();
|
||||
self.processors.insert(id, (processor_type, processor));
|
||||
}
|
||||
|
||||
/// Registers a processor for a specific tool
|
||||
pub fn register_tool_processor(&mut self, tool_name: &str, processor: Box<dyn Processor>) {
|
||||
self.tool_processors.insert(tool_name.to_string(), processor);
|
||||
}
|
||||
|
||||
/// Returns true if the registry has a processor for the given type
|
||||
pub fn has_processor(&self, processor_type: &str) -> bool {
|
||||
self.processors
|
||||
.values()
|
||||
.any(|(type_str, _)| type_str == processor_type)
|
||||
}
|
||||
|
||||
/// Returns true if the registry has a processor for the given ID
|
||||
pub fn has_processor_with_id(&self, id: &str) -> bool {
|
||||
self.processors.contains_key(id)
|
||||
}
|
||||
|
||||
/// Returns true if the registry has a processor for the given tool
|
||||
pub fn has_processor_for_tool(&self, tool_name: &str) -> bool {
|
||||
self.tool_processors.contains_key(tool_name)
|
||||
}
|
||||
|
||||
/// Returns a reference to all processors
|
||||
pub fn get_processors(&self) -> &HashMap<String, (String, Box<dyn Processor>)> {
|
||||
&self.processors
|
||||
}
|
||||
|
||||
/// Returns a processor for a specific tool
|
||||
pub fn get_processor_for_tool(&self, tool_name: &str) -> Option<&Box<dyn Processor>> {
|
||||
self.tool_processors.get(tool_name)
|
||||
}
|
||||
|
||||
/// Processes the given JSON with the appropriate processor
|
||||
pub fn process(
|
||||
&self,
|
||||
id: String,
|
||||
json: &str,
|
||||
processor_type: &str,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Check if we have a processor registered with this ID
|
||||
if let Some((_, processor)) = self.processors.get(&id) {
|
||||
if processor.can_process(json) {
|
||||
return processor.process(id, json);
|
||||
}
|
||||
}
|
||||
|
||||
// If not, find a processor by type
|
||||
for (_, (type_str, processor)) in &self.processors {
|
||||
if type_str == processor_type && processor.can_process(json) {
|
||||
// Create a new processor instance with this ID for future use
|
||||
let result = processor.process(id.clone(), json);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Processes the given JSON with the processor registered with the given ID
|
||||
pub fn process_by_id(
|
||||
&self,
|
||||
id: String,
|
||||
json: &str,
|
||||
processor_type: &str,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Check if we have a processor registered with this ID
|
||||
if let Some((_, processor)) = self.processors.get(&id) {
|
||||
if processor.can_process(json) {
|
||||
return processor.process(id, json);
|
||||
}
|
||||
}
|
||||
|
||||
// If not, find a processor by type and register it with this ID
|
||||
for (_, (type_str, processor)) in &self.processors {
|
||||
if type_str == processor_type && processor.can_process(json) {
|
||||
return processor.process(id, json);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process a JSON string with a processor by ID and processor type, providing previous output context
|
||||
pub fn process_by_id_with_context(
|
||||
&self,
|
||||
id: String,
|
||||
json: &str,
|
||||
processor_type: &str,
|
||||
previous_output: Option<ProcessedOutput>,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
println!("Processor ID: {}", id);
|
||||
// Check if we have a processor registered with this ID
|
||||
if let Some((_, processor)) = self.processors.get(&id) {
|
||||
println!("Processor exists");
|
||||
return processor.process_with_context(id, json, previous_output);
|
||||
}
|
||||
|
||||
// If not, find a processor by type and register it with this ID
|
||||
for (_, (type_str, processor)) in &self.processors {
|
||||
if type_str == processor_type && processor.can_process(json) {
|
||||
println!("Processor does not exist");
|
||||
return processor.process_with_context(id, json, previous_output);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process a JSON string with a processor for a specific tool
|
||||
pub fn process_with_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
id: String,
|
||||
json: &str,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
if let Some(processor) = self.get_processor_for_tool(tool_name) {
|
||||
return processor.process(id, json);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Checks if a processor of the given type is registered
|
||||
pub fn has_processor(&self, processor_type: &str) -> bool {
|
||||
self.processors.contains_key(processor_type)
|
||||
|
||||
/// Process a JSON string with a processor for a specific tool, providing previous output context
|
||||
pub fn process_with_tool_and_context(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
id: String,
|
||||
json: &str,
|
||||
previous_output: Option<ProcessedOutput>,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
if let Some(processor) = self.get_processor_for_tool(tool_name) {
|
||||
return processor.process_with_context(id, json, previous_output);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process a JSON string with a processor for a specific tool and cache the output
|
||||
pub fn process_and_cache_tool_output(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
id: String,
|
||||
json: &str,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
let result = self.process_with_tool(tool_name, id.clone(), json)?;
|
||||
|
||||
if let Some(output) = &result {
|
||||
self.output_cache.insert(id, output.clone());
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a list of registered processor types
|
||||
pub fn processor_types(&self) -> Vec<String> {
|
||||
self.processors.keys().cloned().collect()
|
||||
/// Process a JSON string with a processor for a specific tool, providing previous output context,
|
||||
/// and cache the result
|
||||
pub fn process_and_cache_tool_output_with_context(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
id: String,
|
||||
json: &str,
|
||||
previous_output: Option<ProcessedOutput>,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
let result = self.process_with_tool_and_context(tool_name, id.clone(), json, previous_output)?;
|
||||
|
||||
if let Some(output) = &result {
|
||||
self.output_cache.insert(id, output.clone());
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Updates the chunk buffer for a specific tool call
|
||||
pub fn update_tool_chunk_buffer(&mut self, id: &str, chunk: &str) {
|
||||
if let Some(buffer) = self.chunk_buffers.get_mut(id) {
|
||||
buffer.push_str(chunk);
|
||||
} else {
|
||||
self.chunk_buffers.insert(id.to_string(), chunk.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the chunk buffer for a specific tool call
|
||||
pub fn get_tool_chunk_buffer(&self, id: &str) -> Option<&String> {
|
||||
self.chunk_buffers.get(id)
|
||||
}
|
||||
|
||||
/// Clears the chunk buffer for a specific tool call
|
||||
pub fn clear_tool_chunk_buffer(&mut self, id: &str) {
|
||||
self.chunk_buffers.remove(id);
|
||||
}
|
||||
|
||||
/// Gets the chunk buffer for the given ID
|
||||
pub fn get_chunk_buffer(&self, id: &str) -> Option<String> {
|
||||
self.chunk_buffers.get(id).cloned()
|
||||
}
|
||||
|
||||
/// Updates the chunk buffer for the given ID
|
||||
pub fn update_chunk_buffer(&mut self, id: String, buffer: String) {
|
||||
self.chunk_buffers.insert(id, buffer);
|
||||
}
|
||||
|
||||
/// Clears the chunk buffer for the given ID
|
||||
pub fn clear_chunk_buffer(&mut self, id: &str) {
|
||||
self.chunk_buffers.remove(id);
|
||||
}
|
||||
|
||||
/// Caches the processed output for the given ID
|
||||
pub fn cache_output(&mut self, id: String, output: ProcessedOutput) {
|
||||
self.output_cache.insert(id, output);
|
||||
}
|
||||
|
||||
/// Gets the cached output for the given ID
|
||||
pub fn get_cached_output(&self, id: &str) -> Option<&ProcessedOutput> {
|
||||
self.output_cache.get(id)
|
||||
}
|
||||
|
||||
/// Clears the cache for the given ID
|
||||
pub fn clear_cache(&mut self, id: &str) {
|
||||
self.output_cache.remove(id);
|
||||
}
|
||||
|
||||
/// Clears all caches
|
||||
pub fn clear_all_caches(&mut self) {
|
||||
self.output_cache.clear();
|
||||
self.chunk_buffers.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use crate::types::{ProcessedOutput, ReasoningText};
|
||||
|
||||
struct TestProcessor;
|
||||
|
||||
|
||||
impl Processor for TestProcessor {
|
||||
fn processor_type(&self) -> ProcessorType {
|
||||
ProcessorType::Custom("test".to_string())
|
||||
}
|
||||
|
||||
fn can_process(&self, json: &str) -> bool {
|
||||
json.contains("test_key")
|
||||
|
||||
fn can_process(&self, _json: &str) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
if self.can_process(json) {
|
||||
Ok(Some(ProcessedOutput::Text(crate::types::ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Test".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some("Test message".to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("completed".to_string()),
|
||||
})))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Test".to_string(),
|
||||
secondary_title: "Test".to_string(),
|
||||
message: Some(json.to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("loading".to_string()),
|
||||
})))
|
||||
}
|
||||
|
||||
fn process_with_context(
|
||||
&self,
|
||||
id: String,
|
||||
json: &str,
|
||||
_previous_output: Option<ProcessedOutput>,
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Test".to_string(),
|
||||
secondary_title: "Test".to_string(),
|
||||
message: Some(json.to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("loading".to_string()),
|
||||
})))
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(TestProcessor)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_processor_registry() {
|
||||
let mut registry = ProcessorRegistry::new();
|
||||
|
||||
// Test empty registry
|
||||
assert!(!registry.has_processor("test"));
|
||||
assert_eq!(registry.processor_types().len(), 0);
|
||||
|
||||
// Register a processor
|
||||
registry.register(Box::new(TestProcessor));
|
||||
|
||||
// Test registry with processor
|
||||
assert!(registry.has_processor("test"));
|
||||
assert_eq!(registry.processor_types().len(), 1);
|
||||
assert_eq!(registry.processor_types()[0], "test");
|
||||
|
||||
// Test processing with valid data
|
||||
let result = registry.process(
|
||||
"test_id".to_string(),
|
||||
r#"{"test_key": "value"}"#,
|
||||
"test",
|
||||
);
|
||||
|
||||
let result = registry.process("id1".to_string(), r#"{"test": "value"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_some());
|
||||
|
||||
// Test processing with invalid data
|
||||
let result = registry.process(
|
||||
"test_id".to_string(),
|
||||
r#"{"other_key": "value"}"#,
|
||||
"test",
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_processor_registry_with_id() {
|
||||
let mut registry = ProcessorRegistry::new();
|
||||
registry.register_with_id("custom_id".to_string(), Box::new(TestProcessor));
|
||||
|
||||
let result =
|
||||
registry.process_by_id("custom_id".to_string(), r#"{"test": "value"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
|
||||
// Test processing with non-existent processor
|
||||
let result = registry.process(
|
||||
"test_id".to_string(),
|
||||
r#"{"test_key": "value"}"#,
|
||||
"non_existent",
|
||||
);
|
||||
assert!(result.unwrap().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_processor_registry_caching() {
|
||||
let mut registry = ProcessorRegistry::new();
|
||||
registry.register(Box::new(TestProcessor));
|
||||
|
||||
// Process once
|
||||
let id = "cache_test_id".to_string();
|
||||
let result = registry.process(id.clone(), r#"{"test": "value"}"#, "test");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
let output = result.unwrap().unwrap();
|
||||
|
||||
// Cache the output
|
||||
registry.cache_output(id.clone(), output.clone());
|
||||
|
||||
// Check if cached
|
||||
let cached = registry.get_cached_output(&id);
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap(), &output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_processor_registry_chunk_buffers() {
|
||||
let mut registry = ProcessorRegistry::new();
|
||||
|
||||
// Set chunk buffer
|
||||
let id = "buffer_test_id".to_string();
|
||||
registry.update_chunk_buffer(id.clone(), "partial json".to_string());
|
||||
|
||||
// Get chunk buffer
|
||||
let buffer = registry.get_chunk_buffer(&id);
|
||||
assert!(buffer.is_some());
|
||||
assert_eq!(buffer.unwrap(), "partial json");
|
||||
|
||||
// Update chunk buffer
|
||||
registry.update_chunk_buffer(id.clone(), "complete json".to_string());
|
||||
let updated_buffer = registry.get_chunk_buffer(&id);
|
||||
assert!(updated_buffer.is_some());
|
||||
assert_eq!(updated_buffer.unwrap(), "complete json");
|
||||
|
||||
// Clear chunk buffer
|
||||
registry.clear_chunk_buffer(&id);
|
||||
let cleared_buffer = registry.get_chunk_buffer(&id);
|
||||
assert!(cleared_buffer.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_processors_same_type() {
|
||||
let mut registry = ProcessorRegistry::new();
|
||||
|
||||
// Register two processors with the same type but different IDs
|
||||
registry.register_with_id("id1".to_string(), Box::new(TestProcessor));
|
||||
registry.register_with_id("id2".to_string(), Box::new(TestProcessor));
|
||||
|
||||
// Process with first ID
|
||||
let result1 = registry.process_by_id("id1".to_string(), r#"{"test": "value1"}"#, "test");
|
||||
assert!(result1.is_ok());
|
||||
let output1 = result1.unwrap().unwrap();
|
||||
|
||||
// Process with second ID
|
||||
let result2 = registry.process_by_id("id2".to_string(), r#"{"test": "value2"}"#, "test");
|
||||
assert!(result2.is_ok());
|
||||
let output2 = result2.unwrap().unwrap();
|
||||
|
||||
// Verify they have different content
|
||||
match (output1, output2) {
|
||||
(ProcessedOutput::Text(text1), ProcessedOutput::Text(text2)) => {
|
||||
assert_eq!(text1.id, "id1");
|
||||
assert_eq!(text2.id, "id2");
|
||||
assert_ne!(text1.message, text2.message);
|
||||
}
|
||||
_ => panic!("Expected Text outputs"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,12 +58,24 @@ impl Processor for CreateDashboardsProcessor {
|
|||
}
|
||||
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
self.process_with_context(id, json, None)
|
||||
}
|
||||
|
||||
fn process_with_context(&self, id: String, json: &str, previous_output: Option<ProcessedOutput>) -> Result<Option<ProcessedOutput>> {
|
||||
// Try to parse the JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json) {
|
||||
// Check if it's a dashboard file structure
|
||||
if let Some(files) = value.get("files").and_then(Value::as_array) {
|
||||
let mut files_map = HashMap::new();
|
||||
let mut file_ids = Vec::new();
|
||||
|
||||
// Get previous files if they exist
|
||||
let previous_files = if let Some(ProcessedOutput::File(file_output)) = &previous_output {
|
||||
&file_output.files
|
||||
} else {
|
||||
&HashMap::new()
|
||||
};
|
||||
|
||||
for file in files {
|
||||
if let Some(file_obj) = file.as_object() {
|
||||
let has_name = file_obj.get("name").and_then(Value::as_str).is_some();
|
||||
|
@ -78,9 +90,25 @@ impl Processor for CreateDashboardsProcessor {
|
|||
|
||||
// Generate deterministic UUID based on tool call ID, file name, and type
|
||||
let file_id = self.generate_deterministic_uuid(&id, name, "dashboard")?;
|
||||
let file_id_str = file_id.to_string();
|
||||
|
||||
// Get the previously processed content for this file
|
||||
let previous_content = if let Some(prev_file) = previous_files.get(&file_id_str) {
|
||||
prev_file.file.text_chunk.clone().unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Calculate the new content (what wasn't in the previous content)
|
||||
let new_content = if yml_content.len() > previous_content.len() {
|
||||
yml_content[previous_content.len()..].to_string()
|
||||
} else {
|
||||
// If for some reason the new content is shorter, just use the whole thing
|
||||
yml_content.to_string()
|
||||
};
|
||||
|
||||
let file = File {
|
||||
id: file_id.to_string(),
|
||||
id: file_id_str.clone(),
|
||||
file_type: "dashboard".to_string(),
|
||||
file_name: name.to_string(),
|
||||
version_number: 1,
|
||||
|
@ -88,14 +116,14 @@ impl Processor for CreateDashboardsProcessor {
|
|||
status: "loading".to_string(),
|
||||
file: FileContent {
|
||||
text: None,
|
||||
text_chunk: Some(yml_content.to_string()),
|
||||
modifided: None,
|
||||
text_chunk: if new_content.is_empty() { None } else { Some(new_content) },
|
||||
modified: None,
|
||||
},
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
file_ids.push(file_id.to_string());
|
||||
files_map.insert(file_id.to_string(), file);
|
||||
file_ids.push(file_id_str.clone());
|
||||
files_map.insert(file_id_str, file);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -116,11 +144,16 @@ impl Processor for CreateDashboardsProcessor {
|
|||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(CreateDashboardsProcessor::new())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::ProcessedOutput;
|
||||
|
||||
#[test]
|
||||
fn test_can_process() {
|
||||
|
@ -156,32 +189,76 @@ mod tests {
|
|||
let json = r#"{"files":[{"name":"test_dashboard.yml","yml_content":"name: Test Dashboard\ndescription: A test dashboard"}]}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file)) = output {
|
||||
assert_eq!(file.id, id);
|
||||
assert_eq!(file.title, "Creating dashboard files...");
|
||||
assert_eq!(file.file_ids.len(), 1);
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = output {
|
||||
assert_eq!(file_output.id, id);
|
||||
assert_eq!(file_output.title, "Creating dashboard files...");
|
||||
assert_eq!(file_output.files.len(), 1);
|
||||
|
||||
// Check that the file was created with the correct properties
|
||||
let file_id = &file.file_ids[0];
|
||||
let dashboard_file = file.files.get(file_id).unwrap();
|
||||
assert_eq!(dashboard_file.file_type, "dashboard");
|
||||
assert_eq!(dashboard_file.file_name, "test_dashboard.yml");
|
||||
assert_eq!(dashboard_file.status, "loading");
|
||||
|
||||
// Check the file content
|
||||
assert!(dashboard_file.file.text_chunk.as_ref().unwrap().contains("name: Test Dashboard"));
|
||||
// Check the first file
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file_name, "test_dashboard.yml");
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, Some("name: Test Dashboard\ndescription: A test dashboard".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_with_context_streaming() {
|
||||
let processor = CreateDashboardsProcessor::new();
|
||||
let id = "test_id".to_string();
|
||||
|
||||
// First chunk
|
||||
let json1 = r#"{"files":[{"name":"test_dashboard.yml","yml_content":""}]}"#;
|
||||
let result1 = processor.process_with_context(id.clone(), json1, None);
|
||||
assert!(result1.is_ok());
|
||||
let output1 = result1.unwrap();
|
||||
assert!(output1.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = &output1 {
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, None); // Empty string, so no chunk
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
|
||||
// Test with invalid data
|
||||
let json = r#"{"other_key":"value"}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
// Second chunk
|
||||
let json2 = r#"{"files":[{"name":"test_dashboard.yml","yml_content":"name: Test Dashboard\n"}]}"#;
|
||||
let result2 = processor.process_with_context(id.clone(), json2, output1);
|
||||
assert!(result2.is_ok());
|
||||
let output2 = result2.unwrap();
|
||||
assert!(output2.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = &output2 {
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, Some("name: Test Dashboard\n".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
|
||||
// Third chunk
|
||||
let json3 = r#"{"files":[{"name":"test_dashboard.yml","yml_content":"name: Test Dashboard\ndescription: A test dashboard"}]}"#;
|
||||
let result3 = processor.process_with_context(id.clone(), json3, output2);
|
||||
assert!(result3.is_ok());
|
||||
let output3 = result3.unwrap();
|
||||
assert!(output3.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = &output3 {
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, Some("description: A test dashboard".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,49 +58,77 @@ impl Processor for CreateMetricsProcessor {
|
|||
}
|
||||
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
self.process_with_context(id, json, None)
|
||||
}
|
||||
|
||||
fn process_with_context(&self, id: String, json: &str, previous_output: Option<ProcessedOutput>) -> Result<Option<ProcessedOutput>> {
|
||||
// Try to parse the JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json) {
|
||||
// Check if it's a metric file structure
|
||||
if let Some(files) = value.get("files").and_then(Value::as_array) {
|
||||
let mut files_map = HashMap::new();
|
||||
let mut file_ids = Vec::new();
|
||||
|
||||
// Get the previously processed files
|
||||
let previous_files = if let Some(ProcessedOutput::File(output)) = &previous_output {
|
||||
output.files.clone()
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
// Process each file
|
||||
for file in files {
|
||||
if let Some(file_obj) = file.as_object() {
|
||||
let has_name = file_obj.get("name").and_then(Value::as_str).is_some();
|
||||
let has_yml_content = file_obj.get("yml_content").is_some();
|
||||
|
||||
if has_name && has_yml_content {
|
||||
let name = file_obj.get("name").and_then(Value::as_str).unwrap_or("");
|
||||
let yml_content = file_obj
|
||||
.get("yml_content")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or("");
|
||||
|
||||
// Generate deterministic UUID based on tool call ID, file name, and type
|
||||
// Check if the file has a name and yml_content
|
||||
if let (Some(name), Some(yml_content)) = (
|
||||
file.get("name").and_then(Value::as_str),
|
||||
file.get("yml_content").and_then(Value::as_str),
|
||||
) {
|
||||
// Only process files that end with .yml
|
||||
if name.ends_with(".yml") {
|
||||
// Generate a deterministic UUID for this file
|
||||
let file_id = self.generate_deterministic_uuid(&id, name, "metric")?;
|
||||
let file_id_str = file_id.to_string();
|
||||
|
||||
let file = File {
|
||||
id: file_id.to_string(),
|
||||
file_type: "metric".to_string(),
|
||||
file_name: name.to_string(),
|
||||
version_number: 1,
|
||||
version_id: String::from("0203f597-5ec5-4fd8-86e2-8587fe1c23b6"),
|
||||
status: "loading".to_string(),
|
||||
file: FileContent {
|
||||
text: None,
|
||||
text_chunk: Some(yml_content.to_string()),
|
||||
modifided: None,
|
||||
},
|
||||
metadata: None,
|
||||
// Get the previously processed content for this file
|
||||
let previous_content = if let Some(prev_file) = previous_files.get(&file_id_str) {
|
||||
prev_file.file.text_chunk.clone().unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
file_ids.push(file_id.to_string());
|
||||
files_map.insert(file_id.to_string(), file);
|
||||
// Calculate the new content (what wasn't in the previous content)
|
||||
let new_content = if yml_content.len() > previous_content.len() {
|
||||
yml_content[previous_content.len()..].to_string()
|
||||
} else {
|
||||
// If for some reason the new content is shorter, just use the whole thing
|
||||
yml_content.to_string()
|
||||
};
|
||||
|
||||
// Add the file to the output
|
||||
files_map.insert(
|
||||
file_id_str.clone(),
|
||||
File {
|
||||
id: file_id_str.clone(),
|
||||
file_type: "metric".to_string(),
|
||||
file_name: name.to_string(),
|
||||
version_number: 1,
|
||||
version_id: String::from("0203f597-5ec5-4fd8-86e2-8587fe1c23b6"),
|
||||
status: "loading".to_string(),
|
||||
file: FileContent {
|
||||
text: None,
|
||||
text_chunk: if new_content.is_empty() { None } else { Some(new_content) },
|
||||
modified: None,
|
||||
},
|
||||
metadata: None,
|
||||
},
|
||||
);
|
||||
file_ids.push(file_id_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !files_map.is_empty() {
|
||||
// Only return the output if we have files
|
||||
if !file_ids.is_empty() {
|
||||
return Ok(Some(ProcessedOutput::File(ReasoningFile {
|
||||
id,
|
||||
message_type: "files".to_string(),
|
||||
|
@ -116,11 +144,16 @@ impl Processor for CreateMetricsProcessor {
|
|||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(CreateMetricsProcessor::new())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::ProcessedOutput;
|
||||
|
||||
#[test]
|
||||
fn test_can_process() {
|
||||
|
@ -156,32 +189,76 @@ mod tests {
|
|||
let json = r#"{"files":[{"name":"test_metric.yml","yml_content":"name: Test Metric\ndescription: A test metric"}]}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file)) = output {
|
||||
assert_eq!(file.id, id);
|
||||
assert_eq!(file.title, "Creating metric files...");
|
||||
assert_eq!(file.file_ids.len(), 1);
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = output {
|
||||
assert_eq!(file_output.id, id);
|
||||
assert_eq!(file_output.title, "Creating metric files...");
|
||||
assert_eq!(file_output.files.len(), 1);
|
||||
|
||||
// Check that the file was created with the correct properties
|
||||
let file_id = &file.file_ids[0];
|
||||
let metric_file = file.files.get(file_id).unwrap();
|
||||
assert_eq!(metric_file.file_type, "metric");
|
||||
assert_eq!(metric_file.file_name, "test_metric.yml");
|
||||
assert_eq!(metric_file.status, "loading");
|
||||
|
||||
// Check the file content
|
||||
assert!(metric_file.file.text_chunk.as_ref().unwrap().contains("name: Test Metric"));
|
||||
// Check the first file
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file_name, "test_metric.yml");
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, Some("name: Test Metric\ndescription: A test metric".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_with_context_streaming() {
|
||||
let processor = CreateMetricsProcessor::new();
|
||||
let id = "test_id".to_string();
|
||||
|
||||
// First chunk
|
||||
let json1 = r#"{"files":[{"name":"test_metric.yml","yml_content":""}]}"#;
|
||||
let result1 = processor.process_with_context(id.clone(), json1, None);
|
||||
assert!(result1.is_ok());
|
||||
let output1 = result1.unwrap();
|
||||
assert!(output1.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = &output1 {
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, None); // Empty string, so no chunk
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
|
||||
// Test with invalid data
|
||||
let json = r#"{"other_key":"value"}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
// Second chunk
|
||||
let json2 = r#"{"files":[{"name":"test_metric.yml","yml_content":"name: Test Metric\n"}]}"#;
|
||||
let result2 = processor.process_with_context(id.clone(), json2, output1);
|
||||
assert!(result2.is_ok());
|
||||
let output2 = result2.unwrap();
|
||||
assert!(output2.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = &output2 {
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, Some("name: Test Metric\n".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
|
||||
// Third chunk
|
||||
let json3 = r#"{"files":[{"name":"test_metric.yml","yml_content":"name: Test Metric\ndescription: A test metric"}]}"#;
|
||||
let result3 = processor.process_with_context(id.clone(), json3, output2);
|
||||
assert!(result3.is_ok());
|
||||
let output3 = result3.unwrap();
|
||||
assert!(output3.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::File(file_output)) = &output3 {
|
||||
let file_id = &file_output.file_ids[0];
|
||||
let file = file_output.files.get(file_id).unwrap();
|
||||
assert_eq!(file.file.text, None);
|
||||
assert_eq!(file.file.text_chunk, Some("description: A test metric".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::File");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,18 +27,37 @@ impl Processor for CreatePlanProcessor {
|
|||
}
|
||||
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
self.process_with_context(id, json, None)
|
||||
}
|
||||
|
||||
fn process_with_context(&self, id: String, json: &str, previous_output: Option<ProcessedOutput>) -> Result<Option<ProcessedOutput>> {
|
||||
// Try to parse the JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json) {
|
||||
// Check if it's a plan structure (has plan_markdown key)
|
||||
// Check if it's a plan markdown structure
|
||||
if let Some(plan_markdown) = value.get("plan_markdown").and_then(Value::as_str) {
|
||||
// Return the plan as a ReasoningText
|
||||
// Get the previously processed content
|
||||
let previous_content = if let Some(ProcessedOutput::Text(text)) = previous_output {
|
||||
text.message_chunk.clone().unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Calculate the new content (what wasn't in the previous content)
|
||||
let new_content = if plan_markdown.len() > previous_content.len() {
|
||||
plan_markdown[previous_content.len()..].to_string()
|
||||
} else {
|
||||
// If for some reason the new content is shorter, just use the whole thing
|
||||
plan_markdown.to_string()
|
||||
};
|
||||
|
||||
// Return the plan markdown as a ReasoningText
|
||||
return Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
reasoning_type: "text".to_string(),
|
||||
title: "Creating a plan...".to_string(),
|
||||
secondary_title: String::from(""),
|
||||
message: None,
|
||||
message_chunk: Some(plan_markdown.to_string()),
|
||||
message_chunk: if new_content.is_empty() { None } else { Some(new_content) },
|
||||
status: Some("loading".to_string()),
|
||||
})));
|
||||
}
|
||||
|
@ -46,26 +65,31 @@ impl Processor for CreatePlanProcessor {
|
|||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(CreatePlanProcessor::new())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::ProcessedOutput;
|
||||
|
||||
#[test]
|
||||
fn test_can_process() {
|
||||
let processor = CreatePlanProcessor::new();
|
||||
|
||||
// Test with valid plan data
|
||||
let json = r#"{"plan_markdown": "This is a plan"}"#;
|
||||
let json = r#"{"plan_markdown":"This is a plan"}"#;
|
||||
assert!(processor.can_process(json));
|
||||
|
||||
// Test with invalid data
|
||||
let json = r#"{"other_key": "value"}"#;
|
||||
let json = r#"{"other_key":"value"}"#;
|
||||
assert!(!processor.can_process(json));
|
||||
|
||||
// Test with malformed JSON
|
||||
let json = r#"{"plan_markdown": "This is a plan"#;
|
||||
let json = r#"{"plan_markdown":"This is a plan"#;
|
||||
assert!(!processor.can_process(json));
|
||||
}
|
||||
|
||||
|
@ -75,25 +99,65 @@ mod tests {
|
|||
let id = "test_id".to_string();
|
||||
|
||||
// Test with valid plan data
|
||||
let json = r#"{"plan_markdown": "This is a plan"}"#;
|
||||
let json = r#"{"plan_markdown":"This is a plan"}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = output {
|
||||
assert_eq!(text.id, id);
|
||||
assert_eq!(text.title, "Creating a plan...");
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, Some("This is a plan".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
}
|
||||
|
||||
// Test with invalid data
|
||||
let json = r#"{"other_key": "value"}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
#[test]
|
||||
fn test_process_with_context_streaming() {
|
||||
let processor = CreatePlanProcessor::new();
|
||||
let id = "test_id".to_string();
|
||||
|
||||
// First chunk
|
||||
let json1 = r#"{"plan_markdown":""}"#;
|
||||
let result1 = processor.process_with_context(id.clone(), json1, None);
|
||||
assert!(result1.is_ok());
|
||||
let output1 = result1.unwrap();
|
||||
assert!(output1.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = &output1 {
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, None); // Empty string, so no chunk
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
|
||||
// Second chunk
|
||||
let json2 = r#"{"plan_markdown":"Objective:\n"}"#;
|
||||
let result2 = processor.process_with_context(id.clone(), json2, output1);
|
||||
assert!(result2.is_ok());
|
||||
let output2 = result2.unwrap();
|
||||
assert!(output2.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = &output2 {
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, Some("Objective:\n".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
|
||||
// Third chunk
|
||||
let json3 = r#"{"plan_markdown":"Objective:\nOur goal is to "}"#;
|
||||
let result3 = processor.process_with_context(id.clone(), json3, output2);
|
||||
assert!(result3.is_ok());
|
||||
let output3 = result3.unwrap();
|
||||
assert!(output3.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = &output3 {
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, Some("Our goal is to ".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,10 +27,29 @@ impl Processor for SearchDataCatalogProcessor {
|
|||
}
|
||||
|
||||
fn process(&self, id: String, json: &str) -> Result<Option<ProcessedOutput>> {
|
||||
self.process_with_context(id, json, None)
|
||||
}
|
||||
|
||||
fn process_with_context(&self, id: String, json: &str, previous_output: Option<ProcessedOutput>) -> Result<Option<ProcessedOutput>> {
|
||||
// Try to parse the JSON
|
||||
if let Ok(value) = serde_json::from_str::<Value>(json) {
|
||||
// Check if it's a search requirements structure
|
||||
if let Some(search_requirements) = value.get("search_requirements").and_then(Value::as_str) {
|
||||
// Get the previously processed content
|
||||
let previous_content = if let Some(ProcessedOutput::Text(text)) = previous_output {
|
||||
text.message_chunk.clone().unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Calculate the new content (what wasn't in the previous content)
|
||||
let new_content = if search_requirements.len() > previous_content.len() {
|
||||
search_requirements[previous_content.len()..].to_string()
|
||||
} else {
|
||||
// If for some reason the new content is shorter, just use the whole thing
|
||||
search_requirements.to_string()
|
||||
};
|
||||
|
||||
// Return the search requirements as a ReasoningText
|
||||
return Ok(Some(ProcessedOutput::Text(ReasoningText {
|
||||
id,
|
||||
|
@ -38,7 +57,7 @@ impl Processor for SearchDataCatalogProcessor {
|
|||
title: "Searching your data catalog...".to_string(),
|
||||
secondary_title: String::from(""),
|
||||
message: None,
|
||||
message_chunk: Some(search_requirements.to_string()),
|
||||
message_chunk: if new_content.is_empty() { None } else { Some(new_content) },
|
||||
status: Some("loading".to_string()),
|
||||
})));
|
||||
}
|
||||
|
@ -46,26 +65,31 @@ impl Processor for SearchDataCatalogProcessor {
|
|||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn Processor> {
|
||||
Box::new(SearchDataCatalogProcessor::new())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::ProcessedOutput;
|
||||
|
||||
#[test]
|
||||
fn test_can_process() {
|
||||
let processor = SearchDataCatalogProcessor::new();
|
||||
|
||||
// Test with valid search data
|
||||
let json = r#"{"search_requirements": "Find metrics related to user engagement"}"#;
|
||||
let json = r#"{"search_requirements":"Find data about sales"}"#;
|
||||
assert!(processor.can_process(json));
|
||||
|
||||
// Test with invalid data
|
||||
let json = r#"{"other_key": "value"}"#;
|
||||
let json = r#"{"other_key":"value"}"#;
|
||||
assert!(!processor.can_process(json));
|
||||
|
||||
// Test with malformed JSON
|
||||
let json = r#"{"search_requirements": "Find metrics"#;
|
||||
let json = r#"{"search_requirements":"Find data about sales"#;
|
||||
assert!(!processor.can_process(json));
|
||||
}
|
||||
|
||||
|
@ -75,25 +99,65 @@ mod tests {
|
|||
let id = "test_id".to_string();
|
||||
|
||||
// Test with valid search data
|
||||
let json = r#"{"search_requirements": "Find metrics related to user engagement"}"#;
|
||||
let json = r#"{"search_requirements":"Find data about sales"}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = output {
|
||||
assert_eq!(text.id, id);
|
||||
assert_eq!(text.title, "Searching your data catalog...");
|
||||
assert_eq!(text.message_chunk, Some("Find metrics related to user engagement".to_string()));
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, Some("Find data about sales".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_with_context_streaming() {
|
||||
let processor = SearchDataCatalogProcessor::new();
|
||||
let id = "test_id".to_string();
|
||||
|
||||
// First chunk
|
||||
let json1 = r#"{"search_requirements":""}"#;
|
||||
let result1 = processor.process_with_context(id.clone(), json1, None);
|
||||
assert!(result1.is_ok());
|
||||
let output1 = result1.unwrap();
|
||||
assert!(output1.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = &output1 {
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, None); // Empty string, so no chunk
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
|
||||
// Test with invalid data
|
||||
let json = r#"{"other_key": "value"}"#;
|
||||
let result = processor.process(id.clone(), json);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
// Second chunk
|
||||
let json2 = r#"{"search_requirements":"Find data "}"#;
|
||||
let result2 = processor.process_with_context(id.clone(), json2, output1);
|
||||
assert!(result2.is_ok());
|
||||
let output2 = result2.unwrap();
|
||||
assert!(output2.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = &output2 {
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, Some("Find data ".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
|
||||
// Third chunk
|
||||
let json3 = r#"{"search_requirements":"Find data about sales"}"#;
|
||||
let result3 = processor.process_with_context(id.clone(), json3, output2);
|
||||
assert!(result3.is_ok());
|
||||
let output3 = result3.unwrap();
|
||||
assert!(output3.is_some());
|
||||
|
||||
if let Some(ProcessedOutput::Text(text)) = &output3 {
|
||||
assert_eq!(text.message, None);
|
||||
assert_eq!(text.message_chunk, Some("about sales".to_string()));
|
||||
} else {
|
||||
panic!("Expected ProcessedOutput::Text");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use chrono::{DateTime, Utc};
|
||||
use uuid::Uuid;
|
||||
use serde_json::Value;
|
||||
|
||||
/// The main output type for processors
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProcessedOutput {
|
||||
/// A text-based reasoning message
|
||||
|
@ -14,7 +17,7 @@ pub enum ProcessedOutput {
|
|||
}
|
||||
|
||||
/// Represents a text-based reasoning message
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ReasoningText {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
|
@ -27,7 +30,7 @@ pub struct ReasoningText {
|
|||
}
|
||||
|
||||
/// Represents a file-based reasoning message
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ReasoningFile {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
|
@ -40,7 +43,7 @@ pub struct ReasoningFile {
|
|||
}
|
||||
|
||||
/// Represents a reasoning pill message
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ReasoningPill {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
|
@ -52,14 +55,14 @@ pub struct ReasoningPill {
|
|||
}
|
||||
|
||||
/// Represents a container for thought pills
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ThoughtPillContainer {
|
||||
pub title: String,
|
||||
pub pills: Vec<ThoughtPill>,
|
||||
}
|
||||
|
||||
/// Represents an individual thought pill
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ThoughtPill {
|
||||
pub id: String,
|
||||
pub text: String,
|
||||
|
@ -68,7 +71,7 @@ pub struct ThoughtPill {
|
|||
}
|
||||
|
||||
/// Represents a file in the system
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct File {
|
||||
pub id: String,
|
||||
pub file_type: String,
|
||||
|
@ -81,15 +84,15 @@ pub struct File {
|
|||
}
|
||||
|
||||
/// Represents the content of a file
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct FileContent {
|
||||
pub text: Option<String>,
|
||||
pub text_chunk: Option<String>,
|
||||
pub modifided: Option<Vec<(i32, i32)>>,
|
||||
pub modified: Option<Vec<(i32, i32)>>,
|
||||
}
|
||||
|
||||
/// Represents metadata for a file
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct FileMetadata {
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
|
@ -128,3 +131,57 @@ impl From<&str> for ProcessorType {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the different types of LiteLlmMessages that can be processed
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub enum MessageType {
|
||||
/// An Assistant message with tool calls (not null)
|
||||
AssistantToolCall,
|
||||
/// An Assistant message with content (text response)
|
||||
AssistantResponse,
|
||||
/// A Tool message (output from executed tool call)
|
||||
ToolOutput,
|
||||
}
|
||||
|
||||
/// A tool call with its associated information
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ToolCallInfo {
|
||||
/// The ID of the tool call
|
||||
pub id: String,
|
||||
/// The name of the tool
|
||||
pub name: String,
|
||||
/// The input parameters
|
||||
pub input: Value,
|
||||
/// The output content (if available)
|
||||
pub output: Option<Value>,
|
||||
/// The timestamp when the tool call was created
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// The current state of the tool call
|
||||
pub state: ToolCallState,
|
||||
/// The chunks received so far for this tool call
|
||||
pub chunks: Vec<String>,
|
||||
}
|
||||
|
||||
/// The state of a tool call
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub enum ToolCallState {
|
||||
/// The tool call is in progress
|
||||
InProgress,
|
||||
/// The tool call is complete
|
||||
Complete,
|
||||
/// The tool call has an output
|
||||
HasOutput,
|
||||
}
|
||||
|
||||
/// A processed message
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct ProcessedMessage {
|
||||
/// The ID of the message
|
||||
pub id: String,
|
||||
/// The type of the message
|
||||
pub message_type: MessageType,
|
||||
/// The processed content
|
||||
pub content: ProcessedOutput,
|
||||
/// The timestamp when the message was created
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
|
|
@ -1,347 +0,0 @@
|
|||
# PRD: Integrating ChunkTracker with StreamingParser
|
||||
|
||||
## Problem Statement ✅
|
||||
|
||||
The current implementation of streaming content processing in the Buster API has two separate components that handle related functionality:
|
||||
|
||||
1. The `ChunkTracker` in `libs/handlers/src/chats/post_chat_handler.rs` - Tracks and manages streaming content chunks, calculating deltas between chunks and maintaining complete text.
|
||||
2. The `StreamingParser` in `libs/streaming/src/parser.rs` - Handles parsing of incomplete JSON streams and processes them through specialized processors.
|
||||
|
||||
This separation creates several issues:
|
||||
|
||||
- **Code duplication**: Similar streaming functionality exists in two places
|
||||
- **Inconsistent handling**: Different approaches to handling streaming content
|
||||
- **Maintenance overhead**: Changes to streaming logic need to be made in multiple places
|
||||
- **Limited reusability**: The `ChunkTracker` is only available in the chat handler context
|
||||
|
||||
### Current Limitations
|
||||
- `ChunkTracker` is implemented as a singleton with global state in the chat handler
|
||||
- `StreamingParser` doesn't have built-in tracking for incremental content changes
|
||||
- No unified approach to handling streaming content across the application
|
||||
- Difficult to reuse the chunk tracking functionality in other contexts
|
||||
|
||||
### Impact
|
||||
- **Developer Experience**: Developers need to understand and maintain two separate systems
|
||||
- **Code Quality**: Duplicated logic increases the chance of bugs and inconsistencies
|
||||
- **Maintainability**: Changes to streaming behavior require updates in multiple places
|
||||
- **Feature Development**: New streaming features are harder to implement consistently
|
||||
|
||||
## Requirements
|
||||
|
||||
### Functional Requirements ✅
|
||||
|
||||
#### Core Functionality
|
||||
- Integrate `ChunkTracker` functionality directly into the `StreamingParser`
|
||||
- Details: Move the chunk tracking logic from the chat handler to the streaming library
|
||||
- Acceptance Criteria: All existing functionality preserved with a unified API
|
||||
- Dependencies: None
|
||||
|
||||
- Support calculating deltas between chunks
|
||||
- Details: Maintain the ability to identify only the new content in each chunk
|
||||
- Acceptance Criteria: Delta calculation works identically to current implementation
|
||||
- Dependencies: None
|
||||
|
||||
- Maintain complete text tracking
|
||||
- Details: Keep track of the complete accumulated text for each chunk ID
|
||||
- Acceptance Criteria: Complete text retrieval works identically to current implementation
|
||||
- Dependencies: None
|
||||
|
||||
- Support clearing tracked chunks
|
||||
- Details: Allow clearing tracked chunks when they're no longer needed
|
||||
- Acceptance Criteria: Chunk clearing works identically to current implementation
|
||||
- Dependencies: None
|
||||
|
||||
#### API Design
|
||||
- Provide a clean, intuitive API for chunk tracking
|
||||
- Details: Design methods that are easy to use and understand
|
||||
- Acceptance Criteria: API follows Rust best practices and is well-documented
|
||||
- Dependencies: None
|
||||
|
||||
- Support both raw text and JSON processing
|
||||
- Details: Handle both raw text chunks and JSON-structured content
|
||||
- Acceptance Criteria: Both types of content can be processed with appropriate methods
|
||||
- Dependencies: None
|
||||
|
||||
### Non-Functional Requirements ✅
|
||||
|
||||
- Performance Requirements
|
||||
- Maintain or improve current performance characteristics
|
||||
- Efficient memory usage for long-running streams
|
||||
- Thread-safe implementation for concurrent access
|
||||
|
||||
- Maintainability Requirements
|
||||
- Well-documented code with clear comments
|
||||
- Comprehensive unit tests for all functionality
|
||||
- Clear separation of concerns within the implementation
|
||||
|
||||
## Technical Design ✅
|
||||
|
||||
### System Architecture
|
||||
|
||||
The integration will enhance the StreamingParser to include chunk tracking functionality, creating a unified streaming content processing system:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Client] -->|Streaming Content| B[API Endpoint]
|
||||
B -->|Process Chunks| C[StreamingParser]
|
||||
C -->|Track Chunks| D[Internal ChunkTracker]
|
||||
C -->|Process Content| E[Processor Registry]
|
||||
E -->|Process with| F[Specialized Processors]
|
||||
C -->|Return Processed Output| B
|
||||
B -->|Stream Response| A
|
||||
```
|
||||
|
||||
### Core Components ✅
|
||||
|
||||
#### Enhanced StreamingParser
|
||||
|
||||
```rust
|
||||
pub struct StreamingParser {
|
||||
/// Buffer to accumulate chunks of data
|
||||
buffer: String,
|
||||
/// Registry of processors for different types of content
|
||||
processors: ProcessorRegistry,
|
||||
/// Regex for extracting YAML content
|
||||
yml_content_regex: Regex,
|
||||
/// Tracks chunks and their state
|
||||
chunk_tracker: ChunkTracker,
|
||||
}
|
||||
|
||||
impl StreamingParser {
|
||||
/// Creates a new StreamingParser with an empty processor registry
|
||||
pub fn new() -> Self {
|
||||
StreamingParser {
|
||||
buffer: String::new(),
|
||||
processors: ProcessorRegistry::new(),
|
||||
yml_content_regex: Regex::new(
|
||||
r#""yml_content":\s*"((?:[^"\\]|\\.|[\r\n])*?)(?:"|$)"#,
|
||||
)
|
||||
.unwrap(),
|
||||
chunk_tracker: ChunkTracker::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// Existing methods...
|
||||
|
||||
/// Process a chunk of data and track changes
|
||||
pub fn process_chunk_with_tracking(
|
||||
&mut self,
|
||||
chunk_id: String,
|
||||
chunk: &str,
|
||||
processor_type: &str,
|
||||
) -> Result<(Option<ProcessedOutput>, String)> {
|
||||
// Calculate delta using chunk tracker
|
||||
let delta = self.chunk_tracker.add_chunk(chunk_id.clone(), chunk.to_string());
|
||||
|
||||
// Process with the appropriate processor
|
||||
let processed = self.process_chunk(chunk_id, chunk, processor_type)?;
|
||||
|
||||
// Return both the processed output and the delta
|
||||
Ok((processed, delta))
|
||||
}
|
||||
|
||||
/// Get complete text for a chunk ID
|
||||
pub fn get_complete_text(&self, chunk_id: String) -> Option<String> {
|
||||
self.chunk_tracker.get_complete_text(chunk_id)
|
||||
}
|
||||
|
||||
/// Clear tracking for a chunk ID
|
||||
pub fn clear_chunk(&mut self, chunk_id: String) {
|
||||
self.chunk_tracker.clear_chunk(chunk_id)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Internal ChunkTracker Component
|
||||
|
||||
```rust
|
||||
/// Tracks and manages streaming content chunks
|
||||
struct ChunkTracker {
|
||||
chunks: Mutex<HashMap<String, ChunkState>>,
|
||||
}
|
||||
|
||||
struct ChunkState {
|
||||
complete_text: String,
|
||||
last_seen_content: String,
|
||||
}
|
||||
|
||||
impl ChunkTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
chunks: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_chunk(&self, chunk_id: String, new_chunk: String) -> String {
|
||||
if let Ok(mut chunks) = self.chunks.lock() {
|
||||
let state = chunks.entry(chunk_id).or_insert(ChunkState {
|
||||
complete_text: String::new(),
|
||||
last_seen_content: String::new(),
|
||||
});
|
||||
|
||||
// Calculate the delta by finding what's new since last_seen_content
|
||||
let delta = if state.last_seen_content.is_empty() {
|
||||
// First chunk, use it as is
|
||||
new_chunk.clone()
|
||||
} else if new_chunk.starts_with(&state.last_seen_content) {
|
||||
// New chunk contains all previous content at the start, extract only the new part
|
||||
new_chunk[state.last_seen_content.len()..].to_string()
|
||||
} else {
|
||||
// If we can't find the previous content, try to find where the new content starts
|
||||
match new_chunk.find(&state.last_seen_content) {
|
||||
Some(pos) => new_chunk[pos + state.last_seen_content.len()..].to_string(),
|
||||
None => {
|
||||
// If we can't find any overlap, this might be completely new content
|
||||
new_chunk.clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Update tracking state only if we found new content
|
||||
if !delta.is_empty() {
|
||||
state.complete_text.push_str(&delta);
|
||||
state.last_seen_content = new_chunk;
|
||||
}
|
||||
|
||||
delta
|
||||
} else {
|
||||
new_chunk
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_complete_text(&self, chunk_id: String) -> Option<String> {
|
||||
self.chunks.lock().ok().and_then(|chunks| {
|
||||
chunks
|
||||
.get(&chunk_id)
|
||||
.map(|state| state.complete_text.clone())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_chunk(&self, chunk_id: String) {
|
||||
if let Ok(mut chunks) = self.chunks.lock() {
|
||||
chunks.remove(&chunk_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### File Changes ✅
|
||||
|
||||
#### Modified Files
|
||||
|
||||
- `libs/streaming/src/parser.rs`
|
||||
- Purpose: Enhance the StreamingParser to include chunk tracking functionality
|
||||
- Key changes:
|
||||
- Add ChunkTracker as an internal component
|
||||
- Add methods for chunk tracking and delta calculation
|
||||
- Integrate chunk tracking with existing processing logic
|
||||
|
||||
- `libs/streaming/src/lib.rs`
|
||||
- Purpose: Update exports to include new functionality
|
||||
- Key changes:
|
||||
- Re-export new chunk tracking methods
|
||||
|
||||
- `libs/handlers/src/chats/post_chat_handler.rs`
|
||||
- Purpose: Update to use the enhanced StreamingParser
|
||||
- Key changes:
|
||||
- Remove the existing ChunkTracker implementation
|
||||
- Update code to use the StreamingParser's chunk tracking functionality
|
||||
|
||||
## Implementation Plan ✅
|
||||
|
||||
### Phase 1: Core Implementation ⏳
|
||||
|
||||
1. Enhance StreamingParser with ChunkTracker functionality
|
||||
- [ ] Add ChunkTracker as an internal component in StreamingParser
|
||||
- [ ] Implement chunk tracking methods in StreamingParser
|
||||
- [ ] Add unit tests for new functionality
|
||||
|
||||
2. Update exports in streaming library
|
||||
- [ ] Update lib.rs to export new functionality
|
||||
- [ ] Ensure backward compatibility
|
||||
|
||||
### Phase 2: Integration with Existing Code 🔜
|
||||
|
||||
1. Update post_chat_handler.rs
|
||||
- [ ] Remove existing ChunkTracker implementation
|
||||
- [ ] Update code to use StreamingParser's chunk tracking
|
||||
- [ ] Test to ensure functionality is preserved
|
||||
|
||||
2. Comprehensive testing
|
||||
- [ ] End-to-end testing of streaming functionality
|
||||
- [ ] Performance testing to ensure no regressions
|
||||
|
||||
### Phase 3: Documentation and Cleanup 🔜
|
||||
|
||||
1. Documentation
|
||||
- [ ] Add comprehensive documentation for new functionality
|
||||
- [ ] Update existing documentation to reflect changes
|
||||
|
||||
2. Code cleanup
|
||||
- [ ] Remove any redundant code
|
||||
- [ ] Address any technical debt identified during implementation
|
||||
|
||||
## Testing Strategy ✅
|
||||
|
||||
### Unit Tests
|
||||
|
||||
- Test chunk tracking functionality
|
||||
- Test adding first chunk
|
||||
- Test adding subsequent chunks with various overlap patterns
|
||||
- Test retrieving complete text
|
||||
- Test clearing chunks
|
||||
|
||||
- Test integration with existing StreamingParser
|
||||
- Test processing chunks with tracking
|
||||
- Test combined functionality of processing and tracking
|
||||
|
||||
### Integration Tests
|
||||
|
||||
- Test end-to-end streaming scenarios
|
||||
- Test with various types of streaming content
|
||||
- Test with different chunk sizes and patterns
|
||||
- Test with concurrent access
|
||||
|
||||
### Performance Tests
|
||||
|
||||
- Benchmark memory usage
|
||||
- Compare before and after implementation
|
||||
- Test with large streams
|
||||
|
||||
- Benchmark processing time
|
||||
- Compare before and after implementation
|
||||
- Test with various content types and sizes
|
||||
|
||||
## Success Criteria ✅
|
||||
|
||||
- All existing functionality is preserved
|
||||
- Code is more maintainable and reusable
|
||||
- No performance regressions
|
||||
- Comprehensive test coverage
|
||||
- Clear, well-documented API
|
||||
|
||||
## Dependencies ✅
|
||||
|
||||
- `std::collections::HashMap`
|
||||
- `std::sync::Mutex`
|
||||
- Existing StreamingParser implementation
|
||||
- Existing processor infrastructure
|
||||
|
||||
## Security Considerations ✅
|
||||
|
||||
- Thread safety for concurrent access
|
||||
- Memory management for large streams
|
||||
- Proper error handling
|
||||
|
||||
## Rollback Plan ✅
|
||||
|
||||
If issues are encountered:
|
||||
1. Revert changes to StreamingParser
|
||||
2. Restore original ChunkTracker implementation in post_chat_handler.rs
|
||||
3. Update any dependent code to use the original implementation
|
||||
|
||||
## Monitoring and Metrics ✅
|
||||
|
||||
- Track memory usage during streaming
|
||||
- Monitor processing time for chunks
|
||||
- Track error rates related to streaming functionality
|
|
@ -0,0 +1,506 @@
|
|||
# Streaming Library Enhancement for Agent Integration
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The current implementation of the chat handling system in `libs/handlers/src/chats/post_chat_handler.rs` lacks clean abstraction for managing streaming logic for agent messages. The post chat handler is responsible for too many concerns, including agent instantiation, chat context loading, message processing, and streaming management. This has led to:
|
||||
|
||||
1. Difficulty in maintaining and extending the streaming functionality
|
||||
2. Lack of clear separation between different message types (assistant text, assistant tool, tool output)
|
||||
3. Inconsistent chunk tracking across message types
|
||||
4. Complex code that is difficult to test and debug
|
||||
|
||||
### Current Limitations
|
||||
|
||||
- The streaming library does not handle caching and chunk tracking for all tool calls consistently
|
||||
- There's no clear distinction between assistant tool calls and tool outputs
|
||||
- The chunk tracking mechanism is not centralized, making it difficult to ensure all message chunks are properly tracked
|
||||
- The `post_chat_handler.rs` has too many responsibilities, making it hard to maintain and extend
|
||||
- No consistent way to collect and store reasoning messages for later display
|
||||
|
||||
### Impact
|
||||
|
||||
- **User Impact**: Inconsistent user experience when streaming messages, especially with complex tool calls
|
||||
- **System Impact**: Increased complexity and potential for bugs in message handling
|
||||
- **Business Impact**: Slower development velocity due to complex codebase, increased maintenance cost
|
||||
|
||||
## Requirements
|
||||
|
||||
### Functional Requirements
|
||||
|
||||
#### Core Functionality
|
||||
|
||||
- The streaming library must handle all streaming logic for agent messages
|
||||
- It must support different LiteLlmMessage types (Assistant with tool calls, Assistant with content, Tool messages)
|
||||
- It must track and cache all tool calls and their outputs using shared IDs
|
||||
- It must store messages for redisplay
|
||||
- It must support both reasoning and response messages
|
||||
|
||||
#### Message Handling
|
||||
|
||||
- **Message Parsing**: The StreamingParser must parse and process LiteLlmMessage types
|
||||
- Acceptance Criteria: Successfully parse and process all LiteLlmMessage variants
|
||||
- Dependencies: Current StreamingParser implementation
|
||||
|
||||
- **Tool Call Tracking**: The system must track tool calls and their outputs using shared IDs
|
||||
- Acceptance Criteria: Successfully associate tool calls with their outputs
|
||||
- Dependencies: LiteLlmMessage format from Agent
|
||||
|
||||
- **Message Storage**: The system must store messages for future reference
|
||||
- Acceptance Criteria: Successfully retrieve stored messages
|
||||
- Dependencies: StreamingParser implementation
|
||||
|
||||
#### Post Chat Handler Integration
|
||||
|
||||
- **Simplified Interface**: The post chat handler must have a clean interface for streaming
|
||||
- Acceptance Criteria: Post chat handler code is significantly simplified
|
||||
- Dependencies: Current post_chat_handler implementation
|
||||
|
||||
### Non-Functional Requirements
|
||||
|
||||
- **Performance Requirements**
|
||||
- Message processing must not introduce significant latency (<50ms per message)
|
||||
- Memory usage should be optimized for large message streams
|
||||
|
||||
- **Maintainability Requirements**
|
||||
- Clear separation of concerns between components
|
||||
- Comprehensive test coverage (>80%) for all new components
|
||||
- Well-documented public API
|
||||
|
||||
- **Compatibility Requirements**
|
||||
- Must be backward compatible with existing processors and message formats
|
||||
- Must integrate seamlessly with the current agent implementation
|
||||
|
||||
## Technical Design
|
||||
|
||||
### System Architecture
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Post Chat Handler] --> B[Enhanced StreamingParser]
|
||||
D[Agent] -- "LiteLlmMessages" --> A
|
||||
B -- "Store Messages" --> B
|
||||
```
|
||||
|
||||
### Core Components
|
||||
|
||||
#### Component 1: Message Types
|
||||
|
||||
```rust
|
||||
/// Represents the different types of LiteLlmMessages that can be processed
|
||||
pub enum MessageType {
|
||||
/// An Assistant message with tool calls (not null)
|
||||
AssistantToolCall,
|
||||
/// An Assistant message with content (text response)
|
||||
AssistantResponse,
|
||||
/// A Tool message (output from executed tool call)
|
||||
ToolOutput,
|
||||
}
|
||||
|
||||
/// A tool call with its associated information
|
||||
pub struct ToolCallInfo {
|
||||
/// The ID of the tool call
|
||||
pub id: String,
|
||||
/// The name of the tool
|
||||
pub name: String,
|
||||
/// The input parameters
|
||||
pub input: Value,
|
||||
/// The output content (if available)
|
||||
pub output: Option<Value>,
|
||||
/// The timestamp when the tool call was created
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// The current state of the tool call
|
||||
pub state: ToolCallState,
|
||||
/// The chunks received so far for this tool call
|
||||
pub chunks: Vec<String>,
|
||||
}
|
||||
|
||||
/// The state of a tool call
|
||||
pub enum ToolCallState {
|
||||
/// The tool call is in progress
|
||||
InProgress,
|
||||
/// The tool call is complete
|
||||
Complete,
|
||||
/// The tool call has an output
|
||||
HasOutput,
|
||||
}
|
||||
|
||||
/// A processed message
|
||||
pub struct ProcessedMessage {
|
||||
/// The ID of the message
|
||||
pub id: String,
|
||||
/// The type of the message
|
||||
pub message_type: MessageType,
|
||||
/// The processed content
|
||||
pub content: ProcessedOutput,
|
||||
/// The timestamp when the message was created
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
```
|
||||
|
||||
#### Component 2: Enhanced StreamingParser
|
||||
|
||||
```rust
|
||||
/// Enhanced StreamingParser with support for LiteLlmMessage types
|
||||
pub struct StreamingParser {
|
||||
/// Buffer for incomplete JSON
|
||||
buffer: String,
|
||||
/// Registry of processors for different message types
|
||||
processors: ProcessorRegistry,
|
||||
/// Map of tool call IDs to their information
|
||||
tool_calls: HashMap<String, ToolCallInfo>,
|
||||
/// List of reasoning messages (tool calls and outputs)
|
||||
reasoning_messages: Vec<ProcessedMessage>,
|
||||
/// List of response messages
|
||||
response_messages: Vec<String>,
|
||||
}
|
||||
|
||||
impl StreamingParser {
|
||||
/// Creates a new StreamingParser
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
processors: ProcessorRegistry::new(),
|
||||
tool_calls: HashMap::new(),
|
||||
reasoning_messages: Vec::new(),
|
||||
response_messages: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a LiteLlmMessage
|
||||
pub fn process_message(&mut self, message: &LiteLlmMessage) -> Result<Option<ProcessedOutput>> {
|
||||
match message {
|
||||
LiteLlmMessage::Assistant { tool_calls: Some(tool_calls), .. } => {
|
||||
self.process_assistant_tool_call(message, tool_calls)
|
||||
},
|
||||
LiteLlmMessage::Assistant { content: Some(content), tool_calls: None, .. } => {
|
||||
self.process_assistant_response(message, content)
|
||||
},
|
||||
LiteLlmMessage::Tool { content, tool_call_id, .. } => {
|
||||
self.process_tool_output(message, tool_call_id, content)
|
||||
},
|
||||
_ => Ok(None), // Ignore other message types
|
||||
}
|
||||
}
|
||||
|
||||
/// Process an Assistant message with tool calls
|
||||
fn process_assistant_tool_call(
|
||||
&mut self,
|
||||
message: &LiteLlmMessage,
|
||||
tool_calls: &[ToolCall]
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
for tool_call in tool_calls {
|
||||
let id = tool_call.id.clone();
|
||||
let name = tool_call.function.name.clone();
|
||||
let arguments = tool_call.function.arguments.clone();
|
||||
|
||||
// Parse arguments as JSON
|
||||
let input = serde_json::from_str::<Value>(&arguments)
|
||||
.unwrap_or_else(|_| serde_json::json!({"raw": arguments}));
|
||||
|
||||
// Register or update tool call
|
||||
if let Some(existing_tool_call) = self.tool_calls.get_mut(&id) {
|
||||
// Update existing tool call with new chunks
|
||||
existing_tool_call.chunks.push(arguments.clone());
|
||||
existing_tool_call.input = input.clone();
|
||||
if existing_tool_call.state == ToolCallState::InProgress {
|
||||
existing_tool_call.state = ToolCallState::Complete;
|
||||
}
|
||||
} else {
|
||||
// Register new tool call
|
||||
self.tool_calls.insert(id.clone(), ToolCallInfo {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: input.clone(),
|
||||
output: None,
|
||||
timestamp: Utc::now(),
|
||||
state: ToolCallState::Complete,
|
||||
chunks: vec![arguments.clone()],
|
||||
});
|
||||
}
|
||||
|
||||
// Process with appropriate processor
|
||||
if let Some(processor) = self.processors.get_processor_for_tool(&name) {
|
||||
let processed = processor.process(&input)?;
|
||||
|
||||
// Store as reasoning message
|
||||
self.add_reasoning_message(id.clone(), MessageType::AssistantToolCall, processed.clone());
|
||||
|
||||
return Ok(Some(processed));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Process an Assistant message with content (text response)
|
||||
fn process_assistant_response(
|
||||
&mut self,
|
||||
message: &LiteLlmMessage,
|
||||
content: &str
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// For response messages, we just store the text
|
||||
self.response_messages.push(content.to_string());
|
||||
|
||||
// Create a simple processed output
|
||||
let processed = ProcessedOutput::Text(ReasoningText {
|
||||
id: message.get_id().unwrap_or_else(|| Uuid::new_v4().to_string()),
|
||||
reasoning_type: "response".to_string(),
|
||||
title: "Assistant Response".to_string(),
|
||||
secondary_title: "".to_string(),
|
||||
message: Some(content.to_string()),
|
||||
message_chunk: None,
|
||||
status: Some("complete".to_string()),
|
||||
});
|
||||
|
||||
Ok(Some(processed))
|
||||
}
|
||||
|
||||
/// Process a Tool message (output from executed tool call)
|
||||
fn process_tool_output(
|
||||
&mut self,
|
||||
message: &LiteLlmMessage,
|
||||
tool_call_id: &str,
|
||||
content: &str
|
||||
) -> Result<Option<ProcessedOutput>> {
|
||||
// Parse content as JSON if possible
|
||||
let output = serde_json::from_str::<Value>(content)
|
||||
.unwrap_or_else(|_| serde_json::json!({"text": content}));
|
||||
|
||||
// Update tool call with output
|
||||
if let Some(tool_call) = self.tool_calls.get_mut(tool_call_id) {
|
||||
tool_call.output = Some(output.clone());
|
||||
tool_call.state = ToolCallState::HasOutput;
|
||||
|
||||
// Get the tool name
|
||||
let name = tool_call.name.clone();
|
||||
|
||||
// Process with appropriate processor
|
||||
if let Some(processor) = self.processors.get_processor_for_tool(&name) {
|
||||
let processed = processor.process_output(&output)?;
|
||||
|
||||
// Store as reasoning message
|
||||
self.add_reasoning_message(
|
||||
tool_call_id.to_string(),
|
||||
MessageType::ToolOutput,
|
||||
processed.clone()
|
||||
);
|
||||
|
||||
return Ok(Some(processed));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Adds a reasoning message
|
||||
fn add_reasoning_message(&mut self, id: String, message_type: MessageType, content: ProcessedOutput) {
|
||||
self.reasoning_messages.push(ProcessedMessage {
|
||||
id,
|
||||
message_type,
|
||||
content,
|
||||
timestamp: Utc::now(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Gets all reasoning messages
|
||||
pub fn get_reasoning_messages(&self) -> &[ProcessedMessage] {
|
||||
&self.reasoning_messages
|
||||
}
|
||||
|
||||
/// Gets all response messages
|
||||
pub fn get_response_messages(&self) -> &[String] {
|
||||
&self.response_messages
|
||||
}
|
||||
|
||||
/// Gets all tool calls
|
||||
pub fn get_tool_calls(&self) -> &HashMap<String, ToolCallInfo> {
|
||||
&self.tool_calls
|
||||
}
|
||||
|
||||
/// Gets a specific tool call by ID
|
||||
pub fn get_tool_call(&self, id: &str) -> Option<&ToolCallInfo> {
|
||||
self.tool_calls.get(id)
|
||||
}
|
||||
|
||||
/// Registers a processor for a specific tool
|
||||
pub fn register_processor(&mut self, name: &str, processor: Box<dyn Processor>) {
|
||||
self.processors.register_tool_processor(name, processor);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
The streaming library enhancement will follow this simplified data flow:
|
||||
|
||||
1. **Agent Execution**: The agent runs and produces LiteLlmMessage objects (Assistant with tool calls, Assistant with content, Tool messages)
|
||||
2. **Message Streaming**: LiteLlmMessages stream into the post_chat_handler
|
||||
3. **Parsing and Processing**: Messages are passed to the StreamingParser, which:
|
||||
- Identifies the message type (AssistantToolCall, AssistantResponse, ToolOutput)
|
||||
- Processes messages through appropriate processors
|
||||
- Tracks tool calls and their outputs using shared IDs
|
||||
- Stores messages internally for later retrieval
|
||||
4. **Message Storage**: After the agent finishes execution, the collected messages are retrieved from the StreamingParser for storage and display
|
||||
|
||||
This flow ensures a clean separation of concerns while minimizing the number of components:
|
||||
|
||||
- StreamingParser handles both real-time processing and message storage
|
||||
- Tool calls and their outputs are linked using shared IDs
|
||||
- All chunk tracking is handled within the StreamingParser
|
||||
|
||||
### Example Usage in post_chat_handler
|
||||
|
||||
```rust
|
||||
pub async fn post_chat_handler(
|
||||
request: ChatCreateNewChat,
|
||||
user: AuthenticatedUser,
|
||||
tx: Option<mpsc::Sender<Result<(BusterContainer, ThreadEvent)>>>,
|
||||
) -> Result<ChatWithMessages> {
|
||||
// Initialize enhanced StreamingParser
|
||||
let mut streaming_parser = StreamingParser::new();
|
||||
|
||||
// Register processors for different tool types
|
||||
streaming_parser.register_processor("create_plan", Box::new(CreatePlanProcessor::new()));
|
||||
streaming_parser.register_processor("create_metrics", Box::new(CreateMetricsProcessor::new()));
|
||||
// ... register other processors
|
||||
|
||||
// Initialize agent and get stream receiver
|
||||
let mut agent = BusterSuperAgent::new(user.clone(), chat_id).await?;
|
||||
let mut chat = AgentThread::new(Some(chat_id), user.id, initial_messages);
|
||||
let mut rx = agent.run(&mut chat).await?;
|
||||
|
||||
// Process streaming messages
|
||||
while let Ok(message_result) = rx.recv().await {
|
||||
match message_result {
|
||||
Ok(message) => {
|
||||
// Process the LiteLlmMessage with the StreamingParser
|
||||
if let Some(processed) = streaming_parser.process_message(&message)? {
|
||||
// Send to client if tx is available
|
||||
if let Some(tx) = &tx {
|
||||
match message {
|
||||
LiteLlmMessage::Assistant { tool_calls: Some(_), .. } => {
|
||||
let event = ThreadEvent::ReasoningMessage {
|
||||
id: message.get_id().unwrap_or_default(),
|
||||
content: processed,
|
||||
};
|
||||
tx.send(Ok((BusterContainer::new(), event))).await?;
|
||||
},
|
||||
LiteLlmMessage::Assistant { content: Some(_), .. } => {
|
||||
let event = ThreadEvent::ResponseMessage {
|
||||
content: processed,
|
||||
};
|
||||
tx.send(Ok((BusterContainer::new(), event))).await?;
|
||||
},
|
||||
LiteLlmMessage::Tool { .. } => {
|
||||
let event = ThreadEvent::ReasoningMessage {
|
||||
id: message.get_id().unwrap_or_default(),
|
||||
content: processed,
|
||||
};
|
||||
tx.send(Ok((BusterContainer::new(), event))).await?;
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Error receiving message: {}", e);
|
||||
return Err(anyhow!("Error receiving message: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After agent execution, collect all messages
|
||||
let reasoning_messages = streaming_parser.get_reasoning_messages();
|
||||
let response_messages = streaming_parser.get_response_messages();
|
||||
|
||||
// Create chat with messages
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Core Functionality
|
||||
|
||||
1. ✅ **Update Message Types**
|
||||
- ✅ Define MessageType enum for different message types
|
||||
- ✅ Create ToolCallInfo struct for managing tool calls
|
||||
- ✅ Define ToolCallState enum for tracking tool call state
|
||||
- ✅ Create ProcessedMessage struct for encapsulating processed messages
|
||||
|
||||
2. ✅ **Enhance StreamingParser**
|
||||
- ✅ Update the StreamingParser to handle LiteLlmMessage types
|
||||
- ✅ Implement methods for processing different message types
|
||||
- ✅ Add storage for tool calls, reasoning messages, and response messages
|
||||
|
||||
### Phase 2: Integration
|
||||
|
||||
1. ✅ **Update ProcessorRegistry**
|
||||
- ✅ Enhance the ProcessorRegistry to support tool-specific processors
|
||||
- ✅ Add methods for retrieving processors by tool name
|
||||
|
||||
2. **Integrate with post_chat_handler**
|
||||
- Update post_chat_handler to use the enhanced StreamingParser
|
||||
- Simplify the message handling logic in post_chat_handler
|
||||
|
||||
3. **Phase 3: Testing and Validation**
|
||||
- Write unit tests for the enhanced StreamingParser
|
||||
- Write integration tests for the entire flow
|
||||
- Validate that all requirements are met
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Testing
|
||||
|
||||
1. **StreamingParser Tests**
|
||||
- Test processing of different LiteLlmMessage types
|
||||
- Test tool call tracking and association with outputs
|
||||
- Test message storage and retrieval
|
||||
|
||||
2. **ProcessorRegistry Tests**
|
||||
- Test registration and retrieval of processors
|
||||
- Test processor selection based on tool name
|
||||
|
||||
### Integration Testing
|
||||
|
||||
1. **End-to-End Flow Tests**
|
||||
- Test the entire flow from agent execution to message display
|
||||
- Verify that all message types are correctly processed and stored
|
||||
|
||||
2. **Performance Tests**
|
||||
- Test with large message streams to ensure performance requirements are met
|
||||
- Measure latency and memory usage
|
||||
|
||||
## Success Criteria
|
||||
|
||||
1. **Functional Success**
|
||||
- All LiteLlmMessage types are correctly processed
|
||||
- Tool calls and their outputs are correctly associated using shared IDs
|
||||
- Messages are stored and can be retrieved for display
|
||||
|
||||
2. **Non-Functional Success**
|
||||
- Performance requirements are met
|
||||
- Code is maintainable and well-documented
|
||||
- Test coverage is at least 80%
|
||||
|
||||
## Risks and Mitigations
|
||||
|
||||
1. **Risk**: Incompatibility with existing processors
|
||||
- **Mitigation**: Ensure backward compatibility by maintaining the existing processor interface
|
||||
|
||||
2. **Risk**: Performance degradation with large message streams
|
||||
- **Mitigation**: Implement efficient caching and chunk tracking mechanisms
|
||||
|
||||
3. **Risk**: Incomplete or malformed messages in the stream
|
||||
- **Mitigation**: Implement robust error handling and recovery mechanisms
|
||||
|
||||
## Appendix
|
||||
|
||||
### Glossary
|
||||
|
||||
- **LiteLlmMessage**: A message from the LiteLLM library, which can be an Assistant message with tool calls, an Assistant message with content, or a Tool message
|
||||
- **Tool Call**: A request from the assistant to execute a tool
|
||||
- **Tool Output**: The result of executing a tool
|
||||
- **StreamingParser**: The component responsible for parsing and processing messages
|
||||
- **ProcessorRegistry**: A registry of processors for different message types
|
|
@ -74,8 +74,21 @@ async fn get_asset_access_handler(
|
|||
.first::<(Uuid, bool, bool, Option<DateTime<Utc>>)>(&mut conn)
|
||||
.await?;
|
||||
|
||||
let user_permission = {
|
||||
let pg_pool = pg_pool.clone();
|
||||
let user_id = user.id.clone();
|
||||
let asset_id = asset_id.clone();
|
||||
tokio::spawn(async move {
|
||||
get_user_dashboard_permission(&pg_pool, &user_id, &asset_id).await
|
||||
})
|
||||
};
|
||||
|
||||
(dashboard_info, Some(AssetPermissionRole::Owner))
|
||||
let user_permission = user_permission
|
||||
.await
|
||||
.map_err(|_| anyhow!("Failed to join task"))? // Changed to discard error details
|
||||
.unwrap_or(None); // Use None for both error and no permission cases
|
||||
|
||||
(dashboard_info, user_permission)
|
||||
}
|
||||
AssetType::Thread => {
|
||||
let mut conn = pg_pool.get().await?;
|
||||
|
|
Loading…
Reference in New Issue