Skip to content

Commit ff43f86

Browse files
chore: refactor ssh pool implementation (#1041) (#1042)
* chore: modularize the MonitorServerStatus function * feat: no need to delete ssh client if server is already offline - because ssh check function already done the job * feat: reduce delete ssh client call * feat: optimize the server status fixing algo * feat: pre-validation for ssh has been added * chore: reduce complexity (cherry picked from commit 38e6664) Co-authored-by: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com>
1 parent a360b95 commit ff43f86

File tree

8 files changed

+121
-50
lines changed

8 files changed

+121
-50
lines changed

ssh_toolkit/command.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@ import (
1414
func ExecCommandOverSSH(cmd string,
1515
stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task
1616
host string, port int, user string, privateKey string, // for ssh client
17+
) error {
18+
return ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, sessionTimeoutSeconds, host, port, user, privateKey, true)
19+
}
20+
21+
func ExecCommandOverSSHWithOptions(cmd string,
22+
stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task
23+
host string, port int, user string, privateKey string, // for ssh client
24+
validate bool, // if true, will validate if server is online
1725
) error {
1826
// fetch ssh client
19-
sshRecord, err := getSSHClient(host, port, user, privateKey)
27+
sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, validate)
2028
if err != nil {
2129
if isErrorWhenSSHClientNeedToBeRecreated(err) {
2230
DeleteSSHClient(host)
@@ -49,12 +57,9 @@ func ExecCommandOverSSH(cmd string,
4957
// run command
5058
err = session.Run(cmd)
5159
if err != nil {
52-
if isErrorWhenSSHClientNeedToBeRecreated(err) {
53-
DeleteSSHClient(host)
54-
}
55-
if isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) {
60+
if isErrorWhenSSHClientNeedToBeRecreated(err) || isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) {
5661
DeleteSSHClient(host)
57-
return fmt.Errorf("%s - %s", err, stderrBuf.String())
62+
return fmt.Errorf("%s - %s", err.Error(), stderrBuf.String())
5863
}
5964
return err
6065
}

ssh_toolkit/errors.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package ssh_toolkit
22

3-
import "strings"
3+
import (
4+
"strings"
5+
)
46

57
var errorsWhenSSHClientNeedToBeRecreated = []string{
68
"dial timeout",
@@ -23,17 +25,20 @@ var errorsWhenSSHClientNeedToBeRecreated = []string{
2325
"open failed",
2426
"handshake failed",
2527
"subsystem request failed",
26-
"EOF",
28+
"eof",
2729
"broken pipe",
2830
"closing write end of pipe",
31+
"connection reset by peer",
32+
"unexpected packet in response to channel open",
2933
}
3034

3135
func isErrorWhenSSHClientNeedToBeRecreated(err error) bool {
3236
if err == nil {
3337
return false
3438
}
39+
errMsg := strings.ToLower(err.Error())
3540
for _, msg := range errorsWhenSSHClientNeedToBeRecreated {
36-
if strings.Contains(err.Error(), msg) {
41+
if strings.Contains(errMsg, msg) {
3742
return true
3843
}
3944
}

ssh_toolkit/net_conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func NetConnOverSSH(
1212
host string, port int, user string, privateKey string, // for ssh client
1313
) (net.Conn, error) {
1414
// fetch ssh client
15-
sshRecord, err := getSSHClient(host, port, user, privateKey)
15+
sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, true)
1616
if err != nil {
1717
if isErrorWhenSSHClientNeedToBeRecreated(err) {
1818
DeleteSSHClient(host)

ssh_toolkit/pool.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ssh_toolkit
22

33
import (
4+
"errors"
45
"fmt"
56
"log"
67
"sync"
@@ -13,12 +14,23 @@ var sshClientPool *sshConnectionPool
1314

1415
func init() {
1516
sshClientPool = &sshConnectionPool{
16-
clients: make(map[string]*sshClient),
17-
mutex: &sync.RWMutex{},
17+
clients: make(map[string]*sshClient),
18+
mutex: &sync.RWMutex{},
19+
validator: nil,
1820
}
1921
}
2022

21-
func getSSHClient(host string, port int, user string, privateKey string) (*ssh.Client, error) {
23+
func SetValidator(validator ServerOnlineStatusValidator) {
24+
sshClientPool.mutex.Lock()
25+
defer sshClientPool.mutex.Unlock()
26+
sshClientPool.validator = &validator
27+
}
28+
29+
func getSSHClientWithOptions(host string, port int, user string, privateKey string, validate bool) (*ssh.Client, error) {
30+
// reject if server is offline
31+
if validate && sshClientPool.validator != nil && !(*sshClientPool.validator)(host) {
32+
return nil, errors.New("server is offline, cannot connect to it")
33+
}
2234
sshClientPool.mutex.RLock()
2335
clientEntry, ok := sshClientPool.clients[host]
2436
sshClientPool.mutex.RUnlock()
@@ -102,7 +114,7 @@ func DeleteSSHClient(host string) {
102114
}
103115
}
104116
clientEntry.mutex.Unlock()
117+
delete(sshClientPool.clients, host)
105118
}
106-
delete(sshClientPool.clients, host)
107119
sshClientPool.mutex.Unlock()
108120
}

ssh_toolkit/types.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
package ssh_toolkit
22

33
import (
4-
"golang.org/x/crypto/ssh"
54
"sync"
5+
6+
"golang.org/x/crypto/ssh"
67
)
78

89
type sshConnectionPool struct {
9-
clients map[string]*sshClient // map of <host:port> to sshClient
10-
mutex *sync.RWMutex
10+
clients map[string]*sshClient // map of <host:port> to sshClient
11+
mutex *sync.RWMutex
12+
validator *ServerOnlineStatusValidator
1113
}
1214

15+
type ServerOnlineStatusValidator func(host string) bool
16+
1317
type sshClient struct {
1418
client *ssh.Client
1519
mutex *sync.RWMutex

swiftwave_service/core/server.operations.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ package core
33
import (
44
"errors"
55
"fmt"
6-
"gorm.io/gorm"
76
"net"
87
"time"
8+
9+
"gorm.io/gorm"
910
)
1011

1112
// CreateServer creates a new server in the database
@@ -109,6 +110,16 @@ func FetchServerByID(db *gorm.DB, id uint) (*Server, error) {
109110
return &server, err
110111
}
111112

113+
// FetchServerByIP fetches a server by its IP from the database
114+
func FetchServerByIP(db *gorm.DB, ip string) (*Server, error) {
115+
var server Server
116+
err := db.Where("ip = ?", ip).First(&server).Error
117+
if errors.Is(err, gorm.ErrRecordNotFound) {
118+
return nil, errors.New("server not found")
119+
}
120+
return &server, err
121+
}
122+
112123
// FetchServerIDByHostName fetches a server by its hostname from the database
113124
func FetchServerIDByHostName(db *gorm.DB, hostName string) (uint, error) {
114125
var server Server

swiftwave_service/cronjob/server_status_monitor.go

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@ package cronjob
22

33
import (
44
"bytes"
5+
"strings"
6+
"sync"
7+
"time"
8+
59
"github.com/swiftwave-org/swiftwave/ssh_toolkit"
610
"github.com/swiftwave-org/swiftwave/swiftwave_service/core"
711
"github.com/swiftwave-org/swiftwave/swiftwave_service/logger"
8-
"strings"
9-
"time"
1012
)
1113

1214
func (m Manager) MonitorServerStatus() {
1315
logger.CronJobLogger.Println("Starting server status monitor [cronjob]")
1416
for {
1517
m.monitorServerStatus()
16-
time.Sleep(1 * time.Minute)
18+
time.Sleep(2 * time.Second)
1719
}
1820
}
1921

@@ -29,43 +31,64 @@ func (m Manager) monitorServerStatus() {
2931
if len(servers) == 0 {
3032
logger.CronJobLogger.Println("Skipping ! No server found")
3133
return
34+
}
35+
36+
var wg sync.WaitGroup
37+
for _, server := range servers {
38+
if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing {
39+
continue
40+
}
41+
wg.Add(1)
42+
go func(server core.Server) {
43+
defer wg.Done()
44+
m.checkAndUpdateServerStatus(server)
45+
}(server)
46+
}
47+
wg.Wait()
48+
}
49+
50+
func (m Manager) checkAndUpdateServerStatus(server core.Server) {
51+
if m.isServerOnline(server) {
52+
if server.Status != core.ServerOnline {
53+
err := core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server)
54+
if err != nil {
55+
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online >", server.HostName, err)
56+
} else {
57+
logger.CronJobLogger.Println("Server marked as online >", server.HostName)
58+
}
59+
}
3260
} else {
33-
for _, server := range servers {
34-
if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing {
35-
continue
61+
if server.Status != core.ServerOffline {
62+
err := core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server)
63+
if err != nil {
64+
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline >", server.HostName, err)
65+
} else {
66+
logger.CronJobLogger.Println("Server marked as offline >", server.HostName)
3667
}
37-
go func(server core.Server) {
38-
if server.Status == core.ServerOffline {
39-
ssh_toolkit.DeleteSSHClient(server.HostName)
40-
}
41-
if m.isServerOnline(server) {
42-
err = core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server)
43-
if err != nil {
44-
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online > ", server.HostName)
45-
} else {
46-
logger.CronJobLogger.Println("Server marked as online > ", server.HostName)
47-
}
48-
} else {
49-
err = core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server)
50-
if err != nil {
51-
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline > ", server.HostName)
52-
} else {
53-
logger.CronJobLogger.Println("Server marked as offline > ", server.HostName)
54-
}
55-
}
56-
}(server)
68+
} else {
69+
logger.CronJobLogger.Println("Server already offline >", server.HostName)
5770
}
5871
}
5972
}
6073

6174
func (m Manager) isServerOnline(server core.Server) bool {
75+
retries := 3 // try for 3 times before giving up
76+
if server.Status == core.ServerOffline {
77+
/**
78+
* If server is offline, try only once
79+
* Else, it will take total 30 seconds (3 retries * 10 seconds of default SSH timeout)
80+
*/
81+
retries = 1
82+
}
6283
// try for 3 times
63-
for i := 0; i < 3; i++ {
84+
for i := 0; i < retries; i++ {
6485
cmd := "echo ok"
6586
stdoutBuf := new(bytes.Buffer)
6687
stderrBuf := new(bytes.Buffer)
67-
err := ssh_toolkit.ExecCommandOverSSH(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey)
88+
err := ssh_toolkit.ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey, false)
6889
if err != nil {
90+
logger.CronJobLoggerError.Println("Error while checking if server is online", server.HostName, err.Error())
91+
time.Sleep(1 * time.Second)
6992
continue
7093
}
7194
if strings.Compare(strings.TrimSpace(stdoutBuf.String()), "ok") == 0 {

swiftwave_service/main.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@ package swiftwave
33
import (
44
"context"
55
"fmt"
6+
"log"
7+
"net/http"
8+
"strings"
9+
610
"github.com/fatih/color"
711
"github.com/golang-jwt/jwt/v5"
812
echojwt "github.com/labstack/echo-jwt/v4"
13+
"github.com/swiftwave-org/swiftwave/ssh_toolkit"
914
"github.com/swiftwave-org/swiftwave/swiftwave_service/config"
1015
"github.com/swiftwave-org/swiftwave/swiftwave_service/console"
1116
"github.com/swiftwave-org/swiftwave/swiftwave_service/core"
1217
"github.com/swiftwave-org/swiftwave/swiftwave_service/dashboard"
1318
"github.com/swiftwave-org/swiftwave/swiftwave_service/logger"
1419
"github.com/swiftwave-org/swiftwave/swiftwave_service/service_manager"
15-
"log"
16-
"net/http"
17-
"strings"
1820

1921
"github.com/labstack/echo/v4"
2022
"github.com/labstack/echo/v4/middleware"
@@ -32,6 +34,15 @@ func StartSwiftwave(config *config.Config) {
3234
}
3335
manager.Load(*config)
3436

37+
// Set the server status validator for ssh
38+
ssh_toolkit.SetValidator(func(host string) bool {
39+
server, err := core.FetchServerByIP(&manager.DbClient, host)
40+
if err != nil {
41+
return false
42+
}
43+
return server.Status != core.ServerOffline
44+
})
45+
3546
// Create pubsub default topics
3647
err := manager.PubSubClient.CreateTopic(manager.CancelImageBuildTopic)
3748
if err != nil {
@@ -50,7 +61,7 @@ func StartSwiftwave(config *config.Config) {
5061
cronjobManager.Start(true)
5162

5263
// create a channel to block the main thread
53-
var waitForever chan struct{}
64+
waitForever := make(chan struct{})
5465

5566
// StartSwiftwave the swift wave server
5667
go StartServer(config, manager, workerManager)

0 commit comments

Comments
 (0)