@@ -33,7 +33,7 @@ type dataLoader[K comparable, V any] struct {
33
33
config config
34
34
mu sync.Mutex
35
35
batch []K
36
- chs []chan Result [V ]
36
+ chs map [ K ] []chan Result [V ]
37
37
}
38
38
39
39
// Interface is a `DataLoader` Interface which defines a public API for loading data from a particular
@@ -93,10 +93,10 @@ func New[K comparable, V any](loader Loader[K, V], options ...Option) Interface[
93
93
dl := & dataLoader [K , V ]{
94
94
loader : loader ,
95
95
config : config ,
96
- batch : make ([]K , 0 , config .BatchSize ),
97
- chs : make ([]chan Result [V ], 0 , config .BatchSize ),
98
96
}
99
97
98
+ dl .reset ()
99
+
100
100
// Create a cache if the cache size is greater than 0
101
101
if config .CacheSize > 0 {
102
102
dl .cache = expirable .NewLRU [K , V ](config .CacheSize , nil , config .CacheExpire )
@@ -150,18 +150,24 @@ func (d *dataLoader[K, V]) Go(ctx context.Context, key K) <-chan Result[V] {
150
150
go d .scheduleBatch (ctx , ch )
151
151
}
152
152
153
+ // Check if the key is in flight
154
+ if chs , ok := d .chs [key ]; ok {
155
+ d .chs [key ] = append (chs , ch )
156
+ d .mu .Unlock ()
157
+ return ch
158
+ }
159
+
153
160
// If the current batch is full, start processing it
154
161
if len (d .batch ) >= d .config .BatchSize {
155
162
// spawn a new goroutine to process the batch
156
163
go d .processBatch (ctx , d .batch , d .chs )
157
164
// Create a new batch, and a new set of channels
158
- d .batch = make ([]K , 0 , d .config .BatchSize )
159
- d .chs = make ([]chan Result [V ], 0 , d .config .BatchSize )
165
+ d .reset ()
160
166
}
161
167
162
168
// Add the key and channel to the current batch
163
169
d .batch = append (d .batch , key )
164
- d .chs = append ( d . chs , ch )
170
+ d .chs [ key ] = [] chan Result [ V ]{ ch }
165
171
166
172
// Unlock the DataLoader
167
173
d .mu .Unlock ()
@@ -204,15 +210,20 @@ func (d *dataLoader[K, V]) LoadMap(ctx context.Context, keys []K) map[K]Result[V
204
210
return results
205
211
}
206
212
213
+ // reset resets the DataLoader
214
+ func (d * dataLoader [K , V ]) reset () {
215
+ d .batch = make ([]K , 0 , d .config .BatchSize )
216
+ d .chs = make (map [K ][]chan Result [V ], d .config .BatchSize )
217
+ }
218
+
207
219
// scheduleBatch schedules a batch to be processed
208
220
func (d * dataLoader [K , V ]) scheduleBatch (ctx context.Context , ch chan Result [V ]) {
209
221
select {
210
222
case <- time .After (d .config .Wait ):
211
223
d .mu .Lock ()
212
224
if len (d .batch ) > 0 {
213
225
go d .processBatch (ctx , d .batch , d .chs )
214
- d .batch = make ([]K , 0 , d .config .BatchSize )
215
- d .chs = make ([]chan Result [V ], 0 , d .config .BatchSize )
226
+ d .reset ()
216
227
}
217
228
d .mu .Unlock ()
218
229
case <- ctx .Done ():
@@ -221,17 +232,19 @@ func (d *dataLoader[K, V]) scheduleBatch(ctx context.Context, ch chan Result[V])
221
232
}
222
233
223
234
// processBatch processes a batch of keys
224
- func (d * dataLoader [K , V ]) processBatch (ctx context.Context , keys []K , chs []chan Result [V ]) {
235
+ func (d * dataLoader [K , V ]) processBatch (ctx context.Context , keys []K , chs map [ K ] []chan Result [V ]) {
225
236
defer func () {
226
237
if r := recover (); r != nil {
227
238
const size = 64 << 10
228
239
buf := make ([]byte , size )
229
240
buf = buf [:runtime .Stack (buf , false )]
230
241
fmt .Fprintf (os .Stderr , "Dataloader: Panic received in loader function: %v\n %s" , r , buf )
231
242
232
- for _ , ch := range chs {
233
- ch <- Result [V ]{err : fmt .Errorf ("panic received in loader function: %v" , r )}
234
- close (ch )
243
+ for _ , chs := range chs {
244
+ for _ , ch := range chs {
245
+ ch <- Result [V ]{err : fmt .Errorf ("panic received in loader function: %v" , r )}
246
+ close (ch )
247
+ }
235
248
}
236
249
return
237
250
}
@@ -242,8 +255,11 @@ func (d *dataLoader[K, V]) processBatch(ctx context.Context, keys []K, chs []cha
242
255
if results [i ].err == nil && d .cache != nil {
243
256
d .cache .Add (key , results [i ].data )
244
257
}
245
- chs [i ] <- results [i ]
246
- close (chs [i ])
258
+
259
+ for _ , ch := range chs [key ] {
260
+ ch <- results [i ]
261
+ close (ch )
262
+ }
247
263
}
248
264
}
249
265
0 commit comments