diff --git a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs index 693d8e68f..6be6225f6 100644 --- a/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs +++ b/api/libs/agents/src/tools/categories/file_tools/search_data_catalog.rs @@ -14,7 +14,7 @@ use uuid::Uuid; use crate::{agent::Agent, tools::ToolExecutor}; -use litellm::{ChatCompletionRequest, LiteLLMClient, AgentMessage, Metadata, ResponseFormat}; +use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat}; #[derive(Debug, Serialize, Deserialize)] pub struct SearchDataCatalogParams { @@ -90,13 +90,17 @@ impl SearchDataCatalogTool { true } - async fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result { + async fn format_search_prompt( + query_params: &[String], + datasets: &[DatasetRecord], + ) -> Result { let datasets_json = datasets .iter() .map(|d| d.to_llm_format()) .collect::>(); - Ok(SearchDataCatalogTool::get_search_prompt().await + Ok(SearchDataCatalogTool::get_search_prompt() + .await .replace("{queries_joined_with_newlines}", &query_params.join("\n")) .replace( "{datasets_array_as_json}", @@ -108,7 +112,7 @@ impl SearchDataCatalogTool { if env::var("USE_BRAINTRUST_PROMPTS").is_err() { return CATALOG_SEARCH_PROMPT.to_string(); } - + let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap(); match get_prompt_system_message(&client, "812b3f76-20d5-49e3-884c-2c8084800b43").await { Ok(message) => message, @@ -137,7 +141,7 @@ impl SearchDataCatalogTool { while retry_count < MAX_RETRIES { let request = ChatCompletionRequest { - model: "o3-mini".to_string(), + model: "gemini-2.0-flash-001".to_string(), messages: vec![AgentMessage::User { id: None, content: current_prompt.clone(), @@ -146,7 +150,27 @@ impl SearchDataCatalogTool { stream: Some(false), response_format: Some(ResponseFormat { type_: "json_object".to_string(), - json_schema: None, + json_schema: Some(json!({ + "type": "object", + "properties": { + "results": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "format": "uuid" + } + }, + "required": ["id"], + "additionalProperties": false + } + } + }, + "required": ["results"], + "additionalProperties": false + })), }), metadata: Some(Metadata { generation_name: "search_data_catalog".to_string(), @@ -154,7 +178,6 @@ impl SearchDataCatalogTool { session_id: session_id.to_string(), trace_id: session_id.to_string(), }), - reasoning_effort: Some(String::from("low")), max_completion_tokens: Some(8092), ..Default::default() }; @@ -288,7 +311,8 @@ impl ToolExecutor for SearchDataCatalogTool { } // Format prompt and perform search - let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?; + let prompt = + Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?; let search_results = match Self::perform_llm_search( prompt, &self.agent.get_user_id(), diff --git a/api/libs/handlers/src/chats/post_chat_handler.rs b/api/libs/handlers/src/chats/post_chat_handler.rs index 64d0e090d..1c1e9fdfb 100644 --- a/api/libs/handlers/src/chats/post_chat_handler.rs +++ b/api/libs/handlers/src/chats/post_chat_handler.rs @@ -849,7 +849,7 @@ pub struct BusterChatResponseFileMetadata { } #[derive(Debug, Serialize, Clone)] -#[serde(tag = "type")] +#[serde(tag = "type", rename_all = "snake_case")] pub enum BusterChatMessage { Text { id: String,