This commit is contained in:
dal 2025-03-21 12:54:54 -06:00
parent 39385acf9d
commit dba826d874
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
35 changed files with 174 additions and 494 deletions

View File

@ -1020,6 +1020,9 @@ mod tests {
)
.await?;
let _params = params.as_object().unwrap();
let _tool_call_id = tool_call_id.clone();
// Simulate a delay
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
@ -1088,7 +1091,7 @@ mod tests {
vec![AgentMessage::user("Hello, world!".to_string())],
);
let response = match agent.process_thread(&thread).await {
let _response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
@ -1111,7 +1114,7 @@ mod tests {
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
// Add tool to agent
agent.add_tool(weather_tool.get_name(), weather_tool);
let _ = agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
None,
@ -1121,7 +1124,7 @@ mod tests {
)],
);
let response = match agent.process_thread(&thread).await {
let _response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};
@ -1142,7 +1145,7 @@ mod tests {
let weather_tool = WeatherTool::new(Arc::new(agent.clone()));
agent.add_tool(weather_tool.get_name(), weather_tool);
let _ = agent.add_tool(weather_tool.get_name(), weather_tool);
let thread = AgentThread::new(
None,
@ -1152,7 +1155,7 @@ mod tests {
)],
);
let response = match agent.process_thread(&thread).await {
let _response = match agent.process_thread(&thread).await {
Ok(response) => response,
Err(e) => panic!("Error processing thread: {:?}", e),
};

View File

@ -1063,12 +1063,6 @@ mod tests {
use uuid::Uuid;
// Mock functions for testing
#[cfg(test)]
pub(crate) async fn validate_metric_ids(ids: &[Uuid]) -> Result<Vec<Uuid>> {
// For tests, just return an empty vector indicating all IDs are valid
Ok(Vec::new())
}
#[tokio::test]
async fn test_validate_sql_empty() {

View File

@ -350,7 +350,6 @@ async fn get_modify_dashboards_content_to_replace_description() -> String {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::tools::categories::file_tools::common::{
@ -433,15 +432,6 @@ mod tests {
#[test]
fn test_tool_parameter_validation() {
let tool = ModifyDashboardFilesTool {
agent: Arc::new(Agent::new(
"o3-mini".to_string(),
HashMap::new(),
Uuid::new_v4(),
Uuid::new_v4(),
"test_agent".to_string(),
)),
};
// Test valid parameters
let valid_params = json!({

View File

@ -379,7 +379,6 @@ async fn get_metric_id_description() -> String {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::tools::file_tools::common::{apply_modifications_to_content, Modification};
@ -459,15 +458,6 @@ mod tests {
#[test]
fn test_tool_parameter_validation() {
let tool = ModifyMetricFilesTool {
agent: Arc::new(Agent::new(
"o3-mini".to_string(),
HashMap::new(),
Uuid::new_v4(),
Uuid::new_v4(),
"test_agent".to_string(),
)),
};
// Test valid parameters
let valid_params = json!({

View File

@ -466,7 +466,7 @@ mod tests {
"yml_content": "description: Test dataset\nschema:\n - name: id\n type: uuid"
});
let parsed = parse_search_result(&result).unwrap();
let _ = parse_search_result(&result).unwrap();
}
#[test]

View File

@ -40,7 +40,6 @@ pub struct MessageWithUser {
#[derive(Queryable)]
struct AssetPermissionInfo {
identity_id: Uuid,
role: AssetPermissionRole,
email: String,
name: Option<String>,
@ -54,8 +53,8 @@ pub async fn get_chat_handler(chat_id: &Uuid, user_id: &Uuid) -> Result<ChatWith
Err(e) => return Err(anyhow!("Failed to get database connection: {}", e)),
};
let chat_id = chat_id.clone();
let user_id = user_id.clone();
let chat_id = *chat_id;
let user_id = *user_id;
tokio::spawn(async move {
chats::table
@ -84,7 +83,7 @@ pub async fn get_chat_handler(chat_id: &Uuid, user_id: &Uuid) -> Result<ChatWith
Err(e) => return Err(anyhow!("Failed to get database connection: {}", e)),
};
let chat_id = chat_id.clone();
let chat_id = *chat_id;
tokio::spawn(async move {
messages::table
@ -115,7 +114,7 @@ pub async fn get_chat_handler(chat_id: &Uuid, user_id: &Uuid) -> Result<ChatWith
Err(e) => return Err(anyhow!("Failed to get database connection: {}", e)),
};
let chat_id = chat_id.clone();
let chat_id = *chat_id;
tokio::spawn(async move {
// Query individual permissions for this chat
@ -126,7 +125,6 @@ pub async fn get_chat_handler(chat_id: &Uuid, user_id: &Uuid) -> Result<ChatWith
.filter(asset_permissions::identity_type.eq(IdentityType::User))
.filter(asset_permissions::deleted_at.is_null())
.select((
asset_permissions::identity_id,
asset_permissions::role,
users::email,
users::name,

View File

@ -5,7 +5,10 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use database::{pool::get_pg_pool, schema::messages};
use database::{
pool::get_pg_pool,
schema::{chats, messages},
};
#[derive(Debug, Serialize, Deserialize)]
pub struct GetRawLlmMessagesRequest {
@ -18,13 +21,18 @@ pub struct GetRawLlmMessagesResponse {
pub raw_llm_messages: Value,
}
pub async fn get_raw_llm_messages_handler(chat_id: Uuid) -> Result<GetRawLlmMessagesResponse> {
pub async fn get_raw_llm_messages_handler(
chat_id: Uuid,
organization_id: Uuid,
) -> Result<GetRawLlmMessagesResponse> {
let pool = get_pg_pool();
let mut conn = pool.get().await?;
// Get messages for the chat, ordered by creation time
let raw_llm_messages: Value = messages::table
.inner_join(chats::table.on(messages::chat_id.eq(chats::id)))
.filter(messages::chat_id.eq(chat_id))
.filter(chats::organization_id.eq(organization_id))
.filter(messages::deleted_at.is_null())
.order_by(messages::created_at.desc())
.select(messages::raw_llm_messages)

View File

@ -26,7 +26,7 @@ pub async fn create_chat_sharing_handler(
) -> Result<()> {
// 1. Validate the chat exists
let chat_exists = get_chat_exists(chat_id).await?;
if !chat_exists {
return Err(anyhow!("Chat not found"));
}
@ -38,7 +38,8 @@ pub async fn create_chat_sharing_handler(
*user_id,
IdentityType::User,
AssetPermissionRole::FullAccess, // Owner role implicitly has FullAccess permissions
).await?;
)
.await?;
if !has_permission {
return Err(anyhow!("User does not have permission to share this chat"));
@ -46,19 +47,22 @@ pub async fn create_chat_sharing_handler(
// 3. Process each email and create sharing permissions
for (email, role) in emails_and_roles {
match create_share_by_email(
&email,
*chat_id,
AssetType::Chat,
role,
*user_id,
).await {
match create_share_by_email(&email, *chat_id, AssetType::Chat, role, *user_id).await {
Ok(_) => {
tracing::info!("Created sharing permission for email: {} on chat: {} with role: {:?}", email, chat_id, role);
},
tracing::info!(
"Created sharing permission for email: {} on chat: {} with role: {:?}",
email,
chat_id,
role
);
}
Err(e) => {
tracing::error!("Failed to create sharing for email {}: {}", email, e);
return Err(anyhow!("Failed to create sharing for email {}: {}", email, e));
return Err(anyhow!(
"Failed to create sharing for email {}: {}",
email,
e
));
}
}
}
@ -69,55 +73,16 @@ pub async fn create_chat_sharing_handler(
/// Helper function to check if a chat exists
pub async fn get_chat_exists(chat_id: &Uuid) -> Result<bool> {
let mut conn = get_pg_pool().get().await?;
let chat_exists = chats::table
.filter(chats::id.eq(chat_id))
.filter(chats::deleted_at.is_null())
.count()
.get_result::<i64>(&mut conn)
.await?;
Ok(chat_exists > 0)
}
#[cfg(test)]
mod tests {
use super::*;
use database::enums::{AssetPermissionRole, AssetType, IdentityType};
use uuid::Uuid;
// Mock function to test permission checking
async fn mock_has_permission(
_asset_id: Uuid,
_asset_type: AssetType,
_identity_id: Uuid,
_identity_type: IdentityType,
_required_role: AssetPermissionRole,
) -> Result<bool> {
// For testing, return true to simulate having permission
Ok(true)
}
// This test would require a test database setup
// Mock implementation to demonstrate testing approach
#[tokio::test]
async fn test_create_chat_sharing_handler_permissions() {
// This test would need a properly mocked database
// Just demonstrating the structure
// Setup test data
let chat_id = Uuid::new_v4();
let user_id = Uuid::new_v4();
let emails_and_roles = vec![
("test@example.com".to_string(), AssetPermissionRole::Viewer),
];
// In a real test, we would use a test database
// and set up the necessary mocks
// Example assertion
// let result = create_chat_sharing_handler(&chat_id, &user_id, emails_and_roles).await;
// assert!(result.is_ok());
}
}
mod tests {}

View File

@ -19,7 +19,6 @@ use crate::collections::types::{
#[derive(Queryable)]
struct AssetPermissionInfo {
identity_id: Uuid,
role: AssetPermissionRole,
email: String,
name: Option<String>,
@ -87,7 +86,6 @@ pub async fn get_collection_handler(
.filter(asset_permissions::identity_type.eq(IdentityType::User))
.filter(asset_permissions::deleted_at.is_null())
.select((
asset_permissions::identity_id,
asset_permissions::role,
users::email,
users::name,

View File

@ -97,22 +97,11 @@ pub async fn create_collection_sharing_handler(
#[cfg(test)]
mod tests {
use super::*;
use database::enums::AssetPermissionRole;
use uuid::Uuid;
#[tokio::test]
async fn test_create_collection_sharing_collection_not_found() {
// Test case: Collection not found
// Expected: Error with "Collection not found" message
let collection_id = Uuid::new_v4();
let user_id = Uuid::new_v4();
let request = vec![ShareRecipient {
email: "test@example.com".to_string(),
role: AssetPermissionRole::Viewer,
}];
// Since we can't easily mock the function in an integration test
// This is just a placeholder for the real test
// A proper test would use a test database or more sophisticated mocking

View File

@ -82,17 +82,12 @@ pub async fn delete_collection_sharing_handler(
#[cfg(test)]
mod tests {
use uuid::Uuid;
#[tokio::test]
async fn test_delete_collection_sharing_collection_not_found() {
// Test case: Collection not found
// Expected: Error with "Collection not found" message
let collection_id = Uuid::new_v4();
let user_id = Uuid::new_v4();
let emails = vec!["test@example.com".to_string()];
// Since we can't easily mock the function in an integration test
// This is just a placeholder for the real test
// A proper test would use a test database or more sophisticated mocking

View File

@ -1,9 +1,7 @@
use anyhow::{anyhow, Result};
use chrono::Utc;
use database::{
collections::fetch_collection,
enums::AssetPermissionRole,
pool::get_pg_pool,
collections::fetch_collection, enums::AssetPermissionRole, pool::get_pg_pool,
schema::collections,
};
use diesel::{update, ExpressionMethods};
@ -12,9 +10,7 @@ use std::sync::Arc;
use tokio;
use uuid::Uuid;
use crate::collections::types::{
CollectionState, UpdateCollectionObject, UpdateCollectionRequest,
};
use crate::collections::types::{CollectionState, UpdateCollectionObject, UpdateCollectionRequest};
/// Handler for updating a collection
///
@ -186,11 +182,6 @@ async fn update_collection_record(
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Once;
use uuid::Uuid;
static INIT: Once = Once::new();
// Notice: This is a basic smoke test to check that our changes compile and work
// It doesn't actually hit the database, but verifies that the handler's structure is correct
@ -211,11 +202,14 @@ mod tests {
// Check that our handler function accepts the request with the correct types
// This is mostly a compilation test to verify our refactoring didn't break the interface
let result = update_collection_handler(&user_id, collection_id, req).await;
// We expect an error since we're not actually hitting the database
assert!(result.is_err());
// Check that the error contains the expected message
assert!(result.unwrap_err().to_string().contains("Collection not found"));
assert!(result
.unwrap_err()
.to_string()
.contains("Collection not found"));
}
}

View File

@ -25,7 +25,9 @@ struct QueryableDashboardFile {
name: String,
file_name: String,
content: Value,
#[allow(dead_code)]
filter: Option<String>,
#[allow(dead_code)]
organization_id: Uuid,
created_by: Uuid,
created_at: chrono::DateTime<chrono::Utc>,
@ -38,7 +40,6 @@ struct QueryableDashboardFile {
#[derive(Queryable)]
struct AssetPermissionInfo {
identity_id: Uuid,
role: AssetPermissionRole,
email: String,
name: Option<String>,
@ -148,7 +149,6 @@ pub async fn get_dashboard_handler(dashboard_id: &Uuid, user_id: &Uuid, version_
.filter(asset_permissions::identity_type.eq(IdentityType::User))
.filter(asset_permissions::deleted_at.is_null())
.select((
asset_permissions::identity_id,
asset_permissions::role,
users::email,
users::name,

View File

@ -1,19 +1,11 @@
use anyhow::{anyhow, Result};
use diesel::{
ExpressionMethods, QueryDsl,
Queryable, Selectable,
};
use chrono::{DateTime, Utc};
use diesel::{ExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use chrono::{DateTime, Utc};
use database::{
enums::Verification,
pool::get_pg_pool,
schema::dashboard_files,
};
use database::{enums::Verification, pool::get_pg_pool, schema::dashboard_files};
use super::{BusterDashboardListItem, DashboardMember};
@ -29,16 +21,6 @@ pub struct DashboardsListRequest {
pub only_my_dashboards: Option<bool>,
}
#[derive(Queryable, Selectable)]
#[diesel(table_name = dashboard_files)]
struct QueryableDashboardFile {
id: Uuid,
name: String,
created_by: Uuid,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
pub async fn list_dashboard_handler(
_user_id: &Uuid,
request: DashboardsListRequest,
@ -62,20 +44,17 @@ pub async fn list_dashboard_handler(
))
.filter(dashboard_files::deleted_at.is_null())
.distinct()
.order((dashboard_files::updated_at.desc(), dashboard_files::id.asc()))
.order((
dashboard_files::updated_at.desc(),
dashboard_files::id.asc(),
))
.offset(offset)
.limit(request.page_size)
.into_boxed();
// Execute the query
let dashboard_results = match dashboard_statement
.load::<(
Uuid,
String,
Uuid,
DateTime<Utc>,
DateTime<Utc>,
)>(&mut conn)
.load::<(Uuid, String, Uuid, DateTime<Utc>, DateTime<Utc>)>(&mut conn)
.await
{
Ok(results) => results,
@ -85,26 +64,24 @@ pub async fn list_dashboard_handler(
// Transform query results into BusterDashboardListItem
let dashboards = dashboard_results
.into_iter()
.map(
|(id, name, created_by, created_at, updated_at)| {
let owner = DashboardMember {
id: created_by,
name: "Unknown".to_string(),
avatar_url: None,
};
.map(|(id, name, created_by, created_at, updated_at)| {
let owner = DashboardMember {
id: created_by,
name: "Unknown".to_string(),
avatar_url: None,
};
BusterDashboardListItem {
id,
name,
created_at,
last_edited: updated_at,
owner,
members: vec![],
status: Verification::Verified, // Default status, can be updated if needed
is_shared: false,
}
},
)
BusterDashboardListItem {
id,
name,
created_at,
last_edited: updated_at,
owner,
members: vec![],
status: Verification::Verified, // Default status, can be updated if needed
is_shared: false,
}
})
.collect();
Ok(dashboards)

View File

@ -568,55 +568,6 @@ async fn get_chats_from_collections(
Ok(chat_objects)
}
async fn get_threads_from_collections(
collection_ids: &[Uuid],
) -> Result<Vec<(Uuid, FavoriteObject)>> {
let mut conn = match get_pg_pool().get().await {
Ok(conn) => conn,
Err(e) => return Err(anyhow!("Error getting connection from pool: {:?}", e)),
};
let threads_records: Vec<(Uuid, Uuid, Option<String>)> = match threads_deprecated::table
.inner_join(
collections_to_assets::table.on(threads_deprecated::id.eq(collections_to_assets::asset_id)),
)
.inner_join(messages_deprecated::table.on(threads_deprecated::id.eq(messages_deprecated::thread_id)))
.select((
collections_to_assets::collection_id,
threads_deprecated::id,
messages_deprecated::title,
))
.filter(collections_to_assets::asset_type.eq(AssetType::Thread))
.filter(collections_to_assets::collection_id.eq_any(collection_ids))
.filter(threads_deprecated::deleted_at.is_null())
.filter(collections_to_assets::deleted_at.is_null())
.filter(messages_deprecated::deleted_at.is_null())
.filter(messages_deprecated::draft_session_id.is_null())
.order((threads_deprecated::id, messages_deprecated::created_at.desc()))
.distinct_on(threads_deprecated::id)
.load::<(Uuid, Uuid, Option<String>)>(&mut conn)
.await
{
Ok(threads_records) => threads_records,
Err(e) => return Err(anyhow!("Error loading threads records: {:?}", e)),
};
let thread_objects: Vec<(Uuid, FavoriteObject)> = threads_records
.iter()
.map(|(collection_id, id, name)| {
(
*collection_id,
FavoriteObject {
id: *id,
name: name.clone().unwrap_or_else(|| String::from("Untitled")),
type_: AssetType::Thread,
},
)
})
.collect();
Ok(thread_objects)
}
async fn get_favorite_metrics(metric_ids: Arc<Vec<Uuid>>) -> Result<Vec<FavoriteObject>> {
let mut conn = match get_pg_pool().get().await {
Ok(conn) => conn,

View File

@ -30,7 +30,7 @@ pub struct LogListItem {
pub struct PaginationInfo {
pub has_more: bool,
pub next_page: Option<i32>,
pub total_items: i32, // Number of items in current page
pub total_items: i32, // Number of items in current page
}
#[derive(Debug, Serialize, Deserialize)]
@ -53,35 +53,37 @@ struct ChatWithUser {
}
/// List logs with pagination support
///
///
/// This function efficiently retrieves a list of chats (logs) with their associated user information.
/// It supports pagination using page number and limits results using page_size.
/// Unlike the regular chats endpoint, logs are not restricted to the user and are visible to everyone.
///
///
/// Returns a list of log items with user information and pagination details.
pub async fn list_logs_handler(
request: ListLogsRequest,
organization_id: Uuid,
) -> Result<Vec<LogListItem>, anyhow::Error> {
use database::schema::{chats, users};
let mut conn = get_pg_pool().get().await?;
// Start building the query
let mut query = chats::table
.inner_join(users::table.on(chats::created_by.eq(users::id)))
.filter(chats::deleted_at.is_null())
.filter(chats::organization_id.eq(organization_id))
.into_boxed();
// Calculate offset based on page number
let page = request.page.unwrap_or(1);
let offset = (page - 1) * request.page_size;
// Order by creation date descending and apply pagination
query = query
.order_by(chats::created_at.desc())
.offset(offset as i64)
.limit((request.page_size + 1) as i64);
// Execute query and select required fields
let results: Vec<ChatWithUser> = query
.select((
@ -95,18 +97,19 @@ pub async fn list_logs_handler(
))
.load::<ChatWithUser>(&mut conn)
.await?;
// Check if there are more results and prepare pagination info
let has_more = results.len() > request.page_size as usize;
let items: Vec<LogListItem> = results
.into_iter()
.take(request.page_size as usize)
.map(|chat| {
let created_by_avatar = chat.user_attributes
let created_by_avatar = chat
.user_attributes
.get("avatar")
.and_then(|v| v.as_str())
.map(String::from);
LogListItem {
id: chat.id.to_string(),
title: chat.title,
@ -127,6 +130,6 @@ pub async fn list_logs_handler(
next_page: if has_more { Some(page + 1) } else { None },
total_items: items.len() as i32,
};
Ok(items)
}

View File

@ -24,6 +24,7 @@ struct QueryableMetricFile {
file_name: String,
content: MetricYml,
verification: Verification,
#[allow(dead_code)]
evaluation_obj: Option<Value>,
evaluation_summary: Option<String>,
evaluation_score: Option<f64>,
@ -42,20 +43,8 @@ struct DatasetInfo {
name: String,
}
#[derive(Queryable, Selectable)]
#[diesel(table_name = users)]
struct UserInfo {
id: Uuid,
email: String,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
name: Option<String>,
#[diesel(sql_type = diesel::sql_types::Nullable<diesel::sql_types::Text>)]
avatar_url: Option<String>,
}
#[derive(Queryable)]
struct AssetPermissionInfo {
identity_id: Uuid,
role: AssetPermissionRole,
email: String,
name: Option<String>,
@ -225,7 +214,6 @@ pub async fn get_metric_handler(
.filter(asset_permissions::identity_type.eq(IdentityType::User))
.filter(asset_permissions::deleted_at.is_null())
.select((
asset_permissions::identity_id,
asset_permissions::role,
users::email,
users::name,

View File

@ -29,7 +29,9 @@ pub async fn setup_test_environment() -> Result<()> {
// Initialize database pools only once
INIT.call_once(|| {
init_pools();
// Create a runtime for the sync context and block on the async init_pools function
let rt = tokio::runtime::Runtime::new().expect("Failed to create runtime");
let _ = rt.block_on(init_pools());
});
Ok(())

View File

@ -1005,9 +1005,7 @@ mod tests {
id,
content,
tool_calls,
name,
progress,
initial,
..
} => {
assert_eq!(id, &None);
assert_eq!(content, &None);

View File

@ -43,21 +43,6 @@ pub async fn find_user_by_email(email: &str) -> Result<Option<User>> {
#[cfg(test)]
mod tests {
use super::*;
use database::models::User;
use uuid::Uuid;
fn mock_user() -> User {
User {
id: Uuid::new_v4(),
email: "test@example.com".to_string(),
name: Some("Test User".to_string()),
config: serde_json::Value::Null,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
attributes: serde_json::Value::Null,
avatar_url: Some("https://example.com/avatar.png".to_string()),
}
}
// Test for invalid email format
#[tokio::test]

View File

@ -1,175 +0,0 @@
use anyhow::{anyhow, Result};
use database::schema::{api_keys, users};
use diesel::{ExpressionMethods, JoinOnDsl, QueryDsl};
use diesel_async::RunQueryDsl;
use std::{collections::HashMap, env};
use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::database::{models::User, pool::get_pg_pool};
/// Authentication is done via Bearer token with a JWT issued from Supabase. We also offer API access that
/// is done via a Bearer token with a JWT issued from us.
///
/// The user ID is always included as the `sub` in the JWT.
///
/// In the JWT that we issue, we provide an extra field of `api_key_id` that contains the ID of the API key that is being used.
/// The reason why we have the api_key_id is because a user could have multiple API keys and we want to be able to route to the correct key and track it.
#[derive(Serialize, Deserialize, Debug)]
struct JwtClaims {
pub aud: String,
pub sub: String,
pub exp: u64,
}
pub async fn auth(mut req: Request, next: Next) -> Result<Response, StatusCode> {
let is_ws = req
.headers()
.get("upgrade")
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
let handle_auth_error = |msg: &str| {
if is_ws {
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("Sec-WebSocket-Protocol", "close")
.header("Sec-WebSocket-Close-Code", "4001") // Custom close code
.header("Sec-WebSocket-Close-Reason", msg)
.body(axum::body::Body::empty())
.unwrap())
} else {
Err(StatusCode::UNAUTHORIZED)
}
};
let buster_wh_token = env::var("BUSTER_WH_TOKEN").expect("BUSTER_WH_TOKEN is not set");
let bearer_token = req.headers().get("Authorization").and_then(|value| {
value.to_str().ok().and_then(|v| {
if v.starts_with("Bearer ") {
v.split_whitespace().nth(1)
} else {
Some(v)
}
})
});
if let Some(token) = bearer_token {
if token == buster_wh_token {
return Ok(next.run(req).await);
}
}
let token = if bearer_token.is_none() {
match req
.uri()
.query()
.and_then(|query| serde_urlencoded::from_str::<HashMap<String, String>>(query).ok())
.and_then(|params| params.get("authentication").cloned())
{
Some(token) => token,
None => {
tracing::error!("No token found in request");
return handle_auth_error("No token found");
}
}
} else {
bearer_token.unwrap().to_string()
};
let user = match authorize_current_user(&token).await {
Ok(user) => match user {
Some(user) => user,
None => return Err(StatusCode::UNAUTHORIZED),
},
Err(e) => {
tracing::error!("Authorization error: {}", e);
return handle_auth_error("invalid jwt");
}
};
req.extensions_mut().insert(user);
Ok(next.run(req).await)
}
async fn authorize_current_user(token: &str) -> Result<Option<User>> {
let pg_pool = get_pg_pool();
let _conn = pg_pool.get().await.map_err(|e| {
tracing::error!("Pool connection error in auth: {:?}", e);
anyhow!("Database connection error in auth")
})?;
let key = env::var("JWT_SECRET").expect("JWT_SECRET is not set");
let mut validation = Validation::new(Algorithm::HS256);
validation.set_audience(&["authenticated", "api"]);
let token_data =
match decode::<JwtClaims>(token, &DecodingKey::from_secret(key.as_ref()), &validation) {
Ok(jwt_claims) => jwt_claims.claims,
Err(e) => {
return Err(anyhow!("Error while decoding the token: {}", e));
}
};
let user = match token_data.aud.contains("api") {
true => find_user_by_api_key(token).await,
false => find_user_by_id(&Uuid::parse_str(&token_data.sub).unwrap()).await,
};
let user = match user {
Ok(user) => user,
Err(e) => {
tracing::error!("Error while querying user: {}", e);
return Err(anyhow!("Error while querying user: {}", e));
}
};
Ok(user)
}
async fn find_user_by_id(id: &Uuid) -> Result<Option<User>> {
let mut conn = match get_pg_pool().get().await {
Ok(conn) => conn,
Err(e) => return Err(anyhow!("Error while querying user: {}", e)),
};
let user = match users::table
.filter(users::id.eq(id))
.first::<User>(&mut conn)
.await
{
Ok(user) => user,
Err(e) => return Err(anyhow!("Error while querying user: {}", e)),
};
Ok(Some(user))
}
async fn find_user_by_api_key(token: &str) -> Result<Option<User>> {
let mut conn = match get_pg_pool().get().await {
Ok(conn) => conn,
Err(e) => return Err(anyhow!("Error while querying user: {}", e)),
};
let user = match users::table
.inner_join(api_keys::table.on(users::id.eq(api_keys::owner_id)))
.filter(api_keys::key.eq(token))
.filter(api_keys::deleted_at.is_null())
.select(users::all_columns)
.first::<User>(&mut conn)
.await
{
Ok(user) => user,
Err(e) => return Err(anyhow!("Error while querying user: {}", e)),
};
Ok(Some(user))
}

View File

@ -1,12 +0,0 @@
use axum::http::{header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}, Method};
use tower_http::cors::{Any, CorsLayer};
pub fn cors() -> CorsLayer {
let cors = CorsLayer::new()
.allow_methods(vec![Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_origin(Any)
.allow_headers([AUTHORIZATION, ACCEPT, CONTENT_TYPE]);
cors
}

View File

@ -1,2 +0,0 @@
pub mod auth;
pub mod cors;

View File

@ -1,5 +1,4 @@
mod routes;
mod buster_middleware;
mod types;
pub mod utils;

View File

@ -21,6 +21,7 @@ pub fn router() -> Router {
pub enum ApiResponse<T> {
OK,
#[allow(dead_code)]
Created,
NoContent,
JsonData(T),

View File

@ -374,6 +374,7 @@ async fn is_organization_admin_or_owner(
let user_organization_id = match get_user_organization_id(&user_id).await {
Ok(organization_id) => organization_id,
Err(e) => {
tracing::error!("Error getting user organization id: {}", e);
return Ok(false);
}
};

View File

@ -10,11 +10,25 @@ pub async fn get_chat_raw_llm_messages(
Extension(user): Extension<AuthenticatedUser>,
Path(chat_id): Path<Uuid>,
) -> Result<ApiResponse<GetRawLlmMessagesResponse>, (StatusCode, &'static str)> {
match get_raw_llm_messages_handler(chat_id).await {
let organization_id = match user.organizations.get(0) {
Some(organization) => organization.id,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to get organization id",
));
}
};
match get_raw_llm_messages_handler(chat_id, organization_id).await {
Ok(response) => Ok(ApiResponse::JsonData(response)),
Err(e) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to get raw LLM messages",
)),
Err(e) => {
// Log the error for debugging and monitoring
tracing::error!("Failed to get raw LLM messages: {}", e);
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to get raw LLM messages",
))
}
}
}

View File

@ -1,40 +1,51 @@
use anyhow::Result;
use axum::{
extract::Path,
http::StatusCode,
Extension,
};
use axum::{extract::Path, http::StatusCode, Extension};
use diesel::prelude::*;
use diesel_async::RunQueryDsl;
use uuid::Uuid;
use database::pool::get_pg_pool;
use database::models::DatasetGroup;
use database::schema::dataset_groups;
use crate::routes::rest::ApiResponse;
use super::list_dataset_groups::DatasetGroupInfo;
use crate::routes::rest::ApiResponse;
use database::models::DatasetGroup;
use database::pool::get_pg_pool;
use database::schema::dataset_groups;
use middleware::AuthenticatedUser;
pub async fn get_dataset_group(
Extension(user): Extension<AuthenticatedUser>,
Path(dataset_group_id): Path<Uuid>,
) -> Result<ApiResponse<DatasetGroupInfo>, (StatusCode, &'static str)> {
let dataset_group = match get_dataset_group_handler(dataset_group_id).await {
Ok(group) => group,
Err(e) => {
tracing::error!("Error getting dataset group: {:?}", e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, "Error getting dataset group"));
let organization_id = match user.organizations.get(0) {
Some(organization) => organization.id,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Error getting organization id",
));
}
};
Ok(ApiResponse::JsonData(dataset_group))
match get_dataset_group_handler(dataset_group_id, organization_id).await {
Ok(group) => Ok(ApiResponse::JsonData(group)),
Err(e) => {
tracing::error!("Error getting dataset group: {:?}", e);
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Error getting dataset group",
));
}
}
}
async fn get_dataset_group_handler(dataset_group_id: Uuid) -> Result<DatasetGroupInfo> {
async fn get_dataset_group_handler(
dataset_group_id: Uuid,
organization_id: Uuid,
) -> Result<DatasetGroupInfo> {
let mut conn = get_pg_pool().get().await?;
let dataset_group = dataset_groups::table
.filter(dataset_groups::id.eq(dataset_group_id))
.filter(dataset_groups::organization_id.eq(organization_id))
.filter(dataset_groups::deleted_at.is_null())
.first::<DatasetGroup>(&mut *conn)
.await
@ -46,4 +57,4 @@ async fn get_dataset_group_handler(dataset_group_id: Uuid) -> Result<DatasetGrou
created_at: dataset_group.created_at,
updated_at: dataset_group.updated_at,
})
}
}

View File

@ -1,3 +1,5 @@
#![allow(dead_code, unused_imports, unused_variables)]
use anyhow::Result;
use axum::{extract::Json, Extension};
use chrono::{DateTime, Utc};

View File

@ -83,7 +83,9 @@ struct Measure {
// Add type mapping enum
#[derive(Debug)]
enum ColumnMappingType {
#[allow(dead_code)]
Dimension(String), // String holds the semantic type
#[allow(dead_code)]
Measure(String), // String holds the measure type (e.g., "number")
Unsupported,
}
@ -235,7 +237,7 @@ async fn generate_model_yaml(
// Process each column and categorize as dimension or measure
for col in model_columns {
match map_snowflake_type(&col.type_) {
ColumnMappingType::Dimension(semantic_type) => {
ColumnMappingType::Dimension(_) => {
dimensions.push(Dimension {
name: col.name.clone(),
expr: col.name.clone(),
@ -244,7 +246,7 @@ async fn generate_model_yaml(
searchable: Some(false),
});
}
ColumnMappingType::Measure(measure_type) => {
ColumnMappingType::Measure(_) => {
measures.push(Measure {
name: col.name.clone(),
expr: col.name.clone(),

View File

@ -100,7 +100,10 @@ async fn get_dataset_data_sample_handler(
let sql = format!("SELECT * FROM {}.{} LIMIT 25", schema, database_name);
match query_engine(dataset_id, &sql).await {
Ok(data) => data,
Err(e) => Vec::new(),
Err(e) => {
tracing::error!("Error getting dataset data: {:?}", e);
Vec::new()
}
}
};

View File

@ -104,13 +104,13 @@ async fn list_datasets_handler(
admin_view: Option<bool>,
enabled: Option<bool>,
imported: Option<bool>,
permission_group_id: Option<Uuid>,
_permission_group_id: Option<Uuid>,
_belongs_to: Option<bool>,
data_source_id: Option<Uuid>,
) -> Result<Vec<ListDatasetObject>> {
let page = page.unwrap_or(0);
let page_size = page_size.unwrap_or(25);
let admin_view = admin_view.unwrap_or(false);
let _admin_view = admin_view.unwrap_or(false);
let mut conn = match get_pg_pool().get().await {
Ok(conn) => conn,
@ -281,8 +281,8 @@ async fn get_org_datasets(
async fn get_restricted_user_datasets(
user_id: &Uuid,
page: i64,
page_size: i64,
_page: i64,
_page_size: i64,
) -> Result<Vec<ListDatasetObject>> {
// Direct dataset access
let direct_user_permissioned_datasets_handle = {

View File

@ -1,7 +1,5 @@
use axum::{extract::Query, http::StatusCode, Extension};
use handlers::logs::list_logs_handler::{
list_logs_handler, ListLogsRequest, LogListItem,
};
use handlers::logs::list_logs_handler::{list_logs_handler, ListLogsRequest, LogListItem};
use middleware::AuthenticatedUser;
use serde::Deserialize;
@ -27,7 +25,17 @@ pub async fn list_logs_route(
page_size: query.page_size,
};
match list_logs_handler(request).await {
let organization_id = match user.organizations.get(0) {
Some(organization) => organization.id,
_ => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Error getting organization id",
));
}
};
match list_logs_handler(request, organization_id).await {
Ok(response) => Ok(ApiResponse::JsonData(response)),
Err(e) => {
tracing::error!("Error listing logs: {}", e);

View File

@ -257,7 +257,6 @@ async fn ws_handler(stream: WebSocket, user: AuthenticatedUser, shutdown_tx: Arc
let mut tasks = JoinSet::new();
let start_time = Instant::now();
let last_pong = Arc::new(Mutex::new(Instant::now()));
let (ping_timeout_tx, mut ping_timeout_rx) = oneshot::channel();

View File

@ -241,7 +241,7 @@ pub async fn unsubscribe_from_stream(
.del::<String, redis::Value>(draft_subscription.clone())
.await
{
Ok(response) => {}
Ok(_) => {}
Err(e) => {
tracing::warn!("Error deleting draft subscription key: {}", e);
}
@ -274,6 +274,7 @@ pub async fn set_key_value(key: &String, value: &String) -> Result<()> {
Ok(())
}
#[allow(dead_code)]
pub async fn delete_key_value(key: String) -> Result<()> {
let mut redis_conn = match get_redis_pool().get().await {
Ok(conn) => conn,