mirror of https://github.com/buster-so/buster.git
ok now generating descriptions ayo
This commit is contained in:
parent
bfe61d52e2
commit
37854342da
|
@ -7,6 +7,8 @@ use serde::{Deserialize, Serialize};
|
|||
use serde_yaml;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
use regex::Regex;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use crate::{
|
||||
database::{
|
||||
|
@ -22,6 +24,10 @@ use crate::{
|
|||
credentials::get_data_source_credentials,
|
||||
import_dataset_columns::{retrieve_dataset_columns_batch, DatasetColumnRecord},
|
||||
},
|
||||
clients::ai::{
|
||||
openai::{OpenAiChatModel, OpenAiChatRole, OpenAiChatContent, OpenAiChatMessage},
|
||||
llm_router::{llm_chat, LlmModel, LlmMessage},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -156,6 +162,125 @@ pub async fn generate_datasets(
|
|||
}
|
||||
}
|
||||
|
||||
async fn enhance_yaml_with_descriptions(yaml: String) -> Result<String> {
|
||||
const DESCRIPTION_PLACEHOLDER: &str = "{NEED DESCRIPTION HERE}";
|
||||
|
||||
// Skip OpenAI call if no placeholders exist
|
||||
if !yaml.contains(DESCRIPTION_PLACEHOLDER) {
|
||||
return Ok(yaml);
|
||||
}
|
||||
|
||||
let messages = vec![
|
||||
LlmMessage::new(
|
||||
"developer".to_string(),
|
||||
"You are a YAML description enhancer. Your output must be wrapped in markdown code blocks using ```yml format.
|
||||
Your task is to ONLY replace text matching exactly \"{NEED DESCRIPTION HERE}\" with appropriate descriptions.
|
||||
DO NOT modify any other part of the YAML.
|
||||
DO NOT add any explanations or text outside the ```yml block.
|
||||
Return the complete YAML wrapped in markdown, with only the placeholders replaced.".to_string(),
|
||||
),
|
||||
LlmMessage::new(
|
||||
"user".to_string(),
|
||||
yaml,
|
||||
),
|
||||
];
|
||||
|
||||
let response = llm_chat(
|
||||
LlmModel::OpenAi(OpenAiChatModel::O3Mini),
|
||||
&messages,
|
||||
0.1,
|
||||
2048,
|
||||
30,
|
||||
None,
|
||||
false,
|
||||
None,
|
||||
&Uuid::new_v4(),
|
||||
&Uuid::new_v4(),
|
||||
crate::utils::clients::ai::langfuse::PromptName::CustomPrompt("enhance_yaml_descriptions".to_string()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Extract YAML from markdown code blocks
|
||||
let re = Regex::new(r"```yml\n([\s\S]*?)\n```").unwrap();
|
||||
let yaml = match re.captures(&response) {
|
||||
Some(caps) => caps.get(1).unwrap().as_str().to_string(),
|
||||
None => return Err(anyhow!("Failed to extract YAML from response")),
|
||||
};
|
||||
|
||||
Ok(yaml)
|
||||
}
|
||||
|
||||
async fn generate_model_yaml(
|
||||
model_name: &str,
|
||||
ds_columns: &[DatasetColumnRecord],
|
||||
schema: &str,
|
||||
) -> Result<String> {
|
||||
// Filter columns for this model
|
||||
let model_columns: Vec<_> = ds_columns
|
||||
.iter()
|
||||
.filter(|col| {
|
||||
col.dataset_name.to_lowercase() == model_name.to_lowercase()
|
||||
&& col.schema_name.to_lowercase() == schema.to_lowercase()
|
||||
})
|
||||
.collect();
|
||||
|
||||
if model_columns.is_empty() {
|
||||
return Err(anyhow!("No columns found for model"));
|
||||
}
|
||||
|
||||
let mut dimensions = Vec::new();
|
||||
let mut measures = Vec::new();
|
||||
|
||||
// Process each column and categorize as dimension or measure
|
||||
for col in model_columns {
|
||||
match map_snowflake_type(&col.type_) {
|
||||
ColumnMappingType::Dimension(semantic_type) => {
|
||||
dimensions.push(Dimension {
|
||||
name: col.name.clone(),
|
||||
expr: col.name.clone(),
|
||||
type_: semantic_type,
|
||||
description: "{NEED DESCRIPTION HERE}".to_string(),
|
||||
searchable: Some(false),
|
||||
});
|
||||
}
|
||||
ColumnMappingType::Measure(measure_type) => {
|
||||
measures.push(Measure {
|
||||
name: col.name.clone(),
|
||||
expr: col.name.clone(),
|
||||
type_: measure_type,
|
||||
agg: Some("sum".to_string()),
|
||||
description: "{NEED DESCRIPTION HERE}".to_string(),
|
||||
});
|
||||
}
|
||||
ColumnMappingType::Unsupported => {
|
||||
tracing::warn!(
|
||||
"Skipping unsupported column type: {} for column: {}",
|
||||
col.type_,
|
||||
col.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let model = Model {
|
||||
name: model_name.to_string(),
|
||||
description: format!("Generated model for {}", model_name),
|
||||
dimensions,
|
||||
measures,
|
||||
};
|
||||
|
||||
let config = ModelConfig {
|
||||
models: vec![model],
|
||||
};
|
||||
|
||||
let yaml = serde_yaml::to_string(&config)?;
|
||||
|
||||
// Enhance descriptions using OpenAI
|
||||
let enhanced_yaml = enhance_yaml_with_descriptions(yaml).await?;
|
||||
|
||||
Ok(enhanced_yaml)
|
||||
}
|
||||
|
||||
async fn generate_datasets_handler(
|
||||
request: &GenerateDatasetRequest,
|
||||
organization_id: &Uuid,
|
||||
|
@ -190,32 +315,34 @@ async fn generate_datasets_handler(
|
|||
Err(e) => return Err(anyhow!("Failed to get columns from data source: {}", e)),
|
||||
};
|
||||
|
||||
// Check for existing datasets (just for logging/info purposes)
|
||||
let existing_datasets: HashMap<String, Dataset> = datasets::table
|
||||
.filter(datasets::data_source_id.eq(&data_source.id))
|
||||
.filter(datasets::deleted_at.is_null())
|
||||
.load::<Dataset>(&mut conn)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|d| (d.name.clone(), d))
|
||||
.collect();
|
||||
// Process models concurrently
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for model_name in &request.model_names {
|
||||
let model_name = model_name.clone();
|
||||
let schema = request.schema.clone();
|
||||
let ds_columns = ds_columns.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = generate_model_yaml(&model_name, &ds_columns, &schema).await;
|
||||
(model_name, result)
|
||||
});
|
||||
}
|
||||
|
||||
let mut yml_contents = HashMap::new();
|
||||
let mut errors = HashMap::new();
|
||||
|
||||
// Process each model
|
||||
for model_name in &request.model_names {
|
||||
// Log if dataset already exists
|
||||
if existing_datasets.contains_key(model_name) {
|
||||
tracing::info!("Dataset {} already exists", model_name);
|
||||
}
|
||||
|
||||
match generate_model_yaml(model_name, &ds_columns, &request.schema).await {
|
||||
Ok(yaml) => {
|
||||
yml_contents.insert(model_name.clone(), yaml);
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
match result {
|
||||
Ok((model_name, Ok(yaml))) => {
|
||||
yml_contents.insert(model_name, yaml);
|
||||
}
|
||||
Ok((model_name, Err(e))) => {
|
||||
errors.insert(model_name, e.to_string());
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(model_name.clone(), e.to_string());
|
||||
tracing::error!("Task join error: {:?}", e);
|
||||
return Err(anyhow!("Task execution failed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -224,72 +351,4 @@ async fn generate_datasets_handler(
|
|||
yml_contents,
|
||||
errors,
|
||||
})
|
||||
}
|
||||
|
||||
async fn generate_model_yaml(
|
||||
model_name: &str,
|
||||
ds_columns: &[DatasetColumnRecord],
|
||||
schema: &str,
|
||||
) -> Result<String> {
|
||||
// Filter columns for this model
|
||||
let model_columns: Vec<_> = ds_columns
|
||||
.iter()
|
||||
.filter(|col| {
|
||||
col.dataset_name.to_lowercase() == model_name.to_lowercase()
|
||||
&& col.schema_name.to_lowercase() == schema.to_lowercase()
|
||||
})
|
||||
.collect();
|
||||
|
||||
if model_columns.is_empty() {
|
||||
return Err(anyhow!("No columns found for model"));
|
||||
}
|
||||
|
||||
let mut dimensions = Vec::new();
|
||||
let mut measures = Vec::new();
|
||||
|
||||
// Process each column and categorize as dimension or measure
|
||||
for col in model_columns {
|
||||
match map_snowflake_type(&col.type_) {
|
||||
ColumnMappingType::Dimension(semantic_type) => {
|
||||
dimensions.push(Dimension {
|
||||
name: col.name.clone(),
|
||||
expr: col.name.clone(),
|
||||
type_: semantic_type,
|
||||
description: format!("Column {} from {}", col.name, model_name),
|
||||
searchable: Some(false),
|
||||
});
|
||||
}
|
||||
ColumnMappingType::Measure(measure_type) => {
|
||||
measures.push(Measure {
|
||||
name: col.name.clone(),
|
||||
expr: col.name.clone(),
|
||||
type_: measure_type,
|
||||
agg: Some("sum".to_string()), // Default aggregation
|
||||
description: format!("Column {} from {}", col.name, model_name),
|
||||
});
|
||||
}
|
||||
ColumnMappingType::Unsupported => {
|
||||
tracing::warn!(
|
||||
"Skipping unsupported column type: {} for column: {}",
|
||||
col.type_,
|
||||
col.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let model = Model {
|
||||
name: model_name.to_string(),
|
||||
description: format!("Generated model for {}", model_name),
|
||||
dimensions,
|
||||
measures,
|
||||
};
|
||||
|
||||
let config = ModelConfig {
|
||||
models: vec![model],
|
||||
};
|
||||
|
||||
let yaml = serde_yaml::to_string(&config)?;
|
||||
|
||||
Ok(yaml)
|
||||
}
|
|
@ -20,7 +20,7 @@ use gcp_bigquery_client::model::query_request::QueryRequest;
|
|||
use sqlx::{FromRow, Row};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DatasetColumnRecord {
|
||||
pub dataset_name: String,
|
||||
pub schema_name: String,
|
||||
|
|
Loading…
Reference in New Issue