Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/MongoDb/LooseExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace YakShaveFx.OutboxKit.MongoDb;

internal static class LooseExtensions
{
public static ValueTask TryDisposeAsync(this IAsyncDisposable? disposable)
=> disposable?.DisposeAsync() ?? ValueTask.CompletedTask;
}
61 changes: 61 additions & 0 deletions src/MongoDb/Synchronization/ChangeStreamListener.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using MongoDB.Driver;
using Nito.AsyncEx;

namespace YakShaveFx.OutboxKit.MongoDb.Synchronization;

internal sealed class ChangeStreamListener(
AsyncAutoResetEvent autoResetEvent,
IChangeStreamCursor<ChangeStreamDocument<DistributedLockDocument>> cursor,
CancellationTokenSource cts) : IAsyncDisposable
{
public Task WaitAsync() => autoResetEvent.WaitAsync(cts.Token);

public static async Task<ChangeStreamListener> StartAsync(
IMongoCollection<DistributedLockDocument> collection,
DistributedLockDefinition lockDefinition,
CancellationToken ct)
{
var cursor = await collection.WatchAsync(
PipelineDefinitionBuilder
.For<ChangeStreamDocument<DistributedLockDocument>>()
.Match(d => d.DocumentKey["_id"] == lockDefinition.Id),
new ChangeStreamOptions
{
BatchSize = 1,
MaxAwaitTime = TimeSpan.FromMinutes(5)
},
ct);

var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
var autoResetEvent = new AsyncAutoResetEvent();
_ = Task.Run(async () =>
{
while (!cts.Token.IsCancellationRequested && await cursor.MoveNextAsync(cts.Token))
{
// only yield when a change is detected, don't care about the amount, just that there is a change
if (cursor.Current.Any())
{
autoResetEvent.Set();
}
}
}, cts.Token);

var watcher = new ChangeStreamListener(autoResetEvent, cursor, cts);
return watcher;
}

public async ValueTask DisposeAsync()
{
try
{
await cts.CancelAsync();
}
catch (Exception)
{
// try to cancel, but don't throw if it fails
}

cts.Dispose();
cursor.Dispose();
}
}
136 changes: 71 additions & 65 deletions src/MongoDb/Synchronization/DistributedLockThingy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,64 @@ internal sealed partial class DistributedLockThingy(

public async Task<IDistributedLock> AcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
{
await KeepTryingToAcquireAsync(lockDefinition, ct);
var keepAliveCts = new CancellationTokenSource();
OnAcquired(lockDefinition, keepAliveCts.Token);
return new DistributedLock(lockDefinition, keepAliveCts, ReleaseLockAsync);
// we want to start the listener even before acquiring the lock,
// so there's no window of opportunity in which the lock would be lost without the listener noticing it
var internalLockDefinition = await CreateInternalLockDefinitionAsync(lockDefinition, ct);
try
{
await KeepTryingToAcquireAsync(internalLockDefinition, ct);
var keepAliveCts = new CancellationTokenSource();
OnAcquired(internalLockDefinition, keepAliveCts.Token);
return new DistributedLock(internalLockDefinition, keepAliveCts, ReleaseLockAsync);
}
catch (Exception)
{
await internalLockDefinition.ChangeStreamListener.TryDisposeAsync();
throw;
}
}

public async Task<IDistributedLock?> TryAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
{
if (!await InnerTryAcquireAsync(lockDefinition, ct))
// we want to start the listener even before acquiring the lock,
// so there's no window of opportunity in which the lock would be lost without the listener noticing it
var internalLockDefinition = await CreateInternalLockDefinitionAsync(lockDefinition, ct);
try
{
if (!await InnerTryAcquireAsync(internalLockDefinition, ct))
{
await internalLockDefinition.ChangeStreamListener.TryDisposeAsync();
return null;
}

var keepAliveCts = new CancellationTokenSource();
OnAcquired(internalLockDefinition, keepAliveCts.Token);
return new DistributedLock(internalLockDefinition, keepAliveCts, ReleaseLockAsync);
}
catch (Exception)
{
return null;
await internalLockDefinition.ChangeStreamListener.TryDisposeAsync();
throw;
}
}

var keepAliveCts = new CancellationTokenSource();
OnAcquired(lockDefinition, keepAliveCts.Token);
return new DistributedLock(lockDefinition, keepAliveCts, ReleaseLockAsync);
private async Task<InternalDistributedLockDefinition> CreateInternalLockDefinitionAsync(
DistributedLockDefinition lockDefinition,
CancellationToken ct)
{
var changeStreamListener = _changeStreamsEnabled
? await ChangeStreamListener.StartAsync(_collection, lockDefinition, ct)
: null;

return new InternalDistributedLockDefinition
{
Definition = lockDefinition,
ChangeStreamListener = changeStreamListener
};
}

private async ValueTask ReleaseLockAsync(
DistributedLockDefinition lockDefinition,
InternalDistributedLockDefinition lockDefinition,
CancellationTokenSource keepAliveCts)
{
// try to release the lock, so others can acquire it before expiration
Expand Down Expand Up @@ -65,7 +103,7 @@ private async ValueTask ReleaseLockAsync(
}
}

private async Task<bool> InnerTryAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
private async Task<bool> InnerTryAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
{
try
{
Expand Down Expand Up @@ -94,7 +132,7 @@ private async Task<bool> InnerTryAcquireAsync(DistributedLockDefinition lockDefi
}
}

private FilterDefinition<DistributedLockDocument> GetUpsertFilter(DistributedLockDefinition lockDefinition)
private FilterDefinition<DistributedLockDocument> GetUpsertFilter(InternalDistributedLockDefinition lockDefinition)
=> Builders<DistributedLockDocument>.Filter.Or(
Builders<DistributedLockDocument>.Filter.And(
Builders<DistributedLockDocument>.Filter.Eq(d => d.Id, lockDefinition.Id),
Expand All @@ -103,7 +141,7 @@ private FilterDefinition<DistributedLockDocument> GetUpsertFilter(DistributedLoc
Builders<DistributedLockDocument>.Filter.Eq(d => d.Id, lockDefinition.Id),
Builders<DistributedLockDocument>.Filter.Lt(d => d.ExpiresAt, GetNow())));

private async Task KeepTryingToAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
private async Task KeepTryingToAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
{
if (!_changeStreamsEnabled)
{
Expand All @@ -122,7 +160,7 @@ private async Task KeepTryingToAcquireAsync(DistributedLockDefinition lockDefini
await linkedTokenSource.CancelAsync();
}

private async Task PollAndKeepTryingToAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
private async Task PollAndKeepTryingToAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
{
while (!ct.IsCancellationRequested)
{
Expand All @@ -139,30 +177,17 @@ private async Task PollAndKeepTryingToAcquireAsync(DistributedLockDefinition loc
}
}

private async Task WatchAndKeepTryingToAcquireAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
private async Task WatchAndKeepTryingToAcquireAsync(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
{
using var cursor = await _collection.WatchAsync(
PipelineDefinitionBuilder
.For<ChangeStreamDocument<DistributedLockDocument>>()
.Match(d =>
d.OperationType == ChangeStreamOperationType.Delete && d.DocumentKey["_id"] == lockDefinition.Id),
new ChangeStreamOptions
{
BatchSize = 1,
MaxAwaitTime = TimeSpan.FromMinutes(5)
},
ct);

while (!ct.IsCancellationRequested && await cursor.MoveNextAsync(ct))
while (!ct.IsCancellationRequested)
{
foreach (var _ in cursor.Current)
{
if (await InnerTryAcquireAsync(lockDefinition, ct)) return;
}
// listener is not null when change streams are enabled
await lockDefinition.ChangeStreamListener!.WaitAsync();
if (await InnerTryAcquireAsync(lockDefinition, ct)) return;
}
}

private void KickoffKeepAlive(DistributedLockDefinition lockDefinition, CancellationToken ct)
private void KickoffKeepAlive(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
=> _ = Task.Run(async () =>
{
var keepAliveInterval = lockDefinition.Duration / 2;
Expand Down Expand Up @@ -202,8 +227,14 @@ private void KickoffKeepAlive(DistributedLockDefinition lockDefinition, Cancella
try
{
var delayTask = Task.Delay(keepAliveInterval, timeProvider, linkedTokenSource.Token);
watchLockLossTask = WatchForPotentialLockLossAsync(lockDefinition, linkedTokenSource.Token);
await Task.WhenAny(delayTask, watchLockLossTask);

// listener is not null when change streams are enabled
await Task.WhenAny(delayTask, lockDefinition.ChangeStreamListener!.WaitAsync());

if (!delayTask.IsCompleted)
{
LogPotentiallyLost(logger, lockDefinition.Id, lockDefinition.Context);
}

// if awaiting was interrupted by cancellation, we can break
if (ct.IsCancellationRequested) break;
Expand Down Expand Up @@ -236,40 +267,16 @@ private void KickoffKeepAlive(DistributedLockDefinition lockDefinition, Cancella
}
}
}

LogKeepAliveStopped(logger, lockDefinition.Id, lockDefinition.Context);
}, CancellationToken.None);

private async Task WatchForPotentialLockLossAsync(DistributedLockDefinition lockDefinition, CancellationToken ct)
{
using var cursor = await _collection.WatchAsync(
PipelineDefinitionBuilder
.For<ChangeStreamDocument<DistributedLockDocument>>()
.Match(d => d.DocumentKey["_id"] == lockDefinition.Id),
new ChangeStreamOptions
{
BatchSize = 1,
MaxAwaitTime = TimeSpan.FromMinutes(5)
},
ct);

while (!ct.IsCancellationRequested && await cursor.MoveNextAsync(ct))
{
if (cursor.Current.Any())
{
LogPotentiallyLost(logger, lockDefinition.Id, lockDefinition.Context);
return;
}
}
}

private void OnAcquired(DistributedLockDefinition lockDefinition, CancellationToken ct)
private void OnAcquired(InternalDistributedLockDefinition lockDefinition, CancellationToken ct)
{
LogAcquired(logger, lockDefinition.Id, lockDefinition.Context);
KickoffKeepAlive(lockDefinition, ct);
}

private void OnLost(DistributedLockDefinition lockDefinition)
private void OnLost(InternalDistributedLockDefinition lockDefinition)
{
LogLost(logger, lockDefinition.Id, lockDefinition.Context);
_ = Task.Run(() => lockDefinition.OnLockLost());
Expand All @@ -279,7 +286,7 @@ private void OnLost(DistributedLockDefinition lockDefinition)

private long GetExpiresAt(TimeSpan duration) => timeProvider.GetUtcNow().Add(duration).ToUnixTimeMilliseconds();

private TimeSpan GetRemainingTime(DistributedLockDefinition lockDefinition, long expiresAt)
private TimeSpan GetRemainingTime(InternalDistributedLockDefinition lockDefinition, long expiresAt)
{
var remaining = expiresAt - timeProvider.GetUtcNow().ToUnixTimeMilliseconds();
return remaining > 0
Expand Down Expand Up @@ -321,11 +328,10 @@ private TimeSpan GetRemainingTime(DistributedLockDefinition lockDefinition, long
Message = "An error occurred executing lock keep alive (id \"{Id}\" context \"{Context}\")")]
private static partial void LogErrorExecutingKeepAlive(ILogger logger, Exception ex, string id, string? context);


private sealed class DistributedLock(
DistributedLockDefinition definition,
InternalDistributedLockDefinition definition,
CancellationTokenSource keepAliveCts,
Func<DistributedLockDefinition, CancellationTokenSource, ValueTask> releaseLock) : IDistributedLock
Func<InternalDistributedLockDefinition, CancellationTokenSource, ValueTask> releaseLock) : IDistributedLock
{
public ValueTask DisposeAsync() => releaseLock(definition, keepAliveCts);
}
Expand Down
15 changes: 15 additions & 0 deletions src/MongoDb/Synchronization/InternalDistributedLockDefinition.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
namespace YakShaveFx.OutboxKit.MongoDb.Synchronization;

// wraps the DistributedLockDefinition and adds additional properties for internal use
// so it's easier to pass it around
internal sealed class InternalDistributedLockDefinition
{
public required DistributedLockDefinition Definition { get; init; }
public required ChangeStreamListener? ChangeStreamListener { get; init; }

public string Id => Definition.Id;
public string Owner => Definition.Owner;
public string? Context => Definition.Context;
public TimeSpan Duration => Definition.Duration;
public Func<Task> OnLockLost => Definition.OnLockLost;
}
Loading