Merge pull request #265 from buster-so/dal/cli-final-features

Dal/cli final features
This commit is contained in:
dal 2025-05-05 14:27:24 -07:00 committed by GitHub
commit 9060e9fa74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 300 additions and 2 deletions

View File

@ -12,6 +12,7 @@ members = [
"libs/dataset_security",
"libs/email",
"libs/stored_values",
"libs/raindrop",
]
resolver = "2"

View File

@ -34,6 +34,7 @@ sqlx = { workspace = true }
stored_values = { path = "../stored_values" }
tokio-retry = { workspace = true }
thiserror = { workspace = true }
raindrop = { path = "../raindrop" }
sql_analyzer = { path = "../sql_analyzer" }
# Development dependencies

View File

@ -11,9 +11,13 @@ use std::time::{Duration, Instant};
use std::{collections::HashMap, env, sync::Arc};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio_retry::{strategy::ExponentialBackoff, Retry};
use tracing::{error, warn};
use tracing::{debug, error, info, instrument, warn};
use uuid::Uuid;
// Raindrop imports
use raindrop::types::{AiData as RaindropAiData, Event as RaindropEvent};
use raindrop::RaindropClient;
// Type definition for tool registry to simplify complex type
// No longer needed, defined below
use crate::models::AgentThread;
@ -561,6 +565,9 @@ impl Agent {
trace_builder: Option<TraceBuilder>,
parent_span: Option<braintrust::Span>,
) -> Result<()> {
// Attempt to initialize Raindrop client (non-blocking)
let raindrop_client = RaindropClient::new().ok();
// Set the initial thread
{
let mut current = agent.current_thread.write().await;
@ -721,6 +728,35 @@ impl Agent {
..Default::default()
};
// --- Track Request with Raindrop ---
if let Some(client) = raindrop_client.clone() {
let request_clone = request.clone(); // Clone request for tracking
let user_id = agent.user_id.clone();
let session_id = agent.session_id.to_string();
let current_history = agent.get_conversation_history().await.unwrap_or_default();
tokio::spawn(async move {
let event = RaindropEvent {
user_id: user_id.to_string(),
event: "llm_request".to_string(),
properties: Some(HashMap::from([(
"conversation_history".to_string(),
serde_json::to_value(&current_history).unwrap_or(Value::Null),
)])),
attachments: None,
ai_data: Some(RaindropAiData {
model: request_clone.model.clone(),
input: serde_json::to_string(&request_clone.messages).unwrap_or_default(),
output: "".to_string(), // Output is not known yet
convo_id: Some(session_id.clone()),
}),
event_id: None, // Raindrop assigns this
timestamp: Some(chrono::Utc::now()),
};
if let Err(e) = client.track_events(vec![event]).await {}
});
}
// --- End Track Request ---
// --- Retry Logic for Initial Stream Request ---
let retry_strategy = ExponentialBackoff::from_millis(100).take(3); // Retry 3 times, ~100ms, ~200ms, ~400ms
@ -980,6 +1016,37 @@ impl Agent {
// Update thread with assistant message
agent.update_current_thread(final_message.clone()).await?;
// --- Track Response with Raindrop ---
if let Some(client) = raindrop_client {
let request_clone = request.clone(); // Clone again for response tracking
let final_message_clone = final_message.clone();
let user_id = agent.user_id.clone();
let session_id = agent.session_id.to_string();
// Get history *after* adding the final message
let current_history = agent.get_conversation_history().await.unwrap_or_default();
tokio::spawn(async move {
let event = RaindropEvent {
user_id: user_id.to_string(),
event: "llm_response".to_string(),
properties: Some(HashMap::from([(
"conversation_history".to_string(),
serde_json::to_value(&current_history).unwrap_or(Value::Null),
)])),
attachments: None,
ai_data: Some(RaindropAiData {
model: request_clone.model.clone(),
input: serde_json::to_string(&request_clone.messages).unwrap_or_default(),
output: serde_json::to_string(&final_message_clone).unwrap_or_default(),
convo_id: Some(session_id.clone()),
}),
event_id: None, // Raindrop assigns this
timestamp: Some(chrono::Utc::now()),
};
if let Err(e) = client.track_events(vec![event]).await {}
});
}
// --- End Track Response ---
// Get the updated thread state AFTER adding the final assistant message
// This will be used for the potential recursive call later.
let mut updated_thread_for_recursion = agent
@ -1825,4 +1892,3 @@ mod tests {
assert_eq!(agent.get_state_bool("bool_key").await, None);
}
}

View File

@ -0,0 +1,22 @@
[package]
name = "raindrop"
version = "0.1.0"
edition = "2021"
[dependencies]
# Workspace dependencies
anyhow = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tokio = { workspace = true, features = ["full"] }
tracing = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
thiserror = { workspace = true }
reqwest = { workspace = true, features = ["json", "rustls-tls"] }
# Non-workspace dependencies (if any, prefer workspace)
dotenvy = "0.15"
[dev-dependencies]
tokio-test = { workspace = true }
mockito = { workspace = true }

View File

@ -0,0 +1,27 @@
use thiserror::Error;
use reqwest::header::InvalidHeaderValue;
/// Custom error types for the Raindrop SDK.
#[derive(Error, Debug)]
pub enum RaindropError {
#[error("Missing Raindrop API Write Key. Set the RAINDROP_WRITE_KEY environment variable.")]
MissingApiKey,
#[error("Invalid header value provided: {0}")]
InvalidHeaderValue(#[from] InvalidHeaderValue),
#[error("Failed to build HTTP client: {0}")]
HttpClientBuildError(#[from] reqwest::Error),
#[error("HTTP request failed: {0}")]
RequestError(reqwest::Error),
#[error("Raindrop API error: {status} - {body}")]
ApiError {
status: reqwest::StatusCode,
body: String,
},
#[error("Failed to serialize request body: {0}")]
SerializationError(#[from] serde_json::Error),
}

View File

@ -0,0 +1,114 @@
#![doc = "A Rust SDK for interacting with the Raindrop.ai API."]
pub mod errors;
pub mod types;
use anyhow::Context;
use reqwest::{Client, header};
use std::env;
use tracing::{debug, error, instrument};
use errors::RaindropError;
use types::{Event, Signal};
const DEFAULT_BASE_URL: &str = "https://api.raindrop.ai/v1";
/// Client for interacting with the Raindrop API.
#[derive(Debug, Clone)]
pub struct RaindropClient {
client: Client,
base_url: String,
write_key: String,
}
impl RaindropClient {
/// Creates a new RaindropClient.
/// Reads the write key from the `RAINDROP_WRITE_KEY` environment variable.
/// Uses the default Raindrop API base URL.
pub fn new() -> Result<Self, RaindropError> {
let write_key = env::var("RAINDROP_WRITE_KEY")
.map_err(|_| RaindropError::MissingApiKey)?;
let base_url = DEFAULT_BASE_URL.to_string();
Self::build_client(write_key, base_url)
}
/// Creates a new RaindropClient with a specific write key and base URL.
/// Useful for testing or custom deployments.
pub fn with_key_and_url(write_key: String, base_url: &str) -> Result<Self, RaindropError> {
Self::build_client(write_key, base_url.to_string())
}
/// Builds the underlying reqwest client.
fn build_client(write_key: String, base_url: String) -> Result<Self, RaindropError> {
let mut headers = header::HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Bearer {}", write_key))
.map_err(RaindropError::InvalidHeaderValue)?,
);
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
let client = Client::builder()
.default_headers(headers)
.build()
.map_err(RaindropError::HttpClientBuildError)?;
Ok(Self {
client,
base_url,
write_key,
})
}
/// Tracks a batch of events.
#[instrument(skip(self, events), fields(count = events.len()))]
pub async fn track_events(&self, events: Vec<Event>) -> Result<(), RaindropError> {
if events.is_empty() {
debug!("No events to track, skipping API call.");
return Ok(());
}
let url = format!("{}/events/track", self.base_url);
self.post_data(&url, &events).await
}
/// Tracks a batch of signals.
#[instrument(skip(self, signals), fields(count = signals.len()))]
pub async fn track_signals(&self, signals: Vec<Signal>) -> Result<(), RaindropError> {
if signals.is_empty() {
debug!("No signals to track, skipping API call.");
return Ok(());
}
let url = format!("{}/signals/track", self.base_url);
self.post_data(&url, &signals).await
}
/// Helper function to POST JSON data to a specified URL.
async fn post_data<T: serde::Serialize>(
&self,
url: &str,
data: &T,
) -> Result<(), RaindropError> {
debug!(url = url, "Sending POST request to Raindrop");
let response = self
.client
.post(url)
.json(data)
.send()
.await
.map_err(RaindropError::RequestError)?;
let status = response.status();
if status.is_success() {
debug!(url = url, status = %status, "Raindrop API call successful");
Ok(())
} else {
let body = response.text().await.unwrap_or_else(|_| "Failed to read error body".to_string());
error!(url = url, status = %status, body = body, "Raindrop API call failed");
Err(RaindropError::ApiError { status, body })
}
}
}

View File

@ -0,0 +1,67 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// Represents a single event to be tracked.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Event {
pub user_id: String,
pub event: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub attachments: Option<Vec<Attachment>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ai_data: Option<AiData>,
// Optional fields provided by Raindrop API
#[serde(skip_serializing_if = "Option::is_none")]
pub event_id: Option<String>, // Returned by Raindrop, optional on send
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<DateTime<Utc>>,
}
/// Represents an attachment associated with an event.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Attachment {
#[serde(rename = "type")] // Use `type` keyword in JSON
pub attachment_type: String, // e.g., "image", "text", "json"
pub value: String, // URL or content
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>, // e.g., "input", "output"
}
/// Represents AI-specific data for an event.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct AiData {
pub model: String,
pub input: String,
pub output: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub convo_id: Option<String>,
}
/// Represents a single signal to be tracked.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Signal {
pub event_id: String, // The ID of the event this signal relates to
pub signal_name: String, // e.g., "thumbs_down", "corrected_answer"
pub signal_type: String, // e.g., "feedback", "correction"
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Value>>,
// Optional fields
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<DateTime<Utc>>,
}