diff --git a/api/src/utils/query_engine/data_source_query_routes/snowflake_query.rs b/api/src/utils/query_engine/data_source_query_routes/snowflake_query.rs index 51ce89925..fc31cd58f 100644 --- a/api/src/utils/query_engine/data_source_query_routes/snowflake_query.rs +++ b/api/src/utils/query_engine/data_source_query_routes/snowflake_query.rs @@ -11,13 +11,39 @@ use arrow::array::{ use arrow::datatypes::TimeUnit; use anyhow::{anyhow, Error}; -use chrono::{LocalResult, TimeZone, Utc, NaiveTime}; +use chrono::{DateTime, LocalResult, NaiveTime, TimeZone, Utc}; use snowflake_api::SnowflakeApi; use serde_json::Value; use crate::utils::query_engine::data_types::DataType; +// Add helper functions at the top level +fn process_string_value(value: String) -> String { + value.to_lowercase() +} + +fn process_json_value(value: Value) -> Value { + match value { + Value::String(s) => Value::String(s.to_lowercase()), + Value::Array(arr) => Value::Array(arr.into_iter().map(process_json_value).collect()), + Value::Object(map) => { + let new_map = map.into_iter() + .map(|(k, v)| (k.to_lowercase(), process_json_value(v))) + .collect(); + Value::Object(new_map) + } + _ => value, + } +} + +fn parse_snowflake_timestamp(epoch_data: i64, subsec_nanos: u32) -> Result, Error> { + match Utc.timestamp_opt(epoch_data, subsec_nanos) { + LocalResult::Single(dt) => Ok(dt), + _ => Err(anyhow!("Invalid timestamp value")) + } +} + pub async fn snowflake_query( mut snowflake_client: SnowflakeApi, query: String, @@ -100,12 +126,12 @@ pub async fn snowflake_query( arrow::datatypes::DataType::Utf8 => { let array = column.as_any().downcast_ref::().unwrap(); if array.is_null(row_idx) { DataType::Null } - else { DataType::Text(Some(array.value(row_idx).to_string())) } + else { DataType::Text(Some(process_string_value(array.value(row_idx).to_string()))) } } arrow::datatypes::DataType::LargeUtf8 => { let array = column.as_any().downcast_ref::().unwrap(); if array.is_null(row_idx) { DataType::Null } - else { DataType::Text(Some(array.value(row_idx).to_string())) } + else { DataType::Text(Some(process_string_value(array.value(row_idx).to_string()))) } } arrow::datatypes::DataType::Binary => { let array = column.as_any().downcast_ref::().unwrap(); @@ -144,8 +170,9 @@ pub async fn snowflake_query( } arrow::datatypes::DataType::Timestamp(unit, tz) => { let array = column.as_any().downcast_ref::().unwrap(); - if array.is_null(row_idx) { DataType::Null } - else { + if array.is_null(row_idx) { + DataType::Null + } else { let nanos = array.value(row_idx); let (secs, subsec_nanos) = match unit { TimeUnit::Second => (nanos, 0), @@ -153,12 +180,16 @@ pub async fn snowflake_query( TimeUnit::Microsecond => (nanos / 1_000_000, (nanos % 1_000_000) * 1000), TimeUnit::Nanosecond => (nanos / 1_000_000_000, nanos % 1_000_000_000), }; - match Utc.timestamp_opt(secs as i64, subsec_nanos as u32) { - LocalResult::Single(dt) => match tz { + + match parse_snowflake_timestamp(secs as i64, subsec_nanos as u32) { + Ok(dt) => match tz { Some(_) => DataType::Timestamptz(Some(dt)), None => DataType::Timestamp(Some(dt.naive_utc())), }, - _ => DataType::Null, + Err(e) => { + tracing::error!("Failed to parse timestamp: {}", e); + DataType::Null + } } } } @@ -281,14 +312,14 @@ pub async fn snowflake_query( } else if let Some(num) = values.as_any().downcast_ref::() { Some(Value::Number(num.value(i).into())) } else if let Some(str) = values.as_any().downcast_ref::() { - Some(Value::String(str.value(i).to_string())) + Some(Value::String(process_string_value(str.value(i).to_string()))) } else { None } }) .collect() ); - DataType::Json(Some(json_array)) + DataType::Json(Some(process_json_value(json_array))) } } arrow::datatypes::DataType::Struct(fields) => { @@ -309,7 +340,7 @@ pub async fn snowflake_query( }; map.insert(field_name.to_string(), value); } - DataType::Json(Some(Value::Object(map))) + DataType::Json(Some(process_json_value(Value::Object(map)))) } } arrow::datatypes::DataType::Union(_, _) => { @@ -347,7 +378,7 @@ pub async fn snowflake_query( json_map.insert(key.to_string(), Value::Number(value.into())); } } - DataType::Json(Some(Value::Object(json_map))) + DataType::Json(Some(process_json_value(Value::Object(json_map)))) } } arrow::datatypes::DataType::RunEndEncoded(_, _) => { diff --git a/api/src/utils/stored_values/mod.rs b/api/src/utils/stored_values/mod.rs index 19c776ac5..da2044e76 100644 --- a/api/src/utils/stored_values/mod.rs +++ b/api/src/utils/stored_values/mod.rs @@ -7,7 +7,8 @@ use chrono::Utc; use diesel::prelude::*; use diesel_async::RunQueryDsl; use uuid::Uuid; -use crate::database::lib::get_pg_pool; +use crate::database::enums::StoredValuesStatus; +use crate::database::{lib::get_pg_pool, schema::dataset_columns}; use crate::utils::clients::ai::embedding_router::embedding_router; use diesel::sql_types::{Text, Uuid as SqlUuid, Array, Float4, Timestamptz, Integer}; @@ -88,6 +89,9 @@ pub async fn store_column_values( // Query distinct values in batches let mut offset = 0; + let mut first_batch = true; + let schema_name = organization_id.to_string().replace("-", "_"); + loop { let query = format!( "SELECT DISTINCT \"{}\" as value @@ -104,6 +108,9 @@ pub async fn store_column_values( Ok(results) => results, Err(e) => { tracing::error!("Error querying stored values: {:?}", e); + if first_batch { + return Err(e); + } vec![] } }; @@ -128,10 +135,41 @@ pub async fn store_column_values( break; } + // If this is the first batch and we have 15 or fewer values, handle as enum + if first_batch && values.len() <= 15 { + // Get current description + let current_description = diesel::sql_query("SELECT description FROM dataset_columns WHERE id = $1") + .bind::(column_id) + .get_result::(&mut conn) + .await + .ok() + .and_then(|row| Some(row.value)); + + // Format new description + let enum_list = format!("Values for this column are: {}", values.join(", ")); + let new_description = match current_description { + Some(desc) if !desc.is_empty() => format!("{}. {}", desc, enum_list), + _ => enum_list, + }; + + // Update column description + diesel::update(dataset_columns::table) + .filter(dataset_columns::id.eq(column_id)) + .set(( + dataset_columns::description.eq(new_description), + dataset_columns::stored_values_status.eq(StoredValuesStatus::Success), + dataset_columns::stored_values_count.eq(values.len() as i64), + dataset_columns::stored_values_last_synced.eq(Utc::now()), + )) + .execute(&mut conn) + .await?; + + return Ok(()); + } + // Create embeddings for the batch let embeddings = create_embeddings_batch(&values).await?; - let schema_name = organization_id.to_string().replace("-", "_"); // Insert values and embeddings for (value, embedding) in values.iter().zip(embeddings.iter()) { let insert_sql = format!( @@ -154,6 +192,7 @@ pub async fn store_column_values( .await?; } + first_batch = false; offset += BATCH_SIZE; }