Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import os
import re
import inflect
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.audio_utils import log_mel_spectrogram
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation


Expand Down Expand Up @@ -95,7 +95,7 @@ def _extract_text_token_generator(self, text_generator):
def _extract_speech_token(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
feat = log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name:
feat.detach().cpu().numpy(),
Expand Down
4 changes: 2 additions & 2 deletions cosyvoice/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import pyarrow.parquet as pq
from io import BytesIO
import numpy as np
import whisper
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
from cosyvoice.utils.audio_utils import log_mel_spectrogram
from cosyvoice.utils.onnx import embedding_extractor, online_feature

AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
Expand Down Expand Up @@ -193,7 +193,7 @@ def compute_whisper_fbank(data, num_frames=-1, mode='train'):
if num_frames != -1:
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
sample['whisper_feat'] = log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
yield sample


Expand Down
5 changes: 2 additions & 3 deletions cosyvoice/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Optional
import torch
from transformers import AutoTokenizer
from whisper.tokenizer import Tokenizer

import tiktoken

LANGUAGES = {
Expand Down Expand Up @@ -213,7 +211,7 @@ def get_tokenizer(
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
):
if language is not None:
language = language.lower()
if language not in LANGUAGES:
Expand All @@ -233,6 +231,7 @@ def get_tokenizer(

encoding = get_encoding(name=encoding_name, num_languages=num_languages)

from whisper.tokenizer import Tokenizer
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)
Expand Down
56 changes: 56 additions & 0 deletions cosyvoice/utils/audio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torchaudio


def log_mel_spectrogram(audio, n_mels=128, n_fft=400, hop_length=160, sample_rate=16000):
"""Compute a log-mel spectrogram from a waveform tensor.

This is a drop-in replacement for ``whisper.log_mel_spectrogram`` that uses
only ``torch`` and ``torchaudio``, avoiding the heavy ``openai-whisper``
dependency. The output is numerically equivalent for the default Whisper
parameters (n_fft=400, hop_length=160, sample_rate=16000).

Args:
audio: 1-D or 2-D float tensor of raw audio at *sample_rate* Hz.
n_mels: Number of mel-frequency bins.
n_fft: FFT window size.
hop_length: Hop length for STFT.
sample_rate: Expected sample rate of *audio*.

Returns:
Tensor of shape ``(n_mels, n_frames)`` (if 1-D input) or
``(batch, n_mels, n_frames)`` (if 2-D input).
"""
window = torch.hann_window(n_fft).to(audio.device)
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_filters = torchaudio.functional.melscale_fbanks(
n_freqs=n_fft // 2 + 1,
f_min=0.0,
f_max=sample_rate / 2.0,
n_mels=n_mels,
sample_rate=sample_rate,
norm="slaney",
mel_scale="slaney",
).to(audio.device)

mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.amax() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32'
openai-whisper==20231117
protobuf==4.25
pyarrow==18.1.0
pydantic==2.7.0
Expand Down
4 changes: 2 additions & 2 deletions tools/extract_speech_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import onnxruntime
import numpy as np
import torchaudio
import whisper
from cosyvoice.utils.audio_utils import log_mel_spectrogram


def single_job(utt):
Expand All @@ -34,7 +34,7 @@ def single_job(utt):
logging.warning('do not support extract speech token for audio longer than 30s')
speech_token = []
else:
feat = whisper.log_mel_spectrogram(audio, n_mels=128)
feat = log_mel_spectrogram(audio, n_mels=128)
speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
return utt, speech_token
Expand Down