Merge branch 'staging' into dal/cli-release

This commit is contained in:
dal 2025-02-04 16:14:47 -07:00
commit 05e75a5652
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
2 changed files with 84 additions and 14 deletions

View File

@ -11,13 +11,39 @@ use arrow::array::{
use arrow::datatypes::TimeUnit;
use anyhow::{anyhow, Error};
use chrono::{LocalResult, TimeZone, Utc, NaiveTime};
use chrono::{DateTime, LocalResult, NaiveTime, TimeZone, Utc};
use snowflake_api::SnowflakeApi;
use serde_json::Value;
use crate::utils::query_engine::data_types::DataType;
// Add helper functions at the top level
fn process_string_value(value: String) -> String {
value.to_lowercase()
}
fn process_json_value(value: Value) -> Value {
match value {
Value::String(s) => Value::String(s.to_lowercase()),
Value::Array(arr) => Value::Array(arr.into_iter().map(process_json_value).collect()),
Value::Object(map) => {
let new_map = map.into_iter()
.map(|(k, v)| (k.to_lowercase(), process_json_value(v)))
.collect();
Value::Object(new_map)
}
_ => value,
}
}
fn parse_snowflake_timestamp(epoch_data: i64, subsec_nanos: u32) -> Result<DateTime<Utc>, Error> {
match Utc.timestamp_opt(epoch_data, subsec_nanos) {
LocalResult::Single(dt) => Ok(dt),
_ => Err(anyhow!("Invalid timestamp value"))
}
}
pub async fn snowflake_query(
mut snowflake_client: SnowflakeApi,
query: String,
@ -100,12 +126,12 @@ pub async fn snowflake_query(
arrow::datatypes::DataType::Utf8 => {
let array = column.as_any().downcast_ref::<StringArray>().unwrap();
if array.is_null(row_idx) { DataType::Null }
else { DataType::Text(Some(array.value(row_idx).to_string())) }
else { DataType::Text(Some(process_string_value(array.value(row_idx).to_string()))) }
}
arrow::datatypes::DataType::LargeUtf8 => {
let array = column.as_any().downcast_ref::<LargeStringArray>().unwrap();
if array.is_null(row_idx) { DataType::Null }
else { DataType::Text(Some(array.value(row_idx).to_string())) }
else { DataType::Text(Some(process_string_value(array.value(row_idx).to_string()))) }
}
arrow::datatypes::DataType::Binary => {
let array = column.as_any().downcast_ref::<BinaryArray>().unwrap();
@ -144,8 +170,9 @@ pub async fn snowflake_query(
}
arrow::datatypes::DataType::Timestamp(unit, tz) => {
let array = column.as_any().downcast_ref::<TimestampNanosecondArray>().unwrap();
if array.is_null(row_idx) { DataType::Null }
else {
if array.is_null(row_idx) {
DataType::Null
} else {
let nanos = array.value(row_idx);
let (secs, subsec_nanos) = match unit {
TimeUnit::Second => (nanos, 0),
@ -153,12 +180,16 @@ pub async fn snowflake_query(
TimeUnit::Microsecond => (nanos / 1_000_000, (nanos % 1_000_000) * 1000),
TimeUnit::Nanosecond => (nanos / 1_000_000_000, nanos % 1_000_000_000),
};
match Utc.timestamp_opt(secs as i64, subsec_nanos as u32) {
LocalResult::Single(dt) => match tz {
match parse_snowflake_timestamp(secs as i64, subsec_nanos as u32) {
Ok(dt) => match tz {
Some(_) => DataType::Timestamptz(Some(dt)),
None => DataType::Timestamp(Some(dt.naive_utc())),
},
_ => DataType::Null,
Err(e) => {
tracing::error!("Failed to parse timestamp: {}", e);
DataType::Null
}
}
}
}
@ -281,14 +312,14 @@ pub async fn snowflake_query(
} else if let Some(num) = values.as_any().downcast_ref::<Int64Array>() {
Some(Value::Number(num.value(i).into()))
} else if let Some(str) = values.as_any().downcast_ref::<StringArray>() {
Some(Value::String(str.value(i).to_string()))
Some(Value::String(process_string_value(str.value(i).to_string())))
} else {
None
}
})
.collect()
);
DataType::Json(Some(json_array))
DataType::Json(Some(process_json_value(json_array)))
}
}
arrow::datatypes::DataType::Struct(fields) => {
@ -309,7 +340,7 @@ pub async fn snowflake_query(
};
map.insert(field_name.to_string(), value);
}
DataType::Json(Some(Value::Object(map)))
DataType::Json(Some(process_json_value(Value::Object(map))))
}
}
arrow::datatypes::DataType::Union(_, _) => {
@ -347,7 +378,7 @@ pub async fn snowflake_query(
json_map.insert(key.to_string(), Value::Number(value.into()));
}
}
DataType::Json(Some(Value::Object(json_map)))
DataType::Json(Some(process_json_value(Value::Object(json_map))))
}
}
arrow::datatypes::DataType::RunEndEncoded(_, _) => {

View File

@ -7,7 +7,8 @@ use chrono::Utc;
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use uuid::Uuid;
use crate::database::lib::get_pg_pool;
use crate::database::enums::StoredValuesStatus;
use crate::database::{lib::get_pg_pool, schema::dataset_columns};
use crate::utils::clients::ai::embedding_router::embedding_router;
use diesel::sql_types::{Text, Uuid as SqlUuid, Array, Float4, Timestamptz, Integer};
@ -88,6 +89,9 @@ pub async fn store_column_values(
// Query distinct values in batches
let mut offset = 0;
let mut first_batch = true;
let schema_name = organization_id.to_string().replace("-", "_");
loop {
let query = format!(
"SELECT DISTINCT \"{}\" as value
@ -104,6 +108,9 @@ pub async fn store_column_values(
Ok(results) => results,
Err(e) => {
tracing::error!("Error querying stored values: {:?}", e);
if first_batch {
return Err(e);
}
vec![]
}
};
@ -128,10 +135,41 @@ pub async fn store_column_values(
break;
}
// If this is the first batch and we have 15 or fewer values, handle as enum
if first_batch && values.len() <= 15 {
// Get current description
let current_description = diesel::sql_query("SELECT description FROM dataset_columns WHERE id = $1")
.bind::<SqlUuid, _>(column_id)
.get_result::<StoredValueRow>(&mut conn)
.await
.ok()
.and_then(|row| Some(row.value));
// Format new description
let enum_list = format!("Values for this column are: {}", values.join(", "));
let new_description = match current_description {
Some(desc) if !desc.is_empty() => format!("{}. {}", desc, enum_list),
_ => enum_list,
};
// Update column description
diesel::update(dataset_columns::table)
.filter(dataset_columns::id.eq(column_id))
.set((
dataset_columns::description.eq(new_description),
dataset_columns::stored_values_status.eq(StoredValuesStatus::Success),
dataset_columns::stored_values_count.eq(values.len() as i64),
dataset_columns::stored_values_last_synced.eq(Utc::now()),
))
.execute(&mut conn)
.await?;
return Ok(());
}
// Create embeddings for the batch
let embeddings = create_embeddings_batch(&values).await?;
let schema_name = organization_id.to_string().replace("-", "_");
// Insert values and embeddings
for (value, embedding) in values.iter().zip(embeddings.iter()) {
let insert_sql = format!(
@ -154,6 +192,7 @@ pub async fn store_column_values(
.await?;
}
first_batch = false;
offset += BATCH_SIZE;
}