Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/post_processing/dataclass/detection_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class DetectionFilter:
f_max: float | None = None
score: float | None = None
box: bool = False
filename_format: str = None

@classmethod
def from_yaml(
Expand Down Expand Up @@ -86,7 +87,6 @@ def from_dict(
filters = []
for detection_file, filters_dict in parameters.items():
df_preview = read_dataframe(Path(detection_file), nrows=5)

filters_dict["timebin_origin"] = Timedelta(
max(df_preview["end_time"]),
"s",
Expand Down
131 changes: 104 additions & 27 deletions src/post_processing/utils/filtering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,32 +196,62 @@ def get_dataset(df: DataFrame) -> list[str]:
def get_canonical_tz(tz):
"""Return timezone of object as a pytz timezone."""
if isinstance(tz, datetime.timezone):
if tz == datetime.timezone.utc:
if tz == datetime.UTC:
return pytz.utc
offset_minutes = int(tz.utcoffset(None).total_seconds() / 60)
return pytz.FixedOffset(offset_minutes)
if hasattr(tz, "zone") and tz.zone:
return pytz.timezone(tz.zone)
if hasattr(tz, "key"):
return pytz.timezone(tz.key)
else:
msg = f"Unknown timezone: {tz}"
raise TypeError(msg)
msg = f"Unknown timezone: {tz}"
raise TypeError(msg)


def get_timezone(df: DataFrame):
"""Return timezone(s) from DataFrame."""
"""Return timezone(s) from APLOSE DataFrame.

Parameters
----------
df: DataFrame
APLOSE result Dataframe

Returns
-------
tzoffset: list[tzoffset]
list of timezones

"""
timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]}

if len(timezones) == 1:
return next(iter(timezones))
return list(timezones)


def check_timestamp(df: DataFrame, timestamp_audio: list[Timestamp]) -> None:
"""Check if provided timestamp_audio list is correctly formated.

Parameters
----------
df: DataFrame APLOSE results Dataframe.
timestamp_audio: A list of timestamps. Each timestamp is the start datetime of the
corresponding audio file for each detection in df.

"""
if timestamp_audio is None:
msg = "`timestamp_wav` is empty"
raise ValueError(msg)
if len(timestamp_audio) != len(df):
msg = "`timestamp_wav` is not the same length as `df`"
raise ValueError(msg)


def reshape_timebin(
df: DataFrame,
*,
timebin_new: Timedelta | None,
timestamp: list[Timestamp] | None = None,
timestamp_audio: list[Timestamp] | None = None,
) -> DataFrame:
"""Reshape an APLOSE result DataFrame according to a new time bin.

Expand All @@ -231,8 +261,9 @@ def reshape_timebin(
An APLOSE result DataFrame.
timebin_new: Timedelta
The size of the new time bin.
timestamp: list[Timestamp]
A list of Timestamp objects.
timestamp_audio: list[Timestamp]
A list of Timestamp objects corresponding to the shape
in which the data should be reshaped.

Returns
-------
Expand All @@ -247,14 +278,20 @@ def reshape_timebin(
if not timebin_new:
return df

check_timestamp(df, timestamp_audio)

annotators = get_annotators(df)
labels = get_labels(df)
max_freq = get_max_freq(df)
dataset = get_dataset(df)

if isinstance(get_timezone(df), list):
df["start_datetime"] = [to_datetime(elem, utc=True) for elem in df["start_datetime"]]
df["end_datetime"] = [to_datetime(elem, utc=True) for elem in df["end_datetime"]]
df["start_datetime"] = [to_datetime(elem, utc=True)
for elem in df["start_datetime"]
]
df["end_datetime"] = [to_datetime(elem, utc=True)
for elem in df["end_datetime"]
]

results = []
for ant in annotators:
Expand All @@ -264,13 +301,13 @@ def reshape_timebin(
if df_1annot_1label.empty:
continue

if timestamp is not None:
if timestamp_audio is not None:
# I do not remember if this is a regular case or not
# might need to be deleted
origin_timebin = timestamp[1] - timestamp[0]
step = int(timebin_new / origin_timebin)
time_vector = timestamp[0::step]
else:
#origin_timebin = timestamp_audio[1] - timestamp_audio[0]
#step = int(timebin_new / origin_timebin)
#time_vector = timestamp_audio[0::step]
#else:
t1 = min(df_1annot_1label["start_datetime"]).floor(timebin_new)
t2 = max(df_1annot_1label["end_datetime"]).ceil(timebin_new)
time_vector = date_range(start=t1, end=t2, freq=timebin_new)
Expand All @@ -280,14 +317,19 @@ def reshape_timebin(
filenames = df_1annot_1label["filename"].to_list()

# filename_vector
filename_vector = [
filenames[
bisect.bisect_left(ts_detect_beg, ts) - (ts not in ts_detect_beg)
]
if bisect.bisect_left(ts_detect_beg, ts) > 0
else filenames[0]
for ts in time_vector
]
filename_vector = []
for ts in time_vector:
if (bisect.bisect_left(ts_detect_beg, ts) > 0 and
bisect.bisect_left(ts_detect_beg, ts) != len(ts_detect_beg)):
idx = bisect.bisect_left(ts_detect_beg, ts)
filename_vector.append(
filenames[idx] if timestamp_audio[idx] <= ts else
filenames[idx - 1],
)
elif bisect.bisect_left(ts_detect_beg, ts) == len(ts_detect_beg):
filename_vector.append(filenames[-1])
else:
filename_vector.append(filenames[0])

# detection vector
detect_vec = [0] * len(time_vector)
Expand Down Expand Up @@ -327,8 +369,39 @@ def reshape_timebin(
),
)

return concat(results).sort_values(by=["start_datetime", "end_datetime", "annotator", "annotation"]).reset_index(drop=True)
return (concat(results).
sort_values(by=["start_datetime", "end_datetime",
"annotator", "annotation"]).reset_index(drop=True)
)


def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]:
"""Get start timestamps of the wav files of each detection contained in df.

Parameters.
----------
df: DataFrame
An APLOSE result DataFrame.
date_parser: str
date parser of the wav file

Returns
-------
List of Timestamps corresponding to the wav files' start timestamps
of each detection contained in df.

"""
tz = get_timezone(df)
try:
return [
to_datetime(
ts,
format=date_parser,
).tz_localize(tz) for ts in df["filename"]
]
except ValueError:
msg = """Could not parse timestamps from `df["filename"]`."""
raise ValueError(msg) from None

def ensure_in_list(value: str, candidates: list[str], label: str) -> None:
"""Check for non-valid elements of a list."""
Expand Down Expand Up @@ -366,10 +439,14 @@ def load_detections(filters: DetectionFilter) -> DataFrame:
df = filter_by_label(df, label=filters.annotation)
df = filter_by_freq(df, filters.f_min, filters.f_max)
df = filter_by_score(df, filters.score)
df = reshape_timebin(df, filters.timebin_new)
filename_ts = get_filename_timestamps(df, filters.filename_format)
df = reshape_timebin(df,
timebin_new=filters.timebin_new,
timestamp_audio=filename_ts
)

annotators = get_annotators(df)
if len(annotators) > 1 and filters.user_sel in ["union", "intersection"]:
if len(annotators) > 1 and filters.user_sel in {"union", "intersection"}:
df = intersection_or_union(df, user_sel=filters.user_sel)

return df.sort_values(by=["start_datetime", "end_datetime"]).reset_index(drop=True)
Expand All @@ -385,7 +462,7 @@ def intersection_or_union(df: DataFrame, user_sel: str) -> DataFrame:
if user_sel == "all":
return df

if user_sel not in ("intersection", "union"):
if user_sel not in {"intersection", "union"}:
msg = "'user_sel' must be either 'intersection' or 'union'"
raise ValueError(msg)

Expand Down
10 changes: 6 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,11 @@ def sample_csv_timestamp(tmp_path: Path, sample_status: DataFrame) -> Path:


@pytest.fixture
def sample_yaml(tmp_path: Path,
sample_csv_result: Path,
sample_csv_timestamp: Path,
) -> Path:
def sample_yaml(
tmp_path: Path,
sample_csv_result: Path,
sample_csv_timestamp: Path,
) -> Path:
yaml_content = {
f"{sample_csv_result}": {
"timebin_new": None,
Expand All @@ -177,6 +178,7 @@ def sample_yaml(tmp_path: Path,
"annotation": "lbl1",
"box": True,
"timestamp_file": f"{sample_csv_timestamp}",
"filename_format": "%Y_%m_%d_%H_%M_%S",
"user_sel": "all",
"f_min": None,
"f_max": None,
Expand Down
1 change: 1 addition & 0 deletions tests/test_DetectionFilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def test_from_yaml(sample_yaml: Path,
"annotation": "lbl1",
"box": True,
"timestamp_file": f"{sample_csv_timestamp}",
"filename_format": "%Y_%m_%d_%H_%M_%S",
"user_sel": "all",
"f_min": None,
"f_max": None,
Expand Down
57 changes: 43 additions & 14 deletions tests/test_filtering_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import csv
from pathlib import Path
from zoneinfo import ZoneInfo

import pytest
import pytz
from pandas import DataFrame, Timedelta, Timestamp, date_range, concat, to_datetime
from pandas import DataFrame, Timedelta, Timestamp, concat, to_datetime

from post_processing.utils.filtering_utils import (
filter_by_annotator,
Expand Down Expand Up @@ -296,7 +298,7 @@ def test_get_timezone_single(sample_df: DataFrame) -> None:
def test_get_timezone_several(sample_df: DataFrame) -> None:
new_row = {
"dataset": "dataset",
"filename": "filename",
"filename": "2025_01_26_06_20_00",
"start_time": 0,
"end_time": 2,
"start_frequency": 100,
Expand Down Expand Up @@ -382,7 +384,7 @@ def test_no_timebin_returns_original(sample_df: DataFrame) -> None:
def test_no_timebin_several_tz(sample_df: DataFrame) -> None:
new_row = {
"dataset": "dataset",
"filename": "filename",
"filename": "2025_01_26_06_20_00",
"start_time": 0,
"end_time": 2,
"start_frequency": 100,
Expand All @@ -398,13 +400,23 @@ def test_no_timebin_several_tz(sample_df: DataFrame) -> None:
[sample_df, DataFrame([new_row])],
ignore_index=False
)

df_out = reshape_timebin(sample_df, timebin_new=None)
timestamp_wav = to_datetime(sample_df["filename"],
format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(pytz.UTC)
df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=None)
assert df_out.equals(sample_df)


def test_no_timebin_original_timebin(sample_df: DataFrame) -> None:
df_out = reshape_timebin(sample_df, timebin_new=Timedelta("1min"))
tz = get_timezone(sample_df)
timestamp_wav = to_datetime(
sample_df["filename"],
format="%Y_%m_%d_%H_%M_%S"
).dt.tz_localize(tz)
df_out = reshape_timebin(
sample_df,
timestamp_audio=timestamp_wav,
timebin_new=Timedelta("1min"),
)
expected = DataFrame(
{
"dataset": ["sample_dataset"] * 18,
Expand Down Expand Up @@ -486,7 +498,16 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None:


def test_simple_reshape_hourly(sample_df: DataFrame) -> None:
df_out = reshape_timebin(sample_df, timebin_new=Timedelta(hours=1))
tz = get_timezone(sample_df)
timestamp_wav = to_datetime(
sample_df["filename"],
format="%Y_%m_%d_%H_%M_%S"
).dt.tz_localize(tz)
df_out = reshape_timebin(
sample_df,
timestamp_audio=timestamp_wav,
timebin_new=Timedelta(hours=1),
)
assert not df_out.empty
assert all(df_out["end_time"] == 3600.0)
assert df_out["end_frequency"].max() == sample_df["end_frequency"].max()
Expand All @@ -495,22 +516,27 @@ def test_simple_reshape_hourly(sample_df: DataFrame) -> None:


def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None:
df_out = reshape_timebin(sample_df, timebin_new=Timedelta(days=1))
tz = get_timezone(sample_df)
timestamp_wav = to_datetime(
sample_df["filename"],
format="%Y_%m_%d_%H_%M_%S"
).dt.tz_localize(tz)
df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1))
assert not df_out.empty
assert all(df_out["end_time"] == 86400.0)
assert df_out["start_datetime"].min() >= sample_df["start_datetime"].min().floor("D")
assert df_out["end_datetime"].max() <= sample_df["end_datetime"].max().ceil("D")


def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None:
t0 = sample_df["start_datetime"].min().floor("30min")
t1 = sample_df["end_datetime"].max().ceil("30min")
ts_vec = list(date_range(t0, t1, freq="30min"))

tz = get_timezone(sample_df)
timestamp_wav = to_datetime(sample_df["filename"],
format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz)
df_out = reshape_timebin(
sample_df,
timebin_new=Timedelta(hours=1),
timestamp=ts_vec,
timestamp_audio=timestamp_wav,
timebin_new=Timedelta(hours=1)
)

assert not df_out.empty
Expand All @@ -519,8 +545,11 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None:


def test_empty_result_when_no_matching(sample_df: DataFrame) -> None:
tz = get_timezone(sample_df)
timestamp_wav = to_datetime(sample_df["filename"],
format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz)
with pytest.raises(ValueError, match="DataFrame is empty"):
reshape_timebin(DataFrame(), Timedelta(hours=1))
reshape_timebin(DataFrame(), timestamp_audio=timestamp_wav, timebin_new=Timedelta(hours=1))


# %% ensure_no_invalid
Expand Down