fixes on issues

This commit is contained in:
dal 2025-04-09 11:04:40 -06:00
parent d4572ff7ce
commit 6fe50c78de
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 33 additions and 9 deletions

View File

@ -14,7 +14,7 @@ use uuid::Uuid;
use crate::{agent::Agent, tools::ToolExecutor}; use crate::{agent::Agent, tools::ToolExecutor};
use litellm::{ChatCompletionRequest, LiteLLMClient, AgentMessage, Metadata, ResponseFormat}; use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct SearchDataCatalogParams { pub struct SearchDataCatalogParams {
@ -90,13 +90,17 @@ impl SearchDataCatalogTool {
true true
} }
async fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result<String> { async fn format_search_prompt(
query_params: &[String],
datasets: &[DatasetRecord],
) -> Result<String> {
let datasets_json = datasets let datasets_json = datasets
.iter() .iter()
.map(|d| d.to_llm_format()) .map(|d| d.to_llm_format())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Ok(SearchDataCatalogTool::get_search_prompt().await Ok(SearchDataCatalogTool::get_search_prompt()
.await
.replace("{queries_joined_with_newlines}", &query_params.join("\n")) .replace("{queries_joined_with_newlines}", &query_params.join("\n"))
.replace( .replace(
"{datasets_array_as_json}", "{datasets_array_as_json}",
@ -137,7 +141,7 @@ impl SearchDataCatalogTool {
while retry_count < MAX_RETRIES { while retry_count < MAX_RETRIES {
let request = ChatCompletionRequest { let request = ChatCompletionRequest {
model: "o3-mini".to_string(), model: "gemini-2.0-flash-001".to_string(),
messages: vec![AgentMessage::User { messages: vec![AgentMessage::User {
id: None, id: None,
content: current_prompt.clone(), content: current_prompt.clone(),
@ -146,7 +150,27 @@ impl SearchDataCatalogTool {
stream: Some(false), stream: Some(false),
response_format: Some(ResponseFormat { response_format: Some(ResponseFormat {
type_: "json_object".to_string(), type_: "json_object".to_string(),
json_schema: None, json_schema: Some(json!({
"type": "object",
"properties": {
"results": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid"
}
},
"required": ["id"],
"additionalProperties": false
}
}
},
"required": ["results"],
"additionalProperties": false
})),
}), }),
metadata: Some(Metadata { metadata: Some(Metadata {
generation_name: "search_data_catalog".to_string(), generation_name: "search_data_catalog".to_string(),
@ -154,7 +178,6 @@ impl SearchDataCatalogTool {
session_id: session_id.to_string(), session_id: session_id.to_string(),
trace_id: session_id.to_string(), trace_id: session_id.to_string(),
}), }),
reasoning_effort: Some(String::from("low")),
max_completion_tokens: Some(8092), max_completion_tokens: Some(8092),
..Default::default() ..Default::default()
}; };
@ -288,7 +311,8 @@ impl ToolExecutor for SearchDataCatalogTool {
} }
// Format prompt and perform search // Format prompt and perform search
let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?; let prompt =
Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?;
let search_results = match Self::perform_llm_search( let search_results = match Self::perform_llm_search(
prompt, prompt,
&self.agent.get_user_id(), &self.agent.get_user_id(),

View File

@ -849,7 +849,7 @@ pub struct BusterChatResponseFileMetadata {
} }
#[derive(Debug, Serialize, Clone)] #[derive(Debug, Serialize, Clone)]
#[serde(tag = "type")] #[serde(tag = "type", rename_all = "snake_case")]
pub enum BusterChatMessage { pub enum BusterChatMessage {
Text { Text {
id: String, id: String,