Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 150 additions & 11 deletions batbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
output_paths, metadata_path, metadata = spectrogram.compute(filepath)
"""

from os.path import exists, join
import concurrent.futures
from multiprocessing import Manager
from os.path import basename, exists, join, splitext
from pathlib import Path

import pooch
from tqdm import tqdm

from batbot import utils

Expand Down Expand Up @@ -60,10 +63,12 @@ def fetch(pull=False, config=None):

def pipeline(
filepath,
config=None,
# classifier_thresh=classifier.CONFIGS[None]['thresh'],
clean=True,
output_folder='.',
out_file_stem=None,
output_folder=None,
fast_mode=False,
force_overwrite=False,
quiet=False,
debug=False,
):
"""
Run the ML pipeline on a given WAV filepath and return the classification results
Expand Down Expand Up @@ -93,12 +98,138 @@ def pipeline(
Returns:
tuple ( float, list ( dict ) ): classifier score, list of time windows
"""

# Generate spectrogram
output_paths, metadata_path, metadata = spectrogram.compute(
filepath, output_folder=output_folder
output_paths, compressed_paths, metadata_path, metadata = spectrogram.compute(
filepath,
out_file_stem=out_file_stem,
output_folder=output_folder,
fast_mode=fast_mode,
force_overwrite=force_overwrite,
quiet=quiet,
debug=debug,
)

return output_paths, metadata_path
return output_paths, compressed_paths, metadata_path


def pipeline_multi_wrapper(
filepaths,
out_file_stems=None,
fast_mode=False,
force_overwrite=False,
worker_position=None,
quiet=False,
tqdm_lock=None,
):
"""Fault-tolerant wrapper for multiple inputs.

Args:
filepaths (_type_): _description_
out_file_stems (_type_, optional): _description_. Defaults to None.
fast_mode (bool, optional): _description_. Defaults to False.
force_overwrite (bool, optional): _description_. Defaults to False.

Returns:
_type_: _description_
"""

if out_file_stems is not None:
assert len(filepaths) == len(
out_file_stems
), 'Input filepaths and out_file_stems have different length.'
else:
out_file_stems = [None] * len(filepaths)

outputs = {'output_paths': [], 'compressed_paths': [], 'metadata_paths': [], 'failed_files': []}
# print(filepaths, out_file_stems)
if tqdm_lock is not None:
tqdm.set_lock(tqdm_lock)
for in_file, out_stem in tqdm(
zip(filepaths, out_file_stems),
desc='Processing, worker {}'.format(worker_position),
position=worker_position,
total=len(filepaths),
leave=True,
):
try:
output_paths, compressed_paths, metadata_path = pipeline(
in_file,
out_file_stem=out_stem,
fast_mode=fast_mode,
force_overwrite=force_overwrite,
quiet=quiet,
)
outputs['output_paths'].extend(output_paths)
outputs['compressed_paths'].extend(compressed_paths)
outputs['metadata_paths'].append(metadata_path)
except Exception as e:
outputs['failed_files'].append((str(in_file), e))

return tuple(outputs.values())


def parallel_pipeline(
in_file_chunks,
out_stem_chunks=None,
fast_mode=False,
force_overwrite=False,
num_workers=0,
threaded=False,
quiet=False,
desc=None,
):

if out_stem_chunks is None:
out_stem_chunks = [None] * len(in_file_chunks)

if len(in_file_chunks) == 0:
return None
else:
assert len(in_file_chunks) == len(
out_stem_chunks
), 'in_file_chunks and out_stem_chunks must have the same length.'

if threaded:
executor_cls = concurrent.futures.ThreadPoolExecutor
else:
executor_cls = concurrent.futures.ProcessPoolExecutor

num_workers = min(len(in_file_chunks), num_workers)

outputs = {'output_paths': [], 'compressed_paths': [], 'metadata_paths': [], 'failed_files': []}

lock_manager = Manager()
tqdm_lock = lock_manager.Lock()

with tqdm(total=len(in_file_chunks), disable=quiet, desc=desc) as progress:
with executor_cls(max_workers=num_workers) as executor:

futures = [
executor.submit(
pipeline_multi_wrapper,
filepaths=file_chunk,
out_file_stems=out_stem_chunk,
fast_mode=fast_mode,
force_overwrite=force_overwrite,
worker_position=index % num_workers,
quiet=quiet,
tqdm_lock=tqdm_lock,
)
for index, (file_chunk, out_stem_chunk) in enumerate(
zip(in_file_chunks, out_stem_chunks)
)
]

for future in concurrent.futures.as_completed(futures):
output_paths, compressed_paths, metadata_path, failed_files = future.result()
outputs['output_paths'].extend(output_paths)
outputs['compressed_paths'].extend(compressed_paths)
outputs['metadata_paths'].extend(metadata_path)
outputs['failed_files'].extend(failed_files)
progress.update(1)

return tuple(outputs.values())


def batch(
Expand Down Expand Up @@ -140,7 +271,7 @@ def batch(
# Run tiling
batch = {}
for filepath in filepaths:
_, _, metadata = spectrogram.compute(filepath)
_, _, _, metadata = spectrogram.compute(filepath)
batch[filepath] = metadata

raise NotImplementedError
Expand All @@ -164,7 +295,15 @@ def example():
assert exists(wav_filepath)

log.debug(f'Running pipeline on WAV: {wav_filepath}')
output = './output'
results = pipeline(wav_filepath, output_folder=output)

import time

output_stem = join('output', splitext(basename(wav_filepath))[0])
start_time = time.time()
results = pipeline(
wav_filepath, out_file_stem=output_stem, fast_mode=False, force_overwrite=True
)
stop_time = time.time()
print('Example pipeline completed in {} seconds.'.format(stop_time - start_time))

log.debug(results)
186 changes: 0 additions & 186 deletions batbot/batbot.py

This file was deleted.

Loading