From 39392681c4a42ade314172f9063c0027048defba Mon Sep 17 00:00:00 2001 From: Alexandre Beslic Date: Thu, 18 Jul 2024 20:06:01 +0100 Subject: [PATCH] Fix stuck call to Dial when calling Stop on the initiator This commit fixes an issue when calling Start() and then Stop() on the initiator while the connection is likely to fail and timeout. Calling initiator.Stop() will block since Dial will attempt to connect until it times out and returns on the 'waitForReconnectInterval' call. We mitigate this problem by using a proxy.ContextDialer and allowing to pass a context with cancellation method to the dialer.DialContext method on 'handleConnection'. We need to start a routine listening for the stopChan in order to call cancel() explicitly and thus exit the DialContext method. Note: there are scenarios where cancel() will be called twice, this choice was made in order to avoid a larger refactor of the reconnect logic, but since the call to cancel() is idempotent, this doesn't lead to any adverse effect. fixes quickfixgo#653 Signed-off-by: Alexandre Beslic --- dialer.go | 18 ++++++++++++++++-- initiator.go | 26 +++++++++++++++++++++++--- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/dialer.go b/dialer.go index a8e4c1893..de4419ff8 100644 --- a/dialer.go +++ b/dialer.go @@ -25,7 +25,7 @@ import ( "github.com/quickfixgo/quickfix/config" ) -func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) { +func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) { stdDialer := &net.Dialer{} if settings.HasSetting(config.SocketTimeout) { timeout, err := settings.DurationSetting(config.SocketTimeout) @@ -73,9 +73,23 @@ func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error } } - dialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, dialer) + var proxyDialer proxy.Dialer + + proxyDialer, err = proxy.SOCKS5("tcp", fmt.Sprintf("%s:%d", proxyHost, proxyPort), proxyAuth, stdDialer) + if err != nil { + return + } + + if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok { + dialer = contextDialer + } else { + err = fmt.Errorf("proxy does not support context dialer") + return + } + default: err = fmt.Errorf("unsupported proxy type %s", proxyType) } + return } diff --git a/initiator.go b/initiator.go index 8f7a76200..18451477e 100644 --- a/initiator.go +++ b/initiator.go @@ -17,6 +17,7 @@ package quickfix import ( "bufio" + "context" "crypto/tls" "strings" "sync" @@ -50,7 +51,7 @@ func (i *Initiator) Start() (err error) { return } - var dialer proxy.Dialer + var dialer proxy.ContextDialer if dialer, err = loadDialerConfig(settings); err != nil { return } @@ -142,7 +143,7 @@ func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bo return true } -func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.Dialer) { +func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) { var wg sync.WaitGroup wg.Add(1) go func() { @@ -162,6 +163,19 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di return } + ctx, cancel := context.WithCancel(context.Background()) + + // We start a goroutine in order to be able to cancel the dialer mid-connection + // on receiving a stop signal to stop the initiator. + go func() { + select { + case <-i.stopChan: + cancel() + case <-ctx.Done(): + return + } + }() + var disconnected chan interface{} var msgIn chan fixIn var msgOut chan []byte @@ -169,7 +183,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)] session.log.OnEventf("Connecting to: %v", address) - netConn, err := dialer.Dial("tcp", address) + netConn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { session.log.OnEventf("Failed to connect: %v", err) goto reconnect @@ -208,6 +222,10 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di close(disconnected) }() + // This ensures we properly cleanup the goroutine and context used for + // dial cancelation after successful connection. + cancel() + select { case <-disconnected: case <-i.stopChan: @@ -215,6 +233,8 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di } reconnect: + cancel() + connectionAttempt++ session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval) if !i.waitForReconnectInterval(session.ReconnectInterval) {