From 3d5c05f89d2eb6651e9ec8123789b135e68f2af8 Mon Sep 17 00:00:00 2001 From: dal Date: Tue, 25 Feb 2025 20:36:56 -0700 Subject: [PATCH] match redshift with postgres --- .../data_source_query_routes/query_router.rs | 2 +- .../redshift_query.rs | 133 +++++++++++++----- 2 files changed, 96 insertions(+), 39 deletions(-) diff --git a/api/src/utils/query_engine/data_source_query_routes/query_router.rs b/api/src/utils/query_engine/data_source_query_routes/query_router.rs index 474cbd74d..aa4babf04 100644 --- a/api/src/utils/query_engine/data_source_query_routes/query_router.rs +++ b/api/src/utils/query_engine/data_source_query_routes/query_router.rs @@ -129,7 +129,7 @@ async fn route_to_query( let redshift_client = get_redshift_connection(&credentials).await?; - let results = match redshift_query(redshift_client, sql.clone()).await { + let results = match redshift_query(redshift_client, sql.clone(), limit).await { Ok(results) => results, Err(e) => { tracing::error!("There was an issue while fetching the tables: {}", e); diff --git a/api/src/utils/query_engine/data_source_query_routes/redshift_query.rs b/api/src/utils/query_engine/data_source_query_routes/redshift_query.rs index 49eee66a7..a28fa2ddf 100644 --- a/api/src/utils/query_engine/data_source_query_routes/redshift_query.rs +++ b/api/src/utils/query_engine/data_source_query_routes/redshift_query.rs @@ -3,62 +3,119 @@ use futures::TryStreamExt; use indexmap::IndexMap; use anyhow::{Error, Result}; -use sqlx::{types::BigDecimal, Column, Pool, Postgres, Row}; -use num_traits::cast::ToPrimitive; +use sqlx::{Column, Pool, Postgres, Row}; +use tokio::task; use crate::utils::query_engine::data_types::DataType; pub async fn redshift_query( pg_pool: Pool, query: String, + limit: Option, ) -> Result>, Error> { let mut stream = sqlx::query(&query).fetch(&pg_pool); let mut result: Vec> = Vec::new(); - let mut count = 0; + let batch_size = 100; + + let mut rows = Vec::new(); while let Some(row) = stream.try_next().await? { - let mut row_map: IndexMap = IndexMap::new(); + rows.push(row); + count += 1; - for (i, column) in row.columns().iter().enumerate() { - let column_name = column.name(); - let type_info = column.type_info().clone().to_string(); - let column_value = match type_info.as_str() { - "BOOL" => DataType::Bool(Some(row.get::(i))), - "BYTEA" => DataType::Bytea(Some(row.get::, _>(i))), - "CHAR" => DataType::Char(Some(row.get::(i))), - "INT8" => DataType::Int8(Some(row.get::(i))), - "INT4" => DataType::Int4(Some(row.get::(i))), - "INT2" => DataType::Int2(Some(row.get::(i))), - "TEXT" | "VARCHAR" => DataType::Text(Some(row.get::(i))), - "FLOAT4" => DataType::Float4(Some(row.get::(i))), - "FLOAT8" => DataType::Float8(Some(row.get::(i))), - "NUMERIC" => { - let value: BigDecimal = row.get::(i); - let value: f64 = value.to_f64().unwrap(); - DataType::Float8(Some(value)) - } - "UUID" => DataType::Uuid(Some(row.get::(i))), - "TIMESTAMP" => DataType::Timestamp(Some(row.get::(i))), - "DATE" => DataType::Date(Some(row.get::(i))), - "TIME" => DataType::Time(Some(row.get::(i))), - "TIMESTAMPTZ" => { - DataType::Timestamptz(Some(row.get::, _>(i))) - } - "JSON" | "JSONB" => DataType::Json(Some(row.get::(i))), - _ => DataType::Unknown(Some(row.get::(i))), - }; - - row_map.insert(column_name.to_string(), column_value); + if count % batch_size == 0 { + let batch_result = process_batch(rows).await?; + result.extend(batch_result); + rows = Vec::new(); } - result.push(row_map); - - count += 1; - if count >= 1000 { + if let Some(limit) = limit { + if count >= limit { + break; + } + } else if count >= 1000 { + // Default limit of 1000 if no limit is specified break; } } + + // Process any remaining rows + if !rows.is_empty() { + let batch_result = process_batch(rows).await?; + result.extend(batch_result); + } + Ok(result) } + +async fn process_batch( + rows: Vec, +) -> Result>, Error> { + let mut tasks = Vec::new(); + + for (index, row) in rows.into_iter().enumerate() { + let task = task::spawn(async move { + let mut row_map: IndexMap = IndexMap::with_capacity(row.len()); + + for (i, column) in row.columns().iter().enumerate() { + let column_name = column.name(); + let type_info = column.type_info().clone().to_string(); + let column_value = match type_info.as_str() { + "BOOL" => DataType::Bool(row.try_get::(i).ok()), + "BYTEA" => DataType::Bytea(row.try_get::, _>(i).ok()), + "CHAR" => DataType::Char(row.try_get::(i).ok()), + "INT8" => DataType::Int8(row.try_get::(i).ok()), + "INT4" => DataType::Int4(row.try_get::(i).ok()), + "INT2" => DataType::Int2(row.try_get::(i).ok()), + "TEXT" | "VARCHAR" => DataType::Text(row.try_get::(i).ok()), + "FLOAT4" => DataType::Float4(row.try_get::(i).ok()), + "FLOAT8" => DataType::Float8(row.try_get::(i).ok()), + "NUMERIC" => { + DataType::Float8(row.try_get(i).ok().and_then( + |v: sqlx::types::BigDecimal| v.to_string().parse::().ok(), + )) + } + "UUID" => DataType::Uuid(row.try_get::(i).ok()), + "TIMESTAMP" => { + DataType::Timestamp(row.try_get::(i).ok()) + } + "DATE" => DataType::Date(row.try_get::(i).ok()), + "TIME" => DataType::Time(row.try_get::(i).ok()), + "TIMESTAMPTZ" => { + DataType::Timestamptz(row.try_get::, _>(i).ok()) + } + "JSON" | "JSONB" => DataType::Json(row.try_get::(i).ok()), + _ => DataType::Unknown(row.try_get::(i).ok()), + }; + + row_map.insert(column_name.to_string(), column_value); + } + + (index, row_map) + }); + + tasks.push(task); + } + + let batch_result: Vec<_> = match futures::future::try_join_all(tasks).await { + Ok(batch_result) => batch_result, + Err(e) => { + tracing::error!("Error joining tasks: {:?}", e); + Vec::new() + } + }; + + let mut sorted_result: Vec<(usize, IndexMap)> = + batch_result.into_iter().collect(); + + sorted_result.sort_by_key(|(index, _)| *index); + + let final_result: Vec> = sorted_result + .into_iter() + .map(|(_, row_map)| row_map) + .collect(); + + Ok(final_result) +}