mirror of https://github.com/buster-so/buster.git
Merge pull request #265 from buster-so/dal/cli-final-features
Dal/cli final features
This commit is contained in:
commit
9060e9fa74
|
@ -12,6 +12,7 @@ members = [
|
|||
"libs/dataset_security",
|
||||
"libs/email",
|
||||
"libs/stored_values",
|
||||
"libs/raindrop",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ sqlx = { workspace = true }
|
|||
stored_values = { path = "../stored_values" }
|
||||
tokio-retry = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
raindrop = { path = "../raindrop" }
|
||||
sql_analyzer = { path = "../sql_analyzer" }
|
||||
|
||||
# Development dependencies
|
||||
|
|
|
@ -11,9 +11,13 @@ use std::time::{Duration, Instant};
|
|||
use std::{collections::HashMap, env, sync::Arc};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tokio_retry::{strategy::ExponentialBackoff, Retry};
|
||||
use tracing::{error, warn};
|
||||
use tracing::{debug, error, info, instrument, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
// Raindrop imports
|
||||
use raindrop::types::{AiData as RaindropAiData, Event as RaindropEvent};
|
||||
use raindrop::RaindropClient;
|
||||
|
||||
// Type definition for tool registry to simplify complex type
|
||||
// No longer needed, defined below
|
||||
use crate::models::AgentThread;
|
||||
|
@ -561,6 +565,9 @@ impl Agent {
|
|||
trace_builder: Option<TraceBuilder>,
|
||||
parent_span: Option<braintrust::Span>,
|
||||
) -> Result<()> {
|
||||
// Attempt to initialize Raindrop client (non-blocking)
|
||||
let raindrop_client = RaindropClient::new().ok();
|
||||
|
||||
// Set the initial thread
|
||||
{
|
||||
let mut current = agent.current_thread.write().await;
|
||||
|
@ -721,6 +728,35 @@ impl Agent {
|
|||
..Default::default()
|
||||
};
|
||||
|
||||
// --- Track Request with Raindrop ---
|
||||
if let Some(client) = raindrop_client.clone() {
|
||||
let request_clone = request.clone(); // Clone request for tracking
|
||||
let user_id = agent.user_id.clone();
|
||||
let session_id = agent.session_id.to_string();
|
||||
let current_history = agent.get_conversation_history().await.unwrap_or_default();
|
||||
tokio::spawn(async move {
|
||||
let event = RaindropEvent {
|
||||
user_id: user_id.to_string(),
|
||||
event: "llm_request".to_string(),
|
||||
properties: Some(HashMap::from([(
|
||||
"conversation_history".to_string(),
|
||||
serde_json::to_value(¤t_history).unwrap_or(Value::Null),
|
||||
)])),
|
||||
attachments: None,
|
||||
ai_data: Some(RaindropAiData {
|
||||
model: request_clone.model.clone(),
|
||||
input: serde_json::to_string(&request_clone.messages).unwrap_or_default(),
|
||||
output: "".to_string(), // Output is not known yet
|
||||
convo_id: Some(session_id.clone()),
|
||||
}),
|
||||
event_id: None, // Raindrop assigns this
|
||||
timestamp: Some(chrono::Utc::now()),
|
||||
};
|
||||
if let Err(e) = client.track_events(vec![event]).await {}
|
||||
});
|
||||
}
|
||||
// --- End Track Request ---
|
||||
|
||||
// --- Retry Logic for Initial Stream Request ---
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
|
||||
|
||||
|
@ -980,6 +1016,37 @@ impl Agent {
|
|||
// Update thread with assistant message
|
||||
agent.update_current_thread(final_message.clone()).await?;
|
||||
|
||||
// --- Track Response with Raindrop ---
|
||||
if let Some(client) = raindrop_client {
|
||||
let request_clone = request.clone(); // Clone again for response tracking
|
||||
let final_message_clone = final_message.clone();
|
||||
let user_id = agent.user_id.clone();
|
||||
let session_id = agent.session_id.to_string();
|
||||
// Get history *after* adding the final message
|
||||
let current_history = agent.get_conversation_history().await.unwrap_or_default();
|
||||
tokio::spawn(async move {
|
||||
let event = RaindropEvent {
|
||||
user_id: user_id.to_string(),
|
||||
event: "llm_response".to_string(),
|
||||
properties: Some(HashMap::from([(
|
||||
"conversation_history".to_string(),
|
||||
serde_json::to_value(¤t_history).unwrap_or(Value::Null),
|
||||
)])),
|
||||
attachments: None,
|
||||
ai_data: Some(RaindropAiData {
|
||||
model: request_clone.model.clone(),
|
||||
input: serde_json::to_string(&request_clone.messages).unwrap_or_default(),
|
||||
output: serde_json::to_string(&final_message_clone).unwrap_or_default(),
|
||||
convo_id: Some(session_id.clone()),
|
||||
}),
|
||||
event_id: None, // Raindrop assigns this
|
||||
timestamp: Some(chrono::Utc::now()),
|
||||
};
|
||||
if let Err(e) = client.track_events(vec![event]).await {}
|
||||
});
|
||||
}
|
||||
// --- End Track Response ---
|
||||
|
||||
// Get the updated thread state AFTER adding the final assistant message
|
||||
// This will be used for the potential recursive call later.
|
||||
let mut updated_thread_for_recursion = agent
|
||||
|
@ -1825,4 +1892,3 @@ mod tests {
|
|||
assert_eq!(agent.get_state_bool("bool_key").await, None);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "raindrop"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
# Workspace dependencies
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
thiserror = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "rustls-tls"] }
|
||||
|
||||
# Non-workspace dependencies (if any, prefer workspace)
|
||||
dotenvy = "0.15"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
mockito = { workspace = true }
|
|
@ -0,0 +1,27 @@
|
|||
use thiserror::Error;
|
||||
use reqwest::header::InvalidHeaderValue;
|
||||
|
||||
/// Custom error types for the Raindrop SDK.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum RaindropError {
|
||||
#[error("Missing Raindrop API Write Key. Set the RAINDROP_WRITE_KEY environment variable.")]
|
||||
MissingApiKey,
|
||||
|
||||
#[error("Invalid header value provided: {0}")]
|
||||
InvalidHeaderValue(#[from] InvalidHeaderValue),
|
||||
|
||||
#[error("Failed to build HTTP client: {0}")]
|
||||
HttpClientBuildError(#[from] reqwest::Error),
|
||||
|
||||
#[error("HTTP request failed: {0}")]
|
||||
RequestError(reqwest::Error),
|
||||
|
||||
#[error("Raindrop API error: {status} - {body}")]
|
||||
ApiError {
|
||||
status: reqwest::StatusCode,
|
||||
body: String,
|
||||
},
|
||||
|
||||
#[error("Failed to serialize request body: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
#![doc = "A Rust SDK for interacting with the Raindrop.ai API."]
|
||||
|
||||
pub mod errors;
|
||||
pub mod types;
|
||||
|
||||
use anyhow::Context;
|
||||
use reqwest::{Client, header};
|
||||
use std::env;
|
||||
use tracing::{debug, error, instrument};
|
||||
|
||||
use errors::RaindropError;
|
||||
use types::{Event, Signal};
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "https://api.raindrop.ai/v1";
|
||||
|
||||
/// Client for interacting with the Raindrop API.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RaindropClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
write_key: String,
|
||||
}
|
||||
|
||||
impl RaindropClient {
|
||||
/// Creates a new RaindropClient.
|
||||
/// Reads the write key from the `RAINDROP_WRITE_KEY` environment variable.
|
||||
/// Uses the default Raindrop API base URL.
|
||||
pub fn new() -> Result<Self, RaindropError> {
|
||||
let write_key = env::var("RAINDROP_WRITE_KEY")
|
||||
.map_err(|_| RaindropError::MissingApiKey)?;
|
||||
let base_url = DEFAULT_BASE_URL.to_string();
|
||||
Self::build_client(write_key, base_url)
|
||||
}
|
||||
|
||||
/// Creates a new RaindropClient with a specific write key and base URL.
|
||||
/// Useful for testing or custom deployments.
|
||||
pub fn with_key_and_url(write_key: String, base_url: &str) -> Result<Self, RaindropError> {
|
||||
Self::build_client(write_key, base_url.to_string())
|
||||
}
|
||||
|
||||
/// Builds the underlying reqwest client.
|
||||
fn build_client(write_key: String, base_url: String) -> Result<Self, RaindropError> {
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::AUTHORIZATION,
|
||||
header::HeaderValue::from_str(&format!("Bearer {}", write_key))
|
||||
.map_err(RaindropError::InvalidHeaderValue)?,
|
||||
);
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let client = Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(RaindropError::HttpClientBuildError)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
write_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Tracks a batch of events.
|
||||
#[instrument(skip(self, events), fields(count = events.len()))]
|
||||
pub async fn track_events(&self, events: Vec<Event>) -> Result<(), RaindropError> {
|
||||
if events.is_empty() {
|
||||
debug!("No events to track, skipping API call.");
|
||||
return Ok(());
|
||||
}
|
||||
let url = format!("{}/events/track", self.base_url);
|
||||
self.post_data(&url, &events).await
|
||||
}
|
||||
|
||||
/// Tracks a batch of signals.
|
||||
#[instrument(skip(self, signals), fields(count = signals.len()))]
|
||||
pub async fn track_signals(&self, signals: Vec<Signal>) -> Result<(), RaindropError> {
|
||||
if signals.is_empty() {
|
||||
debug!("No signals to track, skipping API call.");
|
||||
return Ok(());
|
||||
}
|
||||
let url = format!("{}/signals/track", self.base_url);
|
||||
self.post_data(&url, &signals).await
|
||||
}
|
||||
|
||||
/// Helper function to POST JSON data to a specified URL.
|
||||
async fn post_data<T: serde::Serialize>(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &T,
|
||||
) -> Result<(), RaindropError> {
|
||||
debug!(url = url, "Sending POST request to Raindrop");
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.json(data)
|
||||
.send()
|
||||
.await
|
||||
.map_err(RaindropError::RequestError)?;
|
||||
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
debug!(url = url, status = %status, "Raindrop API call successful");
|
||||
Ok(())
|
||||
} else {
|
||||
let body = response.text().await.unwrap_or_else(|_| "Failed to read error body".to_string());
|
||||
error!(url = url, status = %status, body = body, "Raindrop API call failed");
|
||||
Err(RaindropError::ApiError { status, body })
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
use std::collections::HashMap;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Represents a single event to be tracked.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Event {
|
||||
pub user_id: String,
|
||||
pub event: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub properties: Option<HashMap<String, Value>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub attachments: Option<Vec<Attachment>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ai_data: Option<AiData>,
|
||||
|
||||
// Optional fields provided by Raindrop API
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub event_id: Option<String>, // Returned by Raindrop, optional on send
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Represents an attachment associated with an event.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Attachment {
|
||||
#[serde(rename = "type")] // Use `type` keyword in JSON
|
||||
pub attachment_type: String, // e.g., "image", "text", "json"
|
||||
pub value: String, // URL or content
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>, // e.g., "input", "output"
|
||||
}
|
||||
|
||||
/// Represents AI-specific data for an event.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct AiData {
|
||||
pub model: String,
|
||||
pub input: String,
|
||||
pub output: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub convo_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Represents a single signal to be tracked.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Signal {
|
||||
pub event_id: String, // The ID of the event this signal relates to
|
||||
pub signal_name: String, // e.g., "thumbs_down", "corrected_answer"
|
||||
pub signal_type: String, // e.g., "feedback", "correction"
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub properties: Option<HashMap<String, Value>>,
|
||||
|
||||
// Optional fields
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timestamp: Option<DateTime<Utc>>,
|
||||
}
|
Loading…
Reference in New Issue