Merge branch 'evals' of https://github.com/buster-so/buster into evals

This commit is contained in:
Nate Kelley 2025-04-04 16:41:42 -06:00
commit 93439d45b0
No known key found for this signature in database
GPG Key ID: FD90372AB8D98B4F
24 changed files with 5 additions and 3392 deletions

View File

@ -36,14 +36,14 @@ pub async fn post_dataset(
Ok(None) => { Ok(None) => {
return Err(( return Err((
StatusCode::FORBIDDEN, StatusCode::FORBIDDEN,
"User does not belong to any organization", "User does not belong to any organization".to_string(),
)); ));
} }
Err(e) => { Err(e) => {
tracing::error!("Error getting user organization id: {:?}", e); tracing::error!("Error getting user organization id: {:?}", e);
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"Error getting user organization id", "Error getting user organization id".to_string(),
)); ));
} }
}; };
@ -53,14 +53,14 @@ pub async fn post_dataset(
Ok(false) => { Ok(false) => {
return Err(( return Err((
StatusCode::FORBIDDEN, StatusCode::FORBIDDEN,
"Insufficient permissions", "Insufficient permissions".to_string(),
)) ))
} }
Err(e) => { Err(e) => {
tracing::error!("Error checking user permissions: {:?}", e); tracing::error!("Error checking user permissions: {:?}", e);
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"Error checking user permissions", "Error checking user permissions".to_string(),
)); ));
} }
} }
@ -78,7 +78,7 @@ pub async fn post_dataset(
tracing::error!("Error creating dataset: {:?}", e); tracing::error!("Error creating dataset: {:?}", e);
return Err(( return Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"Error creating dataset", "Error creating dataset".to_string(),
)); ));
} }
}; };

View File

@ -1 +0,0 @@
pub mod types;

View File

@ -1,383 +0,0 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum ViewType {
Chart,
Table,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "camelCase")]
pub enum ChartType {
Line,
Bar,
Scatter,
Pie,
Metric,
Table,
Combo,
}
impl ChartType {
pub fn to_string(&self) -> String {
match self {
ChartType::Line => "line".to_string(),
ChartType::Bar => "bar".to_string(),
ChartType::Scatter => "scatter".to_string(),
ChartType::Pie => "pie".to_string(),
ChartType::Metric => "metric".to_string(),
ChartType::Table => "table".to_string(),
ChartType::Combo => "combo".to_string(),
}
}
pub fn from_string(chart_type: &str) -> ChartType {
match chart_type {
"line" => ChartType::Line,
"bar" => ChartType::Bar,
"scatter" => ChartType::Scatter,
"pie" => ChartType::Pie,
"metric" => ChartType::Metric,
"table" => ChartType::Table,
"combo" => ChartType::Combo,
_ => ChartType::Table,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BusterChartConfig {
pub selected_chart_type: ChartType,
pub selected_view: ViewType,
#[serde(skip_serializing_if = "Option::is_none")]
pub column_label_formats: Option<HashMap<String, ColumnLabelFormat>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub column_settings: Option<HashMap<String, ColumnSettings>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub colors: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub show_legend: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grid_lines: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub show_legend_headline: Option<ShowLegendHeadline>,
#[serde(skip_serializing_if = "Option::is_none")]
pub goal_lines: Option<Vec<GoalLine>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trendlines: Option<Vec<Trendline>>,
#[serde(flatten)]
pub y_axis_config: YAxisConfig,
#[serde(flatten)]
pub x_axis_config: XAxisConfig,
#[serde(flatten)]
pub y2_axis_config: Y2AxisConfig,
#[serde(flatten)]
pub bar_chart_props: BarChartProps,
#[serde(flatten)]
pub line_chart_props: LineChartProps,
#[serde(flatten)]
pub scatter_chart_props: ScatterChartProps,
#[serde(flatten)]
pub pie_chart_props: PieChartProps,
#[serde(flatten)]
pub table_chart_props: TableChartProps,
#[serde(flatten)]
pub combo_chart_props: ComboChartProps,
#[serde(flatten)]
pub metric_chart_props: MetricChartProps,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct LineChartProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub line_style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub line_group_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct BarChartProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub bar_and_line_axis: Option<BarAndLineAxis>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bar_layout: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bar_sort_by: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bar_group_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bar_show_total_at_top: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ScatterChartProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub scatter_axis: Option<ScatterAxis>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scatter_dot_size: Option<(f64, f64)>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct PieChartProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_chart_axis: Option<PieChartAxis>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_display_label_as: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_show_inner_label: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_inner_label_aggregate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_inner_label_title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_label_position: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_donut_width: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pie_minimum_slice_percentage: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct TableChartProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub table_column_order: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_column_widths: Option<HashMap<String, f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_header_background_color: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_header_font_color: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_column_font_color: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ComboChartProps {
#[serde(skip_serializing_if = "Option::is_none")]
pub combo_chart_axis: Option<ComboChartAxis>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct MetricChartProps {
pub metric_column_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metric_value_aggregate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metric_header: Option<MetricTitle>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metric_sub_header: Option<MetricTitle>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metric_value_label: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MetricTitle {
String(String),
Derived(DerivedMetricTitle),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DerivedMetricTitle {
pub column_id: String,
pub use_value: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ShowLegendHeadline {
Bool(bool),
String(String),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ColumnSettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub show_data_labels: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub column_visualization: Option<String>,
#[serde(flatten)]
pub line_settings: LineColumnSettings,
#[serde(flatten)]
pub bar_settings: BarColumnSettings,
#[serde(flatten)]
pub dot_settings: DotColumnSettings,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct LineColumnSettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub line_dash_style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub line_width: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub line_style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub line_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub line_symbol_size: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct BarColumnSettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub bar_roundness: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct DotColumnSettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub line_symbol_size: Option<f64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ColumnLabelFormat {
#[serde(skip_serializing_if = "Option::is_none")]
pub style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub column_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub number_separator_style: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub minimum_fraction_digits: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum_fraction_digits: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub multiplier: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prefix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub replace_missing_data_with: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_relative_time: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_utc: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub make_label_human_readable: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub currency: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub date_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub convert_number_to: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GoalLine {
pub show: bool,
pub value: f64,
pub show_goal_line_label: bool,
pub goal_line_label: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub goal_line_color: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Trendline {
pub show: bool,
pub show_trendline_label: bool,
pub trendline_label: Option<String>,
pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub trendline_color: Option<String>,
pub column_id: String,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct YAxisConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub y_axis_show_axis_label: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y_axis_show_axis_title: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y_axis_axis_title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y_axis_start_axis_at_zero: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y_axis_scale_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Y2AxisConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub y2_axis_show_axis_label: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y2_axis_show_axis_title: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y2_axis_axis_title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y2_axis_start_axis_at_zero: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y2_axis_scale_type: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct XAxisConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub x_axis_show_ticks: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x_axis_show_axis_label: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x_axis_show_axis_title: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x_axis_axis_title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x_axis_label_rotation: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x_axis_data_zoom: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct CategoryAxisStyleConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub category_show_total_at_top: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub category_axis_title: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BarAndLineAxis {
pub x: Vec<String>,
pub y: Vec<String>,
pub category: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ScatterAxis {
pub x: Vec<String>,
pub y: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub size: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ComboChartAxis {
pub x: Vec<String>,
pub y: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y2: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tooltip: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PieChartAxis {
pub x: Vec<String>,
pub y: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tooltip: Option<Vec<String>>,
}

View File

@ -1,238 +0,0 @@
use anyhow::{anyhow, Result};
use futures::StreamExt;
use std::{env, time::Duration};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;
use serde::{Deserialize, Serialize};
use crate::utils::clients::sentry_utils::send_sentry_error;
const ANTHROPIC_CHAT_URL: &str = "https://api.anthropic.com/v1/messages";
lazy_static::lazy_static! {
static ref ANTHROPIC_API_KEY: String = env::var("ANTHROPIC_API_KEY")
.expect("ANTHROPIC_API_KEY must be set");
static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED")
.unwrap_or(String::from("true"))
.parse()
.expect("MONITORING_ENABLED must be a boolean");
}
#[derive(Serialize, Clone)]
pub enum AnthropicChatModel {
#[serde(rename = "claude-3-opus-20240229")]
Claude3Opus20240229,
}
#[derive(Serialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum AnthropicChatRole {
User,
Assistant,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "lowercase")]
pub enum AnthropicContentType {
Text,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct AnthropicContent {
#[serde(rename = "type")]
pub _type: AnthropicContentType,
pub text: String,
}
#[derive(Serialize, Clone)]
pub struct AnthropicChatMessage {
pub role: AnthropicChatRole,
pub content: Vec<AnthropicContent>,
}
#[derive(Serialize, Clone)]
pub struct AnthropicChatRequest {
pub model: AnthropicChatModel,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
pub messages: Vec<AnthropicChatMessage>,
pub temperature: f32,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
pub stream: bool,
}
#[derive(Deserialize, Debug, Clone)]
pub struct ChatCompletionResponse {
pub content: Vec<Content>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Content {
#[serde(rename = "type")]
pub _type: String,
pub text: String,
}
pub async fn anthropic_chat(
model: &AnthropicChatModel,
system: Option<String>,
messages: &Vec<AnthropicChatMessage>,
temperature: f32,
max_tokens: u32,
timeout: u64,
stop: Option<Vec<String>>,
) -> Result<String> {
let chat_request = AnthropicChatRequest {
model: model.clone(),
system,
messages: messages.clone(),
temperature,
max_tokens,
stop_sequences: stop,
stream: false,
};
let client = reqwest::Client::new();
let headers = {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"x-api-key",
format!("{}", ANTHROPIC_API_KEY.to_string())
.parse()
.unwrap(),
);
headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
headers
};
let response = match client
.post(ANTHROPIC_CHAT_URL)
.headers(headers)
.json(&chat_request)
.timeout(Duration::from_secs(timeout))
.send()
.await
{
Ok(response) => response,
Err(e) => {
tracing::error!("Unable to send request to Anthropic: {:?}", e);
let err = anyhow!("Unable to send request to Anthropic: {}", e);
send_sentry_error(&err.to_string(), None);
return Err(err);
}
};
let completion_res = match response.json::<ChatCompletionResponse>().await {
Ok(res) => res,
Err(e) => {
tracing::error!("Unable to parse response from Anthropic: {:?}", e);
let err = anyhow!("Unable to parse response from Anthropic: {}", e);
send_sentry_error(&err.to_string(), None);
return Err(err);
}
};
let content = match completion_res.content.get(0) {
Some(content) => content.text.clone(),
None => return Err(anyhow!("No content returned from Anthropic")),
};
Ok(content)
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AnthropicChatDelta {
#[serde(rename = "type")]
pub _type: String,
pub delta: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct AnthropicChatStreamResponse {
#[serde(rename = "type")]
pub _type: String,
pub delta: Option<AnthropicChatDelta>,
}
pub async fn anthropic_chat_stream(
model: &AnthropicChatModel,
system: Option<String>,
messages: &Vec<AnthropicChatMessage>,
temperature: f32,
max_tokens: u32,
timeout: u64,
stop: Option<Vec<String>>,
) -> Result<ReceiverStream<String>> {
let chat_request = AnthropicChatRequest {
model: model.clone(),
system,
messages: messages.clone(),
temperature,
max_tokens,
stream: true,
stop_sequences: stop,
};
let client = reqwest::Client::new();
let headers = {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", ANTHROPIC_API_KEY.to_string())
.parse()
.unwrap(),
);
headers
};
let (_tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel(100);
tokio::spawn(async move {
let response = client
.post(ANTHROPIC_CHAT_URL)
.headers(headers)
.json(&chat_request)
.timeout(Duration::from_secs(timeout))
.send()
.await
.map_err(|e| {
tracing::error!("Unable to send request to Anthropic: {:?}", e);
let err = anyhow!("Unable to send request to Anthropic: {}", e);
send_sentry_error(&err.to_string(), None);
err
});
if let Err(_e) = response {
return;
}
let response = response.unwrap();
let mut stream = response.bytes_stream();
let _buffer = String::new();
while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
let chunk = String::from_utf8(bytes.to_vec()).unwrap();
println!("----------------------");
println!("Chunk: {}", chunk);
println!("----------------------");
}
Err(e) => {
tracing::error!("Error while streaming response: {:?}", e);
let err = anyhow!("Error while streaming response: {}", e);
send_sentry_error(&err.to_string(), None);
break;
}
}
}
});
Ok(ReceiverStream::new(rx))
}

View File

@ -1,59 +0,0 @@
use std::env;
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use tokio::task;
use super::{
hugging_face::hugging_face_embedding, ollama::ollama_embedding, openai::ada_bulk_embedding,
};
#[derive(Serialize)]
pub struct EmbeddingRequest {
pub prompt: String,
}
#[derive(Deserialize, Debug)]
pub struct EmbeddingResponse {
pub embedding: Vec<Vec<f32>>,
}
pub enum EmbeddingProvider {
OpenAi,
Ollama,
HuggingFace,
}
impl EmbeddingProvider {
pub fn get_embedding_provider() -> Result<EmbeddingProvider> {
let embedding_provider =
env::var("EMBEDDING_PROVIDER").expect("An embedding provider is required.");
match embedding_provider.as_str() {
"openai" => Ok(EmbeddingProvider::OpenAi),
"ollama" => Ok(EmbeddingProvider::Ollama),
"huggingface" => Ok(EmbeddingProvider::HuggingFace),
_ => Err(anyhow!("Invalid embedding provider")),
}
}
}
pub async fn embedding_router(prompts: Vec<String>, for_retrieval: bool) -> Result<Vec<Vec<f32>>> {
let embedding_provider = EmbeddingProvider::get_embedding_provider()?;
match embedding_provider {
EmbeddingProvider::Ollama => {
let tasks = prompts.into_iter().map(|prompt| {
task::spawn(async move { ollama_embedding(prompt, for_retrieval).await })
});
let results = futures::future::join_all(tasks).await;
let embeddings: Result<Vec<Vec<f32>>> = results
.into_iter()
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.collect();
embeddings
}
EmbeddingProvider::HuggingFace => hugging_face_embedding(prompts).await,
EmbeddingProvider::OpenAi => ada_bulk_embedding(prompts).await,
}
}

View File

@ -1,57 +0,0 @@
use anyhow::{anyhow, Result};
use serde::Serialize;
use std::env;
use axum::http::HeaderMap;
#[derive(Serialize)]
pub struct HuggingFaceEmbeddingRequest {
pub inputs: Vec<String>,
}
pub async fn hugging_face_embedding(prompts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let hugging_face_url = env::var("HUGGING_FACE_URL").expect("HUGGING_FACE_URL must be set");
let hugging_face_api_key =
env::var("HUGGING_FACE_API_KEY").expect("HUGGING_FACE_API_KEY must be set");
let client = match reqwest::Client::builder().build() {
Ok(client) => client,
Err(e) => {
return Err(anyhow!("Error creating reqwest client: {:?}", e));
}
};
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", hugging_face_api_key).parse().unwrap(),
);
let req = HuggingFaceEmbeddingRequest { inputs: prompts };
let res = match client
.post(hugging_face_url)
.headers(headers)
.json(&req)
.send()
.await
{
Ok(res) => res,
Err(e) => {
return Err(anyhow!("Error sending Ollama request: {:?}", e));
}
};
let embeddings = match res.json::<Vec<Vec<f32>>>().await {
Ok(res) => res,
Err(e) => {
return Err(anyhow!("Error parsing Ollama response: {:?}", e));
}
};
Ok(embeddings)
}

View File

@ -1,394 +0,0 @@
use anyhow::{anyhow, Result};
use axum::http::HeaderMap;
use base64::Engine;
use reqwest::Method;
use std::env;
use tiktoken_rs::o200k_base;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::utils::clients::sentry_utils::send_sentry_error;
use super::{anthropic::AnthropicChatModel, llm_router::LlmModel, openai::OpenAiChatModel};
lazy_static::lazy_static! {
static ref LANGFUSE_API_URL: String = env::var("LANGFUSE_API_URL").unwrap_or("https://us.cloud.langfuse.com".to_string());
static ref LANGFUSE_API_PUBLIC_KEY: String = env::var("LANGFUSE_PUBLIC_API_KEY").expect("LANGFUSE_PUBLIC_API_KEY must be set");
static ref LANGFUSE_API_PRIVATE_KEY: String = env::var("LANGFUSE_PRIVATE_API_KEY").expect("LANGFUSE_PRIVATE_API_KEY must be set");
}
impl LlmModel {
pub fn generate_usage(&self, input: &String, output: &String) -> Usage {
let bpe = o200k_base().unwrap();
let input_token = bpe.encode_with_special_tokens(&input);
let output_token = bpe.encode_with_special_tokens(&output);
match self {
LlmModel::OpenAi(OpenAiChatModel::O3Mini) => Usage {
input: input_token.len() as u32,
output: output_token.len() as u32,
unit: "TOKENS".to_string(),
input_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5,
output_cost: (output_token.len() as f64 / 1_000_000.0) * 10.0,
total_cost: (input_token.len() as f64 / 1_000_000.0) * 2.5
+ (output_token.len() as f64 / 1_000_000.0) * 10.0,
},
LlmModel::OpenAi(OpenAiChatModel::Gpt35Turbo) => Usage {
input: input_token.len() as u32,
output: output_token.len() as u32,
unit: "TOKENS".to_string(),
input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5,
output_cost: (output_token.len() as f64 / 1_000_000.0) * 1.5,
total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.5
+ (output_token.len() as f64 / 1_000_000.0) * 1.5,
},
LlmModel::OpenAi(OpenAiChatModel::Gpt4o) => Usage {
input: input_token.len() as u32,
output: output_token.len() as u32,
unit: "TOKENS".to_string(),
input_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15,
output_cost: (output_token.len() as f64 / 1_000_000.0) * 0.6,
total_cost: (input_token.len() as f64 / 1_000_000.0) * 0.15
+ (output_token.len() as f64 / 1_000_000.0) * 0.6,
},
LlmModel::Anthropic(AnthropicChatModel::Claude3Opus20240229) => Usage {
input: input_token.len() as u32,
output: output_token.len() as u32,
unit: "TOKENS".to_string(),
input_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0,
output_cost: (output_token.len() as f64 / 1_000_000.0) * 15.0,
total_cost: (input_token.len() as f64 / 1_000_000.0) * 3.0
+ (output_token.len() as f64 / 1_000_000.0) * 15.0,
},
}
}
}
#[derive(Serialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum PromptName {
SelectDataset,
GenerateSql,
SelectTerm,
DataSummary,
AutoChartConfig,
LineChartConfig,
BarChartConfig,
ScatterChartConfig,
PieChartConfig,
MetricChartConfig,
TableConfig,
ColumnLabelFormat,
AdvancedVisualizationConfig,
NoDataReturnedResponse,
DataExplanation,
MetricTitle,
TimeFrame,
FixSqlPlanner,
FixSql,
GenerateColDescriptions,
GenerateDatasetDescription,
SummaryQuestion,
CustomPrompt(String),
}
impl PromptName {
fn to_string(&self) -> String {
match self {
PromptName::SelectDataset => "select_dataset".to_string(),
PromptName::GenerateSql => "generate_sql".to_string(),
PromptName::SelectTerm => "select_term".to_string(),
PromptName::DataSummary => "data_summary".to_string(),
PromptName::AutoChartConfig => "auto_chart_config".to_string(),
PromptName::LineChartConfig => "line_chart_config".to_string(),
PromptName::BarChartConfig => "bar_chart_config".to_string(),
PromptName::ScatterChartConfig => "scatter_chart_config".to_string(),
PromptName::PieChartConfig => "pie_chart_config".to_string(),
PromptName::MetricChartConfig => "metric_chart_config".to_string(),
PromptName::TableConfig => "table_config".to_string(),
PromptName::ColumnLabelFormat => "column_label_format".to_string(),
PromptName::AdvancedVisualizationConfig => "advanced_visualization_config".to_string(),
PromptName::NoDataReturnedResponse => "no_data_returned_response".to_string(),
PromptName::DataExplanation => "data_explanation".to_string(),
PromptName::MetricTitle => "metric_title".to_string(),
PromptName::TimeFrame => "time_frame".to_string(),
PromptName::FixSqlPlanner => "fix_sql_planner".to_string(),
PromptName::FixSql => "fix_sql".to_string(),
PromptName::GenerateColDescriptions => "generate_col_descriptions".to_string(),
PromptName::GenerateDatasetDescription => "generate_dataset_description".to_string(),
PromptName::SummaryQuestion => "summary_question".to_string(),
PromptName::CustomPrompt(prompt) => prompt.clone(),
}
}
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "kebab-case")]
enum LangfuseIngestionType {
TraceCreate,
GenerationCreate,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
struct CreateTraceBody {
id: Uuid,
timestamp: DateTime<Utc>,
name: String,
user_id: Uuid,
input: String,
output: String,
session_id: Uuid,
release: String,
version: String,
metadata: Metadata,
tags: Vec<String>,
public: bool,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GenerationCreateBody {
trace_id: Uuid,
name: String,
start_time: DateTime<Utc>,
completion_start_time: DateTime<Utc>,
input: String,
output: String,
level: String,
end_time: DateTime<Utc>,
model: LlmModel,
id: Uuid,
usage: Usage,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Usage {
input: u32,
output: u32,
unit: String,
input_cost: f64,
output_cost: f64,
total_cost: f64,
}
#[derive(Serialize, Debug)]
struct Metadata {}
#[derive(Serialize)]
#[serde(untagged)]
enum LangfuseRequestBody {
CreateTraceBody(CreateTraceBody),
GenerationCreateBody(GenerationCreateBody),
}
#[derive(Serialize)]
struct LangfuseBatchItem {
id: Uuid,
r#type: LangfuseIngestionType,
body: LangfuseRequestBody,
timestamp: DateTime<Utc>,
}
#[derive(Serialize)]
struct LangfuseBatch {
batch: Vec<LangfuseBatchItem>,
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct LangfuseError {
id: String,
status: u16,
message: String,
error: String,
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct LangfuseSuccess {
id: String,
status: u16,
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct LangfuseResponse {
successes: Vec<LangfuseSuccess>,
errors: Vec<LangfuseError>,
}
/// Args:
/// - session_id: this can be a thread_id or any other type of chain event we have
///
pub async fn send_langfuse_request(
session_id: &Uuid,
prompt_name: PromptName,
context: Option<String>,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
input: String,
output: String,
user_id: &Uuid,
langfuse_model: &LlmModel,
) -> () {
let session_id = session_id.clone();
let user_id = user_id.clone();
let langfuse_model = langfuse_model.clone();
tokio::spawn(async move {
match langfuse_handler(
session_id,
prompt_name,
context,
start_time,
end_time,
input,
output,
user_id,
langfuse_model,
)
.await
{
Ok(_) => (),
Err(e) => {
let err = anyhow!("Error sending Langfuse request: {:?}", e);
send_sentry_error(&err.to_string(), Some(&user_id));
}
}
});
}
async fn langfuse_handler(
session_id: Uuid,
prompt_name: PromptName,
context: Option<String>,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
input: String,
output: String,
user_id: Uuid,
langfuse_model: LlmModel,
) -> Result<()> {
let input = match context {
Some(context) => format!("{} \n\n {}", context, input),
None => input,
};
let trace_id = Uuid::new_v4();
let langfuse_trace = LangfuseBatchItem {
id: Uuid::new_v4(),
r#type: LangfuseIngestionType::TraceCreate,
body: LangfuseRequestBody::CreateTraceBody(CreateTraceBody {
id: trace_id.clone(),
timestamp: Utc::now(),
name: prompt_name.clone().to_string(),
user_id: user_id,
input: serde_json::to_string(&input).unwrap(),
output: serde_json::to_string(&output).unwrap(),
session_id: session_id,
release: "1.0.0".to_string(),
version: "1.0.0".to_string(),
metadata: Metadata {},
tags: vec![],
public: false,
}),
timestamp: Utc::now(),
};
let langfuse_generation = LangfuseBatchItem {
id: Uuid::new_v4(),
r#type: LangfuseIngestionType::GenerationCreate,
body: LangfuseRequestBody::GenerationCreateBody(GenerationCreateBody {
id: Uuid::new_v4(),
name: prompt_name.to_string(),
input: serde_json::to_string(&input).unwrap(),
output: serde_json::to_string(&output).unwrap(),
trace_id,
start_time,
completion_start_time: start_time,
level: "DEBUG".to_string(),
end_time,
model: langfuse_model.clone(),
usage: langfuse_model.generate_usage(&input, &output),
}),
timestamp: Utc::now(),
};
let langfuse_batch = LangfuseBatch {
batch: vec![langfuse_trace, langfuse_generation],
};
let client = match reqwest::Client::builder().build() {
Ok(client) => client,
Err(e) => {
send_sentry_error(&e.to_string(), Some(&user_id));
return Err(anyhow!("Error creating reqwest client: {:?}", e));
}
};
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.insert(
reqwest::header::AUTHORIZATION,
format!(
"Basic {}",
base64::engine::general_purpose::STANDARD.encode(format!(
"{}:{}",
*LANGFUSE_API_PUBLIC_KEY, *LANGFUSE_API_PRIVATE_KEY
))
)
.parse()
.unwrap(),
);
let res = match client
.request(
Method::POST,
LANGFUSE_API_URL.to_string() + "/api/public/ingestion",
)
.headers(headers)
.json(&langfuse_batch)
.send()
.await
{
Ok(res) => res,
Err(e) => {
tracing::debug!("Error sending Langfuse request: {:?}", e);
send_sentry_error(&e.to_string(), Some(&user_id));
return Err(anyhow!("Error sending Langfuse request: {:?}", e));
}
};
let langfuse_res: LangfuseResponse = match res.json::<LangfuseResponse>().await {
Ok(res) => res,
Err(e) => {
tracing::debug!("Error parsing Langfuse response: {:?}", e);
send_sentry_error(&e.to_string(), Some(&user_id));
return Err(anyhow!("Error parsing Langfuse response: {:?}", e));
}
};
if langfuse_res.errors.len() > 0 {
for err in &langfuse_res.errors {
tracing::debug!("Langfuse error: {:?}", err);
send_sentry_error(&err.message, Some(&user_id));
}
return Err(anyhow::anyhow!(
"Langfuse errors: {:?}",
langfuse_res.errors
));
}
Ok(())
}

View File

@ -1,383 +0,0 @@
use std::env;
use anyhow::{anyhow, Result};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::{
sync::mpsc::{self, Receiver},
task::JoinHandle,
};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use uuid::Uuid;
use super::{
anthropic::{
anthropic_chat, anthropic_chat_stream, AnthropicChatMessage, AnthropicChatModel,
AnthropicChatRole, AnthropicContent, AnthropicContentType,
},
langfuse::{send_langfuse_request, PromptName},
openai::{
openai_chat, openai_chat_stream, OpenAiChatContent, OpenAiChatMessage, OpenAiChatModel,
OpenAiChatRole,
},
};
use lazy_static::lazy_static;
lazy_static! {
static ref MONITORING_ENABLED: bool = env::var("MONITORING_ENABLED")
.unwrap_or(String::from("true"))
.parse()
.expect("MONITORING_ENABLED must be a boolean");
}
#[derive(Serialize, Clone)]
#[serde(untagged)]
pub enum LlmModel {
Anthropic(AnthropicChatModel),
OpenAi(OpenAiChatModel),
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum LlmRole {
System,
User,
Assistant,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct LlmMessage {
pub role: LlmRole,
pub content: String,
}
impl LlmMessage {
pub fn new(role: String, content: String) -> Self {
let role = match role.to_lowercase().as_str() {
"system" => LlmRole::System,
"user" => LlmRole::User,
"assistant" => LlmRole::Assistant,
_ => LlmRole::User,
};
Self { role, content }
}
}
pub async fn llm_chat(
model: LlmModel,
messages: &Vec<LlmMessage>,
temperature: f32,
max_tokens: u32,
timeout: u64,
stop: Option<Vec<String>>,
json_mode: bool,
json_schema: Option<Value>,
session_id: &Uuid,
user_id: &Uuid,
prompt_name: PromptName,
) -> Result<String> {
let start_time = Utc::now();
let response_result = match &model {
LlmModel::Anthropic(model) => {
anthropic_chat_compiler(model, messages, max_tokens, temperature, timeout, stop).await
}
LlmModel::OpenAi(model) => {
openai_chat_compiler(
model,
messages,
7048,
temperature,
timeout,
stop,
json_mode,
json_schema,
)
.await
}
};
let response = match response_result {
Ok(response) => response,
Err(e) => return Err(anyhow!("LLM chat error: {}", e)),
};
let end_time = Utc::now();
send_langfuse_request(
session_id,
prompt_name,
None,
start_time,
end_time,
serde_json::to_string(&messages).unwrap(),
serde_json::to_string(&response).unwrap(),
user_id,
&model,
)
.await;
Ok(response)
}
pub async fn llm_chat_stream(
model: LlmModel,
messages: Vec<LlmMessage>,
temperature: f32,
max_tokens: u32,
timeout: u64,
stop: Option<Vec<String>>,
session_id: &Uuid,
user_id: &Uuid,
prompt_name: PromptName,
) -> Result<(Receiver<String>, JoinHandle<Result<String>>)> {
let start_time = Utc::now();
let stream_result = match &model {
LlmModel::Anthropic(model) => {
anthropic_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop)
.await
}
LlmModel::OpenAi(model) => {
openai_chat_stream_compiler(model, &messages, max_tokens, temperature, timeout, stop)
.await
}
};
let mut stream = match stream_result {
Ok(stream) => stream,
Err(e) => return Err(anyhow!("LLM chat error: {}", e)),
};
let (tx, rx) = mpsc::channel(100);
let res_future = {
let session_id = session_id.clone();
let user_id = user_id.clone();
tokio::spawn(async move {
let mut response = String::new();
while let Some(content) = stream.next().await {
response.push_str(&content);
match tx.send(content).await {
Ok(_) => (),
Err(e) => return Err(anyhow!("Streaming Error: {}", e)),
}
}
let end_time = Utc::now();
send_langfuse_request(
&session_id,
prompt_name,
None,
start_time,
end_time,
serde_json::to_string(&messages).unwrap(),
serde_json::to_string(&response).unwrap(),
&user_id,
&model,
)
.await;
Ok(response)
})
};
Ok((rx, res_future))
}
async fn anthropic_chat_compiler(
model: &AnthropicChatModel,
messages: &Vec<LlmMessage>,
_max_tokens: u32,
temperature: f32,
timeout: u64,
stop: Option<Vec<String>>,
) -> Result<String> {
let system_message = match messages.iter().find(|m| m.role == LlmRole::System) {
Some(message) => Some(message.content.clone()),
None => None,
};
let mut anthropic_messages = Vec::new();
for message in messages {
let anthropic_role = match message.role {
LlmRole::System => continue,
LlmRole::User => AnthropicChatRole::User,
LlmRole::Assistant => AnthropicChatRole::Assistant,
};
let anthropic_content = AnthropicContent {
text: message.content.clone(),
_type: AnthropicContentType::Text,
};
anthropic_messages.push(AnthropicChatMessage {
role: anthropic_role,
content: vec![anthropic_content],
});
}
let response = match anthropic_chat(
model,
system_message,
&anthropic_messages,
temperature,
7048,
timeout,
stop,
)
.await
{
Ok(response) => response,
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
};
Ok(response)
}
async fn openai_chat_compiler(
model: &OpenAiChatModel,
messages: &Vec<LlmMessage>,
max_tokens: u32,
temperature: f32,
timeout: u64,
stop: Option<Vec<String>>,
json_mode: bool,
json_schema: Option<Value>,
) -> Result<String> {
let mut openai_messages = Vec::new();
for message in messages {
let openai_role = match message.role {
LlmRole::System => OpenAiChatRole::System,
LlmRole::User => OpenAiChatRole::User,
LlmRole::Assistant => OpenAiChatRole::Assistant,
};
let openai_message = OpenAiChatMessage {
role: openai_role,
content: vec![OpenAiChatContent {
type_: "text".to_string(),
text: message.content.clone(),
}],
};
openai_messages.push(openai_message);
}
let response = match openai_chat(
&model,
openai_messages,
temperature,
max_tokens,
timeout,
stop,
json_mode,
json_schema,
)
.await
{
Ok(response) => response,
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
};
Ok(response)
}
async fn anthropic_chat_stream_compiler(
model: &AnthropicChatModel,
messages: &Vec<LlmMessage>,
max_tokens: u32,
temperature: f32,
timeout: u64,
stop: Option<Vec<String>>,
) -> Result<ReceiverStream<String>> {
let system_message = match messages.iter().find(|m| m.role == LlmRole::System) {
Some(message) => Some(message.content.clone()),
None => None,
};
let mut anthropic_messages = Vec::new();
for message in messages {
let anthropic_role = match message.role {
LlmRole::System => continue,
LlmRole::User => AnthropicChatRole::User,
LlmRole::Assistant => AnthropicChatRole::Assistant,
};
let anthropic_content = AnthropicContent {
text: message.content.clone(),
_type: AnthropicContentType::Text,
};
anthropic_messages.push(AnthropicChatMessage {
role: anthropic_role,
content: vec![anthropic_content],
});
}
let stream = match anthropic_chat_stream(
model,
system_message,
&anthropic_messages,
temperature,
max_tokens,
timeout,
stop,
)
.await
{
Ok(response) => response,
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
};
Ok(stream)
}
async fn openai_chat_stream_compiler(
model: &OpenAiChatModel,
messages: &Vec<LlmMessage>,
max_tokens: u32,
temperature: f32,
timeout: u64,
stop: Option<Vec<String>>,
) -> Result<ReceiverStream<String>> {
let mut openai_messages = Vec::new();
for message in messages {
let openai_role = match message.role {
LlmRole::System => OpenAiChatRole::System,
LlmRole::User => OpenAiChatRole::User,
LlmRole::Assistant => OpenAiChatRole::Assistant,
};
let openai_message = OpenAiChatMessage {
role: openai_role,
content: vec![OpenAiChatContent {
type_: "text".to_string(),
text: message.content.clone(),
}],
};
openai_messages.push(openai_message);
}
let stream = match openai_chat_stream(
&model,
&openai_messages,
temperature,
max_tokens,
timeout,
stop,
)
.await
{
Ok(response) => response,
Err(e) => return Err(anyhow!("Anthropic chat error: {}", e)),
};
Ok(stream)
}

View File

@ -1,7 +0,0 @@
mod anthropic;
pub mod embedding_router;
mod hugging_face;
pub mod langfuse;
pub mod llm_router;
pub mod ollama;
pub mod openai;

View File

@ -1,67 +0,0 @@
use std::env;
use anyhow::{anyhow, Result};
use axum::http::HeaderMap;
use serde::{Deserialize, Serialize};
#[derive(Serialize)]
pub struct OllamaEmbeddingRequest {
pub prompt: String,
pub model: String,
}
#[derive(Deserialize, Debug)]
pub struct OllamaEmbeddingResponse {
pub embedding: Vec<f32>,
}
pub async fn ollama_embedding(prompt: String, for_retrieval: bool) -> Result<Vec<f32>> {
let ollama_url = env::var("OLLAMA_URL").unwrap_or(String::from("http://localhost:11434"));
let model = env::var("EMBEDDING_MODEL").unwrap_or(String::from("mxbai-embed-large"));
let prompt = if model == "mxbai-embed-large" && for_retrieval {
format!(
"Represent this sentence for searching relevant passages:: {}",
prompt
)
} else {
prompt
};
let client = match reqwest::Client::builder().build() {
Ok(client) => client,
Err(e) => {
return Err(anyhow!("Error creating reqwest client: {:?}", e));
}
};
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
let req = OllamaEmbeddingRequest { prompt, model };
let res = match client
.post(format!("{}/api/embeddings", ollama_url))
.headers(headers)
.json(&req)
.send()
.await
{
Ok(res) => res,
Err(e) => {
return Err(anyhow!("Error sending Ollama request: {:?}", e));
}
};
let ollama_res: OllamaEmbeddingResponse = match res.json::<OllamaEmbeddingResponse>().await {
Ok(res) => res,
Err(e) => {
return Err(anyhow!("Error parsing Ollama response: {:?}", e));
}
};
Ok(ollama_res.embedding)
}

View File

@ -1,417 +0,0 @@
use anyhow::{anyhow, Result};
use futures::StreamExt;
use serde_json::{json, Value};
use std::{env, time::Duration};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;
use serde::{Deserialize, Serialize};
use crate::utils::clients::sentry_utils::send_sentry_error;
const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings";
lazy_static::lazy_static! {
static ref OPENAI_API_KEY: String = env::var("OPENAI_API_KEY")
.expect("OPENAI_API_KEY must be set");
static ref OPENAI_CHAT_URL: String = env::var("OPENAI_CHAT_URL").unwrap_or("https://api.openai.com/v1/chat/completions".to_string());
}
#[derive(Serialize, Clone)]
pub enum OpenAiChatModel {
#[serde(rename = "gpt-4o-2024-11-20")]
Gpt4o,
#[serde(rename = "o3-mini")]
O3Mini,
#[serde(rename = "gpt-3.5-turbo")]
Gpt35Turbo,
}
#[derive(Serialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum OpenAiChatRole {
System,
Developer,
User,
Assistant,
}
#[derive(Serialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
#[derive(Serialize, Clone)]
pub struct OpenAiChatContent {
#[serde(rename = "type")]
pub type_: String,
pub text: String,
}
#[derive(Serialize, Clone)]
pub struct OpenAiChatMessage {
pub role: OpenAiChatRole,
pub content: Vec<OpenAiChatContent>,
}
// Helper functions for conditional serialization
fn is_o3_model(model: &OpenAiChatModel) -> bool {
matches!(model, OpenAiChatModel::O3Mini)
}
#[derive(Serialize, Clone)]
pub struct OpenAiChatRequest {
model: OpenAiChatModel,
messages: Vec<OpenAiChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<ReasoningEffort>,
frequency_penalty: f32,
presence_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<Value>,
}
impl OpenAiChatRequest {
pub fn new(
model: OpenAiChatModel,
messages: Vec<OpenAiChatMessage>,
temperature: f32,
max_tokens: u32,
stop: Option<Vec<String>>,
stream: bool,
response_format: Option<Value>,
) -> Self {
let (temperature, max_tokens, top_p, reasoning_effort) = if is_o3_model(&model) {
(None, None, None, Some(ReasoningEffort::Low))
} else {
(Some(temperature), Some(max_tokens), Some(1.0), None)
};
Self {
model,
messages,
temperature,
max_tokens,
top_p,
reasoning_effort,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop,
stream,
response_format,
}
}
}
#[derive(Deserialize, Debug, Clone)]
pub struct ChatCompletionResponse {
pub choices: Vec<Choice>,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Choice {
pub message: Message,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub content: String,
}
pub async fn openai_chat(
model: &OpenAiChatModel,
messages: Vec<OpenAiChatMessage>,
temperature: f32,
max_tokens: u32,
timeout: u64,
stop: Option<Vec<String>>,
json_mode: bool,
json_schema: Option<Value>,
) -> Result<String> {
let response_format = match json_mode {
true => Some(json!({"type": "json_object"})),
false => None,
};
let response_format = match json_schema {
Some(schema) => Some(json!({"type": "json_schema", "json_schema": schema})),
None => response_format,
};
let chat_request = OpenAiChatRequest::new(
model.clone(),
messages,
temperature,
max_tokens,
stop,
false,
response_format,
);
let client = reqwest::Client::new();
let headers = {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", OPENAI_API_KEY.to_string())
.parse()
.unwrap(),
);
headers
};
let response = match client
.post(OPENAI_CHAT_URL.to_string())
.headers(headers)
.json(&chat_request)
.timeout(Duration::from_secs(timeout))
.send()
.await
{
Ok(response) => response,
Err(e) => {
tracing::error!("Unable to send request to OpenAI: {:?}", e);
let err = anyhow!("Unable to send request to OpenAI: {}", e);
send_sentry_error(&err.to_string(), None);
return Err(err);
}
};
let response_text = response.text().await.unwrap();
let completion_res = match serde_json::from_str::<ChatCompletionResponse>(&response_text) {
Ok(res) => res,
Err(e) => {
tracing::error!("Unable to parse response from OpenAI: {:?}", e);
let err = anyhow!("Unable to parse response from OpenAI: {}", e);
send_sentry_error(&err.to_string(), None);
return Err(err);
}
};
let content = match completion_res.choices.get(0) {
Some(choice) => choice.message.content.clone(),
None => return Err(anyhow!("No content returned from OpenAI")),
};
Ok(content)
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct OpenAiChatDelta {
pub content: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct OpenAiChatChoice {
pub delta: OpenAiChatDelta,
pub index: u32,
pub finish_reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct OpenAiChatStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<OpenAiChatChoice>,
}
/// We can't do Langfuse traces directly in the stream function.
/// This is mainly because I'm too lazy to figure out how to set up all the passing with the stream reciever
/// Langfuse traces must be written directly on the stream call.
pub async fn openai_chat_stream(
model: &OpenAiChatModel,
messages: &Vec<OpenAiChatMessage>,
temperature: f32,
max_tokens: u32,
timeout: u64,
stop: Option<Vec<String>>,
) -> Result<ReceiverStream<String>> {
let chat_request = OpenAiChatRequest::new(
model.clone(),
messages.clone(),
temperature,
max_tokens,
stop,
true,
None,
);
let client = reqwest::Client::new();
let headers = {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", OPENAI_API_KEY.to_string())
.parse()
.unwrap(),
);
headers
};
let (tx, rx): (Sender<String>, Receiver<String>) = mpsc::channel(100);
tokio::spawn(async move {
let response = client
.post(OPENAI_CHAT_URL.to_string())
.headers(headers)
.json(&chat_request)
.timeout(Duration::from_secs(timeout))
.send()
.await
.map_err(|e| {
tracing::error!("Unable to send request to OpenAI: {:?}", e);
let err = anyhow!("Unable to send request to OpenAI: {}", e);
send_sentry_error(&err.to_string(), None);
err
});
if let Err(_e) = response {
return;
}
let response = response.unwrap();
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
let chunk = String::from_utf8(bytes.to_vec()).unwrap();
buffer.push_str(&chunk);
while let Some(pos) = buffer.find("}\n") {
let (json_str, rest) = {
let (json_str, rest) = buffer.split_at(pos + 1);
(json_str.to_string(), rest.to_string())
};
buffer = rest;
let json_str_trimmed = json_str.replace("data: ", "");
match serde_json::from_str::<OpenAiChatStreamResponse>(
&json_str_trimmed.trim(),
) {
Ok(response) => {
if let Some(content) = &response.choices[0].delta.content {
if tx.send(content.clone()).await.is_err() {
break;
}
}
}
Err(e) => {
tracing::error!("Error parsing JSON response: {:?}", e);
let err = anyhow!("Error parsing JSON response: {}", e);
send_sentry_error(&err.to_string(), None);
}
}
}
}
Err(e) => {
tracing::error!("Error while streaming response: {:?}", e);
let err = anyhow!("Error while streaming response: {}", e);
send_sentry_error(&err.to_string(), None);
break;
}
}
}
});
Ok(ReceiverStream::new(rx))
}
#[derive(Serialize, Debug)]
pub struct AdaBulkEmbedding {
pub model: String,
pub input: Vec<String>,
pub dimensions: u32,
}
#[derive(Deserialize, Debug)]
pub struct AdaEmbeddingArray {
pub embedding: Vec<f32>,
}
#[derive(Deserialize, Debug)]
pub struct AdaEmbeddingResponse {
pub data: Vec<AdaEmbeddingArray>,
}
pub async fn ada_bulk_embedding(text_list: Vec<String>) -> Result<Vec<Vec<f32>>> {
let embedding_model = env::var("EMBEDDING_MODEL").expect("EMBEDDING_MODEL must be set");
let client = reqwest::Client::new();
let ada_bulk_embedding = AdaBulkEmbedding {
model: embedding_model,
input: text_list,
dimensions: 1024,
};
let embeddings_result = match client
.post(OPENAI_EMBEDDING_URL)
.headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", OPENAI_API_KEY.to_string())
.parse()
.unwrap(),
);
headers
})
.timeout(Duration::from_secs(60))
.json(&ada_bulk_embedding)
.send()
.await
{
Ok(res) => {
if !res.status().is_success() {
tracing::error!(
"There was an issue while getting the data source: {}",
res.text().await.unwrap()
);
return Err(anyhow!("Error getting data source"));
}
Ok(res)
}
Err(e) => Err(anyhow!(e.to_string())),
};
let embedding_text = embeddings_result.unwrap().text().await.unwrap();
let embeddings: AdaEmbeddingResponse =
match serde_json::from_str(embedding_text.clone().as_str()) {
Ok(embeddings) => embeddings,
Err(e) => {
tracing::error!(
"There was an issue decoding the bulk embedding json response: {}",
e
);
return Err(anyhow!(e.to_string()));
}
};
let result: Vec<Vec<f32>> = embeddings.data.into_iter().map(|x| x.embedding).collect();
Ok(result)
}

View File

@ -1,65 +0,0 @@
use anyhow::{anyhow, Result};
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_secretsmanager::Client;
use uuid::Uuid;
pub async fn create_db_secret(secret: String) -> Result<String> {
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
let config = aws_config::from_env().region(region_provider).load().await;
let client = Client::new(&config);
let secret_id = Uuid::new_v4().to_string();
match client
.create_secret()
.name(secret_id.clone())
.secret_string(secret)
.send()
.await
{
Ok(res) => {
tracing::info!("Successfully created secret in AWS.");
tracing::info!("Secret ID: {}", secret_id);
tracing::info!("Secret ARN: {}", res.arn.unwrap());
}
Err(e) => {
tracing::error!("There was an issue while creating the secret in AWS.");
return Err(anyhow!(e));
}
}
Ok(secret_id)
}
pub async fn delete_db_secret(secret_id: String) -> Result<()> {
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
let config = aws_config::from_env().region(region_provider).load().await;
let client = Client::new(&config);
let _secret_response = match client.delete_secret().secret_id(secret_id).send().await {
Ok(secret_response) => secret_response,
Err(e) => return Err(anyhow!(e)),
};
Ok(())
}
pub async fn read_secret(secret_id: String) -> Result<String> {
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
let config = aws_config::from_env().region(region_provider).load().await;
let client = Client::new(&config);
let secret_response = client.get_secret_value().secret_id(secret_id).send().await;
let secret = match secret_response {
Ok(secret) => secret,
Err(e) => return Err(anyhow!(e)),
};
let secret_string = match secret.secret_string {
Some(secret_string) => secret_string,
None => return Err(anyhow!("There was no secret string in the response")),
};
return Ok(secret_string);
}

View File

@ -1,6 +1,2 @@
pub mod ai;
// pub mod aws;
pub mod email; pub mod email;
pub mod posthog;
pub mod sentry_utils; pub mod sentry_utils;
pub mod typesense;

View File

@ -1,190 +0,0 @@
use anyhow::{anyhow, Result};
use reqwest::header::AUTHORIZATION;
use std::env;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::utils::clients::sentry_utils::send_sentry_error;
lazy_static::lazy_static! {
static ref POSTHOG_API_KEY: String = env::var("POSTHOG_API_KEY").expect("POSTHOG_API_KEY must be set");
}
const POSTHOG_API_URL: &str = "https://us.i.posthog.com/capture/";
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PosthogEvent {
event: PosthogEventType,
distinct_id: Option<Uuid>,
properties: PosthogEventProperties,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum PosthogEventType {
MetricCreated,
MetricFollowUp,
MetricAddedToDashboard,
MetricViewed,
DashboardViewed,
TitleManuallyUpdated,
SqlManuallyUpdated,
ChartStylingManuallyUpdated,
ChartStylingAutoUpdated,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum PosthogEventProperties {
MetricCreated(MetricCreatedProperties),
MetricFollowUp(MetricFollowUpProperties),
MetricAddedToDashboard(MetricAddedToDashboardProperties),
MetricViewed(MetricViewedProperties),
DashboardViewed(DashboardViewedProperties),
TitleManuallyUpdated(TitleManuallyUpdatedProperties),
SqlManuallyUpdated(SqlManuallyUpdatedProperties),
ChartStylingManuallyUpdated(ChartStylingManuallyUpdatedProperties),
ChartStylingAutoUpdated(ChartStylingAutoUpdatedProperties),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MetricCreatedProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MetricFollowUpProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MetricAddedToDashboardProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
pub dashboard_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct MetricViewedProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DashboardViewedProperties {
pub dashboard_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TitleManuallyUpdatedProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SqlManuallyUpdatedProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChartStylingManuallyUpdatedProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChartStylingAutoUpdatedProperties {
pub message_id: Uuid,
pub thread_id: Uuid,
}
pub async fn send_posthog_event_handler(
event_type: PosthogEventType,
user_id: Option<Uuid>,
properties: PosthogEventProperties,
) -> Result<()> {
let event = PosthogEvent {
event: event_type,
distinct_id: user_id,
properties,
};
tokio::spawn(async move {
match send_event(event).await {
Ok(_) => (),
Err(e) => {
send_sentry_error(&e.to_string(), None);
tracing::error!("Unable to send event to Posthog: {:?}", e);
}
}
});
Ok(())
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PosthogRequest {
#[serde(flatten)]
event: PosthogEvent,
api_key: String,
}
async fn send_event(event: PosthogEvent) -> Result<()> {
let client = reqwest::Client::new();
let headers = {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {}", POSTHOG_API_KEY.to_string())
.parse()
.unwrap(),
);
headers
};
let posthog_req = PosthogRequest {
event,
api_key: POSTHOG_API_KEY.to_string(),
};
match client
.post(POSTHOG_API_URL)
.headers(headers)
.json(&posthog_req)
.send()
.await
{
Ok(_) => (),
Err(e) => {
tracing::error!("Unable to send request to Posthog: {:?}", e);
return Err(anyhow!(e));
}
};
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use dotenv::dotenv;
#[tokio::test]
async fn test_send_event() {
dotenv().ok();
let _ = send_event(PosthogEvent {
event: PosthogEventType::MetricCreated,
distinct_id: Some(Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap()),
properties: PosthogEventProperties::MetricCreated(MetricCreatedProperties {
message_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(),
thread_id: Uuid::parse_str("c2dd64cd-f7f3-4884-bc91-d46ae431901e").unwrap(),
}),
})
.await;
}
}

View File

@ -1,318 +0,0 @@
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::env;
use uuid::Uuid;
lazy_static::lazy_static! {
static ref TYPESENSE_API_KEY: String = env::var("TYPESENSE_API_KEY").unwrap_or("xyz".to_string());
static ref TYPESENSE_API_HOST: String = env::var("TYPESENSE_API_HOST").unwrap_or("http://localhost:8108".to_string());
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "snake_case")]
pub enum CollectionName {
Messages,
Collections,
Datasets,
PermissionGroups,
Teams,
DataSources,
Terms,
Dashboards,
#[serde(untagged)]
StoredValues(String),
}
impl CollectionName {
pub fn to_string(&self) -> String {
match self {
CollectionName::Messages => "messages".to_string(),
CollectionName::Collections => "collections".to_string(),
CollectionName::Datasets => "datasets".to_string(),
CollectionName::PermissionGroups => "permission_groups".to_string(),
CollectionName::Teams => "teams".to_string(),
CollectionName::DataSources => "data_sources".to_string(),
CollectionName::Terms => "terms".to_string(),
CollectionName::Dashboards => "dashboards".to_string(),
CollectionName::StoredValues(s) => s.clone(),
}
}
}
#[derive(Deserialize, Serialize)]
pub struct MessageDocument {
pub id: Uuid,
pub name: String,
pub summary_question: String,
pub organization_id: Uuid,
}
#[derive(Deserialize, Serialize)]
pub struct GenericDocument {
pub id: Uuid,
pub name: String,
pub organization_id: Uuid,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct StoredValueDocument {
pub id: Uuid,
pub value: String,
pub dataset_id: Uuid,
pub dataset_column_id: Uuid,
}
#[derive(Deserialize, Serialize)]
#[serde(untagged)]
pub enum Document {
Message(MessageDocument),
Dashboard(GenericDocument),
Dataset(GenericDocument),
PermissionGroup(GenericDocument),
Team(GenericDocument),
DataSource(GenericDocument),
Term(GenericDocument),
Collection(GenericDocument),
StoredValue(StoredValueDocument),
}
impl Document {
pub fn id(&self) -> Uuid {
match self {
Document::Message(m) => m.id,
Document::Dashboard(d) => d.id,
Document::Dataset(d) => d.id,
Document::PermissionGroup(d) => d.id,
Document::Team(d) => d.id,
Document::DataSource(d) => d.id,
Document::Term(d) => d.id,
Document::Collection(d) => d.id,
Document::StoredValue(d) => d.id,
}
}
pub fn into_stored_value_document(&self) -> &StoredValueDocument {
match self {
Document::StoredValue(d) => d,
_ => panic!("Document is not a StoredValueDocument"),
}
}
}
pub async fn upsert_document(collection: CollectionName, document: Document) -> Result<()> {
let client = reqwest::Client::builder().build()?;
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/json".parse()?);
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
match client
.request(
reqwest::Method::POST,
format!(
"{}/collections/{}/documents?action=upsert",
*TYPESENSE_API_HOST,
collection.to_string()
),
)
.headers(headers)
.json(&document)
.send()
.await
{
Ok(_) => (),
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
};
Ok(())
}
#[derive(Serialize, Debug)]
pub struct SearchRequest {
pub searches: Vec<SearchRequestObject>,
}
#[derive(Serialize, Debug)]
pub struct SearchRequestObject {
pub collection: CollectionName,
pub q: String,
pub query_by: String,
pub prefix: bool,
pub exclude_fields: String,
pub highlight_fields: String,
pub use_cache: bool,
pub filter_by: String,
pub vector_query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<i64>,
}
#[derive(Deserialize)]
pub struct RequestParams {
pub collection_name: CollectionName,
}
#[derive(Deserialize)]
pub struct Highlight {
pub field: String,
pub matched_tokens: Vec<String>,
}
#[derive(Deserialize)]
pub struct HybridSearchInfo {
pub rank_fusion_score: f64,
}
#[derive(Deserialize)]
pub struct Hit {
pub document: Document,
pub highlights: Vec<Highlight>,
pub hybrid_search_info: HybridSearchInfo,
}
#[derive(Deserialize)]
pub struct SearchResult {
pub request_params: RequestParams,
pub hits: Vec<Hit>,
}
#[derive(Deserialize)]
pub struct SearchResponse {
pub results: Vec<SearchResult>,
}
pub async fn search_documents(search_reqs: Vec<SearchRequestObject>) -> Result<SearchResponse> {
let client = reqwest::Client::builder().build()?;
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/json".parse()?);
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
let search_req = SearchRequest {
searches: search_reqs,
};
let res = match client
.request(
reqwest::Method::POST,
format!("{}/multi_search", *TYPESENSE_API_HOST),
)
.headers(headers)
.json(&search_req)
.send()
.await
{
Ok(res) => {
let res = if res.status().is_success() {
res
} else {
return Err(anyhow!(
"Error sending request: {}",
res.text().await.unwrap()
));
};
res
}
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
};
let search_response = match res.json::<SearchResponse>().await {
Ok(search_response) => search_response,
Err(e) => return Err(anyhow!("Error parsing response: {}", e)),
};
Ok(search_response)
}
pub async fn create_collection(schema: Value) -> Result<()> {
let client = reqwest::Client::builder().build()?;
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/json".parse()?);
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
match client
.request(
reqwest::Method::POST,
format!("{}/collections", *TYPESENSE_API_HOST,),
)
.headers(headers)
.json(&schema)
.send()
.await
{
Ok(_) => (),
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
};
Ok(())
}
pub async fn delete_collection(collection_name: &String, filter_by: &String) -> Result<()> {
let client = reqwest::Client::builder().build()?;
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/json".parse()?);
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
match client
.request(
reqwest::Method::DELETE,
format!(
"{}/collections/{}/documents?filter_by={}",
*TYPESENSE_API_HOST, collection_name, filter_by
),
)
.headers(headers)
.send()
.await
{
Ok(_) => (),
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
};
Ok(())
}
pub async fn bulk_insert_documents<T>(collection_name: &String, documents: &Vec<T>) -> Result<()>
where
T: serde::Serialize,
{
let client = reqwest::Client::builder().build()?;
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/json".parse()?);
headers.insert("X-TYPESENSE-API-KEY", TYPESENSE_API_KEY.parse()?);
let serialized_documents = documents
.iter()
.map(|doc| serde_json::to_string(doc))
.collect::<Result<Vec<String>, _>>()?
.join("\n");
match client
.request(
reqwest::Method::POST,
format!(
"{}/collections/{}/documents/import?action=create",
*TYPESENSE_API_HOST, collection_name
),
)
.headers(headers)
.body(serialized_documents)
.send()
.await
{
Ok(res) => {
if res.status().is_success() {
tracing::info!("Res: {:?}", res.text().await.unwrap());
return Ok(());
}
tracing::error!("Error sending request: {}", res.text().await.unwrap());
return Err(anyhow!("Error sending bulk insert request"));
}
Err(e) => return Err(anyhow!("Error sending request: {}", e)),
};
}

View File

@ -1,11 +1,5 @@
pub mod charting;
pub mod clients; pub mod clients;
pub mod security; pub mod security;
pub mod serde_helpers;
pub mod stored_values;
pub mod validation;
pub use agents::*; pub use agents::*;
pub use security::*; pub use security::*;
pub use stored_values::*;
pub use validation::*;

View File

@ -1,21 +0,0 @@
use serde::{de::Deserializer, Deserialize};
use serde_json::Value;
pub fn deserialize_double_option<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
where
T: serde::Deserialize<'de>,
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::Null => Ok(Some(None)), // explicit null
Value::Object(obj) if obj.is_empty() => Ok(None), // empty object
_ => {
match T::deserialize(value) {
Ok(val) => Ok(Some(Some(val))),
Err(_) => Ok(None)
}
}
}
}

View File

@ -1 +0,0 @@
pub mod deserialization_helpers;

View File

@ -1,291 +0,0 @@
pub mod search;
use query_engine::data_source_query_routes::query_engine::query_engine;
use query_engine::data_types::DataType;
pub use search::*;
use crate::utils::clients::ai::embedding_router::embedding_router;
use anyhow::Result;
use chrono::Utc;
use database::enums::StoredValuesStatus;
use database::{pool::get_pg_pool, schema::dataset_columns};
use diesel::prelude::*;
use diesel::sql_types::{Array, Float4, Integer, Text, Timestamptz, Uuid as SqlUuid};
use diesel_async::RunQueryDsl;
use uuid::Uuid;
#[derive(Debug, QueryableByName)]
pub struct StoredValueRow {
#[diesel(sql_type = Text)]
pub value: String,
}
#[derive(Debug, QueryableByName)]
pub struct StoredValueWithDistance {
#[diesel(sql_type = Text)]
pub value: String,
#[diesel(sql_type = Text)]
pub column_name: String,
#[diesel(sql_type = SqlUuid)]
pub column_id: Uuid,
}
const BATCH_SIZE: usize = 10_000;
const MAX_VALUE_LENGTH: usize = 50;
pub async fn ensure_stored_values_schema(organization_id: &Uuid) -> Result<()> {
let pool = get_pg_pool();
let mut conn = pool.get().await?;
// Create schema and table using raw SQL
let schema_name = organization_id.to_string().replace("-", "_");
let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS values_{}", schema_name);
let create_table_sql = format!(
"CREATE TABLE IF NOT EXISTS values_{}.values_v1 (
value text NOT NULL,
dataset_id uuid NOT NULL,
column_name text NOT NULL,
column_id uuid NOT NULL,
embedding vector(1024),
created_at timestamp with time zone NOT NULL DEFAULT now(),
UNIQUE(dataset_id, column_name, value)
)",
schema_name
);
let create_index_sql = format!(
"CREATE INDEX IF NOT EXISTS values_v1_embedding_idx
ON values_{}.values_v1
USING ivfflat (embedding vector_cosine_ops)",
schema_name
);
diesel::sql_query(create_schema_sql)
.execute(&mut conn)
.await?;
diesel::sql_query(create_table_sql)
.execute(&mut conn)
.await?;
diesel::sql_query(create_index_sql)
.execute(&mut conn)
.await?;
Ok(())
}
pub async fn store_column_values(
organization_id: &Uuid,
dataset_id: &Uuid,
column_name: &str,
column_id: &Uuid,
_data_source_id: &Uuid,
schema: &str,
table_name: &str,
) -> Result<()> {
let pool = get_pg_pool();
let mut conn = pool.get().await?;
// Create schema and table if they don't exist
ensure_stored_values_schema(organization_id).await?;
// Query distinct values in batches
let mut offset = 0;
let mut first_batch = true;
let schema_name = organization_id.to_string().replace("-", "_");
loop {
let query = format!(
"SELECT DISTINCT \"{}\" as value
FROM {}.{}
WHERE \"{}\" IS NOT NULL
AND length(\"{}\") <= {}
ORDER BY \"{}\"
LIMIT {} OFFSET {}",
column_name,
schema,
table_name,
column_name,
column_name,
MAX_VALUE_LENGTH,
column_name,
BATCH_SIZE,
offset
);
let results = match query_engine(dataset_id, &query, None).await {
Ok(results) => results.data,
Err(e) => {
tracing::error!("Error querying stored values: {:?}", e);
if first_batch {
return Err(e);
}
vec![]
}
};
if results.is_empty() {
break;
}
// Extract values from the query results
let values: Vec<String> = results
.into_iter()
.filter_map(|row| {
if let Some(DataType::Text(Some(value))) = row.get("value") {
Some(value.clone())
} else {
None
}
})
.collect();
if values.is_empty() {
break;
}
// If this is the first batch and we have 15 or fewer values, handle as enum
if first_batch && values.len() <= 15 {
// Get current description
let current_description =
diesel::sql_query("SELECT description FROM dataset_columns WHERE id = $1")
.bind::<SqlUuid, _>(column_id)
.get_result::<StoredValueRow>(&mut conn)
.await
.ok()
.and_then(|row| Some(row.value));
// Format new description
let enum_list = format!("Values for this column are: {}", values.join(", "));
let new_description = match current_description {
Some(desc) if !desc.is_empty() => format!("{}. {}", desc, enum_list),
_ => enum_list,
};
// Update column description
diesel::update(dataset_columns::table)
.filter(dataset_columns::id.eq(column_id))
.set((
dataset_columns::description.eq(new_description),
dataset_columns::stored_values_status.eq(StoredValuesStatus::Success),
dataset_columns::stored_values_count.eq(values.len() as i64),
dataset_columns::stored_values_last_synced.eq(Utc::now()),
))
.execute(&mut conn)
.await?;
return Ok(());
}
// Create embeddings for the batch
let embeddings = create_embeddings_batch(&values).await?;
// Insert values and embeddings
for (value, embedding) in values.iter().zip(embeddings.iter()) {
let insert_sql = format!(
"INSERT INTO {}_values.values_v1
(value, dataset_id, column_name, column_id, embedding, created_at)
VALUES ($1::text, $2::uuid, $3::text, $4::uuid, $5::vector, $6::timestamptz)
ON CONFLICT (dataset_id, column_name, value)
DO UPDATE SET created_at = EXCLUDED.created_at",
schema_name
);
diesel::sql_query(insert_sql)
.bind::<Text, _>(value)
.bind::<SqlUuid, _>(dataset_id)
.bind::<Text, _>(column_name)
.bind::<SqlUuid, _>(column_id)
.bind::<Array<Float4>, _>(embedding)
.bind::<Timestamptz, _>(Utc::now())
.execute(&mut conn)
.await?;
}
first_batch = false;
offset += BATCH_SIZE;
}
Ok(())
}
async fn create_embeddings_batch(values: &[String]) -> Result<Vec<Vec<f32>>> {
let embeddings = embedding_router(values.to_vec(), true).await?;
Ok(embeddings)
}
pub async fn search_stored_values(
organization_id: &Uuid,
dataset_id: &Uuid,
query_embedding: Vec<f32>,
limit: Option<i32>,
) -> Result<Vec<(String, String, Uuid)>> {
let pool = get_pg_pool();
let mut conn = pool.get().await?;
let limit = limit.unwrap_or(10);
let schema_name = organization_id.to_string().replace("-", "_");
let query = format!(
"SELECT value, column_name, column_id
FROM values_{}.values_v1
WHERE dataset_id = $2::uuid
ORDER BY embedding <=> $1::vector
LIMIT $3::integer",
schema_name
);
let results: Vec<StoredValueWithDistance> = diesel::sql_query(query)
.bind::<Array<Float4>, _>(query_embedding)
.bind::<SqlUuid, _>(dataset_id)
.bind::<Integer, _>(limit)
.load(&mut conn)
.await?;
Ok(results
.into_iter()
.map(|r| (r.value, r.column_name, r.column_id))
.collect())
}
pub struct StoredValueColumn {
pub organization_id: Uuid,
pub dataset_id: Uuid,
pub column_name: String,
pub column_id: Uuid,
pub data_source_id: Uuid,
pub schema: String,
pub table_name: String,
}
pub async fn process_stored_values_background(columns: Vec<StoredValueColumn>) {
for column in columns {
match store_column_values(
&column.organization_id,
&column.dataset_id,
&column.column_name,
&column.column_id,
&column.data_source_id,
&column.schema,
&column.table_name,
)
.await
{
Ok(_) => {
tracing::info!(
"Successfully processed stored values for column '{}' in dataset '{}'",
column.column_name,
column.table_name
);
}
Err(e) => {
tracing::error!(
"Failed to process stored values for column '{}' in dataset '{}': {:?}",
column.column_name,
column.table_name,
e
);
}
}
}
}

View File

@ -1,64 +0,0 @@
use anyhow::Result;
use cohere_rust::{api::rerank::{ReRankModel, ReRankRequest}, Cohere};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::utils::clients::ai::embedding_router::embedding_router;
use super::search_stored_values;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredValue {
pub value: String,
pub dataset_id: Uuid,
pub column_name: String,
pub column_id: Uuid,
}
pub async fn search_values_for_dataset(
organization_id: &Uuid,
dataset_id: &Uuid,
query: String,
) -> Result<Vec<StoredValue>> {
// Create embedding for the search query
let query_vec = vec![query.clone()];
let query_embedding = embedding_router(query_vec, true).await?[0].clone();
// Get initial candidates using vector similarity
let candidates = search_stored_values(
organization_id,
dataset_id,
query_embedding,
Some(25), // Get more candidates for reranking
).await?;
// Extract just the values for reranking
let candidate_values: Vec<String> = candidates.iter().map(|(value, _, _)| value.clone()).collect();
// Rerank the candidates using the cohere client
let co = Cohere::default();
let request = ReRankRequest {
query: query.as_str(),
documents: &candidate_values,
model: ReRankModel::EnglishV3,
top_n: Some(10),
..Default::default()
};
let response = co.rerank(&request).await?;
// Convert to StoredValue structs
let values = response.into_iter()
.map(|result| {
let (value, column_name, column_id) = candidates[result.index as usize].clone();
StoredValue {
value,
dataset_id: *dataset_id,
column_name,
column_id,
}
})
.collect();
Ok(values)
}

View File

@ -1,5 +0,0 @@
pub mod types;
pub mod type_mapping;
pub use types::*;
pub use type_mapping::*;

View File

@ -1,234 +0,0 @@
use database::enums::DataSourceType;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use query_engine::data_types::DataType;
// Standard types we use across all data sources
#[derive(Debug, Clone, PartialEq)]
pub enum StandardType {
Text,
Integer,
Float,
Boolean,
Date,
Timestamp,
Json,
Unknown,
}
impl StandardType {
pub fn from_data_type(data_type: &DataType) -> Self {
match data_type.simple_type() {
Some(simple_type) => match simple_type.as_str() {
"string" => StandardType::Text,
"number" => StandardType::Float,
"boolean" => StandardType::Boolean,
"date" => StandardType::Date,
"null" => StandardType::Unknown,
_ => StandardType::Unknown,
},
None => StandardType::Unknown,
}
}
pub fn from_str(type_str: &str) -> Self {
match type_str.to_lowercase().as_str() {
"text" | "string" | "varchar" | "char" => StandardType::Text,
"int" | "integer" | "bigint" | "smallint" => StandardType::Integer,
"float" | "double" | "decimal" | "numeric" => StandardType::Float,
"bool" | "boolean" => StandardType::Boolean,
"date" => StandardType::Date,
"timestamp" | "datetime" | "timestamptz" => StandardType::Timestamp,
"json" | "jsonb" => StandardType::Json,
_ => StandardType::Unknown,
}
}
}
// Type mappings for each data source type to DataType
static TYPE_MAPPINGS: Lazy<HashMap<DataSourceType, HashMap<&'static str, DataType>>> =
Lazy::new(|| {
let mut mappings = HashMap::new();
// Postgres mappings
let mut postgres = HashMap::new();
postgres.insert("text", DataType::Text(None));
postgres.insert("varchar", DataType::Text(None));
postgres.insert("char", DataType::Text(None));
postgres.insert("int", DataType::Int4(None));
postgres.insert("integer", DataType::Int4(None));
postgres.insert("bigint", DataType::Int8(None));
postgres.insert("smallint", DataType::Int2(None));
postgres.insert("float", DataType::Float4(None));
postgres.insert("double precision", DataType::Float8(None));
postgres.insert("numeric", DataType::Decimal(None));
postgres.insert("decimal", DataType::Decimal(None));
postgres.insert("boolean", DataType::Bool(None));
postgres.insert("date", DataType::Date(None));
postgres.insert("timestamp", DataType::Timestamp(None));
postgres.insert("timestamptz", DataType::Timestamptz(None));
postgres.insert("json", DataType::Json(None));
postgres.insert("jsonb", DataType::Json(None));
mappings.insert(DataSourceType::Postgres, postgres.clone());
mappings.insert(DataSourceType::Supabase, postgres);
// MySQL mappings
let mut mysql = HashMap::new();
mysql.insert("varchar", DataType::Text(None));
mysql.insert("text", DataType::Text(None));
mysql.insert("char", DataType::Text(None));
mysql.insert("int", DataType::Int4(None));
mysql.insert("bigint", DataType::Int8(None));
mysql.insert("tinyint", DataType::Int2(None));
mysql.insert("float", DataType::Float4(None));
mysql.insert("double", DataType::Float8(None));
mysql.insert("decimal", DataType::Decimal(None));
mysql.insert("boolean", DataType::Bool(None));
mysql.insert("date", DataType::Date(None));
mysql.insert("datetime", DataType::Timestamp(None));
mysql.insert("timestamp", DataType::Timestamptz(None));
mysql.insert("json", DataType::Json(None));
mappings.insert(DataSourceType::MySql, mysql.clone());
mappings.insert(DataSourceType::Mariadb, mysql);
// BigQuery mappings
let mut bigquery = HashMap::new();
bigquery.insert("STRING", DataType::Text(None));
bigquery.insert("INT64", DataType::Int8(None));
bigquery.insert("INTEGER", DataType::Int4(None));
bigquery.insert("FLOAT64", DataType::Float8(None));
bigquery.insert("NUMERIC", DataType::Decimal(None));
bigquery.insert("BOOL", DataType::Bool(None));
bigquery.insert("DATE", DataType::Date(None));
bigquery.insert("TIMESTAMP", DataType::Timestamptz(None));
bigquery.insert("JSON", DataType::Json(None));
mappings.insert(DataSourceType::BigQuery, bigquery);
// Snowflake mappings
let mut snowflake = HashMap::new();
snowflake.insert("TEXT", DataType::Text(None));
snowflake.insert("VARCHAR", DataType::Text(None));
snowflake.insert("CHAR", DataType::Text(None));
snowflake.insert("NUMBER", DataType::Decimal(None));
snowflake.insert("DECIMAL", DataType::Decimal(None));
snowflake.insert("INTEGER", DataType::Int4(None));
snowflake.insert("BIGINT", DataType::Int8(None));
snowflake.insert("BOOLEAN", DataType::Bool(None));
snowflake.insert("DATE", DataType::Date(None));
snowflake.insert("TIMESTAMP", DataType::Timestamptz(None));
snowflake.insert("VARIANT", DataType::Json(None));
mappings.insert(DataSourceType::Snowflake, snowflake);
mappings
});
pub fn normalize_type(source_type: DataSourceType, type_str: &str) -> DataType {
TYPE_MAPPINGS
.get(&source_type)
.and_then(|mappings| mappings.get(type_str))
.cloned()
.unwrap_or(DataType::Unknown(Some(type_str.to_string())))
}
pub fn types_compatible(source_type: DataSourceType, ds_type: &str, model_type: &str) -> bool {
let ds_data_type = normalize_type(source_type, ds_type);
let model_data_type = normalize_type(source_type, model_type);
match (&ds_data_type, &model_data_type) {
// Allow integer -> float/decimal conversions
(DataType::Int2(_), DataType::Float4(_)) => true,
(DataType::Int2(_), DataType::Float8(_)) => true,
(DataType::Int2(_), DataType::Decimal(_)) => true,
(DataType::Int4(_), DataType::Float4(_)) => true,
(DataType::Int4(_), DataType::Float8(_)) => true,
(DataType::Int4(_), DataType::Decimal(_)) => true,
(DataType::Int8(_), DataType::Float4(_)) => true,
(DataType::Int8(_), DataType::Float8(_)) => true,
(DataType::Int8(_), DataType::Decimal(_)) => true,
// Allow text for any type (common in views/computed columns)
(DataType::Text(_), _) => true,
// Allow timestamp/timestamptz compatibility
(DataType::Timestamp(_), DataType::Timestamptz(_)) => true,
(DataType::Timestamptz(_), DataType::Timestamp(_)) => true,
// Exact matches (using to_string to compare types, not values)
(a, b) if a.to_string() == b.to_string() => true,
// Everything else is incompatible
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postgres_type_normalization() {
assert!(matches!(
normalize_type(DataSourceType::Postgres, "text"),
DataType::Text(_)
));
assert!(matches!(
normalize_type(DataSourceType::Postgres, "integer"),
DataType::Int4(_)
));
assert!(matches!(
normalize_type(DataSourceType::Postgres, "numeric"),
DataType::Decimal(_)
));
}
#[test]
fn test_bigquery_type_normalization() {
assert!(matches!(
normalize_type(DataSourceType::BigQuery, "STRING"),
DataType::Text(_)
));
assert!(matches!(
normalize_type(DataSourceType::BigQuery, "INT64"),
DataType::Int8(_)
));
assert!(matches!(
normalize_type(DataSourceType::BigQuery, "FLOAT64"),
DataType::Float8(_)
));
}
#[test]
fn test_type_compatibility() {
// Same types are compatible
assert!(types_compatible(DataSourceType::Postgres, "text", "text"));
// Integer can be used as float
assert!(types_compatible(
DataSourceType::Postgres,
"integer",
"float"
));
// Text can be used for any type
assert!(types_compatible(
DataSourceType::Postgres,
"text",
"integer"
));
// Different types are incompatible
assert!(!types_compatible(
DataSourceType::Postgres,
"integer",
"text"
));
// Timestamp compatibility
assert!(types_compatible(
DataSourceType::Postgres,
"timestamp",
"timestamptz"
));
}
}

View File

@ -1,182 +0,0 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ValidationResult {
pub success: bool,
pub model_name: String,
pub data_source_name: String,
pub schema: String,
pub errors: Vec<ValidationError>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ValidationError {
pub error_type: ValidationErrorType,
pub column_name: Option<String>,
pub message: String,
pub suggestion: Option<String>,
pub context: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub enum ValidationErrorType {
TableNotFound,
ColumnNotFound,
TypeMismatch,
DataSourceError,
ModelNotFound,
InvalidRelationship,
ExpressionError,
ProjectNotFound,
InvalidBusterYml,
DataSourceMismatch,
RequiredFieldMissing,
DataSourceNotFound,
SchemaError,
CredentialsError,
DatabaseError,
ValidationError,
InternalError,
}
impl ValidationResult {
pub fn new(model_name: String, data_source_name: String, schema: String) -> Self {
Self {
success: true,
model_name,
data_source_name,
schema,
errors: Vec::new(),
}
}
pub fn add_error(&mut self, error: ValidationError) {
self.success = false;
self.errors.push(error);
}
}
impl ValidationError {
pub fn new(
error_type: ValidationErrorType,
column_name: Option<String>,
message: String,
suggestion: Option<String>,
) -> Self {
Self {
error_type,
column_name,
message,
suggestion,
context: None,
}
}
pub fn with_context(mut self, context: String) -> Self {
self.context = Some(context);
self
}
pub fn table_not_found(table_name: &str) -> Self {
Self::new(
ValidationErrorType::TableNotFound,
None,
format!("Table '{}' not found in data source", table_name),
None,
)
}
pub fn column_not_found(column_name: &str) -> Self {
Self::new(
ValidationErrorType::ColumnNotFound,
Some(column_name.to_string()),
format!("Column '{}' not found in data source", column_name),
None,
)
}
pub fn type_mismatch(column_name: &str, expected: &str, found: &str) -> Self {
Self::new(
ValidationErrorType::TypeMismatch,
Some(column_name.to_string()),
format!(
"Column '{}' type mismatch. Expected: {}, Found: {}",
column_name, expected, found
),
None,
)
}
pub fn data_source_error(message: String) -> Self {
Self::new(
ValidationErrorType::DataSourceError,
None,
message,
None,
)
}
pub fn model_not_found(model_name: &str) -> Self {
Self::new(
ValidationErrorType::ModelNotFound,
None,
format!("Model '{}' not found in data source", model_name),
None,
)
}
pub fn invalid_relationship(from: &str, to: &str, reason: &str) -> Self {
Self::new(
ValidationErrorType::InvalidRelationship,
None,
format!("Invalid relationship from '{}' to '{}': {}", from, to, reason),
None,
)
}
pub fn expression_error(column_name: &str, expr: &str, reason: &str) -> Self {
Self::new(
ValidationErrorType::ExpressionError,
Some(column_name.to_string()),
format!("Invalid expression '{}' for column '{}': {}", expr, column_name, reason),
None,
)
}
// New factory methods for enhanced error types
pub fn schema_error(schema_name: &str, reason: &str) -> Self {
Self::new(
ValidationErrorType::SchemaError,
None,
format!("Schema '{}' error: {}", schema_name, reason),
None,
)
}
pub fn credentials_error(data_source: &str, reason: &str) -> Self {
Self::new(
ValidationErrorType::CredentialsError,
None,
format!("Credentials error for data source '{}': {}", data_source, reason),
None,
)
}
pub fn database_error(message: String) -> Self {
Self::new(
ValidationErrorType::DatabaseError,
None,
message,
None,
)
}
pub fn internal_error(message: String) -> Self {
Self::new(
ValidationErrorType::InternalError,
None,
message,
None,
)
}
}