From 3beec0878c86c1abca57bc31967fa19da9d43d6a Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 4 Apr 2025 16:19:31 -0600 Subject: [PATCH 1/3] fix the string error --- .../src/routes/rest/routes/datasets/post_dataset.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/api/server/src/routes/rest/routes/datasets/post_dataset.rs b/api/server/src/routes/rest/routes/datasets/post_dataset.rs index a8ca1dd26..c4877b303 100644 --- a/api/server/src/routes/rest/routes/datasets/post_dataset.rs +++ b/api/server/src/routes/rest/routes/datasets/post_dataset.rs @@ -36,14 +36,14 @@ pub async fn post_dataset( Ok(None) => { return Err(( StatusCode::FORBIDDEN, - "User does not belong to any organization", + "User does not belong to any organization".to_string(), )); } Err(e) => { tracing::error!("Error getting user organization id: {:?}", e); return Err(( StatusCode::INTERNAL_SERVER_ERROR, - "Error getting user organization id", + "Error getting user organization id".to_string(), )); } }; @@ -53,14 +53,14 @@ pub async fn post_dataset( Ok(false) => { return Err(( StatusCode::FORBIDDEN, - "Insufficient permissions", + "Insufficient permissions".to_string(), )) } Err(e) => { tracing::error!("Error checking user permissions: {:?}", e); return Err(( StatusCode::INTERNAL_SERVER_ERROR, - "Error checking user permissions", + "Error checking user permissions".to_string(), )); } } @@ -78,7 +78,7 @@ pub async fn post_dataset( tracing::error!("Error creating dataset: {:?}", e); return Err(( StatusCode::INTERNAL_SERVER_ERROR, - "Error creating dataset", + "Error creating dataset".to_string(), )); } }; From 028eded9c56cf893901c731980f93980a21a2b55 Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 4 Apr 2025 16:22:58 -0600 Subject: [PATCH 2/3] cleaned up a few more utils --- api/server/src/utils/charting/mod.rs | 1 - api/server/src/utils/charting/types.rs | 383 ------------------ api/server/src/utils/mod.rs | 4 - .../serde_helpers/deserialization_helpers.rs | 21 - api/server/src/utils/serde_helpers/mod.rs | 1 - api/server/src/utils/validation/mod.rs | 5 - .../src/utils/validation/type_mapping.rs | 234 ----------- api/server/src/utils/validation/types.rs | 182 --------- 8 files changed, 831 deletions(-) delete mode 100644 api/server/src/utils/charting/mod.rs delete mode 100644 api/server/src/utils/charting/types.rs delete mode 100644 api/server/src/utils/serde_helpers/deserialization_helpers.rs delete mode 100644 api/server/src/utils/serde_helpers/mod.rs delete mode 100644 api/server/src/utils/validation/mod.rs delete mode 100644 api/server/src/utils/validation/type_mapping.rs delete mode 100644 api/server/src/utils/validation/types.rs diff --git a/api/server/src/utils/charting/mod.rs b/api/server/src/utils/charting/mod.rs deleted file mode 100644 index cd408564e..000000000 --- a/api/server/src/utils/charting/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod types; diff --git a/api/server/src/utils/charting/types.rs b/api/server/src/utils/charting/types.rs deleted file mode 100644 index 61f76a0b4..000000000 --- a/api/server/src/utils/charting/types.rs +++ /dev/null @@ -1,383 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub enum ViewType { - Chart, - Table, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub enum ChartType { - Line, - Bar, - Scatter, - Pie, - Metric, - Table, - Combo, -} - -impl ChartType { - pub fn to_string(&self) -> String { - match self { - ChartType::Line => "line".to_string(), - ChartType::Bar => "bar".to_string(), - ChartType::Scatter => "scatter".to_string(), - ChartType::Pie => "pie".to_string(), - ChartType::Metric => "metric".to_string(), - ChartType::Table => "table".to_string(), - ChartType::Combo => "combo".to_string(), - } - } - - pub fn from_string(chart_type: &str) -> ChartType { - match chart_type { - "line" => ChartType::Line, - "bar" => ChartType::Bar, - "scatter" => ChartType::Scatter, - "pie" => ChartType::Pie, - "metric" => ChartType::Metric, - "table" => ChartType::Table, - "combo" => ChartType::Combo, - _ => ChartType::Table, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct BusterChartConfig { - pub selected_chart_type: ChartType, - pub selected_view: ViewType, - #[serde(skip_serializing_if = "Option::is_none")] - pub column_label_formats: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub column_settings: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub colors: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub show_legend: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub grid_lines: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub show_legend_headline: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub goal_lines: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub trendlines: Option>, - #[serde(flatten)] - pub y_axis_config: YAxisConfig, - #[serde(flatten)] - pub x_axis_config: XAxisConfig, - #[serde(flatten)] - pub y2_axis_config: Y2AxisConfig, - #[serde(flatten)] - pub bar_chart_props: BarChartProps, - #[serde(flatten)] - pub line_chart_props: LineChartProps, - #[serde(flatten)] - pub scatter_chart_props: ScatterChartProps, - #[serde(flatten)] - pub pie_chart_props: PieChartProps, - #[serde(flatten)] - pub table_chart_props: TableChartProps, - #[serde(flatten)] - pub combo_chart_props: ComboChartProps, - #[serde(flatten)] - pub metric_chart_props: MetricChartProps, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct LineChartProps { - #[serde(skip_serializing_if = "Option::is_none")] - pub line_style: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub line_group_type: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct BarChartProps { - #[serde(skip_serializing_if = "Option::is_none")] - pub bar_and_line_axis: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub bar_layout: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub bar_sort_by: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub bar_group_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub bar_show_total_at_top: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct ScatterChartProps { - #[serde(skip_serializing_if = "Option::is_none")] - pub scatter_axis: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub scatter_dot_size: Option<(f64, f64)>, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct PieChartProps { - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_chart_axis: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_display_label_as: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_show_inner_label: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_inner_label_aggregate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_inner_label_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_label_position: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_donut_width: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub pie_minimum_slice_percentage: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct TableChartProps { - #[serde(skip_serializing_if = "Option::is_none")] - pub table_column_order: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub table_column_widths: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub table_header_background_color: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub table_header_font_color: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub table_column_font_color: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct ComboChartProps { - #[serde(skip_serializing_if = "Option::is_none")] - pub combo_chart_axis: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct MetricChartProps { - pub metric_column_id: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub metric_value_aggregate: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub metric_header: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub metric_sub_header: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub metric_value_label: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum MetricTitle { - String(String), - Derived(DerivedMetricTitle), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DerivedMetricTitle { - pub column_id: String, - pub use_value: bool, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ShowLegendHeadline { - Bool(bool), - String(String), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ColumnSettings { - #[serde(skip_serializing_if = "Option::is_none")] - pub show_data_labels: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub column_visualization: Option, - #[serde(flatten)] - pub line_settings: LineColumnSettings, - #[serde(flatten)] - pub bar_settings: BarColumnSettings, - #[serde(flatten)] - pub dot_settings: DotColumnSettings, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct LineColumnSettings { - #[serde(skip_serializing_if = "Option::is_none")] - pub line_dash_style: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub line_width: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub line_style: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub line_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub line_symbol_size: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct BarColumnSettings { - #[serde(skip_serializing_if = "Option::is_none")] - pub bar_roundness: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct DotColumnSettings { - #[serde(skip_serializing_if = "Option::is_none")] - pub line_symbol_size: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ColumnLabelFormat { - #[serde(skip_serializing_if = "Option::is_none")] - pub style: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub column_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub display_name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub number_separator_style: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub minimum_fraction_digits: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub maximum_fraction_digits: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub multiplier: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub prefix: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub suffix: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub replace_missing_data_with: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub use_relative_time: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_utc: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub make_label_human_readable: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub currency: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub date_format: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub convert_number_to: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct GoalLine { - pub show: bool, - pub value: f64, - pub show_goal_line_label: bool, - pub goal_line_label: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub goal_line_color: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Trendline { - pub show: bool, - pub show_trendline_label: bool, - pub trendline_label: Option, - pub type_: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub trendline_color: Option, - pub column_id: String, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct YAxisConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub y_axis_show_axis_label: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y_axis_show_axis_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y_axis_axis_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y_axis_start_axis_at_zero: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y_axis_scale_type: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct Y2AxisConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub y2_axis_show_axis_label: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y2_axis_show_axis_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y2_axis_axis_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y2_axis_start_axis_at_zero: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub y2_axis_scale_type: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct XAxisConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub x_axis_show_ticks: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub x_axis_show_axis_label: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub x_axis_show_axis_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub x_axis_axis_title: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub x_axis_label_rotation: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub x_axis_data_zoom: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct CategoryAxisStyleConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub category_show_total_at_top: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub category_axis_title: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct BarAndLineAxis { - pub x: Vec, - pub y: Vec, - pub category: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub tooltip: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ScatterAxis { - pub x: Vec, - pub y: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub category: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub size: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tooltip: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ComboChartAxis { - pub x: Vec, - pub y: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub y2: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub category: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tooltip: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PieChartAxis { - pub x: Vec, - pub y: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub tooltip: Option>, -} diff --git a/api/server/src/utils/mod.rs b/api/server/src/utils/mod.rs index 7ba03c04e..bd4e4ad1b 100644 --- a/api/server/src/utils/mod.rs +++ b/api/server/src/utils/mod.rs @@ -1,11 +1,7 @@ -pub mod charting; pub mod clients; pub mod security; -pub mod serde_helpers; pub mod stored_values; -pub mod validation; pub use agents::*; pub use security::*; pub use stored_values::*; -pub use validation::*; diff --git a/api/server/src/utils/serde_helpers/deserialization_helpers.rs b/api/server/src/utils/serde_helpers/deserialization_helpers.rs deleted file mode 100644 index 65bc6797e..000000000 --- a/api/server/src/utils/serde_helpers/deserialization_helpers.rs +++ /dev/null @@ -1,21 +0,0 @@ -use serde::{de::Deserializer, Deserialize}; -use serde_json::Value; - -pub fn deserialize_double_option<'de, T, D>(deserializer: D) -> Result>, D::Error> -where - T: serde::Deserialize<'de>, - D: Deserializer<'de>, -{ - let value = Value::deserialize(deserializer)?; - - match value { - Value::Null => Ok(Some(None)), // explicit null - Value::Object(obj) if obj.is_empty() => Ok(None), // empty object - _ => { - match T::deserialize(value) { - Ok(val) => Ok(Some(Some(val))), - Err(_) => Ok(None) - } - } - } -} diff --git a/api/server/src/utils/serde_helpers/mod.rs b/api/server/src/utils/serde_helpers/mod.rs deleted file mode 100644 index 20b7457cd..000000000 --- a/api/server/src/utils/serde_helpers/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod deserialization_helpers; \ No newline at end of file diff --git a/api/server/src/utils/validation/mod.rs b/api/server/src/utils/validation/mod.rs deleted file mode 100644 index 474b840b6..000000000 --- a/api/server/src/utils/validation/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod types; -pub mod type_mapping; - -pub use types::*; -pub use type_mapping::*; \ No newline at end of file diff --git a/api/server/src/utils/validation/type_mapping.rs b/api/server/src/utils/validation/type_mapping.rs deleted file mode 100644 index d6768468c..000000000 --- a/api/server/src/utils/validation/type_mapping.rs +++ /dev/null @@ -1,234 +0,0 @@ -use database::enums::DataSourceType; -use once_cell::sync::Lazy; -use std::collections::HashMap; - -use query_engine::data_types::DataType; - -// Standard types we use across all data sources -#[derive(Debug, Clone, PartialEq)] -pub enum StandardType { - Text, - Integer, - Float, - Boolean, - Date, - Timestamp, - Json, - Unknown, -} - -impl StandardType { - pub fn from_data_type(data_type: &DataType) -> Self { - match data_type.simple_type() { - Some(simple_type) => match simple_type.as_str() { - "string" => StandardType::Text, - "number" => StandardType::Float, - "boolean" => StandardType::Boolean, - "date" => StandardType::Date, - "null" => StandardType::Unknown, - _ => StandardType::Unknown, - }, - None => StandardType::Unknown, - } - } - - pub fn from_str(type_str: &str) -> Self { - match type_str.to_lowercase().as_str() { - "text" | "string" | "varchar" | "char" => StandardType::Text, - "int" | "integer" | "bigint" | "smallint" => StandardType::Integer, - "float" | "double" | "decimal" | "numeric" => StandardType::Float, - "bool" | "boolean" => StandardType::Boolean, - "date" => StandardType::Date, - "timestamp" | "datetime" | "timestamptz" => StandardType::Timestamp, - "json" | "jsonb" => StandardType::Json, - _ => StandardType::Unknown, - } - } -} - -// Type mappings for each data source type to DataType -static TYPE_MAPPINGS: Lazy>> = - Lazy::new(|| { - let mut mappings = HashMap::new(); - - // Postgres mappings - let mut postgres = HashMap::new(); - postgres.insert("text", DataType::Text(None)); - postgres.insert("varchar", DataType::Text(None)); - postgres.insert("char", DataType::Text(None)); - postgres.insert("int", DataType::Int4(None)); - postgres.insert("integer", DataType::Int4(None)); - postgres.insert("bigint", DataType::Int8(None)); - postgres.insert("smallint", DataType::Int2(None)); - postgres.insert("float", DataType::Float4(None)); - postgres.insert("double precision", DataType::Float8(None)); - postgres.insert("numeric", DataType::Decimal(None)); - postgres.insert("decimal", DataType::Decimal(None)); - postgres.insert("boolean", DataType::Bool(None)); - postgres.insert("date", DataType::Date(None)); - postgres.insert("timestamp", DataType::Timestamp(None)); - postgres.insert("timestamptz", DataType::Timestamptz(None)); - postgres.insert("json", DataType::Json(None)); - postgres.insert("jsonb", DataType::Json(None)); - mappings.insert(DataSourceType::Postgres, postgres.clone()); - mappings.insert(DataSourceType::Supabase, postgres); - - // MySQL mappings - let mut mysql = HashMap::new(); - mysql.insert("varchar", DataType::Text(None)); - mysql.insert("text", DataType::Text(None)); - mysql.insert("char", DataType::Text(None)); - mysql.insert("int", DataType::Int4(None)); - mysql.insert("bigint", DataType::Int8(None)); - mysql.insert("tinyint", DataType::Int2(None)); - mysql.insert("float", DataType::Float4(None)); - mysql.insert("double", DataType::Float8(None)); - mysql.insert("decimal", DataType::Decimal(None)); - mysql.insert("boolean", DataType::Bool(None)); - mysql.insert("date", DataType::Date(None)); - mysql.insert("datetime", DataType::Timestamp(None)); - mysql.insert("timestamp", DataType::Timestamptz(None)); - mysql.insert("json", DataType::Json(None)); - mappings.insert(DataSourceType::MySql, mysql.clone()); - mappings.insert(DataSourceType::Mariadb, mysql); - - // BigQuery mappings - let mut bigquery = HashMap::new(); - bigquery.insert("STRING", DataType::Text(None)); - bigquery.insert("INT64", DataType::Int8(None)); - bigquery.insert("INTEGER", DataType::Int4(None)); - bigquery.insert("FLOAT64", DataType::Float8(None)); - bigquery.insert("NUMERIC", DataType::Decimal(None)); - bigquery.insert("BOOL", DataType::Bool(None)); - bigquery.insert("DATE", DataType::Date(None)); - bigquery.insert("TIMESTAMP", DataType::Timestamptz(None)); - bigquery.insert("JSON", DataType::Json(None)); - mappings.insert(DataSourceType::BigQuery, bigquery); - - // Snowflake mappings - let mut snowflake = HashMap::new(); - snowflake.insert("TEXT", DataType::Text(None)); - snowflake.insert("VARCHAR", DataType::Text(None)); - snowflake.insert("CHAR", DataType::Text(None)); - snowflake.insert("NUMBER", DataType::Decimal(None)); - snowflake.insert("DECIMAL", DataType::Decimal(None)); - snowflake.insert("INTEGER", DataType::Int4(None)); - snowflake.insert("BIGINT", DataType::Int8(None)); - snowflake.insert("BOOLEAN", DataType::Bool(None)); - snowflake.insert("DATE", DataType::Date(None)); - snowflake.insert("TIMESTAMP", DataType::Timestamptz(None)); - snowflake.insert("VARIANT", DataType::Json(None)); - mappings.insert(DataSourceType::Snowflake, snowflake); - - mappings - }); - -pub fn normalize_type(source_type: DataSourceType, type_str: &str) -> DataType { - TYPE_MAPPINGS - .get(&source_type) - .and_then(|mappings| mappings.get(type_str)) - .cloned() - .unwrap_or(DataType::Unknown(Some(type_str.to_string()))) -} - -pub fn types_compatible(source_type: DataSourceType, ds_type: &str, model_type: &str) -> bool { - let ds_data_type = normalize_type(source_type, ds_type); - let model_data_type = normalize_type(source_type, model_type); - - match (&ds_data_type, &model_data_type) { - // Allow integer -> float/decimal conversions - (DataType::Int2(_), DataType::Float4(_)) => true, - (DataType::Int2(_), DataType::Float8(_)) => true, - (DataType::Int2(_), DataType::Decimal(_)) => true, - (DataType::Int4(_), DataType::Float4(_)) => true, - (DataType::Int4(_), DataType::Float8(_)) => true, - (DataType::Int4(_), DataType::Decimal(_)) => true, - (DataType::Int8(_), DataType::Float4(_)) => true, - (DataType::Int8(_), DataType::Float8(_)) => true, - (DataType::Int8(_), DataType::Decimal(_)) => true, - - // Allow text for any type (common in views/computed columns) - (DataType::Text(_), _) => true, - - // Allow timestamp/timestamptz compatibility - (DataType::Timestamp(_), DataType::Timestamptz(_)) => true, - (DataType::Timestamptz(_), DataType::Timestamp(_)) => true, - - // Exact matches (using to_string to compare types, not values) - (a, b) if a.to_string() == b.to_string() => true, - - // Everything else is incompatible - _ => false, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_postgres_type_normalization() { - assert!(matches!( - normalize_type(DataSourceType::Postgres, "text"), - DataType::Text(_) - )); - assert!(matches!( - normalize_type(DataSourceType::Postgres, "integer"), - DataType::Int4(_) - )); - assert!(matches!( - normalize_type(DataSourceType::Postgres, "numeric"), - DataType::Decimal(_) - )); - } - - #[test] - fn test_bigquery_type_normalization() { - assert!(matches!( - normalize_type(DataSourceType::BigQuery, "STRING"), - DataType::Text(_) - )); - assert!(matches!( - normalize_type(DataSourceType::BigQuery, "INT64"), - DataType::Int8(_) - )); - assert!(matches!( - normalize_type(DataSourceType::BigQuery, "FLOAT64"), - DataType::Float8(_) - )); - } - - #[test] - fn test_type_compatibility() { - // Same types are compatible - assert!(types_compatible(DataSourceType::Postgres, "text", "text")); - - // Integer can be used as float - assert!(types_compatible( - DataSourceType::Postgres, - "integer", - "float" - )); - - // Text can be used for any type - assert!(types_compatible( - DataSourceType::Postgres, - "text", - "integer" - )); - - // Different types are incompatible - assert!(!types_compatible( - DataSourceType::Postgres, - "integer", - "text" - )); - - // Timestamp compatibility - assert!(types_compatible( - DataSourceType::Postgres, - "timestamp", - "timestamptz" - )); - } -} diff --git a/api/server/src/utils/validation/types.rs b/api/server/src/utils/validation/types.rs deleted file mode 100644 index 3802ac467..000000000 --- a/api/server/src/utils/validation/types.rs +++ /dev/null @@ -1,182 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ValidationResult { - pub success: bool, - pub model_name: String, - pub data_source_name: String, - pub schema: String, - pub errors: Vec, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ValidationError { - pub error_type: ValidationErrorType, - pub column_name: Option, - pub message: String, - pub suggestion: Option, - pub context: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub enum ValidationErrorType { - TableNotFound, - ColumnNotFound, - TypeMismatch, - DataSourceError, - ModelNotFound, - InvalidRelationship, - ExpressionError, - ProjectNotFound, - InvalidBusterYml, - DataSourceMismatch, - RequiredFieldMissing, - DataSourceNotFound, - SchemaError, - CredentialsError, - DatabaseError, - ValidationError, - InternalError, -} - -impl ValidationResult { - pub fn new(model_name: String, data_source_name: String, schema: String) -> Self { - Self { - success: true, - model_name, - data_source_name, - schema, - errors: Vec::new(), - } - } - - pub fn add_error(&mut self, error: ValidationError) { - self.success = false; - self.errors.push(error); - } -} - -impl ValidationError { - pub fn new( - error_type: ValidationErrorType, - column_name: Option, - message: String, - suggestion: Option, - ) -> Self { - Self { - error_type, - column_name, - message, - suggestion, - context: None, - } - } - - pub fn with_context(mut self, context: String) -> Self { - self.context = Some(context); - self - } - - pub fn table_not_found(table_name: &str) -> Self { - Self::new( - ValidationErrorType::TableNotFound, - None, - format!("Table '{}' not found in data source", table_name), - None, - ) - } - - pub fn column_not_found(column_name: &str) -> Self { - Self::new( - ValidationErrorType::ColumnNotFound, - Some(column_name.to_string()), - format!("Column '{}' not found in data source", column_name), - None, - ) - } - - pub fn type_mismatch(column_name: &str, expected: &str, found: &str) -> Self { - Self::new( - ValidationErrorType::TypeMismatch, - Some(column_name.to_string()), - format!( - "Column '{}' type mismatch. Expected: {}, Found: {}", - column_name, expected, found - ), - None, - ) - } - - pub fn data_source_error(message: String) -> Self { - Self::new( - ValidationErrorType::DataSourceError, - None, - message, - None, - ) - } - - pub fn model_not_found(model_name: &str) -> Self { - Self::new( - ValidationErrorType::ModelNotFound, - None, - format!("Model '{}' not found in data source", model_name), - None, - ) - } - - pub fn invalid_relationship(from: &str, to: &str, reason: &str) -> Self { - Self::new( - ValidationErrorType::InvalidRelationship, - None, - format!("Invalid relationship from '{}' to '{}': {}", from, to, reason), - None, - ) - } - - pub fn expression_error(column_name: &str, expr: &str, reason: &str) -> Self { - Self::new( - ValidationErrorType::ExpressionError, - Some(column_name.to_string()), - format!("Invalid expression '{}' for column '{}': {}", expr, column_name, reason), - None, - ) - } - - // New factory methods for enhanced error types - pub fn schema_error(schema_name: &str, reason: &str) -> Self { - Self::new( - ValidationErrorType::SchemaError, - None, - format!("Schema '{}' error: {}", schema_name, reason), - None, - ) - } - - pub fn credentials_error(data_source: &str, reason: &str) -> Self { - Self::new( - ValidationErrorType::CredentialsError, - None, - format!("Credentials error for data source '{}': {}", data_source, reason), - None, - ) - } - - pub fn database_error(message: String) -> Self { - Self::new( - ValidationErrorType::DatabaseError, - None, - message, - None, - ) - } - - pub fn internal_error(message: String) -> Self { - Self::new( - ValidationErrorType::InternalError, - None, - message, - None, - ) - } -} \ No newline at end of file From 2649fb7656eb4eea02a5d0df189c7f0fa70917bd Mon Sep 17 00:00:00 2001 From: dal Date: Fri, 4 Apr 2025 16:25:20 -0600 Subject: [PATCH 3/3] last few utils to clean up --- api/server/src/types/mod.rs | 0 api/server/src/utils/clients/ai/anthropic.rs | 238 ---------- .../src/utils/clients/ai/embedding_router.rs | 59 --- .../src/utils/clients/ai/hugging_face.rs | 57 --- api/server/src/utils/clients/ai/langfuse.rs | 394 ----------------- api/server/src/utils/clients/ai/llm_router.rs | 383 ---------------- api/server/src/utils/clients/ai/mod.rs | 7 - api/server/src/utils/clients/ai/ollama.rs | 67 --- api/server/src/utils/clients/ai/openai.rs | 417 ------------------ api/server/src/utils/clients/aws.rs | 65 --- api/server/src/utils/clients/mod.rs | 4 - api/server/src/utils/clients/posthog.rs | 190 -------- api/server/src/utils/clients/typesense.rs | 318 ------------- api/server/src/utils/mod.rs | 2 - api/server/src/utils/stored_values/mod.rs | 291 ------------ api/server/src/utils/stored_values/search.rs | 64 --- 16 files changed, 2556 deletions(-) delete mode 100644 api/server/src/types/mod.rs delete mode 100644 api/server/src/utils/clients/ai/anthropic.rs delete mode 100644 api/server/src/utils/clients/ai/embedding_router.rs delete mode 100644 api/server/src/utils/clients/ai/hugging_face.rs delete mode 100644 api/server/src/utils/clients/ai/langfuse.rs delete mode 100644 api/server/src/utils/clients/ai/llm_router.rs delete mode 100644 api/server/src/utils/clients/ai/mod.rs delete mode 100644 api/server/src/utils/clients/ai/ollama.rs delete mode 100644 api/server/src/utils/clients/ai/openai.rs delete mode 100644 api/server/src/utils/clients/aws.rs delete mode 100644 api/server/src/utils/clients/posthog.rs delete mode 100644 api/server/src/utils/clients/typesense.rs delete mode 100644 api/server/src/utils/stored_values/mod.rs delete mode 100644 api/server/src/utils/stored_values/search.rs diff --git a/api/server/src/types/mod.rs b/api/server/src/types/mod.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/api/server/src/utils/clients/ai/anthropic.rs b/api/server/src/utils/clients/ai/anthropic.rs deleted file mode 100644 index 7cd9fa9e8..000000000 --- a/api/server/src/utils/clients/ai/anthropic.rs +++ /dev/null @@ -1,238 +0,0 @@ -use anyhow::{anyhow, Result}; -use futures::StreamExt; -use std::{env, time::Duration}; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio_stream::wrappers::ReceiverStream; - -use serde::{Deserialize, Serialize}; - -use crate::utils::clients::sentry_utils::send_sentry_error; - -const ANTHROPIC_CHAT_URL: &str = "https://api.anthropic.com/v1/messages"; - -lazy_static::lazy_static! { - static ref ANTHROPIC_API_KEY: String = env::var("ANTHROPIC_API_KEY") - .expect("ANTHROPIC_API_KEY must be set"); - static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED") - .unwrap_or(String::from("true")) - .parse() - .expect("MONITORING_ENABLED must be a boolean"); -} - -#[derive(Serialize, Clone)] -pub enum AnthropicChatModel { - #[serde(rename = "claude-3-opus-20240229")] - Claude3Opus20240229, -} - -#[derive(Serialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum AnthropicChatRole { - User, - Assistant, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(rename_all = "lowercase")] -pub enum AnthropicContentType { - Text, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct AnthropicContent { - #[serde(rename = "type")] - pub _type: AnthropicContentType, - pub text: String, -} - -#[derive(Serialize, Clone)] -pub struct AnthropicChatMessage { - pub role: AnthropicChatRole, - pub content: Vec, -} - -#[derive(Serialize, Clone)] -pub struct AnthropicChatRequest { - pub model: AnthropicChatModel, - #[serde(skip_serializing_if = "Option::is_none")] - pub system: Option, - pub messages: Vec, - pub temperature: f32, - pub max_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop_sequences: Option>, - pub stream: bool, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct ChatCompletionResponse { - pub content: Vec, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct Content { - #[serde(rename = "type")] - pub _type: String, - pub text: String, -} - -pub async fn anthropic_chat( - model: &AnthropicChatModel, - system: Option, - messages: &Vec, - temperature: f32, - max_tokens: u32, - timeout: u64, - stop: Option>, -) -> Result { - let chat_request = AnthropicChatRequest { - model: model.clone(), - system, - messages: messages.clone(), - temperature, - max_tokens, - stop_sequences: stop, - stream: false, - }; - - let client = reqwest::Client::new(); - - let headers = { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - "x-api-key", - format!("{}", ANTHROPIC_API_KEY.to_string()) - .parse() - .unwrap(), - ); - headers.insert("anthropic-version", "2023-06-01".parse().unwrap()); - headers - }; - - let response = match client - .post(ANTHROPIC_CHAT_URL) - .headers(headers) - .json(&chat_request) - .timeout(Duration::from_secs(timeout)) - .send() - .await - { - Ok(response) => response, - Err(e) => { - tracing::error!("Unable to send request to Anthropic: {:?}", e); - let err = anyhow!("Unable to send request to Anthropic: {}", e); - send_sentry_error(&err.to_string(), None); - return Err(err); - } - }; - - let completion_res = match response.json::().await { - Ok(res) => res, - Err(e) => { - tracing::error!("Unable to parse response from Anthropic: {:?}", e); - let err = anyhow!("Unable to parse response from Anthropic: {}", e); - send_sentry_error(&err.to_string(), None); - return Err(err); - } - }; - - let content = match completion_res.content.get(0) { - Some(content) => content.text.clone(), - None => return Err(anyhow!("No content returned from Anthropic")), - }; - - Ok(content) -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct AnthropicChatDelta { - #[serde(rename = "type")] - pub _type: String, - pub delta: String, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct AnthropicChatStreamResponse { - #[serde(rename = "type")] - pub _type: String, - pub delta: Option, -} - -pub async fn anthropic_chat_stream( - model: &AnthropicChatModel, - system: Option, - messages: &Vec, - temperature: f32, - max_tokens: u32, - timeout: u64, - stop: Option>, -) -> Result> { - let chat_request = AnthropicChatRequest { - model: model.clone(), - system, - messages: messages.clone(), - temperature, - max_tokens, - stream: true, - stop_sequences: stop, - }; - - let client = reqwest::Client::new(); - - let headers = { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Bearer {}", ANTHROPIC_API_KEY.to_string()) - .parse() - .unwrap(), - ); - headers - }; - - let (_tx, rx): (Sender, Receiver) = mpsc::channel(100); - - tokio::spawn(async move { - let response = client - .post(ANTHROPIC_CHAT_URL) - .headers(headers) - .json(&chat_request) - .timeout(Duration::from_secs(timeout)) - .send() - .await - .map_err(|e| { - tracing::error!("Unable to send request to Anthropic: {:?}", e); - let err = anyhow!("Unable to send request to Anthropic: {}", e); - send_sentry_error(&err.to_string(), None); - err - }); - - if let Err(_e) = response { - return; - } - - let response = response.unwrap(); - let mut stream = response.bytes_stream(); - - let _buffer = String::new(); - - while let Some(item) = stream.next().await { - match item { - Ok(bytes) => { - let chunk = String::from_utf8(bytes.to_vec()).unwrap(); - println!("----------------------"); - println!("Chunk: {}", chunk); - println!("----------------------"); - } - Err(e) => { - tracing::error!("Error while streaming response: {:?}", e); - let err = anyhow!("Error while streaming response: {}", e); - send_sentry_error(&err.to_string(), None); - break; - } - } - } - }); - - Ok(ReceiverStream::new(rx)) -} diff --git a/api/server/src/utils/clients/ai/embedding_router.rs b/api/server/src/utils/clients/ai/embedding_router.rs deleted file mode 100644 index b10cb495d..000000000 --- a/api/server/src/utils/clients/ai/embedding_router.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::env; - -use anyhow::{anyhow, Result}; -use serde::{Deserialize, Serialize}; -use tokio::task; - -use super::{ - hugging_face::hugging_face_embedding, ollama::ollama_embedding, openai::ada_bulk_embedding, -}; - -#[derive(Serialize)] -pub struct EmbeddingRequest { - pub prompt: String, -} - -#[derive(Deserialize, Debug)] -pub struct EmbeddingResponse { - pub embedding: Vec>, -} - -pub enum EmbeddingProvider { - OpenAi, - Ollama, - HuggingFace, -} - -impl EmbeddingProvider { - pub fn get_embedding_provider() -> Result { - let embedding_provider = - env::var("EMBEDDING_PROVIDER").expect("An embedding provider is required."); - match embedding_provider.as_str() { - "openai" => Ok(EmbeddingProvider::OpenAi), - "ollama" => Ok(EmbeddingProvider::Ollama), - "huggingface" => Ok(EmbeddingProvider::HuggingFace), - _ => Err(anyhow!("Invalid embedding provider")), - } - } -} - -pub async fn embedding_router(prompts: Vec, for_retrieval: bool) -> Result>> { - let embedding_provider = EmbeddingProvider::get_embedding_provider()?; - - match embedding_provider { - EmbeddingProvider::Ollama => { - let tasks = prompts.into_iter().map(|prompt| { - task::spawn(async move { ollama_embedding(prompt, for_retrieval).await }) - }); - let results = futures::future::join_all(tasks).await; - let embeddings: Result>> = results - .into_iter() - .collect::, _>>()? - .into_iter() - .collect(); - embeddings - } - EmbeddingProvider::HuggingFace => hugging_face_embedding(prompts).await, - EmbeddingProvider::OpenAi => ada_bulk_embedding(prompts).await, - } -} diff --git a/api/server/src/utils/clients/ai/hugging_face.rs b/api/server/src/utils/clients/ai/hugging_face.rs deleted file mode 100644 index b40d33be5..000000000 --- a/api/server/src/utils/clients/ai/hugging_face.rs +++ /dev/null @@ -1,57 +0,0 @@ -use anyhow::{anyhow, Result}; -use serde::Serialize; -use std::env; - -use axum::http::HeaderMap; - -#[derive(Serialize)] -pub struct HuggingFaceEmbeddingRequest { - pub inputs: Vec, -} - -pub async fn hugging_face_embedding(prompts: Vec) -> Result>> { - let hugging_face_url = env::var("HUGGING_FACE_URL").expect("HUGGING_FACE_URL must be set"); - let hugging_face_api_key = - env::var("HUGGING_FACE_API_KEY").expect("HUGGING_FACE_API_KEY must be set"); - - let client = match reqwest::Client::builder().build() { - Ok(client) => client, - Err(e) => { - return Err(anyhow!("Error creating reqwest client: {:?}", e)); - } - }; - - let mut headers = HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - "application/json".parse().unwrap(), - ); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Bearer {}", hugging_face_api_key).parse().unwrap(), - ); - - let req = HuggingFaceEmbeddingRequest { inputs: prompts }; - - let res = match client - .post(hugging_face_url) - .headers(headers) - .json(&req) - .send() - .await - { - Ok(res) => res, - Err(e) => { - return Err(anyhow!("Error sending Ollama request: {:?}", e)); - } - }; - - let embeddings = match res.json::>>().await { - Ok(res) => res, - Err(e) => { - return Err(anyhow!("Error parsing Ollama response: {:?}", e)); - } - }; - - Ok(embeddings) -} diff --git a/api/server/src/utils/clients/ai/langfuse.rs b/api/server/src/utils/clients/ai/langfuse.rs deleted file mode 100644 index d9017606f..000000000 --- a/api/server/src/utils/clients/ai/langfuse.rs +++ /dev/null @@ -1,394 +0,0 @@ -use anyhow::{anyhow, Result}; -use axum::http::HeaderMap; -use base64::Engine; -use reqwest::Method; -use std::env; -use tiktoken_rs::o200k_base; - -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::utils::clients::sentry_utils::send_sentry_error; - -use super::{anthropic::AnthropicChatModel, llm_router::LlmModel, openai::OpenAiChatModel}; - -lazy_static::lazy_static! { - static ref LANGFUSE_API_URL: String = env::var("LANGFUSE_API_URL").unwrap_or("https://us.cloud.langfuse.com".to_string()); - static ref LANGFUSE_API_PUBLIC_KEY: String = env::var("LANGFUSE_PUBLIC_API_KEY").expect("LANGFUSE_PUBLIC_API_KEY must be set"); - static ref LANGFUSE_API_PRIVATE_KEY: String = env::var("LANGFUSE_PRIVATE_API_KEY").expect("LANGFUSE_PRIVATE_API_KEY must be set"); -} - -impl LlmModel { - pub fn generate_usage(&self, input: &String, output: &String) -> Usage { - let bpe = o200k_base().unwrap(); - - let input_token = bpe.encode_with_special_tokens(&input); - let output_token = bpe.encode_with_special_tokens(&output); - - match self { - LlmModel::OpenAi(OpenAiChatModel::O3Mini) => Usage { - input: input_token.len() as u32, - output: output_token.len() as u32, - unit: "TOKENS".to_string(), - input_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5, - output_cost: (output_token.len() as f64 / 1_000_000.0) * 10.0, - total_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5 - + (output_token.len() as f64 / 1_000_000.0) * 10.0, - }, - LlmModel::OpenAi(OpenAiChatModel::Gpt35Turbo) => Usage { - input: input_token.len() as u32, - output: output_token.len() as u32, - unit: "TOKENS".to_string(), - input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5, - output_cost: (output_token.len() as f64 / 1_000_000.0) * 1.5, - total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5 - + (output_token.len() as f64 / 1_000_000.0) * 1.5, - }, - LlmModel::OpenAi(OpenAiChatModel::Gpt4o) => Usage { - input: input_token.len() as u32, - output: output_token.len() as u32, - unit: "TOKENS".to_string(), - input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15, - output_cost: (output_token.len() as f64 / 1_000_000.0) * 0.6, - total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15 - + (output_token.len() as f64 / 1_000_000.0) * 0.6, - }, - LlmModel::Anthropic(AnthropicChatModel::Claude3Opus20240229) => Usage { - input: input_token.len() as u32, - output: output_token.len() as u32, - unit: "TOKENS".to_string(), - input_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0, - output_cost: (output_token.len() as f64 / 1_000_000.0) * 15.0, - total_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0 - + (output_token.len() as f64 / 1_000_000.0) * 15.0, - }, - } - } -} - -#[derive(Serialize, Clone, Debug)] -#[serde(rename_all = "snake_case")] -pub enum PromptName { - SelectDataset, - GenerateSql, - SelectTerm, - DataSummary, - AutoChartConfig, - LineChartConfig, - BarChartConfig, - ScatterChartConfig, - PieChartConfig, - MetricChartConfig, - TableConfig, - ColumnLabelFormat, - AdvancedVisualizationConfig, - NoDataReturnedResponse, - DataExplanation, - MetricTitle, - TimeFrame, - FixSqlPlanner, - FixSql, - GenerateColDescriptions, - GenerateDatasetDescription, - SummaryQuestion, - CustomPrompt(String), -} - -impl PromptName { - fn to_string(&self) -> String { - match self { - PromptName::SelectDataset => "select_dataset".to_string(), - PromptName::GenerateSql => "generate_sql".to_string(), - PromptName::SelectTerm => "select_term".to_string(), - PromptName::DataSummary => "data_summary".to_string(), - PromptName::AutoChartConfig => "auto_chart_config".to_string(), - PromptName::LineChartConfig => "line_chart_config".to_string(), - PromptName::BarChartConfig => "bar_chart_config".to_string(), - PromptName::ScatterChartConfig => "scatter_chart_config".to_string(), - PromptName::PieChartConfig => "pie_chart_config".to_string(), - PromptName::MetricChartConfig => "metric_chart_config".to_string(), - PromptName::TableConfig => "table_config".to_string(), - PromptName::ColumnLabelFormat => "column_label_format".to_string(), - PromptName::AdvancedVisualizationConfig => "advanced_visualization_config".to_string(), - PromptName::NoDataReturnedResponse => "no_data_returned_response".to_string(), - PromptName::DataExplanation => "data_explanation".to_string(), - PromptName::MetricTitle => "metric_title".to_string(), - PromptName::TimeFrame => "time_frame".to_string(), - PromptName::FixSqlPlanner => "fix_sql_planner".to_string(), - PromptName::FixSql => "fix_sql".to_string(), - PromptName::GenerateColDescriptions => "generate_col_descriptions".to_string(), - PromptName::GenerateDatasetDescription => "generate_dataset_description".to_string(), - PromptName::SummaryQuestion => "summary_question".to_string(), - PromptName::CustomPrompt(prompt) => prompt.clone(), - } - } -} - -#[derive(Serialize, Debug)] -#[serde(rename_all = "kebab-case")] -enum LangfuseIngestionType { - TraceCreate, - GenerationCreate, -} - -#[derive(Serialize, Debug)] -#[serde(rename_all = "camelCase")] -struct CreateTraceBody { - id: Uuid, - timestamp: DateTime, - name: String, - user_id: Uuid, - input: String, - output: String, - session_id: Uuid, - release: String, - version: String, - metadata: Metadata, - tags: Vec, - public: bool, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct GenerationCreateBody { - trace_id: Uuid, - name: String, - start_time: DateTime, - completion_start_time: DateTime, - input: String, - output: String, - level: String, - end_time: DateTime, - model: LlmModel, - id: Uuid, - usage: Usage, -} - -#[derive(Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct Usage { - input: u32, - output: u32, - unit: String, - input_cost: f64, - output_cost: f64, - total_cost: f64, -} - -#[derive(Serialize, Debug)] -struct Metadata {} - -#[derive(Serialize)] -#[serde(untagged)] -enum LangfuseRequestBody { - CreateTraceBody(CreateTraceBody), - GenerationCreateBody(GenerationCreateBody), -} - -#[derive(Serialize)] -struct LangfuseBatchItem { - id: Uuid, - r#type: LangfuseIngestionType, - body: LangfuseRequestBody, - timestamp: DateTime, -} - -#[derive(Serialize)] -struct LangfuseBatch { - batch: Vec, -} - -#[allow(dead_code)] -#[derive(Deserialize, Debug)] -struct LangfuseError { - id: String, - status: u16, - message: String, - error: String, -} - -#[allow(dead_code)] -#[derive(Deserialize, Debug)] -struct LangfuseSuccess { - id: String, - status: u16, -} - -#[allow(dead_code)] -#[derive(Deserialize, Debug)] -struct LangfuseResponse { - successes: Vec, - errors: Vec, -} - -/// Args: -/// - session_id: this can be a thread_id or any other type of chain event we have -/// - -pub async fn send_langfuse_request( - session_id: &Uuid, - prompt_name: PromptName, - context: Option, - start_time: DateTime, - end_time: DateTime, - input: String, - output: String, - user_id: &Uuid, - langfuse_model: &LlmModel, -) -> () { - let session_id = session_id.clone(); - let user_id = user_id.clone(); - let langfuse_model = langfuse_model.clone(); - - tokio::spawn(async move { - match langfuse_handler( - session_id, - prompt_name, - context, - start_time, - end_time, - input, - output, - user_id, - langfuse_model, - ) - .await - { - Ok(_) => (), - Err(e) => { - let err = anyhow!("Error sending Langfuse request: {:?}", e); - send_sentry_error(&err.to_string(), Some(&user_id)); - } - } - }); -} - -async fn langfuse_handler( - session_id: Uuid, - prompt_name: PromptName, - context: Option, - start_time: DateTime, - end_time: DateTime, - input: String, - output: String, - user_id: Uuid, - langfuse_model: LlmModel, -) -> Result<()> { - let input = match context { - Some(context) => format!("{} \n\n {}", context, input), - None => input, - }; - - let trace_id = Uuid::new_v4(); - - let langfuse_trace = LangfuseBatchItem { - id: Uuid::new_v4(), - r#type: LangfuseIngestionType::TraceCreate, - body: LangfuseRequestBody::CreateTraceBody(CreateTraceBody { - id: trace_id.clone(), - timestamp: Utc::now(), - name: prompt_name.clone().to_string(), - user_id: user_id, - input: serde_json::to_string(&input).unwrap(), - output: serde_json::to_string(&output).unwrap(), - session_id: session_id, - release: "1.0.0".to_string(), - version: "1.0.0".to_string(), - metadata: Metadata {}, - tags: vec![], - public: false, - }), - timestamp: Utc::now(), - }; - - let langfuse_generation = LangfuseBatchItem { - id: Uuid::new_v4(), - r#type: LangfuseIngestionType::GenerationCreate, - body: LangfuseRequestBody::GenerationCreateBody(GenerationCreateBody { - id: Uuid::new_v4(), - name: prompt_name.to_string(), - input: serde_json::to_string(&input).unwrap(), - output: serde_json::to_string(&output).unwrap(), - trace_id, - start_time, - completion_start_time: start_time, - level: "DEBUG".to_string(), - end_time, - model: langfuse_model.clone(), - usage: langfuse_model.generate_usage(&input, &output), - }), - timestamp: Utc::now(), - }; - - let langfuse_batch = LangfuseBatch { - batch: vec![langfuse_trace, langfuse_generation], - }; - - let client = match reqwest::Client::builder().build() { - Ok(client) => client, - Err(e) => { - send_sentry_error(&e.to_string(), Some(&user_id)); - return Err(anyhow!("Error creating reqwest client: {:?}", e)); - } - }; - - let mut headers = HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - "application/json".parse().unwrap(), - ); - headers.insert( - reqwest::header::AUTHORIZATION, - format!( - "Basic {}", - base64::engine::general_purpose::STANDARD.encode(format!( - "{}:{}", - *LANGFUSE_API_PUBLIC_KEY, *LANGFUSE_API_PRIVATE_KEY - )) - ) - .parse() - .unwrap(), - ); - - let res = match client - .request( - Method::POST, - LANGFUSE_API_URL.to_string() + "/api/public/ingestion", - ) - .headers(headers) - .json(&langfuse_batch) - .send() - .await - { - Ok(res) => res, - Err(e) => { - tracing::debug!("Error sending Langfuse request: {:?}", e); - send_sentry_error(&e.to_string(), Some(&user_id)); - return Err(anyhow!("Error sending Langfuse request: {:?}", e)); - } - }; - - let langfuse_res: LangfuseResponse = match res.json::().await { - Ok(res) => res, - Err(e) => { - tracing::debug!("Error parsing Langfuse response: {:?}", e); - send_sentry_error(&e.to_string(), Some(&user_id)); - return Err(anyhow!("Error parsing Langfuse response: {:?}", e)); - } - }; - - if langfuse_res.errors.len() > 0 { - for err in &langfuse_res.errors { - tracing::debug!("Langfuse error: {:?}", err); - send_sentry_error(&err.message, Some(&user_id)); - } - - return Err(anyhow::anyhow!( - "Langfuse errors: {:?}", - langfuse_res.errors - )); - } - - Ok(()) -} diff --git a/api/server/src/utils/clients/ai/llm_router.rs b/api/server/src/utils/clients/ai/llm_router.rs deleted file mode 100644 index 3e8720930..000000000 --- a/api/server/src/utils/clients/ai/llm_router.rs +++ /dev/null @@ -1,383 +0,0 @@ -use std::env; - -use anyhow::{anyhow, Result}; -use chrono::Utc; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::{ - sync::mpsc::{self, Receiver}, - task::JoinHandle, -}; -use tokio_stream::{wrappers::ReceiverStream, StreamExt}; -use uuid::Uuid; - -use super::{ - anthropic::{ - anthropic_chat, anthropic_chat_stream, AnthropicChatMessage, AnthropicChatModel, - AnthropicChatRole, AnthropicContent, AnthropicContentType, - }, - langfuse::{send_langfuse_request, PromptName}, - openai::{ - openai_chat, openai_chat_stream, OpenAiChatContent, OpenAiChatMessage, OpenAiChatModel, - OpenAiChatRole, - }, -}; -use lazy_static::lazy_static; - -lazy_static! { - static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED") - .unwrap_or(String::from("true")) - .parse() - .expect("MONITORING_ENABLED must be a boolean"); -} - -#[derive(Serialize, Clone)] -#[serde(untagged)] -pub enum LlmModel { - Anthropic(AnthropicChatModel), - OpenAi(OpenAiChatModel), -} - -#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum LlmRole { - System, - User, - Assistant, -} - -#[derive(Serialize, Deserialize, Clone)] -pub struct LlmMessage { - pub role: LlmRole, - pub content: String, -} - -impl LlmMessage { - pub fn new(role: String, content: String) -> Self { - let role = match role.to_lowercase().as_str() { - "system" => LlmRole::System, - "user" => LlmRole::User, - "assistant" => LlmRole::Assistant, - _ => LlmRole::User, - }; - Self { role, content } - } -} - -pub async fn llm_chat( - model: LlmModel, - messages: &Vec, - temperature: f32, - max_tokens: u32, - timeout: u64, - stop: Option>, - json_mode: bool, - json_schema: Option, - session_id: &Uuid, - user_id: &Uuid, - prompt_name: PromptName, -) -> Result { - let start_time = Utc::now(); - - let response_result = match &model { - LlmModel::Anthropic(model) => { - anthropic_chat_compiler(model, messages, max_tokens, temperature, timeout, stop).await - } - LlmModel::OpenAi(model) => { - openai_chat_compiler( - model, - messages, - 7048, - temperature, - timeout, - stop, - json_mode, - json_schema, - ) - .await - } - }; - - let response = match response_result { - Ok(response) => response, - Err(e) => return Err(anyhow!("LLM chat error: {}", e)), - }; - - let end_time = Utc::now(); - - send_langfuse_request( - session_id, - prompt_name, - None, - start_time, - end_time, - serde_json::to_string(&messages).unwrap(), - serde_json::to_string(&response).unwrap(), - user_id, - &model, - ) - .await; - - Ok(response) -} - -pub async fn llm_chat_stream( - model: LlmModel, - messages: Vec, - temperature: f32, - max_tokens: u32, - timeout: u64, - stop: Option>, - session_id: &Uuid, - user_id: &Uuid, - prompt_name: PromptName, -) -> Result<(Receiver, JoinHandle>)> { - let start_time = Utc::now(); - - let stream_result = match &model { - LlmModel::Anthropic(model) => { - anthropic_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop) - .await - } - LlmModel::OpenAi(model) => { - openai_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop) - .await - } - }; - - let mut stream = match stream_result { - Ok(stream) => stream, - Err(e) => return Err(anyhow!("LLM chat error: {}", e)), - }; - - let (tx, rx) = mpsc::channel(100); - - let res_future = { - let session_id = session_id.clone(); - let user_id = user_id.clone(); - - tokio::spawn(async move { - let mut response = String::new(); - while let Some(content) = stream.next().await { - response.push_str(&content); - - match tx.send(content).await { - Ok(_) => (), - Err(e) => return Err(anyhow!("Streaming Error: {}", e)), - } - } - - let end_time = Utc::now(); - - send_langfuse_request( - &session_id, - prompt_name, - None, - start_time, - end_time, - serde_json::to_string(&messages).unwrap(), - serde_json::to_string(&response).unwrap(), - &user_id, - &model, - ) - .await; - - Ok(response) - }) - }; - - Ok((rx, res_future)) -} - -async fn anthropic_chat_compiler( - model: &AnthropicChatModel, - messages: &Vec, - _max_tokens: u32, - temperature: f32, - timeout: u64, - stop: Option>, -) -> Result { - let system_message = match messages.iter().find(|m| m.role == LlmRole::System) { - Some(message) => Some(message.content.clone()), - None => None, - }; - let mut anthropic_messages = Vec::new(); - - for message in messages { - let anthropic_role = match message.role { - LlmRole::System => continue, - LlmRole::User => AnthropicChatRole::User, - LlmRole::Assistant => AnthropicChatRole::Assistant, - }; - - let anthropic_content = AnthropicContent { - text: message.content.clone(), - _type: AnthropicContentType::Text, - }; - - anthropic_messages.push(AnthropicChatMessage { - role: anthropic_role, - content: vec![anthropic_content], - }); - } - - let response = match anthropic_chat( - model, - system_message, - &anthropic_messages, - temperature, - 7048, - timeout, - stop, - ) - .await - { - Ok(response) => response, - Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)), - }; - - Ok(response) -} - -async fn openai_chat_compiler( - model: &OpenAiChatModel, - messages: &Vec, - max_tokens: u32, - temperature: f32, - timeout: u64, - stop: Option>, - json_mode: bool, - json_schema: Option, -) -> Result { - let mut openai_messages = Vec::new(); - - for message in messages { - let openai_role = match message.role { - LlmRole::System => OpenAiChatRole::System, - LlmRole::User => OpenAiChatRole::User, - LlmRole::Assistant => OpenAiChatRole::Assistant, - }; - - let openai_message = OpenAiChatMessage { - role: openai_role, - content: vec![OpenAiChatContent { - type_: "text".to_string(), - text: message.content.clone(), - }], - }; - - openai_messages.push(openai_message); - } - - let response = match openai_chat( - &model, - openai_messages, - temperature, - max_tokens, - timeout, - stop, - json_mode, - json_schema, - ) - .await - { - Ok(response) => response, - Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)), - }; - - Ok(response) -} - -async fn anthropic_chat_stream_compiler( - model: &AnthropicChatModel, - messages: &Vec, - max_tokens: u32, - temperature: f32, - timeout: u64, - stop: Option>, -) -> Result> { - let system_message = match messages.iter().find(|m| m.role == LlmRole::System) { - Some(message) => Some(message.content.clone()), - None => None, - }; - let mut anthropic_messages = Vec::new(); - - for message in messages { - let anthropic_role = match message.role { - LlmRole::System => continue, - LlmRole::User => AnthropicChatRole::User, - LlmRole::Assistant => AnthropicChatRole::Assistant, - }; - - let anthropic_content = AnthropicContent { - text: message.content.clone(), - _type: AnthropicContentType::Text, - }; - - anthropic_messages.push(AnthropicChatMessage { - role: anthropic_role, - content: vec![anthropic_content], - }); - } - - let stream = match anthropic_chat_stream( - model, - system_message, - &anthropic_messages, - temperature, - max_tokens, - timeout, - stop, - ) - .await - { - Ok(response) => response, - Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)), - }; - - Ok(stream) -} - -async fn openai_chat_stream_compiler( - model: &OpenAiChatModel, - messages: &Vec, - max_tokens: u32, - temperature: f32, - timeout: u64, - stop: Option>, -) -> Result> { - let mut openai_messages = Vec::new(); - - for message in messages { - let openai_role = match message.role { - LlmRole::System => OpenAiChatRole::System, - LlmRole::User => OpenAiChatRole::User, - LlmRole::Assistant => OpenAiChatRole::Assistant, - }; - - let openai_message = OpenAiChatMessage { - role: openai_role, - content: vec![OpenAiChatContent { - type_: "text".to_string(), - text: message.content.clone(), - }], - }; - - openai_messages.push(openai_message); - } - - let stream = match openai_chat_stream( - &model, - &openai_messages, - temperature, - max_tokens, - timeout, - stop, - ) - .await - { - Ok(response) => response, - Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)), - }; - - Ok(stream) -} diff --git a/api/server/src/utils/clients/ai/mod.rs b/api/server/src/utils/clients/ai/mod.rs deleted file mode 100644 index f4368095b..000000000 --- a/api/server/src/utils/clients/ai/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod anthropic; -pub mod embedding_router; -mod hugging_face; -pub mod langfuse; -pub mod llm_router; -pub mod ollama; -pub mod openai; \ No newline at end of file diff --git a/api/server/src/utils/clients/ai/ollama.rs b/api/server/src/utils/clients/ai/ollama.rs deleted file mode 100644 index d69e56c77..000000000 --- a/api/server/src/utils/clients/ai/ollama.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::env; - -use anyhow::{anyhow, Result}; -use axum::http::HeaderMap; -use serde::{Deserialize, Serialize}; - -#[derive(Serialize)] -pub struct OllamaEmbeddingRequest { - pub prompt: String, - pub model: String, -} - -#[derive(Deserialize, Debug)] -pub struct OllamaEmbeddingResponse { - pub embedding: Vec, -} - -pub async fn ollama_embedding(prompt: String, for_retrieval: bool) -> Result> { - let ollama_url = env::var("OLLAMA_URL").unwrap_or(String::from("http://localhost:11434")); - let model = env::var("EMBEDDING_MODEL").unwrap_or(String::from("mxbai-embed-large")); - - let prompt = if model == "mxbai-embed-large" && for_retrieval { - format!( - "Represent this sentence for searching relevant passages:: {}", - prompt - ) - } else { - prompt - }; - - let client = match reqwest::Client::builder().build() { - Ok(client) => client, - Err(e) => { - return Err(anyhow!("Error creating reqwest client: {:?}", e)); - } - }; - - let mut headers = HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - "application/json".parse().unwrap(), - ); - - let req = OllamaEmbeddingRequest { prompt, model }; - - let res = match client - .post(format!("{}/api/embeddings", ollama_url)) - .headers(headers) - .json(&req) - .send() - .await - { - Ok(res) => res, - Err(e) => { - return Err(anyhow!("Error sending Ollama request: {:?}", e)); - } - }; - - let ollama_res: OllamaEmbeddingResponse = match res.json::().await { - Ok(res) => res, - Err(e) => { - return Err(anyhow!("Error parsing Ollama response: {:?}", e)); - } - }; - - Ok(ollama_res.embedding) -} diff --git a/api/server/src/utils/clients/ai/openai.rs b/api/server/src/utils/clients/ai/openai.rs deleted file mode 100644 index f8462a391..000000000 --- a/api/server/src/utils/clients/ai/openai.rs +++ /dev/null @@ -1,417 +0,0 @@ -use anyhow::{anyhow, Result}; -use futures::StreamExt; -use serde_json::{json, Value}; -use std::{env, time::Duration}; -use tokio::sync::mpsc::{self, Receiver, Sender}; -use tokio_stream::wrappers::ReceiverStream; - -use serde::{Deserialize, Serialize}; - -use crate::utils::clients::sentry_utils::send_sentry_error; - -const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings"; - -lazy_static::lazy_static! { - static ref OPENAI_API_KEY: String = env::var("OPENAI_API_KEY") - .expect("OPENAI_API_KEY must be set"); - static ref OPENAI_CHAT_URL: String = env::var("OPENAI_CHAT_URL").unwrap_or("https://api.openai.com/v1/chat/completions".to_string()); -} - -#[derive(Serialize, Clone)] -pub enum OpenAiChatModel { - #[serde(rename = "gpt-4o-2024-11-20")] - Gpt4o, - #[serde(rename = "o3-mini")] - O3Mini, - #[serde(rename = "gpt-3.5-turbo")] - Gpt35Turbo, -} - -#[derive(Serialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum OpenAiChatRole { - System, - Developer, - User, - Assistant, -} - -#[derive(Serialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum ReasoningEffort { - Low, - Medium, - High, -} - -#[derive(Serialize, Clone)] -pub struct OpenAiChatContent { - #[serde(rename = "type")] - pub type_: String, - pub text: String, -} - -#[derive(Serialize, Clone)] -pub struct OpenAiChatMessage { - pub role: OpenAiChatRole, - pub content: Vec, -} - -// Helper functions for conditional serialization -fn is_o3_model(model: &OpenAiChatModel) -> bool { - matches!(model, OpenAiChatModel::O3Mini) -} - -#[derive(Serialize, Clone)] -pub struct OpenAiChatRequest { - model: OpenAiChatModel, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - reasoning_effort: Option, - frequency_penalty: f32, - presence_penalty: f32, - #[serde(skip_serializing_if = "Option::is_none")] - stop: Option>, - stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - response_format: Option, -} - -impl OpenAiChatRequest { - pub fn new( - model: OpenAiChatModel, - messages: Vec, - temperature: f32, - max_tokens: u32, - stop: Option>, - stream: bool, - response_format: Option, - ) -> Self { - let (temperature, max_tokens, top_p, reasoning_effort) = if is_o3_model(&model) { - (None, None, None, Some(ReasoningEffort::Low)) - } else { - (Some(temperature), Some(max_tokens), Some(1.0), None) - }; - - Self { - model, - messages, - temperature, - max_tokens, - top_p, - reasoning_effort, - frequency_penalty: 0.0, - presence_penalty: 0.0, - stop, - stream, - response_format, - } - } -} - -#[derive(Deserialize, Debug, Clone)] -pub struct ChatCompletionResponse { - pub choices: Vec, -} - -#[derive(Deserialize, Debug, Clone)] -pub struct Choice { - pub message: Message, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct Message { - pub content: String, -} - -pub async fn openai_chat( - model: &OpenAiChatModel, - messages: Vec, - temperature: f32, - max_tokens: u32, - timeout: u64, - stop: Option>, - json_mode: bool, - json_schema: Option, -) -> Result { - let response_format = match json_mode { - true => Some(json!({"type": "json_object"})), - false => None, - }; - - let response_format = match json_schema { - Some(schema) => Some(json!({"type": "json_schema", "json_schema": schema})), - None => response_format, - }; - - let chat_request = OpenAiChatRequest::new( - model.clone(), - messages, - temperature, - max_tokens, - stop, - false, - response_format, - ); - - let client = reqwest::Client::new(); - - let headers = { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Bearer {}", OPENAI_API_KEY.to_string()) - .parse() - .unwrap(), - ); - headers - }; - - let response = match client - .post(OPENAI_CHAT_URL.to_string()) - .headers(headers) - .json(&chat_request) - .timeout(Duration::from_secs(timeout)) - .send() - .await - { - Ok(response) => response, - Err(e) => { - tracing::error!("Unable to send request to OpenAI: {:?}", e); - let err = anyhow!("Unable to send request to OpenAI: {}", e); - send_sentry_error(&err.to_string(), None); - return Err(err); - } - }; - - let response_text = response.text().await.unwrap(); - - let completion_res = match serde_json::from_str::(&response_text) { - Ok(res) => res, - Err(e) => { - tracing::error!("Unable to parse response from OpenAI: {:?}", e); - let err = anyhow!("Unable to parse response from OpenAI: {}", e); - send_sentry_error(&err.to_string(), None); - return Err(err); - } - }; - - let content = match completion_res.choices.get(0) { - Some(choice) => choice.message.content.clone(), - None => return Err(anyhow!("No content returned from OpenAI")), - }; - - Ok(content) -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAiChatDelta { - pub content: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAiChatChoice { - pub delta: OpenAiChatDelta, - pub index: u32, - pub finish_reason: Option, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct OpenAiChatStreamResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub system_fingerprint: String, - pub choices: Vec, -} - -/// We can't do Langfuse traces directly in the stream function. -/// This is mainly because I'm too lazy to figure out how to set up all the passing with the stream reciever -/// Langfuse traces must be written directly on the stream call. - -pub async fn openai_chat_stream( - model: &OpenAiChatModel, - messages: &Vec, - temperature: f32, - max_tokens: u32, - timeout: u64, - stop: Option>, -) -> Result> { - let chat_request = OpenAiChatRequest::new( - model.clone(), - messages.clone(), - temperature, - max_tokens, - stop, - true, - None, - ); - - let client = reqwest::Client::new(); - - let headers = { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Bearer {}", OPENAI_API_KEY.to_string()) - .parse() - .unwrap(), - ); - headers - }; - - let (tx, rx): (Sender, Receiver) = mpsc::channel(100); - - tokio::spawn(async move { - let response = client - .post(OPENAI_CHAT_URL.to_string()) - .headers(headers) - .json(&chat_request) - .timeout(Duration::from_secs(timeout)) - .send() - .await - .map_err(|e| { - tracing::error!("Unable to send request to OpenAI: {:?}", e); - let err = anyhow!("Unable to send request to OpenAI: {}", e); - send_sentry_error(&err.to_string(), None); - err - }); - - if let Err(_e) = response { - return; - } - - let response = response.unwrap(); - let mut stream = response.bytes_stream(); - - let mut buffer = String::new(); - - while let Some(item) = stream.next().await { - match item { - Ok(bytes) => { - let chunk = String::from_utf8(bytes.to_vec()).unwrap(); - buffer.push_str(&chunk); - - while let Some(pos) = buffer.find("}\n") { - let (json_str, rest) = { - let (json_str, rest) = buffer.split_at(pos + 1); - (json_str.to_string(), rest.to_string()) - }; - buffer = rest; - - let json_str_trimmed = json_str.replace("data: ", ""); - match serde_json::from_str::( - &json_str_trimmed.trim(), - ) { - Ok(response) => { - if let Some(content) = &response.choices[0].delta.content { - if tx.send(content.clone()).await.is_err() { - break; - } - } - } - Err(e) => { - tracing::error!("Error parsing JSON response: {:?}", e); - let err = anyhow!("Error parsing JSON response: {}", e); - send_sentry_error(&err.to_string(), None); - } - } - } - } - Err(e) => { - tracing::error!("Error while streaming response: {:?}", e); - let err = anyhow!("Error while streaming response: {}", e); - send_sentry_error(&err.to_string(), None); - break; - } - } - } - }); - - Ok(ReceiverStream::new(rx)) -} - -#[derive(Serialize, Debug)] -pub struct AdaBulkEmbedding { - pub model: String, - pub input: Vec, - pub dimensions: u32, -} - -#[derive(Deserialize, Debug)] -pub struct AdaEmbeddingArray { - pub embedding: Vec, -} - -#[derive(Deserialize, Debug)] -pub struct AdaEmbeddingResponse { - pub data: Vec, -} - -pub async fn ada_bulk_embedding(text_list: Vec) -> Result>> { - let embedding_model = env::var("EMBEDDING_MODEL").expect("EMBEDDING_MODEL must be set"); - - let client = reqwest::Client::new(); - - let ada_bulk_embedding = AdaBulkEmbedding { - model: embedding_model, - input: text_list, - dimensions: 1024, - }; - - let embeddings_result = match client - .post(OPENAI_EMBEDDING_URL) - .headers({ - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Bearer {}", OPENAI_API_KEY.to_string()) - .parse() - .unwrap(), - ); - headers - }) - .timeout(Duration::from_secs(60)) - .json(&ada_bulk_embedding) - .send() - .await - { - Ok(res) => { - if !res.status().is_success() { - tracing::error!( - "There was an issue while getting the data source: {}", - res.text().await.unwrap() - ); - return Err(anyhow!("Error getting data source")); - } - - Ok(res) - } - Err(e) => Err(anyhow!(e.to_string())), - }; - - let embedding_text = embeddings_result.unwrap().text().await.unwrap(); - - let embeddings: AdaEmbeddingResponse = - match serde_json::from_str(embedding_text.clone().as_str()) { - Ok(embeddings) => embeddings, - Err(e) => { - tracing::error!( - "There was an issue decoding the bulk embedding json response: {}", - e - ); - return Err(anyhow!(e.to_string())); - } - }; - - let result: Vec> = embeddings.data.into_iter().map(|x| x.embedding).collect(); - - Ok(result) -} diff --git a/api/server/src/utils/clients/aws.rs b/api/server/src/utils/clients/aws.rs deleted file mode 100644 index 4587fa1e4..000000000 --- a/api/server/src/utils/clients/aws.rs +++ /dev/null @@ -1,65 +0,0 @@ -use anyhow::{anyhow, Result}; -use aws_config::meta::region::RegionProviderChain; -use aws_sdk_secretsmanager::Client; -use uuid::Uuid; - -pub async fn create_db_secret(secret: String) -> Result { - let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); - let config = aws_config::from_env().region(region_provider).load().await; - let client = Client::new(&config); - - let secret_id = Uuid::new_v4().to_string(); - - match client - .create_secret() - .name(secret_id.clone()) - .secret_string(secret) - .send() - .await - { - Ok(res) => { - tracing::info!("Successfully created secret in AWS."); - tracing::info!("Secret ID: {}", secret_id); - tracing::info!("Secret ARN: {}", res.arn.unwrap()); - } - Err(e) => { - tracing::error!("There was an issue while creating the secret in AWS."); - return Err(anyhow!(e)); - } - } - - Ok(secret_id) -} - -pub async fn delete_db_secret(secret_id: String) -> Result<()> { - let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); - let config = aws_config::from_env().region(region_provider).load().await; - let client = Client::new(&config); - - let _secret_response = match client.delete_secret().secret_id(secret_id).send().await { - Ok(secret_response) => secret_response, - Err(e) => return Err(anyhow!(e)), - }; - - Ok(()) -} - -pub async fn read_secret(secret_id: String) -> Result { - let region_provider = RegionProviderChain::default_provider().or_else("us-east-1"); - let config = aws_config::from_env().region(region_provider).load().await; - let client = Client::new(&config); - - let secret_response = client.get_secret_value().secret_id(secret_id).send().await; - - let secret = match secret_response { - Ok(secret) => secret, - Err(e) => return Err(anyhow!(e)), - }; - - let secret_string = match secret.secret_string { - Some(secret_string) => secret_string, - None => return Err(anyhow!("There was no secret string in the response")), - }; - - return Ok(secret_string); -} diff --git a/api/server/src/utils/clients/mod.rs b/api/server/src/utils/clients/mod.rs index 4ab14b1f8..98dae6a87 100644 --- a/api/server/src/utils/clients/mod.rs +++ b/api/server/src/utils/clients/mod.rs @@ -1,6 +1,2 @@ -pub mod ai; -// pub mod aws; pub mod email; -pub mod posthog; pub mod sentry_utils; -pub mod typesense; diff --git a/api/server/src/utils/clients/posthog.rs b/api/server/src/utils/clients/posthog.rs deleted file mode 100644 index a7ba8db10..000000000 --- a/api/server/src/utils/clients/posthog.rs +++ /dev/null @@ -1,190 +0,0 @@ -use anyhow::{anyhow, Result}; -use reqwest::header::AUTHORIZATION; -use std::env; - -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::utils::clients::sentry_utils::send_sentry_error; - -lazy_static::lazy_static! { - static ref POSTHOG_API_KEY: String = env::var("POSTHOG_API_KEY").expect("POSTHOG_API_KEY must be set"); -} - -const POSTHOG_API_URL: &str = "https://us.i.posthog.com/capture/"; - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct PosthogEvent { - event: PosthogEventType, - distinct_id: Option, - properties: PosthogEventProperties, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -#[serde(rename_all = "snake_case")] -pub enum PosthogEventType { - MetricCreated, - MetricFollowUp, - MetricAddedToDashboard, - MetricViewed, - DashboardViewed, - TitleManuallyUpdated, - SqlManuallyUpdated, - ChartStylingManuallyUpdated, - ChartStylingAutoUpdated, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum PosthogEventProperties { - MetricCreated(MetricCreatedProperties), - MetricFollowUp(MetricFollowUpProperties), - MetricAddedToDashboard(MetricAddedToDashboardProperties), - MetricViewed(MetricViewedProperties), - DashboardViewed(DashboardViewedProperties), - TitleManuallyUpdated(TitleManuallyUpdatedProperties), - SqlManuallyUpdated(SqlManuallyUpdatedProperties), - ChartStylingManuallyUpdated(ChartStylingManuallyUpdatedProperties), - ChartStylingAutoUpdated(ChartStylingAutoUpdatedProperties), -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct MetricCreatedProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct MetricFollowUpProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct MetricAddedToDashboardProperties { - pub message_id: Uuid, - pub thread_id: Uuid, - pub dashboard_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct MetricViewedProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct DashboardViewedProperties { - pub dashboard_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct TitleManuallyUpdatedProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct SqlManuallyUpdatedProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChartStylingManuallyUpdatedProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct ChartStylingAutoUpdatedProperties { - pub message_id: Uuid, - pub thread_id: Uuid, -} - -pub async fn send_posthog_event_handler( - event_type: PosthogEventType, - user_id: Option, - properties: PosthogEventProperties, -) -> Result<()> { - let event = PosthogEvent { - event: event_type, - distinct_id: user_id, - properties, - }; - - tokio::spawn(async move { - match send_event(event).await { - Ok(_) => (), - Err(e) => { - send_sentry_error(&e.to_string(), None); - tracing::error!("Unable to send event to Posthog: {:?}", e); - } - } - }); - - Ok(()) -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct PosthogRequest { - #[serde(flatten)] - event: PosthogEvent, - api_key: String, -} - -async fn send_event(event: PosthogEvent) -> Result<()> { - let client = reqwest::Client::new(); - - let headers = { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - AUTHORIZATION, - format!("Bearer {}", POSTHOG_API_KEY.to_string()) - .parse() - .unwrap(), - ); - headers - }; - - let posthog_req = PosthogRequest { - event, - api_key: POSTHOG_API_KEY.to_string(), - }; - - match client - .post(POSTHOG_API_URL) - .headers(headers) - .json(&posthog_req) - .send() - .await - { - Ok(_) => (), - Err(e) => { - tracing::error!("Unable to send request to Posthog: {:?}", e); - return Err(anyhow!(e)); - } - }; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use dotenv::dotenv; - - #[tokio::test] - async fn test_send_event() { - dotenv().ok(); - - let _ = send_event(PosthogEvent { - event: PosthogEventType::MetricCreated, - distinct_id: Some(Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap()), - properties: PosthogEventProperties::MetricCreated(MetricCreatedProperties { - message_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(), - thread_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(), - }), - }) - .await; - } -} diff --git a/api/server/src/utils/clients/typesense.rs b/api/server/src/utils/clients/typesense.rs deleted file mode 100644 index 4755c522a..000000000 --- a/api/server/src/utils/clients/typesense.rs +++ /dev/null @@ -1,318 +0,0 @@ -use anyhow::{anyhow, Result}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::env; -use uuid::Uuid; - -lazy_static::lazy_static! { - static ref TYPESENSE_API_KEY: String = env::var("TYPESENSE_API_KEY").unwrap_or("xyz".to_string()); - static ref TYPESENSE_API_HOST: String = env::var("TYPESENSE_API_HOST").unwrap_or("http://localhost:8108".to_string()); -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "snake_case")] -pub enum CollectionName { - Messages, - Collections, - Datasets, - PermissionGroups, - Teams, - DataSources, - Terms, - Dashboards, - #[serde(untagged)] - StoredValues(String), -} - -impl CollectionName { - pub fn to_string(&self) -> String { - match self { - CollectionName::Messages => "messages".to_string(), - CollectionName::Collections => "collections".to_string(), - CollectionName::Datasets => "datasets".to_string(), - CollectionName::PermissionGroups => "permission_groups".to_string(), - CollectionName::Teams => "teams".to_string(), - CollectionName::DataSources => "data_sources".to_string(), - CollectionName::Terms => "terms".to_string(), - CollectionName::Dashboards => "dashboards".to_string(), - CollectionName::StoredValues(s) => s.clone(), - } - } -} - -#[derive(Deserialize, Serialize)] -pub struct MessageDocument { - pub id: Uuid, - pub name: String, - pub summary_question: String, - pub organization_id: Uuid, -} - -#[derive(Deserialize, Serialize)] -pub struct GenericDocument { - pub id: Uuid, - pub name: String, - pub organization_id: Uuid, -} - -#[derive(Deserialize, Serialize, Clone, Debug)] -pub struct StoredValueDocument { - pub id: Uuid, - pub value: String, - pub dataset_id: Uuid, - pub dataset_column_id: Uuid, -} - -#[derive(Deserialize, Serialize)] -#[serde(untagged)] -pub enum Document { - Message(MessageDocument), - Dashboard(GenericDocument), - Dataset(GenericDocument), - PermissionGroup(GenericDocument), - Team(GenericDocument), - DataSource(GenericDocument), - Term(GenericDocument), - Collection(GenericDocument), - StoredValue(StoredValueDocument), -} - -impl Document { - pub fn id(&self) -> Uuid { - match self { - Document::Message(m) => m.id, - Document::Dashboard(d) => d.id, - Document::Dataset(d) => d.id, - Document::PermissionGroup(d) => d.id, - Document::Team(d) => d.id, - Document::DataSource(d) => d.id, - Document::Term(d) => d.id, - Document::Collection(d) => d.id, - Document::StoredValue(d) => d.id, - } - } - - pub fn into_stored_value_document(&self) -> &StoredValueDocument { - match self { - Document::StoredValue(d) => d, - _ => panic!("Document is not a StoredValueDocument"), - } - } -} - -pub async fn upsert_document(collection: CollectionName, document: Document) -> Result<()> { - let client = reqwest::Client::builder().build()?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse()?); - headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?); - - match client - .request( - reqwest::Method::POST, - format!( - "{}/collections/{}/documents?action=upsert", - *TYPESENSE_API_HOST, - collection.to_string() - ), - ) - .headers(headers) - .json(&document) - .send() - .await - { - Ok(_) => (), - Err(e) => return Err(anyhow!("Error sending request: {}", e)), - }; - - Ok(()) -} - -#[derive(Serialize, Debug)] -pub struct SearchRequest { - pub searches: Vec, -} - -#[derive(Serialize, Debug)] -pub struct SearchRequestObject { - pub collection: CollectionName, - pub q: String, - pub query_by: String, - pub prefix: bool, - pub exclude_fields: String, - pub highlight_fields: String, - pub use_cache: bool, - pub filter_by: String, - pub vector_query: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub limit: Option, -} - -#[derive(Deserialize)] -pub struct RequestParams { - pub collection_name: CollectionName, -} - -#[derive(Deserialize)] -pub struct Highlight { - pub field: String, - pub matched_tokens: Vec, -} - -#[derive(Deserialize)] -pub struct HybridSearchInfo { - pub rank_fusion_score: f64, -} - -#[derive(Deserialize)] -pub struct Hit { - pub document: Document, - pub highlights: Vec, - pub hybrid_search_info: HybridSearchInfo, -} - -#[derive(Deserialize)] -pub struct SearchResult { - pub request_params: RequestParams, - pub hits: Vec, -} - -#[derive(Deserialize)] -pub struct SearchResponse { - pub results: Vec, -} - -pub async fn search_documents(search_reqs: Vec) -> Result { - let client = reqwest::Client::builder().build()?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse()?); - headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?); - - let search_req = SearchRequest { - searches: search_reqs, - }; - - let res = match client - .request( - reqwest::Method::POST, - format!("{}/multi_search", *TYPESENSE_API_HOST), - ) - .headers(headers) - .json(&search_req) - .send() - .await - { - Ok(res) => { - let res = if res.status().is_success() { - res - } else { - return Err(anyhow!( - "Error sending request: {}", - res.text().await.unwrap() - )); - }; - - res - } - Err(e) => return Err(anyhow!("Error sending request: {}", e)), - }; - - let search_response = match res.json::().await { - Ok(search_response) => search_response, - Err(e) => return Err(anyhow!("Error parsing response: {}", e)), - }; - - Ok(search_response) -} - -pub async fn create_collection(schema: Value) -> Result<()> { - let client = reqwest::Client::builder().build()?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse()?); - headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?); - - match client - .request( - reqwest::Method::POST, - format!("{}/collections", *TYPESENSE_API_HOST,), - ) - .headers(headers) - .json(&schema) - .send() - .await - { - Ok(_) => (), - Err(e) => return Err(anyhow!("Error sending request: {}", e)), - }; - - Ok(()) -} - -pub async fn delete_collection(collection_name: &String, filter_by: &String) -> Result<()> { - let client = reqwest::Client::builder().build()?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse()?); - headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?); - - match client - .request( - reqwest::Method::DELETE, - format!( - "{}/collections/{}/documents?filter_by={}", - *TYPESENSE_API_HOST, collection_name, filter_by - ), - ) - .headers(headers) - .send() - .await - { - Ok(_) => (), - Err(e) => return Err(anyhow!("Error sending request: {}", e)), - }; - - Ok(()) -} - -pub async fn bulk_insert_documents(collection_name: &String, documents: &Vec) -> Result<()> -where - T: serde::Serialize, -{ - let client = reqwest::Client::builder().build()?; - - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse()?); - headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?); - - let serialized_documents = documents - .iter() - .map(|doc| serde_json::to_string(doc)) - .collect::, _>>()? - .join("\n"); - - match client - .request( - reqwest::Method::POST, - format!( - "{}/collections/{}/documents/import?action=create", - *TYPESENSE_API_HOST, collection_name - ), - ) - .headers(headers) - .body(serialized_documents) - .send() - .await - { - Ok(res) => { - if res.status().is_success() { - tracing::info!("Res: {:?}", res.text().await.unwrap()); - return Ok(()); - } - tracing::error!("Error sending request: {}", res.text().await.unwrap()); - return Err(anyhow!("Error sending bulk insert request")); - } - Err(e) => return Err(anyhow!("Error sending request: {}", e)), - }; -} diff --git a/api/server/src/utils/mod.rs b/api/server/src/utils/mod.rs index bd4e4ad1b..12f063180 100644 --- a/api/server/src/utils/mod.rs +++ b/api/server/src/utils/mod.rs @@ -1,7 +1,5 @@ pub mod clients; pub mod security; -pub mod stored_values; pub use agents::*; pub use security::*; -pub use stored_values::*; diff --git a/api/server/src/utils/stored_values/mod.rs b/api/server/src/utils/stored_values/mod.rs deleted file mode 100644 index 0150ee482..000000000 --- a/api/server/src/utils/stored_values/mod.rs +++ /dev/null @@ -1,291 +0,0 @@ -pub mod search; -use query_engine::data_source_query_routes::query_engine::query_engine; - -use query_engine::data_types::DataType; -pub use search::*; - -use crate::utils::clients::ai::embedding_router::embedding_router; -use anyhow::Result; -use chrono::Utc; -use database::enums::StoredValuesStatus; -use database::{pool::get_pg_pool, schema::dataset_columns}; -use diesel::prelude::*; -use diesel::sql_types::{Array, Float4, Integer, Text, Timestamptz, Uuid as SqlUuid}; -use diesel_async::RunQueryDsl; -use uuid::Uuid; - -#[derive(Debug, QueryableByName)] -pub struct StoredValueRow { - #[diesel(sql_type = Text)] - pub value: String, -} - -#[derive(Debug, QueryableByName)] -pub struct StoredValueWithDistance { - #[diesel(sql_type = Text)] - pub value: String, - #[diesel(sql_type = Text)] - pub column_name: String, - #[diesel(sql_type = SqlUuid)] - pub column_id: Uuid, -} - -const BATCH_SIZE: usize = 10_000; -const MAX_VALUE_LENGTH: usize = 50; - -pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> { - let pool = get_pg_pool(); - let mut conn = pool.get().await?; - - // Create schema and table using raw SQL - let schema_name = organization_id.to_string().replace("-", "_"); - let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS values_{}", schema_name); - - let create_table_sql = format!( - "CREATE TABLE IF NOT EXISTS values_{}.values_v1 ( - value text NOT NULL, - dataset_id uuid NOT NULL, - column_name text NOT NULL, - column_id uuid NOT NULL, - embedding vector(1024), - created_at timestamp with time zone NOT NULL DEFAULT now(), - UNIQUE(dataset_id, column_name, value) - )", - schema_name - ); - - let create_index_sql = format!( - "CREATE INDEX IF NOT EXISTS values_v1_embedding_idx - ON values_{}.values_v1 - USING ivfflat (embedding vector_cosine_ops)", - schema_name - ); - - diesel::sql_query(create_schema_sql) - .execute(&mut conn) - .await?; - diesel::sql_query(create_table_sql) - .execute(&mut conn) - .await?; - diesel::sql_query(create_index_sql) - .execute(&mut conn) - .await?; - - Ok(()) -} - -pub async fn store_column_values( - organization_id: &Uuid, - dataset_id: &Uuid, - column_name: &str, - column_id: &Uuid, - _data_source_id: &Uuid, - schema: &str, - table_name: &str, -) -> Result<()> { - let pool = get_pg_pool(); - let mut conn = pool.get().await?; - - // Create schema and table if they don't exist - ensure_stored_values_schema(organization_id).await?; - - // Query distinct values in batches - let mut offset = 0; - let mut first_batch = true; - let schema_name = organization_id.to_string().replace("-", "_"); - - loop { - let query = format!( - "SELECT DISTINCT \"{}\" as value - FROM {}.{} - WHERE \"{}\" IS NOT NULL - AND length(\"{}\") <= {} - ORDER BY \"{}\" - LIMIT {} OFFSET {}", - column_name, - schema, - table_name, - column_name, - column_name, - MAX_VALUE_LENGTH, - column_name, - BATCH_SIZE, - offset - ); - - let results = match query_engine(dataset_id, &query, None).await { - Ok(results) => results.data, - Err(e) => { - tracing::error!("Error querying stored values: {:?}", e); - if first_batch { - return Err(e); - } - vec![] - } - }; - - if results.is_empty() { - break; - } - - // Extract values from the query results - let values: Vec = results - .into_iter() - .filter_map(|row| { - if let Some(DataType::Text(Some(value))) = row.get("value") { - Some(value.clone()) - } else { - None - } - }) - .collect(); - - if values.is_empty() { - break; - } - - // If this is the first batch and we have 15 or fewer values, handle as enum - if first_batch && values.len() <= 15 { - // Get current description - let current_description = - diesel::sql_query("SELECT description FROM dataset_columns WHERE id = $1") - .bind::(column_id) - .get_result::(&mut conn) - .await - .ok() - .and_then(|row| Some(row.value)); - - // Format new description - let enum_list = format!("Values for this column are: {}", values.join(", ")); - let new_description = match current_description { - Some(desc) if !desc.is_empty() => format!("{}. {}", desc, enum_list), - _ => enum_list, - }; - - // Update column description - diesel::update(dataset_columns::table) - .filter(dataset_columns::id.eq(column_id)) - .set(( - dataset_columns::description.eq(new_description), - dataset_columns::stored_values_status.eq(StoredValuesStatus::Success), - dataset_columns::stored_values_count.eq(values.len() as i64), - dataset_columns::stored_values_last_synced.eq(Utc::now()), - )) - .execute(&mut conn) - .await?; - - return Ok(()); - } - - // Create embeddings for the batch - let embeddings = create_embeddings_batch(&values).await?; - - // Insert values and embeddings - for (value, embedding) in values.iter().zip(embeddings.iter()) { - let insert_sql = format!( - "INSERT INTO {}_values.values_v1 - (value, dataset_id, column_name, column_id, embedding, created_at) - VALUES ($1::text, $2::uuid, $3::text, $4::uuid, $5::vector, $6::timestamptz) - ON CONFLICT (dataset_id, column_name, value) - DO UPDATE SET created_at = EXCLUDED.created_at", - schema_name - ); - - diesel::sql_query(insert_sql) - .bind::(value) - .bind::(dataset_id) - .bind::(column_name) - .bind::(column_id) - .bind::, _>(embedding) - .bind::(Utc::now()) - .execute(&mut conn) - .await?; - } - - first_batch = false; - offset += BATCH_SIZE; - } - - Ok(()) -} - -async fn create_embeddings_batch(values: &[String]) -> Result>> { - let embeddings = embedding_router(values.to_vec(), true).await?; - Ok(embeddings) -} - -pub async fn search_stored_values( - organization_id: &Uuid, - dataset_id: &Uuid, - query_embedding: Vec, - limit: Option, -) -> Result> { - let pool = get_pg_pool(); - let mut conn = pool.get().await?; - - let limit = limit.unwrap_or(10); - - let schema_name = organization_id.to_string().replace("-", "_"); - let query = format!( - "SELECT value, column_name, column_id - FROM values_{}.values_v1 - WHERE dataset_id = $2::uuid - ORDER BY embedding <=> $1::vector - LIMIT $3::integer", - schema_name - ); - - let results: Vec = diesel::sql_query(query) - .bind::, _>(query_embedding) - .bind::(dataset_id) - .bind::(limit) - .load(&mut conn) - .await?; - - Ok(results - .into_iter() - .map(|r| (r.value, r.column_name, r.column_id)) - .collect()) -} - -pub struct StoredValueColumn { - pub organization_id: Uuid, - pub dataset_id: Uuid, - pub column_name: String, - pub column_id: Uuid, - pub data_source_id: Uuid, - pub schema: String, - pub table_name: String, -} - -pub async fn process_stored_values_background(columns: Vec) { - for column in columns { - match store_column_values( - &column.organization_id, - &column.dataset_id, - &column.column_name, - &column.column_id, - &column.data_source_id, - &column.schema, - &column.table_name, - ) - .await - { - Ok(_) => { - tracing::info!( - "Successfully processed stored values for column '{}' in dataset '{}'", - column.column_name, - column.table_name - ); - } - Err(e) => { - tracing::error!( - "Failed to process stored values for column '{}' in dataset '{}': {:?}", - column.column_name, - column.table_name, - e - ); - } - } - } -} diff --git a/api/server/src/utils/stored_values/search.rs b/api/server/src/utils/stored_values/search.rs deleted file mode 100644 index e42c27877..000000000 --- a/api/server/src/utils/stored_values/search.rs +++ /dev/null @@ -1,64 +0,0 @@ -use anyhow::Result; -use cohere_rust::{api::rerank::{ReRankModel, ReRankRequest}, Cohere}; -use serde::{Deserialize, Serialize}; -use uuid::Uuid; - -use crate::utils::clients::ai::embedding_router::embedding_router; - -use super::search_stored_values; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StoredValue { - pub value: String, - pub dataset_id: Uuid, - pub column_name: String, - pub column_id: Uuid, -} - -pub async fn search_values_for_dataset( - organization_id: &Uuid, - dataset_id: &Uuid, - query: String, -) -> Result> { - // Create embedding for the search query - let query_vec = vec![query.clone()]; - let query_embedding = embedding_router(query_vec, true).await?[0].clone(); - - // Get initial candidates using vector similarity - let candidates = search_stored_values( - organization_id, - dataset_id, - query_embedding, - Some(25), // Get more candidates for reranking - ).await?; - - // Extract just the values for reranking - let candidate_values: Vec = candidates.iter().map(|(value, _, _)| value.clone()).collect(); - - // Rerank the candidates using the cohere client - let co = Cohere::default(); - let request = ReRankRequest { - query: query.as_str(), - documents: &candidate_values, - model: ReRankModel::EnglishV3, - top_n: Some(10), - ..Default::default() - }; - - let response = co.rerank(&request).await?; - - // Convert to StoredValue structs - let values = response.into_iter() - .map(|result| { - let (value, column_name, column_id) = candidates[result.index as usize].clone(); - StoredValue { - value, - dataset_id: *dataset_id, - column_name, - column_id, - } - }) - .collect(); - - Ok(values) -} \ No newline at end of file