Skip to content

Commit d6836d1

Browse files
ccciudatuadragomir
authored andcommitted
[HSTACK] Building blocks for Ray DataFusionDatasource
1 parent 1a0896d commit d6836d1

File tree

4 files changed

+146
-7
lines changed

4 files changed

+146
-7
lines changed

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from . import functions, object_store, substrait
3030

3131
# The following imports are okay to remain as opaque to the user.
32-
from ._internal import Config
32+
from ._internal import Config, partition_stream
3333
from .catalog import Catalog, Database, Table
3434
from .common import (
3535
DFSchema,
@@ -86,6 +86,7 @@
8686
"read_avro",
8787
"read_csv",
8888
"read_json",
89+
"partition_stream",
8990
]
9091

9192

python/datafusion/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ def count(self) -> int:
805805
"""
806806
return self.df.count()
807807

808+
def distributed_plan(self):
809+
return self.df.distributed_plan()
810+
808811
@deprecated("Use :py:func:`unnest_columns` instead.")
809812
def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame:
810813
"""See :py:func:`unnest_columns`."""

src/dataframe.rs

Lines changed: 139 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,38 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727
use datafusion::arrow::datatypes::Schema;
2828
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2929
use datafusion::arrow::util::pretty;
30-
use datafusion::common::UnnestOptions;
31-
use datafusion::config::{CsvOptions, TableParquetOptions};
30+
use datafusion::common::stats::Precision;
31+
use datafusion::common::{DFSchema, DataFusionError, Statistics, UnnestOptions};
32+
use datafusion::common::tree_node::{Transformed, TreeNode};
33+
use datafusion::config::{ConfigOptions, CsvOptions, TableParquetOptions};
3234
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3335
use datafusion::datasource::TableProvider;
36+
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
37+
use datafusion::datasource::physical_plan::ParquetExec;
3438
use datafusion::execution::SendableRecordBatchStream;
3539
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
40+
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
3641
use datafusion::prelude::*;
42+
43+
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
44+
use datafusion_proto::protobuf::PhysicalPlanNode;
45+
use deltalake::delta_datafusion::DeltaPhysicalCodec;
46+
use prost::Message;
3747
use pyo3::exceptions::PyValueError;
3848
use pyo3::prelude::*;
3949
use pyo3::pybacked::PyBackedStr;
4050
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
4151
use tokio::task::JoinHandle;
4252

4353
use crate::catalog::PyTable;
54+
use crate::common::df_schema::PyDFSchema;
4455
use crate::errors::{py_datafusion_err, PyDataFusionError};
4556
use crate::expr::sort_expr::to_sort_expressions;
4657
use crate::physical_plan::PyExecutionPlan;
4758
use crate::record_batch::PyRecordBatchStream;
4859
use crate::sql::logical::PyLogicalPlan;
4960
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
50-
use crate::{
51-
errors::PyDataFusionResult,
52-
expr::{sort_expr::PySortExpr, PyExpr},
53-
};
61+
use crate::{errors::PyDataFusionResult, expr::{sort_expr::PySortExpr, PyExpr}};
5462

5563
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
5664
// - we have not decided on the table_provider approach yet
@@ -697,6 +705,131 @@ impl PyDataFrame {
697705
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
698706
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
699707
}
708+
709+
fn distributed_plan(&self, py: Python<'_>) -> PyResult<DistributedPlan> {
710+
let future_plan = self.df.as_ref().clone().create_physical_plan();
711+
wait_for_future(py, future_plan)
712+
.map(DistributedPlan::new)
713+
.map_err(py_datafusion_err)
714+
}
715+
716+
}
717+
718+
#[pyclass(get_all)]
719+
#[derive(Debug, Clone)]
720+
pub struct DistributedPlan {
721+
physical_plan: PyExecutionPlan,
722+
}
723+
724+
#[pymethods]
725+
impl DistributedPlan {
726+
727+
fn serialize(&self) -> PyResult<Vec<u8>> {
728+
PhysicalPlanNode::try_from_physical_plan(self.plan().clone(), codec())
729+
.map(|node| node.encode_to_vec())
730+
.map_err(py_datafusion_err)
731+
}
732+
733+
fn partition_count(&self) -> usize {
734+
self.plan().output_partitioning().partition_count()
735+
}
736+
737+
fn num_bytes(&self) -> Option<usize> {
738+
self.stats_field(|stats| stats.total_byte_size)
739+
}
740+
741+
fn num_rows(&self) -> Option<usize> {
742+
self.stats_field(|stats| stats.num_rows)
743+
}
744+
745+
fn schema(&self) -> PyResult<PyDFSchema> {
746+
DFSchema::try_from(self.plan().schema())
747+
.map(PyDFSchema::from)
748+
.map_err(py_datafusion_err)
749+
}
750+
751+
fn set_desired_parallelism(&mut self, desired_parallelism: usize) -> PyResult<()> {
752+
if self.plan().output_partitioning().partition_count() == desired_parallelism {
753+
return Ok(())
754+
}
755+
let updated_plan = self.plan().clone().transform_up(|node| {
756+
if let Some(parquet) = node.as_any().downcast_ref::<ParquetExec>() {
757+
// Remove redundant ranges from partition files because ParquetExec refuses to repartition
758+
// if any file has a range defined (even when the range actually covers the entire file).
759+
// The EnforceDistribution optimizer rule adds ranges for both full and partial files,
760+
// so this tries to rever that to trigger a repartition when no files are actually split.
761+
let mut file_groups = parquet.base_config().file_groups.clone();
762+
for group in file_groups.iter_mut() {
763+
for file in group.iter_mut() {
764+
if let Some(range) = &file.range {
765+
if range.start == 0 && range.end == file.object_meta.size as i64 {
766+
file.range = None; // remove redundant range
767+
}
768+
}
769+
}
770+
}
771+
if let Some(repartitioned) = parquet.clone().into_builder().with_file_groups(file_groups)
772+
.build_arc()
773+
.repartitioned(desired_parallelism, &ConfigOptions::default())? {
774+
Ok(Transformed::yes(repartitioned))
775+
} else {
776+
Ok(Transformed::no(node))
777+
}
778+
} else {
779+
Ok(Transformed::no(node))
780+
}
781+
}).map_err(py_datafusion_err)?.data;
782+
self.physical_plan = PyExecutionPlan::new(updated_plan);
783+
Ok(())
784+
}
785+
}
786+
787+
impl DistributedPlan {
788+
789+
fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
790+
Self {
791+
physical_plan: PyExecutionPlan::new(plan)
792+
}
793+
}
794+
795+
fn plan(&self) -> &Arc<dyn ExecutionPlan> {
796+
&self.physical_plan.plan
797+
}
798+
799+
fn stats_field(&self, field: fn(Statistics) -> Precision<usize>) -> Option<usize> {
800+
if let Ok(stats) = self.physical_plan.plan.statistics() {
801+
match field(stats) {
802+
Precision::Exact(n) => Some(n),
803+
_ => None,
804+
}
805+
} else {
806+
None
807+
}
808+
}
809+
810+
}
811+
812+
#[pyfunction]
813+
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
814+
deltalake::ensure_initialized();
815+
let node = PhysicalPlanNode::decode(serialized_plan)
816+
.map_err(|e| DataFusionError::External(Box::new(e)))
817+
.map_err(py_datafusion_err)?;
818+
let ctx = SessionContext::new();
819+
let plan = node.try_into_physical_plan(&ctx, ctx.runtime_env().as_ref(), codec())
820+
.map_err(py_datafusion_err)?;
821+
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
822+
plan.execute(partition, ctx.task_ctx())
823+
});
824+
wait_for_future(py, stream_with_runtime)
825+
.map_err(py_datafusion_err)?
826+
.map(PyRecordBatchStream::new)
827+
.map_err(py_datafusion_err)
828+
}
829+
830+
fn codec() -> &'static dyn PhysicalExtensionCodec {
831+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
832+
&CODEC
700833
}
701834

702835
/// Print DataFrame

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
116116
#[cfg(feature = "substrait")]
117117
setup_substrait_module(py, &m)?;
118118

119+
m.add_class::<dataframe::DistributedPlan>()?;
120+
m.add_wrapped(wrap_pyfunction!(dataframe::partition_stream))?;
119121
Ok(())
120122
}
121123

0 commit comments

Comments
 (0)