Skip to content

Commit 4bf8c80

Browse files
authored
Merge pull request #73 from mmwillet/fix-unreachable-cleanup-server
reorders worker init so that cleanup is actually reached
2 parents b1cdd52 + af4910e commit 4bf8c80

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

examples/server/server.cpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ void terminate(worker_pool * pool) {
297297
(*pool)[0]->task_queue->terminate();
298298
(*pool)[0]->response_map->terminate();
299299
}
300+
}
301+
302+
void complete(worker_pool * pool) {
300303
for (auto w : *pool) {
301304
if (w->thread) {
302305
w->thread->join();
@@ -762,30 +765,18 @@ int main(int argc, const char ** argv) {
762765
svr->wait_until_ready();
763766
fprintf(stdout, "%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port"), *args.get_int_param("--n-http-threads"));
764767

765-
// load the model
766-
fprintf(stdout, "%s: loading model and initializing main loop\n", __func__);
767768

768-
// It might make sense in the long run to have the primary thread run clean up on the response map and keep the model workers parallel.
769-
// pool = initialize_workers(args, tqueue, rmap);
770769
pool = new worker_pool;
771-
for (int i = *args.get_int_param("--n-parallelism"); i > 0; i--) {
772-
if (i == 1) {
773-
fprintf(stdout, "%s: server is listening on http://%s:%d\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port"));
774-
worker * w = new worker(tqueue, rmap, args.get_string_param("--text-encoder-path"), *args.get_int_param("--timeout"));
775-
state.store(READY);
776-
pool->push_back(w);
777-
init_worker(&model_map, *args.get_int_param("--n-threads"), !args.get_bool_param("--use-metal"), default_generation_config, w);
778-
} else {
779-
worker * w = new worker(tqueue, rmap, args.get_string_param("--text-encoder-path"), *args.get_int_param("--timeout"));
780-
w->thread = new std::thread(init_worker, &model_map, *args.get_int_param("--n-threads"), !args.get_bool_param("--use-metal"), default_generation_config, w);
781-
pool->push_back(w);
782-
}
783-
}
770+
shutdown_handler = [&](int) {
771+
// this should unblock the primary thread;
772+
terminate(pool);
773+
return;
774+
};
784775

785776
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
786777
struct sigaction sigint_action;
787778
sigint_action.sa_handler = signal_handler;
788-
sigemptyset (&sigint_action.sa_mask);
779+
sigemptyset(&sigint_action.sa_mask);
789780
sigint_action.sa_flags = 0;
790781
sigaction(SIGINT, &sigint_action, NULL);
791782
sigaction(SIGTERM, &sigint_action, NULL);
@@ -796,9 +787,25 @@ int main(int argc, const char ** argv) {
796787
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
797788
#endif
798789

799-
clean_up();
790+
fprintf(stdout, "%s: loading model and initializing main loop\n", __func__);
791+
// It might make sense in the long run to have the primary thread run clean up on the response map and keep the model workers parallel.
792+
for (int i = *args.get_int_param("--n-parallelism"); i > 0; i--) {
793+
if (i == 1) {
794+
fprintf(stdout, "%s: server is listening on http://%s:%d\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port"));
795+
worker * w = new worker(tqueue, rmap, args.get_string_param("--text-encoder-path"), *args.get_int_param("--timeout"));
796+
state.store(READY);
797+
pool->push_back(w);
798+
init_worker(&model_map, *args.get_int_param("--n-threads"), !args.get_bool_param("--use-metal"), default_generation_config, w);
799+
} else {
800+
worker * w = new worker(tqueue, rmap, args.get_string_param("--text-encoder-path"), *args.get_int_param("--timeout"));
801+
w->thread = new std::thread(init_worker, &model_map, *args.get_int_param("--n-threads"), !args.get_bool_param("--use-metal"), default_generation_config, w);
802+
pool->push_back(w);
803+
}
804+
}
805+
fprintf(stdout, "%s: HTTP server listening on hostname: %s and port: %d, is shutting down.\n", __func__, args.get_string_param("--host").c_str(), *args.get_int_param("--port"));
806+
svr->stop();
800807
t.join();
801-
terminate(pool);
808+
complete(pool);
802809
rmap->cleanup_thread->join();
803810

804811
return 0;

0 commit comments

Comments
 (0)