Skip to content
26 changes: 15 additions & 11 deletions python/sedonadb/python/sedonadb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def __init__(self):
self.__impl = None
self.options = Options()

@classmethod
def _init_from_impl(cls, impl, options):
instance = cls()
instance.__impl = impl
instance.options = options
return instance

@property
def _impl(self):
"""Lazily initialize the internal Rust context on first use.
Expand Down Expand Up @@ -145,7 +152,7 @@ def create_data_frame(self, obj: Any, schema: Any = None) -> DataFrame:
│ 1 │
└───────┘
"""
return _create_data_frame(self._impl, obj, schema, self.options)
return _create_data_frame(self, obj, schema)

def view(self, name: str) -> DataFrame:
"""Create a [DataFrame][sedonadb.dataframe.DataFrame] from a named view
Expand All @@ -169,7 +176,7 @@ def view(self, name: str) -> DataFrame:
>>> sd.drop_view("foofy")

"""
return DataFrame(self._impl, self._impl.view(name), self.options)
return DataFrame(self, self._impl.view(name))

def drop_view(self, name: str) -> None:
"""Remove a named view
Expand Down Expand Up @@ -271,11 +278,10 @@ def read_parquet(
geometry_columns = json.dumps(geometry_columns)

return DataFrame(
self._impl,
self,
self._impl.read_parquet(
[str(path) for path in table_paths], options, geometry_columns, validate
),
self.options,
)

def read_pyogrio(
Expand Down Expand Up @@ -344,11 +350,10 @@ def read_pyogrio(
spec = spec.with_options(options)

return DataFrame(
self._impl,
self,
self._impl.read_external_format(
spec, [str(path) for path in table_paths], False
),
self.options,
)

def read_format(
Expand Down Expand Up @@ -388,11 +393,10 @@ def read_format(
table_paths = [table_paths]

return DataFrame(
self._impl,
self,
self._impl.read_external_format(
spec, [str(path) for path in table_paths], check_extension
),
self.options,
)

def sql(
Expand Down Expand Up @@ -438,7 +442,7 @@ def sql(
└────────────┘

"""
df = DataFrame(self._impl, self._impl.sql(sql), self.options)
df = DataFrame(self, self._impl.sql(sql))

if params is not None:
if isinstance(params, (tuple, list)):
Expand Down Expand Up @@ -509,7 +513,7 @@ def col(self, name: str, qualifier: Optional[str] = None) -> Expr:
>>> sd.col("x", "t")
Expr(t.x)
"""
return col_expr(name, qualifier=qualifier)
return col_expr(name, qualifier=qualifier, ctx=self)

def lit(self, value: Any) -> LiteralExpr:
"""Create a literal (constant) expression
Expand All @@ -536,7 +540,7 @@ def lit(self, value: Any) -> LiteralExpr:
- pyproj CRS objects become PROJJSON strings (e.g., so they may be used
in `ST_SetCRS()`, `ST_Point()`, or `ST_GeomFromWKT()`).
"""
return lit_expr(value)
return lit_expr(value, ctx=self)


def connect() -> SedonaContext:
Expand Down
74 changes: 31 additions & 43 deletions python/sedonadb/python/sedonadb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,9 @@ class DataFrame:
└────────────┘
"""

def __init__(self, ctx, impl, options):
def __init__(self, ctx, impl):
self._ctx = ctx
self._impl = impl
self._options = options

@property
def schema(self):
Expand Down Expand Up @@ -199,11 +198,7 @@ def alias(self, name: str) -> "DataFrame":
references in join expressions. This is the equivalent of aliasing a subquery
in SQL (`(SELECT * FROM df) AS name`).
"""
return DataFrame(
self._ctx,
self._impl.alias(name),
self._options,
)
return DataFrame(self._ctx, self._impl.alias(name))

def __getitem__(self, key: Union[str, int]) -> Expr:
"""Reference a single column by name or position.
Expand Down Expand Up @@ -262,8 +257,9 @@ def __getitem__(self, key: Union[str, int]) -> Expr:
"DataFrame slicing is not supported. "
"Use df.limit(n) or df.limit(n, offset=k)."
)

inner_expr = self._impl.qualified_column_expr(key)
return Expr(inner_expr)
return Expr(inner_expr, self._ctx)

def __getattr__(self, name):
"""Syntactic sugar for column access
Expand Down Expand Up @@ -372,7 +368,7 @@ def select(
f"got {type(e).__name__} for '{name}'"
)

return DataFrame(self._ctx, self._impl.select(coerced), self._options)
return DataFrame(self._ctx, self._impl.select(coerced))

def filter(self, *exprs: Expr) -> "DataFrame":
"""Filter rows by one or more boolean expressions.
Expand Down Expand Up @@ -425,7 +421,6 @@ def filter(self, *exprs: Expr) -> "DataFrame":
return DataFrame(
self._ctx,
self._impl.filter([e._impl for e in exprs]),
self._options,
)

def sort(self, *keys: Union[str, Expr, SortExpr]) -> "DataFrame":
Expand Down Expand Up @@ -492,7 +487,7 @@ def sort(self, *keys: Union[str, Expr, SortExpr]) -> "DataFrame":
f"got {type(k).__name__}"
)

return DataFrame(self._ctx, self._impl.sort(coerced), self._options)
return DataFrame(self._ctx, self._impl.sort(coerced))

def drop(self, *cols: str) -> "DataFrame":
"""Drop the named columns.
Expand Down Expand Up @@ -537,7 +532,7 @@ def drop(self, *cols: str) -> "DataFrame":
f"Column(s) {unknown} not found. Available columns: {columns}"
)

return DataFrame(self._ctx, self._impl.drop_columns(list(cols)), self._options)
return DataFrame(self._ctx, self._impl.drop_columns(list(cols)))

def agg(self, *exprs: Expr, **named_exprs: Expr) -> "DataFrame":
"""Aggregate the entire DataFrame to a single row.
Expand Down Expand Up @@ -593,7 +588,6 @@ def agg(self, *exprs: Expr, **named_exprs: Expr) -> "DataFrame":
return DataFrame(
self._ctx,
self._impl.aggregate([], [e._impl for e in all_exprs]),
self._options,
)

def group_by(self, *keys: Union[str, Expr]) -> "GroupedDataFrame":
Expand Down Expand Up @@ -670,7 +664,7 @@ def limit(self, n: Optional[int], /, *, offset: int = 0) -> "DataFrame":
└───────┘

"""
return DataFrame(self._ctx, self._impl.limit(n, offset), self._options)
return DataFrame(self._ctx, self._impl.limit(n, offset))

def execute(self) -> None:
"""Execute the plan represented by this DataFrame
Expand Down Expand Up @@ -751,7 +745,6 @@ def with_params(self, *args: List[Any], **kwargs: Dict[str, Any]):
return DataFrame(
self._ctx,
self._impl.with_params(positional_params, named_params),
self._options,
)

def __arrow_c_schema__(self):
Expand All @@ -775,7 +768,7 @@ def __arrow_c_stream__(self, requested_schema: Any = None):
Args:
requested_schema: A PyCapsule representing the desired output schema.
"""
return self._impl.to_stream(self._ctx, simplify=False).__arrow_c_stream__(
return self._impl.to_stream(self._ctx._impl, simplify=False).__arrow_c_stream__(
requested_schema=requested_schema
)

Expand Down Expand Up @@ -807,7 +800,7 @@ def to_arrow_reader(self, *, simplify: bool = False) -> "pa.RecordBatchReader":
import pyarrow as pa

return pa.RecordBatchReader.from_stream(
self._impl.to_stream(self._ctx, simplify=simplify)
self._impl.to_stream(self._ctx._impl, simplify=simplify)
)

def arrow(self, *, simplify: bool = False) -> "pa.RecordBatchReader":
Expand Down Expand Up @@ -837,7 +830,7 @@ def to_view(self, name: str, overwrite: bool = False):
└────────────┘

"""
self._impl.to_view(self._ctx, name, overwrite)
self._impl.to_view(self._ctx._impl, name, overwrite)

def to_memtable(self) -> "DataFrame":
"""Collect a data frame into a memtable
Expand All @@ -860,7 +853,7 @@ def to_memtable(self) -> "DataFrame":
└────────────┘

"""
return DataFrame(self._ctx, self._impl.to_memtable(self._ctx), self._options)
return DataFrame(self._ctx, self._impl.to_memtable(self._ctx._impl))

def __datafusion_table_provider__(self):
return self._impl.__datafusion_table_provider__()
Expand Down Expand Up @@ -1034,7 +1027,7 @@ def to_parquet(
sort_by = []

self._impl.to_parquet(
self._ctx,
self._ctx._impl,
str(path),
options,
partition_by,
Expand Down Expand Up @@ -1116,7 +1109,7 @@ def to_pyogrio(

# GDAL does not support newer Arrow types like string views util 3.14, so we export a
# reader with simpler types here
self_simplified = self._impl.to_stream(self._ctx, simplify=True)
self_simplified = self._impl.to_stream(self._ctx._impl, simplify=True)

# Writer: pyogrio.write_arrow() via Cython ogr_write_arrow()
# https://github.com/geopandas/pyogrio/blob/3b2d40273b501c10ecf46cbd37c6e555754c89af/pyogrio/raw.py#L755-L897
Expand Down Expand Up @@ -1164,7 +1157,7 @@ def show(

"""
width = self._out_width(width)
print(self._impl.show(self._ctx, limit, width, ascii), end="")
print(self._impl.show(self._ctx._impl, limit, width, ascii), end="")

def explain(
self,
Expand Down Expand Up @@ -1207,23 +1200,21 @@ def explain(
│ ┆ │
└───────────────┴─────────────────────────────────┘
"""
return DataFrame(self._ctx, self._impl.explain(type, format), self._options)
return DataFrame(self._ctx, self._impl.explain(type, format))

def __repr__(self) -> str:
if self._options.interactive:
if self._ctx.options.interactive:
width = self._out_width()
return self._impl.show(self._ctx, 10, width, ascii=False).strip()
return self._impl.show(self._ctx._impl, 10, width, ascii=False).strip()
else:
return super().__repr__()

def _simplify_storage_types(self):
return DataFrame(
self._ctx, self._impl.simplify_storage_types(self._ctx), self._options
)
return DataFrame(self._ctx, self._impl.simplify_storage_types(self._ctx._impl))

def _out_width(self, width=None) -> int:
if width is None:
width = self._options.width
width = self._ctx.options.width

if width is None:
import shutil
Expand All @@ -1233,7 +1224,7 @@ def _out_width(self, width=None) -> int:
return width


def _create_data_frame(ctx_impl, obj, schema, options) -> DataFrame:
def _create_data_frame(ctx, obj, schema) -> DataFrame:
"""Create a DataFrame (internal)

This is defined here because we need it in future dataframe methods like
Expand All @@ -1243,7 +1234,7 @@ def _create_data_frame(ctx_impl, obj, schema, options) -> DataFrame:
# If we're dealing with an anonymous data frame on the same context,
# just return it. Otherwise, fall back to the default interpretation
# (which uses __datafusion_table_provider__).
if isinstance(obj, DataFrame) and obj._ctx is ctx_impl and schema is None:
if isinstance(obj, DataFrame) and obj._ctx is ctx and schema is None:
return obj

# We special case a few object types where collecting the __arrow_c_stream__
Expand All @@ -1252,22 +1243,22 @@ def _create_data_frame(ctx_impl, obj, schema, options) -> DataFrame:
# This includes geopandas/pandas DataFrames, pyarrow tables, and Polars tables.
type_name = _qualified_type_name(obj)
if type_name in SPECIAL_CASED_SCANS:
return SPECIAL_CASED_SCANS[type_name](ctx_impl, obj, schema, options)
return SPECIAL_CASED_SCANS[type_name](ctx, obj, schema)

# The default implementation handles objects that implement
# __datafusion_table_provider__ or __arrow_c_stream__. For objects implementing
# __arrow_c_stream__, this currently will only work for a single scan (i.e.,
# the returned data frame can't be previewed before the query is computed).
return _scan_default(ctx_impl, obj, schema, options)
return _scan_default(ctx, obj, schema)


def _scan_default(ctx_impl, obj, schema, options):
impl = ctx_impl.create_data_frame(obj, schema)
return DataFrame(ctx_impl, impl, options)
def _scan_default(ctx, obj, schema):
impl = ctx._impl.create_data_frame(obj, schema)
return DataFrame(ctx, impl)


def _scan_collected_default(ctx_impl, obj, schema, options):
return _scan_default(ctx_impl, obj, schema, options).to_memtable()
def _scan_collected_default(ctx, obj, schema):
return _scan_default(ctx, obj, schema).to_memtable()


class GroupedDataFrame:
Expand Down Expand Up @@ -1317,14 +1308,11 @@ def agg(self, *exprs: Expr, **named_exprs: Expr) -> DataFrame:
[g._impl for g in self._group_exprs],
[e._impl for e in all_exprs],
),
self._df._options,
)


def _scan_geopandas(ctx_impl, obj, schema, options):
return _scan_collected_default(
ctx_impl, obj.to_arrow(geometry_encoding="WKB"), schema, options
)
def _scan_geopandas(ctx, obj, schema):
return _scan_collected_default(ctx, obj.to_arrow(geometry_encoding="WKB"), schema)


def _qualified_type_name(obj):
Expand Down
Loading