Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/osekit/core_api/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def move_files(self, folder: Path) -> None:
Destination folder in which the dataset files will be moved.

"""
for file in tqdm(self.files, disable=os.environ.get("DISABLE_TQDM", "")):
for file in tqdm(
self.files,
disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"),
):
file.move(folder)
self._folder = folder

Expand Down Expand Up @@ -186,7 +189,7 @@ def write(
last = len(self.data) if last is None else last
for data in tqdm(
self.data[first:last],
disable=os.environ.get("DISABLE_TQDM", ""),
disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"),
):
data.write(folder=folder, link=link)

Expand Down Expand Up @@ -348,7 +351,7 @@ def _get_base_data_from_files_timedelta_total(

for data_begin in tqdm(
date_range(begin, end, freq=freq, inclusive="left"),
disable=os.environ.get("DISABLE_TQDM", ""),
disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"),
):
data_end = Timestamp(data_begin + data_duration)
while (
Expand Down Expand Up @@ -395,7 +398,7 @@ def _get_base_data_from_files_timedelta_file(
files_chunk = []
for idx, file in tqdm(
enumerate(files[first:last]),
disable=os.environ.get("DISABLE_TQDM", ""),
disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"),
):
if file in files_chunk:
continue
Expand Down Expand Up @@ -492,7 +495,10 @@ def from_folder( # noqa: PLR0913
supported_file_extensions = []
valid_files = []
rejected_files = []
for file in tqdm(folder.iterdir(), disable=os.environ.get("DISABLE_TQDM", "")):
for file in tqdm(
folder.iterdir(),
disable=os.getenv("DISABLE_TQDM", "False").lower() in ("true", "1", "t"),
):
if file.suffix.lower() not in supported_file_extensions:
continue
try:
Expand Down
18 changes: 9 additions & 9 deletions src/osekit/public_api/export_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,24 +239,24 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--tqdm-disable",
required=False,
type=str,
default="true",
action=argparse.BooleanOptionalAction,
default=True,
help="Disable TQDM progress bars.",
)

parser.add_argument(
"--multiprocessing",
required=False,
type=str,
default="false",
action=argparse.BooleanOptionalAction,
default=False,
help="Turn multiprocessing on or off.",
)

parser.add_argument(
"--use-logging-setup",
required=False,
type=str,
default="false",
action=argparse.BooleanOptionalAction,
default=False,
help="Call osekit.setup_logging() before running the analysis.",
)

Expand Down Expand Up @@ -284,12 +284,12 @@ def main() -> None:
"""Export an analysis."""
args = create_parser().parse_args()

os.environ["DISABLE_TQDM"] = "" if not args.tqdm_disable else str(args.tqdm_disable)
os.environ["DISABLE_TQDM"] = str(args.tqdm_disable)

if args.use_logging_setup.lower() == "true":
if args.use_logging_setup:
setup_logging()

config.multiprocessing["is_active"] = args.multiprocessing.lower() == "true"
config.multiprocessing["is_active"] = args.multiprocessing
if (nb_processes := args.nb_processes) is not None:
config.multiprocessing["nb_processes"] = (
None if nb_processes.lower() == "none" else int(nb_processes)
Expand Down
13 changes: 12 additions & 1 deletion src/osekit/utils/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,16 @@ def progress(self) -> None:
return
self._status = JobStatus(self._status.value + 1)

def _build_arg_string(self) -> str:
"""Build a string representation of the job's arguments."""
arg_list = []
for key, value in self.script_args.items():
if isinstance(value, bool):
arg_list.append(f"--{'no-' if not value else ''}{key}")
else:
arg_list.append(f"--{key} {value}")
return " ".join(arg_list)

def write_pbs(self, path: Path) -> None:
"""Write a PBS file matching the job.

Expand Down Expand Up @@ -287,7 +297,8 @@ def write_pbs(self, path: Path) -> None:
for key, value in request.items()
if value
)
script = f"python {self.script_path} {' '.join(f'--{key} {value}' for key, value in self.script_args.items())}"

script = f"python {self.script_path} {self._build_arg_string()}"

pbs = f"{preamble}\n{request_str}\n{self.venv_activate_script}\n{script}"
with path.open("w") as file:
Expand Down
9 changes: 7 additions & 2 deletions src/osekit/utils/multiprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def multiprocess(
if bypass_multiprocessing or not config.multiprocessing["is_active"]:
return list(
func(element, *args, **kwargs)
for element in tqdm(enumerable, disable=os.environ.get("DISABLE_TQDM", ""))
for element in tqdm(
enumerable,
disable=os.getenv("DISABLE_TQDM", "False").lower()
in ("true", "1", "t"),
)
)

partial_func = partial(func, *args, **kwargs)
Expand All @@ -53,6 +57,7 @@ def multiprocess(
tqdm(
pool.imap(partial_func, enumerable),
total=len(list(enumerable)),
disable=os.environ.get("DISABLE_TQDM", ""),
disable=os.getenv("DISABLE_TQDM", "False").lower()
in ("true", "1", "t"),
),
)
150 changes: 77 additions & 73 deletions tests/test_export_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import os
import shlex
from pathlib import Path

import pytest
Expand All @@ -10,6 +11,7 @@
from osekit.core_api.spectro_dataset import SpectroDataset
from osekit.public_api import export_analysis
from osekit.public_api.export_analysis import create_parser
from osekit.utils.job import Job


def test_parser_factory() -> None:
Expand Down Expand Up @@ -64,81 +66,81 @@ def test_argument_defaults() -> None:
assert args.downsampling_quality is None
assert args.upsampling_quality is None
assert args.umask == 0o002 # noqa: PLR2004
assert args.tqdm_disable == "true"
assert args.multiprocessing == "false"
assert args.use_logging_setup == "false"
assert args.tqdm_disable
assert not args.multiprocessing
assert not args.use_logging_setup
assert args.nb_processes is None
assert args.dataset_json_path is None


@pytest.fixture
def script_arguments() -> dict:
return {
"--analysis": 2,
"--ads-json": r"path/to/ads.json",
"--sds-json": r"path/to/ads.json",
"--subtype": "FLOAT",
"--matrix-folder-path": r"out/matrix",
"--spectrogram-folder-path": r"out/spectro",
"--welch-folder-path": r"out/welch",
"--first": 10,
"--last": 12,
"--downsampling-quality": "HQ",
"--upsampling-quality": "VHQ",
"--umask": 0o022,
"--tqdm-disable": "False",
"--multiprocessing": "True",
"--nb-processes": "3", # String because it might be "None"
"--use-logging-setup": "True",
"--dataset-json-path": r"path/to/dataset.json",
"analysis": 2,
"ads-json": r"path/to/ads.json",
"sds-json": r"path/to/ads.json",
"subtype": "FLOAT",
"matrix-folder-path": r"out/matrix",
"spectrogram-folder-path": r"out/spectro",
"welch-folder-path": r"out/welch",
"first": 10,
"last": 12,
"downsampling-quality": "HQ",
"upsampling-quality": "VHQ",
"umask": 0o022,
"tqdm-disable": False,
"multiprocessing": True,
"nb-processes": "3", # String because it might be "None"
"use-logging-setup": True,
"dataset-json-path": r"path/to/dataset.json",
}


def test_specified_arguments(script_arguments: dict) -> None:
parser = create_parser()

args = parser.parse_args(
[str(arg_part) for arg in script_arguments.items() for arg_part in arg],
)

assert args.analysis == script_arguments["--analysis"]
assert args.ads_json == script_arguments["--ads-json"]
assert args.sds_json == script_arguments["--sds-json"]
assert args.subtype == script_arguments["--subtype"]
assert args.matrix_folder_path == script_arguments["--matrix-folder-path"]
assert args.spectrogram_folder_path == script_arguments["--spectrogram-folder-path"]
assert args.welch_folder_path == script_arguments["--welch-folder-path"]
assert args.first == script_arguments["--first"]
assert args.last == script_arguments["--last"]
assert args.downsampling_quality == script_arguments["--downsampling-quality"]
assert args.upsampling_quality == script_arguments["--upsampling-quality"]
assert args.umask == script_arguments["--umask"]
assert args.tqdm_disable == script_arguments["--tqdm-disable"]
assert args.multiprocessing == script_arguments["--multiprocessing"]
assert args.use_logging_setup == script_arguments["--use-logging-setup"]
assert args.nb_processes == script_arguments["--nb-processes"]
assert args.dataset_json_path == script_arguments["--dataset-json-path"]
parsed_str = Job(Path(), script_arguments)._build_arg_string()

args = parser.parse_args(shlex.split(parsed_str))

assert args.analysis == script_arguments["analysis"]
assert args.ads_json == script_arguments["ads-json"]
assert args.sds_json == script_arguments["sds-json"]
assert args.subtype == script_arguments["subtype"]
assert args.matrix_folder_path == script_arguments["matrix-folder-path"]
assert args.spectrogram_folder_path == script_arguments["spectrogram-folder-path"]
assert args.welch_folder_path == script_arguments["welch-folder-path"]
assert args.first == script_arguments["first"]
assert args.last == script_arguments["last"]
assert args.downsampling_quality == script_arguments["downsampling-quality"]
assert args.upsampling_quality == script_arguments["upsampling-quality"]
assert args.umask == script_arguments["umask"]
assert args.tqdm_disable == script_arguments["tqdm-disable"]
assert args.multiprocessing == script_arguments["multiprocessing"]
assert args.use_logging_setup == script_arguments["use-logging-setup"]
assert args.nb_processes == script_arguments["nb-processes"]
assert args.dataset_json_path == script_arguments["dataset-json-path"]


def test_main_script(monkeypatch: pytest.MonkeyPatch, script_arguments: dict) -> None:
class MockedArgs:
def __init__(self, *args: list, **kwargs: dict) -> None:
self.analysis = script_arguments["--analysis"]
self.ads_json = script_arguments["--ads-json"]
self.sds_json = script_arguments["--sds-json"]
self.subtype = script_arguments["--subtype"]
self.matrix_folder_path = script_arguments["--matrix-folder-path"]
self.spectrogram_folder_path = script_arguments["--spectrogram-folder-path"]
self.welch_folder_path = script_arguments["--welch-folder-path"]
self.first = script_arguments["--first"]
self.last = script_arguments["--last"]
self.downsampling_quality = script_arguments["--downsampling-quality"]
self.upsampling_quality = script_arguments["--upsampling-quality"]
self.umask = script_arguments["--umask"]
self.tqdm_disable = script_arguments["--tqdm-disable"]
self.multiprocessing = script_arguments["--multiprocessing"]
self.use_logging_setup = script_arguments["--use-logging-setup"]
self.nb_processes = script_arguments["--nb-processes"]
self.analysis = script_arguments["analysis"]
self.ads_json = script_arguments["ads-json"]
self.sds_json = script_arguments["sds-json"]
self.subtype = script_arguments["subtype"]
self.matrix_folder_path = script_arguments["matrix-folder-path"]
self.spectrogram_folder_path = script_arguments["spectrogram-folder-path"]
self.welch_folder_path = script_arguments["welch-folder-path"]
self.first = script_arguments["first"]
self.last = script_arguments["last"]
self.downsampling_quality = script_arguments["downsampling-quality"]
self.upsampling_quality = script_arguments["upsampling-quality"]
self.umask = script_arguments["umask"]
self.tqdm_disable = script_arguments["tqdm-disable"]
self.multiprocessing = script_arguments["multiprocessing"]
self.use_logging_setup = script_arguments["use-logging-setup"]
self.nb_processes = script_arguments["nb-processes"]
self.dataset_json_path = "none"

def return_mocked_attr(*args: list, **kwargs: dict) -> MockedArgs:
Expand Down Expand Up @@ -172,31 +174,33 @@ def mock_write_analysis(*args: list, **kwargs: dict) -> None:

export_analysis.main()

assert os.environ["DISABLE_TQDM"] == str(args.tqdm_disable)
assert config.multiprocessing["is_active"] is True
assert config.multiprocessing["nb_processes"] == 3
assert (
os.environ["DISABLE_TQDM"].lower() in ("true", "1", "t")
) == args.tqdm_disable
assert config.multiprocessing["is_active"]
assert config.multiprocessing["nb_processes"] == 3 # noqa: PLR2004
assert (
config.resample_quality_settings["downsample"]
== script_arguments["--downsampling-quality"]
== script_arguments["downsampling-quality"]
)
assert (
config.resample_quality_settings["upsample"]
== script_arguments["--upsampling-quality"]
== script_arguments["upsampling-quality"]
)
assert calls["ads_json"] == Path(script_arguments["--ads-json"])
assert calls["sds_json"] == Path(script_arguments["--sds-json"])
assert calls["ads_json"] == Path(script_arguments["ads-json"])
assert calls["sds_json"] == Path(script_arguments["sds-json"])

# write_analysis
assert calls["analysis_type"].value == script_arguments["--analysis"]
assert calls["ads"] == Path(script_arguments["--ads-json"])
assert calls["sds"] == Path(script_arguments["--sds-json"])
assert calls["subtype"] == script_arguments["--subtype"]
assert calls["matrix_folder_path"] == Path(script_arguments["--matrix-folder-path"])
assert calls["analysis_type"].value == script_arguments["analysis"]
assert calls["ads"] == Path(script_arguments["ads-json"])
assert calls["sds"] == Path(script_arguments["sds-json"])
assert calls["subtype"] == script_arguments["subtype"]
assert calls["matrix_folder_path"] == Path(script_arguments["matrix-folder-path"])
assert calls["spectrogram_folder_path"] == Path(
script_arguments["--spectrogram-folder-path"],
script_arguments["spectrogram-folder-path"],
)
assert calls["welch_folder_path"] == Path(script_arguments["--welch-folder-path"])
assert calls["first"] == script_arguments["--first"]
assert calls["last"] == script_arguments["--last"]
assert calls["welch_folder_path"] == Path(script_arguments["welch-folder-path"])
assert calls["first"] == script_arguments["first"]
assert calls["last"] == script_arguments["last"]
assert calls["link"] is True
assert calls["logger"] == logging.getLogger()
Loading