diff --git a/projects/Applications/CreateChannel/Program.cs b/projects/Applications/CreateChannel/Program.cs index 4ed61c6f9..f008cc24f 100644 --- a/projects/Applications/CreateChannel/Program.cs +++ b/projects/Applications/CreateChannel/Program.cs @@ -31,10 +31,12 @@ using System; using System.Diagnostics; -using System.Threading; +using System.Globalization; +using System.Runtime.ExceptionServices; using System.Threading.Tasks; using RabbitMQ.Client; +using RabbitMQ.Client.Exceptions; namespace CreateChannel { @@ -44,11 +46,11 @@ public static class Program private const int ChannelsToOpen = 50; private static int channelsOpened; - private static AutoResetEvent doneEvent; + private readonly static TaskCompletionSource s_tcs = new(); public static async Task Main() { - doneEvent = new AutoResetEvent(false); + AppDomain.CurrentDomain.FirstChanceException += CurrentDomain_FirstChanceException; var connectionFactory = new ConnectionFactory { }; await using IConnection connection = await connectionFactory.CreateConnectionAsync(); @@ -67,26 +69,48 @@ public static async Task Main() for (int j = 0; j < channels.Length; j++) { + if (j % 2 == 0) + { + try + { + await channels[j].QueueDeclarePassiveAsync(Guid.NewGuid().ToString()); + } + catch (Exception) + { + } + } await channels[j].DisposeAsync(); } } - doneEvent.Set(); + s_tcs.SetResult(true); }); Console.WriteLine($"{Repeats} times opening {ChannelsToOpen} channels on a connection. => Total channel open/close: {Repeats * ChannelsToOpen}"); Console.WriteLine(); Console.WriteLine("Opened"); - while (!doneEvent.WaitOne(500)) + while (false == s_tcs.Task.IsCompleted) { + await Task.Delay(500); Console.WriteLine($"{channelsOpened,5}"); } watch.Stop(); Console.WriteLine($"{channelsOpened,5}"); Console.WriteLine(); Console.WriteLine($"Took {watch.Elapsed.TotalMilliseconds} ms"); + } + + private static string Now => DateTime.UtcNow.ToString("s", CultureInfo.InvariantCulture); - Console.ReadLine(); + private static void CurrentDomain_FirstChanceException(object sender, FirstChanceExceptionEventArgs e) + { + if (e.Exception is OperationInterruptedException) + { + } + else + { + Console.Error.WriteLine("{0} [ERROR] {1}", Now, e.Exception); + } } } } diff --git a/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs b/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs index 8781453b9..35dc1cedd 100644 --- a/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs +++ b/projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs @@ -96,6 +96,8 @@ public ConfiguredTaskAwaitable.ConfiguredTaskAwaiter GetAwaiter() return _tcsConfiguredTaskAwaitable.GetAwaiter(); } + public abstract ProtocolCommandId[] HandledProtocolCommandIds { get; } + public async Task HandleCommandAsync(IncomingCommand cmd) { try @@ -203,6 +205,9 @@ public ConnectionSecureOrTuneAsyncRpcContinuation(TimeSpan continuationTimeout, { } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [ProtocolCommandId.ConnectionSecure, ProtocolCommandId.ConnectionTune]; + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.ConnectionSecure) @@ -240,6 +245,9 @@ public SimpleAsyncRpcContinuation(ProtocolCommandId expectedCommandId, TimeSpan _expectedCommandId = expectedCommandId; } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [_expectedCommandId]; + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == _expectedCommandId) @@ -297,6 +305,9 @@ public BasicConsumeAsyncRpcContinuation(IAsyncBasicConsumer consumer, IConsumerD _consumerDispatcher = consumerDispatcher; } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [ProtocolCommandId.BasicConsumeOk]; + protected override async Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.BasicConsumeOk) @@ -326,6 +337,9 @@ public BasicGetAsyncRpcContinuation(Func adjustDeliveryTag, _adjustDeliveryTag = adjustDeliveryTag; } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [ProtocolCommandId.BasicGetOk, ProtocolCommandId.BasicGetEmpty]; + internal DateTime StartTime { get; } = DateTime.UtcNow; protected override Task DoHandleCommandAsync(IncomingCommand cmd) @@ -441,6 +455,9 @@ public QueueDeclareAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellati { } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [ProtocolCommandId.QueueDeclareOk]; + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueueDeclareOk) @@ -481,6 +498,9 @@ public QueueDeleteAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellatio { } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [ProtocolCommandId.QueueDeleteOk]; + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueueDeleteOk) @@ -504,6 +524,9 @@ public QueuePurgeAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellation { } + public override ProtocolCommandId[] HandledProtocolCommandIds + => [ProtocolCommandId.QueuePurgeOk]; + protected override Task DoHandleCommandAsync(IncomingCommand cmd) { if (cmd.CommandId == ProtocolCommandId.QueuePurgeOk) diff --git a/projects/RabbitMQ.Client/Impl/Channel.cs b/projects/RabbitMQ.Client/Impl/Channel.cs index 8d373cea2..d8e3ef09a 100644 --- a/projects/RabbitMQ.Client/Impl/Channel.cs +++ b/projects/RabbitMQ.Client/Impl/Channel.cs @@ -208,6 +208,13 @@ public Task CloseAsync(ushort replyCode, string replyText, bool abort, public async Task CloseAsync(ShutdownEventArgs args, bool abort, CancellationToken cancellationToken) { + CancellationToken argCancellationToken = cancellationToken; + if (IsOpen) + { + // Note: we really do need to try and close this channel! + cancellationToken = CancellationToken.None; + } + bool enqueued = false; var k = new ChannelCloseAsyncRpcContinuation(ContinuationTimeout, cancellationToken); @@ -259,6 +266,7 @@ await ConsumerDispatcher.WaitForShutdownAsync() MaybeDisposeContinuation(enqueued, k); _rpcSemaphore.Release(); ChannelShutdownAsync -= k.OnConnectionShutdownAsync; + argCancellationToken.ThrowIfCancellationRequested(); } } @@ -296,7 +304,15 @@ await ModelSendAsync(in method, k.CancellationToken) // negotiation finishes } - return await k; + try + { + return await k; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -332,7 +348,15 @@ await ModelSendAsync(in method, k.CancellationToken) // negotiation finishes } - return await k; + try + { + return await k; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -341,20 +365,6 @@ await ModelSendAsync(in method, k.CancellationToken) } } - protected bool Enqueue(IRpcContinuation k) - { - if (IsOpen) - { - _continuationQueue.Enqueue(k); - return true; - } - else - { - k.HandleChannelShutdown(CloseReason); - return false; - } - } - internal async Task OpenAsync(CreateChannelOptions createChannelOptions, CancellationToken cancellationToken) { @@ -375,11 +385,19 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + try + { + bool result = await k; + Debug.Assert(result); - await MaybeConfirmSelect(cancellationToken) - .ConfigureAwait(false); + await MaybeConfirmSelect(cancellationToken) + .ConfigureAwait(false); + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -402,6 +420,20 @@ await Session.CloseAsync(reason) m_connectionStartCell?.TrySetResult(null); } + private bool Enqueue(IRpcContinuation k) + { + if (IsOpen) + { + _continuationQueue.Enqueue(k); + return true; + } + else + { + k.HandleChannelShutdown(CloseReason); + return false; + } + } + private async Task HandleCommandAsync(IncomingCommand cmd, CancellationToken cancellationToken) { /* @@ -412,6 +444,11 @@ private async Task HandleCommandAsync(IncomingCommand cmd, CancellationToken can */ try { + if (_continuationQueue.ShouldIgnoreCommand(cmd.CommandId)) + { + return; + } + if (false == await DispatchCommandAsync(cmd, cancellationToken) .ConfigureAwait(false)) { @@ -921,11 +958,19 @@ await ModelSendAsync(in method, k.CancellationToken) { enqueued = Enqueue(k); - await ModelSendAsync(in method, k.CancellationToken) - .ConfigureAwait(false); + try + { + await ModelSendAsync(in method, k.CancellationToken) + .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + bool result = await k; + Debug.Assert(result); + } + catch + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } return; @@ -957,7 +1002,15 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - return await k; + try + { + return await k; + } + catch + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -983,17 +1036,25 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - BasicGetResult? result = await k; + try + { + BasicGetResult? result = await k; - using Activity? activity = result != null - ? RabbitMQActivitySource.BasicGet(result.RoutingKey, - result.Exchange, - result.DeliveryTag, result.BasicProperties, result.Body.Length) - : RabbitMQActivitySource.BasicGetEmpty(queue); + using Activity? activity = result != null + ? RabbitMQActivitySource.BasicGet(result.RoutingKey, + result.Exchange, + result.DeliveryTag, result.BasicProperties, result.Body.Length) + : RabbitMQActivitySource.BasicGetEmpty(queue); - activity?.SetStartTime(k.StartTime); + activity?.SetStartTime(k.StartTime); - return result; + return result; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1030,9 +1091,17 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); - return; + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1057,9 +1126,17 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); - return; + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1093,8 +1170,17 @@ await ModelSendAsync(in method, k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } return; @@ -1137,8 +1223,17 @@ await ModelSendAsync(in method, k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } return; @@ -1174,8 +1269,16 @@ await ModelSendAsync(in method, k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + try + { + bool result = await k; + Debug.Assert(result); + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } return; @@ -1212,8 +1315,16 @@ await ModelSendAsync(in method, k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + try + { + bool result = await k; + Debug.Assert(result); + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } return; @@ -1276,16 +1387,24 @@ await ModelSendAsync(in method, k.CancellationToken) { enqueued = Enqueue(k); - await ModelSendAsync(in method, k.CancellationToken) - .ConfigureAwait(false); + try + { + await ModelSendAsync(in method, k.CancellationToken) + .ConfigureAwait(false); - QueueDeclareOk result = await k; - if (false == passive) + QueueDeclareOk result = await k; + if (false == passive) + { + CurrentQueue = result.QueueName; + } + + return result; + } + catch (OperationCanceledException) { - CurrentQueue = result.QueueName; + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; } - - return result; } } finally @@ -1320,8 +1439,16 @@ await ModelSendAsync(in method, k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); + try + { + bool result = await k; + Debug.Assert(result); + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } return; @@ -1375,7 +1502,15 @@ await ModelSendAsync(in method, k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - return await k; + try + { + return await k; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } } finally @@ -1401,7 +1536,15 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - return await k; + try + { + return await k; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1427,9 +1570,17 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); - return; + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1453,9 +1604,17 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); - return; + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1479,9 +1638,17 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); - return; + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { @@ -1505,9 +1672,17 @@ await _rpcSemaphore.WaitAsync(k.CancellationToken) await ModelSendAsync(in method, k.CancellationToken) .ConfigureAwait(false); - bool result = await k; - Debug.Assert(result); - return; + try + { + bool result = await k; + Debug.Assert(result); + return; + } + catch (OperationCanceledException) + { + _continuationQueue.RpcCanceled(k.HandledProtocolCommandIds); + throw; + } } finally { diff --git a/projects/RabbitMQ.Client/Impl/Connection.cs b/projects/RabbitMQ.Client/Impl/Connection.cs index 2bf1ccc9a..b2c271cc9 100644 --- a/projects/RabbitMQ.Client/Impl/Connection.cs +++ b/projects/RabbitMQ.Client/Impl/Connection.cs @@ -320,6 +320,24 @@ public Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, b /// internal async Task CloseAsync(ShutdownEventArgs reason, bool abort, TimeSpan timeout, CancellationToken cancellationToken) { + CancellationToken argCancellationToken = cancellationToken; + + if (abort && timeout < InternalConstants.DefaultConnectionAbortTimeout) + { + timeout = InternalConstants.DefaultConnectionAbortTimeout; + } + + if (false == abort && timeout < InternalConstants.DefaultConnectionCloseTimeout) + { + timeout = InternalConstants.DefaultConnectionCloseTimeout; + } + + if (IsOpen) + { + // Note: we really do need to try and close this connection! + cancellationToken = CancellationToken.None; + } + using var timeoutCts = new CancellationTokenSource(timeout); using var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutCts.Token, cancellationToken); @@ -335,6 +353,7 @@ internal async Task CloseAsync(ShutdownEventArgs reason, bool abort, TimeSpan ti { await OnShutdownAsync(reason) .ConfigureAwait(false); + await _session0.SetSessionClosingAsync(false, cts.Token) .ConfigureAwait(false); @@ -393,7 +412,7 @@ await _session0.TransmitAsync(method, cts.Token) try { - await _mainLoopTask.WaitAsync(timeout, cts.Token) + await _mainLoopTask.WaitAsync(cts.Token) .ConfigureAwait(false); } catch @@ -412,6 +431,8 @@ await _frameHandler.CloseAsync(cts.Token) throw; } } + + argCancellationToken.ThrowIfCancellationRequested(); } internal async Task ClosedViaPeerAsync(ShutdownEventArgs reason, CancellationToken cancellationToken) diff --git a/projects/RabbitMQ.Client/Impl/RpcContinuationQueue.cs b/projects/RabbitMQ.Client/Impl/RpcContinuationQueue.cs index 5622abe36..cac609bd3 100644 --- a/projects/RabbitMQ.Client/Impl/RpcContinuationQueue.cs +++ b/projects/RabbitMQ.Client/Impl/RpcContinuationQueue.cs @@ -30,10 +30,13 @@ //--------------------------------------------------------------------------- using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Threading; using System.Threading.Tasks; using RabbitMQ.Client.Events; +using RabbitMQ.Client.Framing; namespace RabbitMQ.Client.Impl { @@ -65,6 +68,7 @@ public void Dispose() } private static readonly EmptyRpcContinuation s_tmp = new EmptyRpcContinuation(); + private readonly Queue _rpcCancellationQueue = new(); private IRpcContinuation _outstandingRpc = s_tmp; ///Enqueue a continuation, marking a pending RPC. @@ -138,5 +142,29 @@ public bool TryPeek([NotNullWhen(true)] out T? continuation) where T : class, continuation = default; return false; } + + public void RpcCanceled(ProtocolCommandId[] protocolCommandIds) + { + _rpcCancellationQueue.Enqueue(protocolCommandIds); + } + + public bool ShouldIgnoreCommand(ProtocolCommandId commandId) + { + // rabbitmq/rabbitmq-dotnet-client#1802 + // This keeps track of ProtocolCommandId values from previous RPC + // commands that have timed out. + bool rv = false; + + if (_rpcCancellationQueue.Count > 0) + { + ProtocolCommandId[] lastErroredCommandIds = _rpcCancellationQueue.Dequeue(); + if (lastErroredCommandIds.Contains(commandId)) + { + rv = true; + } + } + + return rv; + } } } diff --git a/projects/Test/Integration/TestBasicPublish.cs b/projects/Test/Integration/TestBasicPublish.cs index fd1b69a96..06fc35a8b 100644 --- a/projects/Test/Integration/TestBasicPublish.cs +++ b/projects/Test/Integration/TestBasicPublish.cs @@ -183,7 +183,7 @@ public async Task TestMaxInboundMessageBodySize() int count = 0; byte[] msg0 = _encoding.GetBytes("hi"); - byte[] msg1 = GetRandomBody(maxMsgSize * 20); + byte[] msg1 = GetRandomBody(maxMsgSize * 64); ConnectionFactory cf = CreateConnectionFactory(); cf.AutomaticRecoveryEnabled = false; diff --git a/projects/Test/Integration/TestToxiproxy.cs b/projects/Test/Integration/TestToxiproxy.cs index 75e4af1e7..feb96288b 100644 --- a/projects/Test/Integration/TestToxiproxy.cs +++ b/projects/Test/Integration/TestToxiproxy.cs @@ -404,6 +404,52 @@ public async Task TestPublisherConfirmationThrottling() Assert.Equal(TotalMessageCount, publishCount); } + [SkippableFact] + [Trait("Category", "Toxiproxy")] + public async Task TestRpcContinuationTimeout_GH1802() + { + Skip.IfNot(AreToxiproxyTestsEnabled, "RABBITMQ_TOXIPROXY_TESTS is not set, skipping test"); + + ConnectionFactory cf = CreateConnectionFactory(); + cf.Endpoint = new AmqpTcpEndpoint(IPAddress.Loopback.ToString(), _proxyPort); + cf.ContinuationTimeout = TimeSpan.FromSeconds(1); + cf.AutomaticRecoveryEnabled = false; + cf.TopologyRecoveryEnabled = false; + + await using IConnection conn = await cf.CreateConnectionAsync(); + await using IChannel ch = await conn.CreateChannelAsync(); + + string toxicName = $"rmq-localhost-bandwidth-{Now}-{GenerateShortUuid()}"; + var bandwidthToxic = new BandwidthToxic + { + Name = toxicName + }; + bandwidthToxic.Attributes.Rate = 0; + bandwidthToxic.Toxicity = 1.0; + bandwidthToxic.Stream = ToxicDirection.DownStream; + + Task addToxicTask = _toxiproxyManager.AddToxicAsync(bandwidthToxic); + + await Task.Delay(TimeSpan.FromSeconds(1)); + + bool sawContinuationTimeout = false; + try + { + ch.ContinuationTimeout = TimeSpan.FromMilliseconds(5); + QueueDeclareOk q = await ch.QueueDeclareAsync(); + } + catch (OperationCanceledException) + { + sawContinuationTimeout = true; + } + + await _toxiproxyManager.RemoveToxicAsync(toxicName); + + await ch.CloseAsync(); + + Assert.True(sawContinuationTimeout); + } + private bool AreToxiproxyTestsEnabled { get