-
Notifications
You must be signed in to change notification settings - Fork 403
add option to provide audio samples for prediction #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -213,7 +213,7 @@ def window_audio_file( | |||||
|
|
||||||
|
|
||||||
| def get_audio_input( | ||||||
| audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int | ||||||
| audio_path_or_array: Union[pathlib.Path, str, np.ndarray], sample_rate: int, overlap_len: int, hop_size: int | ||||||
| ) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]: | ||||||
| """ | ||||||
| Read wave file (as mono), pad appropriately, and return as | ||||||
|
|
@@ -228,8 +228,20 @@ def get_audio_input( | |||||
|
|
||||||
| """ | ||||||
| assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}" | ||||||
|
|
||||||
| audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) | ||||||
| # if a numpy array of samples is provided, use it directly | ||||||
| if isinstance(audio_path_or_array, np.ndarray): | ||||||
| audio_original = audio_path_or_array | ||||||
| if sample_rate is None: | ||||||
| raise ValueError("Sample rate must be provided when input is an array of audio samples.") | ||||||
| # resample audio if required | ||||||
| elif sample_rate != AUDIO_SAMPLE_RATE: | ||||||
| audio_original = librosa.resample(audio_original, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE) | ||||||
| # convert to mono if necessary | ||||||
| if audio_original.ndim != 1: | ||||||
| audio_original = librosa.to_mono(audio_path_or_array) | ||||||
| # load audio file | ||||||
| else: | ||||||
| audio_original, _ = librosa.load(str(audio_path_or_array), sr=AUDIO_SAMPLE_RATE, mono=True) | ||||||
|
|
||||||
| original_length = audio_original.shape[0] | ||||||
| audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original]) | ||||||
|
|
@@ -267,14 +279,16 @@ def unwrap_output( | |||||
|
|
||||||
|
|
||||||
| def run_inference( | ||||||
| audio_path: Union[pathlib.Path, str], | ||||||
| audio_path_or_array: Union[pathlib.Path, str, np.ndarray], | ||||||
| sample_rate: None, | ||||||
| model_or_model_path: Union[Model, pathlib.Path, str], | ||||||
| debug_file: Optional[pathlib.Path] = None, | ||||||
| ) -> Dict[str, np.array]: | ||||||
| """Run the model on the input audio path. | ||||||
|
|
||||||
| Args: | ||||||
| audio_path: The audio to run inference on. | ||||||
| audio_path_or_array: The audio to run inference on. Can be either the path to an audio file or a numpy array of audio samples. | ||||||
| sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array. | ||||||
|
Comment on lines
+290
to
+291
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the same docstring comments apply here too |
||||||
| model_or_model_path: A loaded Model or path to a serialized model to load. | ||||||
| debug_file: An optional path to output debug data to. Useful for testing/verification. | ||||||
|
|
||||||
|
|
@@ -292,7 +306,7 @@ def run_inference( | |||||
| hop_size = AUDIO_N_SAMPLES - overlap_len | ||||||
|
|
||||||
| output: Dict[str, Any] = {"note": [], "onset": [], "contour": []} | ||||||
| for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size): | ||||||
| for audio_windowed, _, audio_original_length in get_audio_input(audio_path_or_array, sample_rate, overlap_len, hop_size): | ||||||
| for k, v in model.predict(audio_windowed).items(): | ||||||
| output[k].append(v) | ||||||
|
|
||||||
|
|
@@ -415,7 +429,8 @@ def save_note_events( | |||||
|
|
||||||
|
|
||||||
| def predict( | ||||||
| audio_path: Union[pathlib.Path, str], | ||||||
| audio_path_or_array: Union[pathlib.Path, str, np.ndarray], | ||||||
| sample_rate: int = None, | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: use Optional[int] type |
||||||
| model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH, | ||||||
| onset_threshold: float = 0.5, | ||||||
| frame_threshold: float = 0.3, | ||||||
|
|
@@ -426,6 +441,7 @@ def predict( | |||||
| melodia_trick: bool = True, | ||||||
| debug_file: Optional[pathlib.Path] = None, | ||||||
| midi_tempo: float = 120, | ||||||
| verbose: bool = False | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: It's a good idea, although basic pitch is already pretty verbose by default. I think it is fine in this PR to add a few logs lines without needing to control these with a new |
||||||
| ) -> Tuple[ | ||||||
| Dict[str, np.array], | ||||||
| pretty_midi.PrettyMIDI, | ||||||
|
|
@@ -434,7 +450,8 @@ def predict( | |||||
| """Run a single prediction. | ||||||
|
|
||||||
| Args: | ||||||
| audio_path: File path for the audio to run inference on. | ||||||
| audio_path_or_array: File path for the audio to run inference on or array of audio samples. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: It would be great to add information on the expected input shape, and enforce it right at the beginning of the method. It looks like you're merging channels if multiple are provided, it could be worth adding a note about that in the docstring |
||||||
| sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array. | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit |
||||||
| model_or_model_path: A loaded Model or path to a serialized model to load. | ||||||
| onset_threshold: Minimum energy required for an onset to be considered present. | ||||||
| frame_threshold: Minimum energy requirement for a frame to be considered present. | ||||||
|
|
@@ -449,9 +466,12 @@ def predict( | |||||
| """ | ||||||
|
|
||||||
| with no_tf_warnings(): | ||||||
| print(f"Predicting MIDI for {audio_path}...") | ||||||
| if isinstance(audio_path_or_array, np.ndarray) and verbose: | ||||||
| print("Predicting MIDI ...") | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| elif verbose: | ||||||
| print(f"Predicting MIDI for {audio_path_or_array}...") | ||||||
|
|
||||||
| model_output = run_inference(audio_path, model_or_model_path, debug_file) | ||||||
| model_output = run_inference(audio_path_or_array, sample_rate, model_or_model_path, debug_file) | ||||||
| min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP))) | ||||||
| midi_data, note_events = infer.model_output_to_notes( | ||||||
| model_output, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use Optional[int] type