ok now generating descriptions ayo

This commit is contained in:
dal 2025-02-12 08:41:17 -07:00
parent bfe61d52e2
commit 37854342da
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 148 additions and 89 deletions

View File

@ -7,6 +7,8 @@ use serde::{Deserialize, Serialize};
use serde_yaml;
use std::collections::HashMap;
use uuid::Uuid;
use regex::Regex;
use tokio::task::JoinSet;
use crate::{
database::{
@ -22,6 +24,10 @@ use crate::{
credentials::get_data_source_credentials,
import_dataset_columns::{retrieve_dataset_columns_batch, DatasetColumnRecord},
},
clients::ai::{
openai::{OpenAiChatModel, OpenAiChatRole, OpenAiChatContent, OpenAiChatMessage},
llm_router::{llm_chat, LlmModel, LlmMessage},
},
},
};
@ -156,6 +162,125 @@ pub async fn generate_datasets(
}
}
async fn enhance_yaml_with_descriptions(yaml: String) -> Result<String> {
const DESCRIPTION_PLACEHOLDER: &str = "{NEED DESCRIPTION HERE}";
// Skip OpenAI call if no placeholders exist
if !yaml.contains(DESCRIPTION_PLACEHOLDER) {
return Ok(yaml);
}
let messages = vec![
LlmMessage::new(
"developer".to_string(),
"You are a YAML description enhancer. Your output must be wrapped in markdown code blocks using ```yml format.
Your task is to ONLY replace text matching exactly \"{NEED DESCRIPTION HERE}\" with appropriate descriptions.
DO NOT modify any other part of the YAML.
DO NOT add any explanations or text outside the ```yml block.
Return the complete YAML wrapped in markdown, with only the placeholders replaced.".to_string(),
),
LlmMessage::new(
"user".to_string(),
yaml,
),
];
let response = llm_chat(
LlmModel::OpenAi(OpenAiChatModel::O3Mini),
&messages,
0.1,
2048,
30,
None,
false,
None,
&Uuid::new_v4(),
&Uuid::new_v4(),
crate::utils::clients::ai::langfuse::PromptName::CustomPrompt("enhance_yaml_descriptions".to_string()),
)
.await?;
// Extract YAML from markdown code blocks
let re = Regex::new(r"```yml\n([\s\S]*?)\n```").unwrap();
let yaml = match re.captures(&response) {
Some(caps) => caps.get(1).unwrap().as_str().to_string(),
None => return Err(anyhow!("Failed to extract YAML from response")),
};
Ok(yaml)
}
async fn generate_model_yaml(
model_name: &str,
ds_columns: &[DatasetColumnRecord],
schema: &str,
) -> Result<String> {
// Filter columns for this model
let model_columns: Vec<_> = ds_columns
.iter()
.filter(|col| {
col.dataset_name.to_lowercase() == model_name.to_lowercase()
&& col.schema_name.to_lowercase() == schema.to_lowercase()
})
.collect();
if model_columns.is_empty() {
return Err(anyhow!("No columns found for model"));
}
let mut dimensions = Vec::new();
let mut measures = Vec::new();
// Process each column and categorize as dimension or measure
for col in model_columns {
match map_snowflake_type(&col.type_) {
ColumnMappingType::Dimension(semantic_type) => {
dimensions.push(Dimension {
name: col.name.clone(),
expr: col.name.clone(),
type_: semantic_type,
description: "{NEED DESCRIPTION HERE}".to_string(),
searchable: Some(false),
});
}
ColumnMappingType::Measure(measure_type) => {
measures.push(Measure {
name: col.name.clone(),
expr: col.name.clone(),
type_: measure_type,
agg: Some("sum".to_string()),
description: "{NEED DESCRIPTION HERE}".to_string(),
});
}
ColumnMappingType::Unsupported => {
tracing::warn!(
"Skipping unsupported column type: {} for column: {}",
col.type_,
col.name
);
}
}
}
let model = Model {
name: model_name.to_string(),
description: format!("Generated model for {}", model_name),
dimensions,
measures,
};
let config = ModelConfig {
models: vec![model],
};
let yaml = serde_yaml::to_string(&config)?;
// Enhance descriptions using OpenAI
let enhanced_yaml = enhance_yaml_with_descriptions(yaml).await?;
Ok(enhanced_yaml)
}
async fn generate_datasets_handler(
request: &GenerateDatasetRequest,
organization_id: &Uuid,
@ -190,32 +315,34 @@ async fn generate_datasets_handler(
Err(e) => return Err(anyhow!("Failed to get columns from data source: {}", e)),
};
// Check for existing datasets (just for logging/info purposes)
let existing_datasets: HashMap<String, Dataset> = datasets::table
.filter(datasets::data_source_id.eq(&data_source.id))
.filter(datasets::deleted_at.is_null())
.load::<Dataset>(&mut conn)
.await?
.into_iter()
.map(|d| (d.name.clone(), d))
.collect();
// Process models concurrently
let mut join_set = JoinSet::new();
for model_name in &request.model_names {
let model_name = model_name.clone();
let schema = request.schema.clone();
let ds_columns = ds_columns.clone();
join_set.spawn(async move {
let result = generate_model_yaml(&model_name, &ds_columns, &schema).await;
(model_name, result)
});
}
let mut yml_contents = HashMap::new();
let mut errors = HashMap::new();
// Process each model
for model_name in &request.model_names {
// Log if dataset already exists
if existing_datasets.contains_key(model_name) {
tracing::info!("Dataset {} already exists", model_name);
}
match generate_model_yaml(model_name, &ds_columns, &request.schema).await {
Ok(yaml) => {
yml_contents.insert(model_name.clone(), yaml);
while let Some(result) = join_set.join_next().await {
match result {
Ok((model_name, Ok(yaml))) => {
yml_contents.insert(model_name, yaml);
}
Ok((model_name, Err(e))) => {
errors.insert(model_name, e.to_string());
}
Err(e) => {
errors.insert(model_name.clone(), e.to_string());
tracing::error!("Task join error: {:?}", e);
return Err(anyhow!("Task execution failed"));
}
}
}
@ -224,72 +351,4 @@ async fn generate_datasets_handler(
yml_contents,
errors,
})
}
async fn generate_model_yaml(
model_name: &str,
ds_columns: &[DatasetColumnRecord],
schema: &str,
) -> Result<String> {
// Filter columns for this model
let model_columns: Vec<_> = ds_columns
.iter()
.filter(|col| {
col.dataset_name.to_lowercase() == model_name.to_lowercase()
&& col.schema_name.to_lowercase() == schema.to_lowercase()
})
.collect();
if model_columns.is_empty() {
return Err(anyhow!("No columns found for model"));
}
let mut dimensions = Vec::new();
let mut measures = Vec::new();
// Process each column and categorize as dimension or measure
for col in model_columns {
match map_snowflake_type(&col.type_) {
ColumnMappingType::Dimension(semantic_type) => {
dimensions.push(Dimension {
name: col.name.clone(),
expr: col.name.clone(),
type_: semantic_type,
description: format!("Column {} from {}", col.name, model_name),
searchable: Some(false),
});
}
ColumnMappingType::Measure(measure_type) => {
measures.push(Measure {
name: col.name.clone(),
expr: col.name.clone(),
type_: measure_type,
agg: Some("sum".to_string()), // Default aggregation
description: format!("Column {} from {}", col.name, model_name),
});
}
ColumnMappingType::Unsupported => {
tracing::warn!(
"Skipping unsupported column type: {} for column: {}",
col.type_,
col.name
);
}
}
}
let model = Model {
name: model_name.to_string(),
description: format!("Generated model for {}", model_name),
dimensions,
measures,
};
let config = ModelConfig {
models: vec![model],
};
let yaml = serde_yaml::to_string(&config)?;
Ok(yaml)
}

View File

@ -20,7 +20,7 @@ use gcp_bigquery_client::model::query_request::QueryRequest;
use sqlx::{FromRow, Row};
use uuid::Uuid;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct DatasetColumnRecord {
pub dataset_name: String,
pub schema_name: String,