Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 213 additions & 5 deletions examples/stream/stream.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "common-sdl.h"
#include "common.h"
#include "sense-voice.h"
#include "silero-vad.h"
#include <algorithm>
#include <cassert>
#include <cstdio>
Expand All @@ -9,6 +10,41 @@
#include <string>
#include <thread>
#include <vector>
#include <cstring>

// WAV 文件头结构
struct WAVHeader {
char riff[4] = {'R', 'I', 'F', 'F'};
uint32_t file_size;
char wave[4] = {'W', 'A', 'V', 'E'};
char fmt[4] = {'f', 'm', 't', ' '};
uint32_t fmt_size = 16;
uint16_t audio_format = 1; // PCM
uint16_t num_channels = 1; // 单声道
uint32_t sample_rate = SENSE_VOICE_SAMPLE_RATE;
uint32_t byte_rate;
uint16_t block_align;
uint16_t bits_per_sample = 16;
char data[4] = {'d', 'a', 't', 'a'};
uint32_t data_size;

WAVHeader() {
byte_rate = sample_rate * num_channels * bits_per_sample / 8;
block_align = num_channels * bits_per_sample / 8;
file_size = 0; // 将在写入时更新
data_size = 0; // 将在写入时更新
}
};

// 将 float 音频数据转换为 16 位 PCM
void float_to_pcm16(const std::vector<float>& float_audio, std::vector<int16_t>& pcm16_audio) {
pcm16_audio.resize(float_audio.size());
for (size_t i = 0; i < float_audio.size(); ++i) {
// 限制在 [-1.0, 1.0] 范围内,然后转换为 16 位
float sample = std::max(-1.0f, std::min(1.0f, float_audio[i]));
pcm16_audio[i] = static_cast<int16_t>(sample * 32767.0f);
}
}

struct sense_voice_stream_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand All @@ -27,6 +63,10 @@ struct sense_voice_stream_params {
std::string language = "auto";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
std::string audio_out; // 用于输出音频流到文件

// silero-vad 参数
float speech_prob_threshold = 0.1f;
};


Expand All @@ -46,9 +86,11 @@ void sense_voice_stream_usage(int /*argc*/, char **argv, const sense_voice_strea
fprintf(stderr, " -l LANG, --language LANG [%-7s] [SenseVoice] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] [SenseVoice] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] [IO] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -o FNAME, --output-audio FNAME [%-7s] [IO] audio output file name\n", params.audio_out.c_str());
fprintf(stderr, " -ng, --no-gpu [%-7s] [SenseVoice] disable GPU inference\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] [SenseVoice] flash attention during inference\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " --use-itn [%-7s] [SenseVoice] Filter duplicate tokens when outputting\n", params.use_itn ? "true" : "false");
fprintf(stderr, " -spt --speech-prob-threshold [%-7.2f] [VAD] speech probability threshold for VAD\n", params.speech_prob_threshold);
fprintf(stderr, "\n");
}

Expand All @@ -70,6 +112,8 @@ static bool get_stream_params(int argc, char **argv, sense_voice_stream_params &
params.model = argv[++i];
} else if (arg == "-f" || arg == "--file") {
params.fname_out = argv[++i];
} else if (arg == "-o" || arg == "--output-audio") {
params.audio_out = argv[++i];
} else if (arg == "-ng" || arg == "--no-gpu") {
params.use_gpu = false;
} else if (arg == "-fa" || arg == "--flash-attn") {
Expand All @@ -88,6 +132,8 @@ static bool get_stream_params(int argc, char **argv, sense_voice_stream_params &
params.chunk_size = std::stoi(argv[++i]);
} else if (arg == "--use-itn") {
params.use_itn = true;
} else if (arg == "--speech-prob-threshold" || arg == "-spt") {
params.speech_prob_threshold = std::stof(argv[++i]);
}

else {
Expand All @@ -106,6 +152,22 @@ void sense_voice_free(struct sense_voice_context *ctx) {

ggml_backend_buffer_free(ctx->model.buffer);

// 释放VAD相关资源 - 添加空指针检查
if (ctx->state) {
if (ctx->state->vad_ctx) {
ggml_free(ctx->state->vad_ctx);
ctx->state->vad_ctx = nullptr;
}
if (ctx->state->vad_lstm_hidden_state_buffer) {
ggml_backend_buffer_free(ctx->state->vad_lstm_hidden_state_buffer);
ctx->state->vad_lstm_hidden_state_buffer = nullptr;
}
if (ctx->state->vad_lstm_context_buffer) {
ggml_backend_buffer_free(ctx->state->vad_lstm_context_buffer);
ctx->state->vad_lstm_context_buffer = nullptr;
}
}

sense_voice_free_state(ctx->state);

delete ctx->model.model->encoder;
Expand All @@ -118,6 +180,11 @@ void sense_voice_free(struct sense_voice_context *ctx) {
int main(int argc, char **argv) {
sense_voice_stream_params params;
if (get_stream_params(argc, argv, params) == false) return 1;

// VAD 常量定义(与 zcr_main/main.cc 保持一致)
const int VAD_LSTM_STATE_MEMORY_SIZE = 2048;
const int VAD_LSTM_STATE_DIM = 128;

const int n_sample_step = params.chunk_size * 1e-3 * SENSE_VOICE_SAMPLE_RATE;
const int keep_nomute_step = params.chunk_size * params.min_mute_chunks * 1e-3 * SENSE_VOICE_SAMPLE_RATE;
const int max_nomute_step = params.chunk_size * params.max_nomute_chunks * 1e-3 * SENSE_VOICE_SAMPLE_RATE;
Expand All @@ -136,16 +203,75 @@ int main(int argc, char **argv) {

bool is_running = true;
struct sense_voice_context *ctx = sense_voice_small_init_from_file_with_params(params.model.c_str(), cparams);

if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize sense voice context\n");
return 3;
}

// 设置语言ID(重要:必须设置才能正确识别)
ctx->language_id = sense_voice_lang_id(params.language.c_str());
if (ctx->language_id == -1) {
fprintf(stderr, "warning: unknown language '%s', using auto detection\n", params.language.c_str());
ctx->language_id = sense_voice_lang_id("auto");
}

fprintf(stderr, "Language: %s (ID: %d)\n", params.language.c_str(), ctx->language_id);

std::vector<float> pcmf32_audio;
std::vector<double> pcmf32;
std::vector<double> pcmf32_tmp;// 传递给模型用

// 预留合理的容量,避免频繁的内存重分配
pcmf32.reserve(max_nomute_step * 2);
pcmf32_audio.reserve(params.chunk_size * SENSE_VOICE_SAMPLE_RATE / 1000 * 4); // 预留4个chunk的空间
pcmf32_tmp.reserve(max_nomute_step); // 为临时处理预留空间

// 文本输出文件
std::ofstream fout;
if (params.fname_out.length() > 0) {
fout.open(params.fname_out);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open output file '%s'!\n", __func__, params.fname_out.c_str());
fprintf(stderr, "%s: failed to open text output file '%s'!\n", __func__, params.fname_out.c_str());
return 1;
}
}

// 音频输出文件
std::ofstream audio_fout;
WAVHeader wav_header;
uint32_t total_samples = 0;
if (params.audio_out.length() > 0) {
audio_fout.open(params.audio_out, std::ios::binary);
if (!audio_fout.is_open()) {
fprintf(stderr, "%s: failed to open audio output file '%s'!\n", __func__, params.audio_out.c_str());
return 1;
}
// 写入 WAV 文件头(先写入占位符,稍后更新)
audio_fout.write(reinterpret_cast<const char*>(&wav_header), sizeof(WAVHeader));
}

// 初始化 VAD 状态(与 main.cc 保持一致)
if (params.use_vad) {
// init state
ctx->state->vad_ctx = ggml_init({VAD_LSTM_STATE_MEMORY_SIZE, nullptr, true});
ctx->state->vad_lstm_context = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM);
ctx->state->vad_lstm_hidden_state = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM);

ctx->state->vad_lstm_context_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0],
ggml_nbytes(ctx->state->vad_lstm_context)
+ ggml_backend_get_alignment(ctx->state->backends[0]));
ctx->state->vad_lstm_hidden_state_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0],
ggml_nbytes(ctx->state->vad_lstm_hidden_state)
+ ggml_backend_get_alignment(ctx->state->backends[0]));
auto context_alloc = ggml_tallocr_new(ctx->state->vad_lstm_context_buffer);
ggml_tallocr_alloc(&context_alloc, ctx->state->vad_lstm_context);

auto state_alloc = ggml_tallocr_new(ctx->state->vad_lstm_hidden_state_buffer);
ggml_tallocr_alloc(&state_alloc, ctx->state->vad_lstm_hidden_state);

ggml_set_zero(ctx->state->vad_lstm_context);
ggml_set_zero(ctx->state->vad_lstm_hidden_state);
}

{
Expand Down Expand Up @@ -186,6 +312,22 @@ int main(int argc, char **argv) {
// 获取新的音频,不论是否检测音频数据先把数据捞出来
std::this_thread::sleep_for(std::chrono::milliseconds(params.chunk_size));
audio.get(params.chunk_size, pcmf32_audio);

// 如果需要输出音频到文件,即时写入
if (audio_fout.is_open() && !pcmf32_audio.empty()) {
// 转换为 16 位 PCM 格式
std::vector<int16_t> pcm16_data;
float_to_pcm16(pcmf32_audio, pcm16_data);

// 写入 PCM 数据
audio_fout.write(reinterpret_cast<const char*>(pcm16_data.data()),
pcm16_data.size() * sizeof(int16_t));
audio_fout.flush(); // 确保即时输出

// 更新样本计数
total_samples += pcm16_data.size();
}

// 转移到pcmf32中,直接识别pcmf32
pcmf32.insert(pcmf32.end(), pcmf32_audio.begin(), pcmf32_audio.end());
pcmf32_audio.clear();
Expand All @@ -212,7 +354,8 @@ int main(int argc, char **argv) {
// 时间长度太长了直接换行重新开始
if (R_new_chunk >= max_nomute_step + idenitified_floats) {
printf("\n");
pcmf32_tmp = std::vector<double>(pcmf32.begin() + R_new_chunk - idenitified_floats, pcmf32.end());
// 优化:使用assign而不是重新构造vector
pcmf32_tmp.assign(pcmf32.begin() + R_new_chunk - idenitified_floats, pcmf32.end());
pcmf32 = pcmf32_tmp;
idenitified_floats = R_new_chunk;
}
Expand All @@ -224,7 +367,43 @@ int main(int argc, char **argv) {
// 新进来的所有chunk有可能导致序列分拆,需要注意
for (int i = L_new_chunk; i < R_new_chunk; i += n_sample_step) {
// int R_this_chunk = i + n_sample_step;
bool isnomute = vad_energy_zcr<double>(pcmf32.begin() + i - idenitified_floats, n_sample_step, SENSE_VOICE_SAMPLE_RATE, 1e-5, 0.2);
// 使用 silero-vad 替换 vad_energy_zcr
bool isnomute = false;

// VAD检测 - 使用zcr_main/main.cc的方法
int actual_chunk_size = n_sample_step;
int vad_chunk_size = std::max(640, actual_chunk_size);
std::vector<float> vad_chunk(vad_chunk_size, 0);

int start_idx = i - idenitified_floats;

// 确保不越界访问
for (int j = 0; j < actual_chunk_size && start_idx + j < pcmf32.size(); j++) {
if (start_idx + j >= 0) {
vad_chunk[j] = static_cast<float>(pcmf32[start_idx + j]) / 32768.0f;
}
}

// 如果实际chunk小于640,用最后一个样本值填充
if (actual_chunk_size < 640) {
float last_sample = (actual_chunk_size > 0) ? vad_chunk[actual_chunk_size - 1] : 0.0f;
for (int j = actual_chunk_size; j < 640; j++) {
vad_chunk[j] = last_sample;
}
}

float speech_prob = 0;
if (silero_vad_encode_internal(*ctx, *ctx->state, vad_chunk, params.n_threads, speech_prob)) {
isnomute = (speech_prob >= params.speech_prob_threshold);
// 调试信息:显示VAD结果
// if (speech_prob > 0.1) { // 只显示有意义的概率
// fprintf(stderr, "VAD: prob=%.3f, threshold=%.3f, isnomute=%d\n",
// speech_prob, params.threshold, isnomute);
// }
} else {
// 如果 VAD 处理失败,回退到vad_energy_zcr函数
isnomute = vad_energy_zcr<double>(pcmf32.begin() + start_idx, n_sample_step, SENSE_VOICE_SAMPLE_RATE);
}
// fprintf(stderr, "Mute || isnomute = %d, ML = %d, MR = %d, NML = %d, NMR = %d, R_new_chunk = %d, i = %d, size = %d, idenitified = %d\n", isnomute, mute.first, mute.second, nomute.first, nomute.second, R_new_chunk, i, pcmf32.size(), idenitified_floats);
if (nomute.first == -1) {
if (isnomute) nomute.first = i;
Expand Down Expand Up @@ -290,19 +469,48 @@ int main(int argc, char **argv) {
}
// 调整idenitified_floats并且减少pcmf32的长度
if (nomute.second > 0) {
pcmf32_tmp = std::vector<double>(pcmf32.begin() + (nomute.second - idenitified_floats), pcmf32.end());
// 优化:使用assign而不是重新构造vector
pcmf32_tmp.assign(pcmf32.begin() + (nomute.second - idenitified_floats), pcmf32.end());
pcmf32 = pcmf32_tmp;
idenitified_floats = nomute.second;
nomute.second = 0;
} else if (nomute.first == -1) {
pcmf32_tmp = std::vector<double>(pcmf32.begin() + (R_new_chunk - idenitified_floats), pcmf32.end());
// 优化:使用assign而不是重新构造vector
pcmf32_tmp.assign(pcmf32.begin() + (R_new_chunk - idenitified_floats), pcmf32.end());
pcmf32 = pcmf32_tmp;
idenitified_floats = R_new_chunk;
}

// 检查缓冲区大小并发出警告
if (pcmf32.size() > 2 * max_nomute_step) {
fprintf(stderr, "Warning: Audio buffer size (%.2f MB, %.2f sec) exceeds recommended limit. Consider optimizing processing speed.\n",
pcmf32.size() * sizeof(double) / 1e6,
pcmf32.size() / (double)SENSE_VOICE_SAMPLE_RATE);
}
}
fflush(stdout);
}
audio.pause();

// 关闭输出文件
if (fout.is_open()) {
fout.close();
}
if (audio_fout.is_open()) {
// 更新 WAV 文件头中的文件大小信息
wav_header.data_size = total_samples * sizeof(int16_t);
wav_header.file_size = sizeof(WAVHeader) - 8 + wav_header.data_size;

// 重新定位到文件开头并写入更新后的文件头
audio_fout.seekp(0, std::ios::beg);
audio_fout.write(reinterpret_cast<const char*>(&wav_header), sizeof(WAVHeader));
audio_fout.close();

fprintf(stderr, "Audio saved to '%s' (%u samples, %.2f seconds)\n",
params.audio_out.c_str(), total_samples,
(float)total_samples / SENSE_VOICE_SAMPLE_RATE);
}

sense_voice_free(ctx);
return 0;
}
22 changes: 22 additions & 0 deletions sense-voice/csrc/sense-voice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,17 @@ int sense_voice_pcm_to_feature_with_state(struct sense_voice_context *ctx,

// set input
{
// 释放之前的资源以防止内存泄漏
if (state->feature.ctx) {
ggml_free(state->feature.ctx);
state->feature.ctx = nullptr;
}
if (state->feature.buffer) {
ggml_backend_buffer_free(state->feature.buffer);
state->feature.buffer = nullptr;
}
state->feature.tensor = nullptr;

// init features
state->feature.n_len = state->feature.data.size() / (state->feature.n_mel * state->feature.lfr_m);
state->feature.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
Expand Down Expand Up @@ -816,6 +827,17 @@ int sense_voice_batch_pcm_to_feature_with_state(struct sense_voice_context *ctx,

// set input
{
// 释放之前的资源以防止内存泄漏
if (state->feature.ctx) {
ggml_free(state->feature.ctx);
state->feature.ctx = nullptr;
}
if (state->feature.buffer) {
ggml_backend_buffer_free(state->feature.buffer);
state->feature.buffer = nullptr;
}
state->feature.tensor = nullptr;

// init features
state->feature.n_len = state->feature.data.size() / (state->feature.n_mel * state->feature.lfr_m);
state->feature.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
Expand Down
Loading