Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,102 @@ bool write_wav(std::string path, float* samples, size_t size) {
std::cout << "WAV file '" << path << "' has been written" << std::endl;
return true;
}
const char* phrase = "Cada amanecer trae consigo nuevas oportunidades para crecer y aprender";

struct VITSParams {
int n_threads = -1;
std::string model_path = "./scripts/vits-spanish.ggml";
std::string phrase = "Cada amanecer trae consigo nuevas oportunidades para crecer y aprender";
std::string output_path = "./output.wav";
int64_t seed = -1;
};

void print_usage(int argc, char** argv) {
printf("usage: %s [arguments]\n", argv[0]);
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model MODEL path to model\n");
printf(" -p, --phrase PHRASE phrase to say\n");
printf(" -o, --output OUTPUT_PATH path to write result image to (default: ./output.wav)\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
}

void parse_args(int argc, char** argv, VITSParams& params) {
bool invalid_arg = false;
std::string arg;
for (int i = 1; i < argc; i++) {
arg = argv[i];
if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.n_threads = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.model_path = argv[i];
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.output_path = argv[i];
} else if (arg == "-p" || arg == "--phrase") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.phrase = argv[i];
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.seed = std::stoll(argv[i]);
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
}
}
if (invalid_arg) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
}
if (params.n_threads <= 0) {
unsigned int n_threads = std::thread::hardware_concurrency();
params.n_threads = (n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4);
}
if (params.seed < 0) {
srand((int)time(NULL));
params.seed = rand();
}
}

int main(int argc, char ** argv) {
vits_model * model = vits_model_load_from_file("./scripts/vits-spanish.ggml");
int main(int argc, char** argv) {
VITSParams params;
parse_args(argc, argv, params);

vits_model * model = vits_model_load_from_file(params.model_path.c_str());
assert(model != nullptr);

auto result = vits_model_process(model, phrase);
//rng.seed(params.seed);

auto result = vits_model_process(model, params.phrase.c_str());
//auto result = vits_model_process(model, params.phrase.c_str(), params.n_threads);
if (result.size > 0) {
printf("Generated: %d samples of audio %f %f %f\n", result.size, result.data[0], result.data[1],
result.data[2]);
printf("Wrote to file: %s\n", write_wav("output.wav", result.data, result.size) ? "true" : "false");
printf("Wrote to file: %s\n", write_wav(params.output_path, result.data, result.size) ? "true" : "false");
}
vits_free_result(result);
vits_free_model(model);
Expand Down