Skip to content

Commit ee9bcb9

Browse files
committed
feat: deduplicate identical requests
1 parent 8b1a51f commit ee9bcb9

File tree

4 files changed

+65
-16
lines changed

4 files changed

+65
-16
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ This is a implementation of a dataloader in Go.
99
- 200+ lines of code, easy to understand and maintain.
1010
- 100% test coverage, bug free and reliable.
1111
- Based on generics and can be used with any type of data.
12-
- Use a lru cache to store the loaded values.
12+
- Use a LRU cache to store the loaded values.
1313
- Can be used to batch and cache multiple requests.
14+
- Deduplicate identical requests, reducing the number of requests.
1415

1516
Installation
1617
---

benchmark_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func BenchmarkDataLoader(b *testing.B) {
2323
}
2424
})
2525

26-
b.Run("dataloader.AsyncLoad", func(b *testing.B) {
26+
b.Run("dataloader.Go", func(b *testing.B) {
2727
for i := 0; i < b.N; i++ {
2828
results := make([]<-chan Result[string], 10)
2929
for j := 0; j < 10; j++ {

dataloader.go

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type dataLoader[K comparable, V any] struct {
3333
config config
3434
mu sync.Mutex
3535
batch []K
36-
chs []chan Result[V]
36+
chs map[K][]chan Result[V]
3737
}
3838

3939
// Interface is a `DataLoader` Interface which defines a public API for loading data from a particular
@@ -93,10 +93,10 @@ func New[K comparable, V any](loader Loader[K, V], options ...Option) Interface[
9393
dl := &dataLoader[K, V]{
9494
loader: loader,
9595
config: config,
96-
batch: make([]K, 0, config.BatchSize),
97-
chs: make([]chan Result[V], 0, config.BatchSize),
9896
}
9997

98+
dl.reset()
99+
100100
// Create a cache if the cache size is greater than 0
101101
if config.CacheSize > 0 {
102102
dl.cache = expirable.NewLRU[K, V](config.CacheSize, nil, config.CacheExpire)
@@ -150,18 +150,24 @@ func (d *dataLoader[K, V]) Go(ctx context.Context, key K) <-chan Result[V] {
150150
go d.scheduleBatch(ctx, ch)
151151
}
152152

153+
// Check if the key is in flight
154+
if chs, ok := d.chs[key]; ok {
155+
d.chs[key] = append(chs, ch)
156+
d.mu.Unlock()
157+
return ch
158+
}
159+
153160
// If the current batch is full, start processing it
154161
if len(d.batch) >= d.config.BatchSize {
155162
// spawn a new goroutine to process the batch
156163
go d.processBatch(ctx, d.batch, d.chs)
157164
// Create a new batch, and a new set of channels
158-
d.batch = make([]K, 0, d.config.BatchSize)
159-
d.chs = make([]chan Result[V], 0, d.config.BatchSize)
165+
d.reset()
160166
}
161167

162168
// Add the key and channel to the current batch
163169
d.batch = append(d.batch, key)
164-
d.chs = append(d.chs, ch)
170+
d.chs[key] = []chan Result[V]{ch}
165171

166172
// Unlock the DataLoader
167173
d.mu.Unlock()
@@ -204,15 +210,20 @@ func (d *dataLoader[K, V]) LoadMap(ctx context.Context, keys []K) map[K]Result[V
204210
return results
205211
}
206212

213+
// reset resets the DataLoader
214+
func (d *dataLoader[K, V]) reset() {
215+
d.batch = make([]K, 0, d.config.BatchSize)
216+
d.chs = make(map[K][]chan Result[V], d.config.BatchSize)
217+
}
218+
207219
// scheduleBatch schedules a batch to be processed
208220
func (d *dataLoader[K, V]) scheduleBatch(ctx context.Context, ch chan Result[V]) {
209221
select {
210222
case <-time.After(d.config.Wait):
211223
d.mu.Lock()
212224
if len(d.batch) > 0 {
213225
go d.processBatch(ctx, d.batch, d.chs)
214-
d.batch = make([]K, 0, d.config.BatchSize)
215-
d.chs = make([]chan Result[V], 0, d.config.BatchSize)
226+
d.reset()
216227
}
217228
d.mu.Unlock()
218229
case <-ctx.Done():
@@ -221,17 +232,19 @@ func (d *dataLoader[K, V]) scheduleBatch(ctx context.Context, ch chan Result[V])
221232
}
222233

223234
// processBatch processes a batch of keys
224-
func (d *dataLoader[K, V]) processBatch(ctx context.Context, keys []K, chs []chan Result[V]) {
235+
func (d *dataLoader[K, V]) processBatch(ctx context.Context, keys []K, chs map[K][]chan Result[V]) {
225236
defer func() {
226237
if r := recover(); r != nil {
227238
const size = 64 << 10
228239
buf := make([]byte, size)
229240
buf = buf[:runtime.Stack(buf, false)]
230241
fmt.Fprintf(os.Stderr, "Dataloader: Panic received in loader function: %v\n%s", r, buf)
231242

232-
for _, ch := range chs {
233-
ch <- Result[V]{err: fmt.Errorf("panic received in loader function: %v", r)}
234-
close(ch)
243+
for _, chs := range chs {
244+
for _, ch := range chs {
245+
ch <- Result[V]{err: fmt.Errorf("panic received in loader function: %v", r)}
246+
close(ch)
247+
}
235248
}
236249
return
237250
}
@@ -242,8 +255,11 @@ func (d *dataLoader[K, V]) processBatch(ctx context.Context, keys []K, chs []cha
242255
if results[i].err == nil && d.cache != nil {
243256
d.cache.Add(key, results[i].data)
244257
}
245-
chs[i] <- results[i]
246-
close(chs[i])
258+
259+
for _, ch := range chs[key] {
260+
ch <- results[i]
261+
close(ch)
262+
}
247263
}
248264
}
249265

dataloader_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,38 @@ func TestDataLoader(t *testing.T) {
2121
t.Run("LoadMap", testLoadMap)
2222
t.Run("Panic recovered", testPanicRecovered)
2323
t.Run("Prime", testPrime)
24+
t.Run("Inflight", testInflight)
25+
}
26+
27+
func testInflight(t *testing.T) {
28+
loader := New(func(ctx context.Context, keys []int) []Result[string] {
29+
if len(keys) != 5 {
30+
t.Errorf("Expected 5 keys, got %d", keys)
31+
}
32+
33+
results := make([]Result[string], len(keys))
34+
for i, key := range keys {
35+
results[i] = Result[string]{data: fmt.Sprintf("Result for %d", key)}
36+
}
37+
return results
38+
}, WithBatchSize(5))
39+
40+
chs := make([]<-chan Result[string], 0)
41+
for i := 0; i < 10; i++ {
42+
chs = append(chs, loader.Go(context.Background(), i/2))
43+
}
44+
45+
for idx, ch := range chs {
46+
result := <-ch
47+
data, err := result.Unwrap()
48+
if err != nil {
49+
t.Errorf("Unexpected error: %v", err)
50+
}
51+
52+
if data != fmt.Sprintf("Result for %d", idx/2) {
53+
t.Errorf("Unexpected result: %v", data)
54+
}
55+
}
2456
}
2557

2658
func testBasicFunctionality(t *testing.T) {

0 commit comments

Comments
 (0)