mirror of https://github.com/buster-so/buster.git
fixes on issues
This commit is contained in:
parent
d4572ff7ce
commit
6fe50c78de
|
@ -14,7 +14,7 @@ use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{agent::Agent, tools::ToolExecutor};
|
use crate::{agent::Agent, tools::ToolExecutor};
|
||||||
|
|
||||||
use litellm::{ChatCompletionRequest, LiteLLMClient, AgentMessage, Metadata, ResponseFormat};
|
use litellm::{AgentMessage, ChatCompletionRequest, LiteLLMClient, Metadata, ResponseFormat};
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct SearchDataCatalogParams {
|
pub struct SearchDataCatalogParams {
|
||||||
|
@ -90,13 +90,17 @@ impl SearchDataCatalogTool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn format_search_prompt(query_params: &[String], datasets: &[DatasetRecord]) -> Result<String> {
|
async fn format_search_prompt(
|
||||||
|
query_params: &[String],
|
||||||
|
datasets: &[DatasetRecord],
|
||||||
|
) -> Result<String> {
|
||||||
let datasets_json = datasets
|
let datasets_json = datasets
|
||||||
.iter()
|
.iter()
|
||||||
.map(|d| d.to_llm_format())
|
.map(|d| d.to_llm_format())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
Ok(SearchDataCatalogTool::get_search_prompt().await
|
Ok(SearchDataCatalogTool::get_search_prompt()
|
||||||
|
.await
|
||||||
.replace("{queries_joined_with_newlines}", &query_params.join("\n"))
|
.replace("{queries_joined_with_newlines}", &query_params.join("\n"))
|
||||||
.replace(
|
.replace(
|
||||||
"{datasets_array_as_json}",
|
"{datasets_array_as_json}",
|
||||||
|
@ -108,7 +112,7 @@ impl SearchDataCatalogTool {
|
||||||
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
|
if env::var("USE_BRAINTRUST_PROMPTS").is_err() {
|
||||||
return CATALOG_SEARCH_PROMPT.to_string();
|
return CATALOG_SEARCH_PROMPT.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
|
let client = BraintrustClient::new(None, "96af8b2b-cf3c-494f-9092-44eb3d5b96ff").unwrap();
|
||||||
match get_prompt_system_message(&client, "812b3f76-20d5-49e3-884c-2c8084800b43").await {
|
match get_prompt_system_message(&client, "812b3f76-20d5-49e3-884c-2c8084800b43").await {
|
||||||
Ok(message) => message,
|
Ok(message) => message,
|
||||||
|
@ -137,7 +141,7 @@ impl SearchDataCatalogTool {
|
||||||
|
|
||||||
while retry_count < MAX_RETRIES {
|
while retry_count < MAX_RETRIES {
|
||||||
let request = ChatCompletionRequest {
|
let request = ChatCompletionRequest {
|
||||||
model: "o3-mini".to_string(),
|
model: "gemini-2.0-flash-001".to_string(),
|
||||||
messages: vec![AgentMessage::User {
|
messages: vec![AgentMessage::User {
|
||||||
id: None,
|
id: None,
|
||||||
content: current_prompt.clone(),
|
content: current_prompt.clone(),
|
||||||
|
@ -146,7 +150,27 @@ impl SearchDataCatalogTool {
|
||||||
stream: Some(false),
|
stream: Some(false),
|
||||||
response_format: Some(ResponseFormat {
|
response_format: Some(ResponseFormat {
|
||||||
type_: "json_object".to_string(),
|
type_: "json_object".to_string(),
|
||||||
json_schema: None,
|
json_schema: Some(json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"results": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"type": "string",
|
||||||
|
"format": "uuid"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["id"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["results"],
|
||||||
|
"additionalProperties": false
|
||||||
|
})),
|
||||||
}),
|
}),
|
||||||
metadata: Some(Metadata {
|
metadata: Some(Metadata {
|
||||||
generation_name: "search_data_catalog".to_string(),
|
generation_name: "search_data_catalog".to_string(),
|
||||||
|
@ -154,7 +178,6 @@ impl SearchDataCatalogTool {
|
||||||
session_id: session_id.to_string(),
|
session_id: session_id.to_string(),
|
||||||
trace_id: session_id.to_string(),
|
trace_id: session_id.to_string(),
|
||||||
}),
|
}),
|
||||||
reasoning_effort: Some(String::from("low")),
|
|
||||||
max_completion_tokens: Some(8092),
|
max_completion_tokens: Some(8092),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
@ -288,7 +311,8 @@ impl ToolExecutor for SearchDataCatalogTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format prompt and perform search
|
// Format prompt and perform search
|
||||||
let prompt = Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?;
|
let prompt =
|
||||||
|
Self::format_search_prompt(&[params.search_requirements.clone()], &datasets).await?;
|
||||||
let search_results = match Self::perform_llm_search(
|
let search_results = match Self::perform_llm_search(
|
||||||
prompt,
|
prompt,
|
||||||
&self.agent.get_user_id(),
|
&self.agent.get_user_id(),
|
||||||
|
|
|
@ -849,7 +849,7 @@ pub struct BusterChatResponseFileMetadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Clone)]
|
#[derive(Debug, Serialize, Clone)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub enum BusterChatMessage {
|
pub enum BusterChatMessage {
|
||||||
Text {
|
Text {
|
||||||
id: String,
|
id: String,
|
||||||
|
|
Loading…
Reference in New Issue