Skip to content

Commit 6b425f7

Browse files
authored
fix data race sending mail (#82)
* fix data race sending mail when timeout exceed * change localhost to 127.0.0.1 in test * minimal fixes in test * reduce loop logic
1 parent 6250c42 commit 6b425f7

File tree

2 files changed

+136
-13
lines changed

2 files changed

+136
-13
lines changed

email.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/textproto"
1111
"strconv"
1212
"strings"
13+
"sync"
1314
"time"
1415

1516
"github.com/toorop/go-dkim"
@@ -55,6 +56,7 @@ type SMTPServer struct {
5556

5657
// SMTPClient represents a SMTP Client for send email
5758
type SMTPClient struct {
59+
mu sync.Mutex
5860
Client *smtpClient
5961
KeepAlive bool
6062
SendTimeout time.Duration
@@ -865,21 +867,29 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) {
865867

866868
// Reset send RSET command to smtp client
867869
func (smtpClient *SMTPClient) Reset() error {
870+
smtpClient.mu.Lock()
871+
defer smtpClient.mu.Unlock()
868872
return smtpClient.Client.reset()
869873
}
870874

871875
// Noop send NOOP command to smtp client
872876
func (smtpClient *SMTPClient) Noop() error {
877+
smtpClient.mu.Lock()
878+
defer smtpClient.mu.Unlock()
873879
return smtpClient.Client.noop()
874880
}
875881

876882
// Quit send QUIT command to smtp client
877883
func (smtpClient *SMTPClient) Quit() error {
884+
smtpClient.mu.Lock()
885+
defer smtpClient.mu.Unlock()
878886
return smtpClient.Client.quit()
879887
}
880888

881889
// Close closes the connection
882890
func (smtpClient *SMTPClient) Close() error {
891+
smtpClient.mu.Lock()
892+
defer smtpClient.mu.Unlock()
883893
return smtpClient.Client.close()
884894
}
885895

@@ -909,14 +919,14 @@ func send(from string, to []string, msg string, client *SMTPClient) error {
909919
if client.SendTimeout != 0 {
910920
smtpSendChannel = make(chan error, 1)
911921

912-
go func(from string, to []string, msg string, c *smtpClient) {
913-
smtpSendChannel <- sendMailProcess(from, to, msg, c)
914-
}(from, to, msg, client.Client)
922+
go func(from string, to []string, msg string, client *SMTPClient) {
923+
smtpSendChannel <- sendMailProcess(from, to, msg, client)
924+
}(from, to, msg, client)
915925
}
916926

917927
if client.SendTimeout == 0 {
918928
// no SendTimeout, just fire the sendMailProcess
919-
return sendMailProcess(from, to, msg, client.Client)
929+
return sendMailProcess(from, to, msg, client)
920930
}
921931

922932
// get the send result or timeout result, which ever happens first
@@ -928,35 +938,36 @@ func send(from string, to []string, msg string, client *SMTPClient) error {
928938
checkKeepAlive(client)
929939
return errors.New("Mail Error: SMTP Send timed out")
930940
}
931-
932941
}
933942
}
934943

935944
return errors.New("Mail Error: No SMTP Client Provided")
936945
}
937946

938-
func sendMailProcess(from string, to []string, msg string, c *smtpClient) error {
947+
func sendMailProcess(from string, to []string, msg string, c *SMTPClient) error {
948+
c.mu.Lock()
949+
defer c.mu.Unlock()
939950

940951
cmdArgs := make(map[string]string)
941952

942-
if _, ok := c.ext["SIZE"]; ok {
953+
if _, ok := c.Client.ext["SIZE"]; ok {
943954
cmdArgs["SIZE"] = strconv.Itoa(len(msg))
944955
}
945956

946957
// Set the sender
947-
if err := c.mail(from, cmdArgs); err != nil {
958+
if err := c.Client.mail(from, cmdArgs); err != nil {
948959
return err
949960
}
950961

951962
// Set the recipients
952963
for _, address := range to {
953-
if err := c.rcpt(address); err != nil {
964+
if err := c.Client.rcpt(address); err != nil {
954965
return err
955966
}
956967
}
957968

958969
// Send the data command
959-
w, err := c.data()
970+
w, err := c.Client.data()
960971
if err != nil {
961972
return err
962973
}
@@ -978,9 +989,9 @@ func sendMailProcess(from string, to []string, msg string, c *smtpClient) error
978989
// check if keepAlive for close or reset
979990
func checkKeepAlive(client *SMTPClient) {
980991
if client.KeepAlive {
981-
client.Client.reset()
992+
client.Reset()
982993
} else {
983-
client.Client.quit()
984-
client.Client.close()
994+
client.Quit()
995+
client.Close()
985996
}
986997
}

email_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package mail
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"net"
7+
"testing"
8+
"time"
9+
)
10+
11+
func TestSendRace(t *testing.T) {
12+
port := 56666
13+
port2 := 56667
14+
timeout := 1 * time.Second
15+
16+
responses := []string{
17+
`220 test connected`,
18+
`250 after helo`,
19+
`250 after mail from`,
20+
`250 after rcpt to`,
21+
`354 after data`,
22+
}
23+
24+
startService(port, responses, 5*time.Second)
25+
startService(port2, responses, 0)
26+
27+
server := NewSMTPClient()
28+
server.ConnectTimeout = timeout
29+
server.SendTimeout = timeout
30+
server.KeepAlive = false
31+
server.Host = `127.0.0.1`
32+
server.Port = port
33+
34+
smtpClient, err := server.Connect()
35+
if err != nil {
36+
log.Fatalf("couldn't connect: %s", err.Error())
37+
}
38+
defer smtpClient.Close()
39+
40+
// create another server in other port to test timeouts
41+
server.Port = port2
42+
smtpClient2, err := server.Connect()
43+
if err != nil {
44+
log.Fatalf("couldn't connect: %s", err.Error())
45+
}
46+
defer smtpClient2.Close()
47+
48+
msg := NewMSG().
49+
SetFrom(`foo@bar`).
50+
AddTo(`rcpt@bar`).
51+
SetSubject("subject").
52+
SetBody(TextPlain, "body")
53+
54+
// the smtpClient2 has not timeout
55+
err = msg.Send(smtpClient2)
56+
if err != nil {
57+
log.Fatalf("couldn't send: %s", err.Error())
58+
}
59+
60+
// the smtpClient send to listener with the last response is after SendTimeout, so when this error is returned the test succeed.
61+
err = msg.Send(smtpClient)
62+
if err != nil && err.Error() != "Mail Error: SMTP Send timed out" {
63+
log.Fatalf("couldn't send: %s", err.Error())
64+
}
65+
}
66+
67+
func startService(port int, responses []string, timeout time.Duration) {
68+
log.Printf("starting service at %d...\n", port)
69+
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
70+
if err != nil {
71+
log.Fatalf("couldn't listen to port %d: %s", port, err)
72+
}
73+
74+
go func() {
75+
for {
76+
conn, err := listener.Accept()
77+
if err != nil {
78+
log.Fatalf("couldn't listen accept the request in port %d", port)
79+
}
80+
go respond(conn, responses, timeout)
81+
}
82+
}()
83+
}
84+
85+
func respond(conn net.Conn, responses []string, timeout time.Duration) {
86+
buf := make([]byte, 1024)
87+
for _, resp := range responses {
88+
write(conn, resp)
89+
n, err := conn.Read(buf)
90+
if err != nil {
91+
log.Println("couldn't read data")
92+
return
93+
}
94+
readStr := string(buf[:n])
95+
log.Printf("READ:%s", string(readStr))
96+
}
97+
98+
// if timeout, sleep for that time, otherwise sent a 250 OK
99+
if timeout > 0 {
100+
time.Sleep(timeout)
101+
} else {
102+
write(conn, "250 OK")
103+
}
104+
105+
conn.Close()
106+
fmt.Print("\n\n")
107+
}
108+
109+
func write(conn net.Conn, command string) {
110+
log.Printf("WRITE:%s", command)
111+
conn.Write([]byte(command + "\n"))
112+
}

0 commit comments

Comments
 (0)