File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed
Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -231,6 +231,10 @@ def refresh(self):
231231
232232 dataloader_kwargs .pop ("force_has_data" , None )
233233
234+ # Ensure that dataset is loaded onto CPU if pin_memory is used
235+ if self ._pin_memory :
236+ self .dataset .to ("cpu" )
237+
234238 self ._pytorch_loader = torch .utils .data .DataLoader (
235239 self .dataset , ** dataloader_kwargs
236240 )
Original file line number Diff line number Diff line change @@ -81,7 +81,7 @@ def bg_count(self) -> float:
8181 def class_counts (self ) -> float :
8282 return self ._class_counts
8383
84- def to (self , device : str , non_blocking : bool = True ) -> None :
84+ def to (self , device : str | torch . device , non_blocking : bool = True ) -> None :
8585 self .store = self .store .to (device , non_blocking = non_blocking )
8686
8787 def set_spatial_transforms (self , transforms : Mapping [str , Any ] | None ) -> None :
You can’t perform that action at this time.
0 commit comments