diff --git a/api/libs/agents/src/tools/categories/file_tools/create_metrics.rs b/api/libs/agents/src/tools/categories/file_tools/create_metrics.rs index e1093445d..7ab920515 100644 --- a/api/libs/agents/src/tools/categories/file_tools/create_metrics.rs +++ b/api/libs/agents/src/tools/categories/file_tools/create_metrics.rs @@ -12,6 +12,9 @@ use database::{ }; use diesel::insert_into; use diesel_async::RunQueryDsl; +use futures::future::join_all; +use indexmap::IndexMap; +use query_engine::data_types::DataType; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -87,27 +90,44 @@ impl ToolExecutor for CreateMetricFilesTool { let mut created_files = vec![]; let mut failed_files = vec![]; - // Process metric files + // Create futures for concurrent processing + let process_futures = files + .into_iter() + .map(|file| { + let tool_call_id_clone = tool_call_id.clone(); + let user_id = self.agent.get_user_id(); + + async move { + let result = process_metric_file( + tool_call_id_clone, + file.name.clone(), + file.yml_content.clone(), + &user_id, + ) + .await; + + (file.name.clone(), result) + } + }) + .collect::>(); + + // Wait for all futures to complete + let results = join_all(process_futures).await; + + // Process results let mut metric_records = vec![]; let mut metric_ymls = vec![]; let mut results_vec = vec![]; - // First pass - validate and prepare all records - for file in files { - match process_metric_file( - tool_call_id.clone(), - file.name.clone(), - file.yml_content.clone(), - &self.agent.get_user_id(), - ) - .await - { + + for (file_name, result) in results { + match result { Ok((metric_file, metric_yml, message, results)) => { metric_records.push(metric_file); metric_ymls.push(metric_yml); results_vec.push((message, results)); } Err(e) => { - failed_files.push((file.name, e)); + failed_files.push((file_name, e)); } } } diff --git a/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs b/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs index 7e636ea39..047f567b0 100644 --- a/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs +++ b/api/libs/agents/src/tools/categories/file_tools/modify_metrics.rs @@ -11,6 +11,7 @@ use database::{ }; use diesel::{upsert::excluded, ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; +use futures::future::join_all; use indexmap::IndexMap; use query_engine::{data_source_query_routes::query_engine::query_engine, data_types::DataType}; use serde_json::Value; @@ -114,92 +115,103 @@ impl ToolExecutor for ModifyMetricFilesTool { .await { Ok(files) => { - for file in files { - if let Some(modifications) = file_map.get(&file.id) { - match process_metric_file_modification( - file.clone(), - modifications, - start_time.elapsed().as_millis() as i64, - ) - .await - { - Ok(( - mut metric_file, - metric_yml, - results, - validation_message, - validation_results, - )) => { - // Calculate next version number from existing version history - let next_version = - match metric_file.version_history.get_latest_version() { - Some(version) => version.version_number + 1, - None => 1, - }; + // Create futures for concurrent processing of file modifications + let modification_futures = files + .into_iter() + .filter_map(|file| { + let modifications = file_map.get(&file.id)?; + let start_time_elapsed = start_time.elapsed().as_millis() as i64; + + Some(async move { + let result = process_metric_file_modification( + file.clone(), + modifications, + start_time_elapsed, + ).await; + + match result { + Ok((metric_file, metric_yml, results, validation_message, validation_results)) => { + Ok((metric_file, metric_yml, results, validation_message, validation_results)) + } + Err(e) => Err((modifications.file_name.clone(), e.to_string())), + } + }) + }) + .collect::>(); + + // Wait for all futures to complete + let results = join_all(modification_futures).await; + + // Process results + for result in results { + match result { + Ok((mut metric_file, metric_yml, results, validation_message, validation_results)) => { + // Calculate next version number from existing version history + let next_version = match metric_file.version_history.get_latest_version() { + Some(version) => version.version_number + 1, + None => 1, + }; - // Add new version to history - metric_file - .version_history - .add_version(next_version, metric_yml.clone()); + // Add new version to history + metric_file + .version_history + .add_version(next_version, metric_yml.clone()); - // Update metadata if SQL has changed - // The SQL is already validated by process_metric_file_modification - if results.iter().any(|r| r.modification_type == "content") { - // Update the name field from the metric_yml - // This is redundant but ensures the name is set correctly - metric_file.name = metric_yml.name.clone(); + // Update metadata if SQL has changed + // The SQL is already validated by process_metric_file_modification + if results.iter().any(|r| r.modification_type == "content") { + // Update the name field from the metric_yml + // This is redundant but ensures the name is set correctly + metric_file.name = metric_yml.name.clone(); - // Check if we have a dataset to work with - if !metric_yml.dataset_ids.is_empty() { - let dataset_id = metric_yml.dataset_ids[0]; + // Check if we have a dataset to work with + if !metric_yml.dataset_ids.is_empty() { + let dataset_id = metric_yml.dataset_ids[0]; - // Get data source for the dataset - match datasets::table - .filter(datasets::id.eq(dataset_id)) - .select(datasets::data_source_id) - .first::(&mut conn) + // Get data source for the dataset + match datasets::table + .filter(datasets::id.eq(dataset_id)) + .select(datasets::data_source_id) + .first::(&mut conn) + .await + { + Ok(data_source_id) => { + // Execute query to get metadata + match query_engine( + &data_source_id, + &metric_yml.sql, + Some(100), + ) .await - { - Ok(data_source_id) => { - // Execute query to get metadata - match query_engine( - &data_source_id, - &metric_yml.sql, - Some(100), - ) - .await - { - Ok(query_result) => { - // Update metadata - metric_file.data_metadata = - Some(query_result.metadata); - debug!("Updated metadata for metric file {}", metric_file.id); - } - Err(e) => { - debug!("Failed to execute SQL for metadata: {}", e); - // Continue with the update even if metadata refresh fails - } + { + Ok(query_result) => { + // Update metadata + metric_file.data_metadata = + Some(query_result.metadata); + debug!("Updated metadata for metric file {}", metric_file.id); + } + Err(e) => { + debug!("Failed to execute SQL for metadata: {}", e); + // Continue with the update even if metadata refresh fails } } - Err(e) => { - debug!("Failed to get data source ID: {}", e); - // Continue with the update even if we can't get data source - } + } + Err(e) => { + debug!("Failed to get data source ID: {}", e); + // Continue with the update even if we can't get data source } } } + } - batch.files.push(metric_file); - batch.ymls.push(metric_yml); - batch.modification_results.extend(results); - batch.validation_messages.push(validation_message); - batch.validation_results.push(validation_results); - } - Err(e) => { - batch - .failed_modifications - .push((modifications.file_name.clone(), e.to_string())); - } + batch.files.push(metric_file); + batch.ymls.push(metric_yml); + batch.modification_results.extend(results); + batch.validation_messages.push(validation_message); + batch.validation_results.push(validation_results); + } + Err((file_name, error)) => { + batch.failed_modifications.push((file_name, error)); } } }