From 46ae1643cff3fc2d476a6742c4a57ef6ea7a1807 Mon Sep 17 00:00:00 2001 From: speshal71 Date: Wed, 5 Mar 2025 11:49:24 +0300 Subject: [PATCH] fix server panic on ping --- server.go | 22 ++++++++++++---------- server_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 server_test.go diff --git a/server.go b/server.go index eb58ec3..116f876 100644 --- a/server.go +++ b/server.go @@ -16,6 +16,8 @@ const saltSize = 32 type QueryHandler func(queryContext context.Context, query Query) *Result type OnShutdownCallback func(err error) +func defaultPingStatus(*IprotoServer) uint { return OKCommand } + type IprotoServer struct { sync.Mutex conn net.Conn @@ -43,13 +45,14 @@ type IprotoServerOptions struct { func NewIprotoServer(uuid string, handler QueryHandler, onShutdown OnShutdownCallback) *IprotoServer { return &IprotoServer{ - conn: nil, - reader: nil, - writer: nil, - handler: handler, - onShutdown: onShutdown, - uuid: uuid, - schemaID: 1, + conn: nil, + reader: nil, + writer: nil, + handler: handler, + onShutdown: onShutdown, + uuid: uuid, + schemaID: 1, + getPingStatus: defaultPingStatus, } } @@ -58,9 +61,8 @@ func (s *IprotoServer) WithOptions(opts *IprotoServerOptions) *IprotoServer { opts = &IprotoServerOptions{} } s.perf = opts.Perf - s.getPingStatus = opts.GetPingStatus - if s.getPingStatus == nil { - s.getPingStatus = func(*IprotoServer) uint { return 0 } + if opts.GetPingStatus != nil { + s.getPingStatus = opts.GetPingStatus } return s } diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..4b30ba1 --- /dev/null +++ b/server_test.go @@ -0,0 +1,44 @@ +package tarantool + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServerPing(t *testing.T) { + handler := func(queryContext context.Context, query Query) *Result { + return &Result{} + } + + s := NewIprotoServer("1", handler, nil) + + listenAddr := make(chan string) + go func() { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + listenAddr <- ln.Addr().String() + close(listenAddr) + + conn, err := ln.Accept() + require.NoError(t, err) + + s.Accept(conn) + }() + + addr := <-listenAddr + conn, err := Connect(addr, nil) + require.NoError(t, err) + + res := conn.Exec(context.Background(), &Ping{}) + assert.Equal(t, res.ErrorCode, OKCommand) + assert.NoError(t, res.Error) + + conn.Close() + s.Shutdown() +}