Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 69 additions & 18 deletions src/osekit/core_api/spectro_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(
self._db_ref = db_ref
self.v_lim = v_lim
self.colormap = "viridis" if colormap is None else colormap
self.previous_data = None
self.next_data = None

@staticmethod
def get_default_ax() -> plt.Axes:
Expand Down Expand Up @@ -249,11 +251,26 @@ def get_value(self) -> np.ndarray:
padding="zeros",
)

sx = self._merge_with_previous(sx)
sx = self._remove_overlap_with_next(sx)

if self.sx_dtype is float:
sx = abs(sx) ** 2

return sx

def _merge_with_previous(self, data: np.ndarray) -> np.ndarray:
if self.previous_data is None:
return data
olap = SpectroData.get_overlapped_bins(self.previous_data, self)
return np.hstack((olap, data[:, olap.shape[1] :]))

def _remove_overlap_with_next(self, data: np.ndarray) -> np.ndarray:
if self.next_data is None:
return data
olap = SpectroData.get_overlapped_bins(self, self.next_data)
return data[:, : -olap.shape[1]]

def get_welch(
self,
nperseg: int | None = None,
Expand Down Expand Up @@ -567,7 +584,7 @@ def split(self, nb_subdata: int = 2) -> list[SpectroData]:
self.audio_data.split_frames(start_frame=a, stop_frame=b)
for a, b in itertools.pairwise(split_frames)
]
return [
sd_split = [
SpectroData.from_audio_data(
data=ad,
fft=self.fft,
Expand All @@ -577,6 +594,12 @@ def split(self, nb_subdata: int = 2) -> list[SpectroData]:
for ad in ad_split
]

for sd1, sd2 in itertools.pairwise(sd_split):
sd1.next_data = sd2
sd2.previous_data = sd1

return sd_split

def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray:
if not all(
np.array_equal(items[0].file.freq, i.file.freq)
Expand All @@ -588,23 +611,51 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray:
if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1:
raise ValueError("Items don't have the same time resolution.")

output = items[0].get_value(fft=self.fft, sx_dtype=self.sx_dtype)
for item in items[1:]:
p1_le = self.fft.lower_border_end[1] - self.fft.p_min
output = np.hstack(
(
output[:, :-p1_le],
(
output[:, -p1_le:]
+ item.get_value(fft=self.fft, sx_dtype=self.sx_dtype)[
:,
:p1_le,
]
),
item.get_value(fft=self.fft, sx_dtype=self.sx_dtype)[:, p1_le:],
),
)
return output
return np.hstack(
[item.get_value(fft=self.fft, sx_dtype=self.sx_dtype) for item in items],
)

@classmethod
def get_overlapped_bins(cls, sd1: SpectroData, sd2: SpectroData) -> np.ndarray:
"""Compute the bins that overflow between the two spectro data.

The idea is that if there is a SpectroData sd2 that follows sd1,
sd1.get_value() will return the bins up to the first overlapping bin,
and sd2 will return the bins from the first overlapping bin.

Signal processing guys might want to burn my house to the ground for it,
but it seems to effectively resolve the issue we have with visible junction
between spectrogram zoomed parts.

Parameters
----------
sd1: SpectroData
The spectro data that ends before sd2.
sd2: SpectroData
The spectro data that starts after sd1.

Returns
-------
np.ndarray:
The overlapped bins.
If there are p bins, sd1 and sd2 values should be concatenated as:
np.hstack(sd1[:,:-p], result, sd2[:,p:])

"""
fft = sd1.fft
sd1_ub = fft.upper_border_begin(sd1.audio_data.shape[0])
sd1_bin_start = fft.nearest_k_p(k=sd1_ub[0], left=True)
sd2_lb = fft.lower_border_end
sd2_bin_stop = fft.nearest_k_p(k=sd2_lb[0], left=False)

ad1 = sd1.audio_data.split_frames(start_frame=sd1_bin_start)
ad2 = sd2.audio_data.split_frames(stop_frame=sd2_bin_stop)

sd_part1 = SpectroData.from_audio_data(ad1, fft=fft).get_value()
sd_part2 = SpectroData.from_audio_data(ad2, fft=fft).get_value()

p1_le = fft.lower_border_end[1] - fft.p_min
return sd_part1[:, -p1_le:] + sd_part2[:, :p1_le]

@classmethod
def from_files(
Expand Down
18 changes: 12 additions & 6 deletions src/osekit/core_api/spectro_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,25 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray:

start_bin = (
next(
idx
for idx, t in enumerate(time)
if self.begin + Timedelta(seconds=t) > start
(
idx
for idx, t in enumerate(time)
if self.begin + Timedelta(seconds=t) > start
),
1,
)
- 1
)
start_bin = max(start_bin, 0)

stop_bin = (
next(
idx
for idx, t in list(enumerate(time))[::-1]
if self.begin + Timedelta(seconds=t) < stop
(
idx
for idx, t in list(enumerate(time))[::-1]
if self.begin + Timedelta(seconds=t) < stop
),
len(time) - 2,
)
+ 1
)
Expand Down
16 changes: 10 additions & 6 deletions tests/test_spectro.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,14 @@ def test_spectro_parameters_in_npz_files(
pytest.param(
{
"duration": 6,
"sample_rate": 1_024,
"sample_rate": 28_000,
"nb_files": 1,
"date_begin": pd.Timestamp("2024-01-01 12:00:00"),
},
None,
None,
6,
ShortTimeFFT(hamming(1_024), 100, 1_024),
ShortTimeFFT(hamming(1_024), 100, 28_000),
id="6_seconds_split_in_6_with_overlap",
),
pytest.param(
Expand All @@ -289,7 +289,7 @@ def test_spectro_parameters_in_npz_files(
Instrument(end_to_end_db=150.0),
None,
6,
ShortTimeFFT(hamming(1_024), 100, 1_024),
ShortTimeFFT(hamming(1_024), 1_024, 1_024),
id="audio_data_with_instrument",
),
pytest.param(
Expand All @@ -302,7 +302,7 @@ def test_spectro_parameters_in_npz_files(
None,
Normalization.ZSCORE,
6,
ShortTimeFFT(hamming(1_024), 100, 1_024),
ShortTimeFFT(hamming(1_024), 1_024, 1_024),
id="audio_data_with_normalization",
),
],
Expand All @@ -328,7 +328,7 @@ def test_spectrogram_from_npz_files(

sd_split = sd.split(nb_chunks)

import soundfile as sf
import soundfile as sf # noqa: PLC0415

for spectro in sd_split:
spectro.write(tmp_path / "output")
Expand Down Expand Up @@ -985,10 +985,14 @@ def test_spectrodata_split(
colormap=colormap,
)
sd_parts = sd.split(parts)
for sd_part in sd_parts:
for idx, sd_part in enumerate(sd_parts):
assert sd_part.fft is sd.fft
assert sd_part.v_lim == sd.v_lim
assert sd_part.colormap == sd.colormap
if idx > 0:
assert sd_part.previous_data == sd_parts[idx - 1]
if idx < len(sd_parts) - 1:
assert sd_part.next_data == sd_parts[idx + 1]
assert sd_parts[0].begin == sd.begin
assert sd_parts[-1].end == sd.end

Expand Down