@@ -10,34 +10,83 @@ import (
10
10
"time"
11
11
12
12
"github.com/pkg/errors"
13
+ log "github.com/sirupsen/logrus"
13
14
14
15
"golang.org/x/crypto/ssh"
15
16
)
16
17
17
18
// SSHConnection encapsulates the connection to the device
18
19
type SSHConnection struct {
19
- device * Device
20
- client * ssh.Client
21
- conn net.Conn
22
- lastUsed time.Time
23
- mu sync.Mutex
24
- done chan struct {}
20
+ device * Device
21
+ sshClient * ssh.Client
22
+ tcpConn net.Conn
23
+ isConnected bool
24
+ mu sync.RWMutex // protects sshClient, tcpConn and isConnected
25
+ lastUsed time.Time
26
+ lastUsedMu sync.RWMutex
27
+ done chan struct {}
28
+ keepAliveInterval time.Duration
29
+ keepAliveTimeout time.Duration
25
30
}
26
31
27
- // RunCommand runs a command against the device
28
- func (c * SSHConnection ) RunCommand (cmd string ) ([]byte , error ) {
32
+ func NewSSHConnection (device * Device , keepAliveInterval time.Duration , keepAliveTimeout time.Duration ) * SSHConnection {
33
+ return & SSHConnection {
34
+ device : device ,
35
+ keepAliveInterval : keepAliveInterval ,
36
+ keepAliveTimeout : keepAliveTimeout ,
37
+ done : make (chan struct {}),
38
+ }
39
+ }
40
+
41
+ func (c * SSHConnection ) Start (expiredConnectionTimeout time.Duration ) error {
42
+ err := c .connect ()
43
+ if err != nil {
44
+ return err
45
+ }
46
+
47
+ go c .keepalive (expiredConnectionTimeout )
48
+ return nil
49
+ }
50
+
51
+ func (c * SSHConnection ) Stop (err error ) {
52
+ log .Infof ("Stopping SSH connection with %s (reason: %v)" , c .device .Host , err )
53
+
29
54
c .mu .Lock ()
30
55
defer c .mu .Unlock ()
31
56
32
- c .lastUsed = time .Now ()
57
+ if ! c .isConnected {
58
+ return
59
+ }
60
+
61
+ close (c .done )
33
62
34
- if c .client == nil {
35
- return nil , errors .New (fmt .Sprintf ("not connected with %s" , c .conn .RemoteAddr ().String ()))
63
+ if c .sshClient != nil {
64
+ c .sshClient .Close ()
65
+ c .sshClient = nil
36
66
}
37
67
38
- session , err := c .client .NewSession ()
68
+ if c .tcpConn != nil {
69
+ c .tcpConn .Close ()
70
+ c .tcpConn = nil
71
+ }
72
+
73
+ c .isConnected = false
74
+ }
75
+
76
+ // RunCommand runs a command against the device
77
+ func (c * SSHConnection ) RunCommand (cmd string ) ([]byte , error ) {
78
+ c .setLastUsed (time .Now ())
79
+
80
+ sshClient := c .getSSHClient ()
81
+ if sshClient == nil {
82
+ c .Stop (fmt .Errorf ("No ssh client" ))
83
+ return nil , errors .New (fmt .Sprintf ("no SSH client to %s" , c .device .Host ))
84
+ }
85
+
86
+ session , err := c .sshClient .NewSession ()
39
87
if err != nil {
40
- return nil , errors .Wrapf (err , "could not open session with %s" , c .conn .RemoteAddr ().String ())
88
+ c .Stop (fmt .Errorf ("SSH session failure" ))
89
+ return nil , errors .Wrapf (err , "could not open session with %s" , c .device .Host )
41
90
}
42
91
defer session .Close ()
43
92
@@ -46,37 +95,114 @@ func (c *SSHConnection) RunCommand(cmd string) ([]byte, error) {
46
95
47
96
err = session .Run (cmd )
48
97
if err != nil {
49
- return nil , errors .Wrapf (err , "could not run command %q on %s" , cmd , c .conn .RemoteAddr ().String ())
98
+ c .Stop (fmt .Errorf ("failed running command" ))
99
+ return nil , errors .Wrapf (err , "could not run command %q on %s" , cmd , c .device .Host )
50
100
}
51
101
52
102
return b .Bytes (), nil
53
103
}
54
104
55
- func (c * SSHConnection ) isConnected () bool {
56
- return c .conn != nil
105
+ func (c * SSHConnection ) keepalive (expiredConnectionTimeout time.Duration ) {
106
+ for {
107
+ select {
108
+ case <- time .After (c .keepAliveInterval ):
109
+ terminated := c .terminateIfLifetimeExpired (expiredConnectionTimeout )
110
+ if terminated {
111
+ return
112
+ }
113
+
114
+ _ = c .tcpConn .SetDeadline (time .Now ().Add (c .keepAliveTimeout ))
115
+
116
+ ok := c .testSSHClient ()
117
+ if ! ok {
118
+ return
119
+ }
120
+ case <- c .done :
121
+ return
122
+ }
123
+ }
57
124
}
58
125
59
- func (c * SSHConnection ) terminate () {
60
- c .mu .Lock ()
61
- defer c .mu .Unlock ()
126
+ func (c * SSHConnection ) terminateIfLifetimeExpired (expiredConnectionTimeout time.Duration ) bool {
127
+ if time .Since (c .GetLastUsed ()) > expiredConnectionTimeout {
128
+ c .Stop (fmt .Errorf ("lifetime expired" ))
129
+ return true
130
+ }
131
+
132
+ return false
133
+ }
134
+
135
+ func (c * SSHConnection ) testSSHClient () bool {
136
+ sshClient := c .getSSHClient ()
62
137
63
- c .conn .Close ()
138
+ _ , _ , err := sshClient .SendRequest ("keepalive@golang.org" , true , nil )
139
+ if err != nil {
140
+ log .Infof ("SSH keepalive request to %s failed: %v" , c .device , err )
141
+ c .Stop (fmt .Errorf ("keepalive failed" ))
142
+ return false
143
+ }
64
144
65
- c .client = nil
66
- c .conn = nil
145
+ return true
67
146
}
68
147
69
- func (c * SSHConnection ) close () {
148
+ func (c * SSHConnection ) connect () error {
149
+ cfg := & ssh.ClientConfig {
150
+ HostKeyCallback : ssh .InsecureIgnoreHostKey (),
151
+ Timeout : timeoutInSeconds * time .Second ,
152
+ }
153
+
154
+ c .device .Auth (cfg )
155
+
156
+ host := tcpAddressForHost (c .device .Host )
157
+ log .Infof ("Establishing TCP connection with %s" , host )
158
+
159
+ tcpConn , err := net .DialTimeout ("tcp" , host , cfg .Timeout )
160
+ if err != nil {
161
+ return fmt .Errorf ("could not open tcp connection: %w" , err )
162
+ }
163
+
164
+ sshConn , chans , reqs , err := ssh .NewClientConn (tcpConn , host , cfg )
165
+ if err != nil {
166
+ tcpConn .Close ()
167
+ return fmt .Errorf ("could not connect to device: %w" , err )
168
+ }
169
+
70
170
c .mu .Lock ()
71
171
defer c .mu .Unlock ()
72
172
73
- if c .client != nil {
74
- c .client .Close ()
75
- }
173
+ c .tcpConn = tcpConn
174
+ c .sshClient = ssh .NewClient (sshConn , chans , reqs )
175
+ c .isConnected = true
176
+
177
+ return nil
178
+ }
179
+
180
+ func (c * SSHConnection ) setLastUsed (t time.Time ) {
181
+ c .lastUsedMu .Lock ()
182
+ defer c .lastUsedMu .Unlock ()
183
+
184
+ c .lastUsed = t
185
+ }
186
+
187
+ func (c * SSHConnection ) GetLastUsed () time.Time {
188
+ c .lastUsedMu .RLock ()
189
+ defer c .lastUsedMu .RUnlock ()
190
+
191
+ return c .lastUsed
192
+ }
193
+
194
+ func (c * SSHConnection ) getSSHClient () * ssh.Client {
195
+ c .mu .RLock ()
196
+ defer c .mu .RUnlock ()
197
+
198
+ return c .sshClient
199
+ }
200
+
201
+ func (c * SSHConnection ) IsConnected () bool {
202
+ c .mu .RLock ()
203
+ defer c .mu .RUnlock ()
76
204
77
- c .done <- struct {}{}
78
- c .conn = nil
79
- c .client = nil
205
+ return c .isConnected
80
206
}
81
207
82
208
// Host returns the hostname of the connected device
0 commit comments