Skip to content

Commit 2f7bae7

Browse files
wbrunadonington
andcommitted
feat(server): cancel current generation on client disconnect
Co-authored-by: donington <jandastroy@gmail.com>
1 parent 771edfa commit 2f7bae7

1 file changed

Lines changed: 25 additions & 6 deletions

File tree

examples/server/main.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <chrono>
33
#include <filesystem>
44
#include <fstream>
5+
#include <future>
56
#include <iomanip>
67
#include <iostream>
78
#include <mutex>
@@ -365,6 +366,18 @@ int main(int argc, const char** argv) {
365366
return httplib::Server::HandlerResponse::Unhandled;
366367
});
367368

369+
auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) {
370+
std::future_status ft_status;
371+
do {
372+
if (!ft.valid())
373+
break;
374+
ft_status = ft.wait_for(std::chrono::milliseconds(1000));
375+
if (req.is_connection_closed()) {
376+
sd_cancel_generation(sd_ctx, SD_CANCEL_ALL);
377+
}
378+
} while (ft_status != std::future_status::ready);
379+
};
380+
368381
// root
369382
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
370383
if (!svr_params.serve_html_path.empty()) {
@@ -507,11 +520,13 @@ int main(int argc, const char** argv) {
507520
sd_image_t* results = nullptr;
508521
int num_results = 0;
509522

510-
{
523+
std::future<void> ft = std::async(std::launch::async, [&]() {
511524
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
512525
results = generate_image(sd_ctx, &img_gen_params);
513526
num_results = gen_params.batch_count;
514-
}
527+
});
528+
529+
wait_for_generation(ft, sd_ctx, req);
515530

516531
for (int i = 0; i < num_results; i++) {
517532
if (results[i].data == nullptr) {
@@ -748,11 +763,13 @@ int main(int argc, const char** argv) {
748763
sd_image_t* results = nullptr;
749764
int num_results = 0;
750765

751-
{
766+
std::future<void> ft = std::async(std::launch::async, [&]() {
752767
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
753768
results = generate_image(sd_ctx, &img_gen_params);
754769
num_results = gen_params.batch_count;
755-
}
770+
});
771+
772+
wait_for_generation(ft, sd_ctx, req);
756773

757774
json out;
758775
out["created"] = static_cast<long long>(std::time(nullptr));
@@ -1062,11 +1079,13 @@ int main(int argc, const char** argv) {
10621079
sd_image_t* results = nullptr;
10631080
int num_results = 0;
10641081

1065-
{
1082+
std::future<void> ft = std::async(std::launch::async, [&]() {
10661083
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
10671084
results = generate_image(sd_ctx, &img_gen_params);
10681085
num_results = gen_params.batch_count;
1069-
}
1086+
});
1087+
1088+
wait_for_generation(ft, sd_ctx, req);
10701089

10711090
json out;
10721091
out["images"] = json::array();

0 commit comments

Comments
 (0)