-
-
Notifications
You must be signed in to change notification settings - Fork 757
Overhaul publish_dataset extension
#9217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| Coroutine, | ||
| Iterable, | ||
| Iterator, | ||
| Mapping, | ||
| Sequence, | ||
| ) | ||
| from concurrent.futures import ThreadPoolExecutor | ||
|
|
@@ -43,6 +44,7 @@ | |
| TypedDict, | ||
| TypeVar, | ||
| cast, | ||
| overload, | ||
| ) | ||
|
|
||
| from packaging.version import parse as parse_version | ||
|
|
@@ -2864,57 +2866,77 @@ def retry(self, futures, asynchronous=None): | |
| """ | ||
| return self.sync(self._retry, futures, asynchronous=asynchronous) | ||
|
|
||
| @log_errors | ||
| async def _publish_dataset(self, *args, name=None, override=False, **kwargs): | ||
| coroutines = [] | ||
| uid = uuid.uuid4().hex | ||
| self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid}) | ||
|
|
||
| def add_coro(name, data): | ||
| keys = [f.key for f in futures_of(data)] | ||
|
|
||
| async def _(): | ||
| await self.scheduler.publish_wait_flush(uid=uid) | ||
| await self.scheduler.publish_put( | ||
| keys=keys, | ||
| name=name, | ||
| data=to_serialize(data), | ||
| override=override, | ||
| client=self.id, | ||
| ) | ||
|
|
||
| coroutines.append(_()) | ||
| async def _publish_dataset( | ||
| self, *args: Any, name: Key | None = None, override: bool = False, **kwargs: Any | ||
| ): | ||
| names: list[Key] = list(kwargs) | ||
| data = list(kwargs.values()) | ||
|
|
||
| if name: | ||
| if len(args) == 0: | ||
| if not args: | ||
| raise ValueError( | ||
| "If name is provided, expecting call signature like" | ||
| " publish_dataset(df, name='ds')" | ||
| ) | ||
| # in case this is a singleton, collapse it | ||
| elif len(args) == 1: | ||
| args = args[0] | ||
| add_coro(name, args) | ||
| names.append(name) | ||
| data.append(args[0] if len(args) == 1 else args) | ||
| elif args: | ||
| if len(args) != 1 or not isinstance(args[0], Mapping): | ||
| raise ValueError( | ||
| "If name is omitted, positional argument must be " | ||
| "a {name: value} dict" | ||
| ) | ||
| names.extend(args[0]) | ||
| data.extend(args[0].values()) | ||
|
|
||
| for name, data in kwargs.items(): | ||
| add_coro(name, data) | ||
| # Prevent race condition where the client persists the collection immediately | ||
| # before publish_dataset, but the persist command hasn't landed on the scheduler | ||
| # yet when the publish_put RPC call arrives asynchronously. | ||
| uid = uuid.uuid4().hex | ||
| self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid}) | ||
|
|
||
| await asyncio.gather(*coroutines) | ||
| await self.scheduler.publish_put( | ||
| names=names, | ||
| keys=[[f.key for f in futures_of(data_i)] for data_i in data], | ||
| data=[to_serialize(data_i) for data_i in data], | ||
| override=override, | ||
| uid=uid, | ||
| ) | ||
|
|
||
| @overload | ||
| def publish_dataset( | ||
| self, *args: Any, name: Key, override: bool = False, **kwargs | ||
| ): ... | ||
|
|
||
| def publish_dataset(self, *args, **kwargs): | ||
| @overload | ||
| def publish_dataset( | ||
| self, *args: Mapping[Key, Any], override: bool = False, **kwargs | ||
| ): ... | ||
|
|
||
| def publish_dataset( | ||
| self, *args: Any, name: Key | None = None, override: bool = False, **kwargs | ||
| ): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aside: offering four different syntaxes to achieve the same thing is in clear violation of the Zen of Python:
Removing some of these syntaxes however would be a breaking change and is beyond the scope of this PR. |
||
| """ | ||
| Publish named datasets to scheduler | ||
|
|
||
| This stores a named reference to a dask collection or list of futures | ||
| This stores a named reference to one or more dask collections or futures | ||
| on the scheduler. These references are available to other Clients | ||
| which can download the collection or futures with ``get_dataset``. | ||
| which can download the collections or futures with ``get_dataset``. | ||
|
|
||
| Datasets are not immediately computed. You may wish to call | ||
| ``Client.persist`` prior to publishing a dataset. | ||
| Datasets are not immediately computed. You should call ``persist`` prior to | ||
| publishing a dataset. Any unpersisted keys will be stored on the scheduler | ||
| uncomputed and returned as-is to the user when calling ``get_dataset``. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| args : list of objects to publish as name | ||
| args : One or more objects to publish as `name`. | ||
| Alternatively, a single dict of {name: object} pairs. | ||
| name : Dask key (str, int, float, or tuple thereof) | ||
| Name to publish `args` under | ||
| override: bool | ||
| False (default) to raise KeyError if a dataset with the same name already | ||
| exists on the scheduler; True to overwrite it. | ||
| kwargs : dict | ||
| named collections to publish on the scheduler | ||
|
|
||
|
|
@@ -2924,7 +2946,10 @@ def publish_dataset(self, *args, **kwargs): | |
|
|
||
| >>> df = dd.read_csv('s3://...') # doctest: +SKIP | ||
| >>> df = c.persist(df) # doctest: +SKIP | ||
| >>> c.publish_dataset(my_dataset=df) # doctest: +SKIP | ||
| >>> c.publish_dataset({"my_dataset": df}) # doctest: +SKIP | ||
|
|
||
| Alternative invocation | ||
| >>> c.publish_dataset(my_dataset=df) | ||
|
|
||
| Alternative invocation | ||
| >>> c.publish_dataset(df, name='my_dataset') | ||
|
|
@@ -2946,30 +2971,50 @@ def publish_dataset(self, *args, **kwargs): | |
| Client.unpublish_dataset | ||
| Client.persist | ||
| """ | ||
| return self.sync(self._publish_dataset, *args, **kwargs) | ||
| return self.sync( | ||
| self._publish_dataset, *args, name=name, override=override, **kwargs | ||
| ) | ||
|
|
||
| def unpublish_dataset(self, name, **kwargs): | ||
| async def _unpublish_dataset(self, name: Key | list[Key]) -> None: | ||
| names = name if isinstance(name, list) else [name] | ||
| uid = uuid.uuid4().hex | ||
| # Prevent race condition where the user calls get_dataset() and immediately | ||
| # afterwards unpublish_dataset(), thinking that they are holding a reference | ||
| # to the futures locally, but the futures haven't been registered on the | ||
| # scheduler yet by the time unpublish_dataset lands on the scheduler. | ||
| # This method can't be made into just a batched send command as it would | ||
| # create another race condition, where unpublish_dataset() followed by | ||
| # get_dataset() would return the just-deleted data. | ||
| self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid}) | ||
| await self.scheduler.publish_delete(names=names, uid=uid) | ||
|
|
||
| def unpublish_dataset(self, name: Key | list[Key], **kwargs): | ||
| """ | ||
| Remove named datasets from scheduler | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name : str | ||
| The name of the dataset to unpublish | ||
| name : Dask key (str, int, float, or tuple thereof), or list of keys | ||
| Name(s) of the dataset(s) to unpublish. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> c.list_datasets() # doctest: +SKIP | ||
| ['my_dataset'] | ||
| >>> c.unpublish_dataset('my_dataset') # doctest: +SKIP | ||
| ['foo', 'bar', 'baz'] | ||
| >>> c.unpublish_dataset('foo') # doctest: +SKIP | ||
| >>> c.list_datasets() # doctest: +SKIP | ||
| ['bar', 'baz'] | ||
| >>> c.unpublish_dataset(['bar', 'baz']) # doctest: +SKIP | ||
| >>> c.list_datasets() # doctest: +SKIP | ||
| [] | ||
|
|
||
| See Also | ||
| -------- | ||
| Client.publish_dataset | ||
| Client.list_datasets | ||
| Client.get_dataset | ||
| """ | ||
| return self.sync(self.scheduler.publish_delete, name=name, **kwargs) | ||
| return self.sync(self._unpublish_dataset, name=name, **kwargs) | ||
|
|
||
| def list_datasets(self, **kwargs): | ||
| """ | ||
|
|
@@ -2978,46 +3023,52 @@ def list_datasets(self, **kwargs): | |
| See Also | ||
| -------- | ||
| Client.publish_dataset | ||
| Client.unpublish_dataset | ||
| Client.get_dataset | ||
| """ | ||
| return self.sync(self.scheduler.publish_list, **kwargs) | ||
|
|
||
| async def _get_dataset(self, name, default=no_default): | ||
| with self.as_current(): | ||
| out = await self.scheduler.publish_get(name=name, client=self.id) | ||
| if out is None: | ||
| if default is no_default: | ||
| raise KeyError(f"Dataset '{name}' not found") | ||
| async def _get_dataset(self, name: Key | list[Key], default=no_default): | ||
| names = name if isinstance(name, list) else [name] | ||
| raw_outs = await self.scheduler.publish_get(names=names) | ||
|
|
||
| outs = [] | ||
| for out in raw_outs: | ||
| if out is None: | ||
| if default is no_default: | ||
| raise KeyError(f"Dataset '{name}' not found") | ||
| outs.append(default) | ||
| else: | ||
| return default | ||
| for fut in futures_of(out["data"]): | ||
| fut.bind_client(self) | ||
| for fut in futures_of(out["data"]): | ||
| fut.bind_client(self) | ||
| outs.append(out["data"]) | ||
|
|
||
| self._inform_scheduler_of_futures() | ||
| return out["data"] | ||
| return outs if isinstance(name, list) else outs[0] | ||
|
|
||
| def get_dataset(self, name, default=no_default, **kwargs): | ||
| def get_dataset(self, name: Key | list[Key], default=no_default, **kwargs): | ||
| """ | ||
| Get named dataset from the scheduler if present. | ||
| Return the default or raise a KeyError if not present. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| name : str | ||
| name of the dataset to retrieve | ||
| default : str | ||
| name : Dask key (str, int, float, or tuple thereof), or list of keys | ||
| name(s) of the dataset(s) to retrieve | ||
| default : Any | ||
| optional, not set by default | ||
| If set, do not raise a KeyError if the name is not present but | ||
| return this default | ||
| kwargs : dict | ||
| additional keyword arguments to _get_dataset | ||
|
Comment on lines
-3011
to
-3012
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kwargs actually contain the situationally useful |
||
|
|
||
| Returns | ||
| ------- | ||
| The dataset from the scheduler, if present | ||
| The dataset from the scheduler, if present. | ||
| If name is a list of keys, return a list of datasets in the same order. | ||
|
|
||
| See Also | ||
| -------- | ||
| Client.publish_dataset | ||
| Client.unpublish_dataset | ||
| Client.list_datasets | ||
| """ | ||
| return self.sync(self._get_dataset, name, default=default, **kwargs) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,22 @@ | |
| import asyncio | ||
| from collections import defaultdict | ||
| from collections.abc import MutableMapping | ||
| from typing import TYPE_CHECKING, TypedDict | ||
|
|
||
| from dask.typing import Key | ||
| from dask.utils import stringify | ||
|
|
||
| from distributed.protocol.serialize import Serialized | ||
| from distributed.utils import log_errors | ||
|
|
||
| if TYPE_CHECKING: | ||
| from distributed.scheduler import Scheduler | ||
|
|
||
|
|
||
| class PublishedDataset(TypedDict): | ||
| data: Serialized | ||
| keys: tuple[Key, ...] | ||
|
|
||
|
|
||
| class PublishExtension: | ||
| """An extension for the scheduler to manage collections | ||
|
|
@@ -18,51 +29,82 @@ class PublishExtension: | |
| * publish_delete | ||
| """ | ||
|
|
||
| scheduler: Scheduler | ||
| datasets: dict[Key, PublishedDataset] | ||
| _flush_received: defaultdict[bytes, asyncio.Event] | ||
|
|
||
| def __init__(self, scheduler): | ||
| self.scheduler = scheduler | ||
| self.datasets = dict() | ||
| self.datasets = {} | ||
|
|
||
| handlers = { | ||
| "publish_list": self.list, | ||
| "publish_put": self.put, | ||
| "publish_get": self.get, | ||
| "publish_delete": self.delete, | ||
| "publish_wait_flush": self.flush_wait, | ||
| } | ||
| stream_handlers = { | ||
| "publish_flush_batched_send": self.flush_receive, | ||
| "publish_flush_batched_send": self.flush_batched_send, | ||
| } | ||
|
|
||
| self.scheduler.handlers.update(handlers) | ||
| self.scheduler.stream_handlers.update(stream_handlers) | ||
| self._flush_received = defaultdict(asyncio.Event) | ||
|
|
||
| def flush_receive(self, uid, **kwargs): | ||
| def flush_batched_send(self, client: str, uid: bytes) -> None: | ||
| self._flush_received[uid].set() | ||
|
|
||
| async def flush_wait(self, uid): | ||
| await self._flush_received[uid].wait() | ||
| async def _sync_batched_send(self, uid: bytes) -> None: | ||
| """Wait for the client's batched-send to catch up with the same client's RPC | ||
| calls. Return True if the client is still connected; False otherwise. | ||
| """ | ||
| try: | ||
| await self._flush_received[uid].wait() | ||
| finally: | ||
| del self._flush_received[uid] | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed memory leak where each call to publish_dataset would create an immortal Note that there is still a potential memory leak left here, where the client flushes the batched comms, but then disconnects before the RPC call can be executed. |
||
|
|
||
| @log_errors | ||
| def put(self, keys=None, data=None, name=None, override=False, client=None): | ||
| if not override and name in self.datasets: | ||
| raise KeyError("Dataset %s already exists" % name) | ||
| self.scheduler.client_desires_keys(keys, f"published-{stringify(name)}") | ||
| self.datasets[name] = {"data": data, "keys": keys} | ||
|
Comment on lines
-48
to
-51
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed bug where |
||
| return {"status": "OK", "name": name} | ||
| async def put( | ||
| self, | ||
| names: tuple[Key, ...], | ||
| keys: tuple[tuple[Key, ...], ...], | ||
| data: tuple[Serialized, ...], | ||
| override: bool, | ||
| uid: bytes, | ||
| ) -> None: | ||
| await self._sync_batched_send(uid) | ||
|
|
||
| for name, keys_i, data_i in zip(names, keys, data, strict=True): | ||
| if name in self.datasets: | ||
| if override: | ||
| old = self.datasets.pop(name) | ||
| self.scheduler.client_releases_keys( | ||
| old["keys"], f"published-{stringify(name)}" | ||
| ) | ||
| else: | ||
| raise KeyError("Dataset %s already exists" % name) | ||
|
|
||
| self.scheduler.client_desires_keys(keys_i, f"published-{stringify(name)}") | ||
| self.datasets[name] = {"data": data_i, "keys": keys_i} | ||
|
|
||
| @log_errors | ||
| def delete(self, name=None): | ||
| out = self.datasets.pop(name, {"keys": []}) | ||
| self.scheduler.client_releases_keys(out["keys"], f"published-{stringify(name)}") | ||
| async def delete(self, names: tuple[Key, ...], uid: bytes) -> None: | ||
| await self._sync_batched_send(uid) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed race condition where the user calls get_dataset() and immediately afterwards unpublish_dataset(), thinking that they are holding a reference to the futures locally, but the scheduler hasn't noted it by the time unpublish_dataset lands on the scheduler. This caused the client holding a reference to forgotten keys. This is symmetrical to what #8577 fixed for publish_dataset. |
||
|
|
||
| for name in names: | ||
| out = self.datasets.pop(name, None) | ||
| if out is not None: | ||
| self.scheduler.client_releases_keys( | ||
| out["keys"], f"published-{stringify(name)}" | ||
| ) | ||
|
|
||
| @log_errors | ||
| def list(self, *args): | ||
| return list(sorted(self.datasets.keys(), key=str)) | ||
| def list(self) -> list[Key]: | ||
| return list(sorted(self.datasets, key=str)) | ||
|
|
||
| @log_errors | ||
| def get(self, name=None, client=None): | ||
| return self.datasets.get(name, None) | ||
| def get(self, names: tuple[Key, ...]) -> list[PublishedDataset | None]: # type: ignore[valid-type] | ||
| return [self.datasets.get(name, None) for name in names] | ||
|
|
||
|
|
||
| class Datasets(MutableMapping): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Frustratingly, mypy complains if you write
which would be more correct.