Skip to content

Commit 2b62de5

Browse files
committed
Handle discovered ports TTL
1 parent a6b143e commit 2b62de5

File tree

2 files changed

+163
-3
lines changed

2 files changed

+163
-3
lines changed

cache.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//
2+
// This file is part of mdns-discovery.
3+
//
4+
// Copyright 2018-2021 ARDUINO SA (http://www.arduino.cc/)
5+
//
6+
// This software is released under the GNU General Public License version 3,
7+
// which covers the main part of arduino-cli.
8+
// The terms of this license can be found at:
9+
// https://www.gnu.org/licenses/gpl-3.0.en.html
10+
//
11+
// You can be released from the requirements of the above licenses by purchasing
12+
// a commercial license. Buying such a license is mandatory if you want to modify or
13+
// otherwise use the software for commercial activities involving the Arduino
14+
// software without disclosing the source code of your own applications. To purchase
15+
// a commercial license, send an email to license@arduino.cc.
16+
//
17+
18+
package main
19+
20+
import (
21+
"context"
22+
"sync"
23+
"time"
24+
25+
discovery "github.com/arduino/pluggable-discovery-protocol-handler"
26+
)
27+
28+
// cacheItem stores the Port discovered and its timer to handle TTL.
29+
type cacheItem struct {
30+
Port *discovery.Port
31+
timerTTL *time.Timer
32+
ctx context.Context
33+
cancelFunc context.CancelFunc
34+
}
35+
36+
// portsCached is a cache to store discovered ports with a TTL.
37+
// Ports that reach their TTL are automatically deleted and
38+
// main discovery process is notified.
39+
// All operations on the internal data are thread safe.
40+
type portsCache struct {
41+
data map[string]*cacheItem
42+
dataMutex sync.Mutex
43+
itemsTTL time.Duration
44+
deletionCallback func(port *discovery.Port)
45+
}
46+
47+
// NewCache creates a new portsCache and returns it.
48+
// itemsTTL is the TTL of a single item, when it's reached
49+
// the stored item is deleted.
50+
func NewCache(itemsTTL time.Duration) *portsCache {
51+
return &portsCache{
52+
itemsTTL: itemsTTL,
53+
data: make(map[string]*cacheItem),
54+
}
55+
}
56+
57+
// Set stores a new port and sets its TTL.
58+
// If the specified key is already found the item's TTL is
59+
// renewed.
60+
// Set is thread safe.
61+
func (c *portsCache) Set(key string, port *discovery.Port) {
62+
c.dataMutex.Lock()
63+
defer c.dataMutex.Unlock()
64+
// We need a cancellable context to avoid leaving
65+
// goroutines hanging if an item's timer TTL is stopped.
66+
ctx, cancelFunc := context.WithCancel(context.Background())
67+
item, ok := c.data[key]
68+
timerTTL := time.NewTimer(c.itemsTTL)
69+
if ok {
70+
item.timerTTL.Stop()
71+
// If we stop the timer the goroutine that waits
72+
// for it to go off would just hang forever if we
73+
// don't cancel the item's context
74+
item.cancelFunc()
75+
item.timerTTL = timerTTL
76+
item.ctx = ctx
77+
item.cancelFunc = cancelFunc
78+
} else {
79+
item = &cacheItem{
80+
Port: port,
81+
timerTTL: timerTTL,
82+
ctx: ctx,
83+
cancelFunc: cancelFunc,
84+
}
85+
c.data[key] = item
86+
}
87+
88+
go func(key string, item *cacheItem) {
89+
select {
90+
case <-item.timerTTL.C:
91+
c.dataMutex.Lock()
92+
defer c.dataMutex.Unlock()
93+
c.deletionCallback(item.Port)
94+
delete(c.data, key)
95+
return
96+
case <-item.ctx.Done():
97+
// The TTL has been renewed, we also stop the timer
98+
// that keeps track of the TTL.
99+
// The channel of a stopped timer will never fire so
100+
// if keep waiting for it in this goroutine it will
101+
// hang forever.
102+
// Using a context we handle this gracefully.
103+
return
104+
}
105+
}(key, item)
106+
}
107+
108+
// Get returns the item stored with the specified key and true.
109+
// If the item is not found returns a nil port and false.
110+
// Get is thread safe.
111+
func (c *portsCache) Get(key string) (*discovery.Port, bool) {
112+
c.dataMutex.Lock()
113+
defer c.dataMutex.Unlock()
114+
if item, ok := c.data[key]; ok {
115+
return item.Port, ok
116+
}
117+
return nil, false
118+
}
119+
120+
// Clear removes all the stored items and stops their TTL timers.
121+
// Clear is thread safe.
122+
func (c *portsCache) Clear() {
123+
c.dataMutex.Lock()
124+
defer c.dataMutex.Unlock()
125+
for key, item := range c.data {
126+
item.timerTTL.Stop()
127+
item.cancelFunc()
128+
delete(c.data, key)
129+
}
130+
}

main.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,30 @@ func main() {
4040
fmt.Fprintf(os.Stderr, "Error: %s\n", err.Error())
4141
os.Exit(1)
4242
}
43-
4443
}
4544

4645
const mdnsServiceName = "_arduino._tcp"
4746

47+
// Discovered ports stay alive for this amount of time
48+
// since the last time they've been found by an mDNS query.
49+
const portsTTL = time.Second * 60
50+
51+
// This is interval at which mDNS queries are made.
52+
const discoveryInterval = time.Second * 15
53+
4854
// MDNSDiscovery is the implementation of the network pluggable-discovery
4955
type MDNSDiscovery struct {
5056
cancelFunc func()
5157
entriesChan chan *mdns.ServiceEntry
58+
59+
portsCache *portsCache
5260
}
5361

5462
// Hello handles the pluggable-discovery HELLO command
5563
func (d *MDNSDiscovery) Hello(userAgent string, protocolVersion int) error {
5664
// The mdns library used has some logs statement that we must disable
5765
log.SetOutput(ioutil.Discard)
66+
d.portsCache = NewCache(portsTTL)
5867
return nil
5968
}
6069

@@ -68,6 +77,9 @@ func (d *MDNSDiscovery) Stop() error {
6877
close(d.entriesChan)
6978
d.entriesChan = nil
7079
}
80+
if d.portsCache != nil {
81+
d.portsCache.Clear()
82+
}
7183
return nil
7284
}
7385

@@ -82,20 +94,34 @@ func (d *MDNSDiscovery) StartSync(eventCB discovery.EventCallback, errorCB disco
8294
return fmt.Errorf("already syncing")
8395
}
8496

97+
if d.portsCache.deletionCallback == nil {
98+
// We can't set the cache deletion callback at creation,
99+
// this is the only place we can get the callback
100+
d.portsCache.deletionCallback = func(port *discovery.Port) {
101+
eventCB("remove", port)
102+
}
103+
}
104+
85105
d.entriesChan = make(chan *mdns.ServiceEntry, 4)
86106
var receiver <-chan *mdns.ServiceEntry = d.entriesChan
87107
var sender chan<- *mdns.ServiceEntry = d.entriesChan
88108

89109
go func() {
90110
for entry := range receiver {
91-
eventCB("add", toDiscoveryPort(entry))
111+
port := toDiscoveryPort(entry)
112+
key := portKey(port)
113+
if _, ok := d.portsCache.Get(key); !ok {
114+
// Port is not cached so let the user know a new one has been found
115+
eventCB("add", port)
116+
}
117+
d.portsCache.Set(portKey(port), port)
92118
}
93119
}()
94120

95121
params := &mdns.QueryParam{
96122
Service: mdnsServiceName,
97123
Domain: "local",
98-
Timeout: time.Second * 15,
124+
Timeout: discoveryInterval,
99125
Entries: sender,
100126
WantUnicastResponse: false,
101127
}
@@ -117,6 +143,10 @@ func (d *MDNSDiscovery) StartSync(eventCB discovery.EventCallback, errorCB disco
117143
return nil
118144
}
119145

146+
func portKey(p *discovery.Port) string {
147+
return fmt.Sprintf("%s:%s %s", p.Address, p.Properties.Get("port"), p.Properties.Get("board"))
148+
}
149+
120150
func toDiscoveryPort(entry *mdns.ServiceEntry) *discovery.Port {
121151
ip := ""
122152
if len(entry.AddrV4) > 0 {

0 commit comments

Comments
 (0)