Skip to content

Commit 46ae164

Browse files
committed
fix server panic on ping
1 parent 73458f1 commit 46ae164

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

server.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ const saltSize = 32
1616
type QueryHandler func(queryContext context.Context, query Query) *Result
1717
type OnShutdownCallback func(err error)
1818

19+
func defaultPingStatus(*IprotoServer) uint { return OKCommand }
20+
1921
type IprotoServer struct {
2022
sync.Mutex
2123
conn net.Conn
@@ -43,13 +45,14 @@ type IprotoServerOptions struct {
4345

4446
func NewIprotoServer(uuid string, handler QueryHandler, onShutdown OnShutdownCallback) *IprotoServer {
4547
return &IprotoServer{
46-
conn: nil,
47-
reader: nil,
48-
writer: nil,
49-
handler: handler,
50-
onShutdown: onShutdown,
51-
uuid: uuid,
52-
schemaID: 1,
48+
conn: nil,
49+
reader: nil,
50+
writer: nil,
51+
handler: handler,
52+
onShutdown: onShutdown,
53+
uuid: uuid,
54+
schemaID: 1,
55+
getPingStatus: defaultPingStatus,
5356
}
5457
}
5558

@@ -58,9 +61,8 @@ func (s *IprotoServer) WithOptions(opts *IprotoServerOptions) *IprotoServer {
5861
opts = &IprotoServerOptions{}
5962
}
6063
s.perf = opts.Perf
61-
s.getPingStatus = opts.GetPingStatus
62-
if s.getPingStatus == nil {
63-
s.getPingStatus = func(*IprotoServer) uint { return 0 }
64+
if opts.GetPingStatus != nil {
65+
s.getPingStatus = opts.GetPingStatus
6466
}
6567
return s
6668
}

server_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package tarantool
2+
3+
import (
4+
"context"
5+
"net"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestServerPing(t *testing.T) {
13+
handler := func(queryContext context.Context, query Query) *Result {
14+
return &Result{}
15+
}
16+
17+
s := NewIprotoServer("1", handler, nil)
18+
19+
listenAddr := make(chan string)
20+
go func() {
21+
ln, err := net.Listen("tcp", "127.0.0.1:0")
22+
require.NoError(t, err)
23+
defer ln.Close()
24+
25+
listenAddr <- ln.Addr().String()
26+
close(listenAddr)
27+
28+
conn, err := ln.Accept()
29+
require.NoError(t, err)
30+
31+
s.Accept(conn)
32+
}()
33+
34+
addr := <-listenAddr
35+
conn, err := Connect(addr, nil)
36+
require.NoError(t, err)
37+
38+
res := conn.Exec(context.Background(), &Ping{})
39+
assert.Equal(t, res.ErrorCode, OKCommand)
40+
assert.NoError(t, res.Error)
41+
42+
conn.Close()
43+
s.Shutdown()
44+
}

0 commit comments

Comments
 (0)