Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 194 additions & 44 deletions datafusion/datasource-parquet/src/opener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{
apply_file_schema_type_coercions, coerce_int96_to_resolution, row_filter,
};
use arrow::array::{RecordBatch, RecordBatchOptions};
use arrow::datatypes::DataType;
use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener};
use datafusion_physical_expr::projection::ProjectionExprs;
use datafusion_physical_expr::utils::reassign_expr_columns;
Expand All @@ -35,8 +36,10 @@ use std::task::{Context, Poll};

use arrow::datatypes::{SchemaRef, TimeUnit};
use datafusion_common::encryption::FileDecryptionProperties;

use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
use datafusion_common::stats::Precision;
use datafusion_common::{
ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics, exec_err,
};
use datafusion_datasource::{PartitionedFile, TableSchema};
use datafusion_physical_expr::simplifier::PhysicalExprSimplifier;
use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory;
Expand Down Expand Up @@ -137,59 +140,60 @@ impl FileOpener for ParquetOpener {

let batch_size = self.batch_size;

// Build partition values map for replacing partition column references
// with their literal values from this file's partition values.
// Calculate the output schema from the original projection (before literal replacement)
// so we get correct field names from column references
let logical_file_schema = Arc::clone(self.table_schema.file_schema());
let output_schema = Arc::new(
self.projection
.project_schema(self.table_schema.table_schema())?,
);

// Build a combined map for replacing column references with literal values.
// This includes:
// 1. Partition column values from the file path (e.g., region=us-west-2)
// 2. Constant columns detected from file statistics (where min == max)
//
// For example, given
// 1. `region` is a partition column,
// 2. predicate `host IN ('us-east-1', 'eu-central-1')`:
// 3. The file path is `/data/region=us-west-2/...`
// (that is the partition column value is `us-west-2`)
// Although partition columns *are* constant columns, we don't want to rely on
// statistics for them being populated if we can use the partition values
// (which are guaranteed to be present).
//
// The predicate would be rewritten to
// ```sql
// 'us-west-2` IN ('us-east-1', 'eu-central-1')
// ```
// which can be further simplified to `FALSE`, meaning
// the file can be skipped entirely.
// For example, given a partition column `region` and predicate
// `region IN ('us-east-1', 'eu-central-1')` with file path
// `/data/region=us-west-2/...`, the predicate is rewritten to
// `'us-west-2' IN ('us-east-1', 'eu-central-1')` which simplifies to FALSE.
//
// While this particular optimization is done during logical planning,
// there are other cases where partition columns may appear in more
// complex predicates that cannot be simplified until we are about to
// open the file (such as dynamic predicates)
let partition_values: HashMap<&str, &ScalarValue> = self
// While partition column optimization is done during logical planning,
// there are cases where partition columns may appear in more complex
// predicates that cannot be simplified until we open the file (such as
// dynamic predicates).
let mut literal_columns: HashMap<String, ScalarValue> = self
.table_schema
.table_partition_cols()
.iter()
.zip(partitioned_file.partition_values.iter())
.map(|(field, value)| (field.name().as_str(), value))
.map(|(field, value)| (field.name().clone(), value.clone()))
.collect();
// Add constant columns from file statistics.
// Note that if there are statistics for partition columns there will be overlap,
// but since we use a HashMap, we'll just overwrite the partition values with the
// constant values from statistics (which should be the same).
literal_columns.extend(constant_columns_from_stats(
partitioned_file.statistics.as_deref(),
&logical_file_schema,
));

// Calculate the output schema from the original projection (before literal replacement)
// so we get correct field names from column references
let logical_file_schema = Arc::clone(self.table_schema.file_schema());
let output_schema = Arc::new(
self.projection
.project_schema(self.table_schema.table_schema())?,
);

// Apply partition column replacement to projection expressions
// Apply literal replacements to projection and predicate
let mut projection = self.projection.clone();
if !partition_values.is_empty() {
let mut predicate = self.predicate.clone();
if !literal_columns.is_empty() {
projection = projection.try_map_exprs(|expr| {
replace_columns_with_literals(Arc::clone(&expr), &partition_values)
replace_columns_with_literals(Arc::clone(&expr), &literal_columns)
})?;
predicate = predicate
.map(|p| replace_columns_with_literals(p, &literal_columns))
.transpose()?;
}

// Apply partition column replacement to predicate
let mut predicate = if partition_values.is_empty() {
self.predicate.clone()
} else {
self.predicate
.clone()
.map(|p| replace_columns_with_literals(p, &partition_values))
.transpose()?
};
let reorder_predicates = self.reorder_filters;
let pushdown_filters = self.pushdown_filters;
let force_filter_selections = self.force_filter_selections;
Expand Down Expand Up @@ -581,6 +585,64 @@ fn copy_arrow_reader_metrics(
}
}

type ConstantColumns = HashMap<String, ScalarValue>;

/// Extract constant column values from statistics, keyed by column name in the logical file schema.
fn constant_columns_from_stats(
statistics: Option<&Statistics>,
file_schema: &SchemaRef,
) -> ConstantColumns {
let mut constants = HashMap::new();
let Some(statistics) = statistics else {
return constants;
};

let num_rows = match statistics.num_rows {
Precision::Exact(num_rows) => Some(num_rows),
_ => None,
};

for (idx, column_stats) in statistics
.column_statistics
.iter()
.take(file_schema.fields().len())
.enumerate()
{
let field = file_schema.field(idx);
if let Some(value) =
constant_value_from_stats(column_stats, num_rows, field.data_type())
{
constants.insert(field.name().clone(), value);
}
}

constants
}

fn constant_value_from_stats(
column_stats: &ColumnStatistics,
num_rows: Option<usize>,
data_type: &DataType,
) -> Option<ScalarValue> {
if let (Precision::Exact(min), Precision::Exact(max)) =
(&column_stats.min_value, &column_stats.max_value)
&& min == max
&& !min.is_null()
&& matches!(column_stats.null_count, Precision::Exact(0))
{
return Some(min.clone());
}

if let (Some(num_rows), Precision::Exact(nulls)) =
(num_rows, &column_stats.null_count)
&& *nulls == num_rows
{
return ScalarValue::try_new_null(data_type).ok();
}

None
}

/// Wraps an inner RecordBatchStream and a [`FilePruner`]
///
/// This can terminate the scan early when some dynamic filters is updated after
Expand Down Expand Up @@ -841,7 +903,8 @@ fn should_enable_page_index(
mod test {
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};
use super::{ConstantColumns, constant_columns_from_stats};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use bytes::{BufMut, BytesMut};
use datafusion_common::{
ColumnStatistics, DataFusionError, ScalarValue, Statistics, record_batch,
Expand All @@ -850,17 +913,104 @@ mod test {
use datafusion_datasource::{PartitionedFile, TableSchema, file_stream::FileOpener};
use datafusion_expr::{col, lit};
use datafusion_physical_expr::{
PhysicalExpr, expressions::DynamicFilterPhysicalExpr, planner::logical2physical,
PhysicalExpr,
expressions::{Column, DynamicFilterPhysicalExpr, Literal},
planner::logical2physical,
projection::ProjectionExprs,
};
use datafusion_physical_expr_adapter::DefaultPhysicalExprAdapterFactory;
use datafusion_physical_expr_adapter::{
DefaultPhysicalExprAdapterFactory, replace_columns_with_literals,
};
use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet;
use futures::{Stream, StreamExt};
use object_store::{ObjectStore, memory::InMemory, path::Path};
use parquet::arrow::ArrowWriter;

use crate::{DefaultParquetFileReaderFactory, opener::ParquetOpener};

fn constant_int_stats() -> (Statistics, SchemaRef) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
let statistics = Statistics {
num_rows: Precision::Exact(3),
total_byte_size: Precision::Absent,
column_statistics: vec![
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::from(5i32)),
min_value: Precision::Exact(ScalarValue::from(5i32)),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Absent,
},
ColumnStatistics::new_unknown(),
],
};
(statistics, schema)
}

#[test]
fn extract_constant_columns_non_null() {
let (statistics, schema) = constant_int_stats();
let constants = constant_columns_from_stats(Some(&statistics), &schema);
assert_eq!(constants.len(), 1);
assert_eq!(constants.get("a"), Some(&ScalarValue::from(5i32)));
assert!(!constants.contains_key("b"));
}

#[test]
fn extract_constant_columns_all_null() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
let statistics = Statistics {
num_rows: Precision::Exact(2),
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics {
null_count: Precision::Exact(2),
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Absent,
}],
};

let constants = constant_columns_from_stats(Some(&statistics), &schema);
assert_eq!(
constants.get("a"),
Some(&ScalarValue::Utf8(None)),
"all-null column should be treated as constant null"
);
}

#[test]
fn rewrite_projection_to_literals() {
let (statistics, schema) = constant_int_stats();
let constants = constant_columns_from_stats(Some(&statistics), &schema);
let projection = ProjectionExprs::from_indices(&[0, 1], &schema);

let rewritten = projection
.try_map_exprs(|expr| replace_columns_with_literals(expr, &constants))
.unwrap();
let exprs = rewritten.as_ref();
assert!(exprs[0].expr.as_any().downcast_ref::<Literal>().is_some());
assert!(exprs[1].expr.as_any().downcast_ref::<Column>().is_some());

// Only column `b` should remain in the projection mask
assert_eq!(rewritten.column_indices(), vec![1]);
}

#[test]
fn rewrite_physical_expr_literal() {
let mut constants = ConstantColumns::new();
constants.insert("a".to_string(), ScalarValue::from(7i32));
let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));

let rewritten = replace_columns_with_literals(expr, &constants).unwrap();
assert!(rewritten.as_any().downcast_ref::<Literal>().is_some());
}

async fn count_batches_and_rows(
mut stream: std::pin::Pin<
Box<
Expand Down
18 changes: 13 additions & 5 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
//! [`PhysicalExprAdapterFactory`], default implementations,
//! and [`replace_columns_with_literals`].

use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;

use arrow::compute::can_cast_types;
Expand Down Expand Up @@ -50,19 +52,25 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
/// # Arguments
/// - `expr`: The physical expression in which to replace column references.
/// - `replacements`: A mapping from column names to their corresponding literal `ScalarValue`s.
/// Accepts various HashMap types including `HashMap<&str, &ScalarValue>`,
/// `HashMap<String, ScalarValue>`, `HashMap<String, &ScalarValue>`, etc.
///
/// # Returns
/// - `Result<Arc<dyn PhysicalExpr>>`: The rewritten physical expression with columns replaced by literals.
pub fn replace_columns_with_literals(
pub fn replace_columns_with_literals<K, V>(
expr: Arc<dyn PhysicalExpr>,
replacements: &HashMap<&str, &ScalarValue>,
) -> Result<Arc<dyn PhysicalExpr>> {
expr.transform(|expr| {
replacements: &HashMap<K, V>,
) -> Result<Arc<dyn PhysicalExpr>>
where
K: Borrow<str> + Eq + Hash,
V: Borrow<ScalarValue>,
{
expr.transform_down(|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>()
&& let Some(replacement_value) = replacements.get(column.name())
{
return Ok(Transformed::yes(expressions::lit(
(*replacement_value).clone(),
replacement_value.borrow().clone(),
)));
}
Ok(Transformed::no(expr))
Expand Down
Loading