Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,18 @@ basic-pitch --help

**predict()**

Import `basic-pitch` into your own Python code and run the [`predict`](basic_pitch/inference.py) functions directly, providing an `<input-audio-path>` and returning the model's prediction results:
Import `basic-pitch` into your own Python code and run the [`predict`](basic_pitch/inference.py) functions directly, providing an `<input-audio-path>` or an `<array-of-samples>` and returning the model's prediction results:

```python
from basic_pitch.inference import predict
from basic_pitch import ICASSP_2022_MODEL_PATH

# get model predictions given an audio file
model_output, midi_data, note_events = predict(<input-audio-path>)

# or alternatively, provide an array of samples
audio_array, sample_rate = librosa.load(<input-audio-path>, mono=True, duration=10.0, offset=5.0)
model_output, midi_data, note_events = predict(audio_array, sample_rate)

```

- `<minimum-frequency>` & `<maximum-frequency>` (*float*s) set the maximum and minimum allowed note frequency, in Hz, returned by the model. Pitch events with frequencies outside of this range will be excluded from the prediction results.
Expand Down
40 changes: 30 additions & 10 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def window_audio_file(


def get_audio_input(
audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int
audio_path_or_array: Union[pathlib.Path, str, np.ndarray], sample_rate: int, overlap_len: int, hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]:
"""
Read wave file (as mono), pad appropriately, and return as
Expand All @@ -228,8 +228,20 @@ def get_audio_input(

"""
assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"

audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
# if a numpy array of samples is provided, use it directly
if isinstance(audio_path_or_array, np.ndarray):
audio_original = audio_path_or_array
if sample_rate is None:
raise ValueError("Sample rate must be provided when input is an array of audio samples.")
# resample audio if required
elif sample_rate != AUDIO_SAMPLE_RATE:
audio_original = librosa.resample(audio_original, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE)
# convert to mono if necessary
if audio_original.ndim != 1:
audio_original = librosa.to_mono(audio_path_or_array)
# load audio file
else:
audio_original, _ = librosa.load(str(audio_path_or_array), sr=AUDIO_SAMPLE_RATE, mono=True)

original_length = audio_original.shape[0]
audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
Expand Down Expand Up @@ -267,14 +279,16 @@ def unwrap_output(


def run_inference(
audio_path: Union[pathlib.Path, str],
audio_path_or_array: Union[pathlib.Path, str, np.ndarray],
sample_rate: None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use Optional[int] type

model_or_model_path: Union[Model, pathlib.Path, str],
debug_file: Optional[pathlib.Path] = None,
) -> Dict[str, np.array]:
"""Run the model on the input audio path.

Args:
audio_path: The audio to run inference on.
audio_path_or_array: The audio to run inference on. Can be either the path to an audio file or a numpy array of audio samples.
sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array.
Comment on lines +290 to +291
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same docstring comments apply here too

model_or_model_path: A loaded Model or path to a serialized model to load.
debug_file: An optional path to output debug data to. Useful for testing/verification.

Expand All @@ -292,7 +306,7 @@ def run_inference(
hop_size = AUDIO_N_SAMPLES - overlap_len

output: Dict[str, Any] = {"note": [], "onset": [], "contour": []}
for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size):
for audio_windowed, _, audio_original_length in get_audio_input(audio_path_or_array, sample_rate, overlap_len, hop_size):
for k, v in model.predict(audio_windowed).items():
output[k].append(v)

Expand Down Expand Up @@ -415,7 +429,8 @@ def save_note_events(


def predict(
audio_path: Union[pathlib.Path, str],
audio_path_or_array: Union[pathlib.Path, str, np.ndarray],
sample_rate: int = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use Optional[int] type

model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH,
onset_threshold: float = 0.5,
frame_threshold: float = 0.3,
Expand All @@ -426,6 +441,7 @@ def predict(
melodia_trick: bool = True,
debug_file: Optional[pathlib.Path] = None,
midi_tempo: float = 120,
verbose: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's a good idea, although basic pitch is already pretty verbose by default. I think it is fine in this PR to add a few logs lines without needing to control these with a new verbose parameter. We can think of controlling the verbosity in future PRs in my opinion :)

) -> Tuple[
Dict[str, np.array],
pretty_midi.PrettyMIDI,
Expand All @@ -434,7 +450,8 @@ def predict(
"""Run a single prediction.

Args:
audio_path: File path for the audio to run inference on.
audio_path_or_array: File path for the audio to run inference on or array of audio samples.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be great to add information on the expected input shape, and enforce it right at the beginning of the method. It looks like you're merging channels if multiple are provided, it could be worth adding a note about that in the docstring

sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array.
sample_rate: Mandatory if audio_path_or_array is a np array. it should represent the sample rate of the provided array. Ignored if `audio_path_or_array` is a string

nit

model_or_model_path: A loaded Model or path to a serialized model to load.
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
Expand All @@ -449,9 +466,12 @@ def predict(
"""

with no_tf_warnings():
print(f"Predicting MIDI for {audio_path}...")
if isinstance(audio_path_or_array, np.ndarray) and verbose:
print("Predicting MIDI ...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("Predicting MIDI ...")
print("Predicting MIDI for input audio array of shape XX")

elif verbose:
print(f"Predicting MIDI for {audio_path_or_array}...")

model_output = run_inference(audio_path, model_or_model_path, debug_file)
model_output = run_inference(audio_path_or_array, sample_rate, model_or_model_path, debug_file)
min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP)))
midi_data, note_events = infer.model_output_to_notes(
model_output,
Expand Down