|
2 | 2 | #include <chrono> |
3 | 3 | #include <filesystem> |
4 | 4 | #include <fstream> |
| 5 | +#include <future> |
5 | 6 | #include <iomanip> |
6 | 7 | #include <iostream> |
7 | 8 | #include <mutex> |
@@ -365,6 +366,18 @@ int main(int argc, const char** argv) { |
365 | 366 | return httplib::Server::HandlerResponse::Unhandled; |
366 | 367 | }); |
367 | 368 |
|
| 369 | + auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) { |
| 370 | + std::future_status ft_status; |
| 371 | + do { |
| 372 | + if (!ft.valid()) |
| 373 | + break; |
| 374 | + ft_status = ft.wait_for(std::chrono::milliseconds(1000)); |
| 375 | + if (req.is_connection_closed()) { |
| 376 | + sd_cancel_generation(sd_ctx, SD_CANCEL_ALL); |
| 377 | + } |
| 378 | + } while (ft_status != std::future_status::ready); |
| 379 | + }; |
| 380 | + |
368 | 381 | // root |
369 | 382 | svr.Get("/", [&](const httplib::Request&, httplib::Response& res) { |
370 | 383 | if (!svr_params.serve_html_path.empty()) { |
@@ -507,11 +520,13 @@ int main(int argc, const char** argv) { |
507 | 520 | sd_image_t* results = nullptr; |
508 | 521 | int num_results = 0; |
509 | 522 |
|
510 | | - { |
| 523 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
511 | 524 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
512 | 525 | results = generate_image(sd_ctx, &img_gen_params); |
513 | 526 | num_results = gen_params.batch_count; |
514 | | - } |
| 527 | + }); |
| 528 | + |
| 529 | + wait_for_generation(ft, sd_ctx, req); |
515 | 530 |
|
516 | 531 | for (int i = 0; i < num_results; i++) { |
517 | 532 | if (results[i].data == nullptr) { |
@@ -748,11 +763,13 @@ int main(int argc, const char** argv) { |
748 | 763 | sd_image_t* results = nullptr; |
749 | 764 | int num_results = 0; |
750 | 765 |
|
751 | | - { |
| 766 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
752 | 767 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
753 | 768 | results = generate_image(sd_ctx, &img_gen_params); |
754 | 769 | num_results = gen_params.batch_count; |
755 | | - } |
| 770 | + }); |
| 771 | + |
| 772 | + wait_for_generation(ft, sd_ctx, req); |
756 | 773 |
|
757 | 774 | json out; |
758 | 775 | out["created"] = static_cast<long long>(std::time(nullptr)); |
@@ -1062,11 +1079,13 @@ int main(int argc, const char** argv) { |
1062 | 1079 | sd_image_t* results = nullptr; |
1063 | 1080 | int num_results = 0; |
1064 | 1081 |
|
1065 | | - { |
| 1082 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
1066 | 1083 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
1067 | 1084 | results = generate_image(sd_ctx, &img_gen_params); |
1068 | 1085 | num_results = gen_params.batch_count; |
1069 | | - } |
| 1086 | + }); |
| 1087 | + |
| 1088 | + wait_for_generation(ft, sd_ctx, req); |
1070 | 1089 |
|
1071 | 1090 | json out; |
1072 | 1091 | out["images"] = json::array(); |
|
0 commit comments