Skip to content

Commit 54be465

Browse files
committed
feat: add PreDialPlugin
1 parent 62063fb commit 54be465

File tree

3 files changed

+67
-19
lines changed

3 files changed

+67
-19
lines changed

peer.go

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,15 @@ func (p *peer) CountSession() int {
207207

208208
// Dial connects with the peer of the destination address.
209209
func (p *peer) Dial(addr string, protoFunc ...ProtoFunc) (Session, *Status) {
210+
stat := p.pluginContainer.preDial(p.dialer.localAddr, addr)
211+
if !stat.OK() {
212+
return nil, stat
213+
}
210214
var sess = newSession(p, nil, protoFunc)
211215
_, err := p.dialer.dialWithRetry(addr, "", func(conn net.Conn) error {
212216
sess.socket.Reset(conn, protoFunc...)
213217
sess.socket.SetID(sess.LocalAddr().String())
214-
if stat := p.pluginContainer.postDial(sess, false); !stat.OK() {
218+
if stat = p.pluginContainer.postDial(sess, false); !stat.OK() {
215219
conn.Close()
216220
return stat.Cause()
217221
}
@@ -227,23 +231,26 @@ func (p *peer) Dial(addr string, protoFunc ...ProtoFunc) (Session, *Status) {
227231
oldID := sess.ID()
228232
oldIP := sess.LocalAddr().String()
229233
oldConn := sess.getConn()
230-
231-
_, err := p.dialer.dialWithRetry(addr, oldID, func(conn net.Conn) error {
232-
sess.socket.Reset(conn, protoFunc...)
233-
if oldIP == oldID {
234-
sess.socket.SetID(sess.LocalAddr().String())
235-
} else {
236-
sess.socket.SetID(oldID)
237-
}
238-
sess.changeStatus(statusPreparing)
239-
if stat := p.pluginContainer.postDial(sess, true); !stat.OK() {
240-
conn.Close()
241-
sess.changeStatus(statusRedialing)
242-
return stat.Cause()
243-
}
244-
return nil
245-
})
246-
234+
var err error
235+
if stat := p.pluginContainer.preDial(p.dialer.localAddr, addr); stat.OK() {
236+
_, err = p.dialer.dialWithRetry(addr, oldID, func(conn net.Conn) error {
237+
sess.socket.Reset(conn, protoFunc...)
238+
if oldIP == oldID {
239+
sess.socket.SetID(sess.LocalAddr().String())
240+
} else {
241+
sess.socket.SetID(oldID)
242+
}
243+
sess.changeStatus(statusPreparing)
244+
if stat := p.pluginContainer.postDial(sess, true); !stat.OK() {
245+
conn.Close()
246+
sess.changeStatus(statusRedialing)
247+
return stat.Cause()
248+
}
249+
return nil
250+
})
251+
} else {
252+
err = stat.Cause()
253+
}
247254
if err != nil {
248255
sess.closeLocked()
249256
sess.tryChangeStatus(statusRedialFailed, statusRedialing)

plugin.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ type (
4949
Plugin
5050
PostListen(net.Addr) error
5151
}
52+
// PreDialPlugin is executed before dialing.
53+
PreDialPlugin interface {
54+
Plugin
55+
PreDial(localAddr net.Addr, remoteAddr string) *Status
56+
}
5257
// PostDialPlugin is executed after dialing.
5358
PostDialPlugin interface {
5459
Plugin
@@ -362,7 +367,32 @@ func (p *pluginSingleContainer) postListen(addr net.Addr) {
362367
return
363368
}
364369

365-
// PostDial executes the defined plugins after dialing.
370+
// preDial executes the defined plugins before dialing.
371+
func (p *pluginSingleContainer) preDial(localAddr net.Addr, remoteAddr string) (stat *Status) {
372+
var pluginName string
373+
defer func() {
374+
if p := recover(); p != nil {
375+
Errorf("[PreDialPlugin:%s] network:%s, localAddr%s ,remoteAddr:%s, panic:%v\n%s", pluginName, localAddr.Network(), localAddr.String(), remoteAddr, p, goutil.PanicTrace(2))
376+
stat = statDialFailed.Copy(p)
377+
}
378+
}()
379+
for _, plugin := range p.plugins {
380+
if _plugin, ok := plugin.(PreDialPlugin); ok {
381+
pluginName = plugin.Name()
382+
if stat = _plugin.PreDial(localAddr, remoteAddr); !stat.OK() {
383+
LazyDebugf(func() string {
384+
return fmt.Sprintf("[PreDialPlugin:%s] network:%s, localAddr%s ,remoteAddr:%s, error:%s",
385+
pluginName, localAddr.Network(), localAddr.String(), remoteAddr, stat.String(),
386+
)
387+
})
388+
return stat
389+
}
390+
}
391+
}
392+
return nil
393+
}
394+
395+
// postDial executes the defined plugins after dialing.
366396
func (p *pluginSingleContainer) postDial(sess PreSession, isRedial bool) (stat *Status) {
367397
var pluginName string
368398
defer func() {

plugin_impl.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ var (
88
_ PostNewPeerPlugin = (*PluginImpl)(nil)
99
_ PostRegPlugin = (*PluginImpl)(nil)
1010
_ PostListenPlugin = (*PluginImpl)(nil)
11+
_ PreDialPlugin = (*PluginImpl)(nil)
1112
_ PostDialPlugin = (*PluginImpl)(nil)
1213
_ PostAcceptPlugin = (*PluginImpl)(nil)
1314
_ PreWriteCallPlugin = (*PluginImpl)(nil)
@@ -41,6 +42,8 @@ type PluginImpl struct {
4142
OnPostReg func(*Handler) error
4243
// OnPostListen is called after a listener is created.
4344
OnPostListen func(net.Addr) error
45+
// OnPreDial is called before a dial is created.
46+
OnPreDial func(localAddr net.Addr, remoteAddr string) *Status
4447
// OnPostDial is called after a dial is created.
4548
OnPostDial func(sess PreSession, isRedial bool) *Status
4649
// OnPostAccept is called after a session is accepted.
@@ -118,6 +121,14 @@ func (p *PluginImpl) PostListen(addr net.Addr) error {
118121
return p.OnPostListen(addr)
119122
}
120123

124+
// PreDial is called before a dial is created.
125+
func (p *PluginImpl) PreDial(localAddr net.Addr, remoteAddr string) *Status {
126+
if p.OnPreDial == nil {
127+
return nil
128+
}
129+
return p.OnPreDial(localAddr, remoteAddr)
130+
}
131+
121132
// PostDial is called after a dial is created.
122133
func (p *PluginImpl) PostDial(sess PreSession, isRedial bool) *Status {
123134
if p.OnPostDial == nil {

0 commit comments

Comments
 (0)