Skip to content

Commit f842722

Browse files
fix issue where watch/keepalive started after returning the acquired lock, so change streams didn't detect lock loss (#38)
1 parent a1025df commit f842722

File tree

5 files changed

+175
-86
lines changed

5 files changed

+175
-86
lines changed

src/MongoDb/LooseExtensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace YakShaveFx.OutboxKit.MongoDb;
2+
3+
internal static class LooseExtensions
4+
{
5+
public static ValueTask TryDisposeAsync(this IAsyncDisposable? disposable)
6+
=> disposable?.DisposeAsync() ?? ValueTask.CompletedTask;
7+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using MongoDB.Driver;
2+
using Nito.AsyncEx;
3+
4+
namespace YakShaveFx.OutboxKit.MongoDb.Synchronization;
5+
6+
internal sealed class ChangeStreamListener(
7+
AsyncAutoResetEvent autoResetEvent,
8+
IChangeStreamCursor<ChangeStreamDocument<DistributedLockDocument>> cursor,
9+
CancellationTokenSource cts) : IAsyncDisposable
10+
{
11+
public Task WaitAsync() => autoResetEvent.WaitAsync(cts.Token);
12+
13+
public static async Task<ChangeStreamListener> StartAsync(
14+
IMongoCollection<DistributedLockDocument> collection,
15+
DistributedLockDefinition lockDefinition,
16+
CancellationToken ct)
17+
{
18+
var cursor = await collection.WatchAsync(
19+
PipelineDefinitionBuilder
20+
.For<ChangeStreamDocument<DistributedLockDocument>>()
21+
.Match(d => d.DocumentKey["_id"] == lockDefinition.Id),
22+
new ChangeStreamOptions
23+
{
24+
BatchSize = 1,
25+
MaxAwaitTime = TimeSpan.FromMinutes(5)
26+
},
27+
ct);
28+
29+
var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
30+
var autoResetEvent = new AsyncAutoResetEvent();
31+
_ = Task.Run(async () =>
32+
{
33+
while (!cts.Token.IsCancellationRequested && await cursor.MoveNextAsync(cts.Token))
34+
{
35+
// only yield when a change is detected, don't care about the amount, just that there is a change
36+
if (cursor.Current.Any())
37+
{
38+
autoResetEvent.Set();
39+
}
40+
}
41+
}, cts.Token);
42+
43+
var watcher = new ChangeStreamListener(autoResetEvent, cursor, cts);
44+
return watcher;
45+
}
46+
47+
public async ValueTask DisposeAsync()
48+
{
49+
try
50+
{
51+
await cts.CancelAsync();
52+
}
53+
catch (Exception)
54+
{
55+
// try to cancel, but don't throw if it fails
56+
}
57+
58+
cts.Dispose();
59+
cursor.Dispose();
60+
}
61+
}

src/MongoDb/Synchronization/DistributedLockThingy.cs

Lines changed: 71 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,64 @@ internal sealed partial class DistributedLockThingy(
1616

1717
public async Task<IDistributedLock> AcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
1818
{
19-
await KeepTryingToAcquireAsync(lockDefinition, ct);
20-
var keepAliveCts = new CancellationTokenSource();
21-
OnAcquired(lockDefinition, keepAliveCts.Token);
22-
return new DistributedLock(lockDefinition, keepAliveCts, ReleaseLockAsync);
19+
// we want to start the listener even before acquiring the lock,
20+
// so there's no window of opportunity in which the lock would be lost without the listener noticing it
21+
var internalLockDefinition = await CreateInternalLockDefinitionAsync(lockDefinition, ct);
22+
try
23+
{
24+
await KeepTryingToAcquireAsync(internalLockDefinition, ct);
25+
var keepAliveCts = new CancellationTokenSource();
26+
OnAcquired(internalLockDefinition, keepAliveCts.Token);
27+
return new DistributedLock(internalLockDefinition, keepAliveCts, ReleaseLockAsync);
28+
}
29+
catch (Exception)
30+
{
31+
await internalLockDefinition.ChangeStreamListener.TryDisposeAsync();
32+
throw;
33+
}
2334
}
2435

2536
public async Task<IDistributedLock?> TryAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
2637
{
27-
if (!await InnerTryAcquireAsync(lockDefinition, ct))
38+
// we want to start the listener even before acquiring the lock,
39+
// so there's no window of opportunity in which the lock would be lost without the listener noticing it
40+
var internalLockDefinition = await CreateInternalLockDefinitionAsync(lockDefinition, ct);
41+
try
42+
{
43+
if (!await InnerTryAcquireAsync(internalLockDefinition, ct))
44+
{
45+
await internalLockDefinition.ChangeStreamListener.TryDisposeAsync();
46+
return null;
47+
}
48+
49+
var keepAliveCts = new CancellationTokenSource();
50+
OnAcquired(internalLockDefinition, keepAliveCts.Token);
51+
return new DistributedLock(internalLockDefinition, keepAliveCts, ReleaseLockAsync);
52+
}
53+
catch (Exception)
2854
{
29-
return null;
55+
await internalLockDefinition.ChangeStreamListener.TryDisposeAsync();
56+
throw;
3057
}
58+
}
3159

32-
var keepAliveCts = new CancellationTokenSource();
33-
OnAcquired(lockDefinition, keepAliveCts.Token);
34-
return new DistributedLock(lockDefinition, keepAliveCts, ReleaseLockAsync);
60+
private async Task<InternalDistributedLockDefinition> CreateInternalLockDefinitionAsync(
61+
DistributedLockDefinition lockDefinition,
62+
CancellationToken ct)
63+
{
64+
var changeStreamListener = _changeStreamsEnabled
65+
? await ChangeStreamListener.StartAsync(_collection, lockDefinition, ct)
66+
: null;
67+
68+
return new InternalDistributedLockDefinition
69+
{
70+
Definition = lockDefinition,
71+
ChangeStreamListener = changeStreamListener
72+
};
3573
}
3674

3775
private async ValueTask ReleaseLockAsync(
38-
DistributedLockDefinition lockDefinition,
76+
InternalDistributedLockDefinition lockDefinition,
3977
CancellationTokenSource keepAliveCts)
4078
{
4179
// try to release the lock, so others can acquire it before expiration
@@ -65,7 +103,7 @@ private async ValueTask ReleaseLockAsync(
65103
}
66104
}
67105

68-
private async Task<bool> InnerTryAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
106+
private async Task<bool> InnerTryAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
69107
{
70108
try
71109
{
@@ -94,7 +132,7 @@ private async Task<bool> InnerTryAcquireAsync(DistributedLockDefinition lockDefi
94132
}
95133
}
96134

97-
private FilterDefinition<DistributedLockDocument> GetUpsertFilter(DistributedLockDefinition lockDefinition)
135+
private FilterDefinition<DistributedLockDocument> GetUpsertFilter(InternalDistributedLockDefinition lockDefinition)
98136
=> Builders<DistributedLockDocument>.Filter.Or(
99137
Builders<DistributedLockDocument>.Filter.And(
100138
Builders<DistributedLockDocument>.Filter.Eq(d => d.Id, lockDefinition.Id),
@@ -103,7 +141,7 @@ private FilterDefinition<DistributedLockDocument> GetUpsertFilter(DistributedLoc
103141
Builders<DistributedLockDocument>.Filter.Eq(d => d.Id, lockDefinition.Id),
104142
Builders<DistributedLockDocument>.Filter.Lt(d => d.ExpiresAt, GetNow())));
105143

106-
private async Task KeepTryingToAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
144+
private async Task KeepTryingToAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
107145
{
108146
if (!_changeStreamsEnabled)
109147
{
@@ -122,7 +160,7 @@ private async Task KeepTryingToAcquireAsync(DistributedLockDefinition lockDefini
122160
await linkedTokenSource.CancelAsync();
123161
}
124162

125-
private async Task PollAndKeepTryingToAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
163+
private async Task PollAndKeepTryingToAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
126164
{
127165
while (!ct.IsCancellationRequested)
128166
{
@@ -139,30 +177,17 @@ private async Task PollAndKeepTryingToAcquireAsync(DistributedLockDefinition loc
139177
}
140178
}
141179

142-
private async Task WatchAndKeepTryingToAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
180+
private async Task WatchAndKeepTryingToAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
143181
{
144-
using var cursor = await _collection.WatchAsync(
145-
PipelineDefinitionBuilder
146-
.For<ChangeStreamDocument<DistributedLockDocument>>()
147-
.Match(d =>
148-
d.OperationType == ChangeStreamOperationType.Delete && d.DocumentKey["_id"] == lockDefinition.Id),
149-
new ChangeStreamOptions
150-
{
151-
BatchSize = 1,
152-
MaxAwaitTime = TimeSpan.FromMinutes(5)
153-
},
154-
ct);
155-
156-
while (!ct.IsCancellationRequested && await cursor.MoveNextAsync(ct))
182+
while (!ct.IsCancellationRequested)
157183
{
158-
foreach (var _ in cursor.Current)
159-
{
160-
if (await InnerTryAcquireAsync(lockDefinition, ct)) return;
161-
}
184+
// listener is not null when change streams are enabled
185+
await lockDefinition.ChangeStreamListener!.WaitAsync();
186+
if (await InnerTryAcquireAsync(lockDefinition, ct)) return;
162187
}
163188
}
164189

165-
private void KickoffKeepAlive(DistributedLockDefinition lockDefinition, CancellationToken ct)
190+
private void KickoffKeepAlive(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
166191
=> _ = Task.Run(async () =>
167192
{
168193
var keepAliveInterval = lockDefinition.Duration / 2;
@@ -202,8 +227,14 @@ private void KickoffKeepAlive(DistributedLockDefinition lockDefinition, Cancella
202227
try
203228
{
204229
var delayTask = Task.Delay(keepAliveInterval, timeProvider, linkedTokenSource.Token);
205-
watchLockLossTask = WatchForPotentialLockLossAsync(lockDefinition, linkedTokenSource.Token);
206-
await Task.WhenAny(delayTask, watchLockLossTask);
230+
231+
// listener is not null when change streams are enabled
232+
await Task.WhenAny(delayTask, lockDefinition.ChangeStreamListener!.WaitAsync());
233+
234+
if (!delayTask.IsCompleted)
235+
{
236+
LogPotentiallyLost(logger, lockDefinition.Id, lockDefinition.Context);
237+
}
207238

208239
// if awaiting was interrupted by cancellation, we can break
209240
if (ct.IsCancellationRequested) break;
@@ -236,40 +267,16 @@ private void KickoffKeepAlive(DistributedLockDefinition lockDefinition, Cancella
236267
}
237268
}
238269
}
239-
240270
LogKeepAliveStopped(logger, lockDefinition.Id, lockDefinition.Context);
241271
}, CancellationToken.None);
242272

243-
private async Task WatchForPotentialLockLossAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
244-
{
245-
using var cursor = await _collection.WatchAsync(
246-
PipelineDefinitionBuilder
247-
.For<ChangeStreamDocument<DistributedLockDocument>>()
248-
.Match(d => d.DocumentKey["_id"] == lockDefinition.Id),
249-
new ChangeStreamOptions
250-
{
251-
BatchSize = 1,
252-
MaxAwaitTime = TimeSpan.FromMinutes(5)
253-
},
254-
ct);
255-
256-
while (!ct.IsCancellationRequested && await cursor.MoveNextAsync(ct))
257-
{
258-
if (cursor.Current.Any())
259-
{
260-
LogPotentiallyLost(logger, lockDefinition.Id, lockDefinition.Context);
261-
return;
262-
}
263-
}
264-
}
265-
266-
private void OnAcquired(DistributedLockDefinition lockDefinition, CancellationToken ct)
273+
private void OnAcquired(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
267274
{
268275
LogAcquired(logger, lockDefinition.Id, lockDefinition.Context);
269276
KickoffKeepAlive(lockDefinition, ct);
270277
}
271278

272-
private void OnLost(DistributedLockDefinition lockDefinition)
279+
private void OnLost(InternalDistributedLockDefinition lockDefinition)
273280
{
274281
LogLost(logger, lockDefinition.Id, lockDefinition.Context);
275282
_ = Task.Run(() => lockDefinition.OnLockLost());
@@ -279,7 +286,7 @@ private void OnLost(DistributedLockDefinition lockDefinition)
279286

280287
private long GetExpiresAt(TimeSpan duration) => timeProvider.GetUtcNow().Add(duration).ToUnixTimeMilliseconds();
281288

282-
private TimeSpan GetRemainingTime(DistributedLockDefinition lockDefinition, long expiresAt)
289+
private TimeSpan GetRemainingTime(InternalDistributedLockDefinition lockDefinition, long expiresAt)
283290
{
284291
var remaining = expiresAt - timeProvider.GetUtcNow().ToUnixTimeMilliseconds();
285292
return remaining > 0
@@ -321,11 +328,10 @@ private TimeSpan GetRemainingTime(DistributedLockDefinition lockDefinition, long
321328
Message = "An error occurred executing lock keep alive (id \"{Id}\" context \"{Context}\")")]
322329
private static partial void LogErrorExecutingKeepAlive(ILogger logger, Exception ex, string id, string? context);
323330

324-
325331
private sealed class DistributedLock(
326-
DistributedLockDefinition definition,
332+
InternalDistributedLockDefinition definition,
327333
CancellationTokenSource keepAliveCts,
328-
Func<DistributedLockDefinition, CancellationTokenSource, ValueTask> releaseLock) : IDistributedLock
334+
Func<InternalDistributedLockDefinition, CancellationTokenSource, ValueTask> releaseLock) : IDistributedLock
329335
{
330336
public ValueTask DisposeAsync() => releaseLock(definition, keepAliveCts);
331337
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
namespace YakShaveFx.OutboxKit.MongoDb.Synchronization;
2+
3+
// wraps the DistributedLockDefinition and adds additional properties for internal use
4+
// so it's easier to pass it around
5+
internal sealed class InternalDistributedLockDefinition
6+
{
7+
public required DistributedLockDefinition Definition { get; init; }
8+
public required ChangeStreamListener? ChangeStreamListener { get; init; }
9+
10+
public string Id => Definition.Id;
11+
public string Owner => Definition.Owner;
12+
public string? Context => Definition.Context;
13+
public TimeSpan Duration => Definition.Duration;
14+
public Func<Task> OnLockLost => Definition.OnLockLost;
15+
}

0 commit comments

Comments
 (0)