mirror of https://github.com/buster-so/buster.git
final clean up
This commit is contained in:
parent
8bfd0f04af
commit
dc483020be
|
@ -191,7 +191,7 @@ pub async fn llm_chat_stream(
|
|||
async fn anthropic_chat_compiler(
|
||||
model: &AnthropicChatModel,
|
||||
messages: &Vec<LlmMessage>,
|
||||
max_tokens: u32,
|
||||
_max_tokens: u32,
|
||||
temperature: f32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
|
|
|
@ -62,22 +62,6 @@ fn is_o3_model(model: &OpenAiChatModel) -> bool {
|
|||
matches!(model, OpenAiChatModel::O3Mini)
|
||||
}
|
||||
|
||||
fn should_skip_temperature(val: &(&f32, &OpenAiChatModel)) -> bool {
|
||||
is_o3_model(val.1)
|
||||
}
|
||||
|
||||
fn should_skip_max_tokens(val: &(&u32, &OpenAiChatModel)) -> bool {
|
||||
is_o3_model(val.1)
|
||||
}
|
||||
|
||||
fn should_skip_top_p(val: &(&f32, &OpenAiChatModel)) -> bool {
|
||||
is_o3_model(val.1)
|
||||
}
|
||||
|
||||
fn should_skip_reasoning_effort(val: &(&Option<ReasoningEffort>, &OpenAiChatModel)) -> bool {
|
||||
!is_o3_model(val.1)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct OpenAiChatRequest {
|
||||
model: OpenAiChatModel,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
pub fn custom_response_system_prompt(
|
||||
datasets: &String,
|
||||
input: &String,
|
||||
orchestrator_output_string: &String,
|
||||
_input: &String,
|
||||
_orchestrator_output_string: &String,
|
||||
) -> String {
|
||||
format!(
|
||||
r#"##OVERVIEW
|
||||
|
|
|
@ -360,7 +360,7 @@ pub async fn snowflake_query(
|
|||
}
|
||||
}
|
||||
}
|
||||
arrow::datatypes::DataType::Decimal128(precision, scale) => {
|
||||
arrow::datatypes::DataType::Decimal128(_precision, scale) => {
|
||||
let array = column
|
||||
.as_any()
|
||||
.downcast_ref::<Decimal128Array>()
|
||||
|
@ -374,7 +374,7 @@ pub async fn snowflake_query(
|
|||
DataType::Float8(Some(float_val))
|
||||
}
|
||||
}
|
||||
arrow::datatypes::DataType::Decimal256(precision, scale) => {
|
||||
arrow::datatypes::DataType::Decimal256(_precision, scale) => {
|
||||
let array = column
|
||||
.as_any()
|
||||
.downcast_ref::<Decimal256Array>()
|
||||
|
|
|
@ -235,7 +235,7 @@ pub async fn retrieve_dataset_columns_batch(
|
|||
async fn get_snowflake_columns_batch(
|
||||
datasets: &[(String, String)],
|
||||
credentials: &SnowflakeCredentials,
|
||||
database: Option<String>,
|
||||
_database: Option<String>,
|
||||
) -> Result<Vec<DatasetColumnRecord>> {
|
||||
let snowflake_client = get_snowflake_client(credentials).await?;
|
||||
|
||||
|
@ -720,128 +720,3 @@ async fn get_bigquery_columns_batch(
|
|||
|
||||
Ok(columns)
|
||||
}
|
||||
|
||||
async fn get_snowflake_columns(
|
||||
dataset_name: &String,
|
||||
schema_name: &String,
|
||||
credentials: &SnowflakeCredentials,
|
||||
) -> Result<Vec<DatasetColumnRecord>> {
|
||||
let snowflake_client = get_snowflake_client(credentials).await?;
|
||||
|
||||
let uppercase_dataset_name = dataset_name.to_uppercase();
|
||||
let uppercase_schema_name = schema_name.to_uppercase();
|
||||
|
||||
let sql = format!(
|
||||
"SELECT
|
||||
c.COLUMN_NAME AS name,
|
||||
c.DATA_TYPE AS type_,
|
||||
CASE WHEN c.IS_NULLABLE = 'YES' THEN true ELSE false END AS nullable,
|
||||
c.COMMENT AS comment,
|
||||
t.TABLE_TYPE as source_type
|
||||
FROM
|
||||
INFORMATION_SCHEMA.COLUMNS c
|
||||
JOIN
|
||||
INFORMATION_SCHEMA.TABLES t
|
||||
ON c.TABLE_NAME = t.TABLE_NAME
|
||||
AND c.TABLE_SCHEMA = t.TABLE_SCHEMA
|
||||
WHERE
|
||||
c.TABLE_NAME = '{uppercase_dataset_name}'
|
||||
AND c.TABLE_SCHEMA = '{uppercase_schema_name}'
|
||||
ORDER BY c.ORDINAL_POSITION;",
|
||||
);
|
||||
|
||||
// Execute the query using the Snowflake client
|
||||
let results = snowflake_client
|
||||
.exec(&sql)
|
||||
.await
|
||||
.map_err(|e| anyhow!("Error executing query: {:?}", e))?;
|
||||
|
||||
let mut columns = Vec::new();
|
||||
|
||||
if let snowflake_api::QueryResult::Arrow(record_batches) = results {
|
||||
for batch in &record_batches {
|
||||
let schema = batch.schema();
|
||||
|
||||
let name_index = schema
|
||||
.index_of("NAME")
|
||||
.map_err(|e| anyhow!("Error getting index for NAME: {:?}", e))?;
|
||||
let type_index = schema
|
||||
.index_of("TYPE_")
|
||||
.map_err(|e| anyhow!("Error getting index for TYPE_: {:?}", e))?;
|
||||
let nullable_index = schema
|
||||
.index_of("NULLABLE")
|
||||
.map_err(|e| anyhow!("Error getting index for NULLABLE: {:?}", e))?;
|
||||
let comment_index = schema
|
||||
.index_of("COMMENT")
|
||||
.map_err(|e| anyhow!("Error getting index for COMMENT: {:?}", e))?;
|
||||
let source_type_index = schema
|
||||
.index_of("SOURCE_TYPE")
|
||||
.map_err(|e| anyhow!("Error getting index for SOURCE_TYPE: {:?}", e))?;
|
||||
|
||||
let name_column = batch.column(name_index);
|
||||
let type_column = batch.column(type_index);
|
||||
let nullable_column = batch.column(nullable_index);
|
||||
let comment_column = batch.column(comment_index);
|
||||
let source_type_column = batch.column(source_type_index);
|
||||
|
||||
let name_array = name_column
|
||||
.as_any()
|
||||
.downcast_ref::<arrow::array::StringArray>()
|
||||
.ok_or_else(|| anyhow!("Expected StringArray for NAME"))?;
|
||||
|
||||
let type_array = type_column
|
||||
.as_any()
|
||||
.downcast_ref::<arrow::array::StringArray>()
|
||||
.ok_or_else(|| anyhow!("Expected StringArray for TYPE_"))?;
|
||||
|
||||
let nullable_array = nullable_column
|
||||
.as_any()
|
||||
.downcast_ref::<arrow::array::BooleanArray>()
|
||||
.ok_or_else(|| anyhow!("Expected BooleanArray for NULLABLE"))?;
|
||||
|
||||
let comment_array = comment_column
|
||||
.as_any()
|
||||
.downcast_ref::<arrow::array::StringArray>()
|
||||
.ok_or_else(|| anyhow!("Expected StringArray for COMMENT"))?;
|
||||
|
||||
let source_type_array = source_type_column
|
||||
.as_any()
|
||||
.downcast_ref::<arrow::array::StringArray>()
|
||||
.ok_or_else(|| anyhow!("Expected StringArray for SOURCE_TYPE"))?;
|
||||
|
||||
for i in 0..batch.num_rows() {
|
||||
let name = name_array.value(i).to_string();
|
||||
let type_ = type_array.value(i).to_string();
|
||||
let nullable = nullable_array.value(i);
|
||||
let comment = if comment_array.is_null(i) {
|
||||
None
|
||||
} else {
|
||||
Some(comment_array.value(i).to_string())
|
||||
};
|
||||
let source_type = if source_type_array.is_null(i) {
|
||||
"TABLE".to_string()
|
||||
} else {
|
||||
source_type_array.value(i).to_string()
|
||||
};
|
||||
|
||||
columns.push(DatasetColumnRecord {
|
||||
dataset_name: dataset_name.clone(),
|
||||
schema_name: schema_name.clone(),
|
||||
name,
|
||||
type_,
|
||||
nullable,
|
||||
comment,
|
||||
source_type,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if let snowflake_api::QueryResult::Empty = results {
|
||||
return Ok(Vec::new());
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Unexpected query result format from Snowflake. Expected Arrow format."
|
||||
));
|
||||
}
|
||||
|
||||
Ok(columns)
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ pub async fn query_engine(
|
|||
.await
|
||||
{
|
||||
Ok(data_source) => data_source,
|
||||
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
|
||||
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
|
||||
};
|
||||
|
||||
let results = match query_router(&data_source, sql, None, false).await {
|
||||
|
@ -91,7 +91,7 @@ pub async fn modeling_query_engine(
|
|||
.await
|
||||
{
|
||||
Ok(data_source) => data_source,
|
||||
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
|
||||
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
|
||||
};
|
||||
|
||||
let results = match query_router(&data_source, sql, Some(25), false).await {
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::process::Command;
|
||||
|
||||
use database::enums::DataSourceType;
|
||||
|
||||
|
@ -16,9 +14,7 @@ pub enum TargetDialect {
|
|||
Snowflake,
|
||||
#[serde(rename = "tsql")]
|
||||
SqlServer,
|
||||
#[serde(rename = "mysql")]
|
||||
MariaDb,
|
||||
#[serde(rename = "postgres")]
|
||||
Supabase,
|
||||
}
|
||||
|
||||
|
@ -37,65 +33,3 @@ impl From<DataSourceType> for TargetDialect {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn transpile_sql(sql: &String, target_dialect: TargetDialect) -> Result<String> {
|
||||
let serialized_dialect = serde_json::to_string(&target_dialect).unwrap();
|
||||
|
||||
let transpiled_sql = match Command::new("./python/sqlglot_transpiler")
|
||||
.arg(sql)
|
||||
.arg(serialized_dialect.replace("\"", ""))
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
Ok(output) => {
|
||||
if !output.status.success() {
|
||||
tracing::error!("Command failed with exit code: {}", output.status);
|
||||
return Ok(sql.to_string());
|
||||
}
|
||||
|
||||
let stdout = match String::from_utf8(output.stdout) {
|
||||
Ok(stdout) => stdout,
|
||||
Err(e) => {
|
||||
tracing::error!("Error: {}", e);
|
||||
return Ok(sql.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
stdout
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error: {}", e);
|
||||
sql.to_string()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(transpiled_sql)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_transpiler() {
|
||||
let sql = "WITH customer_sales AS (
|
||||
SELECT DISTINCT
|
||||
customer_id,
|
||||
customer_name,
|
||||
SUM(total_sales_amount) AS total_sales
|
||||
FROM sales_summary
|
||||
GROUP BY customer_id, customer_name
|
||||
)
|
||||
SELECT
|
||||
customer_name,
|
||||
total_sales
|
||||
FROM customer_sales
|
||||
ORDER BY total_sales DESC
|
||||
LIMIT 1;";
|
||||
let target_dialect = TargetDialect::Postgres;
|
||||
let transpiled_sql = transpile_sql(&sql.to_string(), target_dialect)
|
||||
.await
|
||||
.unwrap();
|
||||
println!("transpiled_sql: {:?}", transpiled_sql);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ pub async fn write_query_engine(
|
|||
.await
|
||||
{
|
||||
Ok(data_source) => data_source,
|
||||
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
|
||||
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
|
||||
};
|
||||
|
||||
let results = match query_router(&data_source, sql, None, true).await {
|
||||
|
|
|
@ -3,7 +3,6 @@ use tokio_stream::StreamExt;
|
|||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use diesel::QueryDsl;
|
||||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ pub struct StoredValueWithDistance {
|
|||
|
||||
const BATCH_SIZE: usize = 10_000;
|
||||
const MAX_VALUE_LENGTH: usize = 50;
|
||||
const TIMEOUT_SECONDS: u64 = 60;
|
||||
|
||||
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
|
||||
let pool = get_pg_pool();
|
||||
|
@ -77,7 +76,7 @@ pub async fn store_column_values(
|
|||
dataset_id: &Uuid,
|
||||
column_name: &str,
|
||||
column_id: &Uuid,
|
||||
data_source_id: &Uuid,
|
||||
_data_source_id: &Uuid,
|
||||
schema: &str,
|
||||
table_name: &str,
|
||||
) -> Result<()> {
|
||||
|
|
Loading…
Reference in New Issue