Skip to content

Commit 612d50d

Browse files
refactor(rdb_saver): Add SnapshotDataConsumer to SliceSnapshot (#4287)
* refactor(rdb_saver): Add SnapshotDataConsumer to SliceSnapshot fixes #4218 Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> * refactor: address comments Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io> --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
1 parent d162094 commit 612d50d

File tree

6 files changed

+79
-75
lines changed

6 files changed

+79
-75
lines changed

src/server/detail/save_stages_controller.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ GenericError RdbSnapshot::Start(SaveMode save_mode, const std::string& path,
124124
}
125125

126126
error_code RdbSnapshot::SaveBody() {
127-
return saver_->SaveBody(&cntx_);
127+
return saver_->SaveBody(cntx_);
128128
}
129129

130130
error_code RdbSnapshot::WaitSnapshotInShard(EngineShard* shard) {

src/server/dflycmd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha
586586
}
587587

588588
if (flow->start_partial_sync_at.has_value())
589-
saver->StartIncrementalSnapshotInShard(cntx, shard, *flow->start_partial_sync_at);
589+
saver->StartIncrementalSnapshotInShard(*flow->start_partial_sync_at, cntx, shard);
590590
else
591591
saver->StartSnapshotInShard(true, cntx, shard);
592592

src/server/rdb_save.cc

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ error_code AlignedBuffer::Flush() {
10651065
return upstream_->Write(&ivec, 1);
10661066
}
10671067

1068-
class RdbSaver::Impl {
1068+
class RdbSaver::Impl final : public SliceSnapshot::SnapshotDataConsumerInterface {
10691069
private:
10701070
void CleanShardSnapshots();
10711071

@@ -1078,11 +1078,16 @@ class RdbSaver::Impl {
10781078
~Impl();
10791079

10801080
void StartSnapshotting(bool stream_journal, Context* cntx, EngineShard* shard);
1081-
void StartIncrementalSnapshotting(Context* cntx, EngineShard* shard, LSN start_lsn);
1081+
void StartIncrementalSnapshotting(LSN start_lsn, Context* cntx, EngineShard* shard);
10821082

10831083
void StopSnapshotting(EngineShard* shard);
10841084
void WaitForSnapshottingFinish(EngineShard* shard);
10851085

1086+
// Pushes snapshot data. Called from SliceSnapshot
1087+
void ConsumeData(std::string data, Context* cntx) override;
1088+
// Finalizes the snapshot writing. Called from SliceSnapshot
1089+
void Finalize() override;
1090+
10861091
// used only for legacy rdb save flows.
10871092
error_code ConsumeChannel(const Cancellation* cll);
10881093

@@ -1115,8 +1120,6 @@ class RdbSaver::Impl {
11151120
}
11161121

11171122
private:
1118-
void PushSnapshotData(Context* cntx, string record);
1119-
void FinalizeSnapshotWriting();
11201123
error_code WriteRecord(io::Bytes src);
11211124

11221125
unique_ptr<SliceSnapshot>& GetSnapshot(EngineShard* shard);
@@ -1252,49 +1255,26 @@ error_code RdbSaver::Impl::WriteRecord(io::Bytes src) {
12521255
return ec;
12531256
}
12541257

1255-
void RdbSaver::Impl::PushSnapshotData(Context* cntx, string record) {
1256-
if (cntx->IsCancelled()) {
1257-
return;
1258-
}
1259-
if (channel_) { // Rdb write to channel
1260-
channel_->Push(record);
1261-
} else { // Write directly to socket
1262-
auto ec = WriteRecord(io::Buffer(record));
1263-
if (ec) {
1264-
cntx->ReportError(ec);
1265-
}
1266-
}
1267-
}
1268-
1269-
void RdbSaver::Impl::FinalizeSnapshotWriting() {
1270-
if (channel_) {
1271-
channel_->StartClosing();
1272-
}
1273-
}
1274-
12751258
void RdbSaver::Impl::StartSnapshotting(bool stream_journal, Context* cntx, EngineShard* shard) {
12761259
auto& s = GetSnapshot(shard);
12771260
auto& db_slice = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id());
1278-
auto on_snapshot_finish = std::bind(&RdbSaver::Impl::FinalizeSnapshotWriting, this);
1279-
auto push_cb = std::bind(&RdbSaver::Impl::PushSnapshotData, this, cntx, std::placeholders::_1);
12801261

1281-
s = std::make_unique<SliceSnapshot>(&db_slice, compression_mode_, push_cb, on_snapshot_finish);
1262+
s = std::make_unique<SliceSnapshot>(compression_mode_, &db_slice, this, cntx);
12821263

12831264
const auto allow_flush = (save_mode_ != SaveMode::RDB) ? SliceSnapshot::SnapshotFlush::kAllow
12841265
: SliceSnapshot::SnapshotFlush::kDisallow;
12851266

1286-
s->Start(stream_journal, cntx->GetCancellation(), allow_flush);
1267+
s->Start(stream_journal, allow_flush);
12871268
}
12881269

1289-
void RdbSaver::Impl::StartIncrementalSnapshotting(Context* cntx, EngineShard* shard,
1290-
LSN start_lsn) {
1270+
void RdbSaver::Impl::StartIncrementalSnapshotting(LSN start_lsn, Context* cntx,
1271+
EngineShard* shard) {
12911272
auto& db_slice = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id());
12921273
auto& s = GetSnapshot(shard);
1293-
auto on_finalize_cb = std::bind(&RdbSaver::Impl::FinalizeSnapshotWriting, this);
1294-
auto push_cb = std::bind(&RdbSaver::Impl::PushSnapshotData, this, cntx, std::placeholders::_1);
1295-
s = std::make_unique<SliceSnapshot>(&db_slice, compression_mode_, push_cb, on_finalize_cb);
12961274

1297-
s->StartIncremental(cntx, start_lsn);
1275+
s = std::make_unique<SliceSnapshot>(compression_mode_, &db_slice, this, cntx);
1276+
1277+
s->StartIncremental(start_lsn);
12981278
}
12991279

13001280
// called on save flow
@@ -1304,6 +1284,26 @@ void RdbSaver::Impl::WaitForSnapshottingFinish(EngineShard* shard) {
13041284
snapshot->WaitSnapshotting();
13051285
}
13061286

1287+
void RdbSaver::Impl::ConsumeData(std::string data, Context* cntx) {
1288+
if (cntx->IsCancelled()) {
1289+
return;
1290+
}
1291+
if (channel_) { // Rdb write to channel
1292+
channel_->Push(std::move(data));
1293+
} else { // Write directly to socket
1294+
auto ec = WriteRecord(io::Buffer(data));
1295+
if (ec) {
1296+
cntx->ReportError(ec);
1297+
}
1298+
}
1299+
}
1300+
1301+
void RdbSaver::Impl::Finalize() {
1302+
if (channel_) {
1303+
channel_->StartClosing();
1304+
}
1305+
}
1306+
13071307
// called from replication flow
13081308
void RdbSaver::Impl::StopSnapshotting(EngineShard* shard) {
13091309
auto& snapshot = GetSnapshot(shard);
@@ -1462,8 +1462,8 @@ void RdbSaver::StartSnapshotInShard(bool stream_journal, Context* cntx, EngineSh
14621462
impl_->StartSnapshotting(stream_journal, cntx, shard);
14631463
}
14641464

1465-
void RdbSaver::StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn) {
1466-
impl_->StartIncrementalSnapshotting(cntx, shard, start_lsn);
1465+
void RdbSaver::StartIncrementalSnapshotInShard(LSN start_lsn, Context* cntx, EngineShard* shard) {
1466+
impl_->StartIncrementalSnapshotting(start_lsn, cntx, shard);
14671467
}
14681468

14691469
error_code RdbSaver::WaitSnapshotInShard(EngineShard* shard) {
@@ -1489,17 +1489,17 @@ error_code RdbSaver::SaveHeader(const GlobalData& glob_state) {
14891489
return error_code{};
14901490
}
14911491

1492-
error_code RdbSaver::SaveBody(Context* cntx) {
1492+
error_code RdbSaver::SaveBody(const Context& cntx) {
14931493
RETURN_ON_ERR(impl_->FlushSerializer());
14941494

14951495
if (save_mode_ == SaveMode::RDB) {
14961496
VLOG(1) << "SaveBody , snapshots count: " << impl_->Size();
1497-
error_code io_error = impl_->ConsumeChannel(cntx->GetCancellation());
1497+
error_code io_error = impl_->ConsumeChannel(cntx.GetCancellation());
14981498
if (io_error) {
14991499
return io_error;
15001500
}
1501-
if (cntx->GetError()) {
1502-
return cntx->GetError();
1501+
if (cntx.GetError()) {
1502+
return cntx.GetError();
15031503
}
15041504
} else {
15051505
DCHECK(save_mode_ == SaveMode::SUMMARY);

src/server/rdb_save.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class RdbSaver {
9494
void StartSnapshotInShard(bool stream_journal, Context* cntx, EngineShard* shard);
9595

9696
// Send only the incremental snapshot since start_lsn.
97-
void StartIncrementalSnapshotInShard(Context* cntx, EngineShard* shard, LSN start_lsn);
97+
void StartIncrementalSnapshotInShard(LSN start_lsn, Context* cntx, EngineShard* shard);
9898

9999
// Stops full-sync serialization for replication in the shard's thread.
100100
std::error_code StopFullSyncInShard(EngineShard* shard);
@@ -107,7 +107,7 @@ class RdbSaver {
107107

108108
// Writes the RDB file into sink. Waits for the serialization to finish.
109109
// Called only for save rdb flow and save df on summary file.
110-
std::error_code SaveBody(Context* cntx);
110+
std::error_code SaveBody(const Context& cntx);
111111

112112
// Fills freq_map with the histogram of rdb types.
113113
void FillFreqMap(RdbTypeFreqMap* freq_map);

src/server/snapshot.cc

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,9 @@ constexpr size_t kMinBlobSize = 32_KB;
3737

3838
} // namespace
3939

40-
SliceSnapshot::SliceSnapshot(DbSlice* slice, CompressionMode compression_mode,
41-
std::function<void(std::string)> on_push_record,
42-
std::function<void()> on_snapshot_finish)
43-
: db_slice_(slice),
44-
compression_mode_(compression_mode),
45-
on_push_(on_push_record),
46-
on_snapshot_finish_(on_snapshot_finish) {
40+
SliceSnapshot::SliceSnapshot(CompressionMode compression_mode, DbSlice* slice,
41+
SnapshotDataConsumerInterface* consumer, Context* cntx)
42+
: db_slice_(slice), compression_mode_(compression_mode), consumer_(consumer), cntx_(cntx) {
4743
db_array_ = slice->databases();
4844
tl_slice_snapshots.insert(this);
4945
}
@@ -65,7 +61,7 @@ bool SliceSnapshot::IsSnaphotInProgress() {
6561
return tl_slice_snapshots.size() > 0;
6662
}
6763

68-
void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll, SnapshotFlush allow_flush) {
64+
void SliceSnapshot::Start(bool stream_journal, SnapshotFlush allow_flush) {
6965
DCHECK(!snapshot_fb_.IsJoinable());
7066

7167
auto db_cb = absl::bind_front(&SliceSnapshot::OnDbChange, this);
@@ -95,19 +91,18 @@ void SliceSnapshot::Start(bool stream_journal, const Cancellation* cll, Snapshot
9591

9692
VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_;
9793

98-
snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal, cll] {
99-
IterateBucketsFb(cll, stream_journal);
94+
snapshot_fb_ = fb2::Fiber("snapshot", [this, stream_journal] {
95+
this->IterateBucketsFb(stream_journal);
10096
db_slice_->UnregisterOnChange(snapshot_version_);
101-
on_snapshot_finish_();
97+
consumer_->Finalize();
10298
});
10399
}
104100

105-
void SliceSnapshot::StartIncremental(Context* cntx, LSN start_lsn) {
101+
void SliceSnapshot::StartIncremental(LSN start_lsn) {
106102
serializer_ = std::make_unique<RdbSerializer>(compression_mode_);
107103

108-
snapshot_fb_ = fb2::Fiber("incremental_snapshot", [cntx, start_lsn, this] {
109-
this->SwitchIncrementalFb(cntx, start_lsn);
110-
});
104+
snapshot_fb_ = fb2::Fiber("incremental_snapshot",
105+
[start_lsn, this] { this->SwitchIncrementalFb(start_lsn); });
111106
}
112107

113108
// Called only for replication use-case.
@@ -144,7 +139,7 @@ void SliceSnapshot::FinalizeJournalStream(bool cancel) {
144139
// and survived until it finished.
145140

146141
// Serializes all the entries with version less than snapshot_version_.
147-
void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut) {
142+
void SliceSnapshot::IterateBucketsFb(bool send_full_sync_cut) {
148143
{
149144
auto fiber_name = absl::StrCat("SliceSnapshot-", ProactorBase::me()->GetPoolIndex());
150145
ThisFiber::SetName(std::move(fiber_name));
@@ -156,7 +151,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
156151
}
157152

158153
for (DbIndex db_indx = 0; db_indx < db_array_.size(); ++db_indx) {
159-
if (cll->IsCancelled())
154+
if (cntx_->IsCancelled())
160155
return;
161156

162157
if (!db_array_[db_indx])
@@ -168,7 +163,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
168163

169164
VLOG(1) << "Start traversing " << pt->size() << " items for index " << db_indx;
170165
do {
171-
if (cll->IsCancelled()) {
166+
if (cntx_->IsCancelled()) {
172167
return;
173168
}
174169

@@ -204,15 +199,15 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
204199
<< stats_.loop_serialized << "/" << stats_.side_saved << "/" << stats_.savecb_calls;
205200
}
206201

207-
void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) {
202+
void SliceSnapshot::SwitchIncrementalFb(LSN lsn) {
208203
auto* journal = db_slice_->shard_owner()->journal();
209204
DCHECK(journal);
210205
DCHECK_LE(lsn, journal->GetLsn()) << "The replica tried to sync from the future.";
211206

212207
VLOG(1) << "Starting incremental snapshot from lsn=" << lsn;
213208

214209
// The replica sends the LSN of the next entry is wants to receive.
215-
while (!cntx->IsCancelled() && journal->IsLSNInBuffer(lsn)) {
210+
while (!cntx_->IsCancelled() && journal->IsLSNInBuffer(lsn)) {
216211
serializer_->WriteJournalEntry(journal->GetEntry(lsn));
217212
PushSerialized(false);
218213
lsn++;
@@ -239,7 +234,7 @@ void SliceSnapshot::SwitchIncrementalFb(Context* cntx, LSN lsn) {
239234
PushSerialized(true);
240235
} else {
241236
// We stopped but we didn't manage to send the whole stream.
242-
cntx->ReportError(
237+
cntx_->ReportError(
243238
std::make_error_code(errc::state_not_recoverable),
244239
absl::StrCat("Partial sync was unsuccessful because entry #", lsn,
245240
" was dropped from the buffer. Current lsn=", journal->GetLsn()));
@@ -348,7 +343,7 @@ size_t SliceSnapshot::FlushSerialized(SerializerBase::FlushState flush_state) {
348343
seq_cond_.wait(lk, [&] { return id == this->last_pushed_id_ + 1; });
349344

350345
// Blocking point.
351-
on_push_(std::move(sfile.val));
346+
consumer_->ConsumeData(std::move(sfile.val), cntx_);
352347

353348
DCHECK_EQ(last_pushed_id_ + 1, id);
354349
last_pushed_id_ = id;

src/server/snapshot.h

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,18 @@ struct Entry;
4949
// over the sink until explicitly stopped.
5050
class SliceSnapshot {
5151
public:
52-
SliceSnapshot(DbSlice* slice, CompressionMode compression_mode,
53-
std::function<void(std::string)> on_push, std::function<void()> on_snapshot_finish);
52+
// Represents a target for receiving snapshot data.
53+
struct SnapshotDataConsumerInterface {
54+
virtual ~SnapshotDataConsumerInterface() = default;
55+
56+
// Receives a chunk of snapshot data for processing
57+
virtual void ConsumeData(std::string data, Context* cntx) = 0;
58+
// Finalizes the snapshot writing
59+
virtual void Finalize() = 0;
60+
};
61+
62+
SliceSnapshot(CompressionMode compression_mode, DbSlice* slice,
63+
SnapshotDataConsumerInterface* consumer, Context* cntx);
5464
~SliceSnapshot();
5565

5666
static size_t GetThreadLocalMemoryUsage();
@@ -60,15 +70,14 @@ class SliceSnapshot {
6070
// In journal streaming mode it needs to be stopped by either Stop or Cancel.
6171
enum class SnapshotFlush { kAllow, kDisallow };
6272

63-
void Start(bool stream_journal, const Cancellation* cll,
64-
SnapshotFlush allow_flush = SnapshotFlush::kDisallow);
73+
void Start(bool stream_journal, SnapshotFlush allow_flush = SnapshotFlush::kDisallow);
6574

6675
// Initialize a snapshot that sends only the missing journal updates
6776
// since start_lsn and then registers a callback switches into the
6877
// journal streaming mode until stopped.
6978
// If we're slower than the buffer and can't continue, `Cancel()` is
7079
// called.
71-
void StartIncremental(Context* cntx, LSN start_lsn);
80+
void StartIncremental(LSN start_lsn);
7281

7382
// Finalizes journal streaming writes. Only called for replication.
7483
// Blocking. Must be called from the Snapshot thread.
@@ -83,10 +92,10 @@ class SliceSnapshot {
8392
private:
8493
// Main snapshotting fiber that iterates over all buckets in the db slice
8594
// and submits them to SerializeBucket.
86-
void IterateBucketsFb(const Cancellation* cll, bool send_full_sync_cut);
95+
void IterateBucketsFb(bool send_full_sync_cut);
8796

8897
// A fiber function that switches to the incremental mode
89-
void SwitchIncrementalFb(Context* cntx, LSN lsn);
98+
void SwitchIncrementalFb(LSN lsn);
9099

91100
// Called on traversing cursor by IterateBucketsFb.
92101
bool BucketSaveCb(PrimeTable::bucket_iterator it);
@@ -171,8 +180,8 @@ class SliceSnapshot {
171180

172181
ThreadLocalMutex big_value_mu_;
173182

174-
std::function<void(std::string)> on_push_;
175-
std::function<void()> on_snapshot_finish_;
183+
SnapshotDataConsumerInterface* consumer_;
184+
Context* cntx_;
176185
};
177186

178187
} // namespace dfly

0 commit comments

Comments
 (0)