final clean up

This commit is contained in:
dal 2025-03-21 13:23:11 -06:00
parent 8bfd0f04af
commit dc483020be
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
10 changed files with 10 additions and 219 deletions

View File

@ -191,7 +191,7 @@ pub async fn llm_chat_stream(
async fn anthropic_chat_compiler(
model: &AnthropicChatModel,
messages: &Vec<LlmMessage>,
max_tokens: u32,
_max_tokens: u32,
temperature: f32,
timeout: u64,
stop: Option<Vec<String>>,

View File

@ -62,22 +62,6 @@ fn is_o3_model(model: &OpenAiChatModel) -> bool {
matches!(model, OpenAiChatModel::O3Mini)
}
fn should_skip_temperature(val: &(&f32, &OpenAiChatModel)) -> bool {
is_o3_model(val.1)
}
fn should_skip_max_tokens(val: &(&u32, &OpenAiChatModel)) -> bool {
is_o3_model(val.1)
}
fn should_skip_top_p(val: &(&f32, &OpenAiChatModel)) -> bool {
is_o3_model(val.1)
}
fn should_skip_reasoning_effort(val: &(&Option<ReasoningEffort>, &OpenAiChatModel)) -> bool {
!is_o3_model(val.1)
}
#[derive(Serialize, Clone)]
pub struct OpenAiChatRequest {
model: OpenAiChatModel,

View File

@ -1,7 +1,7 @@
pub fn custom_response_system_prompt(
datasets: &String,
input: &String,
orchestrator_output_string: &String,
_input: &String,
_orchestrator_output_string: &String,
) -> String {
format!(
r#"##OVERVIEW

View File

@ -360,7 +360,7 @@ pub async fn snowflake_query(
}
}
}
arrow::datatypes::DataType::Decimal128(precision, scale) => {
arrow::datatypes::DataType::Decimal128(_precision, scale) => {
let array = column
.as_any()
.downcast_ref::<Decimal128Array>()
@ -374,7 +374,7 @@ pub async fn snowflake_query(
DataType::Float8(Some(float_val))
}
}
arrow::datatypes::DataType::Decimal256(precision, scale) => {
arrow::datatypes::DataType::Decimal256(_precision, scale) => {
let array = column
.as_any()
.downcast_ref::<Decimal256Array>()

View File

@ -235,7 +235,7 @@ pub async fn retrieve_dataset_columns_batch(
async fn get_snowflake_columns_batch(
datasets: &[(String, String)],
credentials: &SnowflakeCredentials,
database: Option<String>,
_database: Option<String>,
) -> Result<Vec<DatasetColumnRecord>> {
let snowflake_client = get_snowflake_client(credentials).await?;
@ -720,128 +720,3 @@ async fn get_bigquery_columns_batch(
Ok(columns)
}
async fn get_snowflake_columns(
dataset_name: &String,
schema_name: &String,
credentials: &SnowflakeCredentials,
) -> Result<Vec<DatasetColumnRecord>> {
let snowflake_client = get_snowflake_client(credentials).await?;
let uppercase_dataset_name = dataset_name.to_uppercase();
let uppercase_schema_name = schema_name.to_uppercase();
let sql = format!(
"SELECT
c.COLUMN_NAME AS name,
c.DATA_TYPE AS type_,
CASE WHEN c.IS_NULLABLE = 'YES' THEN true ELSE false END AS nullable,
c.COMMENT AS comment,
t.TABLE_TYPE as source_type
FROM
INFORMATION_SCHEMA.COLUMNS c
JOIN
INFORMATION_SCHEMA.TABLES t
ON c.TABLE_NAME = t.TABLE_NAME
AND c.TABLE_SCHEMA = t.TABLE_SCHEMA
WHERE
c.TABLE_NAME = '{uppercase_dataset_name}'
AND c.TABLE_SCHEMA = '{uppercase_schema_name}'
ORDER BY c.ORDINAL_POSITION;",
);
// Execute the query using the Snowflake client
let results = snowflake_client
.exec(&sql)
.await
.map_err(|e| anyhow!("Error executing query: {:?}", e))?;
let mut columns = Vec::new();
if let snowflake_api::QueryResult::Arrow(record_batches) = results {
for batch in &record_batches {
let schema = batch.schema();
let name_index = schema
.index_of("NAME")
.map_err(|e| anyhow!("Error getting index for NAME: {:?}", e))?;
let type_index = schema
.index_of("TYPE_")
.map_err(|e| anyhow!("Error getting index for TYPE_: {:?}", e))?;
let nullable_index = schema
.index_of("NULLABLE")
.map_err(|e| anyhow!("Error getting index for NULLABLE: {:?}", e))?;
let comment_index = schema
.index_of("COMMENT")
.map_err(|e| anyhow!("Error getting index for COMMENT: {:?}", e))?;
let source_type_index = schema
.index_of("SOURCE_TYPE")
.map_err(|e| anyhow!("Error getting index for SOURCE_TYPE: {:?}", e))?;
let name_column = batch.column(name_index);
let type_column = batch.column(type_index);
let nullable_column = batch.column(nullable_index);
let comment_column = batch.column(comment_index);
let source_type_column = batch.column(source_type_index);
let name_array = name_column
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.ok_or_else(|| anyhow!("Expected StringArray for NAME"))?;
let type_array = type_column
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.ok_or_else(|| anyhow!("Expected StringArray for TYPE_"))?;
let nullable_array = nullable_column
.as_any()
.downcast_ref::<arrow::array::BooleanArray>()
.ok_or_else(|| anyhow!("Expected BooleanArray for NULLABLE"))?;
let comment_array = comment_column
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.ok_or_else(|| anyhow!("Expected StringArray for COMMENT"))?;
let source_type_array = source_type_column
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.ok_or_else(|| anyhow!("Expected StringArray for SOURCE_TYPE"))?;
for i in 0..batch.num_rows() {
let name = name_array.value(i).to_string();
let type_ = type_array.value(i).to_string();
let nullable = nullable_array.value(i);
let comment = if comment_array.is_null(i) {
None
} else {
Some(comment_array.value(i).to_string())
};
let source_type = if source_type_array.is_null(i) {
"TABLE".to_string()
} else {
source_type_array.value(i).to_string()
};
columns.push(DatasetColumnRecord {
dataset_name: dataset_name.clone(),
schema_name: schema_name.clone(),
name,
type_,
nullable,
comment,
source_type,
});
}
}
} else if let snowflake_api::QueryResult::Empty = results {
return Ok(Vec::new());
} else {
return Err(anyhow!(
"Unexpected query result format from Snowflake. Expected Arrow format."
));
}
Ok(columns)
}

View File

@ -36,7 +36,7 @@ pub async fn query_engine(
.await
{
Ok(data_source) => data_source,
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
};
let results = match query_router(&data_source, sql, None, false).await {
@ -91,7 +91,7 @@ pub async fn modeling_query_engine(
.await
{
Ok(data_source) => data_source,
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
};
let results = match query_router(&data_source, sql, Some(25), false).await {

View File

@ -1,6 +1,4 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use database::enums::DataSourceType;
@ -16,9 +14,7 @@ pub enum TargetDialect {
Snowflake,
#[serde(rename = "tsql")]
SqlServer,
#[serde(rename = "mysql")]
MariaDb,
#[serde(rename = "postgres")]
Supabase,
}
@ -37,65 +33,3 @@ impl From<DataSourceType> for TargetDialect {
}
}
}
pub async fn transpile_sql(sql: &String, target_dialect: TargetDialect) -> Result<String> {
let serialized_dialect = serde_json::to_string(&target_dialect).unwrap();
let transpiled_sql = match Command::new("./python/sqlglot_transpiler")
.arg(sql)
.arg(serialized_dialect.replace("\"", ""))
.output()
.await
{
Ok(output) => {
if !output.status.success() {
tracing::error!("Command failed with exit code: {}", output.status);
return Ok(sql.to_string());
}
let stdout = match String::from_utf8(output.stdout) {
Ok(stdout) => stdout,
Err(e) => {
tracing::error!("Error: {}", e);
return Ok(sql.to_string());
}
};
stdout
}
Err(e) => {
tracing::error!("Error: {}", e);
sql.to_string()
}
};
Ok(transpiled_sql)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_transpiler() {
let sql = "WITH customer_sales AS (
SELECT DISTINCT
customer_id,
customer_name,
SUM(total_sales_amount) AS total_sales
FROM sales_summary
GROUP BY customer_id, customer_name
)
SELECT
customer_name,
total_sales
FROM customer_sales
ORDER BY total_sales DESC
LIMIT 1;";
let target_dialect = TargetDialect::Postgres;
let transpiled_sql = transpile_sql(&sql.to_string(), target_dialect)
.await
.unwrap();
println!("transpiled_sql: {:?}", transpiled_sql);
}
}

View File

@ -36,7 +36,7 @@ pub async fn write_query_engine(
.await
{
Ok(data_source) => data_source,
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
};
let results = match query_router(&data_source, sql, None, true).await {

View File

@ -3,7 +3,6 @@ use tokio_stream::StreamExt;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use diesel::QueryDsl;
use serde::Serialize;
use uuid::Uuid;

View File

@ -32,7 +32,6 @@ pub struct StoredValueWithDistance {
const BATCH_SIZE: usize = 10_000;
const MAX_VALUE_LENGTH: usize = 50;
const TIMEOUT_SECONDS: u64 = 60;
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
let pool = get_pg_pool();
@ -77,7 +76,7 @@ pub async fn store_column_values(
dataset_id: &Uuid,
column_name: &str,
column_id: &Uuid,
data_source_id: &Uuid,
_data_source_id: &Uuid,
schema: &str,
table_name: &str,
) -> Result<()> {