Skip to content

Commit 4131899

Browse files
timsaucerclaude
andcommitted
feat: SessionContext UDF lookup helpers
Expose `udf(name)` / `udaf(name)` / `udwf(name)` lookups symmetric with the existing `register_udf` / `register_udaf` / `register_udwf` setters, plus `udfs()` / `udafs()` / `udwfs()` for enumerating registered function names. Looked-up functions come back as the same `ScalarUDF` / `AggregateUDF` / `WindowUDF` wrappers users already get from registration, so they can be called as expressions or re-registered into a different session. Returns Vec<String> from the list helpers (sorted) rather than the raw HashSet upstream returns, so calling code gets a stable ordering. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 51fc2bc commit 4131899

3 files changed

Lines changed: 114 additions & 1 deletion

File tree

crates/core/src/context.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ use datafusion::datasource::listing::{
3535
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
3636
};
3737
use datafusion::datasource::{MemTable, TableProvider};
38-
use datafusion::execution::TaskContextProvider;
3938
use datafusion::execution::context::{
4039
DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
4140
};
@@ -44,6 +43,7 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, Unboun
4443
use datafusion::execution::options::{ArrowReadOptions, ReadOptions};
4544
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
4645
use datafusion::execution::session_state::SessionStateBuilder;
46+
use datafusion::execution::{FunctionRegistry, TaskContextProvider};
4747
use datafusion::prelude::{
4848
AvroReadOptions, CsvReadOptions, DataFrame, JsonReadOptions, ParquetReadOptions,
4949
};
@@ -1072,6 +1072,39 @@ impl PySessionContext {
10721072
self.ctx.deregister_udwf(name);
10731073
}
10741074

1075+
pub fn udf(&self, name: &str) -> PyDataFusionResult<PyScalarUDF> {
1076+
let function = (*self.ctx.udf(name)?).clone();
1077+
Ok(PyScalarUDF { function })
1078+
}
1079+
1080+
pub fn udaf(&self, name: &str) -> PyDataFusionResult<PyAggregateUDF> {
1081+
let function = (*self.ctx.udaf(name)?).clone();
1082+
Ok(PyAggregateUDF { function })
1083+
}
1084+
1085+
pub fn udwf(&self, name: &str) -> PyDataFusionResult<PyWindowUDF> {
1086+
let function = (*self.ctx.udwf(name)?).clone();
1087+
Ok(PyWindowUDF { function })
1088+
}
1089+
1090+
pub fn udfs(&self) -> Vec<String> {
1091+
let mut names: Vec<String> = self.ctx.udfs().into_iter().collect();
1092+
names.sort();
1093+
names
1094+
}
1095+
1096+
pub fn udafs(&self) -> Vec<String> {
1097+
let mut names: Vec<String> = self.ctx.udafs().into_iter().collect();
1098+
names.sort();
1099+
names
1100+
}
1101+
1102+
pub fn udwfs(&self) -> Vec<String> {
1103+
let mut names: Vec<String> = self.ctx.udwfs().into_iter().collect();
1104+
names.sort();
1105+
names
1106+
}
1107+
10751108
#[pyo3(signature = (name="datafusion"))]
10761109
pub fn catalog(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
10771110
let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(

python/datafusion/context.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,65 @@ def deregister_udwf(self, name: str) -> None:
13101310
"""
13111311
self.ctx.deregister_udwf(name)
13121312

1313+
def udf(self, name: str) -> ScalarUDF:
1314+
"""Look up a registered scalar UDF by name.
1315+
1316+
Args:
1317+
name: Name of the registered scalar UDF.
1318+
1319+
Raises:
1320+
Exception: If no scalar UDF is registered under ``name``.
1321+
"""
1322+
from datafusion.user_defined import ScalarUDF as _ScalarUDF # noqa: PLC0415
1323+
1324+
wrapper = _ScalarUDF.__new__(_ScalarUDF)
1325+
wrapper._udf = self.ctx.udf(name)
1326+
return wrapper
1327+
1328+
def udaf(self, name: str) -> AggregateUDF:
1329+
"""Look up a registered aggregate UDF by name.
1330+
1331+
Args:
1332+
name: Name of the registered aggregate UDF.
1333+
1334+
Raises:
1335+
Exception: If no aggregate UDF is registered under ``name``.
1336+
"""
1337+
from datafusion.user_defined import ( # noqa: PLC0415
1338+
AggregateUDF as _AggregateUDF,
1339+
)
1340+
1341+
wrapper = _AggregateUDF.__new__(_AggregateUDF)
1342+
wrapper._udaf = self.ctx.udaf(name)
1343+
return wrapper
1344+
1345+
def udwf(self, name: str) -> WindowUDF:
1346+
"""Look up a registered window UDF by name.
1347+
1348+
Args:
1349+
name: Name of the registered window UDF.
1350+
1351+
Raises:
1352+
Exception: If no window UDF is registered under ``name``.
1353+
"""
1354+
from datafusion.user_defined import WindowUDF as _WindowUDF # noqa: PLC0415
1355+
1356+
wrapper = _WindowUDF.__new__(_WindowUDF)
1357+
wrapper._udwf = self.ctx.udwf(name)
1358+
return wrapper
1359+
1360+
def udfs(self) -> list[str]:
1361+
"""Return the sorted names of all registered scalar UDFs."""
1362+
return self.ctx.udfs()
1363+
1364+
def udafs(self) -> list[str]:
1365+
"""Return the sorted names of all registered aggregate UDFs."""
1366+
return self.ctx.udafs()
1367+
1368+
def udwfs(self) -> list[str]:
1369+
"""Return the sorted names of all registered window UDFs."""
1370+
return self.ctx.udwfs()
1371+
13131372
def catalog(self, name: str = "datafusion") -> Catalog:
13141373
"""Retrieve a catalog by name."""
13151374
return Catalog(self.ctx.catalog(name))

python/tests/test_udf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ def test_register_udf(ctx, df) -> None:
7676
assert result == pa.array([False, False, True])
7777

7878

79+
def test_udf_lookup(ctx, df) -> None:
80+
is_null = udf(
81+
lambda x: x.is_null(),
82+
[pa.float64()],
83+
pa.bool_(),
84+
volatility="immutable",
85+
name="lookup_is_null",
86+
)
87+
ctx.register_udf(is_null)
88+
89+
assert "lookup_is_null" in ctx.udfs()
90+
91+
looked_up = ctx.udf("lookup_is_null")
92+
df_result = df.select(looked_up(column("b")))
93+
result = df_result.collect()[0].column(0)
94+
assert result == pa.array([False, False, True])
95+
96+
with pytest.raises(Exception, match="no UDF named"):
97+
ctx.udf("does_not_exist")
98+
99+
79100
class OverThresholdUDF:
80101
def __init__(self, threshold: int = 0) -> None:
81102
self.threshold = threshold

0 commit comments

Comments
 (0)