Skip to content

Commit 68e5987

Browse files
Ensure dataset is loaded onto CPU when pin_memory is used
1 parent 02373d6 commit 68e5987

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/cellmap_data/dataloader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def refresh(self):
231231

232232
dataloader_kwargs.pop("force_has_data", None)
233233

234+
# Ensure that dataset is loaded onto CPU if pin_memory is used
235+
if self._pin_memory:
236+
self.dataset.to("cpu")
237+
234238
self._pytorch_loader = torch.utils.data.DataLoader(
235239
self.dataset, **dataloader_kwargs
236240
)

src/cellmap_data/empty_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def bg_count(self) -> float:
8181
def class_counts(self) -> float:
8282
return self._class_counts
8383

84-
def to(self, device: str, non_blocking: bool = True) -> None:
84+
def to(self, device: str | torch.device, non_blocking: bool = True) -> None:
8585
self.store = self.store.to(device, non_blocking=non_blocking)
8686

8787
def set_spatial_transforms(self, transforms: Mapping[str, Any] | None) -> None:

0 commit comments

Comments
 (0)