mirror of https://github.com/buster-so/buster.git
Merge branch 'evals' of https://github.com/buster-so/buster into evals
This commit is contained in:
commit
93439d45b0
|
@ -36,14 +36,14 @@ pub async fn post_dataset(
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::FORBIDDEN,
|
StatusCode::FORBIDDEN,
|
||||||
"User does not belong to any organization",
|
"User does not belong to any organization".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Error getting user organization id: {:?}", e);
|
tracing::error!("Error getting user organization id: {:?}", e);
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"Error getting user organization id",
|
"Error getting user organization id".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -53,14 +53,14 @@ pub async fn post_dataset(
|
||||||
Ok(false) => {
|
Ok(false) => {
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::FORBIDDEN,
|
StatusCode::FORBIDDEN,
|
||||||
"Insufficient permissions",
|
"Insufficient permissions".to_string(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("Error checking user permissions: {:?}", e);
|
tracing::error!("Error checking user permissions: {:?}", e);
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"Error checking user permissions",
|
"Error checking user permissions".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -78,7 +78,7 @@ pub async fn post_dataset(
|
||||||
tracing::error!("Error creating dataset: {:?}", e);
|
tracing::error!("Error creating dataset: {:?}", e);
|
||||||
return Err((
|
return Err((
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
"Error creating dataset",
|
"Error creating dataset".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
pub mod types;
|
|
|
@ -1,383 +0,0 @@
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub enum ViewType {
|
|
||||||
Chart,
|
|
||||||
Table,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub enum ChartType {
|
|
||||||
Line,
|
|
||||||
Bar,
|
|
||||||
Scatter,
|
|
||||||
Pie,
|
|
||||||
Metric,
|
|
||||||
Table,
|
|
||||||
Combo,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChartType {
|
|
||||||
pub fn to_string(&self) -> String {
|
|
||||||
match self {
|
|
||||||
ChartType::Line => "line".to_string(),
|
|
||||||
ChartType::Bar => "bar".to_string(),
|
|
||||||
ChartType::Scatter => "scatter".to_string(),
|
|
||||||
ChartType::Pie => "pie".to_string(),
|
|
||||||
ChartType::Metric => "metric".to_string(),
|
|
||||||
ChartType::Table => "table".to_string(),
|
|
||||||
ChartType::Combo => "combo".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_string(chart_type: &str) -> ChartType {
|
|
||||||
match chart_type {
|
|
||||||
"line" => ChartType::Line,
|
|
||||||
"bar" => ChartType::Bar,
|
|
||||||
"scatter" => ChartType::Scatter,
|
|
||||||
"pie" => ChartType::Pie,
|
|
||||||
"metric" => ChartType::Metric,
|
|
||||||
"table" => ChartType::Table,
|
|
||||||
"combo" => ChartType::Combo,
|
|
||||||
_ => ChartType::Table,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct BusterChartConfig {
|
|
||||||
pub selected_chart_type: ChartType,
|
|
||||||
pub selected_view: ViewType,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub column_label_formats: Option<HashMap<String, ColumnLabelFormat>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub column_settings: Option<HashMap<String, ColumnSettings>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub colors: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub show_legend: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub grid_lines: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub show_legend_headline: Option<ShowLegendHeadline>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub goal_lines: Option<Vec<GoalLine>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub trendlines: Option<Vec<Trendline>>,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub y_axis_config: YAxisConfig,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub x_axis_config: XAxisConfig,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub y2_axis_config: Y2AxisConfig,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub bar_chart_props: BarChartProps,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub line_chart_props: LineChartProps,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub scatter_chart_props: ScatterChartProps,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub pie_chart_props: PieChartProps,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub table_chart_props: TableChartProps,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub combo_chart_props: ComboChartProps,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub metric_chart_props: MetricChartProps,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct LineChartProps {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_style: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_group_type: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct BarChartProps {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub bar_and_line_axis: Option<BarAndLineAxis>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub bar_layout: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub bar_sort_by: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub bar_group_type: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub bar_show_total_at_top: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct ScatterChartProps {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub scatter_axis: Option<ScatterAxis>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub scatter_dot_size: Option<(f64, f64)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct PieChartProps {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_chart_axis: Option<PieChartAxis>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_display_label_as: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_show_inner_label: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_inner_label_aggregate: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_inner_label_title: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_label_position: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_donut_width: Option<f64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub pie_minimum_slice_percentage: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct TableChartProps {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub table_column_order: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub table_column_widths: Option<HashMap<String, f64>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub table_header_background_color: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub table_header_font_color: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub table_column_font_color: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct ComboChartProps {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub combo_chart_axis: Option<ComboChartAxis>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct MetricChartProps {
|
|
||||||
pub metric_column_id: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub metric_value_aggregate: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub metric_header: Option<MetricTitle>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub metric_sub_header: Option<MetricTitle>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub metric_value_label: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum MetricTitle {
|
|
||||||
String(String),
|
|
||||||
Derived(DerivedMetricTitle),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct DerivedMetricTitle {
|
|
||||||
pub column_id: String,
|
|
||||||
pub use_value: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ShowLegendHeadline {
|
|
||||||
Bool(bool),
|
|
||||||
String(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct ColumnSettings {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub show_data_labels: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub column_visualization: Option<String>,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub line_settings: LineColumnSettings,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub bar_settings: BarColumnSettings,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub dot_settings: DotColumnSettings,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct LineColumnSettings {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_dash_style: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_width: Option<f64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_style: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_type: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_symbol_size: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct BarColumnSettings {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub bar_roundness: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct DotColumnSettings {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub line_symbol_size: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct ColumnLabelFormat {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub style: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub column_type: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub display_name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub number_separator_style: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub minimum_fraction_digits: Option<i32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub maximum_fraction_digits: Option<i32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub multiplier: Option<f64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prefix: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub suffix: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub replace_missing_data_with: Option<serde_json::Value>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub use_relative_time: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub is_utc: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub make_label_human_readable: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub currency: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub date_format: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub convert_number_to: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct GoalLine {
|
|
||||||
pub show: bool,
|
|
||||||
pub value: f64,
|
|
||||||
pub show_goal_line_label: bool,
|
|
||||||
pub goal_line_label: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub goal_line_color: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct Trendline {
|
|
||||||
pub show: bool,
|
|
||||||
pub show_trendline_label: bool,
|
|
||||||
pub trendline_label: Option<String>,
|
|
||||||
pub type_: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub trendline_color: Option<String>,
|
|
||||||
pub column_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct YAxisConfig {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y_axis_show_axis_label: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y_axis_show_axis_title: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y_axis_axis_title: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y_axis_start_axis_at_zero: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y_axis_scale_type: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct Y2AxisConfig {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y2_axis_show_axis_label: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y2_axis_show_axis_title: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y2_axis_axis_title: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y2_axis_start_axis_at_zero: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y2_axis_scale_type: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct XAxisConfig {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub x_axis_show_ticks: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub x_axis_show_axis_label: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub x_axis_show_axis_title: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub x_axis_axis_title: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub x_axis_label_rotation: Option<i32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub x_axis_data_zoom: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
|
||||||
pub struct CategoryAxisStyleConfig {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub category_show_total_at_top: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub category_axis_title: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct BarAndLineAxis {
|
|
||||||
pub x: Vec<String>,
|
|
||||||
pub y: Vec<String>,
|
|
||||||
pub category: Vec<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tooltip: Option<Vec<String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct ScatterAxis {
|
|
||||||
pub x: Vec<String>,
|
|
||||||
pub y: Vec<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub category: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub size: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tooltip: Option<Vec<String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct ComboChartAxis {
|
|
||||||
pub x: Vec<String>,
|
|
||||||
pub y: Vec<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub y2: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub category: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tooltip: Option<Vec<String>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct PieChartAxis {
|
|
||||||
pub x: Vec<String>,
|
|
||||||
pub y: Vec<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tooltip: Option<Vec<String>>,
|
|
||||||
}
|
|
|
@ -1,238 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use futures::StreamExt;
|
|
||||||
use std::{env, time::Duration};
|
|
||||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
|
||||||
|
|
||||||
const ANTHROPIC_CHAT_URL: &str = "https://api.anthropic.com/v1/messages";
|
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
|
||||||
static ref ANTHROPIC_API_KEY: String = env::var("ANTHROPIC_API_KEY")
|
|
||||||
.expect("ANTHROPIC_API_KEY must be set");
|
|
||||||
static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED")
|
|
||||||
.unwrap_or(String::from("true"))
|
|
||||||
.parse()
|
|
||||||
.expect("MONITORING_ENABLED must be a boolean");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub enum AnthropicChatModel {
|
|
||||||
#[serde(rename = "claude-3-opus-20240229")]
|
|
||||||
Claude3Opus20240229,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum AnthropicChatRole {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum AnthropicContentType {
|
|
||||||
Text,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
|
||||||
pub struct AnthropicContent {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub _type: AnthropicContentType,
|
|
||||||
pub text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub struct AnthropicChatMessage {
|
|
||||||
pub role: AnthropicChatRole,
|
|
||||||
pub content: Vec<AnthropicContent>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub struct AnthropicChatRequest {
|
|
||||||
pub model: AnthropicChatModel,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system: Option<String>,
|
|
||||||
pub messages: Vec<AnthropicChatMessage>,
|
|
||||||
pub temperature: f32,
|
|
||||||
pub max_tokens: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_sequences: Option<Vec<String>>,
|
|
||||||
pub stream: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub content: Vec<Content>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
|
||||||
pub struct Content {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub _type: String,
|
|
||||||
pub text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn anthropic_chat(
|
|
||||||
model: &AnthropicChatModel,
|
|
||||||
system: Option<String>,
|
|
||||||
messages: &Vec<AnthropicChatMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let chat_request = AnthropicChatRequest {
|
|
||||||
model: model.clone(),
|
|
||||||
system,
|
|
||||||
messages: messages.clone(),
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
stop_sequences: stop,
|
|
||||||
stream: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let headers = {
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
"x-api-key",
|
|
||||||
format!("{}", ANTHROPIC_API_KEY.to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
|
|
||||||
headers
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = match client
|
|
||||||
.post(ANTHROPIC_CHAT_URL)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&chat_request)
|
|
||||||
.timeout(Duration::from_secs(timeout))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Unable to send request to Anthropic: {:?}", e);
|
|
||||||
let err = anyhow!("Unable to send request to Anthropic: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let completion_res = match response.json::<ChatCompletionResponse>().await {
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Unable to parse response from Anthropic: {:?}", e);
|
|
||||||
let err = anyhow!("Unable to parse response from Anthropic: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = match completion_res.content.get(0) {
|
|
||||||
Some(content) => content.text.clone(),
|
|
||||||
None => return Err(anyhow!("No content returned from Anthropic")),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct AnthropicChatDelta {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub _type: String,
|
|
||||||
pub delta: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct AnthropicChatStreamResponse {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub _type: String,
|
|
||||||
pub delta: Option<AnthropicChatDelta>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn anthropic_chat_stream(
|
|
||||||
model: &AnthropicChatModel,
|
|
||||||
system: Option<String>,
|
|
||||||
messages: &Vec<AnthropicChatMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
) -> Result<ReceiverStream<String>> {
|
|
||||||
let chat_request = AnthropicChatRequest {
|
|
||||||
model: model.clone(),
|
|
||||||
system,
|
|
||||||
messages: messages.clone(),
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
stream: true,
|
|
||||||
stop_sequences: stop,
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let headers = {
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
format!("Bearer {}", ANTHROPIC_API_KEY.to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers
|
|
||||||
};
|
|
||||||
|
|
||||||
let (_tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel(100);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let response = client
|
|
||||||
.post(ANTHROPIC_CHAT_URL)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&chat_request)
|
|
||||||
.timeout(Duration::from_secs(timeout))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::error!("Unable to send request to Anthropic: {:?}", e);
|
|
||||||
let err = anyhow!("Unable to send request to Anthropic: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
err
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(_e) = response {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = response.unwrap();
|
|
||||||
let mut stream = response.bytes_stream();
|
|
||||||
|
|
||||||
let _buffer = String::new();
|
|
||||||
|
|
||||||
while let Some(item) = stream.next().await {
|
|
||||||
match item {
|
|
||||||
Ok(bytes) => {
|
|
||||||
let chunk = String::from_utf8(bytes.to_vec()).unwrap();
|
|
||||||
println!("----------------------");
|
|
||||||
println!("Chunk: {}", chunk);
|
|
||||||
println!("----------------------");
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Error while streaming response: {:?}", e);
|
|
||||||
let err = anyhow!("Error while streaming response: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(ReceiverStream::new(rx))
|
|
||||||
}
|
|
|
@ -1,59 +0,0 @@
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tokio::task;
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
hugging_face::hugging_face_embedding, ollama::ollama_embedding, openai::ada_bulk_embedding,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct EmbeddingRequest {
|
|
||||||
pub prompt: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct EmbeddingResponse {
|
|
||||||
pub embedding: Vec<Vec<f32>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum EmbeddingProvider {
|
|
||||||
OpenAi,
|
|
||||||
Ollama,
|
|
||||||
HuggingFace,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingProvider {
|
|
||||||
pub fn get_embedding_provider() -> Result<EmbeddingProvider> {
|
|
||||||
let embedding_provider =
|
|
||||||
env::var("EMBEDDING_PROVIDER").expect("An embedding provider is required.");
|
|
||||||
match embedding_provider.as_str() {
|
|
||||||
"openai" => Ok(EmbeddingProvider::OpenAi),
|
|
||||||
"ollama" => Ok(EmbeddingProvider::Ollama),
|
|
||||||
"huggingface" => Ok(EmbeddingProvider::HuggingFace),
|
|
||||||
_ => Err(anyhow!("Invalid embedding provider")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn embedding_router(prompts: Vec<String>, for_retrieval: bool) -> Result<Vec<Vec<f32>>> {
|
|
||||||
let embedding_provider = EmbeddingProvider::get_embedding_provider()?;
|
|
||||||
|
|
||||||
match embedding_provider {
|
|
||||||
EmbeddingProvider::Ollama => {
|
|
||||||
let tasks = prompts.into_iter().map(|prompt| {
|
|
||||||
task::spawn(async move { ollama_embedding(prompt, for_retrieval).await })
|
|
||||||
});
|
|
||||||
let results = futures::future::join_all(tasks).await;
|
|
||||||
let embeddings: Result<Vec<Vec<f32>>> = results
|
|
||||||
.into_iter()
|
|
||||||
.collect::<Result<Vec<_>, _>>()?
|
|
||||||
.into_iter()
|
|
||||||
.collect();
|
|
||||||
embeddings
|
|
||||||
}
|
|
||||||
EmbeddingProvider::HuggingFace => hugging_face_embedding(prompts).await,
|
|
||||||
EmbeddingProvider::OpenAi => ada_bulk_embedding(prompts).await,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use serde::Serialize;
|
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use axum::http::HeaderMap;
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct HuggingFaceEmbeddingRequest {
|
|
||||||
pub inputs: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn hugging_face_embedding(prompts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
|
||||||
let hugging_face_url = env::var("HUGGING_FACE_URL").expect("HUGGING_FACE_URL must be set");
|
|
||||||
let hugging_face_api_key =
|
|
||||||
env::var("HUGGING_FACE_API_KEY").expect("HUGGING_FACE_API_KEY must be set");
|
|
||||||
|
|
||||||
let client = match reqwest::Client::builder().build() {
|
|
||||||
Ok(client) => client,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(anyhow!("Error creating reqwest client: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::CONTENT_TYPE,
|
|
||||||
"application/json".parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
format!("Bearer {}", hugging_face_api_key).parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let req = HuggingFaceEmbeddingRequest { inputs: prompts };
|
|
||||||
|
|
||||||
let res = match client
|
|
||||||
.post(hugging_face_url)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&req)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(anyhow!("Error sending Ollama request: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let embeddings = match res.json::<Vec<Vec<f32>>>().await {
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(anyhow!("Error parsing Ollama response: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(embeddings)
|
|
||||||
}
|
|
|
@ -1,394 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use axum::http::HeaderMap;
|
|
||||||
use base64::Engine;
|
|
||||||
use reqwest::Method;
|
|
||||||
use std::env;
|
|
||||||
use tiktoken_rs::o200k_base;
|
|
||||||
|
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
|
||||||
|
|
||||||
use super::{anthropic::AnthropicChatModel, llm_router::LlmModel, openai::OpenAiChatModel};
|
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
|
||||||
static ref LANGFUSE_API_URL: String = env::var("LANGFUSE_API_URL").unwrap_or("https://us.cloud.langfuse.com".to_string());
|
|
||||||
static ref LANGFUSE_API_PUBLIC_KEY: String = env::var("LANGFUSE_PUBLIC_API_KEY").expect("LANGFUSE_PUBLIC_API_KEY must be set");
|
|
||||||
static ref LANGFUSE_API_PRIVATE_KEY: String = env::var("LANGFUSE_PRIVATE_API_KEY").expect("LANGFUSE_PRIVATE_API_KEY must be set");
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlmModel {
|
|
||||||
pub fn generate_usage(&self, input: &String, output: &String) -> Usage {
|
|
||||||
let bpe = o200k_base().unwrap();
|
|
||||||
|
|
||||||
let input_token = bpe.encode_with_special_tokens(&input);
|
|
||||||
let output_token = bpe.encode_with_special_tokens(&output);
|
|
||||||
|
|
||||||
match self {
|
|
||||||
LlmModel::OpenAi(OpenAiChatModel::O3Mini) => Usage {
|
|
||||||
input: input_token.len() as u32,
|
|
||||||
output: output_token.len() as u32,
|
|
||||||
unit: "TOKENS".to_string(),
|
|
||||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5,
|
|
||||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 10.0,
|
|
||||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5
|
|
||||||
+ (output_token.len() as f64 / 1_000_000.0) * 10.0,
|
|
||||||
},
|
|
||||||
LlmModel::OpenAi(OpenAiChatModel::Gpt35Turbo) => Usage {
|
|
||||||
input: input_token.len() as u32,
|
|
||||||
output: output_token.len() as u32,
|
|
||||||
unit: "TOKENS".to_string(),
|
|
||||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5,
|
|
||||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 1.5,
|
|
||||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5
|
|
||||||
+ (output_token.len() as f64 / 1_000_000.0) * 1.5,
|
|
||||||
},
|
|
||||||
LlmModel::OpenAi(OpenAiChatModel::Gpt4o) => Usage {
|
|
||||||
input: input_token.len() as u32,
|
|
||||||
output: output_token.len() as u32,
|
|
||||||
unit: "TOKENS".to_string(),
|
|
||||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15,
|
|
||||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 0.6,
|
|
||||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15
|
|
||||||
+ (output_token.len() as f64 / 1_000_000.0) * 0.6,
|
|
||||||
},
|
|
||||||
LlmModel::Anthropic(AnthropicChatModel::Claude3Opus20240229) => Usage {
|
|
||||||
input: input_token.len() as u32,
|
|
||||||
output: output_token.len() as u32,
|
|
||||||
unit: "TOKENS".to_string(),
|
|
||||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0,
|
|
||||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 15.0,
|
|
||||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0
|
|
||||||
+ (output_token.len() as f64 / 1_000_000.0) * 15.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone, Debug)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum PromptName {
|
|
||||||
SelectDataset,
|
|
||||||
GenerateSql,
|
|
||||||
SelectTerm,
|
|
||||||
DataSummary,
|
|
||||||
AutoChartConfig,
|
|
||||||
LineChartConfig,
|
|
||||||
BarChartConfig,
|
|
||||||
ScatterChartConfig,
|
|
||||||
PieChartConfig,
|
|
||||||
MetricChartConfig,
|
|
||||||
TableConfig,
|
|
||||||
ColumnLabelFormat,
|
|
||||||
AdvancedVisualizationConfig,
|
|
||||||
NoDataReturnedResponse,
|
|
||||||
DataExplanation,
|
|
||||||
MetricTitle,
|
|
||||||
TimeFrame,
|
|
||||||
FixSqlPlanner,
|
|
||||||
FixSql,
|
|
||||||
GenerateColDescriptions,
|
|
||||||
GenerateDatasetDescription,
|
|
||||||
SummaryQuestion,
|
|
||||||
CustomPrompt(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PromptName {
|
|
||||||
fn to_string(&self) -> String {
|
|
||||||
match self {
|
|
||||||
PromptName::SelectDataset => "select_dataset".to_string(),
|
|
||||||
PromptName::GenerateSql => "generate_sql".to_string(),
|
|
||||||
PromptName::SelectTerm => "select_term".to_string(),
|
|
||||||
PromptName::DataSummary => "data_summary".to_string(),
|
|
||||||
PromptName::AutoChartConfig => "auto_chart_config".to_string(),
|
|
||||||
PromptName::LineChartConfig => "line_chart_config".to_string(),
|
|
||||||
PromptName::BarChartConfig => "bar_chart_config".to_string(),
|
|
||||||
PromptName::ScatterChartConfig => "scatter_chart_config".to_string(),
|
|
||||||
PromptName::PieChartConfig => "pie_chart_config".to_string(),
|
|
||||||
PromptName::MetricChartConfig => "metric_chart_config".to_string(),
|
|
||||||
PromptName::TableConfig => "table_config".to_string(),
|
|
||||||
PromptName::ColumnLabelFormat => "column_label_format".to_string(),
|
|
||||||
PromptName::AdvancedVisualizationConfig => "advanced_visualization_config".to_string(),
|
|
||||||
PromptName::NoDataReturnedResponse => "no_data_returned_response".to_string(),
|
|
||||||
PromptName::DataExplanation => "data_explanation".to_string(),
|
|
||||||
PromptName::MetricTitle => "metric_title".to_string(),
|
|
||||||
PromptName::TimeFrame => "time_frame".to_string(),
|
|
||||||
PromptName::FixSqlPlanner => "fix_sql_planner".to_string(),
|
|
||||||
PromptName::FixSql => "fix_sql".to_string(),
|
|
||||||
PromptName::GenerateColDescriptions => "generate_col_descriptions".to_string(),
|
|
||||||
PromptName::GenerateDatasetDescription => "generate_dataset_description".to_string(),
|
|
||||||
PromptName::SummaryQuestion => "summary_question".to_string(),
|
|
||||||
PromptName::CustomPrompt(prompt) => prompt.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
#[serde(rename_all = "kebab-case")]
|
|
||||||
enum LangfuseIngestionType {
|
|
||||||
TraceCreate,
|
|
||||||
GenerationCreate,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
struct CreateTraceBody {
|
|
||||||
id: Uuid,
|
|
||||||
timestamp: DateTime<Utc>,
|
|
||||||
name: String,
|
|
||||||
user_id: Uuid,
|
|
||||||
input: String,
|
|
||||||
output: String,
|
|
||||||
session_id: Uuid,
|
|
||||||
release: String,
|
|
||||||
version: String,
|
|
||||||
metadata: Metadata,
|
|
||||||
tags: Vec<String>,
|
|
||||||
public: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
struct GenerationCreateBody {
|
|
||||||
trace_id: Uuid,
|
|
||||||
name: String,
|
|
||||||
start_time: DateTime<Utc>,
|
|
||||||
completion_start_time: DateTime<Utc>,
|
|
||||||
input: String,
|
|
||||||
output: String,
|
|
||||||
level: String,
|
|
||||||
end_time: DateTime<Utc>,
|
|
||||||
model: LlmModel,
|
|
||||||
id: Uuid,
|
|
||||||
usage: Usage,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
#[serde(rename_all = "camelCase")]
|
|
||||||
pub struct Usage {
|
|
||||||
input: u32,
|
|
||||||
output: u32,
|
|
||||||
unit: String,
|
|
||||||
input_cost: f64,
|
|
||||||
output_cost: f64,
|
|
||||||
total_cost: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
struct Metadata {}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
enum LangfuseRequestBody {
|
|
||||||
CreateTraceBody(CreateTraceBody),
|
|
||||||
GenerationCreateBody(GenerationCreateBody),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct LangfuseBatchItem {
|
|
||||||
id: Uuid,
|
|
||||||
r#type: LangfuseIngestionType,
|
|
||||||
body: LangfuseRequestBody,
|
|
||||||
timestamp: DateTime<Utc>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct LangfuseBatch {
|
|
||||||
batch: Vec<LangfuseBatchItem>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct LangfuseError {
|
|
||||||
id: String,
|
|
||||||
status: u16,
|
|
||||||
message: String,
|
|
||||||
error: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct LangfuseSuccess {
|
|
||||||
id: String,
|
|
||||||
status: u16,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct LangfuseResponse {
|
|
||||||
successes: Vec<LangfuseSuccess>,
|
|
||||||
errors: Vec<LangfuseError>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Args:
|
|
||||||
/// - session_id: this can be a thread_id or any other type of chain event we have
|
|
||||||
///
|
|
||||||
|
|
||||||
pub async fn send_langfuse_request(
|
|
||||||
session_id: &Uuid,
|
|
||||||
prompt_name: PromptName,
|
|
||||||
context: Option<String>,
|
|
||||||
start_time: DateTime<Utc>,
|
|
||||||
end_time: DateTime<Utc>,
|
|
||||||
input: String,
|
|
||||||
output: String,
|
|
||||||
user_id: &Uuid,
|
|
||||||
langfuse_model: &LlmModel,
|
|
||||||
) -> () {
|
|
||||||
let session_id = session_id.clone();
|
|
||||||
let user_id = user_id.clone();
|
|
||||||
let langfuse_model = langfuse_model.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
match langfuse_handler(
|
|
||||||
session_id,
|
|
||||||
prompt_name,
|
|
||||||
context,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
input,
|
|
||||||
output,
|
|
||||||
user_id,
|
|
||||||
langfuse_model,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => {
|
|
||||||
let err = anyhow!("Error sending Langfuse request: {:?}", e);
|
|
||||||
send_sentry_error(&err.to_string(), Some(&user_id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn langfuse_handler(
|
|
||||||
session_id: Uuid,
|
|
||||||
prompt_name: PromptName,
|
|
||||||
context: Option<String>,
|
|
||||||
start_time: DateTime<Utc>,
|
|
||||||
end_time: DateTime<Utc>,
|
|
||||||
input: String,
|
|
||||||
output: String,
|
|
||||||
user_id: Uuid,
|
|
||||||
langfuse_model: LlmModel,
|
|
||||||
) -> Result<()> {
|
|
||||||
let input = match context {
|
|
||||||
Some(context) => format!("{} \n\n {}", context, input),
|
|
||||||
None => input,
|
|
||||||
};
|
|
||||||
|
|
||||||
let trace_id = Uuid::new_v4();
|
|
||||||
|
|
||||||
let langfuse_trace = LangfuseBatchItem {
|
|
||||||
id: Uuid::new_v4(),
|
|
||||||
r#type: LangfuseIngestionType::TraceCreate,
|
|
||||||
body: LangfuseRequestBody::CreateTraceBody(CreateTraceBody {
|
|
||||||
id: trace_id.clone(),
|
|
||||||
timestamp: Utc::now(),
|
|
||||||
name: prompt_name.clone().to_string(),
|
|
||||||
user_id: user_id,
|
|
||||||
input: serde_json::to_string(&input).unwrap(),
|
|
||||||
output: serde_json::to_string(&output).unwrap(),
|
|
||||||
session_id: session_id,
|
|
||||||
release: "1.0.0".to_string(),
|
|
||||||
version: "1.0.0".to_string(),
|
|
||||||
metadata: Metadata {},
|
|
||||||
tags: vec![],
|
|
||||||
public: false,
|
|
||||||
}),
|
|
||||||
timestamp: Utc::now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let langfuse_generation = LangfuseBatchItem {
|
|
||||||
id: Uuid::new_v4(),
|
|
||||||
r#type: LangfuseIngestionType::GenerationCreate,
|
|
||||||
body: LangfuseRequestBody::GenerationCreateBody(GenerationCreateBody {
|
|
||||||
id: Uuid::new_v4(),
|
|
||||||
name: prompt_name.to_string(),
|
|
||||||
input: serde_json::to_string(&input).unwrap(),
|
|
||||||
output: serde_json::to_string(&output).unwrap(),
|
|
||||||
trace_id,
|
|
||||||
start_time,
|
|
||||||
completion_start_time: start_time,
|
|
||||||
level: "DEBUG".to_string(),
|
|
||||||
end_time,
|
|
||||||
model: langfuse_model.clone(),
|
|
||||||
usage: langfuse_model.generate_usage(&input, &output),
|
|
||||||
}),
|
|
||||||
timestamp: Utc::now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let langfuse_batch = LangfuseBatch {
|
|
||||||
batch: vec![langfuse_trace, langfuse_generation],
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = match reqwest::Client::builder().build() {
|
|
||||||
Ok(client) => client,
|
|
||||||
Err(e) => {
|
|
||||||
send_sentry_error(&e.to_string(), Some(&user_id));
|
|
||||||
return Err(anyhow!("Error creating reqwest client: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::CONTENT_TYPE,
|
|
||||||
"application/json".parse().unwrap(),
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
format!(
|
|
||||||
"Basic {}",
|
|
||||||
base64::engine::general_purpose::STANDARD.encode(format!(
|
|
||||||
"{}:{}",
|
|
||||||
*LANGFUSE_API_PUBLIC_KEY, *LANGFUSE_API_PRIVATE_KEY
|
|
||||||
))
|
|
||||||
)
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let res = match client
|
|
||||||
.request(
|
|
||||||
Method::POST,
|
|
||||||
LANGFUSE_API_URL.to_string() + "/api/public/ingestion",
|
|
||||||
)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&langfuse_batch)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::debug!("Error sending Langfuse request: {:?}", e);
|
|
||||||
send_sentry_error(&e.to_string(), Some(&user_id));
|
|
||||||
return Err(anyhow!("Error sending Langfuse request: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let langfuse_res: LangfuseResponse = match res.json::<LangfuseResponse>().await {
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::debug!("Error parsing Langfuse response: {:?}", e);
|
|
||||||
send_sentry_error(&e.to_string(), Some(&user_id));
|
|
||||||
return Err(anyhow!("Error parsing Langfuse response: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if langfuse_res.errors.len() > 0 {
|
|
||||||
for err in &langfuse_res.errors {
|
|
||||||
tracing::debug!("Langfuse error: {:?}", err);
|
|
||||||
send_sentry_error(&err.message, Some(&user_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
return Err(anyhow::anyhow!(
|
|
||||||
"Langfuse errors: {:?}",
|
|
||||||
langfuse_res.errors
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
|
@ -1,383 +0,0 @@
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use chrono::Utc;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
use tokio::{
|
|
||||||
sync::mpsc::{self, Receiver},
|
|
||||||
task::JoinHandle,
|
|
||||||
};
|
|
||||||
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
anthropic::{
|
|
||||||
anthropic_chat, anthropic_chat_stream, AnthropicChatMessage, AnthropicChatModel,
|
|
||||||
AnthropicChatRole, AnthropicContent, AnthropicContentType,
|
|
||||||
},
|
|
||||||
langfuse::{send_langfuse_request, PromptName},
|
|
||||||
openai::{
|
|
||||||
openai_chat, openai_chat_stream, OpenAiChatContent, OpenAiChatMessage, OpenAiChatModel,
|
|
||||||
OpenAiChatRole,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use lazy_static::lazy_static;
|
|
||||||
|
|
||||||
lazy_static! {
|
|
||||||
static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED")
|
|
||||||
.unwrap_or(String::from("true"))
|
|
||||||
.parse()
|
|
||||||
.expect("MONITORING_ENABLED must be a boolean");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum LlmModel {
|
|
||||||
Anthropic(AnthropicChatModel),
|
|
||||||
OpenAi(OpenAiChatModel),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub enum LlmRole {
|
|
||||||
System,
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub struct LlmMessage {
|
|
||||||
pub role: LlmRole,
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlmMessage {
|
|
||||||
pub fn new(role: String, content: String) -> Self {
|
|
||||||
let role = match role.to_lowercase().as_str() {
|
|
||||||
"system" => LlmRole::System,
|
|
||||||
"user" => LlmRole::User,
|
|
||||||
"assistant" => LlmRole::Assistant,
|
|
||||||
_ => LlmRole::User,
|
|
||||||
};
|
|
||||||
Self { role, content }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn llm_chat(
|
|
||||||
model: LlmModel,
|
|
||||||
messages: &Vec<LlmMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
json_mode: bool,
|
|
||||||
json_schema: Option<Value>,
|
|
||||||
session_id: &Uuid,
|
|
||||||
user_id: &Uuid,
|
|
||||||
prompt_name: PromptName,
|
|
||||||
) -> Result<String> {
|
|
||||||
let start_time = Utc::now();
|
|
||||||
|
|
||||||
let response_result = match &model {
|
|
||||||
LlmModel::Anthropic(model) => {
|
|
||||||
anthropic_chat_compiler(model, messages, max_tokens, temperature, timeout, stop).await
|
|
||||||
}
|
|
||||||
LlmModel::OpenAi(model) => {
|
|
||||||
openai_chat_compiler(
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
7048,
|
|
||||||
temperature,
|
|
||||||
timeout,
|
|
||||||
stop,
|
|
||||||
json_mode,
|
|
||||||
json_schema,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = match response_result {
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => return Err(anyhow!("LLM chat error: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let end_time = Utc::now();
|
|
||||||
|
|
||||||
send_langfuse_request(
|
|
||||||
session_id,
|
|
||||||
prompt_name,
|
|
||||||
None,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
serde_json::to_string(&messages).unwrap(),
|
|
||||||
serde_json::to_string(&response).unwrap(),
|
|
||||||
user_id,
|
|
||||||
&model,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn llm_chat_stream(
|
|
||||||
model: LlmModel,
|
|
||||||
messages: Vec<LlmMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
session_id: &Uuid,
|
|
||||||
user_id: &Uuid,
|
|
||||||
prompt_name: PromptName,
|
|
||||||
) -> Result<(Receiver<String>, JoinHandle<Result<String>>)> {
|
|
||||||
let start_time = Utc::now();
|
|
||||||
|
|
||||||
let stream_result = match &model {
|
|
||||||
LlmModel::Anthropic(model) => {
|
|
||||||
anthropic_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
LlmModel::OpenAi(model) => {
|
|
||||||
openai_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut stream = match stream_result {
|
|
||||||
Ok(stream) => stream,
|
|
||||||
Err(e) => return Err(anyhow!("LLM chat error: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (tx, rx) = mpsc::channel(100);
|
|
||||||
|
|
||||||
let res_future = {
|
|
||||||
let session_id = session_id.clone();
|
|
||||||
let user_id = user_id.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut response = String::new();
|
|
||||||
while let Some(content) = stream.next().await {
|
|
||||||
response.push_str(&content);
|
|
||||||
|
|
||||||
match tx.send(content).await {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => return Err(anyhow!("Streaming Error: {}", e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let end_time = Utc::now();
|
|
||||||
|
|
||||||
send_langfuse_request(
|
|
||||||
&session_id,
|
|
||||||
prompt_name,
|
|
||||||
None,
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
serde_json::to_string(&messages).unwrap(),
|
|
||||||
serde_json::to_string(&response).unwrap(),
|
|
||||||
&user_id,
|
|
||||||
&model,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
})
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((rx, res_future))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn anthropic_chat_compiler(
|
|
||||||
model: &AnthropicChatModel,
|
|
||||||
messages: &Vec<LlmMessage>,
|
|
||||||
_max_tokens: u32,
|
|
||||||
temperature: f32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let system_message = match messages.iter().find(|m| m.role == LlmRole::System) {
|
|
||||||
Some(message) => Some(message.content.clone()),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
let mut anthropic_messages = Vec::new();
|
|
||||||
|
|
||||||
for message in messages {
|
|
||||||
let anthropic_role = match message.role {
|
|
||||||
LlmRole::System => continue,
|
|
||||||
LlmRole::User => AnthropicChatRole::User,
|
|
||||||
LlmRole::Assistant => AnthropicChatRole::Assistant,
|
|
||||||
};
|
|
||||||
|
|
||||||
let anthropic_content = AnthropicContent {
|
|
||||||
text: message.content.clone(),
|
|
||||||
_type: AnthropicContentType::Text,
|
|
||||||
};
|
|
||||||
|
|
||||||
anthropic_messages.push(AnthropicChatMessage {
|
|
||||||
role: anthropic_role,
|
|
||||||
content: vec![anthropic_content],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = match anthropic_chat(
|
|
||||||
model,
|
|
||||||
system_message,
|
|
||||||
&anthropic_messages,
|
|
||||||
temperature,
|
|
||||||
7048,
|
|
||||||
timeout,
|
|
||||||
stop,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn openai_chat_compiler(
|
|
||||||
model: &OpenAiChatModel,
|
|
||||||
messages: &Vec<LlmMessage>,
|
|
||||||
max_tokens: u32,
|
|
||||||
temperature: f32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
json_mode: bool,
|
|
||||||
json_schema: Option<Value>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let mut openai_messages = Vec::new();
|
|
||||||
|
|
||||||
for message in messages {
|
|
||||||
let openai_role = match message.role {
|
|
||||||
LlmRole::System => OpenAiChatRole::System,
|
|
||||||
LlmRole::User => OpenAiChatRole::User,
|
|
||||||
LlmRole::Assistant => OpenAiChatRole::Assistant,
|
|
||||||
};
|
|
||||||
|
|
||||||
let openai_message = OpenAiChatMessage {
|
|
||||||
role: openai_role,
|
|
||||||
content: vec![OpenAiChatContent {
|
|
||||||
type_: "text".to_string(),
|
|
||||||
text: message.content.clone(),
|
|
||||||
}],
|
|
||||||
};
|
|
||||||
|
|
||||||
openai_messages.push(openai_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = match openai_chat(
|
|
||||||
&model,
|
|
||||||
openai_messages,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
timeout,
|
|
||||||
stop,
|
|
||||||
json_mode,
|
|
||||||
json_schema,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn anthropic_chat_stream_compiler(
|
|
||||||
model: &AnthropicChatModel,
|
|
||||||
messages: &Vec<LlmMessage>,
|
|
||||||
max_tokens: u32,
|
|
||||||
temperature: f32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
) -> Result<ReceiverStream<String>> {
|
|
||||||
let system_message = match messages.iter().find(|m| m.role == LlmRole::System) {
|
|
||||||
Some(message) => Some(message.content.clone()),
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
let mut anthropic_messages = Vec::new();
|
|
||||||
|
|
||||||
for message in messages {
|
|
||||||
let anthropic_role = match message.role {
|
|
||||||
LlmRole::System => continue,
|
|
||||||
LlmRole::User => AnthropicChatRole::User,
|
|
||||||
LlmRole::Assistant => AnthropicChatRole::Assistant,
|
|
||||||
};
|
|
||||||
|
|
||||||
let anthropic_content = AnthropicContent {
|
|
||||||
text: message.content.clone(),
|
|
||||||
_type: AnthropicContentType::Text,
|
|
||||||
};
|
|
||||||
|
|
||||||
anthropic_messages.push(AnthropicChatMessage {
|
|
||||||
role: anthropic_role,
|
|
||||||
content: vec![anthropic_content],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
let stream = match anthropic_chat_stream(
|
|
||||||
model,
|
|
||||||
system_message,
|
|
||||||
&anthropic_messages,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
timeout,
|
|
||||||
stop,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn openai_chat_stream_compiler(
|
|
||||||
model: &OpenAiChatModel,
|
|
||||||
messages: &Vec<LlmMessage>,
|
|
||||||
max_tokens: u32,
|
|
||||||
temperature: f32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
) -> Result<ReceiverStream<String>> {
|
|
||||||
let mut openai_messages = Vec::new();
|
|
||||||
|
|
||||||
for message in messages {
|
|
||||||
let openai_role = match message.role {
|
|
||||||
LlmRole::System => OpenAiChatRole::System,
|
|
||||||
LlmRole::User => OpenAiChatRole::User,
|
|
||||||
LlmRole::Assistant => OpenAiChatRole::Assistant,
|
|
||||||
};
|
|
||||||
|
|
||||||
let openai_message = OpenAiChatMessage {
|
|
||||||
role: openai_role,
|
|
||||||
content: vec![OpenAiChatContent {
|
|
||||||
type_: "text".to_string(),
|
|
||||||
text: message.content.clone(),
|
|
||||||
}],
|
|
||||||
};
|
|
||||||
|
|
||||||
openai_messages.push(openai_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
let stream = match openai_chat_stream(
|
|
||||||
&model,
|
|
||||||
&openai_messages,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
timeout,
|
|
||||||
stop,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(stream)
|
|
||||||
}
|
|
|
@ -1,7 +0,0 @@
|
||||||
mod anthropic;
|
|
||||||
pub mod embedding_router;
|
|
||||||
mod hugging_face;
|
|
||||||
pub mod langfuse;
|
|
||||||
pub mod llm_router;
|
|
||||||
pub mod ollama;
|
|
||||||
pub mod openai;
|
|
|
@ -1,67 +0,0 @@
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use axum::http::HeaderMap;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
pub struct OllamaEmbeddingRequest {
|
|
||||||
pub prompt: String,
|
|
||||||
pub model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct OllamaEmbeddingResponse {
|
|
||||||
pub embedding: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ollama_embedding(prompt: String, for_retrieval: bool) -> Result<Vec<f32>> {
|
|
||||||
let ollama_url = env::var("OLLAMA_URL").unwrap_or(String::from("http://localhost:11434"));
|
|
||||||
let model = env::var("EMBEDDING_MODEL").unwrap_or(String::from("mxbai-embed-large"));
|
|
||||||
|
|
||||||
let prompt = if model == "mxbai-embed-large" && for_retrieval {
|
|
||||||
format!(
|
|
||||||
"Represent this sentence for searching relevant passages:: {}",
|
|
||||||
prompt
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
prompt
|
|
||||||
};
|
|
||||||
|
|
||||||
let client = match reqwest::Client::builder().build() {
|
|
||||||
Ok(client) => client,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(anyhow!("Error creating reqwest client: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::CONTENT_TYPE,
|
|
||||||
"application/json".parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let req = OllamaEmbeddingRequest { prompt, model };
|
|
||||||
|
|
||||||
let res = match client
|
|
||||||
.post(format!("{}/api/embeddings", ollama_url))
|
|
||||||
.headers(headers)
|
|
||||||
.json(&req)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(anyhow!("Error sending Ollama request: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let ollama_res: OllamaEmbeddingResponse = match res.json::<OllamaEmbeddingResponse>().await {
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
return Err(anyhow!("Error parsing Ollama response: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(ollama_res.embedding)
|
|
||||||
}
|
|
|
@ -1,417 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use futures::StreamExt;
|
|
||||||
use serde_json::{json, Value};
|
|
||||||
use std::{env, time::Duration};
|
|
||||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
|
||||||
use tokio_stream::wrappers::ReceiverStream;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
|
||||||
|
|
||||||
const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings";
|
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
|
||||||
static ref OPENAI_API_KEY: String = env::var("OPENAI_API_KEY")
|
|
||||||
.expect("OPENAI_API_KEY must be set");
|
|
||||||
static ref OPENAI_CHAT_URL: String = env::var("OPENAI_CHAT_URL").unwrap_or("https://api.openai.com/v1/chat/completions".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub enum OpenAiChatModel {
|
|
||||||
#[serde(rename = "gpt-4o-2024-11-20")]
|
|
||||||
Gpt4o,
|
|
||||||
#[serde(rename = "o3-mini")]
|
|
||||||
O3Mini,
|
|
||||||
#[serde(rename = "gpt-3.5-turbo")]
|
|
||||||
Gpt35Turbo,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum OpenAiChatRole {
|
|
||||||
System,
|
|
||||||
Developer,
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum ReasoningEffort {
|
|
||||||
Low,
|
|
||||||
Medium,
|
|
||||||
High,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub struct OpenAiChatContent {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub type_: String,
|
|
||||||
pub text: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub struct OpenAiChatMessage {
|
|
||||||
pub role: OpenAiChatRole,
|
|
||||||
pub content: Vec<OpenAiChatContent>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions for conditional serialization
|
|
||||||
fn is_o3_model(model: &OpenAiChatModel) -> bool {
|
|
||||||
matches!(model, OpenAiChatModel::O3Mini)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
|
||||||
pub struct OpenAiChatRequest {
|
|
||||||
model: OpenAiChatModel,
|
|
||||||
messages: Vec<OpenAiChatMessage>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
top_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
reasoning_effort: Option<ReasoningEffort>,
|
|
||||||
frequency_penalty: f32,
|
|
||||||
presence_penalty: f32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
stream: bool,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
response_format: Option<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAiChatRequest {
|
|
||||||
pub fn new(
|
|
||||||
model: OpenAiChatModel,
|
|
||||||
messages: Vec<OpenAiChatMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
stream: bool,
|
|
||||||
response_format: Option<Value>,
|
|
||||||
) -> Self {
|
|
||||||
let (temperature, max_tokens, top_p, reasoning_effort) = if is_o3_model(&model) {
|
|
||||||
(None, None, None, Some(ReasoningEffort::Low))
|
|
||||||
} else {
|
|
||||||
(Some(temperature), Some(max_tokens), Some(1.0), None)
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
messages,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
top_p,
|
|
||||||
reasoning_effort,
|
|
||||||
frequency_penalty: 0.0,
|
|
||||||
presence_penalty: 0.0,
|
|
||||||
stop,
|
|
||||||
stream,
|
|
||||||
response_format,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub choices: Vec<Choice>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
|
||||||
pub struct Choice {
|
|
||||||
pub message: Message,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct Message {
|
|
||||||
pub content: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn openai_chat(
|
|
||||||
model: &OpenAiChatModel,
|
|
||||||
messages: Vec<OpenAiChatMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
json_mode: bool,
|
|
||||||
json_schema: Option<Value>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let response_format = match json_mode {
|
|
||||||
true => Some(json!({"type": "json_object"})),
|
|
||||||
false => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let response_format = match json_schema {
|
|
||||||
Some(schema) => Some(json!({"type": "json_schema", "json_schema": schema})),
|
|
||||||
None => response_format,
|
|
||||||
};
|
|
||||||
|
|
||||||
let chat_request = OpenAiChatRequest::new(
|
|
||||||
model.clone(),
|
|
||||||
messages,
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
stop,
|
|
||||||
false,
|
|
||||||
response_format,
|
|
||||||
);
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let headers = {
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
format!("Bearer {}", OPENAI_API_KEY.to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = match client
|
|
||||||
.post(OPENAI_CHAT_URL.to_string())
|
|
||||||
.headers(headers)
|
|
||||||
.json(&chat_request)
|
|
||||||
.timeout(Duration::from_secs(timeout))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => response,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Unable to send request to OpenAI: {:?}", e);
|
|
||||||
let err = anyhow!("Unable to send request to OpenAI: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let response_text = response.text().await.unwrap();
|
|
||||||
|
|
||||||
let completion_res = match serde_json::from_str::<ChatCompletionResponse>(&response_text) {
|
|
||||||
Ok(res) => res,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Unable to parse response from OpenAI: {:?}", e);
|
|
||||||
let err = anyhow!("Unable to parse response from OpenAI: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let content = match completion_res.choices.get(0) {
|
|
||||||
Some(choice) => choice.message.content.clone(),
|
|
||||||
None => return Err(anyhow!("No content returned from OpenAI")),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(content)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct OpenAiChatDelta {
|
|
||||||
pub content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct OpenAiChatChoice {
|
|
||||||
pub delta: OpenAiChatDelta,
|
|
||||||
pub index: u32,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct OpenAiChatStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub system_fingerprint: String,
|
|
||||||
pub choices: Vec<OpenAiChatChoice>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// We can't do Langfuse traces directly in the stream function.
|
|
||||||
/// This is mainly because I'm too lazy to figure out how to set up all the passing with the stream reciever
|
|
||||||
/// Langfuse traces must be written directly on the stream call.
|
|
||||||
|
|
||||||
pub async fn openai_chat_stream(
|
|
||||||
model: &OpenAiChatModel,
|
|
||||||
messages: &Vec<OpenAiChatMessage>,
|
|
||||||
temperature: f32,
|
|
||||||
max_tokens: u32,
|
|
||||||
timeout: u64,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
) -> Result<ReceiverStream<String>> {
|
|
||||||
let chat_request = OpenAiChatRequest::new(
|
|
||||||
model.clone(),
|
|
||||||
messages.clone(),
|
|
||||||
temperature,
|
|
||||||
max_tokens,
|
|
||||||
stop,
|
|
||||||
true,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let headers = {
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
format!("Bearer {}", OPENAI_API_KEY.to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers
|
|
||||||
};
|
|
||||||
|
|
||||||
let (tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel(100);
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let response = client
|
|
||||||
.post(OPENAI_CHAT_URL.to_string())
|
|
||||||
.headers(headers)
|
|
||||||
.json(&chat_request)
|
|
||||||
.timeout(Duration::from_secs(timeout))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
tracing::error!("Unable to send request to OpenAI: {:?}", e);
|
|
||||||
let err = anyhow!("Unable to send request to OpenAI: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
err
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Err(_e) = response {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = response.unwrap();
|
|
||||||
let mut stream = response.bytes_stream();
|
|
||||||
|
|
||||||
let mut buffer = String::new();
|
|
||||||
|
|
||||||
while let Some(item) = stream.next().await {
|
|
||||||
match item {
|
|
||||||
Ok(bytes) => {
|
|
||||||
let chunk = String::from_utf8(bytes.to_vec()).unwrap();
|
|
||||||
buffer.push_str(&chunk);
|
|
||||||
|
|
||||||
while let Some(pos) = buffer.find("}\n") {
|
|
||||||
let (json_str, rest) = {
|
|
||||||
let (json_str, rest) = buffer.split_at(pos + 1);
|
|
||||||
(json_str.to_string(), rest.to_string())
|
|
||||||
};
|
|
||||||
buffer = rest;
|
|
||||||
|
|
||||||
let json_str_trimmed = json_str.replace("data: ", "");
|
|
||||||
match serde_json::from_str::<OpenAiChatStreamResponse>(
|
|
||||||
&json_str_trimmed.trim(),
|
|
||||||
) {
|
|
||||||
Ok(response) => {
|
|
||||||
if let Some(content) = &response.choices[0].delta.content {
|
|
||||||
if tx.send(content.clone()).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Error parsing JSON response: {:?}", e);
|
|
||||||
let err = anyhow!("Error parsing JSON response: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Error while streaming response: {:?}", e);
|
|
||||||
let err = anyhow!("Error while streaming response: {}", e);
|
|
||||||
send_sentry_error(&err.to_string(), None);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(ReceiverStream::new(rx))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
pub struct AdaBulkEmbedding {
|
|
||||||
pub model: String,
|
|
||||||
pub input: Vec<String>,
|
|
||||||
pub dimensions: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct AdaEmbeddingArray {
|
|
||||||
pub embedding: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
pub struct AdaEmbeddingResponse {
|
|
||||||
pub data: Vec<AdaEmbeddingArray>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn ada_bulk_embedding(text_list: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
|
||||||
let embedding_model = env::var("EMBEDDING_MODEL").expect("EMBEDDING_MODEL must be set");
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let ada_bulk_embedding = AdaBulkEmbedding {
|
|
||||||
model: embedding_model,
|
|
||||||
input: text_list,
|
|
||||||
dimensions: 1024,
|
|
||||||
};
|
|
||||||
|
|
||||||
let embeddings_result = match client
|
|
||||||
.post(OPENAI_EMBEDDING_URL)
|
|
||||||
.headers({
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
format!("Bearer {}", OPENAI_API_KEY.to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers
|
|
||||||
})
|
|
||||||
.timeout(Duration::from_secs(60))
|
|
||||||
.json(&ada_bulk_embedding)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => {
|
|
||||||
if !res.status().is_success() {
|
|
||||||
tracing::error!(
|
|
||||||
"There was an issue while getting the data source: {}",
|
|
||||||
res.text().await.unwrap()
|
|
||||||
);
|
|
||||||
return Err(anyhow!("Error getting data source"));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
Err(e) => Err(anyhow!(e.to_string())),
|
|
||||||
};
|
|
||||||
|
|
||||||
let embedding_text = embeddings_result.unwrap().text().await.unwrap();
|
|
||||||
|
|
||||||
let embeddings: AdaEmbeddingResponse =
|
|
||||||
match serde_json::from_str(embedding_text.clone().as_str()) {
|
|
||||||
Ok(embeddings) => embeddings,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(
|
|
||||||
"There was an issue decoding the bulk embedding json response: {}",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
return Err(anyhow!(e.to_string()));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let result: Vec<Vec<f32>> = embeddings.data.into_iter().map(|x| x.embedding).collect();
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use aws_config::meta::region::RegionProviderChain;
|
|
||||||
use aws_sdk_secretsmanager::Client;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
pub async fn create_db_secret(secret: String) -> Result<String> {
|
|
||||||
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
|
|
||||||
let config = aws_config::from_env().region(region_provider).load().await;
|
|
||||||
let client = Client::new(&config);
|
|
||||||
|
|
||||||
let secret_id = Uuid::new_v4().to_string();
|
|
||||||
|
|
||||||
match client
|
|
||||||
.create_secret()
|
|
||||||
.name(secret_id.clone())
|
|
||||||
.secret_string(secret)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => {
|
|
||||||
tracing::info!("Successfully created secret in AWS.");
|
|
||||||
tracing::info!("Secret ID: {}", secret_id);
|
|
||||||
tracing::info!("Secret ARN: {}", res.arn.unwrap());
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("There was an issue while creating the secret in AWS.");
|
|
||||||
return Err(anyhow!(e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(secret_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_db_secret(secret_id: String) -> Result<()> {
|
|
||||||
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
|
|
||||||
let config = aws_config::from_env().region(region_provider).load().await;
|
|
||||||
let client = Client::new(&config);
|
|
||||||
|
|
||||||
let _secret_response = match client.delete_secret().secret_id(secret_id).send().await {
|
|
||||||
Ok(secret_response) => secret_response,
|
|
||||||
Err(e) => return Err(anyhow!(e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_secret(secret_id: String) -> Result<String> {
|
|
||||||
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
|
|
||||||
let config = aws_config::from_env().region(region_provider).load().await;
|
|
||||||
let client = Client::new(&config);
|
|
||||||
|
|
||||||
let secret_response = client.get_secret_value().secret_id(secret_id).send().await;
|
|
||||||
|
|
||||||
let secret = match secret_response {
|
|
||||||
Ok(secret) => secret,
|
|
||||||
Err(e) => return Err(anyhow!(e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let secret_string = match secret.secret_string {
|
|
||||||
Some(secret_string) => secret_string,
|
|
||||||
None => return Err(anyhow!("There was no secret string in the response")),
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(secret_string);
|
|
||||||
}
|
|
|
@ -1,6 +1,2 @@
|
||||||
pub mod ai;
|
|
||||||
// pub mod aws;
|
|
||||||
pub mod email;
|
pub mod email;
|
||||||
pub mod posthog;
|
|
||||||
pub mod sentry_utils;
|
pub mod sentry_utils;
|
||||||
pub mod typesense;
|
|
||||||
|
|
|
@ -1,190 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use reqwest::header::AUTHORIZATION;
|
|
||||||
use std::env;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
|
||||||
static ref POSTHOG_API_KEY: String = env::var("POSTHOG_API_KEY").expect("POSTHOG_API_KEY must be set");
|
|
||||||
}
|
|
||||||
|
|
||||||
const POSTHOG_API_URL: &str = "https://us.i.posthog.com/capture/";
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct PosthogEvent {
|
|
||||||
event: PosthogEventType,
|
|
||||||
distinct_id: Option<Uuid>,
|
|
||||||
properties: PosthogEventProperties,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum PosthogEventType {
|
|
||||||
MetricCreated,
|
|
||||||
MetricFollowUp,
|
|
||||||
MetricAddedToDashboard,
|
|
||||||
MetricViewed,
|
|
||||||
DashboardViewed,
|
|
||||||
TitleManuallyUpdated,
|
|
||||||
SqlManuallyUpdated,
|
|
||||||
ChartStylingManuallyUpdated,
|
|
||||||
ChartStylingAutoUpdated,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub enum PosthogEventProperties {
|
|
||||||
MetricCreated(MetricCreatedProperties),
|
|
||||||
MetricFollowUp(MetricFollowUpProperties),
|
|
||||||
MetricAddedToDashboard(MetricAddedToDashboardProperties),
|
|
||||||
MetricViewed(MetricViewedProperties),
|
|
||||||
DashboardViewed(DashboardViewedProperties),
|
|
||||||
TitleManuallyUpdated(TitleManuallyUpdatedProperties),
|
|
||||||
SqlManuallyUpdated(SqlManuallyUpdatedProperties),
|
|
||||||
ChartStylingManuallyUpdated(ChartStylingManuallyUpdatedProperties),
|
|
||||||
ChartStylingAutoUpdated(ChartStylingAutoUpdatedProperties),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct MetricCreatedProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct MetricFollowUpProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct MetricAddedToDashboardProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
pub dashboard_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct MetricViewedProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct DashboardViewedProperties {
|
|
||||||
pub dashboard_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct TitleManuallyUpdatedProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct SqlManuallyUpdatedProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct ChartStylingManuallyUpdatedProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct ChartStylingAutoUpdatedProperties {
|
|
||||||
pub message_id: Uuid,
|
|
||||||
pub thread_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send_posthog_event_handler(
|
|
||||||
event_type: PosthogEventType,
|
|
||||||
user_id: Option<Uuid>,
|
|
||||||
properties: PosthogEventProperties,
|
|
||||||
) -> Result<()> {
|
|
||||||
let event = PosthogEvent {
|
|
||||||
event: event_type,
|
|
||||||
distinct_id: user_id,
|
|
||||||
properties,
|
|
||||||
};
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
match send_event(event).await {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => {
|
|
||||||
send_sentry_error(&e.to_string(), None);
|
|
||||||
tracing::error!("Unable to send event to Posthog: {:?}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
|
||||||
pub struct PosthogRequest {
|
|
||||||
#[serde(flatten)]
|
|
||||||
event: PosthogEvent,
|
|
||||||
api_key: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_event(event: PosthogEvent) -> Result<()> {
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
|
|
||||||
let headers = {
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert(
|
|
||||||
AUTHORIZATION,
|
|
||||||
format!("Bearer {}", POSTHOG_API_KEY.to_string())
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
headers
|
|
||||||
};
|
|
||||||
|
|
||||||
let posthog_req = PosthogRequest {
|
|
||||||
event,
|
|
||||||
api_key: POSTHOG_API_KEY.to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
match client
|
|
||||||
.post(POSTHOG_API_URL)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&posthog_req)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Unable to send request to Posthog: {:?}", e);
|
|
||||||
return Err(anyhow!(e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use dotenv::dotenv;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_send_event() {
|
|
||||||
dotenv().ok();
|
|
||||||
|
|
||||||
let _ = send_event(PosthogEvent {
|
|
||||||
event: PosthogEventType::MetricCreated,
|
|
||||||
distinct_id: Some(Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap()),
|
|
||||||
properties: PosthogEventProperties::MetricCreated(MetricCreatedProperties {
|
|
||||||
message_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(),
|
|
||||||
thread_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(),
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,318 +0,0 @@
|
||||||
use anyhow::{anyhow, Result};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::env;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
|
||||||
static ref TYPESENSE_API_KEY: String = env::var("TYPESENSE_API_KEY").unwrap_or("xyz".to_string());
|
|
||||||
static ref TYPESENSE_API_HOST: String = env::var("TYPESENSE_API_HOST").unwrap_or("http://localhost:8108".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum CollectionName {
|
|
||||||
Messages,
|
|
||||||
Collections,
|
|
||||||
Datasets,
|
|
||||||
PermissionGroups,
|
|
||||||
Teams,
|
|
||||||
DataSources,
|
|
||||||
Terms,
|
|
||||||
Dashboards,
|
|
||||||
#[serde(untagged)]
|
|
||||||
StoredValues(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CollectionName {
|
|
||||||
pub fn to_string(&self) -> String {
|
|
||||||
match self {
|
|
||||||
CollectionName::Messages => "messages".to_string(),
|
|
||||||
CollectionName::Collections => "collections".to_string(),
|
|
||||||
CollectionName::Datasets => "datasets".to_string(),
|
|
||||||
CollectionName::PermissionGroups => "permission_groups".to_string(),
|
|
||||||
CollectionName::Teams => "teams".to_string(),
|
|
||||||
CollectionName::DataSources => "data_sources".to_string(),
|
|
||||||
CollectionName::Terms => "terms".to_string(),
|
|
||||||
CollectionName::Dashboards => "dashboards".to_string(),
|
|
||||||
CollectionName::StoredValues(s) => s.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize)]
|
|
||||||
pub struct MessageDocument {
|
|
||||||
pub id: Uuid,
|
|
||||||
pub name: String,
|
|
||||||
pub summary_question: String,
|
|
||||||
pub organization_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize)]
|
|
||||||
pub struct GenericDocument {
|
|
||||||
pub id: Uuid,
|
|
||||||
pub name: String,
|
|
||||||
pub organization_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Clone, Debug)]
|
|
||||||
pub struct StoredValueDocument {
|
|
||||||
pub id: Uuid,
|
|
||||||
pub value: String,
|
|
||||||
pub dataset_id: Uuid,
|
|
||||||
pub dataset_column_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum Document {
|
|
||||||
Message(MessageDocument),
|
|
||||||
Dashboard(GenericDocument),
|
|
||||||
Dataset(GenericDocument),
|
|
||||||
PermissionGroup(GenericDocument),
|
|
||||||
Team(GenericDocument),
|
|
||||||
DataSource(GenericDocument),
|
|
||||||
Term(GenericDocument),
|
|
||||||
Collection(GenericDocument),
|
|
||||||
StoredValue(StoredValueDocument),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Document {
|
|
||||||
pub fn id(&self) -> Uuid {
|
|
||||||
match self {
|
|
||||||
Document::Message(m) => m.id,
|
|
||||||
Document::Dashboard(d) => d.id,
|
|
||||||
Document::Dataset(d) => d.id,
|
|
||||||
Document::PermissionGroup(d) => d.id,
|
|
||||||
Document::Team(d) => d.id,
|
|
||||||
Document::DataSource(d) => d.id,
|
|
||||||
Document::Term(d) => d.id,
|
|
||||||
Document::Collection(d) => d.id,
|
|
||||||
Document::StoredValue(d) => d.id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_stored_value_document(&self) -> &StoredValueDocument {
|
|
||||||
match self {
|
|
||||||
Document::StoredValue(d) => d,
|
|
||||||
_ => panic!("Document is not a StoredValueDocument"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn upsert_document(collection: CollectionName, document: Document) -> Result<()> {
|
|
||||||
let client = reqwest::Client::builder().build()?;
|
|
||||||
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert("Content-Type", "application/json".parse()?);
|
|
||||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
|
||||||
|
|
||||||
match client
|
|
||||||
.request(
|
|
||||||
reqwest::Method::POST,
|
|
||||||
format!(
|
|
||||||
"{}/collections/{}/documents?action=upsert",
|
|
||||||
*TYPESENSE_API_HOST,
|
|
||||||
collection.to_string()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&document)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
pub struct SearchRequest {
|
|
||||||
pub searches: Vec<SearchRequestObject>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Debug)]
|
|
||||||
pub struct SearchRequestObject {
|
|
||||||
pub collection: CollectionName,
|
|
||||||
pub q: String,
|
|
||||||
pub query_by: String,
|
|
||||||
pub prefix: bool,
|
|
||||||
pub exclude_fields: String,
|
|
||||||
pub highlight_fields: String,
|
|
||||||
pub use_cache: bool,
|
|
||||||
pub filter_by: String,
|
|
||||||
pub vector_query: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub limit: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct RequestParams {
|
|
||||||
pub collection_name: CollectionName,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct Highlight {
|
|
||||||
pub field: String,
|
|
||||||
pub matched_tokens: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct HybridSearchInfo {
|
|
||||||
pub rank_fusion_score: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct Hit {
|
|
||||||
pub document: Document,
|
|
||||||
pub highlights: Vec<Highlight>,
|
|
||||||
pub hybrid_search_info: HybridSearchInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct SearchResult {
|
|
||||||
pub request_params: RequestParams,
|
|
||||||
pub hits: Vec<Hit>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct SearchResponse {
|
|
||||||
pub results: Vec<SearchResult>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn search_documents(search_reqs: Vec<SearchRequestObject>) -> Result<SearchResponse> {
|
|
||||||
let client = reqwest::Client::builder().build()?;
|
|
||||||
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert("Content-Type", "application/json".parse()?);
|
|
||||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
|
||||||
|
|
||||||
let search_req = SearchRequest {
|
|
||||||
searches: search_reqs,
|
|
||||||
};
|
|
||||||
|
|
||||||
let res = match client
|
|
||||||
.request(
|
|
||||||
reqwest::Method::POST,
|
|
||||||
format!("{}/multi_search", *TYPESENSE_API_HOST),
|
|
||||||
)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&search_req)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => {
|
|
||||||
let res = if res.status().is_success() {
|
|
||||||
res
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"Error sending request: {}",
|
|
||||||
res.text().await.unwrap()
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
res
|
|
||||||
}
|
|
||||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
let search_response = match res.json::<SearchResponse>().await {
|
|
||||||
Ok(search_response) => search_response,
|
|
||||||
Err(e) => return Err(anyhow!("Error parsing response: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(search_response)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create_collection(schema: Value) -> Result<()> {
|
|
||||||
let client = reqwest::Client::builder().build()?;
|
|
||||||
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert("Content-Type", "application/json".parse()?);
|
|
||||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
|
||||||
|
|
||||||
match client
|
|
||||||
.request(
|
|
||||||
reqwest::Method::POST,
|
|
||||||
format!("{}/collections", *TYPESENSE_API_HOST,),
|
|
||||||
)
|
|
||||||
.headers(headers)
|
|
||||||
.json(&schema)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_collection(collection_name: &String, filter_by: &String) -> Result<()> {
|
|
||||||
let client = reqwest::Client::builder().build()?;
|
|
||||||
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert("Content-Type", "application/json".parse()?);
|
|
||||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
|
||||||
|
|
||||||
match client
|
|
||||||
.request(
|
|
||||||
reqwest::Method::DELETE,
|
|
||||||
format!(
|
|
||||||
"{}/collections/{}/documents?filter_by={}",
|
|
||||||
*TYPESENSE_API_HOST, collection_name, filter_by
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.headers(headers)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn bulk_insert_documents<T>(collection_name: &String, documents: &Vec<T>) -> Result<()>
|
|
||||||
where
|
|
||||||
T: serde::Serialize,
|
|
||||||
{
|
|
||||||
let client = reqwest::Client::builder().build()?;
|
|
||||||
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
headers.insert("Content-Type", "application/json".parse()?);
|
|
||||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
|
||||||
|
|
||||||
let serialized_documents = documents
|
|
||||||
.iter()
|
|
||||||
.map(|doc| serde_json::to_string(doc))
|
|
||||||
.collect::<Result<Vec<String>, _>>()?
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
match client
|
|
||||||
.request(
|
|
||||||
reqwest::Method::POST,
|
|
||||||
format!(
|
|
||||||
"{}/collections/{}/documents/import?action=create",
|
|
||||||
*TYPESENSE_API_HOST, collection_name
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.headers(headers)
|
|
||||||
.body(serialized_documents)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => {
|
|
||||||
if res.status().is_success() {
|
|
||||||
tracing::info!("Res: {:?}", res.text().await.unwrap());
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
tracing::error!("Error sending request: {}", res.text().await.unwrap());
|
|
||||||
return Err(anyhow!("Error sending bulk insert request"));
|
|
||||||
}
|
|
||||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
|
||||||
};
|
|
||||||
}
|
|
|
@ -1,11 +1,5 @@
|
||||||
pub mod charting;
|
|
||||||
pub mod clients;
|
pub mod clients;
|
||||||
pub mod security;
|
pub mod security;
|
||||||
pub mod serde_helpers;
|
|
||||||
pub mod stored_values;
|
|
||||||
pub mod validation;
|
|
||||||
|
|
||||||
pub use agents::*;
|
pub use agents::*;
|
||||||
pub use security::*;
|
pub use security::*;
|
||||||
pub use stored_values::*;
|
|
||||||
pub use validation::*;
|
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
use serde::{de::Deserializer, Deserialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub fn deserialize_double_option<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
|
|
||||||
where
|
|
||||||
T: serde::Deserialize<'de>,
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Value::deserialize(deserializer)?;
|
|
||||||
|
|
||||||
match value {
|
|
||||||
Value::Null => Ok(Some(None)), // explicit null
|
|
||||||
Value::Object(obj) if obj.is_empty() => Ok(None), // empty object
|
|
||||||
_ => {
|
|
||||||
match T::deserialize(value) {
|
|
||||||
Ok(val) => Ok(Some(Some(val))),
|
|
||||||
Err(_) => Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1 +0,0 @@
|
||||||
pub mod deserialization_helpers;
|
|
|
@ -1,291 +0,0 @@
|
||||||
pub mod search;
|
|
||||||
use query_engine::data_source_query_routes::query_engine::query_engine;
|
|
||||||
|
|
||||||
use query_engine::data_types::DataType;
|
|
||||||
pub use search::*;
|
|
||||||
|
|
||||||
use crate::utils::clients::ai::embedding_router::embedding_router;
|
|
||||||
use anyhow::Result;
|
|
||||||
use chrono::Utc;
|
|
||||||
use database::enums::StoredValuesStatus;
|
|
||||||
use database::{pool::get_pg_pool, schema::dataset_columns};
|
|
||||||
use diesel::prelude::*;
|
|
||||||
use diesel::sql_types::{Array, Float4, Integer, Text, Timestamptz, Uuid as SqlUuid};
|
|
||||||
use diesel_async::RunQueryDsl;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[derive(Debug, QueryableByName)]
|
|
||||||
pub struct StoredValueRow {
|
|
||||||
#[diesel(sql_type = Text)]
|
|
||||||
pub value: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, QueryableByName)]
|
|
||||||
pub struct StoredValueWithDistance {
|
|
||||||
#[diesel(sql_type = Text)]
|
|
||||||
pub value: String,
|
|
||||||
#[diesel(sql_type = Text)]
|
|
||||||
pub column_name: String,
|
|
||||||
#[diesel(sql_type = SqlUuid)]
|
|
||||||
pub column_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
const BATCH_SIZE: usize = 10_000;
|
|
||||||
const MAX_VALUE_LENGTH: usize = 50;
|
|
||||||
|
|
||||||
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
|
|
||||||
let pool = get_pg_pool();
|
|
||||||
let mut conn = pool.get().await?;
|
|
||||||
|
|
||||||
// Create schema and table using raw SQL
|
|
||||||
let schema_name = organization_id.to_string().replace("-", "_");
|
|
||||||
let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS values_{}", schema_name);
|
|
||||||
|
|
||||||
let create_table_sql = format!(
|
|
||||||
"CREATE TABLE IF NOT EXISTS values_{}.values_v1 (
|
|
||||||
value text NOT NULL,
|
|
||||||
dataset_id uuid NOT NULL,
|
|
||||||
column_name text NOT NULL,
|
|
||||||
column_id uuid NOT NULL,
|
|
||||||
embedding vector(1024),
|
|
||||||
created_at timestamp with time zone NOT NULL DEFAULT now(),
|
|
||||||
UNIQUE(dataset_id, column_name, value)
|
|
||||||
)",
|
|
||||||
schema_name
|
|
||||||
);
|
|
||||||
|
|
||||||
let create_index_sql = format!(
|
|
||||||
"CREATE INDEX IF NOT EXISTS values_v1_embedding_idx
|
|
||||||
ON values_{}.values_v1
|
|
||||||
USING ivfflat (embedding vector_cosine_ops)",
|
|
||||||
schema_name
|
|
||||||
);
|
|
||||||
|
|
||||||
diesel::sql_query(create_schema_sql)
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await?;
|
|
||||||
diesel::sql_query(create_table_sql)
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await?;
|
|
||||||
diesel::sql_query(create_index_sql)
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn store_column_values(
|
|
||||||
organization_id: &Uuid,
|
|
||||||
dataset_id: &Uuid,
|
|
||||||
column_name: &str,
|
|
||||||
column_id: &Uuid,
|
|
||||||
_data_source_id: &Uuid,
|
|
||||||
schema: &str,
|
|
||||||
table_name: &str,
|
|
||||||
) -> Result<()> {
|
|
||||||
let pool = get_pg_pool();
|
|
||||||
let mut conn = pool.get().await?;
|
|
||||||
|
|
||||||
// Create schema and table if they don't exist
|
|
||||||
ensure_stored_values_schema(organization_id).await?;
|
|
||||||
|
|
||||||
// Query distinct values in batches
|
|
||||||
let mut offset = 0;
|
|
||||||
let mut first_batch = true;
|
|
||||||
let schema_name = organization_id.to_string().replace("-", "_");
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let query = format!(
|
|
||||||
"SELECT DISTINCT \"{}\" as value
|
|
||||||
FROM {}.{}
|
|
||||||
WHERE \"{}\" IS NOT NULL
|
|
||||||
AND length(\"{}\") <= {}
|
|
||||||
ORDER BY \"{}\"
|
|
||||||
LIMIT {} OFFSET {}",
|
|
||||||
column_name,
|
|
||||||
schema,
|
|
||||||
table_name,
|
|
||||||
column_name,
|
|
||||||
column_name,
|
|
||||||
MAX_VALUE_LENGTH,
|
|
||||||
column_name,
|
|
||||||
BATCH_SIZE,
|
|
||||||
offset
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = match query_engine(dataset_id, &query, None).await {
|
|
||||||
Ok(results) => results.data,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Error querying stored values: {:?}", e);
|
|
||||||
if first_batch {
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
vec![]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if results.is_empty() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract values from the query results
|
|
||||||
let values: Vec<String> = results
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|row| {
|
|
||||||
if let Some(DataType::Text(Some(value))) = row.get("value") {
|
|
||||||
Some(value.clone())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if values.is_empty() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this is the first batch and we have 15 or fewer values, handle as enum
|
|
||||||
if first_batch && values.len() <= 15 {
|
|
||||||
// Get current description
|
|
||||||
let current_description =
|
|
||||||
diesel::sql_query("SELECT description FROM dataset_columns WHERE id = $1")
|
|
||||||
.bind::<SqlUuid, _>(column_id)
|
|
||||||
.get_result::<StoredValueRow>(&mut conn)
|
|
||||||
.await
|
|
||||||
.ok()
|
|
||||||
.and_then(|row| Some(row.value));
|
|
||||||
|
|
||||||
// Format new description
|
|
||||||
let enum_list = format!("Values for this column are: {}", values.join(", "));
|
|
||||||
let new_description = match current_description {
|
|
||||||
Some(desc) if !desc.is_empty() => format!("{}. {}", desc, enum_list),
|
|
||||||
_ => enum_list,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Update column description
|
|
||||||
diesel::update(dataset_columns::table)
|
|
||||||
.filter(dataset_columns::id.eq(column_id))
|
|
||||||
.set((
|
|
||||||
dataset_columns::description.eq(new_description),
|
|
||||||
dataset_columns::stored_values_status.eq(StoredValuesStatus::Success),
|
|
||||||
dataset_columns::stored_values_count.eq(values.len() as i64),
|
|
||||||
dataset_columns::stored_values_last_synced.eq(Utc::now()),
|
|
||||||
))
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create embeddings for the batch
|
|
||||||
let embeddings = create_embeddings_batch(&values).await?;
|
|
||||||
|
|
||||||
// Insert values and embeddings
|
|
||||||
for (value, embedding) in values.iter().zip(embeddings.iter()) {
|
|
||||||
let insert_sql = format!(
|
|
||||||
"INSERT INTO {}_values.values_v1
|
|
||||||
(value, dataset_id, column_name, column_id, embedding, created_at)
|
|
||||||
VALUES ($1::text, $2::uuid, $3::text, $4::uuid, $5::vector, $6::timestamptz)
|
|
||||||
ON CONFLICT (dataset_id, column_name, value)
|
|
||||||
DO UPDATE SET created_at = EXCLUDED.created_at",
|
|
||||||
schema_name
|
|
||||||
);
|
|
||||||
|
|
||||||
diesel::sql_query(insert_sql)
|
|
||||||
.bind::<Text, _>(value)
|
|
||||||
.bind::<SqlUuid, _>(dataset_id)
|
|
||||||
.bind::<Text, _>(column_name)
|
|
||||||
.bind::<SqlUuid, _>(column_id)
|
|
||||||
.bind::<Array<Float4>, _>(embedding)
|
|
||||||
.bind::<Timestamptz, _>(Utc::now())
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
first_batch = false;
|
|
||||||
offset += BATCH_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_embeddings_batch(values: &[String]) -> Result<Vec<Vec<f32>>> {
|
|
||||||
let embeddings = embedding_router(values.to_vec(), true).await?;
|
|
||||||
Ok(embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn search_stored_values(
|
|
||||||
organization_id: &Uuid,
|
|
||||||
dataset_id: &Uuid,
|
|
||||||
query_embedding: Vec<f32>,
|
|
||||||
limit: Option<i32>,
|
|
||||||
) -> Result<Vec<(String, String, Uuid)>> {
|
|
||||||
let pool = get_pg_pool();
|
|
||||||
let mut conn = pool.get().await?;
|
|
||||||
|
|
||||||
let limit = limit.unwrap_or(10);
|
|
||||||
|
|
||||||
let schema_name = organization_id.to_string().replace("-", "_");
|
|
||||||
let query = format!(
|
|
||||||
"SELECT value, column_name, column_id
|
|
||||||
FROM values_{}.values_v1
|
|
||||||
WHERE dataset_id = $2::uuid
|
|
||||||
ORDER BY embedding <=> $1::vector
|
|
||||||
LIMIT $3::integer",
|
|
||||||
schema_name
|
|
||||||
);
|
|
||||||
|
|
||||||
let results: Vec<StoredValueWithDistance> = diesel::sql_query(query)
|
|
||||||
.bind::<Array<Float4>, _>(query_embedding)
|
|
||||||
.bind::<SqlUuid, _>(dataset_id)
|
|
||||||
.bind::<Integer, _>(limit)
|
|
||||||
.load(&mut conn)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(results
|
|
||||||
.into_iter()
|
|
||||||
.map(|r| (r.value, r.column_name, r.column_id))
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct StoredValueColumn {
|
|
||||||
pub organization_id: Uuid,
|
|
||||||
pub dataset_id: Uuid,
|
|
||||||
pub column_name: String,
|
|
||||||
pub column_id: Uuid,
|
|
||||||
pub data_source_id: Uuid,
|
|
||||||
pub schema: String,
|
|
||||||
pub table_name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn process_stored_values_background(columns: Vec<StoredValueColumn>) {
|
|
||||||
for column in columns {
|
|
||||||
match store_column_values(
|
|
||||||
&column.organization_id,
|
|
||||||
&column.dataset_id,
|
|
||||||
&column.column_name,
|
|
||||||
&column.column_id,
|
|
||||||
&column.data_source_id,
|
|
||||||
&column.schema,
|
|
||||||
&column.table_name,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => {
|
|
||||||
tracing::info!(
|
|
||||||
"Successfully processed stored values for column '{}' in dataset '{}'",
|
|
||||||
column.column_name,
|
|
||||||
column.table_name
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(
|
|
||||||
"Failed to process stored values for column '{}' in dataset '{}': {:?}",
|
|
||||||
column.column_name,
|
|
||||||
column.table_name,
|
|
||||||
e
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
use anyhow::Result;
|
|
||||||
use cohere_rust::{api::rerank::{ReRankModel, ReRankRequest}, Cohere};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use crate::utils::clients::ai::embedding_router::embedding_router;
|
|
||||||
|
|
||||||
use super::search_stored_values;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct StoredValue {
|
|
||||||
pub value: String,
|
|
||||||
pub dataset_id: Uuid,
|
|
||||||
pub column_name: String,
|
|
||||||
pub column_id: Uuid,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn search_values_for_dataset(
|
|
||||||
organization_id: &Uuid,
|
|
||||||
dataset_id: &Uuid,
|
|
||||||
query: String,
|
|
||||||
) -> Result<Vec<StoredValue>> {
|
|
||||||
// Create embedding for the search query
|
|
||||||
let query_vec = vec![query.clone()];
|
|
||||||
let query_embedding = embedding_router(query_vec, true).await?[0].clone();
|
|
||||||
|
|
||||||
// Get initial candidates using vector similarity
|
|
||||||
let candidates = search_stored_values(
|
|
||||||
organization_id,
|
|
||||||
dataset_id,
|
|
||||||
query_embedding,
|
|
||||||
Some(25), // Get more candidates for reranking
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
// Extract just the values for reranking
|
|
||||||
let candidate_values: Vec<String> = candidates.iter().map(|(value, _, _)| value.clone()).collect();
|
|
||||||
|
|
||||||
// Rerank the candidates using the cohere client
|
|
||||||
let co = Cohere::default();
|
|
||||||
let request = ReRankRequest {
|
|
||||||
query: query.as_str(),
|
|
||||||
documents: &candidate_values,
|
|
||||||
model: ReRankModel::EnglishV3,
|
|
||||||
top_n: Some(10),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = co.rerank(&request).await?;
|
|
||||||
|
|
||||||
// Convert to StoredValue structs
|
|
||||||
let values = response.into_iter()
|
|
||||||
.map(|result| {
|
|
||||||
let (value, column_name, column_id) = candidates[result.index as usize].clone();
|
|
||||||
StoredValue {
|
|
||||||
value,
|
|
||||||
dataset_id: *dataset_id,
|
|
||||||
column_name,
|
|
||||||
column_id,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(values)
|
|
||||||
}
|
|
|
@ -1,5 +0,0 @@
|
||||||
pub mod types;
|
|
||||||
pub mod type_mapping;
|
|
||||||
|
|
||||||
pub use types::*;
|
|
||||||
pub use type_mapping::*;
|
|
|
@ -1,234 +0,0 @@
|
||||||
use database::enums::DataSourceType;
|
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use query_engine::data_types::DataType;
|
|
||||||
|
|
||||||
// Standard types we use across all data sources
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
|
||||||
pub enum StandardType {
|
|
||||||
Text,
|
|
||||||
Integer,
|
|
||||||
Float,
|
|
||||||
Boolean,
|
|
||||||
Date,
|
|
||||||
Timestamp,
|
|
||||||
Json,
|
|
||||||
Unknown,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StandardType {
|
|
||||||
pub fn from_data_type(data_type: &DataType) -> Self {
|
|
||||||
match data_type.simple_type() {
|
|
||||||
Some(simple_type) => match simple_type.as_str() {
|
|
||||||
"string" => StandardType::Text,
|
|
||||||
"number" => StandardType::Float,
|
|
||||||
"boolean" => StandardType::Boolean,
|
|
||||||
"date" => StandardType::Date,
|
|
||||||
"null" => StandardType::Unknown,
|
|
||||||
_ => StandardType::Unknown,
|
|
||||||
},
|
|
||||||
None => StandardType::Unknown,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_str(type_str: &str) -> Self {
|
|
||||||
match type_str.to_lowercase().as_str() {
|
|
||||||
"text" | "string" | "varchar" | "char" => StandardType::Text,
|
|
||||||
"int" | "integer" | "bigint" | "smallint" => StandardType::Integer,
|
|
||||||
"float" | "double" | "decimal" | "numeric" => StandardType::Float,
|
|
||||||
"bool" | "boolean" => StandardType::Boolean,
|
|
||||||
"date" => StandardType::Date,
|
|
||||||
"timestamp" | "datetime" | "timestamptz" => StandardType::Timestamp,
|
|
||||||
"json" | "jsonb" => StandardType::Json,
|
|
||||||
_ => StandardType::Unknown,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type mappings for each data source type to DataType
|
|
||||||
static TYPE_MAPPINGS: Lazy<HashMap<DataSourceType, HashMap<&'static str, DataType>>> =
|
|
||||||
Lazy::new(|| {
|
|
||||||
let mut mappings = HashMap::new();
|
|
||||||
|
|
||||||
// Postgres mappings
|
|
||||||
let mut postgres = HashMap::new();
|
|
||||||
postgres.insert("text", DataType::Text(None));
|
|
||||||
postgres.insert("varchar", DataType::Text(None));
|
|
||||||
postgres.insert("char", DataType::Text(None));
|
|
||||||
postgres.insert("int", DataType::Int4(None));
|
|
||||||
postgres.insert("integer", DataType::Int4(None));
|
|
||||||
postgres.insert("bigint", DataType::Int8(None));
|
|
||||||
postgres.insert("smallint", DataType::Int2(None));
|
|
||||||
postgres.insert("float", DataType::Float4(None));
|
|
||||||
postgres.insert("double precision", DataType::Float8(None));
|
|
||||||
postgres.insert("numeric", DataType::Decimal(None));
|
|
||||||
postgres.insert("decimal", DataType::Decimal(None));
|
|
||||||
postgres.insert("boolean", DataType::Bool(None));
|
|
||||||
postgres.insert("date", DataType::Date(None));
|
|
||||||
postgres.insert("timestamp", DataType::Timestamp(None));
|
|
||||||
postgres.insert("timestamptz", DataType::Timestamptz(None));
|
|
||||||
postgres.insert("json", DataType::Json(None));
|
|
||||||
postgres.insert("jsonb", DataType::Json(None));
|
|
||||||
mappings.insert(DataSourceType::Postgres, postgres.clone());
|
|
||||||
mappings.insert(DataSourceType::Supabase, postgres);
|
|
||||||
|
|
||||||
// MySQL mappings
|
|
||||||
let mut mysql = HashMap::new();
|
|
||||||
mysql.insert("varchar", DataType::Text(None));
|
|
||||||
mysql.insert("text", DataType::Text(None));
|
|
||||||
mysql.insert("char", DataType::Text(None));
|
|
||||||
mysql.insert("int", DataType::Int4(None));
|
|
||||||
mysql.insert("bigint", DataType::Int8(None));
|
|
||||||
mysql.insert("tinyint", DataType::Int2(None));
|
|
||||||
mysql.insert("float", DataType::Float4(None));
|
|
||||||
mysql.insert("double", DataType::Float8(None));
|
|
||||||
mysql.insert("decimal", DataType::Decimal(None));
|
|
||||||
mysql.insert("boolean", DataType::Bool(None));
|
|
||||||
mysql.insert("date", DataType::Date(None));
|
|
||||||
mysql.insert("datetime", DataType::Timestamp(None));
|
|
||||||
mysql.insert("timestamp", DataType::Timestamptz(None));
|
|
||||||
mysql.insert("json", DataType::Json(None));
|
|
||||||
mappings.insert(DataSourceType::MySql, mysql.clone());
|
|
||||||
mappings.insert(DataSourceType::Mariadb, mysql);
|
|
||||||
|
|
||||||
// BigQuery mappings
|
|
||||||
let mut bigquery = HashMap::new();
|
|
||||||
bigquery.insert("STRING", DataType::Text(None));
|
|
||||||
bigquery.insert("INT64", DataType::Int8(None));
|
|
||||||
bigquery.insert("INTEGER", DataType::Int4(None));
|
|
||||||
bigquery.insert("FLOAT64", DataType::Float8(None));
|
|
||||||
bigquery.insert("NUMERIC", DataType::Decimal(None));
|
|
||||||
bigquery.insert("BOOL", DataType::Bool(None));
|
|
||||||
bigquery.insert("DATE", DataType::Date(None));
|
|
||||||
bigquery.insert("TIMESTAMP", DataType::Timestamptz(None));
|
|
||||||
bigquery.insert("JSON", DataType::Json(None));
|
|
||||||
mappings.insert(DataSourceType::BigQuery, bigquery);
|
|
||||||
|
|
||||||
// Snowflake mappings
|
|
||||||
let mut snowflake = HashMap::new();
|
|
||||||
snowflake.insert("TEXT", DataType::Text(None));
|
|
||||||
snowflake.insert("VARCHAR", DataType::Text(None));
|
|
||||||
snowflake.insert("CHAR", DataType::Text(None));
|
|
||||||
snowflake.insert("NUMBER", DataType::Decimal(None));
|
|
||||||
snowflake.insert("DECIMAL", DataType::Decimal(None));
|
|
||||||
snowflake.insert("INTEGER", DataType::Int4(None));
|
|
||||||
snowflake.insert("BIGINT", DataType::Int8(None));
|
|
||||||
snowflake.insert("BOOLEAN", DataType::Bool(None));
|
|
||||||
snowflake.insert("DATE", DataType::Date(None));
|
|
||||||
snowflake.insert("TIMESTAMP", DataType::Timestamptz(None));
|
|
||||||
snowflake.insert("VARIANT", DataType::Json(None));
|
|
||||||
mappings.insert(DataSourceType::Snowflake, snowflake);
|
|
||||||
|
|
||||||
mappings
|
|
||||||
});
|
|
||||||
|
|
||||||
pub fn normalize_type(source_type: DataSourceType, type_str: &str) -> DataType {
|
|
||||||
TYPE_MAPPINGS
|
|
||||||
.get(&source_type)
|
|
||||||
.and_then(|mappings| mappings.get(type_str))
|
|
||||||
.cloned()
|
|
||||||
.unwrap_or(DataType::Unknown(Some(type_str.to_string())))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn types_compatible(source_type: DataSourceType, ds_type: &str, model_type: &str) -> bool {
|
|
||||||
let ds_data_type = normalize_type(source_type, ds_type);
|
|
||||||
let model_data_type = normalize_type(source_type, model_type);
|
|
||||||
|
|
||||||
match (&ds_data_type, &model_data_type) {
|
|
||||||
// Allow integer -> float/decimal conversions
|
|
||||||
(DataType::Int2(_), DataType::Float4(_)) => true,
|
|
||||||
(DataType::Int2(_), DataType::Float8(_)) => true,
|
|
||||||
(DataType::Int2(_), DataType::Decimal(_)) => true,
|
|
||||||
(DataType::Int4(_), DataType::Float4(_)) => true,
|
|
||||||
(DataType::Int4(_), DataType::Float8(_)) => true,
|
|
||||||
(DataType::Int4(_), DataType::Decimal(_)) => true,
|
|
||||||
(DataType::Int8(_), DataType::Float4(_)) => true,
|
|
||||||
(DataType::Int8(_), DataType::Float8(_)) => true,
|
|
||||||
(DataType::Int8(_), DataType::Decimal(_)) => true,
|
|
||||||
|
|
||||||
// Allow text for any type (common in views/computed columns)
|
|
||||||
(DataType::Text(_), _) => true,
|
|
||||||
|
|
||||||
// Allow timestamp/timestamptz compatibility
|
|
||||||
(DataType::Timestamp(_), DataType::Timestamptz(_)) => true,
|
|
||||||
(DataType::Timestamptz(_), DataType::Timestamp(_)) => true,
|
|
||||||
|
|
||||||
// Exact matches (using to_string to compare types, not values)
|
|
||||||
(a, b) if a.to_string() == b.to_string() => true,
|
|
||||||
|
|
||||||
// Everything else is incompatible
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_postgres_type_normalization() {
|
|
||||||
assert!(matches!(
|
|
||||||
normalize_type(DataSourceType::Postgres, "text"),
|
|
||||||
DataType::Text(_)
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
normalize_type(DataSourceType::Postgres, "integer"),
|
|
||||||
DataType::Int4(_)
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
normalize_type(DataSourceType::Postgres, "numeric"),
|
|
||||||
DataType::Decimal(_)
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_bigquery_type_normalization() {
|
|
||||||
assert!(matches!(
|
|
||||||
normalize_type(DataSourceType::BigQuery, "STRING"),
|
|
||||||
DataType::Text(_)
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
normalize_type(DataSourceType::BigQuery, "INT64"),
|
|
||||||
DataType::Int8(_)
|
|
||||||
));
|
|
||||||
assert!(matches!(
|
|
||||||
normalize_type(DataSourceType::BigQuery, "FLOAT64"),
|
|
||||||
DataType::Float8(_)
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_type_compatibility() {
|
|
||||||
// Same types are compatible
|
|
||||||
assert!(types_compatible(DataSourceType::Postgres, "text", "text"));
|
|
||||||
|
|
||||||
// Integer can be used as float
|
|
||||||
assert!(types_compatible(
|
|
||||||
DataSourceType::Postgres,
|
|
||||||
"integer",
|
|
||||||
"float"
|
|
||||||
));
|
|
||||||
|
|
||||||
// Text can be used for any type
|
|
||||||
assert!(types_compatible(
|
|
||||||
DataSourceType::Postgres,
|
|
||||||
"text",
|
|
||||||
"integer"
|
|
||||||
));
|
|
||||||
|
|
||||||
// Different types are incompatible
|
|
||||||
assert!(!types_compatible(
|
|
||||||
DataSourceType::Postgres,
|
|
||||||
"integer",
|
|
||||||
"text"
|
|
||||||
));
|
|
||||||
|
|
||||||
// Timestamp compatibility
|
|
||||||
assert!(types_compatible(
|
|
||||||
DataSourceType::Postgres,
|
|
||||||
"timestamp",
|
|
||||||
"timestamptz"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,182 +0,0 @@
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
||||||
pub struct ValidationResult {
|
|
||||||
pub success: bool,
|
|
||||||
pub model_name: String,
|
|
||||||
pub data_source_name: String,
|
|
||||||
pub schema: String,
|
|
||||||
pub errors: Vec<ValidationError>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
||||||
pub struct ValidationError {
|
|
||||||
pub error_type: ValidationErrorType,
|
|
||||||
pub column_name: Option<String>,
|
|
||||||
pub message: String,
|
|
||||||
pub suggestion: Option<String>,
|
|
||||||
pub context: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
|
||||||
pub enum ValidationErrorType {
|
|
||||||
TableNotFound,
|
|
||||||
ColumnNotFound,
|
|
||||||
TypeMismatch,
|
|
||||||
DataSourceError,
|
|
||||||
ModelNotFound,
|
|
||||||
InvalidRelationship,
|
|
||||||
ExpressionError,
|
|
||||||
ProjectNotFound,
|
|
||||||
InvalidBusterYml,
|
|
||||||
DataSourceMismatch,
|
|
||||||
RequiredFieldMissing,
|
|
||||||
DataSourceNotFound,
|
|
||||||
SchemaError,
|
|
||||||
CredentialsError,
|
|
||||||
DatabaseError,
|
|
||||||
ValidationError,
|
|
||||||
InternalError,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ValidationResult {
|
|
||||||
pub fn new(model_name: String, data_source_name: String, schema: String) -> Self {
|
|
||||||
Self {
|
|
||||||
success: true,
|
|
||||||
model_name,
|
|
||||||
data_source_name,
|
|
||||||
schema,
|
|
||||||
errors: Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_error(&mut self, error: ValidationError) {
|
|
||||||
self.success = false;
|
|
||||||
self.errors.push(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ValidationError {
|
|
||||||
pub fn new(
|
|
||||||
error_type: ValidationErrorType,
|
|
||||||
column_name: Option<String>,
|
|
||||||
message: String,
|
|
||||||
suggestion: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
error_type,
|
|
||||||
column_name,
|
|
||||||
message,
|
|
||||||
suggestion,
|
|
||||||
context: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_context(mut self, context: String) -> Self {
|
|
||||||
self.context = Some(context);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn table_not_found(table_name: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::TableNotFound,
|
|
||||||
None,
|
|
||||||
format!("Table '{}' not found in data source", table_name),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn column_not_found(column_name: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::ColumnNotFound,
|
|
||||||
Some(column_name.to_string()),
|
|
||||||
format!("Column '{}' not found in data source", column_name),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn type_mismatch(column_name: &str, expected: &str, found: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::TypeMismatch,
|
|
||||||
Some(column_name.to_string()),
|
|
||||||
format!(
|
|
||||||
"Column '{}' type mismatch. Expected: {}, Found: {}",
|
|
||||||
column_name, expected, found
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn data_source_error(message: String) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::DataSourceError,
|
|
||||||
None,
|
|
||||||
message,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn model_not_found(model_name: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::ModelNotFound,
|
|
||||||
None,
|
|
||||||
format!("Model '{}' not found in data source", model_name),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn invalid_relationship(from: &str, to: &str, reason: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::InvalidRelationship,
|
|
||||||
None,
|
|
||||||
format!("Invalid relationship from '{}' to '{}': {}", from, to, reason),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn expression_error(column_name: &str, expr: &str, reason: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::ExpressionError,
|
|
||||||
Some(column_name.to_string()),
|
|
||||||
format!("Invalid expression '{}' for column '{}': {}", expr, column_name, reason),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// New factory methods for enhanced error types
|
|
||||||
pub fn schema_error(schema_name: &str, reason: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::SchemaError,
|
|
||||||
None,
|
|
||||||
format!("Schema '{}' error: {}", schema_name, reason),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn credentials_error(data_source: &str, reason: &str) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::CredentialsError,
|
|
||||||
None,
|
|
||||||
format!("Credentials error for data source '{}': {}", data_source, reason),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn database_error(message: String) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::DatabaseError,
|
|
||||||
None,
|
|
||||||
message,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn internal_error(message: String) -> Self {
|
|
||||||
Self::new(
|
|
||||||
ValidationErrorType::InternalError,
|
|
||||||
None,
|
|
||||||
message,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue