Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/datafusion_integration/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use datafusion::prelude::SessionConfig;
use crate::kubernetes::K8sClientPool;
use crate::output::QueryResult;

use super::preprocess::{preprocess_sql, validate_read_only};
use super::preprocess::{preprocess_sql_with_registry, validate_read_only};
use super::provider::K8sTableProvider;

/// Table information with native types (for data layer)
Expand Down Expand Up @@ -95,8 +95,15 @@ impl K8sSessionContext {

/// Execute a SQL query and return the results as Arrow RecordBatches
pub async fn execute_sql(&self, sql: &str) -> DFResult<Vec<RecordBatch>> {
// Preprocess first (compiles PRQL to SQL if detected, fixes arrow precedence)
let processed_sql = preprocess_sql(sql)
// Get registry for table-aware JSON column detection
let registry = self
.pool
.get_registry(None)
.await
.map_err(|e| datafusion::error::DataFusionError::Plan(e.to_string()))?;

// Preprocess with table-aware JSON columns (compiles PRQL, converts JSON paths)
let processed_sql = preprocess_sql_with_registry(sql, &registry)
.map_err(|e| datafusion::error::DataFusionError::Plan(e.to_string()))?;

// Validate the resulting SQL is read-only
Expand Down
184 changes: 183 additions & 1 deletion src/datafusion_integration/preprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@
//! ```

use super::{json_path, prql};
use crate::kubernetes::discovery::ResourceRegistry;
use anyhow::Result;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer};
use regex::Regex;
use std::collections::HashSet;
use std::sync::LazyLock;

/// Regex to match arrows followed by comparison operators (left side)
Expand Down Expand Up @@ -120,7 +123,112 @@ fn fix_arrow_precedence(sql: &str) -> String {
.into_owned()
}

/// Preprocess a query for execution.
/// Extract table names from a SQL query.
///
/// Uses DataFusion's tokenizer to find identifiers following FROM and JOIN keywords.
/// This is a best-effort extraction - missing some tables is acceptable since we
/// fall back to DEFAULT_JSON_OBJECT_COLUMNS for unrecognized columns.
///
/// # Examples
///
/// ```ignore
/// extract_table_names("SELECT * FROM pods") // -> ["pods"]
/// extract_table_names("SELECT * FROM pods p JOIN services s ON ...") // -> ["pods", "services"]
/// extract_table_names("SELECT * FROM pods WHERE x IN (SELECT * FROM namespaces)") // -> ["pods", "namespaces"]
/// ```
fn extract_table_names(sql: &str) -> Vec<String> {
let dialect = PostgreSqlDialect {};
let tokens = match Tokenizer::new(&dialect, sql).tokenize() {
Ok(t) => t,
Err(_) => return vec![],
};

let mut table_names = Vec::new();
let mut i = 0;

while i < tokens.len() {
// Look for FROM or JOIN keywords
if let Token::Word(word) = &tokens[i] {
let keyword = word.value.to_uppercase();
if keyword == "FROM" || keyword == "JOIN" {
// Skip whitespace and find the next identifier
i += 1;
while i < tokens.len() && matches!(tokens[i], Token::Whitespace(_)) {
i += 1;
}

// The next token should be a table name (identifier or word)
if let Some(Token::Word(table_word)) = tokens.get(i) {
// Skip keywords that might follow FROM (like SELECT in subqueries)
let upper = table_word.value.to_uppercase();
if !matches!(
upper.as_str(),
"SELECT" | "WITH" | "LATERAL" | "UNNEST" | "("
) {
table_names.push(table_word.value.to_lowercase());
}
}
}
}
i += 1;
}

table_names
}

/// Build a set of JSON column names from the registry for the given tables.
///
/// Merges DEFAULT_JSON_OBJECT_COLUMNS with table-specific JSON columns.
fn build_json_columns_for_tables(
table_names: &[String],
registry: &ResourceRegistry,
) -> HashSet<String> {
// Start with default columns (always available)
let mut columns = json_path::build_json_columns_set(&[]);

// Add table-specific JSON columns
columns.extend(registry.get_json_columns_for_tables(table_names));

columns
}

/// Preprocess a SQL query with table-aware JSON column detection.
///
/// This is the primary preprocessing function when a ResourceRegistry is available.
/// It extracts table names from the query and looks up their JSON columns for
/// accurate dot-notation conversion.
///
/// # Arguments
///
/// * `sql` - The SQL or PRQL query to preprocess
/// * `registry` - The resource registry containing table schemas
///
/// # Returns
///
/// The preprocessed SQL ready for execution
pub fn preprocess_sql_with_registry(sql: &str, registry: &ResourceRegistry) -> Result<String> {
// Step 1: Compile PRQL to SQL if detected
let sql = if prql::is_prql(sql) {
let prql_preprocessed = prql::preprocess_prql_json_paths(sql);
prql::compile_prql(&prql_preprocessed)?
} else {
sql.to_string()
};

// Step 2: Extract table names from the SQL
let table_names = extract_table_names(&sql);

// Step 3: Build JSON columns set from registry for these tables
let json_columns = build_json_columns_for_tables(&table_names, registry);

// Step 4: Convert JSON path syntax with table-aware columns
let sql = json_path::preprocess_json_paths(&sql, Some(&json_columns));

// Step 5: Fix JSON arrow operator precedence
Ok(fix_arrow_precedence(&sql))
}

/// Preprocess a query for execution (without registry - uses defaults only).
///
/// This function handles:
/// 1. **PRQL detection and compilation**: Queries starting with `from`, `let`, or `prql`
Expand All @@ -132,6 +240,9 @@ fn fix_arrow_precedence(sql: &str) -> String {
/// 3. **JSON arrow precedence fix**: Wraps arrow expressions in parentheses when used
/// with comparison operators to work around DataFusion parser precedence.
///
/// Note: This function only recognizes DEFAULT_JSON_OBJECT_COLUMNS (spec, status, labels, etc.).
/// For table-aware JSON column detection, use `preprocess_sql_with_registry` instead.
///
/// # Examples
///
/// ```ignore
Expand All @@ -153,6 +264,7 @@ fn fix_arrow_precedence(sql: &str) -> String {
/// preprocess_sql("SELECT * FROM pods WHERE labels->>'app' = 'nginx'")?;
/// // Returns: SELECT * FROM pods WHERE (labels->>'app') = 'nginx'
/// ```
#[allow(dead_code)] // Used by tests and for backward compatibility
pub fn preprocess_sql(sql: &str) -> Result<String> {
// Step 1: Compile PRQL to SQL if detected
let sql = if prql::is_prql(sql) {
Expand Down Expand Up @@ -709,4 +821,74 @@ mod tests {
let result = preprocess_sql(sql).unwrap();
assert_eq!(result, sql);
}

// ==================== Table extraction tests ====================

#[test]
fn test_extract_table_names_simple() {
let tables = extract_table_names("SELECT * FROM pods");
assert_eq!(tables, vec!["pods"]);
}

#[test]
fn test_extract_table_names_with_alias() {
let tables = extract_table_names("SELECT * FROM pods p");
assert_eq!(tables, vec!["pods"]);
}

#[test]
fn test_extract_table_names_join() {
let tables = extract_table_names("SELECT * FROM pods p JOIN services s ON p.name = s.name");
assert_eq!(tables, vec!["pods", "services"]);
}

#[test]
fn test_extract_table_names_multiple_joins() {
let tables = extract_table_names(
"SELECT * FROM pods p \
JOIN services s ON p.name = s.name \
JOIN deployments d ON d.name = p.name",
);
assert_eq!(tables, vec!["pods", "services", "deployments"]);
}

#[test]
fn test_extract_table_names_subquery() {
let tables = extract_table_names(
"SELECT * FROM pods WHERE namespace IN (SELECT name FROM namespaces)",
);
assert_eq!(tables, vec!["pods", "namespaces"]);
}

#[test]
fn test_extract_table_names_case_insensitive() {
let tables = extract_table_names("SELECT * FROM Pods");
assert_eq!(tables, vec!["pods"]); // lowercased
}

#[test]
fn test_extract_table_names_left_join() {
let tables = extract_table_names("SELECT * FROM pods LEFT JOIN services ON true");
// LEFT is a keyword before JOIN, so we get both tables
assert!(tables.contains(&"pods".to_string()));
assert!(tables.contains(&"services".to_string()));
}

// ==================== Table-aware preprocessing tests ====================
// Note: These tests need a registry, which requires more setup.
// The integration is tested via execute_sql in context.rs.

#[test]
fn test_build_json_columns_includes_defaults() {
// Even with no registry tables found, defaults should be present
use crate::kubernetes::discovery::ResourceRegistry;
let registry = ResourceRegistry::new();
let columns = build_json_columns_for_tables(&[], &registry);

// Should include default columns
assert!(columns.contains("spec"));
assert!(columns.contains("status"));
assert!(columns.contains("labels"));
assert!(columns.contains("annotations"));
}
}
Loading