Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 109 additions & 58 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Coroutine,
Iterable,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -43,6 +44,7 @@
TypedDict,
TypeVar,
cast,
overload,
)

from packaging.version import parse as parse_version
Expand Down Expand Up @@ -2864,57 +2866,77 @@ def retry(self, futures, asynchronous=None):
"""
return self.sync(self._retry, futures, asynchronous=asynchronous)

@log_errors
async def _publish_dataset(self, *args, name=None, override=False, **kwargs):
coroutines = []
uid = uuid.uuid4().hex
self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid})

def add_coro(name, data):
keys = [f.key for f in futures_of(data)]

async def _():
await self.scheduler.publish_wait_flush(uid=uid)
await self.scheduler.publish_put(
keys=keys,
name=name,
data=to_serialize(data),
override=override,
client=self.id,
)

coroutines.append(_())
async def _publish_dataset(
self, *args: Any, name: Key | None = None, override: bool = False, **kwargs: Any
):
names: list[Key] = list(kwargs)
data = list(kwargs.values())

if name:
if len(args) == 0:
if not args:
raise ValueError(
"If name is provided, expecting call signature like"
" publish_dataset(df, name='ds')"
)
# in case this is a singleton, collapse it
elif len(args) == 1:
args = args[0]
add_coro(name, args)
names.append(name)
data.append(args[0] if len(args) == 1 else args)
elif args:
if len(args) != 1 or not isinstance(args[0], Mapping):
raise ValueError(
"If name is omitted, positional argument must be "
"a {name: value} dict"
)
names.extend(args[0])
data.extend(args[0].values())

for name, data in kwargs.items():
add_coro(name, data)
# Prevent race condition where the client persists the collection immediately
# before publish_dataset, but the persist command hasn't landed on the scheduler
# yet when the publish_put RPC call arrives asynchronously.
uid = uuid.uuid4().hex
self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid})

await asyncio.gather(*coroutines)
await self.scheduler.publish_put(
names=names,
keys=[[f.key for f in futures_of(data_i)] for data_i in data],
data=[to_serialize(data_i) for data_i in data],
override=override,
uid=uid,
)

@overload
def publish_dataset(
self, *args: Any, name: Key, override: bool = False, **kwargs
): ...

def publish_dataset(self, *args, **kwargs):
@overload
def publish_dataset(
self, *args: Mapping[Key, Any], override: bool = False, **kwargs
Copy link
Copy Markdown
Collaborator Author

@crusaderky crusaderky Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frustratingly, mypy complains if you write

    @overload
    def publish_dataset(
        self, arg: Mapping[Key, Any], /, *, override: bool = False, **kwargs
    ): ...

which would be more correct.

): ...

def publish_dataset(
self, *args: Any, name: Key | None = None, override: bool = False, **kwargs
):
Copy link
Copy Markdown
Collaborator Author

@crusaderky crusaderky Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: offering four different syntaxes to achieve the same thing is in clear violation of the Zen of Python:

There should be one-- and preferably only one --obvious way to do it.

Removing some of these syntaxes however would be a breaking change and is beyond the scope of this PR.

"""
Publish named datasets to scheduler

This stores a named reference to a dask collection or list of futures
This stores a named reference to one or more dask collections or futures
on the scheduler. These references are available to other Clients
which can download the collection or futures with ``get_dataset``.
which can download the collections or futures with ``get_dataset``.

Datasets are not immediately computed. You may wish to call
``Client.persist`` prior to publishing a dataset.
Datasets are not immediately computed. You should call ``persist`` prior to
publishing a dataset. Any unpersisted keys will be stored on the scheduler
uncomputed and returned as-is to the user when calling ``get_dataset``.

Parameters
----------
args : list of objects to publish as name
args : One or more objects to publish as `name`.
Alternatively, a single dict of {name: object} pairs.
name : Dask key (str, int, float, or tuple thereof)
Name to publish `args` under
override: bool
False (default) to raise KeyError if a dataset with the same name already
exists on the scheduler; True to overwrite it.
kwargs : dict
named collections to publish on the scheduler

Expand All @@ -2924,7 +2946,10 @@ def publish_dataset(self, *args, **kwargs):

>>> df = dd.read_csv('s3://...') # doctest: +SKIP
>>> df = c.persist(df) # doctest: +SKIP
>>> c.publish_dataset(my_dataset=df) # doctest: +SKIP
>>> c.publish_dataset({"my_dataset": df}) # doctest: +SKIP

Alternative invocation
>>> c.publish_dataset(my_dataset=df)

Alternative invocation
>>> c.publish_dataset(df, name='my_dataset')
Expand All @@ -2946,30 +2971,50 @@ def publish_dataset(self, *args, **kwargs):
Client.unpublish_dataset
Client.persist
"""
return self.sync(self._publish_dataset, *args, **kwargs)
return self.sync(
self._publish_dataset, *args, name=name, override=override, **kwargs
)

def unpublish_dataset(self, name, **kwargs):
async def _unpublish_dataset(self, name: Key | list[Key]) -> None:
names = name if isinstance(name, list) else [name]
uid = uuid.uuid4().hex
# Prevent race condition where the user calls get_dataset() and immediately
# afterwards unpublish_dataset(), thinking that they are holding a reference
# to the futures locally, but the futures haven't been registered on the
# scheduler yet by the time unpublish_dataset lands on the scheduler.
# This method can't be made into just a batched send command as it would
# create another race condition, where unpublish_dataset() followed by
# get_dataset() would return the just-deleted data.
self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid})
await self.scheduler.publish_delete(names=names, uid=uid)

def unpublish_dataset(self, name: Key | list[Key], **kwargs):
"""
Remove named datasets from scheduler

Parameters
----------
name : str
The name of the dataset to unpublish
name : Dask key (str, int, float, or tuple thereof), or list of keys
Name(s) of the dataset(s) to unpublish.

Examples
--------
>>> c.list_datasets() # doctest: +SKIP
['my_dataset']
>>> c.unpublish_dataset('my_dataset') # doctest: +SKIP
['foo', 'bar', 'baz']
>>> c.unpublish_dataset('foo') # doctest: +SKIP
>>> c.list_datasets() # doctest: +SKIP
['bar', 'baz']
>>> c.unpublish_dataset(['bar', 'baz']) # doctest: +SKIP
>>> c.list_datasets() # doctest: +SKIP
[]

See Also
--------
Client.publish_dataset
Client.list_datasets
Client.get_dataset
"""
return self.sync(self.scheduler.publish_delete, name=name, **kwargs)
return self.sync(self._unpublish_dataset, name=name, **kwargs)

def list_datasets(self, **kwargs):
"""
Expand All @@ -2978,46 +3023,52 @@ def list_datasets(self, **kwargs):
See Also
--------
Client.publish_dataset
Client.unpublish_dataset
Client.get_dataset
"""
return self.sync(self.scheduler.publish_list, **kwargs)

async def _get_dataset(self, name, default=no_default):
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)
if out is None:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
async def _get_dataset(self, name: Key | list[Key], default=no_default):
names = name if isinstance(name, list) else [name]
raw_outs = await self.scheduler.publish_get(names=names)

outs = []
for out in raw_outs:
if out is None:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
outs.append(default)
else:
return default
for fut in futures_of(out["data"]):
fut.bind_client(self)
for fut in futures_of(out["data"]):
fut.bind_client(self)
outs.append(out["data"])

self._inform_scheduler_of_futures()
return out["data"]
return outs if isinstance(name, list) else outs[0]

def get_dataset(self, name, default=no_default, **kwargs):
def get_dataset(self, name: Key | list[Key], default=no_default, **kwargs):
"""
Get named dataset from the scheduler if present.
Return the default or raise a KeyError if not present.

Parameters
----------
name : str
name of the dataset to retrieve
default : str
name : Dask key (str, int, float, or tuple thereof), or list of keys
name(s) of the dataset(s) to retrieve
default : Any
optional, not set by default
If set, do not raise a KeyError if the name is not present but
return this default
kwargs : dict
additional keyword arguments to _get_dataset
Comment on lines -3011 to -3012
Copy link
Copy Markdown
Collaborator Author

@crusaderky crusaderky Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs actually contain the situationally useful asynchronous=False to be passed to a synchronous client.
However their lack of documentation is endemic and is out of scope for this PR.
This change simply aligns the documentation of this method to all other client methods.


Returns
-------
The dataset from the scheduler, if present
The dataset from the scheduler, if present.
If name is a list of keys, return a list of datasets in the same order.

See Also
--------
Client.publish_dataset
Client.unpublish_dataset
Client.list_datasets
"""
return self.sync(self._get_dataset, name, default=default, **kwargs)
Expand Down
80 changes: 61 additions & 19 deletions distributed/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@
import asyncio
from collections import defaultdict
from collections.abc import MutableMapping
from typing import TYPE_CHECKING, TypedDict

from dask.typing import Key
from dask.utils import stringify

from distributed.protocol.serialize import Serialized
from distributed.utils import log_errors

if TYPE_CHECKING:
from distributed.scheduler import Scheduler


class PublishedDataset(TypedDict):
data: Serialized
keys: tuple[Key, ...]


class PublishExtension:
"""An extension for the scheduler to manage collections
Expand All @@ -18,51 +29,82 @@ class PublishExtension:
* publish_delete
"""

scheduler: Scheduler
datasets: dict[Key, PublishedDataset]
_flush_received: defaultdict[bytes, asyncio.Event]

def __init__(self, scheduler):
self.scheduler = scheduler
self.datasets = dict()
self.datasets = {}

handlers = {
"publish_list": self.list,
"publish_put": self.put,
"publish_get": self.get,
"publish_delete": self.delete,
"publish_wait_flush": self.flush_wait,
}
stream_handlers = {
"publish_flush_batched_send": self.flush_receive,
"publish_flush_batched_send": self.flush_batched_send,
}

self.scheduler.handlers.update(handlers)
self.scheduler.stream_handlers.update(stream_handlers)
self._flush_received = defaultdict(asyncio.Event)

def flush_receive(self, uid, **kwargs):
def flush_batched_send(self, client: str, uid: bytes) -> None:
self._flush_received[uid].set()

async def flush_wait(self, uid):
await self._flush_received[uid].wait()
async def _sync_batched_send(self, uid: bytes) -> None:
"""Wait for the client's batched-send to catch up with the same client's RPC
calls. Return True if the client is still connected; False otherwise.
"""
try:
await self._flush_received[uid].wait()
finally:
del self._flush_received[uid]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed memory leak where each call to publish_dataset would create an immortal asyncio.Event instance on the scheduler.

Note that there is still a potential memory leak left here, where the client flushes the batched comms, but then disconnects before the RPC call can be executed.
I tried fixing this use case but gave up, as I ended up with code that was both severely over-engineered and fragile to race conditions. Namely, one must be thoughtful when testing scheduler.clients, because the register-client endpoint is neither an async RPC nor a batched comm, and I found myself in cases where the batched comms had arrived but the client hadn't been registered yet.


@log_errors
def put(self, keys=None, data=None, name=None, override=False, client=None):
if not override and name in self.datasets:
raise KeyError("Dataset %s already exists" % name)
self.scheduler.client_desires_keys(keys, f"published-{stringify(name)}")
self.datasets[name] = {"data": data, "keys": keys}
Comment on lines -48 to -51
Copy link
Copy Markdown
Collaborator Author

@crusaderky crusaderky Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed bug where publish_dataset(..., override=True) would cause any keys from the original dataset to become immortal unless they were also present in the new dataset

return {"status": "OK", "name": name}
async def put(
self,
names: tuple[Key, ...],
keys: tuple[tuple[Key, ...], ...],
data: tuple[Serialized, ...],
override: bool,
uid: bytes,
) -> None:
await self._sync_batched_send(uid)

for name, keys_i, data_i in zip(names, keys, data, strict=True):
if name in self.datasets:
if override:
old = self.datasets.pop(name)
self.scheduler.client_releases_keys(
old["keys"], f"published-{stringify(name)}"
)
else:
raise KeyError("Dataset %s already exists" % name)

self.scheduler.client_desires_keys(keys_i, f"published-{stringify(name)}")
self.datasets[name] = {"data": data_i, "keys": keys_i}

@log_errors
def delete(self, name=None):
out = self.datasets.pop(name, {"keys": []})
self.scheduler.client_releases_keys(out["keys"], f"published-{stringify(name)}")
async def delete(self, names: tuple[Key, ...], uid: bytes) -> None:
await self._sync_batched_send(uid)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed race condition where the user calls get_dataset() and immediately afterwards unpublish_dataset(), thinking that they are holding a reference to the futures locally, but the scheduler hasn't noted it by the time unpublish_dataset lands on the scheduler. This caused the client holding a reference to forgotten keys. This is symmetrical to what #8577 fixed for publish_dataset.


for name in names:
out = self.datasets.pop(name, None)
if out is not None:
self.scheduler.client_releases_keys(
out["keys"], f"published-{stringify(name)}"
)

@log_errors
def list(self, *args):
return list(sorted(self.datasets.keys(), key=str))
def list(self) -> list[Key]:
return list(sorted(self.datasets, key=str))

@log_errors
def get(self, name=None, client=None):
return self.datasets.get(name, None)
def get(self, names: tuple[Key, ...]) -> list[PublishedDataset | None]: # type: ignore[valid-type]
return [self.datasets.get(name, None) for name in names]


class Datasets(MutableMapping):
Expand Down
Loading
Loading