Skip to content

Commit 3d6c342

Browse files
Create or update dataset (#23)
1 parent 5ca5491 commit 3d6c342

9 files changed

Lines changed: 599 additions & 508 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ repos:
99
hooks:
1010
- id: sync-with-uv
1111
- repo: https://github.com/charliermarsh/ruff-pre-commit
12-
rev: v0.14.7
12+
rev: v0.14.11
1313
hooks:
1414
- id: ruff-check
1515
args: [--fix, --exit-non-zero-on-fix]

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Changed
11+
12+
- `tilebox-datasets`: The `create_dataset` method of the `Client` has been removed. Use `create_or_update_dataset` instead.
13+
1014
### Fixed
1115

1216
- `tilebox-storage`: Fixed a bug on Windows, where the `CopernicusStorageClient` and `USGSLandsatStorageClient` were

tilebox-datasets/tilebox/datasets/aio/client.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from _tilebox.grpc.aio.channel import open_channel
44
from _tilebox.grpc.aio.error import with_pythonic_errors
5+
from _tilebox.grpc.error import NotFoundError
56
from tilebox.datasets.aio.dataset import DatasetClient
67
from tilebox.datasets.client import Client as BaseClient
78
from tilebox.datasets.client import token_from_env
@@ -33,33 +34,38 @@ def __init__(self, *, url: str = "https://api.tilebox.com", token: str | None =
3334
)
3435
self._client = BaseClient(service)
3536

36-
async def create_dataset(
37+
async def create_or_update_dataset(
3738
self,
3839
kind: DatasetKind,
3940
code_name: str,
40-
fields: list[FieldDict],
41+
fields: list[FieldDict] | None = None,
4142
*,
4243
name: str | None = None,
43-
description: str | None = None,
4444
) -> DatasetClient:
4545
"""Create a new dataset.
4646
4747
Args:
4848
kind: The kind of the dataset.
4949
code_name: The code name of the dataset.
50-
fields: The fields of the dataset.
50+
fields: The custom fields of the dataset.
5151
name: The name of the dataset. Defaults to the code name.
52-
description: A short description of the dataset. Optional.
5352
5453
Returns:
5554
The created dataset.
5655
"""
57-
if name is None:
58-
name = code_name
59-
if description is None:
60-
description = ""
6156

62-
return await self._client.create_dataset(kind, code_name, fields, name, description, DatasetClient)
57+
try:
58+
dataset = await self.dataset(code_name)
59+
except NotFoundError:
60+
return await self._client.create_dataset(kind, code_name, fields or [], name or code_name, DatasetClient)
61+
62+
return await self._client.update_dataset(
63+
kind,
64+
dataset._dataset.id, # noqa: SLF001
65+
fields or [],
66+
name or dataset._dataset.name, # noqa: SLF001
67+
DatasetClient,
68+
)
6369

6470
async def datasets(self) -> Group:
6571
"""Fetch all available datasets."""

tilebox-datasets/tilebox/datasets/client.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,32 @@ class Client:
2626
def __init__(self, service: TileboxDatasetService) -> None:
2727
self._service = service
2828

29-
def create_dataset( # noqa: PLR0913
30-
self, kind: DatasetKind, code_name: str, fields: list[FieldDict], name: str, summary: str, dataset_type: type[T]
29+
def create_dataset(
30+
self,
31+
kind: DatasetKind,
32+
code_name: str,
33+
fields: list[FieldDict] | None,
34+
name: str | None,
35+
py_dataset_class: type[T],
3136
) -> Promise[T]:
3237
return (
33-
self._service.create_dataset(kind, code_name, fields, name, summary)
38+
self._service.create_dataset(kind, code_name, name or code_name, fields or [])
3439
.then(_ensure_registered)
35-
.then(lambda dataset: dataset_type(self._service, dataset))
40+
.then(lambda dataset: py_dataset_class(self._service, dataset))
41+
)
42+
43+
def update_dataset(
44+
self,
45+
kind: DatasetKind,
46+
dataset_id: UUID,
47+
fields: list[FieldDict] | None,
48+
name: str | None,
49+
py_dataset_class: type[T],
50+
) -> Promise[T]:
51+
return (
52+
self._service.update_dataset(kind, dataset_id, name, fields or [])
53+
.then(_ensure_registered)
54+
.then(lambda dataset: py_dataset_class(self._service, dataset))
3655
)
3756

3857
def datasets(self, dataset_type: type[T]) -> Promise[Group]:

tilebox-datasets/tilebox/datasets/service.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
GetDatasetRequest,
3838
ListDatasetsRequest,
3939
Package,
40+
UpdateDatasetRequest,
4041
)
4142
from tilebox.datasets.datasets.v1.datasets_pb2_grpc import DatasetServiceStub
4243
from tilebox.datasets.query.pagination import Pagination
@@ -64,24 +65,70 @@ def __init__(
6465
self._data_ingestion_service = data_ingestion_service_stub
6566

6667
def create_dataset(
67-
self, kind: DatasetKind, code_name: str, fields: list[FieldDict], name: str, summary: str
68+
self, kind: DatasetKind, code_name: str, name: str, custom_fields: list[FieldDict]
6869
) -> Promise[Dataset]:
6970
"""Create a new dataset.
7071
7172
Args:
7273
kind: The kind of the dataset.
7374
code_name: The code name of the dataset.
74-
fields: The fields of the dataset.
7575
name: The name of the dataset.
76-
summary: A short summary of the dataset.
76+
fields: The custom fields of the dataset
7777
7878
Returns:
7979
The created dataset.
8080
"""
81-
dataset_type = DatasetType(kind, _REQUIRED_FIELDS_PER_DATASET_KIND[kind] + [Field.from_dict(f) for f in fields])
82-
req = CreateDatasetRequest(name=name, type=dataset_type.to_message(), summary=summary, code_name=code_name)
81+
dataset_type = DatasetType(
82+
kind, _REQUIRED_FIELDS_PER_DATASET_KIND[kind] + [Field.from_dict(f) for f in custom_fields]
83+
)
84+
req = CreateDatasetRequest(name=name, type=dataset_type.to_message(), code_name=code_name)
8385
return Promise.resolve(self._dataset_service.CreateDataset(req)).then(Dataset.from_message)
8486

87+
def update_dataset(
88+
self, kind: DatasetKind, dataset_id: UUID, name: str | None, custom_fields: list[FieldDict]
89+
) -> Promise[Dataset]:
90+
"""Update a dataset.
91+
92+
Args:
93+
kind: The kind of the dataset to update, cannot be changed.
94+
dataset_id: The id of the dataset to update, cannot be changed.
95+
name: The new name of the dataset.
96+
custom_fields: The new list of custom fields of the dataset.
97+
98+
Returns:
99+
The updated dataset.
100+
"""
101+
dataset_type = DatasetType(
102+
kind, _REQUIRED_FIELDS_PER_DATASET_KIND[kind] + [Field.from_dict(f) for f in custom_fields]
103+
)
104+
req = UpdateDatasetRequest(id=uuid_to_uuid_message(dataset_id), name=name, type=dataset_type.to_message())
105+
return Promise.resolve(self._dataset_service.UpdateDataset(req)).then(Dataset.from_message)
106+
107+
def create_or_update_dataset(
108+
self, kind: DatasetKind, code_name: str, name: str, custom_fields: list[FieldDict]
109+
) -> Promise[Dataset]:
110+
"""Create a new dataset, or update it if it already exists.
111+
112+
Args:
113+
kind: The kind of the dataset.
114+
code_name: The code name of the dataset.
115+
name: The name of the dataset.
116+
custom_fields: The custom fields of the dataset
117+
118+
Returns:
119+
The created or updated dataset.
120+
"""
121+
return (
122+
Promise.resolve(self._dataset_service.GetDataset(GetDatasetRequest(slug=code_name)))
123+
.then(
124+
did_fulfill=lambda dataset: self.update_dataset(
125+
kind, Dataset.from_message(dataset).id, name, custom_fields
126+
),
127+
did_reject=lambda _: self.create_dataset(kind, code_name, name, custom_fields),
128+
)
129+
.then(Dataset.from_message)
130+
)
131+
85132
def list_datasets(self) -> Promise[ListDatasetsResponse]:
86133
"""List all datasets and dataset groups."""
87134
return Promise.resolve(

tilebox-datasets/tilebox/datasets/sync/client.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from uuid import UUID
22

33
from _tilebox.grpc.channel import open_channel
4-
from _tilebox.grpc.error import with_pythonic_errors
4+
from _tilebox.grpc.error import NotFoundError, with_pythonic_errors
55
from tilebox.datasets.client import Client as BaseClient
66
from tilebox.datasets.client import token_from_env
77
from tilebox.datasets.data.datasets import DatasetKind, FieldDict
@@ -33,33 +33,38 @@ def __init__(self, *, url: str = "https://api.tilebox.com", token: str | None =
3333
)
3434
self._client = BaseClient(service)
3535

36-
def create_dataset(
36+
def create_or_update_dataset(
3737
self,
3838
kind: DatasetKind,
3939
code_name: str,
40-
fields: list[FieldDict],
40+
fields: list[FieldDict] | None = None,
4141
*,
4242
name: str | None = None,
43-
description: str | None = None,
4443
) -> DatasetClient:
4544
"""Create a new dataset.
4645
4746
Args:
4847
kind: The kind of the dataset.
4948
code_name: The code name of the dataset.
50-
fields: The fields of the dataset.
49+
fields: The custom fields of the dataset.
5150
name: The name of the dataset. Defaults to the code name.
52-
description: A short description of the dataset. Optional.
5351
5452
Returns:
5553
The created dataset.
5654
"""
57-
if name is None:
58-
name = code_name
59-
if description is None:
60-
description = ""
6155

62-
return self._client.create_dataset(kind, code_name, fields, name, description, DatasetClient).get()
56+
try:
57+
dataset = self.dataset(code_name)
58+
except NotFoundError:
59+
return self._client.create_dataset(kind, code_name, fields or [], name or code_name, DatasetClient).get()
60+
61+
return self._client.update_dataset(
62+
kind,
63+
dataset._dataset.id, # noqa: SLF001
64+
fields or [],
65+
name or dataset._dataset.name, # noqa: SLF001
66+
DatasetClient,
67+
).get()
6368

6469
def datasets(self) -> Group:
6570
"""Fetch all available datasets."""

tilebox-workflows/tilebox/workflows/runner/task_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,9 +559,9 @@ def submit_subtask(
559559
def submit_subtasks(
560560
self,
561561
tasks: Sequence[TaskInstance],
562+
depends_on: FutureTask | list[FutureTask] | None = None,
562563
cluster: str | None = None,
563564
max_retries: int = 0,
564-
depends_on: FutureTask | list[FutureTask] | None = None,
565565
) -> list[FutureTask]:
566566
return [
567567
self.submit_subtask(task, cluster=cluster, max_retries=max_retries, depends_on=depends_on) for task in tasks
@@ -575,7 +575,7 @@ def submit_batch(
575575
DeprecationWarning,
576576
stacklevel=2,
577577
)
578-
return self.submit_subtasks(tasks, cluster, max_retries)
578+
return self.submit_subtasks(tasks, cluster=cluster, max_retries=max_retries)
579579

580580
def progress(self, label: str | None = None) -> ProgressUpdate:
581581
if label == "":

tilebox-workflows/tilebox/workflows/task.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,11 @@ class ExecutionContext(ABC):
356356

357357
@abstractmethod
358358
def submit_subtask(
359-
self, task: Task, depends_on: list[FutureTask] | None = None, cluster: str | None = None, max_retries: int = 0
359+
self,
360+
task: Task,
361+
depends_on: FutureTask | list[FutureTask] | None = None,
362+
cluster: str | None = None,
363+
max_retries: int = 0,
360364
) -> FutureTask:
361365
"""Submit a subtask of the current task.
362366
@@ -374,7 +378,11 @@ def submit_subtask(
374378

375379
@abstractmethod
376380
def submit_subtasks(
377-
self, tasks: Sequence[Task], cluster: str | None = None, max_retries: int = 0
381+
self,
382+
tasks: Sequence[Task],
383+
depends_on: FutureTask | list[FutureTask] | None = None,
384+
cluster: str | None = None,
385+
max_retries: int = 0,
378386
) -> list[FutureTask]:
379387
"""Submit a batch of subtasks of the current task. Similar to `submit_subtask`, but for multiple tasks."""
380388

0 commit comments

Comments
 (0)