From 227e3f3e0d24e57fba02524f32273132089fc3f7 Mon Sep 17 00:00:00 2001 From: Frode Jacobsen Date: Sat, 17 May 2025 14:18:15 +0200 Subject: [PATCH] Make PooledSocket.ReadAsync respect receive timeout setting --- src/Enyim.Caching/Memcached/PooledSocket.cs | 39 ++++-- test/MemcachedTest/PooledSocketTest.cs | 143 ++++++++++++++++++++ 2 files changed, 168 insertions(+), 14 deletions(-) create mode 100644 test/MemcachedTest/PooledSocketTest.cs diff --git a/src/Enyim.Caching/Memcached/PooledSocket.cs b/src/Enyim.Caching/Memcached/PooledSocket.cs index 8c47f1d9..79d2b79b 100755 --- a/src/Enyim.Caching/Memcached/PooledSocket.cs +++ b/src/Enyim.Caching/Memcached/PooledSocket.cs @@ -24,7 +24,7 @@ public partial class PooledSocket : IDisposable private bool _isSocketDisposed; private readonly EndPoint _endpoint; private readonly int _connectionTimeout; - + private readonly int _receiveTimeout; private NetworkStream _inputStream; private SslStream _sslStream; #if NET5_0_OR_GREATER @@ -71,6 +71,7 @@ public PooledSocket(EndPoint endpoint, TimeSpan connectionTimeout, TimeSpan rece socket.ReceiveTimeout = rcv; socket.SendTimeout = rcv; + _receiveTimeout = rcv; _socket = socket; } @@ -425,21 +426,31 @@ public async Task ReadAsync(byte[] buffer, int offset, int count) { try { - int currentRead = (_useSslStream - ? await _sslStream.ReadAsync(buffer, offset, shouldRead).ConfigureAwait(false) - : await _inputStream.ReadAsync(buffer, offset, shouldRead).ConfigureAwait(false)); - if (currentRead == count) - break; - if (currentRead < 1) - throw new IOException("The socket seems to be disconnected"); - - read += currentRead; - offset += currentRead; - shouldRead -= currentRead; + var readTask = _useSslStream + ? _sslStream.ReadAsync(buffer, offset, shouldRead) + : _inputStream.ReadAsync(buffer, offset, shouldRead); + var timeoutTask = Task.Delay(_receiveTimeout); + + if (await Task.WhenAny(readTask, timeoutTask).ConfigureAwait(false) == readTask) + { + int currentRead = await readTask.ConfigureAwait(false); + if (currentRead == count) + break; + if (currentRead < 1) + throw new IOException("The socket seems to be disconnected"); + + read += currentRead; + offset += currentRead; + shouldRead -= currentRead; + } + else + { + throw new TimeoutException($"Timeout to read from {_endpoint}."); + } } catch (Exception ex) { - if (ex is IOException || ex is SocketException) + if (ex is IOException || ex is SocketException || ex is TimeoutException) { _isAlive = false; } @@ -648,4 +659,4 @@ private IPEndPoint GetIPEndPoint(EndPoint endpoint) * * ************************************************************/ -#endregion \ No newline at end of file +#endregion diff --git a/test/MemcachedTest/PooledSocketTest.cs b/test/MemcachedTest/PooledSocketTest.cs new file mode 100644 index 00000000..db747357 --- /dev/null +++ b/test/MemcachedTest/PooledSocketTest.cs @@ -0,0 +1,143 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Enyim.Caching.Memcached; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace MemcachedTest; + +public class PooledSocketTest +{ + [Fact] + public async Task ReadSync_ShouldTimeoutOrFail_WhenServerResponseIsSlow() + { + // Arrange + var logger = new NullLogger(); + const int port = 12345; + var server = new SlowLorisServer(); + using var cts = new CancellationTokenSource(); + await server.StartAsync(port, cts.Token); + var endpoint = new IPEndPoint(IPAddress.Loopback, port); + var socket = new PooledSocket( + endpoint, + TimeSpan.FromSeconds(5), + TimeSpan.FromMilliseconds(50), + logger, + useSslStream: false, + useIPv6: false, + sslClientAuthOptions: null + ); + await socket.ConnectAsync(); + var buffer = new byte[server.Response.Length]; + + // Act + var timer = Stopwatch.StartNew(); + var ex = Record.Exception(() => + { + socket.Read(buffer, 0, server.Response.Length); + }); + timer.Stop(); + + // Assert + Assert.True(timer.Elapsed < TimeSpan.FromMilliseconds(500), "Read took too long"); + Assert.NotNull(ex); + Assert.True( + ex is TimeoutException or IOException, + $"Expected TimeoutException or IOException, got {ex.GetType().Name}: {ex.Message}" + ); + + await cts.CancelAsync(); + server.Stop(); + } + + [Fact] + public async Task ReadAsync_ShouldTimeoutOrFail_WhenServerResponseIsSlow() + { + // Arrange + var logger = new NullLogger(); + const int port = 12345; + var server = new SlowLorisServer(); + using var cts = new CancellationTokenSource(); + + await server.StartAsync(port, cts.Token); + + var endpoint = new IPEndPoint(IPAddress.Loopback, port); + var socket = new PooledSocket( + endpoint, + TimeSpan.FromSeconds(5), + TimeSpan.FromMilliseconds(50), + logger, + useSslStream: false, + useIPv6: false, + sslClientAuthOptions: null + ); + + await socket.ConnectAsync(); + + var buffer = new byte[server.Response.Length]; + + // Act + var timer = Stopwatch.StartNew(); + var ex = await Record.ExceptionAsync(async () => + { + await socket.ReadAsync(buffer, 0, server.Response.Length); + }); + timer.Stop(); + + // Assert + Assert.True(timer.Elapsed < TimeSpan.FromMilliseconds(500), "ReadAsync took too long"); + Assert.NotNull(ex); + Assert.True( + ex is TimeoutException or IOException, + $"Expected TimeoutException or IOException, got {ex.GetType().Name}: {ex.Message}" + ); + + // Cleanup + await cts.CancelAsync(); + server.Stop(); + } +} + +public class SlowLorisServer +{ + private TcpListener _listener; + private CancellationToken _token; + public readonly byte[] Response = "Hello, I'm slow!"u8.ToArray(); + + public Task StartAsync(int port, CancellationToken token) + { + _token = token; + _listener = new TcpListener(IPAddress.Loopback, port); + _listener.Start(); + + _ = Task.Run(async () => + { + while (!token.IsCancellationRequested) + { + var client = await _listener.AcceptTcpClientAsync(token); + _ = Task.Run(() => HandleClientAsync(client), token); + } + }, token); + return Task.CompletedTask; + } + + private async Task HandleClientAsync(TcpClient client) + { + await using var stream = client.GetStream(); + for (var i = 0; i < Response.Length; i++) + { + await stream.WriteAsync(Response, i, 1, _token); + await Task.Delay(100, _token); + } + await stream.FlushAsync(_token); + client.Close(); + } + + public void Stop() => _listener.Stop(); +} +