Skip to content

Commit 70187f4

Browse files
committed
Correcting a blocking issue in Fan Out
1 parent 06a285f commit 70187f4

File tree

2 files changed

+105
-28
lines changed

2 files changed

+105
-28
lines changed

stream.go

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ package stream
1313
import (
1414
"context"
1515
"reflect"
16+
17+
. "go.structs.dev/gen"
1618
)
1719

1820
// Pipe accepts an incoming data channel and pipes it to the supplied
@@ -131,25 +133,45 @@ func FanOut[T any](ctx context.Context, in <-chan T, out ...chan<- T) {
131133
return
132134
}
133135

134-
for _, o := range out {
135-
// Closure to catch panic on closed channel write.
136-
// Continue Loop
137-
func() {
138-
select {
139-
case <-ctx.Done():
140-
return
141-
case o <- v:
142-
}
143-
}()
136+
// Closure to catch panic on closed channel write.
137+
selectCases := make([]reflect.SelectCase, 0, len(out)+1)
138+
139+
// 0 index is context
140+
selectCases = append(selectCases, reflect.SelectCase{
141+
Dir: reflect.SelectRecv,
142+
Chan: reflect.ValueOf(ctx.Done()),
143+
})
144+
145+
for _, outc := range out {
146+
// Skip nil channels until they are non-nil
147+
if outc == nil {
148+
continue
149+
}
150+
151+
selectCases = append(selectCases, reflect.SelectCase{
152+
Dir: reflect.SelectSend,
153+
Chan: reflect.ValueOf(outc),
154+
Send: reflect.ValueOf(v),
155+
})
156+
}
157+
158+
for len(selectCases) > 1 {
159+
chosen, _, _ := reflect.Select(selectCases)
160+
161+
// The context was cancelled.
162+
if chosen == 0 {
163+
return
164+
}
165+
166+
selectCases = Exclude(selectCases, selectCases[chosen])
144167
}
145168
}
146169

147170
}
148171
}
149172

150173
// Distribute accepts an incoming data channel and distributes the data among
151-
// the supplied outgoing data channels. This distribution is done stochastically
152-
// using the cryptographic random number generator.
174+
// the supplied outgoing data channels using a dynamic select statement.
153175
//
154176
// NOTE: Execute the Distribute function in a goroutine if parallel execution is
155177
// desired. Cancelling the context or closing the incoming channel is important
@@ -170,24 +192,19 @@ func Distribute[T any](ctx context.Context, in <-chan T, out ...chan<- T) {
170192
return
171193
}
172194

173-
// Closure to catch panic on closed channel write.
174-
func() {
175-
defer recover()
176-
177-
selectCases := make([]reflect.SelectCase, 0, len(out)+1)
178-
for _, outc := range out {
179-
selectCases = append(selectCases, reflect.SelectCase{
180-
Dir: reflect.SelectSend,
181-
Chan: reflect.ValueOf(outc),
182-
Send: reflect.ValueOf(v),
183-
})
184-
}
195+
selectCases := make([]reflect.SelectCase, 0, len(out)+1)
196+
for _, outc := range out {
185197
selectCases = append(selectCases, reflect.SelectCase{
186-
Dir: reflect.SelectRecv,
187-
Chan: reflect.ValueOf(ctx.Done()),
198+
Dir: reflect.SelectSend,
199+
Chan: reflect.ValueOf(outc),
200+
Send: reflect.ValueOf(v),
188201
})
189-
_, _, _ = reflect.Select(selectCases)
190-
}()
202+
}
203+
selectCases = append(selectCases, reflect.SelectCase{
204+
Dir: reflect.SelectRecv,
205+
Chan: reflect.ValueOf(ctx.Done()),
206+
})
207+
_, _, _ = reflect.Select(selectCases)
191208
}
192209

193210
}

stream_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,66 @@ func Test_Distribute_ZeroOut(t *testing.T) {
443443
Distribute(ctx, in)
444444
}
445445

446+
func Test_FanOut(t *testing.T) {
447+
ctx, cancel := context.WithCancel(context.Background())
448+
defer cancel()
449+
450+
c1, c2, c3 := make(chan int), make(chan int), make(chan int)
451+
var c4 chan int
452+
data := Ints[int](1000)
453+
454+
go FanOut(ctx, Slice[int](data).Chan(ctx), c1, c2, c3, c4)
455+
456+
seen := make(map[int]int)
457+
for i := 0; i < len(data)*3; i++ {
458+
select {
459+
case <-ctx.Done():
460+
t.Fatal("context cancelled")
461+
return
462+
case _, ok := <-c1:
463+
if !ok {
464+
return
465+
}
466+
467+
seen[1]++
468+
case _, ok := <-c2:
469+
if !ok {
470+
return
471+
}
472+
473+
seen[2]++
474+
case _, ok := <-c3:
475+
if !ok {
476+
return
477+
}
478+
479+
seen[3]++
480+
case _, ok := <-c4:
481+
if !ok {
482+
return
483+
}
484+
485+
seen[4]++
486+
}
487+
}
488+
489+
if len(seen) != 3 {
490+
t.Fatalf("expected %v, got %v", len(data)-1, len(seen))
491+
}
492+
493+
for k, v := range seen {
494+
if k == 4 {
495+
if v > 0 {
496+
t.Fatalf("expected %v, got %v", 0, v)
497+
}
498+
}
499+
500+
if v != len(data) {
501+
t.Fatalf("Chan C%v: expected %v, got %v", k, len(data), v)
502+
}
503+
}
504+
}
505+
446506
func Test_FanOut_ZeroOut(t *testing.T) {
447507
ctx, cancel := context.WithCancel(context.Background())
448508
defer cancel()

0 commit comments

Comments
 (0)