diff --git a/api/src/utils/tools/file_tools/search_data_catalog.rs b/api/src/utils/tools/file_tools/search_data_catalog.rs index fe66d0bca..cac7f2bd7 100644 --- a/api/src/utils/tools/file_tools/search_data_catalog.rs +++ b/api/src/utils/tools/file_tools/search_data_catalog.rs @@ -116,74 +116,116 @@ impl SearchDataCatalogTool { session_id: &Uuid, ) -> Result> { debug!("Performing LLM search"); - + // Setup LiteLLM client let llm_client = LiteLLMClient::new(None, None); - let request = ChatCompletionRequest { - model: "gemini-2".to_string(), - messages: vec![Message::User { - id: None, - content: prompt, - name: None, - }], - stream: Some(false), - response_format: Some(ResponseFormat { - type_: "json_object".to_string(), - json_schema: None, - }), - metadata: Some(Metadata { - generation_name: "search_data_catalog".to_string(), - user_id: user_id.to_string(), - session_id: session_id.to_string(), - trace_id: Uuid::new_v4().to_string(), - }), - // reasoning_effort: Some("low".to_string()), - ..Default::default() - }; + + // Maximum number of retries for parsing errors + const MAX_RETRIES: usize = 3; + let mut retry_count = 0; + let mut last_error = None; + let mut current_prompt = prompt; - // Get response from LLM - let response = llm_client.chat_completion(request).await.map_err(|e| { - error!(error = %e, "Failed to get response from LLM"); - anyhow::anyhow!("Failed to get response from LLM: {}", e) - })?; + while retry_count < MAX_RETRIES { + let request = ChatCompletionRequest { + model: "o3-mini".to_string(), + messages: vec![Message::User { + id: None, + content: current_prompt.clone(), + name: None, + }], + stream: Some(false), + response_format: Some(ResponseFormat { + type_: "json_object".to_string(), + json_schema: None, + }), + metadata: Some(Metadata { + generation_name: "search_data_catalog".to_string(), + user_id: user_id.to_string(), + session_id: session_id.to_string(), + trace_id: session_id.to_string(), + }), + reasoning_effort: Some("low".to_string()), + max_completion_tokens: Some(8092), + ..Default::default() + }; - // Parse LLM response - let content = match &response.choices[0].message { - Message::Assistant { - content: Some(content), - .. - } => content, - _ => { - error!("LLM response missing content"); - return Err(anyhow::anyhow!("LLM response missing content")); - } - }; - - // Parse into raw response first - let raw_response: RawLLMResponse = serde_json::from_str(content).map_err(|e| { - warn!(error = %e, "Failed to parse LLM response as JSON"); - anyhow::anyhow!("Failed to parse search results: {}", e) - })?; - - // Process each result, logging any invalid ones - let mut valid_results = Vec::new(); - let mut invalid_count = 0; - - for result in raw_response.results { - match parse_search_result(&result) { - Ok(result) => valid_results.push(result), + // Get response from LLM + let response = match llm_client.chat_completion(request).await { + Ok(resp) => resp, Err(e) => { - warn!(error = %e, "Invalid search result from LLM"); - invalid_count += 1; + error!(error = %e, "Failed to get response from LLM"); + return Err(anyhow::anyhow!("Failed to get response from LLM: {}", e)); + } + }; + + // Parse LLM response + let content = match &response.choices[0].message { + Message::Assistant { + content: Some(content), + .. + } => content, + _ => { + error!("LLM response missing content"); + return Err(anyhow::anyhow!("LLM response missing content")); + } + }; + + // Parse into raw response first + match serde_json::from_str::(content) { + Ok(raw_response) => { + // Process each result, logging any invalid ones + let mut valid_results = Vec::new(); + let mut invalid_count = 0; + + for result in raw_response.results { + match parse_search_result(&result) { + Ok(result) => valid_results.push(result), + Err(e) => { + warn!(error = %e, "Invalid search result from LLM"); + invalid_count += 1; + } + } + } + + if invalid_count > 0 { + warn!(count = invalid_count, "Found invalid search results"); + } + + return Ok(valid_results); + }, + Err(e) => { + // Store the error for potential return + let error_message = e.to_string(); + last_error = Some(error_message.clone()); + + // Log the error and retry + warn!( + error = %error_message, + retry = retry_count + 1, + max_retries = MAX_RETRIES, + "Failed to parse LLM response as JSON, retrying with error feedback..." + ); + + // Increment retry counter + retry_count += 1; + + // Only modify the prompt if we're going to retry + if retry_count < MAX_RETRIES { + // Add the error to the prompt to help the LLM correct its response + current_prompt = format!( + "{}\n\nYour previous response could not be parsed correctly. Error: {}\n\nPlease ensure your response is valid JSON with the exact format specified. The response must include a 'results' array containing objects with only an 'id' field that is a valid UUID string.", + current_prompt, error_message + ); + } } } } - if invalid_count > 0 { - warn!(count = invalid_count, "Found invalid search results"); - } - - Ok(valid_results) + // If we've exhausted all retries, return the last error + Err(anyhow::anyhow!("Failed to parse search results after {} retries: {}", + MAX_RETRIES, + last_error.unwrap_or_else(|| "Unknown error".to_string()))) } async fn get_datasets() -> Result> {