Skip to content

Commit 2d064fd

Browse files
author
LittleMouse
committed
[update] update cosy_voice
1 parent bb48236 commit 2d064fd

File tree

1 file changed

+46
-49
lines changed
  • projects/llm_framework/main_cosy_voice/src

1 file changed

+46
-49
lines changed

projects/llm_framework/main_cosy_voice/src/main.cpp

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ class llm_task {
102102
model_ = config_body.at("model");
103103
response_format_ = config_body.at("response_format");
104104
enoutput_ = config_body.at("enoutput");
105-
prompt_ = config_body.at("prompt");
106105

107106
if (config_body.contains("input")) {
108107
if (config_body["input"].is_string()) {
@@ -327,7 +326,6 @@ class llm_task {
327326
g_llm_finished);
328327
};
329328

330-
// Start the LLM in a separate thread
331329
std::thread llm_thread(llm_thread_func);
332330

333331
int token_offset = 0;
@@ -336,8 +334,7 @@ class llm_task {
336334
SLOGE("Error, prompt speech token len %d < 75", prompt_token_len);
337335
return -1;
338336
}
339-
// int prompt_token_align_len = int(prompt_token_len / lToken2Wav.token_hop_len) * lToken2Wav.token_hop_len;
340-
int prompt_token_align_len = 75; // only support 75 now
337+
int prompt_token_align_len = 75;
341338

342339
std::vector<float> prompt_speech_embeds_flow1;
343340
prompt_speech_embeds_flow1.insert(prompt_speech_embeds_flow1.begin(), prompt_speech_embeds_flow.begin(),
@@ -381,28 +378,31 @@ class llm_task {
381378
token_offset += this_token_hop_len;
382379
output.insert(output.end(), speech.begin(), speech.end());
383380
std::string path = "output_" + std::to_string(i) + ".wav";
381+
std::vector<int16_t> wav_pcm_data;
382+
wav_pcm_data.resize(speech.size());
383+
for (size_t k = 0; k < speech.size(); ++k) {
384+
float sample = speech[k];
385+
if (sample > 1.0f) sample = 1.0f;
386+
if (sample < -1.0f) sample = -1.0f;
387+
wav_pcm_data[k] = static_cast<int16_t>(sample * 32767.0f);
388+
}
389+
if (out_callback_) {
390+
out_callback_(std::string(reinterpret_cast<char *>(wav_pcm_data.data()),
391+
wav_pcm_data.size() * sizeof(int16_t)),
392+
false);
393+
}
384394

385395
saveVectorAsWavFloat(speech, path, 24000, 1);
386396
i += 1;
387-
388-
}
389-
390-
else if (g_llm_finished.load()) {
397+
} else if (g_llm_finished.load()) {
391398
std::cout << "[Main/Token2Wav Thread] Buffer is empty and LLM finished. Exiting.\n";
392399
lock.unlock();
393400
break;
394-
}
395-
// Check exit condition: Buffer is empty and LLM is done
396-
else {
397-
// This else branch is technically not needed because the wait condition
398-
// ensures we only get here if one of the conditions is true.
399-
// But it's good practice to structure logic clearly.
400-
// In this specific loop, we will always process if we wake up.
401-
lock.unlock(); // Make sure to unlock if not processing
401+
} else {
402+
lock.unlock();
402403
}
403404
}
404405

405-
// Wait for the LLM thread to finish
406406
if (llm_thread.joinable()) {
407407
llm_thread.join();
408408
}
@@ -419,11 +419,23 @@ class llm_task {
419419
token.insert(token.end(), g_token_buffer.begin() + start, g_token_buffer.end());
420420
auto speech = lToken2Wav.infer(token, prompt_speech_embeds_flow1, prompt_feat1, spk_embeds,
421421
token_offset - start, true);
422-
// TODO: 另起一个线程处理生成的音频
423422
output.insert(output.end(), speech.begin(), speech.end());
424423
std::string path = "output_" + std::to_string(i) + ".wav";
425424
saveVectorAsWavFloat(speech, path, 24000, 1);
426425
saveVectorAsWavFloat(output, "output.wav", 24000, 1);
426+
std::vector<int16_t> wav_pcm_data;
427+
wav_pcm_data.resize(speech.size());
428+
for (size_t k = 0; k < speech.size(); ++k) {
429+
float sample = speech[k];
430+
if (sample > 1.0f) sample = 1.0f;
431+
if (sample < -1.0f) sample = -1.0f;
432+
wav_pcm_data[k] = static_cast<int16_t>(sample * 32767.0f);
433+
}
434+
if (out_callback_) {
435+
out_callback_(
436+
std::string(reinterpret_cast<char *>(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)),
437+
true);
438+
}
427439

428440
SLOGI("tts total use time: %.3f s", time_total.cost() / 1000);
429441
reset();
@@ -560,23 +572,26 @@ class llm_cosy_voice : public StackFlow {
560572
if (!(llm_task_obj && llm_channel)) {
561573
return;
562574
}
563-
SLOGI("send:%s", data.c_str());
575+
std::string base64_data;
576+
if (!data.empty()) {
577+
int len = encode_base64(data, base64_data);
578+
}
564579
if (llm_channel->enstream_) {
565580
static int count = 0;
566581
nlohmann::json data_body;
567582
data_body["index"] = count++;
568-
data_body["delta"] = data;
569-
if (!finish)
570-
data_body["delta"] = data;
583+
if (!data.empty())
584+
data_body["delta"] = base64_data;
571585
else
572-
data_body["delta"] = std::string("");
586+
data_body["delta"] = "";
573587
data_body["finish"] = finish;
574588
if (finish) count = 0;
575-
SLOGI("send stream");
576589
llm_channel->send(llm_task_obj->response_format_, data_body, LLM_NO_ERROR);
577590
} else if (finish) {
578-
SLOGI("send utf-8");
579-
llm_channel->send(llm_task_obj->response_format_, data, LLM_NO_ERROR);
591+
llm_channel->send(llm_task_obj->response_format_, base64_data, LLM_NO_ERROR);
592+
}
593+
if (llm_task_obj->response_format_.find("sys") != std::string::npos) {
594+
unit_call("audio", "queue_play", data);
580595
}
581596
}
582597

@@ -652,24 +667,6 @@ class llm_cosy_voice : public StackFlow {
652667
llm_task_obj->inference_async(sample_unescapeString(*next_data));
653668
}
654669

655-
void task_asr_data(const std::weak_ptr<llm_task> llm_task_obj_weak,
656-
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
657-
const std::string &data)
658-
{
659-
auto llm_task_obj = llm_task_obj_weak.lock();
660-
auto llm_channel = llm_channel_weak.lock();
661-
if (!(llm_task_obj && llm_channel)) {
662-
return;
663-
}
664-
if (object.find("stream") != std::string::npos) {
665-
if (sample_json_str_get(data, "finish") == "true") {
666-
llm_task_obj->inference_async(sample_json_str_get(data, "delta"));
667-
}
668-
} else {
669-
llm_task_obj->inference_async(data);
670-
}
671-
}
672-
673670
void kws_awake(const std::weak_ptr<llm_task> llm_task_obj_weak,
674671
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
675672
const std::string &data)
@@ -716,14 +713,14 @@ class llm_cosy_voice : public StackFlow {
716713
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1, std::placeholders::_2));
717714

718715
for (const auto input : llm_task_obj->inputs_) {
719-
if (input.find("llm") != std::string::npos) {
716+
if (input.find("tts") != std::string::npos) {
720717
llm_channel->subscriber_work_id(
721718
"", std::bind(&llm_cosy_voice::task_user_data, this, std::weak_ptr<llm_task>(llm_task_obj),
722719
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
723720
std::placeholders::_2));
724-
} else if ((input.find("asr") != std::string::npos) || (input.find("whisper") != std::string::npos)) {
721+
} else if ((input.find("llm") != std::string::npos) || (input.find("vlm") != std::string::npos)) {
725722
llm_channel->subscriber_work_id(
726-
input, std::bind(&llm_cosy_voice::task_asr_data, this, std::weak_ptr<llm_task>(llm_task_obj),
723+
input, std::bind(&llm_cosy_voice::task_user_data, this, std::weak_ptr<llm_task>(llm_task_obj),
727724
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
728725
std::placeholders::_2));
729726
} else if (input.find("kws") != std::string::npos) {
@@ -760,10 +757,10 @@ class llm_cosy_voice : public StackFlow {
760757
}
761758
auto llm_channel = get_channel(work_id);
762759
auto llm_task_obj = llm_task_[work_id_num];
763-
if (data.find("asr") != std::string::npos) {
760+
if (data.find("llm") != std::string::npos) {
764761
ret = llm_channel->subscriber_work_id(
765762
data,
766-
std::bind(&llm_cosy_voice::task_asr_data, this, std::weak_ptr<llm_task>(llm_task_obj),
763+
std::bind(&llm_cosy_voice::task_user_data, this, std::weak_ptr<llm_task>(llm_task_obj),
767764
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1, std::placeholders::_2));
768765
llm_task_obj->inputs_.push_back(data);
769766
} else if (data.find("kws") != std::string::npos) {

0 commit comments

Comments
 (0)