mirror of https://github.com/buster-so/buster.git
Merge branch 'evals' of https://github.com/buster-so/buster into evals
This commit is contained in:
commit
93439d45b0
|
@ -36,14 +36,14 @@ pub async fn post_dataset(
|
|||
Ok(None) => {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
"User does not belong to any organization",
|
||||
"User does not belong to any organization".to_string(),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error getting user organization id: {:?}", e);
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Error getting user organization id",
|
||||
"Error getting user organization id".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
@ -53,14 +53,14 @@ pub async fn post_dataset(
|
|||
Ok(false) => {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
"Insufficient permissions",
|
||||
"Insufficient permissions".to_string(),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error checking user permissions: {:?}", e);
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Error checking user permissions",
|
||||
"Error checking user permissions".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ pub async fn post_dataset(
|
|||
tracing::error!("Error creating dataset: {:?}", e);
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Error creating dataset",
|
||||
"Error creating dataset".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
pub mod types;
|
|
@ -1,383 +0,0 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum ViewType {
|
||||
Chart,
|
||||
Table,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum ChartType {
|
||||
Line,
|
||||
Bar,
|
||||
Scatter,
|
||||
Pie,
|
||||
Metric,
|
||||
Table,
|
||||
Combo,
|
||||
}
|
||||
|
||||
impl ChartType {
|
||||
pub fn to_string(&self) -> String {
|
||||
match self {
|
||||
ChartType::Line => "line".to_string(),
|
||||
ChartType::Bar => "bar".to_string(),
|
||||
ChartType::Scatter => "scatter".to_string(),
|
||||
ChartType::Pie => "pie".to_string(),
|
||||
ChartType::Metric => "metric".to_string(),
|
||||
ChartType::Table => "table".to_string(),
|
||||
ChartType::Combo => "combo".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_string(chart_type: &str) -> ChartType {
|
||||
match chart_type {
|
||||
"line" => ChartType::Line,
|
||||
"bar" => ChartType::Bar,
|
||||
"scatter" => ChartType::Scatter,
|
||||
"pie" => ChartType::Pie,
|
||||
"metric" => ChartType::Metric,
|
||||
"table" => ChartType::Table,
|
||||
"combo" => ChartType::Combo,
|
||||
_ => ChartType::Table,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BusterChartConfig {
|
||||
pub selected_chart_type: ChartType,
|
||||
pub selected_view: ViewType,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub column_label_formats: Option<HashMap<String, ColumnLabelFormat>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub column_settings: Option<HashMap<String, ColumnSettings>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub colors: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub show_legend: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub grid_lines: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub show_legend_headline: Option<ShowLegendHeadline>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub goal_lines: Option<Vec<GoalLine>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub trendlines: Option<Vec<Trendline>>,
|
||||
#[serde(flatten)]
|
||||
pub y_axis_config: YAxisConfig,
|
||||
#[serde(flatten)]
|
||||
pub x_axis_config: XAxisConfig,
|
||||
#[serde(flatten)]
|
||||
pub y2_axis_config: Y2AxisConfig,
|
||||
#[serde(flatten)]
|
||||
pub bar_chart_props: BarChartProps,
|
||||
#[serde(flatten)]
|
||||
pub line_chart_props: LineChartProps,
|
||||
#[serde(flatten)]
|
||||
pub scatter_chart_props: ScatterChartProps,
|
||||
#[serde(flatten)]
|
||||
pub pie_chart_props: PieChartProps,
|
||||
#[serde(flatten)]
|
||||
pub table_chart_props: TableChartProps,
|
||||
#[serde(flatten)]
|
||||
pub combo_chart_props: ComboChartProps,
|
||||
#[serde(flatten)]
|
||||
pub metric_chart_props: MetricChartProps,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct LineChartProps {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_style: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_group_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct BarChartProps {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bar_and_line_axis: Option<BarAndLineAxis>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bar_layout: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bar_sort_by: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bar_group_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bar_show_total_at_top: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct ScatterChartProps {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scatter_axis: Option<ScatterAxis>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub scatter_dot_size: Option<(f64, f64)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct PieChartProps {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_chart_axis: Option<PieChartAxis>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_display_label_as: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_show_inner_label: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_inner_label_aggregate: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_inner_label_title: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_label_position: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_donut_width: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pie_minimum_slice_percentage: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct TableChartProps {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_column_order: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_column_widths: Option<HashMap<String, f64>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_header_background_color: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_header_font_color: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_column_font_color: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct ComboChartProps {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub combo_chart_axis: Option<ComboChartAxis>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct MetricChartProps {
|
||||
pub metric_column_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metric_value_aggregate: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metric_header: Option<MetricTitle>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metric_sub_header: Option<MetricTitle>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metric_value_label: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MetricTitle {
|
||||
String(String),
|
||||
Derived(DerivedMetricTitle),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DerivedMetricTitle {
|
||||
pub column_id: String,
|
||||
pub use_value: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ShowLegendHeadline {
|
||||
Bool(bool),
|
||||
String(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ColumnSettings {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub show_data_labels: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub column_visualization: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub line_settings: LineColumnSettings,
|
||||
#[serde(flatten)]
|
||||
pub bar_settings: BarColumnSettings,
|
||||
#[serde(flatten)]
|
||||
pub dot_settings: DotColumnSettings,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct LineColumnSettings {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_dash_style: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_width: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_style: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_symbol_size: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct BarColumnSettings {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bar_roundness: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct DotColumnSettings {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_symbol_size: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ColumnLabelFormat {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub style: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub column_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub number_separator_style: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub minimum_fraction_digits: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub maximum_fraction_digits: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub multiplier: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prefix: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suffix: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub replace_missing_data_with: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_relative_time: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub is_utc: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub make_label_human_readable: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub currency: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub date_format: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub convert_number_to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GoalLine {
|
||||
pub show: bool,
|
||||
pub value: f64,
|
||||
pub show_goal_line_label: bool,
|
||||
pub goal_line_label: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub goal_line_color: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Trendline {
|
||||
pub show: bool,
|
||||
pub show_trendline_label: bool,
|
||||
pub trendline_label: Option<String>,
|
||||
pub type_: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub trendline_color: Option<String>,
|
||||
pub column_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct YAxisConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y_axis_show_axis_label: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y_axis_show_axis_title: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y_axis_axis_title: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y_axis_start_axis_at_zero: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y_axis_scale_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct Y2AxisConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y2_axis_show_axis_label: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y2_axis_show_axis_title: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y2_axis_axis_title: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y2_axis_start_axis_at_zero: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y2_axis_scale_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct XAxisConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub x_axis_show_ticks: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub x_axis_show_axis_label: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub x_axis_show_axis_title: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub x_axis_axis_title: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub x_axis_label_rotation: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub x_axis_data_zoom: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct CategoryAxisStyleConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub category_show_total_at_top: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub category_axis_title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BarAndLineAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
pub category: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ScatterAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub category: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ComboChartAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub y2: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub category: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PieChartAxis {
|
||||
pub x: Vec<String>,
|
||||
pub y: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tooltip: Option<Vec<String>>,
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use futures::StreamExt;
|
||||
use std::{env, time::Duration};
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
||||
|
||||
const ANTHROPIC_CHAT_URL: &str = "https://api.anthropic.com/v1/messages";
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref ANTHROPIC_API_KEY: String = env::var("ANTHROPIC_API_KEY")
|
||||
.expect("ANTHROPIC_API_KEY must be set");
|
||||
static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED")
|
||||
.unwrap_or(String::from("true"))
|
||||
.parse()
|
||||
.expect("MONITORING_ENABLED must be a boolean");
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub enum AnthropicChatModel {
|
||||
#[serde(rename = "claude-3-opus-20240229")]
|
||||
Claude3Opus20240229,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AnthropicChatRole {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AnthropicContentType {
|
||||
Text,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct AnthropicContent {
|
||||
#[serde(rename = "type")]
|
||||
pub _type: AnthropicContentType,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct AnthropicChatMessage {
|
||||
pub role: AnthropicChatRole,
|
||||
pub content: Vec<AnthropicContent>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct AnthropicChatRequest {
|
||||
pub model: AnthropicChatModel,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
pub messages: Vec<AnthropicChatMessage>,
|
||||
pub temperature: f32,
|
||||
pub max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub content: Vec<Content>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Content {
|
||||
#[serde(rename = "type")]
|
||||
pub _type: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub async fn anthropic_chat(
|
||||
model: &AnthropicChatModel,
|
||||
system: Option<String>,
|
||||
messages: &Vec<AnthropicChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
) -> Result<String> {
|
||||
let chat_request = AnthropicChatRequest {
|
||||
model: model.clone(),
|
||||
system,
|
||||
messages: messages.clone(),
|
||||
temperature,
|
||||
max_tokens,
|
||||
stop_sequences: stop,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let headers = {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"x-api-key",
|
||||
format!("{}", ANTHROPIC_API_KEY.to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
|
||||
headers
|
||||
};
|
||||
|
||||
let response = match client
|
||||
.post(ANTHROPIC_CHAT_URL)
|
||||
.headers(headers)
|
||||
.json(&chat_request)
|
||||
.timeout(Duration::from_secs(timeout))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to send request to Anthropic: {:?}", e);
|
||||
let err = anyhow!("Unable to send request to Anthropic: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let completion_res = match response.json::<ChatCompletionResponse>().await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to parse response from Anthropic: {:?}", e);
|
||||
let err = anyhow!("Unable to parse response from Anthropic: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let content = match completion_res.content.get(0) {
|
||||
Some(content) => content.text.clone(),
|
||||
None => return Err(anyhow!("No content returned from Anthropic")),
|
||||
};
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct AnthropicChatDelta {
|
||||
#[serde(rename = "type")]
|
||||
pub _type: String,
|
||||
pub delta: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct AnthropicChatStreamResponse {
|
||||
#[serde(rename = "type")]
|
||||
pub _type: String,
|
||||
pub delta: Option<AnthropicChatDelta>,
|
||||
}
|
||||
|
||||
pub async fn anthropic_chat_stream(
|
||||
model: &AnthropicChatModel,
|
||||
system: Option<String>,
|
||||
messages: &Vec<AnthropicChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
) -> Result<ReceiverStream<String>> {
|
||||
let chat_request = AnthropicChatRequest {
|
||||
model: model.clone(),
|
||||
system,
|
||||
messages: messages.clone(),
|
||||
temperature,
|
||||
max_tokens,
|
||||
stream: true,
|
||||
stop_sequences: stop,
|
||||
};
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let headers = {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", ANTHROPIC_API_KEY.to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers
|
||||
};
|
||||
|
||||
let (_tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let response = client
|
||||
.post(ANTHROPIC_CHAT_URL)
|
||||
.headers(headers)
|
||||
.json(&chat_request)
|
||||
.timeout(Duration::from_secs(timeout))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Unable to send request to Anthropic: {:?}", e);
|
||||
let err = anyhow!("Unable to send request to Anthropic: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
err
|
||||
});
|
||||
|
||||
if let Err(_e) = response {
|
||||
return;
|
||||
}
|
||||
|
||||
let response = response.unwrap();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
let _buffer = String::new();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(bytes) => {
|
||||
let chunk = String::from_utf8(bytes.to_vec()).unwrap();
|
||||
println!("----------------------");
|
||||
println!("Chunk: {}", chunk);
|
||||
println!("----------------------");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error while streaming response: {:?}", e);
|
||||
let err = anyhow!("Error while streaming response: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
use std::env;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task;
|
||||
|
||||
use super::{
|
||||
hugging_face::hugging_face_embedding, ollama::ollama_embedding, openai::ada_bulk_embedding,
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub embedding: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
pub enum EmbeddingProvider {
|
||||
OpenAi,
|
||||
Ollama,
|
||||
HuggingFace,
|
||||
}
|
||||
|
||||
impl EmbeddingProvider {
|
||||
pub fn get_embedding_provider() -> Result<EmbeddingProvider> {
|
||||
let embedding_provider =
|
||||
env::var("EMBEDDING_PROVIDER").expect("An embedding provider is required.");
|
||||
match embedding_provider.as_str() {
|
||||
"openai" => Ok(EmbeddingProvider::OpenAi),
|
||||
"ollama" => Ok(EmbeddingProvider::Ollama),
|
||||
"huggingface" => Ok(EmbeddingProvider::HuggingFace),
|
||||
_ => Err(anyhow!("Invalid embedding provider")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embedding_router(prompts: Vec<String>, for_retrieval: bool) -> Result<Vec<Vec<f32>>> {
|
||||
let embedding_provider = EmbeddingProvider::get_embedding_provider()?;
|
||||
|
||||
match embedding_provider {
|
||||
EmbeddingProvider::Ollama => {
|
||||
let tasks = prompts.into_iter().map(|prompt| {
|
||||
task::spawn(async move { ollama_embedding(prompt, for_retrieval).await })
|
||||
});
|
||||
let results = futures::future::join_all(tasks).await;
|
||||
let embeddings: Result<Vec<Vec<f32>>> = results
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.collect();
|
||||
embeddings
|
||||
}
|
||||
EmbeddingProvider::HuggingFace => hugging_face_embedding(prompts).await,
|
||||
EmbeddingProvider::OpenAi => ada_bulk_embedding(prompts).await,
|
||||
}
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use serde::Serialize;
|
||||
use std::env;
|
||||
|
||||
use axum::http::HeaderMap;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct HuggingFaceEmbeddingRequest {
|
||||
pub inputs: Vec<String>,
|
||||
}
|
||||
|
||||
pub async fn hugging_face_embedding(prompts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
let hugging_face_url = env::var("HUGGING_FACE_URL").expect("HUGGING_FACE_URL must be set");
|
||||
let hugging_face_api_key =
|
||||
env::var("HUGGING_FACE_API_KEY").expect("HUGGING_FACE_API_KEY must be set");
|
||||
|
||||
let client = match reqwest::Client::builder().build() {
|
||||
Ok(client) => client,
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Error creating reqwest client: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", hugging_face_api_key).parse().unwrap(),
|
||||
);
|
||||
|
||||
let req = HuggingFaceEmbeddingRequest { inputs: prompts };
|
||||
|
||||
let res = match client
|
||||
.post(hugging_face_url)
|
||||
.headers(headers)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Error sending Ollama request: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = match res.json::<Vec<Vec<f32>>>().await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Error parsing Ollama response: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
|
@ -1,394 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use axum::http::HeaderMap;
|
||||
use base64::Engine;
|
||||
use reqwest::Method;
|
||||
use std::env;
|
||||
use tiktoken_rs::o200k_base;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
||||
|
||||
use super::{anthropic::AnthropicChatModel, llm_router::LlmModel, openai::OpenAiChatModel};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref LANGFUSE_API_URL: String = env::var("LANGFUSE_API_URL").unwrap_or("https://us.cloud.langfuse.com".to_string());
|
||||
static ref LANGFUSE_API_PUBLIC_KEY: String = env::var("LANGFUSE_PUBLIC_API_KEY").expect("LANGFUSE_PUBLIC_API_KEY must be set");
|
||||
static ref LANGFUSE_API_PRIVATE_KEY: String = env::var("LANGFUSE_PRIVATE_API_KEY").expect("LANGFUSE_PRIVATE_API_KEY must be set");
|
||||
}
|
||||
|
||||
impl LlmModel {
|
||||
pub fn generate_usage(&self, input: &String, output: &String) -> Usage {
|
||||
let bpe = o200k_base().unwrap();
|
||||
|
||||
let input_token = bpe.encode_with_special_tokens(&input);
|
||||
let output_token = bpe.encode_with_special_tokens(&output);
|
||||
|
||||
match self {
|
||||
LlmModel::OpenAi(OpenAiChatModel::O3Mini) => Usage {
|
||||
input: input_token.len() as u32,
|
||||
output: output_token.len() as u32,
|
||||
unit: "TOKENS".to_string(),
|
||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5,
|
||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 10.0,
|
||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5
|
||||
+ (output_token.len() as f64 / 1_000_000.0) * 10.0,
|
||||
},
|
||||
LlmModel::OpenAi(OpenAiChatModel::Gpt35Turbo) => Usage {
|
||||
input: input_token.len() as u32,
|
||||
output: output_token.len() as u32,
|
||||
unit: "TOKENS".to_string(),
|
||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5,
|
||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 1.5,
|
||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5
|
||||
+ (output_token.len() as f64 / 1_000_000.0) * 1.5,
|
||||
},
|
||||
LlmModel::OpenAi(OpenAiChatModel::Gpt4o) => Usage {
|
||||
input: input_token.len() as u32,
|
||||
output: output_token.len() as u32,
|
||||
unit: "TOKENS".to_string(),
|
||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15,
|
||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 0.6,
|
||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15
|
||||
+ (output_token.len() as f64 / 1_000_000.0) * 0.6,
|
||||
},
|
||||
LlmModel::Anthropic(AnthropicChatModel::Claude3Opus20240229) => Usage {
|
||||
input: input_token.len() as u32,
|
||||
output: output_token.len() as u32,
|
||||
unit: "TOKENS".to_string(),
|
||||
input_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0,
|
||||
output_cost: (output_token.len() as f64 / 1_000_000.0) * 15.0,
|
||||
total_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0
|
||||
+ (output_token.len() as f64 / 1_000_000.0) * 15.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PromptName {
|
||||
SelectDataset,
|
||||
GenerateSql,
|
||||
SelectTerm,
|
||||
DataSummary,
|
||||
AutoChartConfig,
|
||||
LineChartConfig,
|
||||
BarChartConfig,
|
||||
ScatterChartConfig,
|
||||
PieChartConfig,
|
||||
MetricChartConfig,
|
||||
TableConfig,
|
||||
ColumnLabelFormat,
|
||||
AdvancedVisualizationConfig,
|
||||
NoDataReturnedResponse,
|
||||
DataExplanation,
|
||||
MetricTitle,
|
||||
TimeFrame,
|
||||
FixSqlPlanner,
|
||||
FixSql,
|
||||
GenerateColDescriptions,
|
||||
GenerateDatasetDescription,
|
||||
SummaryQuestion,
|
||||
CustomPrompt(String),
|
||||
}
|
||||
|
||||
impl PromptName {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
PromptName::SelectDataset => "select_dataset".to_string(),
|
||||
PromptName::GenerateSql => "generate_sql".to_string(),
|
||||
PromptName::SelectTerm => "select_term".to_string(),
|
||||
PromptName::DataSummary => "data_summary".to_string(),
|
||||
PromptName::AutoChartConfig => "auto_chart_config".to_string(),
|
||||
PromptName::LineChartConfig => "line_chart_config".to_string(),
|
||||
PromptName::BarChartConfig => "bar_chart_config".to_string(),
|
||||
PromptName::ScatterChartConfig => "scatter_chart_config".to_string(),
|
||||
PromptName::PieChartConfig => "pie_chart_config".to_string(),
|
||||
PromptName::MetricChartConfig => "metric_chart_config".to_string(),
|
||||
PromptName::TableConfig => "table_config".to_string(),
|
||||
PromptName::ColumnLabelFormat => "column_label_format".to_string(),
|
||||
PromptName::AdvancedVisualizationConfig => "advanced_visualization_config".to_string(),
|
||||
PromptName::NoDataReturnedResponse => "no_data_returned_response".to_string(),
|
||||
PromptName::DataExplanation => "data_explanation".to_string(),
|
||||
PromptName::MetricTitle => "metric_title".to_string(),
|
||||
PromptName::TimeFrame => "time_frame".to_string(),
|
||||
PromptName::FixSqlPlanner => "fix_sql_planner".to_string(),
|
||||
PromptName::FixSql => "fix_sql".to_string(),
|
||||
PromptName::GenerateColDescriptions => "generate_col_descriptions".to_string(),
|
||||
PromptName::GenerateDatasetDescription => "generate_dataset_description".to_string(),
|
||||
PromptName::SummaryQuestion => "summary_question".to_string(),
|
||||
PromptName::CustomPrompt(prompt) => prompt.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
enum LangfuseIngestionType {
|
||||
TraceCreate,
|
||||
GenerationCreate,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct CreateTraceBody {
|
||||
id: Uuid,
|
||||
timestamp: DateTime<Utc>,
|
||||
name: String,
|
||||
user_id: Uuid,
|
||||
input: String,
|
||||
output: String,
|
||||
session_id: Uuid,
|
||||
release: String,
|
||||
version: String,
|
||||
metadata: Metadata,
|
||||
tags: Vec<String>,
|
||||
public: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GenerationCreateBody {
|
||||
trace_id: Uuid,
|
||||
name: String,
|
||||
start_time: DateTime<Utc>,
|
||||
completion_start_time: DateTime<Utc>,
|
||||
input: String,
|
||||
output: String,
|
||||
level: String,
|
||||
end_time: DateTime<Utc>,
|
||||
model: LlmModel,
|
||||
id: Uuid,
|
||||
usage: Usage,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Usage {
|
||||
input: u32,
|
||||
output: u32,
|
||||
unit: String,
|
||||
input_cost: f64,
|
||||
output_cost: f64,
|
||||
total_cost: f64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
struct Metadata {}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum LangfuseRequestBody {
|
||||
CreateTraceBody(CreateTraceBody),
|
||||
GenerationCreateBody(GenerationCreateBody),
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct LangfuseBatchItem {
|
||||
id: Uuid,
|
||||
r#type: LangfuseIngestionType,
|
||||
body: LangfuseRequestBody,
|
||||
timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct LangfuseBatch {
|
||||
batch: Vec<LangfuseBatchItem>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct LangfuseError {
|
||||
id: String,
|
||||
status: u16,
|
||||
message: String,
|
||||
error: String,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct LangfuseSuccess {
|
||||
id: String,
|
||||
status: u16,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct LangfuseResponse {
|
||||
successes: Vec<LangfuseSuccess>,
|
||||
errors: Vec<LangfuseError>,
|
||||
}
|
||||
|
||||
/// Args:
|
||||
/// - session_id: this can be a thread_id or any other type of chain event we have
|
||||
///
|
||||
|
||||
pub async fn send_langfuse_request(
|
||||
session_id: &Uuid,
|
||||
prompt_name: PromptName,
|
||||
context: Option<String>,
|
||||
start_time: DateTime<Utc>,
|
||||
end_time: DateTime<Utc>,
|
||||
input: String,
|
||||
output: String,
|
||||
user_id: &Uuid,
|
||||
langfuse_model: &LlmModel,
|
||||
) -> () {
|
||||
let session_id = session_id.clone();
|
||||
let user_id = user_id.clone();
|
||||
let langfuse_model = langfuse_model.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match langfuse_handler(
|
||||
session_id,
|
||||
prompt_name,
|
||||
context,
|
||||
start_time,
|
||||
end_time,
|
||||
input,
|
||||
output,
|
||||
user_id,
|
||||
langfuse_model,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
let err = anyhow!("Error sending Langfuse request: {:?}", e);
|
||||
send_sentry_error(&err.to_string(), Some(&user_id));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn langfuse_handler(
|
||||
session_id: Uuid,
|
||||
prompt_name: PromptName,
|
||||
context: Option<String>,
|
||||
start_time: DateTime<Utc>,
|
||||
end_time: DateTime<Utc>,
|
||||
input: String,
|
||||
output: String,
|
||||
user_id: Uuid,
|
||||
langfuse_model: LlmModel,
|
||||
) -> Result<()> {
|
||||
let input = match context {
|
||||
Some(context) => format!("{} \n\n {}", context, input),
|
||||
None => input,
|
||||
};
|
||||
|
||||
let trace_id = Uuid::new_v4();
|
||||
|
||||
let langfuse_trace = LangfuseBatchItem {
|
||||
id: Uuid::new_v4(),
|
||||
r#type: LangfuseIngestionType::TraceCreate,
|
||||
body: LangfuseRequestBody::CreateTraceBody(CreateTraceBody {
|
||||
id: trace_id.clone(),
|
||||
timestamp: Utc::now(),
|
||||
name: prompt_name.clone().to_string(),
|
||||
user_id: user_id,
|
||||
input: serde_json::to_string(&input).unwrap(),
|
||||
output: serde_json::to_string(&output).unwrap(),
|
||||
session_id: session_id,
|
||||
release: "1.0.0".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
metadata: Metadata {},
|
||||
tags: vec![],
|
||||
public: false,
|
||||
}),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
let langfuse_generation = LangfuseBatchItem {
|
||||
id: Uuid::new_v4(),
|
||||
r#type: LangfuseIngestionType::GenerationCreate,
|
||||
body: LangfuseRequestBody::GenerationCreateBody(GenerationCreateBody {
|
||||
id: Uuid::new_v4(),
|
||||
name: prompt_name.to_string(),
|
||||
input: serde_json::to_string(&input).unwrap(),
|
||||
output: serde_json::to_string(&output).unwrap(),
|
||||
trace_id,
|
||||
start_time,
|
||||
completion_start_time: start_time,
|
||||
level: "DEBUG".to_string(),
|
||||
end_time,
|
||||
model: langfuse_model.clone(),
|
||||
usage: langfuse_model.generate_usage(&input, &output),
|
||||
}),
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
|
||||
let langfuse_batch = LangfuseBatch {
|
||||
batch: vec![langfuse_trace, langfuse_generation],
|
||||
};
|
||||
|
||||
let client = match reqwest::Client::builder().build() {
|
||||
Ok(client) => client,
|
||||
Err(e) => {
|
||||
send_sentry_error(&e.to_string(), Some(&user_id));
|
||||
return Err(anyhow!("Error creating reqwest client: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!(
|
||||
"Basic {}",
|
||||
base64::engine::general_purpose::STANDARD.encode(format!(
|
||||
"{}:{}",
|
||||
*LANGFUSE_API_PUBLIC_KEY, *LANGFUSE_API_PRIVATE_KEY
|
||||
))
|
||||
)
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let res = match client
|
||||
.request(
|
||||
Method::POST,
|
||||
LANGFUSE_API_URL.to_string() + "/api/public/ingestion",
|
||||
)
|
||||
.headers(headers)
|
||||
.json(&langfuse_batch)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::debug!("Error sending Langfuse request: {:?}", e);
|
||||
send_sentry_error(&e.to_string(), Some(&user_id));
|
||||
return Err(anyhow!("Error sending Langfuse request: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let langfuse_res: LangfuseResponse = match res.json::<LangfuseResponse>().await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::debug!("Error parsing Langfuse response: {:?}", e);
|
||||
send_sentry_error(&e.to_string(), Some(&user_id));
|
||||
return Err(anyhow!("Error parsing Langfuse response: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
if langfuse_res.errors.len() > 0 {
|
||||
for err in &langfuse_res.errors {
|
||||
tracing::debug!("Langfuse error: {:?}", err);
|
||||
send_sentry_error(&err.message, Some(&user_id));
|
||||
}
|
||||
|
||||
return Err(anyhow::anyhow!(
|
||||
"Langfuse errors: {:?}",
|
||||
langfuse_res.errors
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,383 +0,0 @@
|
|||
use std::env;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::Utc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::{
|
||||
sync::mpsc::{self, Receiver},
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
anthropic::{
|
||||
anthropic_chat, anthropic_chat_stream, AnthropicChatMessage, AnthropicChatModel,
|
||||
AnthropicChatRole, AnthropicContent, AnthropicContentType,
|
||||
},
|
||||
langfuse::{send_langfuse_request, PromptName},
|
||||
openai::{
|
||||
openai_chat, openai_chat_stream, OpenAiChatContent, OpenAiChatMessage, OpenAiChatModel,
|
||||
OpenAiChatRole,
|
||||
},
|
||||
};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
lazy_static! {
|
||||
static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED")
|
||||
.unwrap_or(String::from("true"))
|
||||
.parse()
|
||||
.expect("MONITORING_ENABLED must be a boolean");
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum LlmModel {
|
||||
Anthropic(AnthropicChatModel),
|
||||
OpenAi(OpenAiChatModel),
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum LlmRole {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct LlmMessage {
|
||||
pub role: LlmRole,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl LlmMessage {
|
||||
pub fn new(role: String, content: String) -> Self {
|
||||
let role = match role.to_lowercase().as_str() {
|
||||
"system" => LlmRole::System,
|
||||
"user" => LlmRole::User,
|
||||
"assistant" => LlmRole::Assistant,
|
||||
_ => LlmRole::User,
|
||||
};
|
||||
Self { role, content }
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn llm_chat(
|
||||
model: LlmModel,
|
||||
messages: &Vec<LlmMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
json_mode: bool,
|
||||
json_schema: Option<Value>,
|
||||
session_id: &Uuid,
|
||||
user_id: &Uuid,
|
||||
prompt_name: PromptName,
|
||||
) -> Result<String> {
|
||||
let start_time = Utc::now();
|
||||
|
||||
let response_result = match &model {
|
||||
LlmModel::Anthropic(model) => {
|
||||
anthropic_chat_compiler(model, messages, max_tokens, temperature, timeout, stop).await
|
||||
}
|
||||
LlmModel::OpenAi(model) => {
|
||||
openai_chat_compiler(
|
||||
model,
|
||||
messages,
|
||||
7048,
|
||||
temperature,
|
||||
timeout,
|
||||
stop,
|
||||
json_mode,
|
||||
json_schema,
|
||||
)
|
||||
.await
|
||||
}
|
||||
};
|
||||
|
||||
let response = match response_result {
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow!("LLM chat error: {}", e)),
|
||||
};
|
||||
|
||||
let end_time = Utc::now();
|
||||
|
||||
send_langfuse_request(
|
||||
session_id,
|
||||
prompt_name,
|
||||
None,
|
||||
start_time,
|
||||
end_time,
|
||||
serde_json::to_string(&messages).unwrap(),
|
||||
serde_json::to_string(&response).unwrap(),
|
||||
user_id,
|
||||
&model,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn llm_chat_stream(
|
||||
model: LlmModel,
|
||||
messages: Vec<LlmMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
session_id: &Uuid,
|
||||
user_id: &Uuid,
|
||||
prompt_name: PromptName,
|
||||
) -> Result<(Receiver<String>, JoinHandle<Result<String>>)> {
|
||||
let start_time = Utc::now();
|
||||
|
||||
let stream_result = match &model {
|
||||
LlmModel::Anthropic(model) => {
|
||||
anthropic_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop)
|
||||
.await
|
||||
}
|
||||
LlmModel::OpenAi(model) => {
|
||||
openai_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop)
|
||||
.await
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream = match stream_result {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => return Err(anyhow!("LLM chat error: {}", e)),
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::channel(100);
|
||||
|
||||
let res_future = {
|
||||
let session_id = session_id.clone();
|
||||
let user_id = user_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut response = String::new();
|
||||
while let Some(content) = stream.next().await {
|
||||
response.push_str(&content);
|
||||
|
||||
match tx.send(content).await {
|
||||
Ok(_) => (),
|
||||
Err(e) => return Err(anyhow!("Streaming Error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
let end_time = Utc::now();
|
||||
|
||||
send_langfuse_request(
|
||||
&session_id,
|
||||
prompt_name,
|
||||
None,
|
||||
start_time,
|
||||
end_time,
|
||||
serde_json::to_string(&messages).unwrap(),
|
||||
serde_json::to_string(&response).unwrap(),
|
||||
&user_id,
|
||||
&model,
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
};
|
||||
|
||||
Ok((rx, res_future))
|
||||
}
|
||||
|
||||
async fn anthropic_chat_compiler(
|
||||
model: &AnthropicChatModel,
|
||||
messages: &Vec<LlmMessage>,
|
||||
_max_tokens: u32,
|
||||
temperature: f32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
) -> Result<String> {
|
||||
let system_message = match messages.iter().find(|m| m.role == LlmRole::System) {
|
||||
Some(message) => Some(message.content.clone()),
|
||||
None => None,
|
||||
};
|
||||
let mut anthropic_messages = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
let anthropic_role = match message.role {
|
||||
LlmRole::System => continue,
|
||||
LlmRole::User => AnthropicChatRole::User,
|
||||
LlmRole::Assistant => AnthropicChatRole::Assistant,
|
||||
};
|
||||
|
||||
let anthropic_content = AnthropicContent {
|
||||
text: message.content.clone(),
|
||||
_type: AnthropicContentType::Text,
|
||||
};
|
||||
|
||||
anthropic_messages.push(AnthropicChatMessage {
|
||||
role: anthropic_role,
|
||||
content: vec![anthropic_content],
|
||||
});
|
||||
}
|
||||
|
||||
let response = match anthropic_chat(
|
||||
model,
|
||||
system_message,
|
||||
&anthropic_messages,
|
||||
temperature,
|
||||
7048,
|
||||
timeout,
|
||||
stop,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn openai_chat_compiler(
|
||||
model: &OpenAiChatModel,
|
||||
messages: &Vec<LlmMessage>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
json_mode: bool,
|
||||
json_schema: Option<Value>,
|
||||
) -> Result<String> {
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
let openai_role = match message.role {
|
||||
LlmRole::System => OpenAiChatRole::System,
|
||||
LlmRole::User => OpenAiChatRole::User,
|
||||
LlmRole::Assistant => OpenAiChatRole::Assistant,
|
||||
};
|
||||
|
||||
let openai_message = OpenAiChatMessage {
|
||||
role: openai_role,
|
||||
content: vec![OpenAiChatContent {
|
||||
type_: "text".to_string(),
|
||||
text: message.content.clone(),
|
||||
}],
|
||||
};
|
||||
|
||||
openai_messages.push(openai_message);
|
||||
}
|
||||
|
||||
let response = match openai_chat(
|
||||
&model,
|
||||
openai_messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
timeout,
|
||||
stop,
|
||||
json_mode,
|
||||
json_schema,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn anthropic_chat_stream_compiler(
|
||||
model: &AnthropicChatModel,
|
||||
messages: &Vec<LlmMessage>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
) -> Result<ReceiverStream<String>> {
|
||||
let system_message = match messages.iter().find(|m| m.role == LlmRole::System) {
|
||||
Some(message) => Some(message.content.clone()),
|
||||
None => None,
|
||||
};
|
||||
let mut anthropic_messages = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
let anthropic_role = match message.role {
|
||||
LlmRole::System => continue,
|
||||
LlmRole::User => AnthropicChatRole::User,
|
||||
LlmRole::Assistant => AnthropicChatRole::Assistant,
|
||||
};
|
||||
|
||||
let anthropic_content = AnthropicContent {
|
||||
text: message.content.clone(),
|
||||
_type: AnthropicContentType::Text,
|
||||
};
|
||||
|
||||
anthropic_messages.push(AnthropicChatMessage {
|
||||
role: anthropic_role,
|
||||
content: vec![anthropic_content],
|
||||
});
|
||||
}
|
||||
|
||||
let stream = match anthropic_chat_stream(
|
||||
model,
|
||||
system_message,
|
||||
&anthropic_messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
timeout,
|
||||
stop,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
||||
};
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
async fn openai_chat_stream_compiler(
|
||||
model: &OpenAiChatModel,
|
||||
messages: &Vec<LlmMessage>,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
) -> Result<ReceiverStream<String>> {
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
for message in messages {
|
||||
let openai_role = match message.role {
|
||||
LlmRole::System => OpenAiChatRole::System,
|
||||
LlmRole::User => OpenAiChatRole::User,
|
||||
LlmRole::Assistant => OpenAiChatRole::Assistant,
|
||||
};
|
||||
|
||||
let openai_message = OpenAiChatMessage {
|
||||
role: openai_role,
|
||||
content: vec![OpenAiChatContent {
|
||||
type_: "text".to_string(),
|
||||
text: message.content.clone(),
|
||||
}],
|
||||
};
|
||||
|
||||
openai_messages.push(openai_message);
|
||||
}
|
||||
|
||||
let stream = match openai_chat_stream(
|
||||
&model,
|
||||
&openai_messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
timeout,
|
||||
stop,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
|
||||
};
|
||||
|
||||
Ok(stream)
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
mod anthropic;
|
||||
pub mod embedding_router;
|
||||
mod hugging_face;
|
||||
pub mod langfuse;
|
||||
pub mod llm_router;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
|
@ -1,67 +0,0 @@
|
|||
use std::env;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use axum::http::HeaderMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct OllamaEmbeddingRequest {
|
||||
pub prompt: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct OllamaEmbeddingResponse {
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
pub async fn ollama_embedding(prompt: String, for_retrieval: bool) -> Result<Vec<f32>> {
|
||||
let ollama_url = env::var("OLLAMA_URL").unwrap_or(String::from("http://localhost:11434"));
|
||||
let model = env::var("EMBEDDING_MODEL").unwrap_or(String::from("mxbai-embed-large"));
|
||||
|
||||
let prompt = if model == "mxbai-embed-large" && for_retrieval {
|
||||
format!(
|
||||
"Represent this sentence for searching relevant passages:: {}",
|
||||
prompt
|
||||
)
|
||||
} else {
|
||||
prompt
|
||||
};
|
||||
|
||||
let client = match reqwest::Client::builder().build() {
|
||||
Ok(client) => client,
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Error creating reqwest client: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
|
||||
let req = OllamaEmbeddingRequest { prompt, model };
|
||||
|
||||
let res = match client
|
||||
.post(format!("{}/api/embeddings", ollama_url))
|
||||
.headers(headers)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Error sending Ollama request: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let ollama_res: OllamaEmbeddingResponse = match res.json::<OllamaEmbeddingResponse>().await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
return Err(anyhow!("Error parsing Ollama response: {:?}", e));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ollama_res.embedding)
|
||||
}
|
|
@ -1,417 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use futures::StreamExt;
|
||||
use serde_json::{json, Value};
|
||||
use std::{env, time::Duration};
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
||||
|
||||
const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings";
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref OPENAI_API_KEY: String = env::var("OPENAI_API_KEY")
|
||||
.expect("OPENAI_API_KEY must be set");
|
||||
static ref OPENAI_CHAT_URL: String = env::var("OPENAI_CHAT_URL").unwrap_or("https://api.openai.com/v1/chat/completions".to_string());
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub enum OpenAiChatModel {
|
||||
#[serde(rename = "gpt-4o-2024-11-20")]
|
||||
Gpt4o,
|
||||
#[serde(rename = "o3-mini")]
|
||||
O3Mini,
|
||||
#[serde(rename = "gpt-3.5-turbo")]
|
||||
Gpt35Turbo,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OpenAiChatRole {
|
||||
System,
|
||||
Developer,
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ReasoningEffort {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct OpenAiChatContent {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct OpenAiChatMessage {
|
||||
pub role: OpenAiChatRole,
|
||||
pub content: Vec<OpenAiChatContent>,
|
||||
}
|
||||
|
||||
// Helper functions for conditional serialization
|
||||
fn is_o3_model(model: &OpenAiChatModel) -> bool {
|
||||
matches!(model, OpenAiChatModel::O3Mini)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct OpenAiChatRequest {
|
||||
model: OpenAiChatModel,
|
||||
messages: Vec<OpenAiChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_effort: Option<ReasoningEffort>,
|
||||
frequency_penalty: f32,
|
||||
presence_penalty: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
stop: Option<Vec<String>>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
response_format: Option<Value>,
|
||||
}
|
||||
|
||||
impl OpenAiChatRequest {
|
||||
pub fn new(
|
||||
model: OpenAiChatModel,
|
||||
messages: Vec<OpenAiChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
stop: Option<Vec<String>>,
|
||||
stream: bool,
|
||||
response_format: Option<Value>,
|
||||
) -> Self {
|
||||
let (temperature, max_tokens, top_p, reasoning_effort) = if is_o3_model(&model) {
|
||||
(None, None, None, Some(ReasoningEffort::Low))
|
||||
} else {
|
||||
(Some(temperature), Some(max_tokens), Some(1.0), None)
|
||||
};
|
||||
|
||||
Self {
|
||||
model,
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
top_p,
|
||||
reasoning_effort,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop,
|
||||
stream,
|
||||
response_format,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Choice {
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Message {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
pub async fn openai_chat(
|
||||
model: &OpenAiChatModel,
|
||||
messages: Vec<OpenAiChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
json_mode: bool,
|
||||
json_schema: Option<Value>,
|
||||
) -> Result<String> {
|
||||
let response_format = match json_mode {
|
||||
true => Some(json!({"type": "json_object"})),
|
||||
false => None,
|
||||
};
|
||||
|
||||
let response_format = match json_schema {
|
||||
Some(schema) => Some(json!({"type": "json_schema", "json_schema": schema})),
|
||||
None => response_format,
|
||||
};
|
||||
|
||||
let chat_request = OpenAiChatRequest::new(
|
||||
model.clone(),
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens,
|
||||
stop,
|
||||
false,
|
||||
response_format,
|
||||
);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let headers = {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", OPENAI_API_KEY.to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers
|
||||
};
|
||||
|
||||
let response = match client
|
||||
.post(OPENAI_CHAT_URL.to_string())
|
||||
.headers(headers)
|
||||
.json(&chat_request)
|
||||
.timeout(Duration::from_secs(timeout))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to send request to OpenAI: {:?}", e);
|
||||
let err = anyhow!("Unable to send request to OpenAI: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let response_text = response.text().await.unwrap();
|
||||
|
||||
let completion_res = match serde_json::from_str::<ChatCompletionResponse>(&response_text) {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to parse response from OpenAI: {:?}", e);
|
||||
let err = anyhow!("Unable to parse response from OpenAI: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let content = match completion_res.choices.get(0) {
|
||||
Some(choice) => choice.message.content.clone(),
|
||||
None => return Err(anyhow!("No content returned from OpenAI")),
|
||||
};
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct OpenAiChatDelta {
|
||||
pub content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct OpenAiChatChoice {
|
||||
pub delta: OpenAiChatDelta,
|
||||
pub index: u32,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct OpenAiChatStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
pub choices: Vec<OpenAiChatChoice>,
|
||||
}
|
||||
|
||||
/// We can't do Langfuse traces directly in the stream function.
|
||||
/// This is mainly because I'm too lazy to figure out how to set up all the passing with the stream reciever
|
||||
/// Langfuse traces must be written directly on the stream call.
|
||||
|
||||
pub async fn openai_chat_stream(
|
||||
model: &OpenAiChatModel,
|
||||
messages: &Vec<OpenAiChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
timeout: u64,
|
||||
stop: Option<Vec<String>>,
|
||||
) -> Result<ReceiverStream<String>> {
|
||||
let chat_request = OpenAiChatRequest::new(
|
||||
model.clone(),
|
||||
messages.clone(),
|
||||
temperature,
|
||||
max_tokens,
|
||||
stop,
|
||||
true,
|
||||
None,
|
||||
);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let headers = {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", OPENAI_API_KEY.to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers
|
||||
};
|
||||
|
||||
let (tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let response = client
|
||||
.post(OPENAI_CHAT_URL.to_string())
|
||||
.headers(headers)
|
||||
.json(&chat_request)
|
||||
.timeout(Duration::from_secs(timeout))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Unable to send request to OpenAI: {:?}", e);
|
||||
let err = anyhow!("Unable to send request to OpenAI: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
err
|
||||
});
|
||||
|
||||
if let Err(_e) = response {
|
||||
return;
|
||||
}
|
||||
|
||||
let response = response.unwrap();
|
||||
let mut stream = response.bytes_stream();
|
||||
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(bytes) => {
|
||||
let chunk = String::from_utf8(bytes.to_vec()).unwrap();
|
||||
buffer.push_str(&chunk);
|
||||
|
||||
while let Some(pos) = buffer.find("}\n") {
|
||||
let (json_str, rest) = {
|
||||
let (json_str, rest) = buffer.split_at(pos + 1);
|
||||
(json_str.to_string(), rest.to_string())
|
||||
};
|
||||
buffer = rest;
|
||||
|
||||
let json_str_trimmed = json_str.replace("data: ", "");
|
||||
match serde_json::from_str::<OpenAiChatStreamResponse>(
|
||||
&json_str_trimmed.trim(),
|
||||
) {
|
||||
Ok(response) => {
|
||||
if let Some(content) = &response.choices[0].delta.content {
|
||||
if tx.send(content.clone()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error parsing JSON response: {:?}", e);
|
||||
let err = anyhow!("Error parsing JSON response: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error while streaming response: {:?}", e);
|
||||
let err = anyhow!("Error while streaming response: {}", e);
|
||||
send_sentry_error(&err.to_string(), None);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct AdaBulkEmbedding {
|
||||
pub model: String,
|
||||
pub input: Vec<String>,
|
||||
pub dimensions: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct AdaEmbeddingArray {
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct AdaEmbeddingResponse {
|
||||
pub data: Vec<AdaEmbeddingArray>,
|
||||
}
|
||||
|
||||
pub async fn ada_bulk_embedding(text_list: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||
let embedding_model = env::var("EMBEDDING_MODEL").expect("EMBEDDING_MODEL must be set");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let ada_bulk_embedding = AdaBulkEmbedding {
|
||||
model: embedding_model,
|
||||
input: text_list,
|
||||
dimensions: 1024,
|
||||
};
|
||||
|
||||
let embeddings_result = match client
|
||||
.post(OPENAI_EMBEDDING_URL)
|
||||
.headers({
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", OPENAI_API_KEY.to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers
|
||||
})
|
||||
.timeout(Duration::from_secs(60))
|
||||
.json(&ada_bulk_embedding)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if !res.status().is_success() {
|
||||
tracing::error!(
|
||||
"There was an issue while getting the data source: {}",
|
||||
res.text().await.unwrap()
|
||||
);
|
||||
return Err(anyhow!("Error getting data source"));
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
Err(e) => Err(anyhow!(e.to_string())),
|
||||
};
|
||||
|
||||
let embedding_text = embeddings_result.unwrap().text().await.unwrap();
|
||||
|
||||
let embeddings: AdaEmbeddingResponse =
|
||||
match serde_json::from_str(embedding_text.clone().as_str()) {
|
||||
Ok(embeddings) => embeddings,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"There was an issue decoding the bulk embedding json response: {}",
|
||||
e
|
||||
);
|
||||
return Err(anyhow!(e.to_string()));
|
||||
}
|
||||
};
|
||||
|
||||
let result: Vec<Vec<f32>> = embeddings.data.into_iter().map(|x| x.embedding).collect();
|
||||
|
||||
Ok(result)
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use aws_config::meta::region::RegionProviderChain;
|
||||
use aws_sdk_secretsmanager::Client;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn create_db_secret(secret: String) -> Result<String> {
|
||||
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
|
||||
let config = aws_config::from_env().region(region_provider).load().await;
|
||||
let client = Client::new(&config);
|
||||
|
||||
let secret_id = Uuid::new_v4().to_string();
|
||||
|
||||
match client
|
||||
.create_secret()
|
||||
.name(secret_id.clone())
|
||||
.secret_string(secret)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
tracing::info!("Successfully created secret in AWS.");
|
||||
tracing::info!("Secret ID: {}", secret_id);
|
||||
tracing::info!("Secret ARN: {}", res.arn.unwrap());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("There was an issue while creating the secret in AWS.");
|
||||
return Err(anyhow!(e));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(secret_id)
|
||||
}
|
||||
|
||||
pub async fn delete_db_secret(secret_id: String) -> Result<()> {
|
||||
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
|
||||
let config = aws_config::from_env().region(region_provider).load().await;
|
||||
let client = Client::new(&config);
|
||||
|
||||
let _secret_response = match client.delete_secret().secret_id(secret_id).send().await {
|
||||
Ok(secret_response) => secret_response,
|
||||
Err(e) => return Err(anyhow!(e)),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read_secret(secret_id: String) -> Result<String> {
|
||||
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
|
||||
let config = aws_config::from_env().region(region_provider).load().await;
|
||||
let client = Client::new(&config);
|
||||
|
||||
let secret_response = client.get_secret_value().secret_id(secret_id).send().await;
|
||||
|
||||
let secret = match secret_response {
|
||||
Ok(secret) => secret,
|
||||
Err(e) => return Err(anyhow!(e)),
|
||||
};
|
||||
|
||||
let secret_string = match secret.secret_string {
|
||||
Some(secret_string) => secret_string,
|
||||
None => return Err(anyhow!("There was no secret string in the response")),
|
||||
};
|
||||
|
||||
return Ok(secret_string);
|
||||
}
|
|
@ -1,6 +1,2 @@
|
|||
pub mod ai;
|
||||
// pub mod aws;
|
||||
pub mod email;
|
||||
pub mod posthog;
|
||||
pub mod sentry_utils;
|
||||
pub mod typesense;
|
||||
|
|
|
@ -1,190 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use std::env;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::clients::sentry_utils::send_sentry_error;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref POSTHOG_API_KEY: String = env::var("POSTHOG_API_KEY").expect("POSTHOG_API_KEY must be set");
|
||||
}
|
||||
|
||||
const POSTHOG_API_URL: &str = "https://us.i.posthog.com/capture/";
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct PosthogEvent {
|
||||
event: PosthogEventType,
|
||||
distinct_id: Option<Uuid>,
|
||||
properties: PosthogEventProperties,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PosthogEventType {
|
||||
MetricCreated,
|
||||
MetricFollowUp,
|
||||
MetricAddedToDashboard,
|
||||
MetricViewed,
|
||||
DashboardViewed,
|
||||
TitleManuallyUpdated,
|
||||
SqlManuallyUpdated,
|
||||
ChartStylingManuallyUpdated,
|
||||
ChartStylingAutoUpdated,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum PosthogEventProperties {
|
||||
MetricCreated(MetricCreatedProperties),
|
||||
MetricFollowUp(MetricFollowUpProperties),
|
||||
MetricAddedToDashboard(MetricAddedToDashboardProperties),
|
||||
MetricViewed(MetricViewedProperties),
|
||||
DashboardViewed(DashboardViewedProperties),
|
||||
TitleManuallyUpdated(TitleManuallyUpdatedProperties),
|
||||
SqlManuallyUpdated(SqlManuallyUpdatedProperties),
|
||||
ChartStylingManuallyUpdated(ChartStylingManuallyUpdatedProperties),
|
||||
ChartStylingAutoUpdated(ChartStylingAutoUpdatedProperties),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MetricCreatedProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MetricFollowUpProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MetricAddedToDashboardProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
pub dashboard_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MetricViewedProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct DashboardViewedProperties {
|
||||
pub dashboard_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct TitleManuallyUpdatedProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct SqlManuallyUpdatedProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ChartStylingManuallyUpdatedProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ChartStylingAutoUpdatedProperties {
|
||||
pub message_id: Uuid,
|
||||
pub thread_id: Uuid,
|
||||
}
|
||||
|
||||
pub async fn send_posthog_event_handler(
|
||||
event_type: PosthogEventType,
|
||||
user_id: Option<Uuid>,
|
||||
properties: PosthogEventProperties,
|
||||
) -> Result<()> {
|
||||
let event = PosthogEvent {
|
||||
event: event_type,
|
||||
distinct_id: user_id,
|
||||
properties,
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
match send_event(event).await {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
send_sentry_error(&e.to_string(), None);
|
||||
tracing::error!("Unable to send event to Posthog: {:?}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct PosthogRequest {
|
||||
#[serde(flatten)]
|
||||
event: PosthogEvent,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
async fn send_event(event: PosthogEvent) -> Result<()> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let headers = {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
format!("Bearer {}", POSTHOG_API_KEY.to_string())
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers
|
||||
};
|
||||
|
||||
let posthog_req = PosthogRequest {
|
||||
event,
|
||||
api_key: POSTHOG_API_KEY.to_string(),
|
||||
};
|
||||
|
||||
match client
|
||||
.post(POSTHOG_API_URL)
|
||||
.headers(headers)
|
||||
.json(&posthog_req)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to send request to Posthog: {:?}", e);
|
||||
return Err(anyhow!(e));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use dotenv::dotenv;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_event() {
|
||||
dotenv().ok();
|
||||
|
||||
let _ = send_event(PosthogEvent {
|
||||
event: PosthogEventType::MetricCreated,
|
||||
distinct_id: Some(Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap()),
|
||||
properties: PosthogEventProperties::MetricCreated(MetricCreatedProperties {
|
||||
message_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(),
|
||||
thread_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
|
@ -1,318 +0,0 @@
|
|||
use anyhow::{anyhow, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::env;
|
||||
use uuid::Uuid;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref TYPESENSE_API_KEY: String = env::var("TYPESENSE_API_KEY").unwrap_or("xyz".to_string());
|
||||
static ref TYPESENSE_API_HOST: String = env::var("TYPESENSE_API_HOST").unwrap_or("http://localhost:8108".to_string());
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CollectionName {
|
||||
Messages,
|
||||
Collections,
|
||||
Datasets,
|
||||
PermissionGroups,
|
||||
Teams,
|
||||
DataSources,
|
||||
Terms,
|
||||
Dashboards,
|
||||
#[serde(untagged)]
|
||||
StoredValues(String),
|
||||
}
|
||||
|
||||
impl CollectionName {
|
||||
pub fn to_string(&self) -> String {
|
||||
match self {
|
||||
CollectionName::Messages => "messages".to_string(),
|
||||
CollectionName::Collections => "collections".to_string(),
|
||||
CollectionName::Datasets => "datasets".to_string(),
|
||||
CollectionName::PermissionGroups => "permission_groups".to_string(),
|
||||
CollectionName::Teams => "teams".to_string(),
|
||||
CollectionName::DataSources => "data_sources".to_string(),
|
||||
CollectionName::Terms => "terms".to_string(),
|
||||
CollectionName::Dashboards => "dashboards".to_string(),
|
||||
CollectionName::StoredValues(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct MessageDocument {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub summary_question: String,
|
||||
pub organization_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct GenericDocument {
|
||||
pub id: Uuid,
|
||||
pub name: String,
|
||||
pub organization_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug)]
|
||||
pub struct StoredValueDocument {
|
||||
pub id: Uuid,
|
||||
pub value: String,
|
||||
pub dataset_id: Uuid,
|
||||
pub dataset_column_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Document {
|
||||
Message(MessageDocument),
|
||||
Dashboard(GenericDocument),
|
||||
Dataset(GenericDocument),
|
||||
PermissionGroup(GenericDocument),
|
||||
Team(GenericDocument),
|
||||
DataSource(GenericDocument),
|
||||
Term(GenericDocument),
|
||||
Collection(GenericDocument),
|
||||
StoredValue(StoredValueDocument),
|
||||
}
|
||||
|
||||
impl Document {
|
||||
pub fn id(&self) -> Uuid {
|
||||
match self {
|
||||
Document::Message(m) => m.id,
|
||||
Document::Dashboard(d) => d.id,
|
||||
Document::Dataset(d) => d.id,
|
||||
Document::PermissionGroup(d) => d.id,
|
||||
Document::Team(d) => d.id,
|
||||
Document::DataSource(d) => d.id,
|
||||
Document::Term(d) => d.id,
|
||||
Document::Collection(d) => d.id,
|
||||
Document::StoredValue(d) => d.id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_stored_value_document(&self) -> &StoredValueDocument {
|
||||
match self {
|
||||
Document::StoredValue(d) => d,
|
||||
_ => panic!("Document is not a StoredValueDocument"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn upsert_document(collection: CollectionName, document: Document) -> Result<()> {
|
||||
let client = reqwest::Client::builder().build()?;
|
||||
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/json".parse()?);
|
||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
||||
|
||||
match client
|
||||
.request(
|
||||
reqwest::Method::POST,
|
||||
format!(
|
||||
"{}/collections/{}/documents?action=upsert",
|
||||
*TYPESENSE_API_HOST,
|
||||
collection.to_string()
|
||||
),
|
||||
)
|
||||
.headers(headers)
|
||||
.json(&document)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct SearchRequest {
|
||||
pub searches: Vec<SearchRequestObject>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct SearchRequestObject {
|
||||
pub collection: CollectionName,
|
||||
pub q: String,
|
||||
pub query_by: String,
|
||||
pub prefix: bool,
|
||||
pub exclude_fields: String,
|
||||
pub highlight_fields: String,
|
||||
pub use_cache: bool,
|
||||
pub filter_by: String,
|
||||
pub vector_query: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct RequestParams {
|
||||
pub collection_name: CollectionName,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Highlight {
|
||||
pub field: String,
|
||||
pub matched_tokens: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct HybridSearchInfo {
|
||||
pub rank_fusion_score: f64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Hit {
|
||||
pub document: Document,
|
||||
pub highlights: Vec<Highlight>,
|
||||
pub hybrid_search_info: HybridSearchInfo,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SearchResult {
|
||||
pub request_params: RequestParams,
|
||||
pub hits: Vec<Hit>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SearchResponse {
|
||||
pub results: Vec<SearchResult>,
|
||||
}
|
||||
|
||||
pub async fn search_documents(search_reqs: Vec<SearchRequestObject>) -> Result<SearchResponse> {
|
||||
let client = reqwest::Client::builder().build()?;
|
||||
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/json".parse()?);
|
||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
||||
|
||||
let search_req = SearchRequest {
|
||||
searches: search_reqs,
|
||||
};
|
||||
|
||||
let res = match client
|
||||
.request(
|
||||
reqwest::Method::POST,
|
||||
format!("{}/multi_search", *TYPESENSE_API_HOST),
|
||||
)
|
||||
.headers(headers)
|
||||
.json(&search_req)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
let res = if res.status().is_success() {
|
||||
res
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Error sending request: {}",
|
||||
res.text().await.unwrap()
|
||||
));
|
||||
};
|
||||
|
||||
res
|
||||
}
|
||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
||||
};
|
||||
|
||||
let search_response = match res.json::<SearchResponse>().await {
|
||||
Ok(search_response) => search_response,
|
||||
Err(e) => return Err(anyhow!("Error parsing response: {}", e)),
|
||||
};
|
||||
|
||||
Ok(search_response)
|
||||
}
|
||||
|
||||
pub async fn create_collection(schema: Value) -> Result<()> {
|
||||
let client = reqwest::Client::builder().build()?;
|
||||
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/json".parse()?);
|
||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
||||
|
||||
match client
|
||||
.request(
|
||||
reqwest::Method::POST,
|
||||
format!("{}/collections", *TYPESENSE_API_HOST,),
|
||||
)
|
||||
.headers(headers)
|
||||
.json(&schema)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_collection(collection_name: &String, filter_by: &String) -> Result<()> {
|
||||
let client = reqwest::Client::builder().build()?;
|
||||
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/json".parse()?);
|
||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
||||
|
||||
match client
|
||||
.request(
|
||||
reqwest::Method::DELETE,
|
||||
format!(
|
||||
"{}/collections/{}/documents?filter_by={}",
|
||||
*TYPESENSE_API_HOST, collection_name, filter_by
|
||||
),
|
||||
)
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn bulk_insert_documents<T>(collection_name: &String, documents: &Vec<T>) -> Result<()>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
{
|
||||
let client = reqwest::Client::builder().build()?;
|
||||
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert("Content-Type", "application/json".parse()?);
|
||||
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
|
||||
|
||||
let serialized_documents = documents
|
||||
.iter()
|
||||
.map(|doc| serde_json::to_string(doc))
|
||||
.collect::<Result<Vec<String>, _>>()?
|
||||
.join("\n");
|
||||
|
||||
match client
|
||||
.request(
|
||||
reqwest::Method::POST,
|
||||
format!(
|
||||
"{}/collections/{}/documents/import?action=create",
|
||||
*TYPESENSE_API_HOST, collection_name
|
||||
),
|
||||
)
|
||||
.headers(headers)
|
||||
.body(serialized_documents)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
tracing::info!("Res: {:?}", res.text().await.unwrap());
|
||||
return Ok(());
|
||||
}
|
||||
tracing::error!("Error sending request: {}", res.text().await.unwrap());
|
||||
return Err(anyhow!("Error sending bulk insert request"));
|
||||
}
|
||||
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
|
||||
};
|
||||
}
|
|
@ -1,11 +1,5 @@
|
|||
pub mod charting;
|
||||
pub mod clients;
|
||||
pub mod security;
|
||||
pub mod serde_helpers;
|
||||
pub mod stored_values;
|
||||
pub mod validation;
|
||||
|
||||
pub use agents::*;
|
||||
pub use security::*;
|
||||
pub use stored_values::*;
|
||||
pub use validation::*;
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
use serde::{de::Deserializer, Deserialize};
|
||||
use serde_json::Value;
|
||||
|
||||
pub fn deserialize_double_option<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
|
||||
where
|
||||
T: serde::Deserialize<'de>,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value = Value::deserialize(deserializer)?;
|
||||
|
||||
match value {
|
||||
Value::Null => Ok(Some(None)), // explicit null
|
||||
Value::Object(obj) if obj.is_empty() => Ok(None), // empty object
|
||||
_ => {
|
||||
match T::deserialize(value) {
|
||||
Ok(val) => Ok(Some(Some(val))),
|
||||
Err(_) => Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
pub mod deserialization_helpers;
|
|
@ -1,291 +0,0 @@
|
|||
pub mod search;
|
||||
use query_engine::data_source_query_routes::query_engine::query_engine;
|
||||
|
||||
use query_engine::data_types::DataType;
|
||||
pub use search::*;
|
||||
|
||||
use crate::utils::clients::ai::embedding_router::embedding_router;
|
||||
use anyhow::Result;
|
||||
use chrono::Utc;
|
||||
use database::enums::StoredValuesStatus;
|
||||
use database::{pool::get_pg_pool, schema::dataset_columns};
|
||||
use diesel::prelude::*;
|
||||
use diesel::sql_types::{Array, Float4, Integer, Text, Timestamptz, Uuid as SqlUuid};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, QueryableByName)]
|
||||
pub struct StoredValueRow {
|
||||
#[diesel(sql_type = Text)]
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, QueryableByName)]
|
||||
pub struct StoredValueWithDistance {
|
||||
#[diesel(sql_type = Text)]
|
||||
pub value: String,
|
||||
#[diesel(sql_type = Text)]
|
||||
pub column_name: String,
|
||||
#[diesel(sql_type = SqlUuid)]
|
||||
pub column_id: Uuid,
|
||||
}
|
||||
|
||||
const BATCH_SIZE: usize = 10_000;
|
||||
const MAX_VALUE_LENGTH: usize = 50;
|
||||
|
||||
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
|
||||
let pool = get_pg_pool();
|
||||
let mut conn = pool.get().await?;
|
||||
|
||||
// Create schema and table using raw SQL
|
||||
let schema_name = organization_id.to_string().replace("-", "_");
|
||||
let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS values_{}", schema_name);
|
||||
|
||||
let create_table_sql = format!(
|
||||
"CREATE TABLE IF NOT EXISTS values_{}.values_v1 (
|
||||
value text NOT NULL,
|
||||
dataset_id uuid NOT NULL,
|
||||
column_name text NOT NULL,
|
||||
column_id uuid NOT NULL,
|
||||
embedding vector(1024),
|
||||
created_at timestamp with time zone NOT NULL DEFAULT now(),
|
||||
UNIQUE(dataset_id, column_name, value)
|
||||
)",
|
||||
schema_name
|
||||
);
|
||||
|
||||
let create_index_sql = format!(
|
||||
"CREATE INDEX IF NOT EXISTS values_v1_embedding_idx
|
||||
ON values_{}.values_v1
|
||||
USING ivfflat (embedding vector_cosine_ops)",
|
||||
schema_name
|
||||
);
|
||||
|
||||
diesel::sql_query(create_schema_sql)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
diesel::sql_query(create_table_sql)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
diesel::sql_query(create_index_sql)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn store_column_values(
|
||||
organization_id: &Uuid,
|
||||
dataset_id: &Uuid,
|
||||
column_name: &str,
|
||||
column_id: &Uuid,
|
||||
_data_source_id: &Uuid,
|
||||
schema: &str,
|
||||
table_name: &str,
|
||||
) -> Result<()> {
|
||||
let pool = get_pg_pool();
|
||||
let mut conn = pool.get().await?;
|
||||
|
||||
// Create schema and table if they don't exist
|
||||
ensure_stored_values_schema(organization_id).await?;
|
||||
|
||||
// Query distinct values in batches
|
||||
let mut offset = 0;
|
||||
let mut first_batch = true;
|
||||
let schema_name = organization_id.to_string().replace("-", "_");
|
||||
|
||||
loop {
|
||||
let query = format!(
|
||||
"SELECT DISTINCT \"{}\" as value
|
||||
FROM {}.{}
|
||||
WHERE \"{}\" IS NOT NULL
|
||||
AND length(\"{}\") <= {}
|
||||
ORDER BY \"{}\"
|
||||
LIMIT {} OFFSET {}",
|
||||
column_name,
|
||||
schema,
|
||||
table_name,
|
||||
column_name,
|
||||
column_name,
|
||||
MAX_VALUE_LENGTH,
|
||||
column_name,
|
||||
BATCH_SIZE,
|
||||
offset
|
||||
);
|
||||
|
||||
let results = match query_engine(dataset_id, &query, None).await {
|
||||
Ok(results) => results.data,
|
||||
Err(e) => {
|
||||
tracing::error!("Error querying stored values: {:?}", e);
|
||||
if first_batch {
|
||||
return Err(e);
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
|
||||
if results.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Extract values from the query results
|
||||
let values: Vec<String> = results
|
||||
.into_iter()
|
||||
.filter_map(|row| {
|
||||
if let Some(DataType::Text(Some(value))) = row.get("value") {
|
||||
Some(value.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if values.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// If this is the first batch and we have 15 or fewer values, handle as enum
|
||||
if first_batch && values.len() <= 15 {
|
||||
// Get current description
|
||||
let current_description =
|
||||
diesel::sql_query("SELECT description FROM dataset_columns WHERE id = $1")
|
||||
.bind::<SqlUuid, _>(column_id)
|
||||
.get_result::<StoredValueRow>(&mut conn)
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|row| Some(row.value));
|
||||
|
||||
// Format new description
|
||||
let enum_list = format!("Values for this column are: {}", values.join(", "));
|
||||
let new_description = match current_description {
|
||||
Some(desc) if !desc.is_empty() => format!("{}. {}", desc, enum_list),
|
||||
_ => enum_list,
|
||||
};
|
||||
|
||||
// Update column description
|
||||
diesel::update(dataset_columns::table)
|
||||
.filter(dataset_columns::id.eq(column_id))
|
||||
.set((
|
||||
dataset_columns::description.eq(new_description),
|
||||
dataset_columns::stored_values_status.eq(StoredValuesStatus::Success),
|
||||
dataset_columns::stored_values_count.eq(values.len() as i64),
|
||||
dataset_columns::stored_values_last_synced.eq(Utc::now()),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create embeddings for the batch
|
||||
let embeddings = create_embeddings_batch(&values).await?;
|
||||
|
||||
// Insert values and embeddings
|
||||
for (value, embedding) in values.iter().zip(embeddings.iter()) {
|
||||
let insert_sql = format!(
|
||||
"INSERT INTO {}_values.values_v1
|
||||
(value, dataset_id, column_name, column_id, embedding, created_at)
|
||||
VALUES ($1::text, $2::uuid, $3::text, $4::uuid, $5::vector, $6::timestamptz)
|
||||
ON CONFLICT (dataset_id, column_name, value)
|
||||
DO UPDATE SET created_at = EXCLUDED.created_at",
|
||||
schema_name
|
||||
);
|
||||
|
||||
diesel::sql_query(insert_sql)
|
||||
.bind::<Text, _>(value)
|
||||
.bind::<SqlUuid, _>(dataset_id)
|
||||
.bind::<Text, _>(column_name)
|
||||
.bind::<SqlUuid, _>(column_id)
|
||||
.bind::<Array<Float4>, _>(embedding)
|
||||
.bind::<Timestamptz, _>(Utc::now())
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
}
|
||||
|
||||
first_batch = false;
|
||||
offset += BATCH_SIZE;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_embeddings_batch(values: &[String]) -> Result<Vec<Vec<f32>>> {
|
||||
let embeddings = embedding_router(values.to_vec(), true).await?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
pub async fn search_stored_values(
|
||||
organization_id: &Uuid,
|
||||
dataset_id: &Uuid,
|
||||
query_embedding: Vec<f32>,
|
||||
limit: Option<i32>,
|
||||
) -> Result<Vec<(String, String, Uuid)>> {
|
||||
let pool = get_pg_pool();
|
||||
let mut conn = pool.get().await?;
|
||||
|
||||
let limit = limit.unwrap_or(10);
|
||||
|
||||
let schema_name = organization_id.to_string().replace("-", "_");
|
||||
let query = format!(
|
||||
"SELECT value, column_name, column_id
|
||||
FROM values_{}.values_v1
|
||||
WHERE dataset_id = $2::uuid
|
||||
ORDER BY embedding <=> $1::vector
|
||||
LIMIT $3::integer",
|
||||
schema_name
|
||||
);
|
||||
|
||||
let results: Vec<StoredValueWithDistance> = diesel::sql_query(query)
|
||||
.bind::<Array<Float4>, _>(query_embedding)
|
||||
.bind::<SqlUuid, _>(dataset_id)
|
||||
.bind::<Integer, _>(limit)
|
||||
.load(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(results
|
||||
.into_iter()
|
||||
.map(|r| (r.value, r.column_name, r.column_id))
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub struct StoredValueColumn {
|
||||
pub organization_id: Uuid,
|
||||
pub dataset_id: Uuid,
|
||||
pub column_name: String,
|
||||
pub column_id: Uuid,
|
||||
pub data_source_id: Uuid,
|
||||
pub schema: String,
|
||||
pub table_name: String,
|
||||
}
|
||||
|
||||
pub async fn process_stored_values_background(columns: Vec<StoredValueColumn>) {
|
||||
for column in columns {
|
||||
match store_column_values(
|
||||
&column.organization_id,
|
||||
&column.dataset_id,
|
||||
&column.column_name,
|
||||
&column.column_id,
|
||||
&column.data_source_id,
|
||||
&column.schema,
|
||||
&column.table_name,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
tracing::info!(
|
||||
"Successfully processed stored values for column '{}' in dataset '{}'",
|
||||
column.column_name,
|
||||
column.table_name
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Failed to process stored values for column '{}' in dataset '{}': {:?}",
|
||||
column.column_name,
|
||||
column.table_name,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use cohere_rust::{api::rerank::{ReRankModel, ReRankRequest}, Cohere};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::clients::ai::embedding_router::embedding_router;
|
||||
|
||||
use super::search_stored_values;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoredValue {
|
||||
pub value: String,
|
||||
pub dataset_id: Uuid,
|
||||
pub column_name: String,
|
||||
pub column_id: Uuid,
|
||||
}
|
||||
|
||||
pub async fn search_values_for_dataset(
|
||||
organization_id: &Uuid,
|
||||
dataset_id: &Uuid,
|
||||
query: String,
|
||||
) -> Result<Vec<StoredValue>> {
|
||||
// Create embedding for the search query
|
||||
let query_vec = vec![query.clone()];
|
||||
let query_embedding = embedding_router(query_vec, true).await?[0].clone();
|
||||
|
||||
// Get initial candidates using vector similarity
|
||||
let candidates = search_stored_values(
|
||||
organization_id,
|
||||
dataset_id,
|
||||
query_embedding,
|
||||
Some(25), // Get more candidates for reranking
|
||||
).await?;
|
||||
|
||||
// Extract just the values for reranking
|
||||
let candidate_values: Vec<String> = candidates.iter().map(|(value, _, _)| value.clone()).collect();
|
||||
|
||||
// Rerank the candidates using the cohere client
|
||||
let co = Cohere::default();
|
||||
let request = ReRankRequest {
|
||||
query: query.as_str(),
|
||||
documents: &candidate_values,
|
||||
model: ReRankModel::EnglishV3,
|
||||
top_n: Some(10),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = co.rerank(&request).await?;
|
||||
|
||||
// Convert to StoredValue structs
|
||||
let values = response.into_iter()
|
||||
.map(|result| {
|
||||
let (value, column_name, column_id) = candidates[result.index as usize].clone();
|
||||
StoredValue {
|
||||
value,
|
||||
dataset_id: *dataset_id,
|
||||
column_name,
|
||||
column_id,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(values)
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
pub mod types;
|
||||
pub mod type_mapping;
|
||||
|
||||
pub use types::*;
|
||||
pub use type_mapping::*;
|
|
@ -1,234 +0,0 @@
|
|||
use database::enums::DataSourceType;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use query_engine::data_types::DataType;
|
||||
|
||||
// Standard types we use across all data sources
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum StandardType {
|
||||
Text,
|
||||
Integer,
|
||||
Float,
|
||||
Boolean,
|
||||
Date,
|
||||
Timestamp,
|
||||
Json,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl StandardType {
|
||||
pub fn from_data_type(data_type: &DataType) -> Self {
|
||||
match data_type.simple_type() {
|
||||
Some(simple_type) => match simple_type.as_str() {
|
||||
"string" => StandardType::Text,
|
||||
"number" => StandardType::Float,
|
||||
"boolean" => StandardType::Boolean,
|
||||
"date" => StandardType::Date,
|
||||
"null" => StandardType::Unknown,
|
||||
_ => StandardType::Unknown,
|
||||
},
|
||||
None => StandardType::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(type_str: &str) -> Self {
|
||||
match type_str.to_lowercase().as_str() {
|
||||
"text" | "string" | "varchar" | "char" => StandardType::Text,
|
||||
"int" | "integer" | "bigint" | "smallint" => StandardType::Integer,
|
||||
"float" | "double" | "decimal" | "numeric" => StandardType::Float,
|
||||
"bool" | "boolean" => StandardType::Boolean,
|
||||
"date" => StandardType::Date,
|
||||
"timestamp" | "datetime" | "timestamptz" => StandardType::Timestamp,
|
||||
"json" | "jsonb" => StandardType::Json,
|
||||
_ => StandardType::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Type mappings for each data source type to DataType
|
||||
static TYPE_MAPPINGS: Lazy<HashMap<DataSourceType, HashMap<&'static str, DataType>>> =
|
||||
Lazy::new(|| {
|
||||
let mut mappings = HashMap::new();
|
||||
|
||||
// Postgres mappings
|
||||
let mut postgres = HashMap::new();
|
||||
postgres.insert("text", DataType::Text(None));
|
||||
postgres.insert("varchar", DataType::Text(None));
|
||||
postgres.insert("char", DataType::Text(None));
|
||||
postgres.insert("int", DataType::Int4(None));
|
||||
postgres.insert("integer", DataType::Int4(None));
|
||||
postgres.insert("bigint", DataType::Int8(None));
|
||||
postgres.insert("smallint", DataType::Int2(None));
|
||||
postgres.insert("float", DataType::Float4(None));
|
||||
postgres.insert("double precision", DataType::Float8(None));
|
||||
postgres.insert("numeric", DataType::Decimal(None));
|
||||
postgres.insert("decimal", DataType::Decimal(None));
|
||||
postgres.insert("boolean", DataType::Bool(None));
|
||||
postgres.insert("date", DataType::Date(None));
|
||||
postgres.insert("timestamp", DataType::Timestamp(None));
|
||||
postgres.insert("timestamptz", DataType::Timestamptz(None));
|
||||
postgres.insert("json", DataType::Json(None));
|
||||
postgres.insert("jsonb", DataType::Json(None));
|
||||
mappings.insert(DataSourceType::Postgres, postgres.clone());
|
||||
mappings.insert(DataSourceType::Supabase, postgres);
|
||||
|
||||
// MySQL mappings
|
||||
let mut mysql = HashMap::new();
|
||||
mysql.insert("varchar", DataType::Text(None));
|
||||
mysql.insert("text", DataType::Text(None));
|
||||
mysql.insert("char", DataType::Text(None));
|
||||
mysql.insert("int", DataType::Int4(None));
|
||||
mysql.insert("bigint", DataType::Int8(None));
|
||||
mysql.insert("tinyint", DataType::Int2(None));
|
||||
mysql.insert("float", DataType::Float4(None));
|
||||
mysql.insert("double", DataType::Float8(None));
|
||||
mysql.insert("decimal", DataType::Decimal(None));
|
||||
mysql.insert("boolean", DataType::Bool(None));
|
||||
mysql.insert("date", DataType::Date(None));
|
||||
mysql.insert("datetime", DataType::Timestamp(None));
|
||||
mysql.insert("timestamp", DataType::Timestamptz(None));
|
||||
mysql.insert("json", DataType::Json(None));
|
||||
mappings.insert(DataSourceType::MySql, mysql.clone());
|
||||
mappings.insert(DataSourceType::Mariadb, mysql);
|
||||
|
||||
// BigQuery mappings
|
||||
let mut bigquery = HashMap::new();
|
||||
bigquery.insert("STRING", DataType::Text(None));
|
||||
bigquery.insert("INT64", DataType::Int8(None));
|
||||
bigquery.insert("INTEGER", DataType::Int4(None));
|
||||
bigquery.insert("FLOAT64", DataType::Float8(None));
|
||||
bigquery.insert("NUMERIC", DataType::Decimal(None));
|
||||
bigquery.insert("BOOL", DataType::Bool(None));
|
||||
bigquery.insert("DATE", DataType::Date(None));
|
||||
bigquery.insert("TIMESTAMP", DataType::Timestamptz(None));
|
||||
bigquery.insert("JSON", DataType::Json(None));
|
||||
mappings.insert(DataSourceType::BigQuery, bigquery);
|
||||
|
||||
// Snowflake mappings
|
||||
let mut snowflake = HashMap::new();
|
||||
snowflake.insert("TEXT", DataType::Text(None));
|
||||
snowflake.insert("VARCHAR", DataType::Text(None));
|
||||
snowflake.insert("CHAR", DataType::Text(None));
|
||||
snowflake.insert("NUMBER", DataType::Decimal(None));
|
||||
snowflake.insert("DECIMAL", DataType::Decimal(None));
|
||||
snowflake.insert("INTEGER", DataType::Int4(None));
|
||||
snowflake.insert("BIGINT", DataType::Int8(None));
|
||||
snowflake.insert("BOOLEAN", DataType::Bool(None));
|
||||
snowflake.insert("DATE", DataType::Date(None));
|
||||
snowflake.insert("TIMESTAMP", DataType::Timestamptz(None));
|
||||
snowflake.insert("VARIANT", DataType::Json(None));
|
||||
mappings.insert(DataSourceType::Snowflake, snowflake);
|
||||
|
||||
mappings
|
||||
});
|
||||
|
||||
pub fn normalize_type(source_type: DataSourceType, type_str: &str) -> DataType {
|
||||
TYPE_MAPPINGS
|
||||
.get(&source_type)
|
||||
.and_then(|mappings| mappings.get(type_str))
|
||||
.cloned()
|
||||
.unwrap_or(DataType::Unknown(Some(type_str.to_string())))
|
||||
}
|
||||
|
||||
pub fn types_compatible(source_type: DataSourceType, ds_type: &str, model_type: &str) -> bool {
|
||||
let ds_data_type = normalize_type(source_type, ds_type);
|
||||
let model_data_type = normalize_type(source_type, model_type);
|
||||
|
||||
match (&ds_data_type, &model_data_type) {
|
||||
// Allow integer -> float/decimal conversions
|
||||
(DataType::Int2(_), DataType::Float4(_)) => true,
|
||||
(DataType::Int2(_), DataType::Float8(_)) => true,
|
||||
(DataType::Int2(_), DataType::Decimal(_)) => true,
|
||||
(DataType::Int4(_), DataType::Float4(_)) => true,
|
||||
(DataType::Int4(_), DataType::Float8(_)) => true,
|
||||
(DataType::Int4(_), DataType::Decimal(_)) => true,
|
||||
(DataType::Int8(_), DataType::Float4(_)) => true,
|
||||
(DataType::Int8(_), DataType::Float8(_)) => true,
|
||||
(DataType::Int8(_), DataType::Decimal(_)) => true,
|
||||
|
||||
// Allow text for any type (common in views/computed columns)
|
||||
(DataType::Text(_), _) => true,
|
||||
|
||||
// Allow timestamp/timestamptz compatibility
|
||||
(DataType::Timestamp(_), DataType::Timestamptz(_)) => true,
|
||||
(DataType::Timestamptz(_), DataType::Timestamp(_)) => true,
|
||||
|
||||
// Exact matches (using to_string to compare types, not values)
|
||||
(a, b) if a.to_string() == b.to_string() => true,
|
||||
|
||||
// Everything else is incompatible
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_type_normalization() {
|
||||
assert!(matches!(
|
||||
normalize_type(DataSourceType::Postgres, "text"),
|
||||
DataType::Text(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
normalize_type(DataSourceType::Postgres, "integer"),
|
||||
DataType::Int4(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
normalize_type(DataSourceType::Postgres, "numeric"),
|
||||
DataType::Decimal(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bigquery_type_normalization() {
|
||||
assert!(matches!(
|
||||
normalize_type(DataSourceType::BigQuery, "STRING"),
|
||||
DataType::Text(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
normalize_type(DataSourceType::BigQuery, "INT64"),
|
||||
DataType::Int8(_)
|
||||
));
|
||||
assert!(matches!(
|
||||
normalize_type(DataSourceType::BigQuery, "FLOAT64"),
|
||||
DataType::Float8(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_compatibility() {
|
||||
// Same types are compatible
|
||||
assert!(types_compatible(DataSourceType::Postgres, "text", "text"));
|
||||
|
||||
// Integer can be used as float
|
||||
assert!(types_compatible(
|
||||
DataSourceType::Postgres,
|
||||
"integer",
|
||||
"float"
|
||||
));
|
||||
|
||||
// Text can be used for any type
|
||||
assert!(types_compatible(
|
||||
DataSourceType::Postgres,
|
||||
"text",
|
||||
"integer"
|
||||
));
|
||||
|
||||
// Different types are incompatible
|
||||
assert!(!types_compatible(
|
||||
DataSourceType::Postgres,
|
||||
"integer",
|
||||
"text"
|
||||
));
|
||||
|
||||
// Timestamp compatibility
|
||||
assert!(types_compatible(
|
||||
DataSourceType::Postgres,
|
||||
"timestamp",
|
||||
"timestamptz"
|
||||
));
|
||||
}
|
||||
}
|
|
@ -1,182 +0,0 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ValidationResult {
|
||||
pub success: bool,
|
||||
pub model_name: String,
|
||||
pub data_source_name: String,
|
||||
pub schema: String,
|
||||
pub errors: Vec<ValidationError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ValidationError {
|
||||
pub error_type: ValidationErrorType,
|
||||
pub column_name: Option<String>,
|
||||
pub message: String,
|
||||
pub suggestion: Option<String>,
|
||||
pub context: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub enum ValidationErrorType {
|
||||
TableNotFound,
|
||||
ColumnNotFound,
|
||||
TypeMismatch,
|
||||
DataSourceError,
|
||||
ModelNotFound,
|
||||
InvalidRelationship,
|
||||
ExpressionError,
|
||||
ProjectNotFound,
|
||||
InvalidBusterYml,
|
||||
DataSourceMismatch,
|
||||
RequiredFieldMissing,
|
||||
DataSourceNotFound,
|
||||
SchemaError,
|
||||
CredentialsError,
|
||||
DatabaseError,
|
||||
ValidationError,
|
||||
InternalError,
|
||||
}
|
||||
|
||||
impl ValidationResult {
|
||||
pub fn new(model_name: String, data_source_name: String, schema: String) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
model_name,
|
||||
data_source_name,
|
||||
schema,
|
||||
errors: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_error(&mut self, error: ValidationError) {
|
||||
self.success = false;
|
||||
self.errors.push(error);
|
||||
}
|
||||
}
|
||||
|
||||
impl ValidationError {
|
||||
pub fn new(
|
||||
error_type: ValidationErrorType,
|
||||
column_name: Option<String>,
|
||||
message: String,
|
||||
suggestion: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
error_type,
|
||||
column_name,
|
||||
message,
|
||||
suggestion,
|
||||
context: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_context(mut self, context: String) -> Self {
|
||||
self.context = Some(context);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn table_not_found(table_name: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::TableNotFound,
|
||||
None,
|
||||
format!("Table '{}' not found in data source", table_name),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn column_not_found(column_name: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::ColumnNotFound,
|
||||
Some(column_name.to_string()),
|
||||
format!("Column '{}' not found in data source", column_name),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn type_mismatch(column_name: &str, expected: &str, found: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::TypeMismatch,
|
||||
Some(column_name.to_string()),
|
||||
format!(
|
||||
"Column '{}' type mismatch. Expected: {}, Found: {}",
|
||||
column_name, expected, found
|
||||
),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn data_source_error(message: String) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::DataSourceError,
|
||||
None,
|
||||
message,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn model_not_found(model_name: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::ModelNotFound,
|
||||
None,
|
||||
format!("Model '{}' not found in data source", model_name),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn invalid_relationship(from: &str, to: &str, reason: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::InvalidRelationship,
|
||||
None,
|
||||
format!("Invalid relationship from '{}' to '{}': {}", from, to, reason),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn expression_error(column_name: &str, expr: &str, reason: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::ExpressionError,
|
||||
Some(column_name.to_string()),
|
||||
format!("Invalid expression '{}' for column '{}': {}", expr, column_name, reason),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
// New factory methods for enhanced error types
|
||||
pub fn schema_error(schema_name: &str, reason: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::SchemaError,
|
||||
None,
|
||||
format!("Schema '{}' error: {}", schema_name, reason),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn credentials_error(data_source: &str, reason: &str) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::CredentialsError,
|
||||
None,
|
||||
format!("Credentials error for data source '{}': {}", data_source, reason),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn database_error(message: String) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::DatabaseError,
|
||||
None,
|
||||
message,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn internal_error(message: String) -> Self {
|
||||
Self::new(
|
||||
ValidationErrorType::InternalError,
|
||||
None,
|
||||
message,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue