Skip to content

Commit 03fde8a

Browse files
timsaucerclaude
andcommitted
Add deregister methods to SessionContext for UDFs and object stores
Expose upstream DataFusion deregister methods (deregister_udf, deregister_udaf, deregister_udwf, deregister_udtf, deregister_object_store) in both the Rust PyO3 bindings and Python wrappers, closing the gap identified in #1457. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2499409 commit 03fde8a

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed

crates/core/src/context.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,20 @@ impl PySessionContext {
439439
Ok(())
440440
}
441441

442+
/// Deregister an object store with the given url
443+
#[pyo3(signature = (scheme, host=None))]
444+
pub fn deregister_object_store(
445+
&self,
446+
scheme: &str,
447+
host: Option<&str>,
448+
) -> PyDataFusionResult<()> {
449+
let host = host.unwrap_or("");
450+
let url_string = format!("{scheme}{host}");
451+
let url = Url::parse(&url_string).unwrap();
452+
self.ctx.runtime_env().deregister_object_store(&url)?;
453+
Ok(())
454+
}
455+
442456
#[allow(clippy::too_many_arguments)]
443457
#[pyo3(signature = (name, path, table_partition_cols=vec![],
444458
file_extension=".parquet",
@@ -492,6 +506,10 @@ impl PySessionContext {
492506
self.ctx.register_udtf(&name, func);
493507
}
494508

509+
pub fn deregister_udtf(&self, name: &str) {
510+
self.ctx.deregister_udtf(name);
511+
}
512+
495513
#[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
496514
pub fn sql_with_options(
497515
&self,
@@ -975,16 +993,28 @@ impl PySessionContext {
975993
Ok(())
976994
}
977995

996+
pub fn deregister_udf(&self, name: &str) {
997+
self.ctx.deregister_udf(name);
998+
}
999+
9781000
pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
9791001
self.ctx.register_udaf(udaf.function);
9801002
Ok(())
9811003
}
9821004

1005+
pub fn deregister_udaf(&self, name: &str) {
1006+
self.ctx.deregister_udaf(name);
1007+
}
1008+
9831009
pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
9841010
self.ctx.register_udwf(udwf.function);
9851011
Ok(())
9861012
}
9871013

1014+
pub fn deregister_udwf(&self, name: &str) {
1015+
self.ctx.deregister_udwf(name);
1016+
}
1017+
9881018
#[pyo3(signature = (name="datafusion"))]
9891019
pub fn catalog(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
9901020
let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(

python/datafusion/context.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,15 @@ def register_object_store(
568568
"""
569569
self.ctx.register_object_store(schema, store, host)
570570

571+
def deregister_object_store(self, schema: str, host: str | None = None) -> None:
572+
"""Remove an object store from the session.
573+
574+
Args:
575+
schema: The data source schema (e.g. ``"s3://"``).
576+
host: URL for the host (e.g. bucket name).
577+
"""
578+
self.ctx.deregister_object_store(schema, host)
579+
571580
def register_listing_table(
572581
self,
573582
name: str,
@@ -894,6 +903,14 @@ def register_udtf(self, func: TableFunction) -> None:
894903
"""Register a user defined table function."""
895904
self.ctx.register_udtf(func._udtf)
896905

906+
def deregister_udtf(self, name: str) -> None:
907+
"""Remove a user-defined table function from the session.
908+
909+
Args:
910+
name: Name of the UDTF to deregister.
911+
"""
912+
self.ctx.deregister_udtf(name)
913+
897914
def register_record_batches(
898915
self, name: str, partitions: list[list[pa.RecordBatch]]
899916
) -> None:
@@ -1105,14 +1122,38 @@ def register_udf(self, udf: ScalarUDF) -> None:
11051122
"""Register a user-defined function (UDF) with the context."""
11061123
self.ctx.register_udf(udf._udf)
11071124

1125+
def deregister_udf(self, name: str) -> None:
1126+
"""Remove a user-defined scalar function from the session.
1127+
1128+
Args:
1129+
name: Name of the UDF to deregister.
1130+
"""
1131+
self.ctx.deregister_udf(name)
1132+
11081133
def register_udaf(self, udaf: AggregateUDF) -> None:
11091134
"""Register a user-defined aggregation function (UDAF) with the context."""
11101135
self.ctx.register_udaf(udaf._udaf)
11111136

1137+
def deregister_udaf(self, name: str) -> None:
1138+
"""Remove a user-defined aggregate function from the session.
1139+
1140+
Args:
1141+
name: Name of the UDAF to deregister.
1142+
"""
1143+
self.ctx.deregister_udaf(name)
1144+
11121145
def register_udwf(self, udwf: WindowUDF) -> None:
11131146
"""Register a user-defined window function (UDWF) with the context."""
11141147
self.ctx.register_udwf(udwf._udwf)
11151148

1149+
def deregister_udwf(self, name: str) -> None:
1150+
"""Remove a user-defined window function from the session.
1151+
1152+
Args:
1153+
name: Name of the UDWF to deregister.
1154+
"""
1155+
self.ctx.deregister_udwf(name)
1156+
11161157
def catalog(self, name: str = "datafusion") -> Catalog:
11171158
"""Retrieve a catalog by name."""
11181159
return Catalog(self.ctx.catalog(name))

python/tests/test_context.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,126 @@ def test_deregister_table(ctx, database):
351351
assert public.names() == {"csv1", "csv2"}
352352

353353

354+
def test_deregister_udf():
355+
ctx = SessionContext()
356+
from datafusion import udf
357+
358+
is_null = udf(
359+
lambda x: x.is_null(),
360+
[pa.float64()],
361+
pa.bool_(),
362+
volatility="immutable",
363+
name="my_is_null",
364+
)
365+
ctx.register_udf(is_null)
366+
367+
# Verify it works
368+
df = ctx.from_pydict({"a": [1.0, None]})
369+
ctx.register_table("t", df.into_view())
370+
result = ctx.sql("SELECT my_is_null(a) FROM t").collect()
371+
assert result[0].column(0) == pa.array([False, True])
372+
373+
# Deregister and verify it's gone
374+
ctx.deregister_udf("my_is_null")
375+
with pytest.raises(RuntimeError):
376+
ctx.sql("SELECT my_is_null(a) FROM t").collect()
377+
378+
379+
def test_deregister_udaf():
380+
import pyarrow.compute as pc
381+
382+
ctx = SessionContext()
383+
from datafusion import Accumulator, udaf
384+
385+
class MySum(Accumulator):
386+
def __init__(self):
387+
self._sum = 0.0
388+
389+
def update(self, values: pa.Array) -> None:
390+
self._sum += pc.sum(values).as_py()
391+
392+
def merge(self, states: list[pa.Array]) -> None:
393+
self._sum += pc.sum(states[0]).as_py()
394+
395+
def state(self) -> list:
396+
return [self._sum]
397+
398+
def evaluate(self) -> pa.Scalar:
399+
return self._sum
400+
401+
my_sum = udaf(
402+
MySum,
403+
[pa.float64()],
404+
pa.float64(),
405+
[pa.float64()],
406+
volatility="immutable",
407+
name="my_sum",
408+
)
409+
ctx.register_udaf(my_sum)
410+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]})
411+
ctx.register_table("t", df.into_view())
412+
413+
result = ctx.sql("SELECT my_sum(a) FROM t").collect()
414+
assert result[0].column(0) == pa.array([6.0])
415+
416+
ctx.deregister_udaf("my_sum")
417+
with pytest.raises(RuntimeError):
418+
ctx.sql("SELECT my_sum(a) FROM t").collect()
419+
420+
421+
def test_deregister_udwf():
422+
ctx = SessionContext()
423+
from datafusion import udwf
424+
from datafusion.user_defined import WindowEvaluator
425+
426+
class MyRowNumber(WindowEvaluator):
427+
def __init__(self):
428+
self._row = 0
429+
430+
def evaluate_all(self, values, num_rows):
431+
return pa.array(list(range(1, num_rows + 1)), type=pa.uint64())
432+
433+
my_row_number = udwf(
434+
MyRowNumber,
435+
[pa.float64()],
436+
pa.uint64(),
437+
volatility="immutable",
438+
name="my_row_number",
439+
)
440+
ctx.register_udwf(my_row_number)
441+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0]})
442+
ctx.register_table("t", df.into_view())
443+
444+
result = ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect()
445+
assert result[0].column(0) == pa.array([1, 2, 3], type=pa.uint64())
446+
447+
ctx.deregister_udwf("my_row_number")
448+
with pytest.raises(RuntimeError):
449+
ctx.sql("SELECT my_row_number(a) OVER () FROM t").collect()
450+
451+
452+
def test_deregister_udtf():
453+
import pyarrow.dataset as ds
454+
455+
ctx = SessionContext()
456+
from datafusion import Table, udtf
457+
458+
class MyTable:
459+
def __call__(self):
460+
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]})
461+
return Table(ds.dataset([batch]))
462+
463+
my_table = udtf(MyTable(), "my_table")
464+
ctx.register_udtf(my_table)
465+
466+
result = ctx.sql("SELECT * FROM my_table()").collect()
467+
assert result[0].column(0) == pa.array([1, 2, 3])
468+
469+
ctx.deregister_udtf("my_table")
470+
with pytest.raises(RuntimeError):
471+
ctx.sql("SELECT * FROM my_table()").collect()
472+
473+
354474
def test_register_table_from_dataframe(ctx):
355475
df = ctx.from_pydict({"a": [1, 2]})
356476
ctx.register_table("df_tbl", df)

0 commit comments

Comments
 (0)