diff --git a/httplib.h b/httplib.h index 121ce34955..01459fed5c 100644 --- a/httplib.h +++ b/httplib.h @@ -665,6 +665,7 @@ struct Response { Headers headers; std::string body; std::string location; // Redirect location + std::function is_alive; bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", @@ -722,6 +723,7 @@ class Stream { virtual bool is_readable() const = 0; virtual bool is_writable() const = 0; + virtual bool is_alive() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -2333,6 +2335,7 @@ class BufferStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3205,6 +3208,7 @@ class SocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3235,6 +3239,7 @@ class SSLSocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -4601,6 +4606,7 @@ inline bool redirect(T &cli, Request &req, Response &res, } Response new_res; + new_res.is_alive = res.is_alive; auto ret = cli.send(new_req, new_res, error); if (ret) { @@ -5821,6 +5827,10 @@ inline bool SocketStream::is_writable() const { is_socket_alive(sock_); } +inline bool SocketStream::is_alive() const { + return is_socket_alive(sock_); +} + inline ssize_t SocketStream::read(char *ptr, size_t size) { #ifdef _WIN32 size = @@ -5895,6 +5905,8 @@ inline bool BufferStream::is_readable() const { return true; } inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::is_alive() const { return true; } + inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 auto len_read = buffer._Copy_s(ptr, size, size, position); @@ -6991,6 +7003,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, Request req; Response res; + res.is_alive = [&strm]() { return strm.is_alive(); }; res.version = "HTTP/1.1"; res.headers = default_headers_; @@ -8904,6 +8917,10 @@ inline bool SSLSocketStream::is_writable() const { is_socket_alive(sock_); } +inline bool SSLSocketStream::is_alive() const { + return is_socket_alive(sock_); +} + inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); diff --git a/test/fuzzing/server_fuzzer.cc b/test/fuzzing/server_fuzzer.cc index 3cffbae244..f92415fc68 100644 --- a/test/fuzzing/server_fuzzer.cc +++ b/test/fuzzing/server_fuzzer.cc @@ -27,6 +27,8 @@ class FuzzedStream : public httplib::Stream { bool is_writable() const override { return true; } + bool is_alive() const override { return true; } + void get_remote_ip_and_port(std::string &ip, int &port) const override { ip = "127.0.0.1"; port = 8080; diff --git a/test/test.cc b/test/test.cc index 76c6f60d42..7e6fb6886c 100644 --- a/test/test.cc +++ b/test/test.cc @@ -5463,7 +5463,7 @@ TEST(LongPollingTest, ClientCloseDetection) { while (count > 0 && sink.is_writable()) { this_thread::sleep_for(chrono::milliseconds(10)); } - EXPECT_FALSE(sink.is_writable()); // the socket is closed + EXPECT_FALSE(sink.is_writable()); return true; }); }); @@ -5487,6 +5487,44 @@ TEST(LongPollingTest, ClientCloseDetection) { ASSERT_FALSE(res); } +TEST(LongPollingTest, ClientCloseDetectionOnResponse) { + Server svr; + + bool cancelled = false; + std::thread processing_thread; + svr.Get("/events", [&](const Request & /*req*/, Response &res) { + processing_thread = std::thread([&]() { + EXPECT_TRUE(res.is_alive()); + auto count = 10; + while (count > 0 && res.is_alive()) { + this_thread::sleep_for(chrono::milliseconds(10)); + } + EXPECT_FALSE(res.is_alive()); + cancelled = true; + }); + }); + + auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + listen_thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + Client cli("localhost", PORT); + + cli.Get("/events", [&](const char *data, size_t data_length) { + EXPECT_EQ("hello", string(data, data_length)); + return false; // close the socket immediately. + }); + + processing_thread.join(); + + ASSERT_TRUE(cancelled); +} + TEST(GetWithParametersTest, GetWithParameters) { Server svr;