diff --git a/accepter_test.go b/accepter_test.go index 54bfff845..5cf888a8c 100644 --- a/accepter_test.go +++ b/accepter_test.go @@ -16,6 +16,7 @@ package quickfix import ( + "crypto/tls" "net" "testing" @@ -23,6 +24,7 @@ import ( proxyproto "github.com/pires/go-proxyproto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAcceptor_Start(t *testing.T) { @@ -83,3 +85,44 @@ func TestAcceptor_Start(t *testing.T) { }) } } + +func TestAcceptor_SetTLSConfig(t *testing.T) { + sessionSettings := NewSessionSettings() + sessionSettings.Set(config.BeginString, BeginStringFIX42) + sessionSettings.Set(config.SenderCompID, "sender") + sessionSettings.Set(config.TargetCompID, "target") + + genericSettings := NewSettings() + + genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001") + _, err := genericSettings.AddSession(sessionSettings) + require.NoError(t, err) + + logger, err := NewScreenLogFactory().Create() + require.NoError(t, err) + acceptor := &Acceptor{settings: genericSettings, globalLog: logger} + defer acceptor.Stop() + // example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function + // as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs. + customizedTLSConfig := tls.Config{ + Certificates: []tls.Certificate{}, + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key") + if err != nil { + return nil, err + } + return &cert, nil + }, + } + + acceptor.SetTLSConfig(&customizedTLSConfig) + assert.NoError(t, acceptor.Start()) + assert.Len(t, acceptor.listeners, 1) + + conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{ + InsecureSkipVerify: true, + }) + require.NoError(t, err) + assert.NotNil(t, conn) + defer conn.Close() +} diff --git a/acceptor.go b/acceptor.go index f5b9b281c..2f9f6c48b 100644 --- a/acceptor.go +++ b/acceptor.go @@ -48,6 +48,7 @@ type Acceptor struct { sessionHostPort map[SessionID]int listeners map[string]net.Listener connectionValidator ConnectionValidator + tlsConfig *tls.Config sessionFactory } @@ -81,9 +82,12 @@ func (a *Acceptor) Start() (err error) { a.listeners[address] = nil } - var tlsConfig *tls.Config - if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil { - return + if a.tlsConfig == nil { + var tlsConfig *tls.Config + if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil { + return + } + a.tlsConfig = tlsConfig } var useTCPProxy bool @@ -94,8 +98,8 @@ func (a *Acceptor) Start() (err error) { } for address := range a.listeners { - if tlsConfig != nil { - if a.listeners[address], err = tls.Listen("tcp", address, tlsConfig); err != nil { + if a.tlsConfig != nil { + if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil { return } } else if a.listeners[address], err = net.Listen("tcp", address); err != nil { @@ -421,3 +425,13 @@ LOOP: func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { a.connectionValidator = validator } + +// SetTLSConfig allows the creator of the Acceptor to specify a fully customizable tls.Config of their choice, +// which will be used in the Start() method. +// +// Note: when the caller explicitly provides a tls.Config with this function, +// it takes precendent over TLS settings specified in the acceptor's settings.GlobalSettings(), +// meaning that the `settings.GlobalSettings()` object is not inspected or used for the creation of the tls.Config. +func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) { + a.tlsConfig = tlsConfig +}