query engine optimization

This commit is contained in:
dal 2025-04-01 13:22:35 -06:00
parent 74a2c4a493
commit c6a22b12d9
No known key found for this signature in database
GPG Key ID: 16F4B0E1E9F61122
7 changed files with 423 additions and 263 deletions

View File

@ -5,6 +5,7 @@ use chrono::{NaiveDate, NaiveTime};
use gcp_bigquery_client::{model::query_request::QueryRequest, Client};
use serde_json::{Number, Value};
use crate::data_types::DataType;
pub async fn bigquery_query(
@ -78,42 +79,94 @@ pub async fn bigquery_query(
Ok(typed_rows)
}
fn parse_string_to_datatype(s: &str) -> DataType {
if let Ok(value) = s.parse::<i32>() {
DataType::Int4(Some(value))
} else if let Ok(value) = s.parse::<i64>() {
DataType::Int8(Some(value))
} else if let Ok(value) = s.parse::<f32>() {
DataType::Float4(Some(value))
} else if let Ok(value) = s.parse::<f64>() {
DataType::Float8(Some(value))
} else if let Ok(value) = s.parse::<bool>() {
DataType::Bool(Some(value))
} else if let Ok(value) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
DataType::Date(Some(value))
} else if let Ok(value) = NaiveTime::parse_from_str(s, "%H:%M:%S%.f") {
DataType::Time(Some(value))
} else if let Ok(value) = serde_json::from_str::<Value>(s) {
DataType::Json(Some(value))
} else {
DataType::Text(Some(s.to_string()))
#[cfg_attr(test, allow(dead_code))]
pub fn parse_string_to_datatype(s: &str) -> DataType {
// Fast path for empty strings or simple text
if s.is_empty() || !s.starts_with(&['-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't', 'f', '{', '['][..]) {
return DataType::Text(Some(s.to_string()));
}
// Check for boolean values first (very fast)
if s == "true" {
return DataType::Bool(Some(true));
} else if s == "false" {
return DataType::Bool(Some(false));
}
// Try to parse as integer first
if let Ok(value) = s.parse::<i32>() {
return DataType::Int4(Some(value));
}
// Check first character for efficiency
match s.chars().next().unwrap() {
// Likely number
'-' | '0'..='9' => {
// Try larger integer types
if let Ok(value) = s.parse::<i64>() {
return DataType::Int8(Some(value));
}
// Try floating point
if let Ok(value) = s.parse::<f64>() {
if value >= f32::MIN as f64 && value <= f32::MAX as f64 {
return DataType::Float4(Some(value as f32));
} else {
return DataType::Float8(Some(value));
}
}
// Check for date format (YYYY-MM-DD)
if s.len() == 10 && s.chars().nth(4) == Some('-') && s.chars().nth(7) == Some('-') {
if let Ok(value) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
return DataType::Date(Some(value));
}
}
// Check for time format
if s.contains(':') && s.len() >= 8 {
if let Ok(value) = NaiveTime::parse_from_str(s, "%H:%M:%S%.f") {
return DataType::Time(Some(value));
}
}
},
// Likely JSON object or array
'{' | '[' => {
if let Ok(value) = serde_json::from_str::<Value>(s) {
return DataType::Json(Some(value));
}
},
_ => {}
}
// Default to text
DataType::Text(Some(s.to_string()))
}
fn parse_number_to_datatype(n: &Number) -> DataType {
if let Some(i) = n.as_i64() {
#[cfg_attr(test, allow(dead_code))]
pub fn parse_number_to_datatype(n: &Number) -> DataType {
// Check if it's an integer first (more common case)
if n.is_i64() {
let i = n.as_i64().unwrap();
// Use 32-bit int where possible to save memory
if i >= i32::MIN as i64 && i <= i32::MAX as i64 {
DataType::Int4(Some(i as i32))
return DataType::Int4(Some(i as i32));
} else {
DataType::Int8(Some(i))
return DataType::Int8(Some(i));
}
} else if let Some(f) = n.as_f64() {
}
// Then check for float
if n.is_f64() {
let f = n.as_f64().unwrap();
// Use 32-bit float where possible to save memory
if f >= f32::MIN as f64 && f <= f32::MAX as f64 {
DataType::Float4(Some(f as f32))
return DataType::Float4(Some(f as f32));
} else {
DataType::Float8(Some(f))
return DataType::Float8(Some(f));
}
} else {
DataType::Unknown(Some("Invalid number".to_string()))
}
// Should rarely happen
DataType::Unknown(Some("Invalid number".to_string()))
}

View File

@ -12,32 +12,46 @@ pub async fn databricks_query(
query: String,
limit: Option<i64>,
) -> Result<Vec<IndexMap<std::string::String, DataType>>, Error> {
let results = match databricks_client.query(query).await {
// Apply the limit directly at the database level
let default_limit = 5000;
let limit_value = limit.unwrap_or(default_limit);
// Append LIMIT to the query if it doesn't already contain a LIMIT clause
let sql_with_limit = if !query.to_lowercase().contains("limit") {
format!("{} LIMIT {}", query, limit_value)
} else {
query
};
// Execute the query with the limit
let results = match databricks_client.query(sql_with_limit).await {
Ok(results) => results,
Err(e) => {
tracing::error!("Error: {}", e);
tracing::error!("Error executing Databricks query: {}", e);
return Err(anyhow!(e.to_string()));
}
};
let mut result: Vec<IndexMap<String, DataType>> = Vec::new();
let max_rows = limit.unwrap_or(5000) as usize;
// Create vector with estimated capacity
let mut result: Vec<IndexMap<String, DataType>> = Vec::with_capacity(limit_value as usize);
// Get rows from results
let rows = match results.result.data_array {
Some(rows) => rows.into_iter().take(max_rows).collect::<Vec<_>>(),
Some(rows) => rows,
None => return Ok(Vec::new()),
};
let columns = results.manifest.schema.columns;
// Process rows with optimized type conversions
for row in rows {
let columns_clone = columns.clone();
for (i, column) in columns_clone.iter().enumerate() {
let mut row_map: IndexMap<String, DataType> = IndexMap::new();
let mut row_map: IndexMap<String, DataType> = IndexMap::with_capacity(columns.len());
for (i, column) in columns.iter().enumerate() {
let column_name = column.name.clone();
let type_info = column.type_name.clone();
// Use match with string type info for efficient type conversion
let column_value = match type_info.as_str() {
"BIGINT" => DataType::Int8(row[i].parse::<i64>().ok()),
"BOOL" => DataType::Bool(row[i].parse::<bool>().ok()),
@ -57,10 +71,11 @@ pub async fn databricks_query(
_ => DataType::Unknown(Some(row[i].to_string())),
};
row_map.insert(column_name.to_string(), column_value);
result.push(row_map);
row_map.insert(column_name, column_value);
}
result.push(row_map);
}
Ok(result)
}

View File

@ -2,73 +2,68 @@ use chrono::Utc;
use indexmap::IndexMap;
use anyhow::Error;
use futures::{future::join_all, TryStreamExt};
use futures::TryStreamExt;
use sqlx::{Column, MySql, Pool, Row};
use tokio::task;
use crate::data_types::DataType;
pub async fn mysql_query(
pg_pool: Pool<MySql>,
pool: Pool<MySql>,
query: String,
limit: Option<i64>,
) -> Result<Vec<IndexMap<std::string::String, DataType>>, Error> {
let mut stream = sqlx::query(&query).fetch(&pg_pool);
// Apply the limit directly at the database level
let default_limit = 5000;
let limit_value = limit.unwrap_or(default_limit);
// Append LIMIT to the query if it doesn't already contain a LIMIT clause
let sql_with_limit = if !query.to_lowercase().contains("limit") {
format!("{} LIMIT ?", query)
} else {
query
};
// Create query with the limit parameter
let mut stream = sqlx::query(&sql_with_limit)
.bind(limit_value)
.fetch(&pool);
let mut result: Vec<IndexMap<String, DataType>>= Vec::new();
let mut count = 0;
// Pre-allocate result vector with estimated capacity to reduce allocations
let mut result: Vec<IndexMap<String, DataType>> = Vec::with_capacity(limit_value as usize);
// Process all rows without spawning tasks per row
while let Some(row) = stream.try_next().await? {
let mut row_map: IndexMap<String, DataType> = IndexMap::new();
let mut row_map: IndexMap<String, DataType> = IndexMap::with_capacity(row.len());
let mut row_value_handlers = Vec::new();
row_value_handlers.push(task::spawn(async move {
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let type_info = column.type_info().clone().to_string();
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let type_info = column.type_info().clone().to_string();
let column_value = match type_info.as_str() {
"BOOL" | "BOOLEAN" => DataType::Bool(row.try_get::<bool, _>(i).ok()),
"BIT" => DataType::Bytea(row.try_get::<Vec<u8>, _>(i).ok()),
"CHAR" => DataType::Char(row.try_get::<String, _>(i).ok()),
"BIGINT" => DataType::Int8(row.try_get::<i64, _>(i).ok()),
"MEDIUMINT" | "INT" | "INTEGER" => DataType::Int4(row.try_get::<i32, _>(i).ok()),
"TINYINT" | "SMALLINT" => DataType::Int2(row.try_get::<i16, _>(i).ok()),
"TEXT" | "VARCHAR" => DataType::Text(row.try_get::<String, _>(i).ok()),
"FLOAT" => DataType::Float4(row.try_get::<f32, _>(i).ok()),
"DOUBLE" => DataType::Float8(row.try_get::<f64, _>(i).ok()),
"DECIMAL" | "DEC" => DataType::Float8(row.try_get::<f64, _>(i).ok()),
"UUID" => DataType::Uuid(row.try_get::<uuid::Uuid, _>(i).ok()),
"TIMESTAMP" | "DATETIME" => DataType::Timestamp(row.try_get::<chrono::NaiveDateTime, _>(i).ok()),
"DATE" => DataType::Date(row.try_get::<chrono::NaiveDate, _>(i).ok()),
"TIME" => DataType::Time(row.try_get::<chrono::NaiveTime, _>(i).ok()),
"TIMESTAMPTZ" => DataType::Timestamptz(row.try_get::<chrono::DateTime<Utc>, _>(i).ok()),
"JSON" | "JSONB" => DataType::Json(row.try_get::<serde_json::Value, _>(i).ok()),
_ => DataType::Unknown(row.try_get::<String, _>(i).ok()),
};
let column_value = match type_info.as_str() {
"BOOL" | "BOOLEAN" => DataType::Bool(row.try_get::<bool, _>(i).ok()),
"BIT" => DataType::Bytea(row.try_get::<Vec<u8>, _>(i).ok()),
"CHAR" => DataType::Char(row.try_get::<String, _>(i).ok()),
"BIGINT" => DataType::Int8(row.try_get::<i64, _>(i).ok()),
"MEDIUMINT" | "INT" | "INTEGER" => DataType::Int4(row.try_get::<i32, _>(i).ok()),
"TINYINT" | "SMALLINT" => DataType::Int2(row.try_get::<i16, _>(i).ok()),
"TEXT" | "VARCHAR" => DataType::Text(row.try_get::<String, _>(i).ok()),
"FLOAT" => DataType::Float4(row.try_get::<f32, _>(i).ok()),
"DOUBLE" => DataType::Float8(row.try_get::<f64, _>(i).ok()),
"DECIMAL" | "DEC" => DataType::Float8(row.try_get::<f64, _>(i).ok()),
"UUID" => DataType::Uuid(row.try_get::<uuid::Uuid, _>(i).ok()),
"TIMESTAMP" | "DATETIME" => DataType::Timestamp(row.try_get::<chrono::NaiveDateTime, _>(i).ok()),
"DATE" => DataType::Date(row.try_get::<chrono::NaiveDate, _>(i).ok()),
"TIME" => DataType::Time(row.try_get::<chrono::NaiveTime, _>(i).ok()),
"TIMESTAMPTZ" => DataType::Timestamptz(row.try_get::<chrono::DateTime<Utc>, _>(i).ok()),
"JSON" | "JSONB" => DataType::Json(row.try_get::<serde_json::Value, _>(i).ok()),
_ => DataType::Unknown(row.try_get::<String, _>(i).ok()),
};
row_map.insert(column_name.to_string(), column_value);
}
row_map
}));
let row_value_handlers_results = join_all(row_value_handlers).await;
for row_value_handler_result in row_value_handlers_results {
let row = row_value_handler_result.unwrap();
result.push(row)
row_map.insert(column_name.to_string(), column_value);
}
count += 1;
if let Some(row_limit) = limit {
if count >= row_limit {
break;
}
} else if count >= 5000 {
break;
}
result.push(row_map);
}
Ok(result)
}

View File

@ -6,7 +6,6 @@ use indexmap::IndexMap;
use anyhow::{Error, Result};
use sqlx::{Column, Pool, Postgres, Row};
use tokio::task;
use crate::data_types::DataType;
use sqlparser::ast::{Expr, Ident, ObjectName, VisitMut, VisitorMut};
@ -59,6 +58,7 @@ pub async fn postgres_query(
query: String,
limit: Option<i64>,
) -> Result<Vec<IndexMap<std::string::String, DataType>>, Error> {
// Parse the query and quote identifiers
let dialect = PostgreSqlDialect {};
let mut ast = Parser::parse_sql(&dialect, &query)?;
@ -69,106 +69,61 @@ pub async fn postgres_query(
let formatted_sql = ast[0].to_string();
let mut stream = sqlx::raw_sql(&formatted_sql).fetch(&pg_pool);
// Apply the limit directly at the database level
let default_limit = 5000;
let limit_value = limit.unwrap_or(default_limit);
// Append LIMIT to the query with parameter
let sql_with_limit = format!("{} LIMIT $1", formatted_sql);
// Create query with the limit parameter
let mut stream = sqlx::query(&sql_with_limit)
.bind(limit_value)
.fetch(&pg_pool);
let mut result: Vec<IndexMap<String, DataType>> = Vec::new();
let mut count = 0;
let batch_size = 100;
let mut rows = Vec::new();
// Pre-allocate result vector with estimated capacity to reduce allocations
let mut result: Vec<IndexMap<String, DataType>> = Vec::with_capacity(limit_value as usize);
// Process all rows without spawning tasks per row
while let Some(row) = stream.try_next().await? {
rows.push(row);
count += 1;
let mut row_map: IndexMap<String, DataType> = IndexMap::with_capacity(row.len());
if count % batch_size == 0 {
let batch_result = process_batch(rows).await?;
result.extend(batch_result);
rows = Vec::new();
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let type_info = column.type_info().clone().to_string();
let column_value = match type_info.as_str() {
"BOOL" => DataType::Bool(row.try_get::<bool, _>(i).ok()),
"BYTEA" => DataType::Bytea(row.try_get::<Vec<u8>, _>(i).ok()),
"CHAR" => DataType::Char(row.try_get::<String, _>(i).ok()),
"INT8" => DataType::Int8(row.try_get::<i64, _>(i).ok()),
"INT4" => DataType::Int4(row.try_get::<i32, _>(i).ok()),
"INT2" => DataType::Int2(row.try_get::<i16, _>(i).ok()),
"TEXT" | "VARCHAR" | "USER-DEFINED" => DataType::Text(row.try_get::<String, _>(i).ok()),
"FLOAT4" => DataType::Float4(row.try_get::<f32, _>(i).ok()),
"FLOAT8" => DataType::Float8(row.try_get::<f64, _>(i).ok()),
"NUMERIC" => {
DataType::Float8(row.try_get(i).ok().and_then(
|v: sqlx::types::BigDecimal| v.to_string().parse::<f64>().ok(),
))
}
"UUID" => DataType::Uuid(row.try_get::<uuid::Uuid, _>(i).ok()),
"TIMESTAMP" => {
DataType::Timestamp(row.try_get::<chrono::NaiveDateTime, _>(i).ok())
}
"DATE" => DataType::Date(row.try_get::<chrono::NaiveDate, _>(i).ok()),
"TIME" => DataType::Time(row.try_get::<chrono::NaiveTime, _>(i).ok()),
"TIMESTAMPTZ" => {
DataType::Timestamptz(row.try_get::<chrono::DateTime<Utc>, _>(i).ok())
}
"JSON" | "JSONB" => DataType::Json(row.try_get::<serde_json::Value, _>(i).ok()),
_ => DataType::Unknown(row.try_get::<String, _>(i).ok()),
};
row_map.insert(column_name.to_string(), column_value);
}
if let Some(limit) = limit {
if count >= limit {
break;
}
}
}
// Process any remaining rows
if !rows.is_empty() {
let batch_result = process_batch(rows).await?;
result.extend(batch_result);
result.push(row_map);
}
Ok(result)
}
async fn process_batch(
rows: Vec<sqlx::postgres::PgRow>,
) -> Result<Vec<IndexMap<String, DataType>>, Error> {
let mut tasks = Vec::new();
for (index, row) in rows.into_iter().enumerate() {
let task = task::spawn(async move {
let mut row_map: IndexMap<String, DataType> = IndexMap::with_capacity(row.len());
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let type_info = column.type_info().clone().to_string();
let column_value = match type_info.as_str() {
"BOOL" => DataType::Bool(row.try_get::<bool, _>(i).ok()),
"BYTEA" => DataType::Bytea(row.try_get::<Vec<u8>, _>(i).ok()),
"CHAR" => DataType::Char(row.try_get::<String, _>(i).ok()),
"INT8" => DataType::Int8(row.try_get::<i64, _>(i).ok()),
"INT4" => DataType::Int4(row.try_get::<i32, _>(i).ok()),
"INT2" => DataType::Int2(row.try_get::<i16, _>(i).ok()),
"TEXT" | "VARCHAR" | "USER-DEFINED" => DataType::Text(row.try_get::<String, _>(i).ok()),
"FLOAT4" => DataType::Float4(row.try_get::<f32, _>(i).ok()),
"FLOAT8" => DataType::Float8(row.try_get::<f64, _>(i).ok()),
"NUMERIC" => {
DataType::Float8(row.try_get(i).ok().and_then(
|v: sqlx::types::BigDecimal| v.to_string().parse::<f64>().ok(),
))
}
"UUID" => DataType::Uuid(row.try_get::<uuid::Uuid, _>(i).ok()),
"TIMESTAMP" => {
DataType::Timestamp(row.try_get::<chrono::NaiveDateTime, _>(i).ok())
}
"DATE" => DataType::Date(row.try_get::<chrono::NaiveDate, _>(i).ok()),
"TIME" => DataType::Time(row.try_get::<chrono::NaiveTime, _>(i).ok()),
"TIMESTAMPTZ" => {
DataType::Timestamptz(row.try_get::<chrono::DateTime<Utc>, _>(i).ok())
}
"JSON" | "JSONB" => DataType::Json(row.try_get::<serde_json::Value, _>(i).ok()),
_ => DataType::Unknown(row.try_get::<String, _>(i).ok()),
};
row_map.insert(column_name.to_string(), column_value);
}
(index, row_map)
});
tasks.push(task);
}
let batch_result: Vec<_> = match futures::future::try_join_all(tasks).await {
Ok(batch_result) => batch_result,
Err(e) => {
tracing::error!("Error joining tasks: {:?}", e);
Vec::new()
}
};
let mut sorted_result: Vec<(usize, IndexMap<String, DataType>)> =
batch_result.into_iter().collect();
sorted_result.sort_by_key(|(index, _)| *index);
let final_result: Vec<IndexMap<String, DataType>> = sorted_result
.into_iter()
.map(|(_, row_map)| row_map)
.collect();
Ok(final_result)
}

View File

@ -50,6 +50,126 @@ pub async fn query_engine(
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::postgres::PgPoolOptions;
use std::env;
// Test that postgres_query properly applies the limit at the database level
#[tokio::test]
async fn test_postgres_query_with_limit() {
use crate::data_source_query_routes::postgres_query::postgres_query;
// Skip test if no test database is available
let database_url = match env::var("TEST_DATABASE_URL") {
Ok(url) => url,
Err(_) => return, // Skip test if env var not available
};
// Create a pool with the test database
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
.expect("Failed to connect to Postgres");
// Test with explicit limit
let results = postgres_query(
pool.clone(),
"SELECT generate_series(1, 100) AS num".to_string(),
Some(10),
)
.await
.expect("Query should succeed");
assert_eq!(results.len(), 10, "Should return exactly 10 rows with limit 10");
// Test with default limit (5000)
let results = postgres_query(
pool.clone(),
"SELECT generate_series(1, 6000) AS num".to_string(),
None,
)
.await
.expect("Query should succeed");
assert_eq!(results.len(), 5000, "Should return exactly 5000 rows with default limit");
// Test with limit greater than default
let results = postgres_query(
pool,
"SELECT generate_series(1, 6000) AS num".to_string(),
Some(6000),
)
.await
.expect("Query should succeed");
assert_eq!(results.len(), 6000, "Should return exactly 6000 rows with limit 6000");
}
// Test that mysql_query properly applies the limit at the database level
#[tokio::test]
async fn test_mysql_query_with_limit() {
use crate::data_source_query_routes::mysql_query::mysql_query;
use sqlx::mysql::MySqlPoolOptions;
// Skip test if no test database is available
let database_url = match env::var("TEST_MYSQL_URL") {
Ok(url) => url,
Err(_) => return, // Skip test if env var not available
};
// Create a pool with the test database
let pool = MySqlPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
.expect("Failed to connect to MySQL");
// Test with explicit limit
let results = mysql_query(
pool.clone(),
"SELECT * FROM (SELECT 1 AS num UNION SELECT 2 UNION SELECT 3 UNION SELECT 4 UNION SELECT 5 UNION SELECT 6 UNION SELECT 7 UNION SELECT 8 UNION SELECT 9 UNION SELECT 10) AS t".to_string(),
Some(5),
)
.await
.expect("Query should succeed");
assert_eq!(results.len(), 5, "Should return exactly 5 rows with limit 5");
}
// Test parsing functions in the bigquery connector
#[test]
fn test_bigquery_string_parsing() {
use crate::data_source_query_routes::bigquery_query;
// Test integer parsing
match bigquery_query::parse_string_to_datatype("123") {
DataType::Int4(Some(123)) => {}, // Success
other => panic!("Expected Int4(123), got {:?}", other),
}
// Test boolean parsing
match bigquery_query::parse_string_to_datatype("true") {
DataType::Bool(Some(true)) => {}, // Success
other => panic!("Expected Bool(true), got {:?}", other),
}
// Test date parsing
match bigquery_query::parse_string_to_datatype("2023-01-01") {
DataType::Date(Some(_)) => {}, // Success
other => panic!("Expected Date, got {:?}", other),
}
// Test text fallback
match bigquery_query::parse_string_to_datatype("hello world") {
DataType::Text(Some(text)) if text == "hello world" => {}, // Success
other => panic!("Expected Text(hello world), got {:?}", other),
}
}
}
async fn route_to_query(
data_source_id: &Uuid,
sql: &str,

View File

@ -13,14 +13,28 @@ pub async fn redshift_query(
query: String,
limit: Option<i64>,
) -> Result<Vec<IndexMap<std::string::String, DataType>>, Error> {
let mut stream = sqlx::query(&query).fetch(&pg_pool);
// Apply the limit directly at the database level
let default_limit = 5000;
let limit_value = limit.unwrap_or(default_limit);
// Append LIMIT to the query if it doesn't already contain a LIMIT clause
let sql_with_limit = if !query.to_lowercase().contains("limit") {
format!("{} LIMIT $1", query)
} else {
query
};
// Create query with the limit parameter
let mut stream = sqlx::query(&sql_with_limit)
.bind(limit_value)
.fetch(&pg_pool);
let mut result: Vec<IndexMap<String, DataType>> = Vec::new();
let mut count = 0;
// Pre-allocate result vector with estimated capacity
let mut result: Vec<IndexMap<String, DataType>> = Vec::with_capacity(limit_value as usize);
// Process rows sequentially
while let Some(row) = stream.try_next().await? {
let mut row_map: IndexMap<String, DataType> = IndexMap::new();
let mut row_map: IndexMap<String, DataType> = IndexMap::with_capacity(row.len());
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
@ -55,15 +69,7 @@ pub async fn redshift_query(
}
result.push(row_map);
count += 1;
if let Some(row_limit) = limit {
if count >= row_limit {
break;
}
} else if count >= 5000 {
break;
}
}
Ok(result)
}

View File

@ -1,10 +1,9 @@
use crate::data_types::DataType;
use anyhow::{anyhow, Error, Result};
use chrono::NaiveDateTime;
use futures::future::join_all;
use indexmap::IndexMap;
use tiberius::{numeric::Decimal, Client, ColumnType};
use tokio::{net::TcpStream, task};
use tokio::net::TcpStream;
use tokio_util::compat::Compat;
pub async fn sql_server_query(
@ -12,82 +11,99 @@ pub async fn sql_server_query(
query: String,
limit: Option<i64>,
) -> Result<Vec<IndexMap<std::string::String, DataType>>, Error> {
let rows = match client.query(query, &[]).await {
// Apply the limit directly at the database level
let default_limit = 5000;
let limit_value = limit.unwrap_or(default_limit);
// Check if query already has TOP/OFFSET syntax
let sql_with_limit = if !query.to_lowercase().contains("top") && !query.to_lowercase().contains("offset") {
// Add TOP clause for SQL Server
let trimmed_query = query.trim_start();
// Find position of SELECT to insert TOP after it
if let Some(select_pos) = trimmed_query.to_lowercase().find("select") {
let (before_select, after_select) = trimmed_query.split_at(select_pos + 6);
format!("{} TOP({}) {}", before_select, limit_value, after_select)
} else {
// If no SELECT found, return original query
query
}
} else {
query
};
// Execute the query with limit
let rows = match client.query(&sql_with_limit, &[]).await {
Ok(rows) => rows,
Err(e) => {
tracing::error!("Unable to execute query: {:?}", e);
let err = anyhow!("Unable to execute query: {}", e);
return Err(err);
return Err(anyhow!("Unable to execute query: {}", e));
}
};
let mut result: Vec<IndexMap<String, DataType>> = Vec::new();
// Pre-allocate result vector with estimated capacity
let mut result: Vec<IndexMap<String, DataType>> = Vec::with_capacity(limit_value as usize);
let query_result = match rows.into_first_result().await {
Ok(query_result) => query_result.into_iter().take(limit.unwrap_or(5000) as usize),
Ok(query_result) => query_result,
Err(e) => {
tracing::error!("Unable to fetch query result: {:?}", e);
let err = anyhow!("Unable to fetch query result: {}", e);
return Err(err);
return Err(anyhow!("Unable to fetch query result: {}", e));
}
};
// Process rows sequentially without spawning tasks
for row in query_result {
let mut row_value_handlers = Vec::new();
row_value_handlers.push(task::spawn(async move {
let mut row_map = IndexMap::new();
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let type_info = column.column_type();
let column_value = match type_info {
ColumnType::Text
| ColumnType::NVarchar
| ColumnType::NChar
| ColumnType::BigChar
| ColumnType::NText
| ColumnType::BigVarChar => {
DataType::Text(row.get::<&str, _>(i).map(|v| v.to_string()))
}
ColumnType::Int8 => DataType::Bool(row.get::<bool, _>(i)),
ColumnType::Int4 => DataType::Int4(row.get::<i32, _>(i)),
ColumnType::Int2 | ColumnType::Int1 => DataType::Int2(row.get::<i16, _>(i)),
ColumnType::Float4 => DataType::Float4(row.get::<f32, _>(i)),
ColumnType::Float8 => DataType::Float8(row.get::<f64, _>(i)),
ColumnType::Bit => DataType::Bool(row.get::<bool, _>(i)),
ColumnType::Null => DataType::Null,
ColumnType::Datetime4 => {
DataType::Timestamp(row.get::<NaiveDateTime, _>(i))
}
ColumnType::Money => DataType::Int8(row.get::<i64, _>(i)),
ColumnType::Datetime => DataType::Timestamp(row.get::<NaiveDateTime, _>(i)),
ColumnType::Money4 => DataType::Int8(row.get::<i64, _>(i)),
ColumnType::Guid => DataType::Uuid(row.get::<uuid::Uuid, _>(i)),
ColumnType::Intn => DataType::Int4(row.get::<i32, _>(i)),
ColumnType::Decimaln => DataType::Decimal(row.get::<Decimal, _>(i)),
ColumnType::Numericn => DataType::Decimal(row.get::<Decimal, _>(i)),
ColumnType::Floatn => DataType::Float8(row.get::<f64, _>(i)),
ColumnType::Datetimen => {
DataType::Timestamp(row.get::<NaiveDateTime, _>(i))
}
ColumnType::Daten => DataType::Date(row.get::<NaiveDateTime, _>(i).map(|v| v.date())),
ColumnType::Timen => DataType::Time(row.get::<NaiveDateTime, _>(i).map(|v| v.time())),
ColumnType::Datetime2 => DataType::Timestamp(row.get::<NaiveDateTime, _>(i)),
ColumnType::DatetimeOffsetn => DataType::Timestamp(row.get::<NaiveDateTime, _>(i)),
_ => {
tracing::debug!("No match found");
DataType::Null
}
};
tracing::debug!("column_value: {:?}", column_value);
row_map.insert(column_name.to_string(), column_value);
}
row_map
}));
let row_value_handlers_results = join_all(row_value_handlers).await;
for row_value_handler_result in row_value_handlers_results {
let row = row_value_handler_result.unwrap();
result.push(row);
let mut row_map = IndexMap::with_capacity(row.columns().len());
for (i, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let type_info = column.column_type();
let column_value = match type_info {
ColumnType::Text
| ColumnType::NVarchar
| ColumnType::NChar
| ColumnType::BigChar
| ColumnType::NText
| ColumnType::BigVarChar => {
DataType::Text(row.get::<&str, _>(i).map(|v| v.to_string()))
}
ColumnType::Int8 => DataType::Bool(row.get::<bool, _>(i)),
ColumnType::Int4 => DataType::Int4(row.get::<i32, _>(i)),
ColumnType::Int2 | ColumnType::Int1 => DataType::Int2(row.get::<i16, _>(i)),
ColumnType::Float4 => DataType::Float4(row.get::<f32, _>(i)),
ColumnType::Float8 => DataType::Float8(row.get::<f64, _>(i)),
ColumnType::Bit => DataType::Bool(row.get::<bool, _>(i)),
ColumnType::Null => DataType::Null,
ColumnType::Datetime4 => {
DataType::Timestamp(row.get::<NaiveDateTime, _>(i))
}
ColumnType::Money => DataType::Int8(row.get::<i64, _>(i)),
ColumnType::Datetime => DataType::Timestamp(row.get::<NaiveDateTime, _>(i)),
ColumnType::Money4 => DataType::Int8(row.get::<i64, _>(i)),
ColumnType::Guid => DataType::Uuid(row.get::<uuid::Uuid, _>(i)),
ColumnType::Intn => DataType::Int4(row.get::<i32, _>(i)),
ColumnType::Decimaln => DataType::Decimal(row.get::<Decimal, _>(i)),
ColumnType::Numericn => DataType::Decimal(row.get::<Decimal, _>(i)),
ColumnType::Floatn => DataType::Float8(row.get::<f64, _>(i)),
ColumnType::Datetimen => {
DataType::Timestamp(row.get::<NaiveDateTime, _>(i))
}
ColumnType::Daten => DataType::Date(row.get::<NaiveDateTime, _>(i).map(|v| v.date())),
ColumnType::Timen => DataType::Time(row.get::<NaiveDateTime, _>(i).map(|v| v.time())),
ColumnType::Datetime2 => DataType::Timestamp(row.get::<NaiveDateTime, _>(i)),
ColumnType::DatetimeOffsetn => DataType::Timestamp(row.get::<NaiveDateTime, _>(i)),
_ => {
tracing::debug!("No match found for type: {:?}", type_info);
DataType::Null
}
};
row_map.insert(column_name.to_string(), column_value);
}
result.push(row_map);
}
Ok(result)
}