mirror of https://github.com/buster-so/buster.git
final clean up
This commit is contained in:
parent
8bfd0f04af
commit
dc483020be
|
@ -191,7 +191,7 @@ pub async fn llm_chat_stream(
|
||||||
async fn anthropic_chat_compiler(
|
async fn anthropic_chat_compiler(
|
||||||
model: &AnthropicChatModel,
|
model: &AnthropicChatModel,
|
||||||
messages: &Vec<LlmMessage>,
|
messages: &Vec<LlmMessage>,
|
||||||
max_tokens: u32,
|
_max_tokens: u32,
|
||||||
temperature: f32,
|
temperature: f32,
|
||||||
timeout: u64,
|
timeout: u64,
|
||||||
stop: Option<Vec<String>>,
|
stop: Option<Vec<String>>,
|
||||||
|
|
|
@ -62,22 +62,6 @@ fn is_o3_model(model: &OpenAiChatModel) -> bool {
|
||||||
matches!(model, OpenAiChatModel::O3Mini)
|
matches!(model, OpenAiChatModel::O3Mini)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn should_skip_temperature(val: &(&f32, &OpenAiChatModel)) -> bool {
|
|
||||||
is_o3_model(val.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn should_skip_max_tokens(val: &(&u32, &OpenAiChatModel)) -> bool {
|
|
||||||
is_o3_model(val.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn should_skip_top_p(val: &(&f32, &OpenAiChatModel)) -> bool {
|
|
||||||
is_o3_model(val.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn should_skip_reasoning_effort(val: &(&Option<ReasoningEffort>, &OpenAiChatModel)) -> bool {
|
|
||||||
!is_o3_model(val.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
#[derive(Serialize, Clone)]
|
||||||
pub struct OpenAiChatRequest {
|
pub struct OpenAiChatRequest {
|
||||||
model: OpenAiChatModel,
|
model: OpenAiChatModel,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
pub fn custom_response_system_prompt(
|
pub fn custom_response_system_prompt(
|
||||||
datasets: &String,
|
datasets: &String,
|
||||||
input: &String,
|
_input: &String,
|
||||||
orchestrator_output_string: &String,
|
_orchestrator_output_string: &String,
|
||||||
) -> String {
|
) -> String {
|
||||||
format!(
|
format!(
|
||||||
r#"##OVERVIEW
|
r#"##OVERVIEW
|
||||||
|
|
|
@ -360,7 +360,7 @@ pub async fn snowflake_query(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
arrow::datatypes::DataType::Decimal128(precision, scale) => {
|
arrow::datatypes::DataType::Decimal128(_precision, scale) => {
|
||||||
let array = column
|
let array = column
|
||||||
.as_any()
|
.as_any()
|
||||||
.downcast_ref::<Decimal128Array>()
|
.downcast_ref::<Decimal128Array>()
|
||||||
|
@ -374,7 +374,7 @@ pub async fn snowflake_query(
|
||||||
DataType::Float8(Some(float_val))
|
DataType::Float8(Some(float_val))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
arrow::datatypes::DataType::Decimal256(precision, scale) => {
|
arrow::datatypes::DataType::Decimal256(_precision, scale) => {
|
||||||
let array = column
|
let array = column
|
||||||
.as_any()
|
.as_any()
|
||||||
.downcast_ref::<Decimal256Array>()
|
.downcast_ref::<Decimal256Array>()
|
||||||
|
|
|
@ -235,7 +235,7 @@ pub async fn retrieve_dataset_columns_batch(
|
||||||
async fn get_snowflake_columns_batch(
|
async fn get_snowflake_columns_batch(
|
||||||
datasets: &[(String, String)],
|
datasets: &[(String, String)],
|
||||||
credentials: &SnowflakeCredentials,
|
credentials: &SnowflakeCredentials,
|
||||||
database: Option<String>,
|
_database: Option<String>,
|
||||||
) -> Result<Vec<DatasetColumnRecord>> {
|
) -> Result<Vec<DatasetColumnRecord>> {
|
||||||
let snowflake_client = get_snowflake_client(credentials).await?;
|
let snowflake_client = get_snowflake_client(credentials).await?;
|
||||||
|
|
||||||
|
@ -720,128 +720,3 @@ async fn get_bigquery_columns_batch(
|
||||||
|
|
||||||
Ok(columns)
|
Ok(columns)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_snowflake_columns(
|
|
||||||
dataset_name: &String,
|
|
||||||
schema_name: &String,
|
|
||||||
credentials: &SnowflakeCredentials,
|
|
||||||
) -> Result<Vec<DatasetColumnRecord>> {
|
|
||||||
let snowflake_client = get_snowflake_client(credentials).await?;
|
|
||||||
|
|
||||||
let uppercase_dataset_name = dataset_name.to_uppercase();
|
|
||||||
let uppercase_schema_name = schema_name.to_uppercase();
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
"SELECT
|
|
||||||
c.COLUMN_NAME AS name,
|
|
||||||
c.DATA_TYPE AS type_,
|
|
||||||
CASE WHEN c.IS_NULLABLE = 'YES' THEN true ELSE false END AS nullable,
|
|
||||||
c.COMMENT AS comment,
|
|
||||||
t.TABLE_TYPE as source_type
|
|
||||||
FROM
|
|
||||||
INFORMATION_SCHEMA.COLUMNS c
|
|
||||||
JOIN
|
|
||||||
INFORMATION_SCHEMA.TABLES t
|
|
||||||
ON c.TABLE_NAME = t.TABLE_NAME
|
|
||||||
AND c.TABLE_SCHEMA = t.TABLE_SCHEMA
|
|
||||||
WHERE
|
|
||||||
c.TABLE_NAME = '{uppercase_dataset_name}'
|
|
||||||
AND c.TABLE_SCHEMA = '{uppercase_schema_name}'
|
|
||||||
ORDER BY c.ORDINAL_POSITION;",
|
|
||||||
);
|
|
||||||
|
|
||||||
// Execute the query using the Snowflake client
|
|
||||||
let results = snowflake_client
|
|
||||||
.exec(&sql)
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow!("Error executing query: {:?}", e))?;
|
|
||||||
|
|
||||||
let mut columns = Vec::new();
|
|
||||||
|
|
||||||
if let snowflake_api::QueryResult::Arrow(record_batches) = results {
|
|
||||||
for batch in &record_batches {
|
|
||||||
let schema = batch.schema();
|
|
||||||
|
|
||||||
let name_index = schema
|
|
||||||
.index_of("NAME")
|
|
||||||
.map_err(|e| anyhow!("Error getting index for NAME: {:?}", e))?;
|
|
||||||
let type_index = schema
|
|
||||||
.index_of("TYPE_")
|
|
||||||
.map_err(|e| anyhow!("Error getting index for TYPE_: {:?}", e))?;
|
|
||||||
let nullable_index = schema
|
|
||||||
.index_of("NULLABLE")
|
|
||||||
.map_err(|e| anyhow!("Error getting index for NULLABLE: {:?}", e))?;
|
|
||||||
let comment_index = schema
|
|
||||||
.index_of("COMMENT")
|
|
||||||
.map_err(|e| anyhow!("Error getting index for COMMENT: {:?}", e))?;
|
|
||||||
let source_type_index = schema
|
|
||||||
.index_of("SOURCE_TYPE")
|
|
||||||
.map_err(|e| anyhow!("Error getting index for SOURCE_TYPE: {:?}", e))?;
|
|
||||||
|
|
||||||
let name_column = batch.column(name_index);
|
|
||||||
let type_column = batch.column(type_index);
|
|
||||||
let nullable_column = batch.column(nullable_index);
|
|
||||||
let comment_column = batch.column(comment_index);
|
|
||||||
let source_type_column = batch.column(source_type_index);
|
|
||||||
|
|
||||||
let name_array = name_column
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<arrow::array::StringArray>()
|
|
||||||
.ok_or_else(|| anyhow!("Expected StringArray for NAME"))?;
|
|
||||||
|
|
||||||
let type_array = type_column
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<arrow::array::StringArray>()
|
|
||||||
.ok_or_else(|| anyhow!("Expected StringArray for TYPE_"))?;
|
|
||||||
|
|
||||||
let nullable_array = nullable_column
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<arrow::array::BooleanArray>()
|
|
||||||
.ok_or_else(|| anyhow!("Expected BooleanArray for NULLABLE"))?;
|
|
||||||
|
|
||||||
let comment_array = comment_column
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<arrow::array::StringArray>()
|
|
||||||
.ok_or_else(|| anyhow!("Expected StringArray for COMMENT"))?;
|
|
||||||
|
|
||||||
let source_type_array = source_type_column
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<arrow::array::StringArray>()
|
|
||||||
.ok_or_else(|| anyhow!("Expected StringArray for SOURCE_TYPE"))?;
|
|
||||||
|
|
||||||
for i in 0..batch.num_rows() {
|
|
||||||
let name = name_array.value(i).to_string();
|
|
||||||
let type_ = type_array.value(i).to_string();
|
|
||||||
let nullable = nullable_array.value(i);
|
|
||||||
let comment = if comment_array.is_null(i) {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(comment_array.value(i).to_string())
|
|
||||||
};
|
|
||||||
let source_type = if source_type_array.is_null(i) {
|
|
||||||
"TABLE".to_string()
|
|
||||||
} else {
|
|
||||||
source_type_array.value(i).to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
columns.push(DatasetColumnRecord {
|
|
||||||
dataset_name: dataset_name.clone(),
|
|
||||||
schema_name: schema_name.clone(),
|
|
||||||
name,
|
|
||||||
type_,
|
|
||||||
nullable,
|
|
||||||
comment,
|
|
||||||
source_type,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let snowflake_api::QueryResult::Empty = results {
|
|
||||||
return Ok(Vec::new());
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"Unexpected query result format from Snowflake. Expected Arrow format."
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(columns)
|
|
||||||
}
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ pub async fn query_engine(
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(data_source) => data_source,
|
Ok(data_source) => data_source,
|
||||||
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
|
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
|
||||||
};
|
};
|
||||||
|
|
||||||
let results = match query_router(&data_source, sql, None, false).await {
|
let results = match query_router(&data_source, sql, None, false).await {
|
||||||
|
@ -91,7 +91,7 @@ pub async fn modeling_query_engine(
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(data_source) => data_source,
|
Ok(data_source) => data_source,
|
||||||
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
|
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
|
||||||
};
|
};
|
||||||
|
|
||||||
let results = match query_router(&data_source, sql, Some(25), false).await {
|
let results = match query_router(&data_source, sql, Some(25), false).await {
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
use anyhow::Result;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::process::Command;
|
|
||||||
|
|
||||||
use database::enums::DataSourceType;
|
use database::enums::DataSourceType;
|
||||||
|
|
||||||
|
@ -16,9 +14,7 @@ pub enum TargetDialect {
|
||||||
Snowflake,
|
Snowflake,
|
||||||
#[serde(rename = "tsql")]
|
#[serde(rename = "tsql")]
|
||||||
SqlServer,
|
SqlServer,
|
||||||
#[serde(rename = "mysql")]
|
|
||||||
MariaDb,
|
MariaDb,
|
||||||
#[serde(rename = "postgres")]
|
|
||||||
Supabase,
|
Supabase,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,65 +33,3 @@ impl From<DataSourceType> for TargetDialect {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn transpile_sql(sql: &String, target_dialect: TargetDialect) -> Result<String> {
|
|
||||||
let serialized_dialect = serde_json::to_string(&target_dialect).unwrap();
|
|
||||||
|
|
||||||
let transpiled_sql = match Command::new("./python/sqlglot_transpiler")
|
|
||||||
.arg(sql)
|
|
||||||
.arg(serialized_dialect.replace("\"", ""))
|
|
||||||
.output()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(output) => {
|
|
||||||
if !output.status.success() {
|
|
||||||
tracing::error!("Command failed with exit code: {}", output.status);
|
|
||||||
return Ok(sql.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let stdout = match String::from_utf8(output.stdout) {
|
|
||||||
Ok(stdout) => stdout,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Error: {}", e);
|
|
||||||
return Ok(sql.to_string());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
stdout
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Error: {}", e);
|
|
||||||
sql.to_string()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(transpiled_sql)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_transpiler() {
|
|
||||||
let sql = "WITH customer_sales AS (
|
|
||||||
SELECT DISTINCT
|
|
||||||
customer_id,
|
|
||||||
customer_name,
|
|
||||||
SUM(total_sales_amount) AS total_sales
|
|
||||||
FROM sales_summary
|
|
||||||
GROUP BY customer_id, customer_name
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
customer_name,
|
|
||||||
total_sales
|
|
||||||
FROM customer_sales
|
|
||||||
ORDER BY total_sales DESC
|
|
||||||
LIMIT 1;";
|
|
||||||
let target_dialect = TargetDialect::Postgres;
|
|
||||||
let transpiled_sql = transpile_sql(&sql.to_string(), target_dialect)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
println!("transpiled_sql: {:?}", transpiled_sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ pub async fn write_query_engine(
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(data_source) => data_source,
|
Ok(data_source) => data_source,
|
||||||
Err(e) => return Err(anyhow::anyhow!("Data source not found")),
|
Err(_) => return Err(anyhow::anyhow!("Data source not found")),
|
||||||
};
|
};
|
||||||
|
|
||||||
let results = match query_router(&data_source, sql, None, true).await {
|
let results = match query_router(&data_source, sql, None, true).await {
|
||||||
|
|
|
@ -3,7 +3,6 @@ use tokio_stream::StreamExt;
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use diesel::QueryDsl;
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,6 @@ pub struct StoredValueWithDistance {
|
||||||
|
|
||||||
const BATCH_SIZE: usize = 10_000;
|
const BATCH_SIZE: usize = 10_000;
|
||||||
const MAX_VALUE_LENGTH: usize = 50;
|
const MAX_VALUE_LENGTH: usize = 50;
|
||||||
const TIMEOUT_SECONDS: u64 = 60;
|
|
||||||
|
|
||||||
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
|
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
|
||||||
let pool = get_pg_pool();
|
let pool = get_pg_pool();
|
||||||
|
@ -77,7 +76,7 @@ pub async fn store_column_values(
|
||||||
dataset_id: &Uuid,
|
dataset_id: &Uuid,
|
||||||
column_name: &str,
|
column_name: &str,
|
||||||
column_id: &Uuid,
|
column_id: &Uuid,
|
||||||
data_source_id: &Uuid,
|
_data_source_id: &Uuid,
|
||||||
schema: &str,
|
schema: &str,
|
||||||
table_name: &str,
|
table_name: &str,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
|
Loading…
Reference in New Issue