Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ dev = [
"coverage>=7.11.0",
]

[tool.ruff]
exclude = ["scripts/**/*.py"]

[tool.ruff.lint.flake8-copyright]
author = "OSmOSE"

Expand Down
4 changes: 2 additions & 2 deletions src/post_processing/dataclass/data_aplose.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def set_ax(

return ax

def overview(self) -> None:
def overview(self, annotator: list[str] | None = None) -> None:
"""Overview of an APLOSE formatted DataFrame."""
overview(self.df)
overview(self.df, annotator)

def detection_perf(
self,
Expand Down
19 changes: 19 additions & 0 deletions src/post_processing/utils/filtering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def find_delimiter(file: Path) -> str:
return delimiter


def filter_strong_detection(
df: DataFrame,
) -> DataFrame:
"""Filter strong detections of a DataFrame."""
if "type" in df.columns:
df = df[df["type"] == "WEAK"]
elif "is_box" in df.columns:
df = df[df["is_box"] == 0]
else:
msg = "Could not determine annotation type."
raise ValueError(msg)
if df.empty:
msg = "No weak detection found."
raise ValueError(msg)
return df


def filter_by_time(
df: DataFrame,
begin: Timestamp | None,
Expand Down Expand Up @@ -333,6 +350,8 @@ def load_detections(filters: DetectionFilter) -> DataFrame:

"""
df = read_dataframe(filters.detection_file)
if filters.box:
df = filter_strong_detection(df)
df = filter_by_time(df, filters.begin, filters.end)
df = filter_by_annotator(df, annotator=filters.annotator)
df = filter_by_label(df, label=filters.annotation)
Expand Down
4 changes: 3 additions & 1 deletion src/post_processing/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def detection_perf(
timestamps: list[Timestamp] | None = None,
*,
ref: tuple[str, str],
) -> None:
) -> tuple[float, float, float]:
"""Compute performances metrics for detection.

Performances are computed with a reference annotator in
Expand Down Expand Up @@ -128,6 +128,8 @@ def detection_perf(
logging.info(f"Recall: {recall:.2f}")
logging.info(f"F-score: {f_score:.2f}")

return precision, recall, f_score


def _map_datetimes_to_vector(df: DataFrame, timestamps: list[int]) -> ndarray:
"""Map datetime ranges to a binary vector indicating overlap with timestamp bins.
Expand Down
13 changes: 11 additions & 2 deletions src/post_processing/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
round_begin_end_timestamps,
timedelta_to_str,
)
from post_processing.utils.filtering_utils import get_max_time, get_timezone
from post_processing.utils.filtering_utils import (
get_max_time,
get_timezone,
filter_by_annotator,
)
from post_processing.utils.metrics_utils import normalize_counts_by_effort

if TYPE_CHECKING:
Expand Down Expand Up @@ -368,15 +372,20 @@ def heatmap(df: DataFrame,
ax.set_xlabel(f"Time ({bin_size_str} bin)")


def overview(df: DataFrame) -> None:
def overview(df: DataFrame, annotator: list[str] | None = None) -> None:
"""Overview of an APLOSE formatted DataFrame.

Parameters
----------
df: DataFrame
The Dataframe to analyse.
annotator: list[str]
List of annotators.

"""
if annotator is not None:
df = filter_by_annotator(df, annotator)

summary_label = (
df.groupby("annotation")["annotator"] # noqa: PD010
.apply(Counter)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def sample_yaml(tmp_path: Path,
"end": None,
"annotator": "ann1",
"annotation": "lbl1",
"box": False,
"box": True,
"timestamp_file": f"{sample_csv_timestamp}",
"user_sel": "all",
"f_min": None,
Expand Down
5 changes: 4 additions & 1 deletion tests/test_DataAplose.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def test_set_ax(sample_df: DataFrame) -> None:
assert isinstance(locator, mdates.HourLocator)


def test_from_yaml(sample_yaml: Path, sample_df: DataFrame) -> None:
def test_from_yaml(
sample_yaml: Path,
sample_df: DataFrame,
) -> None:
df_from_yaml = DataAplose.from_yaml(file=sample_yaml).df
df_expected = DataAplose(sample_df).filter_df(annotator="ann1", label="lbl1").reset_index(drop=True)
assert df_from_yaml.equals(df_expected)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_DetectionFilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_from_yaml(sample_yaml: Path,
"end": None,
"annotator": "ann1",
"annotation": "lbl1",
"box": False,
"box": True,
"timestamp_file": f"{sample_csv_timestamp}",
"user_sel": "all",
"f_min": None,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_filtering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from post_processing.utils.filtering_utils import (
filter_by_annotator,
filter_strong_detection,
filter_by_freq,
filter_by_label,
filter_by_score,
Expand Down Expand Up @@ -144,6 +145,17 @@ def test_filter_by_score_missing_column(sample_df: DataFrame) -> None:
filter_by_score(df, 0.5)


# filter_weak_strong_detection
def test_filter_weak_only(sample_df: DataFrame) -> None:
df = filter_strong_detection(sample_df)
assert set(df["is_box"]) == {0}


def test_filter_weak_empty(sample_df: DataFrame) -> None:
with pytest.raises(ValueError, match="No weak detection found"):
filter_strong_detection(sample_df[sample_df["is_box"] == 1])


def test_get_annotators(sample_df: DataFrame) -> None:
annotators = get_annotators(sample_df)
expected = sorted(set(sample_df["annotator"]))
Expand Down