Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,25 +1439,31 @@ def _get_name_dict(self):

return name_dict

def _can_use_partial_cache(self, partial_trajectory: Union[np.ndarray, None],
stop_index: Union[int, None]) -> bool:
def _can_use_partial_cache(self, partial_trajectory: Union[np.ndarray, None],
stop_index: Union[int, None],
total_points: Union[int, None] = None) -> bool:
"""Check if (partial) trajectory length is sufficent, or new data needs to be read."""
# no file updates = full read only
# otherwise check sufficient length is available
return (partial_trajectory is not None) and (not self._allow_file_updates or (stop_index is not None and len(partial_trajectory) >= stop_index))
if partial_trajectory is None:
return False
if not self._allow_file_updates:
return True
if stop_index is None:
# Unbounded reads can re-use cache if there were no updates
return total_points is not None and len(partial_trajectory) == total_points
return len(partial_trajectory) >= stop_index

def _get_trajectory(self, data_index, start_index = 0, stop_index = None):
if isinstance(self._data_2, dict):
self._verify_file_data()

# caching
current_data = self._data_2.get(data_index, None)
if self._can_use_partial_cache(current_data, stop_index):
nbr_points = self._data_2_info["nbr_points"]
if self._can_use_partial_cache(current_data, stop_index, nbr_points):
return self._data_2[data_index][start_index:stop_index]

file_position = self._data_2_info["file_position"]
sizeof_type = self._data_2_info["sizeof_type"]
nbr_points = self._data_2_info["nbr_points"]
nbr_variables = self._data_2_info["nbr_variables"]

# Account for sub-sets of data
Expand Down Expand Up @@ -1486,8 +1492,9 @@ def _get_diagnostics_trajectory(self, data_index, start_index = 0, stop_index =
""" Returns trajectory for the diagnostics variable that corresponds to index 'data_index'. """
self._verify_file_data()
# caching
current_data = self._data_3.get(data_index, None)
if self._can_use_partial_cache(current_data, stop_index):
current_data = self._data_3.get(data_index, None)
nbr_diag_points = int(self._data_3_info.shape[0])
if self._can_use_partial_cache(current_data, stop_index, nbr_diag_points):
return self._data_3[data_index][start_index:stop_index]
self._data_3[data_index] = self._read_trajectory_data(data_index, True, start_index, stop_index)
return self._data_3[data_index][start_index:stop_index]
Expand Down Expand Up @@ -1532,8 +1539,9 @@ def _get_interpolated_trajectory(self, data_index: int, start_index: int = 0, st
""" Returns an interpolated trajectory for variable of corresponding index 'data_index'. """
self._verify_file_data()

current_data = self._data_2.get(data_index, None)
if self._can_use_partial_cache(current_data, stop_index):
current_data = self._data_2.get(data_index, None)
nbr_diag_points = int(self._data_3_info.shape[0])
if self._can_use_partial_cache(current_data, stop_index, nbr_diag_points):
return self._data_2[data_index][start_index:stop_index]

diag_time_vector = self._get_diagnostics_trajectory(0, start_index, stop_index)
Expand Down
Loading