Skip to content

Commit 506512f

Browse files
Add KeepAliveTimeout support to WebSocketMiddleware (#57293)
1 parent 8ffa2b2 commit 506512f

File tree

10 files changed

+347
-85
lines changed

10 files changed

+347
-85
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#nullable enable
22
Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.get -> bool
33
Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.set -> void
4+
Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.get -> System.TimeSpan?
5+
Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.set -> void

src/Http/Http.Features/src/WebSocketAcceptContext.cs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,44 @@ namespace Microsoft.AspNetCore.Http;
1111
public class WebSocketAcceptContext
1212
{
1313
private int _serverMaxWindowBits = 15;
14+
private TimeSpan? _keepAliveTimeout;
1415

1516
/// <summary>
1617
/// Gets or sets the subprotocol being negotiated.
1718
/// </summary>
1819
public virtual string? SubProtocol { get; set; }
1920

2021
/// <summary>
21-
/// The interval to send pong frames. This is a heart-beat that keeps the connection alive.
22+
/// The interval to send keep-alive frames. This is a heart-beat that keeps the connection alive.
2223
/// </summary>
24+
/// <remarks>
25+
/// May be either a Ping or a Pong frame, depending on if <see cref="KeepAliveTimeout" /> is set.
26+
/// </remarks>
2327
public virtual TimeSpan? KeepAliveInterval { get; set; }
2428

29+
/// <summary>
30+
/// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted.
31+
/// </summary>
32+
/// <remarks>
33+
/// <c>null</c> means use the value from <c>WebSocketOptions.KeepAliveTimeout</c>.
34+
/// <see cref="Timeout.InfiniteTimeSpan"/> and <see cref="TimeSpan.Zero"/> are valid values and will disable the timeout.
35+
/// </remarks>
36+
/// <exception cref="ArgumentOutOfRangeException">
37+
/// <see cref="TimeSpan"/> is less than <see cref="TimeSpan.Zero"/>.
38+
/// </exception>
39+
public TimeSpan? KeepAliveTimeout
40+
{
41+
get => _keepAliveTimeout;
42+
set
43+
{
44+
if (value is not null && value != Timeout.InfiniteTimeSpan)
45+
{
46+
ArgumentOutOfRangeException.ThrowIfLessThan(value.Value, TimeSpan.Zero);
47+
}
48+
_keepAliveTimeout = value;
49+
}
50+
}
51+
2552
/// <summary>
2653
/// Enables support for the 'permessage-deflate' WebSocket extension.<para />
2754
/// Be aware that enabling compression over encrypted connections makes the application subject to CRIME/BREACH type attacks.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Net.WebSockets;
5+
using Microsoft.AspNetCore.Http;
6+
7+
namespace Microsoft.AspNetCore.WebSockets;
8+
9+
/// <summary>
10+
/// Used in <see cref="WebSocketMiddleware"/> to wrap the <see cref="HttpContext"/>.Request.Body stream
11+
/// so that we can call <see cref="HttpContext.Abort"/> when the stream is disposed and the WebSocket is in the <see cref="WebSocketState.Aborted"/> state.
12+
/// The Stream provided by Kestrel (and maybe other servers) noops in Dispose as it doesn't know whether it's a graceful close or not
13+
/// and can result in truncated responses if in the graceful case.
14+
///
15+
/// This handles explicit <see cref="WebSocket.Abort"/> calls as well as the Keep-Alive timeout setting <see cref="WebSocketState.Aborted"/> and disposing the stream.
16+
/// </summary>
17+
/// <remarks>
18+
/// Workaround for https://github.com/dotnet/runtime/issues/44272
19+
/// </remarks>
20+
internal sealed class AbortStream : Stream
21+
{
22+
private readonly Stream _innerStream;
23+
private readonly HttpContext _httpContext;
24+
25+
public WebSocket? WebSocket { get; set; }
26+
27+
public AbortStream(HttpContext httpContext, Stream innerStream)
28+
{
29+
_innerStream = innerStream;
30+
_httpContext = httpContext;
31+
}
32+
33+
public override bool CanRead => _innerStream.CanRead;
34+
35+
public override bool CanSeek => _innerStream.CanSeek;
36+
37+
public override bool CanWrite => _innerStream.CanWrite;
38+
39+
public override bool CanTimeout => _innerStream.CanTimeout;
40+
41+
public override long Length => _innerStream.Length;
42+
43+
public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; }
44+
45+
public override void Flush()
46+
{
47+
_innerStream.Flush();
48+
}
49+
50+
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
51+
{
52+
return _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
53+
}
54+
55+
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
56+
{
57+
return _innerStream.ReadAsync(buffer, cancellationToken);
58+
}
59+
60+
public override int Read(byte[] buffer, int offset, int count)
61+
{
62+
return _innerStream.Read(buffer, offset, count);
63+
}
64+
65+
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
66+
{
67+
return _innerStream.BeginRead(buffer, offset, count, callback, state);
68+
}
69+
70+
public override int EndRead(IAsyncResult asyncResult)
71+
{
72+
return _innerStream.EndRead(asyncResult);
73+
}
74+
75+
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
76+
{
77+
return _innerStream.BeginWrite(buffer, offset, count, callback, state);
78+
}
79+
80+
public override void EndWrite(IAsyncResult asyncResult)
81+
{
82+
_innerStream.EndWrite(asyncResult);
83+
}
84+
85+
public override long Seek(long offset, SeekOrigin origin)
86+
{
87+
return _innerStream.Seek(offset, origin);
88+
}
89+
90+
public override void SetLength(long value)
91+
{
92+
_innerStream.SetLength(value);
93+
}
94+
95+
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
96+
{
97+
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
98+
}
99+
100+
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
101+
{
102+
return _innerStream.WriteAsync(buffer, cancellationToken);
103+
}
104+
105+
public override void Write(byte[] buffer, int offset, int count)
106+
{
107+
_innerStream.Write(buffer, offset, count);
108+
}
109+
110+
public override Task FlushAsync(CancellationToken cancellationToken)
111+
{
112+
return _innerStream.FlushAsync(cancellationToken);
113+
}
114+
115+
public override ValueTask DisposeAsync()
116+
{
117+
return _innerStream.DisposeAsync();
118+
}
119+
120+
protected override void Dispose(bool disposing)
121+
{
122+
// Currently, if ManagedWebSocket sets the Aborted state it calls Stream.Dispose after
123+
if (WebSocket?.State == WebSocketState.Aborted)
124+
{
125+
_httpContext.Abort();
126+
}
127+
_innerStream.Dispose();
128+
}
129+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
#nullable enable
2+
Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.get -> System.TimeSpan
3+
Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.set -> void

src/Middleware/WebSockets/src/ServerWebSocket.cs

Lines changed: 0 additions & 80 deletions
This file was deleted.

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
141141
bool serverContextTakeover = true;
142142
int serverMaxWindowBits = 15;
143143
TimeSpan keepAliveInterval = _options.KeepAliveInterval;
144+
TimeSpan keepAliveTimeout = _options.KeepAliveTimeout;
144145
if (acceptContext != null)
145146
{
146147
subProtocol = acceptContext.SubProtocol;
147148
enableCompression = acceptContext.DangerousEnableCompression;
148149
serverContextTakeover = !acceptContext.DisableServerContextTakeover;
149150
serverMaxWindowBits = acceptContext.ServerMaxWindowBits;
150151
keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval;
152+
keepAliveTimeout = acceptContext.KeepAliveTimeout ?? keepAliveTimeout;
151153
}
152154

153155
#pragma warning disable CS0618 // Type or member is obsolete
@@ -208,15 +210,18 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
208210
// Disable request timeout, if there is one, after the websocket has been accepted
209211
_context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
210212

211-
var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
213+
var abortStream = new AbortStream(_context, opaqueTransport);
214+
var wrappedSocket = WebSocket.CreateFromStream(abortStream, new WebSocketCreationOptions()
212215
{
213216
IsServer = true,
214217
KeepAliveInterval = keepAliveInterval,
218+
KeepAliveTimeout = keepAliveTimeout,
215219
SubProtocol = subProtocol,
216220
DangerousDeflateOptions = deflateOptions
217221
});
218222

219-
return new ServerWebSocket(wrappedSocket, _context);
223+
abortStream.WebSocket = wrappedSocket;
224+
return wrappedSocket;
220225
}
221226

222227
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)

src/Middleware/WebSockets/src/WebSocketOptions.cs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace Microsoft.AspNetCore.Builder;
88
/// </summary>
99
public class WebSocketOptions
1010
{
11+
private TimeSpan _keepAliveTimeout = Timeout.InfiniteTimeSpan;
12+
1113
/// <summary>
1214
/// Constructs the <see cref="WebSocketOptions"/> class with default values.
1315
/// </summary>
@@ -18,11 +20,40 @@ public WebSocketOptions()
1820
}
1921

2022
/// <summary>
21-
/// Gets or sets the frequency at which to send Ping/Pong keep-alive control frames.
23+
/// The interval to send keep-alive frames. This is a heart-beat that keeps the connection alive.
2224
/// The default is two minutes.
2325
/// </summary>
26+
/// <remarks>
27+
/// May be either a Ping or a Pong frame, depending on if <see cref="KeepAliveTimeout" /> is set.
28+
/// </remarks>
2429
public TimeSpan KeepAliveInterval { get; set; }
2530

31+
/// <summary>
32+
/// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted.
33+
/// </summary>
34+
/// <remarks>
35+
/// Default value is <see cref="Timeout.InfiniteTimeSpan"/>.
36+
/// <see cref="Timeout.InfiniteTimeSpan"/> and <see cref="TimeSpan.Zero"/> will disable the timeout.
37+
/// </remarks>
38+
/// <exception cref="ArgumentOutOfRangeException">
39+
/// <see cref="TimeSpan"/> is less than <see cref="TimeSpan.Zero"/>.
40+
/// </exception>
41+
public TimeSpan KeepAliveTimeout
42+
{
43+
get
44+
{
45+
return _keepAliveTimeout;
46+
}
47+
set
48+
{
49+
if (value != Timeout.InfiniteTimeSpan)
50+
{
51+
ArgumentOutOfRangeException.ThrowIfLessThan(value, TimeSpan.Zero);
52+
}
53+
_keepAliveTimeout = value;
54+
}
55+
}
56+
2657
/// <summary>
2758
/// Gets or sets the size of the protocol buffer used to receive and parse frames.
2859
/// The default is 4kb.

src/Middleware/WebSockets/test/UnitTests/AddWebSocketsTests.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using Microsoft.AspNetCore.Builder;
@@ -17,14 +17,22 @@ public void AddWebSocketsConfiguresOptions()
1717
serviceCollection.AddWebSockets(o =>
1818
{
1919
o.KeepAliveInterval = TimeSpan.FromSeconds(1000);
20+
o.KeepAliveTimeout = TimeSpan.FromSeconds(1234);
2021
o.AllowedOrigins.Add("someString");
2122
});
2223

2324
var services = serviceCollection.BuildServiceProvider();
2425
var socketOptions = services.GetRequiredService<IOptions<WebSocketOptions>>().Value;
2526

2627
Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval);
28+
Assert.Equal(TimeSpan.FromSeconds(1234), socketOptions.KeepAliveTimeout);
2729
Assert.Single(socketOptions.AllowedOrigins);
2830
Assert.Equal("someString", socketOptions.AllowedOrigins[0]);
2931
}
32+
33+
[Fact]
34+
public void ThrowsForBadOptions()
35+
{
36+
Assert.Throws<ArgumentOutOfRangeException>(() => new WebSocketOptions() { KeepAliveTimeout = TimeSpan.FromMicroseconds(-1) });
37+
}
3038
}

0 commit comments

Comments
 (0)