Skip to content

Commit f203307

Browse files
authored
fix: ShutdownWithContext and ctx.Done() exist race. (#1908)
* fix: ShutdownWithContext and ctx.Done() exist race. * fix: Even if ln.Close() err, the Shutdown process should still proceed. * refactor: remove END label.
1 parent 7b74fc9 commit f203307

File tree

2 files changed

+66
-22
lines changed

2 files changed

+66
-22
lines changed

server.go

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,8 @@ func (s *Server) Shutdown() error {
18871887
//
18881888
// ShutdownWithContext does not close keepalive connections so it's recommended to set ReadTimeout and IdleTimeout
18891889
// to something else than 0.
1890+
//
1891+
// When ShutdownWithContext returns errors, any operation to the Server is unavailable.
18901892
func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
18911893
s.mu.Lock()
18921894
defer s.mu.Unlock()
@@ -1898,11 +1900,7 @@ func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
18981900
return nil
18991901
}
19001902

1901-
for _, ln := range s.ln {
1902-
if err = ln.Close(); err != nil {
1903-
return err
1904-
}
1905-
}
1903+
lnerr := s.closeListenersLocked()
19061904

19071905
if s.done != nil {
19081906
close(s.done)
@@ -1913,28 +1911,25 @@ func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
19131911
// Now we just have to wait until all workers are done or timeout.
19141912
ticker := time.NewTicker(time.Millisecond * 100)
19151913
defer ticker.Stop()
1916-
END:
1914+
19171915
for {
19181916
s.closeIdleConns()
19191917

19201918
if open := atomic.LoadInt32(&s.open); open == 0 {
1921-
break
1919+
// There may be a pending request to call ctx.Done(). Therefore, we only set it to nil when open == 0.
1920+
s.done = nil
1921+
return lnerr
19221922
}
19231923
// This is not an optimal solution but using a sync.WaitGroup
19241924
// here causes data races as it's hard to prevent Add() to be called
19251925
// while Wait() is waiting.
19261926
select {
19271927
case <-ctx.Done():
1928-
err = ctx.Err()
1929-
break END
1928+
return ctx.Err()
19301929
case <-ticker.C:
19311930
continue
19321931
}
19331932
}
1934-
1935-
s.done = nil
1936-
s.ln = nil
1937-
return err
19381933
}
19391934

19401935
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
@@ -2749,15 +2744,7 @@ func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) {
27492744
// Note: Because creating a new channel for every request is just too expensive, so
27502745
// RequestCtx.s.done is only closed when the server is shutting down.
27512746
func (ctx *RequestCtx) Done() <-chan struct{} {
2752-
// fix use new variables to prevent panic caused by modifying the original done chan to nil.
2753-
done := ctx.s.done
2754-
2755-
if done == nil {
2756-
done = make(chan struct{}, 1)
2757-
done <- struct{}{}
2758-
return done
2759-
}
2760-
return done
2747+
return ctx.s.done
27612748
}
27622749

27632750
// Err returns a non-nil error value after Done is closed,
@@ -2934,6 +2921,17 @@ func (s *Server) closeIdleConns() {
29342921
s.idleConnsMu.Unlock()
29352922
}
29362923

2924+
func (s *Server) closeListenersLocked() error {
2925+
var err error
2926+
for _, ln := range s.ln {
2927+
if cerr := ln.Close(); cerr != nil && err == nil {
2928+
err = cerr
2929+
}
2930+
}
2931+
s.ln = nil
2932+
return err
2933+
}
2934+
29372935
// A ConnState represents the state of a client connection to a server.
29382936
// It's used by the optional Server.ConnState hook.
29392937
type ConnState int

server_race_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//go:build race
2+
3+
package fasthttp
4+
5+
import (
6+
"context"
7+
"github.com/valyala/fasthttp/fasthttputil"
8+
"math"
9+
"testing"
10+
)
11+
12+
func TestServerDoneRace(t *testing.T) {
13+
t.Parallel()
14+
15+
s := &Server{
16+
Handler: func(ctx *RequestCtx) {
17+
for i := 0; i < math.MaxInt; i++ {
18+
ctx.Done()
19+
}
20+
},
21+
}
22+
23+
ln := fasthttputil.NewInmemoryListener()
24+
defer ln.Close()
25+
26+
go func() {
27+
if err := s.Serve(ln); err != nil {
28+
t.Errorf("unexpected error: %v", err)
29+
}
30+
}()
31+
32+
c, err := ln.Dial()
33+
if err != nil {
34+
t.Fatalf("unexpected error: %v", err)
35+
}
36+
defer c.Close()
37+
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: go.dev\r\nContent-Length: 3\r\n\r\nABC" +
38+
"\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it
39+
"GET / HTTP/1.1\r\nHost: go.dev\r\n\r\n")); err != nil {
40+
t.Fatal(err)
41+
}
42+
ctx, cancelFunc := context.WithCancel(context.Background())
43+
cancelFunc()
44+
45+
s.ShutdownWithContext(ctx)
46+
}

0 commit comments

Comments
 (0)