diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.cs index c7eb39857de6..81c9e807698f 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.cs @@ -2,20 +2,25 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.IO.Pipelines; using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Http.Headers; using System.Security.Cryptography.X509Certificates; +using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Client; +using Microsoft.AspNetCore.Http.Connections.Client.Internal; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.InternalTesting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Testing; using Moq; +using Moq.Protected; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests; @@ -157,4 +162,44 @@ await WithConnectionAsync( Assert.Equal("SendingHttpRequest", writeList[0].EventId.Name); Assert.Equal("UnsuccessfulHttpResponse", writeList[1].EventId.Name); } + + [Fact] + public async Task NegotiateAsyncAppendsCorrectAcceptHeader() + { + var testHttpHandler = new TestHttpMessageHandler(false); + var negotiateUrlTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + testHttpHandler.OnNegotiate((request, cancellationToken) => + { + var headerFound = request.Headers.Accept?.Contains(new MediaTypeWithQualityHeaderValue("*/*")) == true; + negotiateUrlTcs.SetResult(headerFound); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); + }); + + var httpOptions = new HttpConnectionOptions + { + Url = new Uri("http://fakeurl.org/"), + SkipNegotiation = false, + Transports = HttpTransportType.WebSockets, + HttpMessageHandlerFactory = inner => testHttpHandler + }; + + try + { + await WithConnectionAsync( + CreateConnection(httpOptions), + async (connection) => + { + await connection.StartAsync().DefaultTimeout(); + }); + } + catch + { + // ignore connection error + } + + Assert.True(negotiateUrlTcs.Task.IsCompleted); + var headerWasFound = await negotiateUrlTcs.Task.DefaultTimeout(); + Assert.True(headerWasFound); + } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/LongPollingTransportTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/LongPollingTransportTests.cs index 79872a236ace..1e3d00cf3a37 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/LongPollingTransportTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/LongPollingTransportTests.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Http.Headers; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -15,6 +16,7 @@ using Microsoft.AspNetCore.Http.Connections.Client.Internal; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.InternalTesting; +using Microsoft.Extensions.Logging.Abstractions; using Moq; using Moq.Protected; using Xunit; @@ -692,4 +694,59 @@ public async Task SendsDeleteRequestWhenTransportCompleted() Assert.Equal(TestUri, deleteRequest.RequestUri); } } + + [Fact] + public async Task PollRequestsContainCorrectAcceptHeader() + { + var testHttpHandler = new TestHttpMessageHandler(); + var responseTaskCompletionSource = new TaskCompletionSource(); + var requestCount = 0; + var allHeadersCorrect = true; + var secondRequestReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + testHttpHandler.OnRequest(async (request, next, cancellationToken) => + { + if (request.Headers.Accept?.Contains(new MediaTypeWithQualityHeaderValue("*/*")) != true) + { + allHeadersCorrect = false; + } + + requestCount++; + + if (requestCount == 2) + { + secondRequestReceived.SetResult(); + } + + if (requestCount >= 2) + { + if (allHeadersCorrect) + { + responseTaskCompletionSource.TrySetResult(new HttpResponseMessage(HttpStatusCode.OK)); + } + else + { + responseTaskCompletionSource.TrySetResult(new HttpResponseMessage(HttpStatusCode.NoContent)); + } + } + + return await Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using (var httpClient = new HttpClient(testHttpHandler)) + { + var loggerFactory = NullLoggerFactory.Instance; + var transport = new LongPollingTransport(httpClient, loggerFactory: loggerFactory); + + var startTask = transport.StartAsync(TestUri, TransferFormat.Text); + + await secondRequestReceived.Task.DefaultTimeout(); + + await transport.StopAsync(); + + Assert.True(responseTaskCompletionSource.Task.IsCompleted); + var response = await responseTaskCompletionSource.Task.DefaultTimeout(); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/SendUtilsTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/SendUtilsTests.cs new file mode 100644 index 000000000000..cb98da2eceb4 --- /dev/null +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/SendUtilsTests.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.IO.Pipelines; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text; +using Microsoft.AspNetCore.Http.Connections.Client.Internal; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.AspNetCore.InternalTesting; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests; +public partial class SendUtilsTests : VerifiableLoggedTest +{ + [Fact] + public async Task SendMessagesSetsCorrectAcceptHeader() + { + var testHttpHandler = new TestHttpMessageHandler(); + var responseTaskCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + testHttpHandler.OnRequest((request, next, cancellationToken) => + { + if (request.Headers.Accept?.Contains(new MediaTypeWithQualityHeaderValue("*/*")) == true) + { + responseTaskCompletionSource.SetResult(ResponseUtils.CreateResponse(HttpStatusCode.OK)); + } + else + { + responseTaskCompletionSource.SetResult(ResponseUtils.CreateResponse(HttpStatusCode.BadRequest)); + } + return responseTaskCompletionSource.Task; + }); + + using (var httpClient = new HttpClient(testHttpHandler)) + { + var pipe = new Pipe(); + var application = new DuplexPipe(pipe.Reader, pipe.Writer); + + // Simulate writing data to send + await application.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); + application.Output.Complete(); + + await SendUtils.SendMessages(new Uri("http://fakeuri.org"), application, httpClient, logger: Logger).DefaultTimeout(); + + var response = await responseTaskCompletionSource.Task.DefaultTimeout(); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + } +} diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs index ed40cf9055a3..bb987398048a 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs @@ -18,6 +18,7 @@ using Moq; using Moq.Protected; using Xunit; +using System.Net; namespace Microsoft.AspNetCore.SignalR.Client.Tests; @@ -409,4 +410,38 @@ public async Task SSETransportThrowsForInvalidTransferFormat(TransferFormat tran Assert.Equal("transferFormat", exception.ParamName); } } + + [Fact] + public async Task StartAsyncSetsCorrectAcceptHeaderForSSE() + { + var testHttpHandler = new TestHttpMessageHandler(); + var responseTaskCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + // Setting up the handler to check for 'text/event-stream' Accept header + testHttpHandler.OnRequest((request, next, cancellationToken) => + { + if (request.Headers.Accept?.Contains(new MediaTypeWithQualityHeaderValue("text/event-stream")) == true) + { + responseTaskCompletionSource.SetResult(new HttpResponseMessage(HttpStatusCode.OK)); + } + else + { + responseTaskCompletionSource.SetResult(new HttpResponseMessage(HttpStatusCode.NoContent)); + } + return responseTaskCompletionSource.Task; + }); + + using (var httpClient = new HttpClient(testHttpHandler)) + { + var sseTransport = new ServerSentEventsTransport(httpClient, loggerFactory: LoggerFactory); + + // Starting the SSE transport and verifying the outcome + await sseTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Text).DefaultTimeout(); + await sseTransport.StopAsync().DefaultTimeout(); + + Assert.True(responseTaskCompletionSource.Task.IsCompleted); + var response = await responseTaskCompletionSource.Task.DefaultTimeout(); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index 06153ad8be64..e193d346c037 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using System.Net.Http.Headers; namespace Microsoft.AspNetCore.Http.Connections.Client; @@ -469,6 +470,7 @@ private async Task NegotiateAsync(Uri url, HttpClient httpC #else request.Properties.Add("IsNegotiate", true); #endif + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("*/*")); // ResponseHeadersRead instructs SendAsync to return once headers are read // rather than buffer the entire response. This gives a small perf boost. diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs index 69e120b03a20..0f45bcfd9ae4 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/LongPollingTransport.cs @@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using System.Net.Http.Headers; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -51,6 +52,8 @@ public async Task StartAsync(Uri url, TransferFormat transferFormat, Cancellatio // Make initial long polling request // Server uses first long polling request to finish initializing connection and it returns without data var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("*/*")); + using (var response = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false)) { response.EnsureSuccessStatusCode(); @@ -149,7 +152,7 @@ private async Task Poll(Uri pollUrl, CancellationToken cancellationToken) while (!cancellationToken.IsCancellationRequested) { var request = new HttpRequestMessage(HttpMethod.Get, pollUrl); - + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("*/*")); HttpResponseMessage response; try @@ -228,6 +231,7 @@ private async Task SendDeleteRequest(Uri url) { Log.SendingDeleteRequest(_logger, url); var request = new HttpRequestMessage(HttpMethod.Delete, url); + var response = await _httpClient.SendAsync(request).ConfigureAwait(false); if (response.StatusCode == HttpStatusCode.NotFound) diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/SendUtils.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/SendUtils.cs index 2ad92f7d279f..ed0b963e654a 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/SendUtils.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/SendUtils.cs @@ -10,6 +10,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using System.Net.Http.Headers; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal; @@ -40,7 +41,7 @@ public static async Task SendMessages(Uri sendUrl, IDuplexPipe application, Http // Send them in a single post var request = new HttpRequestMessage(HttpMethod.Post, sendUrl); - + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("*/*")); request.Content = new ReadOnlySequenceContent(buffer); // ResponseHeadersRead instructs SendAsync to return once headers are read