Skip to content

Commit 690a44f

Browse files
committed
bridge: Fix credential reset after unauthorized sign-in
1 parent af7fc65 commit 690a44f

File tree

11 files changed

+352
-54
lines changed

11 files changed

+352
-54
lines changed

bridge/device/cloud/manager.go

Lines changed: 98 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"context"
2323
"crypto/tls"
2424
"crypto/x509"
25+
"errors"
2526
"fmt"
2627
"reflect"
2728
"sync"
@@ -41,6 +42,7 @@ import (
4142
"github.com/plgd-dev/device/v2/schema/cloud"
4243
"github.com/plgd-dev/device/v2/schema/device"
4344
plgdResources "github.com/plgd-dev/device/v2/schema/resources"
45+
"github.com/plgd-dev/go-coap/v3/message"
4446
"github.com/plgd-dev/go-coap/v3/message/codes"
4547
"github.com/plgd-dev/go-coap/v3/message/pool"
4648
"github.com/plgd-dev/go-coap/v3/mux"
@@ -50,8 +52,6 @@ import (
5052
"github.com/plgd-dev/go-coap/v3/tcp/client"
5153
)
5254

53-
const tickInterval = time.Second * 10
54-
5555
type (
5656
GetLinksFilteredBy func(endpoints schema.Endpoints, deviceIDfilter uuid.UUID, resourceTypesFitler []string, policyBitMaskFitler schema.BitMask) (links schema.ResourceLinks)
5757
GetCertificates func(deviceID string) []tls.Certificate
@@ -82,22 +82,25 @@ type Manager struct {
8282
caPool CAPoolGetter
8383
getCertificates GetCertificates
8484
removeCloudCAs RemoveCloudCAs
85+
tickInterval time.Duration
8586

8687
private struct {
8788
mutex sync.Mutex
8889
cfg Configuration
8990
previousCloudIDs []string
9091
readyToPublishResources map[string]struct{}
9192
readyToUnpublishResources map[string]struct{}
93+
creds ocfCloud.CoapSignUpResponse
9294
}
9395

9496
logger log.Logger
95-
creds ocfCloud.CoapSignUpResponse
9697
client *client.Conn
9798
signedIn bool
9899
resourcesPublished bool
100+
forceRefreshToken bool
99101
done chan struct{}
100102
stopped atomic.Bool
103+
reconnect atomic.Bool
101104
trigger chan bool
102105
loop *eventloop.Loop
103106
}
@@ -114,7 +117,8 @@ func New(cfg Config, deviceID uuid.UUID, save func(), handler net.RequestHandler
114117
removeCloudCAs: func(...string) {
115118
// do nothing
116119
},
117-
logger: log.NewNilLogger(),
120+
logger: log.NewNilLogger(),
121+
tickInterval: time.Second * 10,
118122
}
119123
for _, opt := range opts {
120124
opt(&o)
@@ -133,6 +137,7 @@ func New(cfg Config, deviceID uuid.UUID, save func(), handler net.RequestHandler
133137
removeCloudCAs: o.removeCloudCAs,
134138
logger: o.logger,
135139
loop: loop,
140+
tickInterval: o.tickInterval,
136141
}
137142
c.private.cfg.ProvisioningStatus = cloud.ProvisioningStatus_UNINITIALIZED
138143
c.importConfig(cfg)
@@ -186,6 +191,13 @@ func (c *Manager) handleTrigger(value reflect.Value, closed bool) {
186191
if wantToReset {
187192
c.resetCredentials(ctx, true)
188193
}
194+
if c.reconnect.CompareAndSwap(true, false) {
195+
err := c.close()
196+
if err != nil && !errors.Is(err, context.Canceled) {
197+
c.logger.Errorf("cannot close connection for reconnect: %w", err)
198+
}
199+
return
200+
}
189201
if !c.isInitialized() {
190202
// resources will be published after sign in
191203
c.resetPublishing()
@@ -220,7 +232,7 @@ func (c *Manager) Init() {
220232
if c.private.cfg.URL != "" {
221233
c.triggerRunner(false)
222234
}
223-
t := time.NewTicker(tickInterval)
235+
t := time.NewTicker(c.tickInterval)
224236
handlers := []eventloop.Handler{
225237
eventloop.NewReadHandler(reflect.ValueOf(c.trigger), c.handleTrigger),
226238
eventloop.NewReadHandler(reflect.ValueOf(t.C), c.handleTimer),
@@ -242,14 +254,16 @@ func (c *Manager) resetCredentials(ctx context.Context, signOff bool) {
242254
c.logger.Debugf("%w", err)
243255
}
244256
}
245-
c.creds = ocfCloud.CoapSignUpResponse{}
246-
c.signedIn = false
257+
c.setCreds(ocfCloud.CoapSignUpResponse{})
247258
c.resourcesPublished = false
259+
c.forceRefreshToken = false
260+
c.reconnect.Store(false)
248261
if err := c.close(); err != nil {
249262
c.logger.Warnf("cannot close connection: %w", err)
250263
}
251264
c.save()
252265
c.removePreviousCloudIDs()
266+
c.logger.Infof("reset credentials")
253267
}
254268

255269
func (c *Manager) cleanup() {
@@ -347,29 +361,40 @@ func validUntil(expiresIn int64) time.Time {
347361
}
348362

349363
func (c *Manager) setCreds(creds ocfCloud.CoapSignUpResponse) {
350-
c.creds = creds
364+
c.private.mutex.Lock()
365+
defer c.private.mutex.Unlock()
366+
c.private.creds = creds
351367
c.signedIn = false
352368
}
353369

370+
func (c *Manager) updateCreds(f func(creds *ocfCloud.CoapSignUpResponse)) {
371+
c.private.mutex.Lock()
372+
defer c.private.mutex.Unlock()
373+
f(&c.private.creds)
374+
}
375+
354376
func (c *Manager) getCreds() ocfCloud.CoapSignUpResponse {
355-
return c.creds
377+
c.private.mutex.Lock()
378+
defer c.private.mutex.Unlock()
379+
return c.private.creds
356380
}
357381

358382
func (c *Manager) isCredsExpiring() bool {
359-
if c.creds.ValidUntil.IsZero() {
383+
creds := c.getCreds()
384+
if creds.ValidUntil.IsZero() {
360385
return false
361386
}
362-
diff := time.Until(c.creds.ValidUntil)
363-
if diff < tickInterval*2 {
387+
diff := time.Until(creds.ValidUntil)
388+
if diff < c.tickInterval*2 {
364389
// refresh token before it expires
365390
return true
366391
}
367392
// refresh token when it is 1/3 before it expires
368-
return time.Now().After(c.creds.ValidUntil.Add(-diff / 3))
393+
return time.Now().After(creds.ValidUntil.Add(-diff / 3))
369394
}
370395

371-
func getResourceTypesFilter(request *mux.Message) []string {
372-
queries, _ := request.Options().Queries()
396+
func getResourceTypesFilter(messageOptions message.Options) []string {
397+
queries, _ := messageOptions.Queries()
373398
resourceTypesFitler := []string{}
374399
for _, q := range queries {
375400
if len(q) > 3 && q[:3] == "rt=" {
@@ -379,37 +404,64 @@ func getResourceTypesFilter(request *mux.Message) []string {
379404
return resourceTypesFitler
380405
}
381406

382-
func (c *Manager) serveCOAP(w mux.ResponseWriter, request *mux.Message) {
383-
request.Message.AddQuery("di=" + c.deviceID.String())
384-
r := net.Request{
385-
Message: request.Message,
386-
Endpoints: nil,
387-
Conn: w.Conn(),
407+
func inFilterSupportedCodes(request *mux.Message) bool {
408+
switch request.Code() {
409+
case codes.POST, codes.PUT, codes.DELETE, codes.GET:
410+
return true
411+
default:
412+
return false
388413
}
389-
var resp *pool.Message
414+
}
415+
416+
func (c *Manager) handleDeviceResource(r *net.Request) (*pool.Message, error) {
417+
links := c.getLinks(schema.Endpoints{}, c.deviceID, nil, resources.PublishToCloud)
418+
for _, link := range links {
419+
if link.HasType(device.ResourceType) {
420+
_ = r.SetPath(link.Href)
421+
break
422+
}
423+
}
424+
return c.handler(r)
425+
}
426+
427+
func (c *Manager) handleDiscoveryResource(r *net.Request) (*pool.Message, error) {
428+
links := c.getLinks(schema.Endpoints{}, c.deviceID, getResourceTypesFilter(r.Message.Options()), resources.PublishToCloud)
429+
links = patchDeviceLink(links)
430+
links = discovery.PatchLinks(links, c.deviceID.String())
431+
return resources.CreateResponseContent(r.Context(), links, codes.Content)
432+
}
433+
434+
func (c *Manager) getHandler(r *net.Request) func(r *net.Request) (*pool.Message, error) {
435+
h := c.handler
390436
p, err := r.Path()
391437
if err == nil {
392438
switch p {
393439
case device.ResourceURI:
394-
links := c.getLinks(schema.Endpoints{}, c.deviceID, nil, resources.PublishToCloud)
395-
for _, link := range links {
396-
if link.HasType(device.ResourceType) {
397-
_ = r.SetPath(link.Href)
398-
break
399-
}
400-
}
401-
resp, err = c.handler(&r)
440+
h = c.handleDeviceResource
402441
case plgdResources.ResourceURI:
403-
links := c.getLinks(schema.Endpoints{}, c.deviceID, getResourceTypesFilter(request), resources.PublishToCloud)
404-
links = patchDeviceLink(links)
405-
links = discovery.PatchLinks(links, c.deviceID.String())
406-
resp, err = resources.CreateResponseContent(request.Context(), links, codes.Content)
407-
default:
408-
resp, err = c.handler(&r)
442+
h = c.handleDiscoveryResource
409443
}
410-
} else {
411-
resp, err = c.handler(&r)
412444
}
445+
return h
446+
}
447+
448+
func (c *Manager) serveCOAP(w mux.ResponseWriter, request *mux.Message) {
449+
if !inFilterSupportedCodes(request) {
450+
// ignore unsupported request
451+
if w.Conn().Context().Err() == nil {
452+
// log only if connection is still open
453+
c.logger.Debugf("unsupported request: %v\n", request)
454+
}
455+
return
456+
}
457+
request.Message.AddQuery("di=" + c.deviceID.String())
458+
r := net.Request{
459+
Message: request.Message,
460+
Endpoints: nil,
461+
Conn: w.Conn(),
462+
}
463+
h := c.getHandler(&r)
464+
resp, err := h(&r)
413465
if err != nil {
414466
resp = net.CreateResponseError(request.Context(), err, request.Token())
415467
}
@@ -502,8 +554,9 @@ func patchDeviceLink(links schema.ResourceLinks) schema.ResourceLinks {
502554

503555
func (c *Manager) connect(ctx context.Context) error {
504556
funcs := make([]func(ctx context.Context) error, 0, 5)
505-
if c.isCredsExpiring() {
557+
if c.isCredsExpiring() || c.forceRefreshToken {
506558
funcs = append(funcs, c.refreshToken)
559+
c.forceRefreshToken = false
507560
}
508561
funcs = append(funcs, []func(ctx context.Context) error{
509562
c.signUp,
@@ -517,7 +570,7 @@ func (c *Manager) connect(ctx context.Context) error {
517570
}
518571
for _, f := range funcs {
519572
r := func(ctx context.Context) error {
520-
fctx, cancel := context.WithTimeout(ctx, time.Second*10)
573+
fctx, cancel := context.WithTimeout(ctx, c.tickInterval)
521574
defer cancel()
522575
return f(fctx)
523576
}
@@ -584,6 +637,11 @@ func (c *Manager) popReadyToPublishResources() map[string]struct{} {
584637
return res
585638
}
586639

640+
func (c *Manager) Reconnect() {
641+
c.reconnect.Store(true)
642+
c.triggerRunner(false)
643+
}
644+
587645
func (c *Manager) popReadyToUnpublishResources(count int) []string {
588646
c.private.mutex.Lock()
589647
defer c.private.mutex.Unlock()

bridge/device/cloud/manager_test.go

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ import (
2727
"time"
2828

2929
"github.com/google/uuid"
30+
"github.com/plgd-dev/device/v2/bridge/device"
3031
"github.com/plgd-dev/device/v2/bridge/device/cloud"
3132
"github.com/plgd-dev/device/v2/bridge/net"
3233
"github.com/plgd-dev/device/v2/bridge/resources"
3334
bridgeTest "github.com/plgd-dev/device/v2/bridge/test"
3435
"github.com/plgd-dev/device/v2/pkg/codec/cbor"
3536
codecOcf "github.com/plgd-dev/device/v2/pkg/codec/ocf"
37+
ocfCloud "github.com/plgd-dev/device/v2/pkg/ocf/cloud"
3638
cloudSchema "github.com/plgd-dev/device/v2/schema/cloud"
3739
"github.com/plgd-dev/device/v2/schema/interfaces"
3840
"github.com/plgd-dev/device/v2/test"
@@ -42,6 +44,7 @@ import (
4244
"github.com/plgd-dev/go-coap/v3/message"
4345
"github.com/plgd-dev/go-coap/v3/message/codes"
4446
"github.com/plgd-dev/go-coap/v3/message/pool"
47+
"github.com/plgd-dev/go-coap/v3/message/status"
4548
"github.com/stretchr/testify/require"
4649
)
4750

@@ -68,7 +71,79 @@ func (r *resourceDataSync) copy() resourceData {
6871
}
6972
}
7073

71-
// device is restarted with an imported configuration with valid cloud credentials
74+
func getUnauthorizedError() status.Status {
75+
msg := pool.NewMessage(context.Background())
76+
msg.SetCode(codes.Unauthorized)
77+
return status.Errorf(msg, "unauthorized")
78+
}
79+
80+
func TestManagerDeviceBecomesUnauthorized(t *testing.T) {
81+
ch := mockCoapGW.NewCoapHandlerWithCounter(3600)
82+
customHandler := mockCoapGW.NewCustomHandler(ch)
83+
makeHandler := func(*mockCoapGWService.Service, ...mockCoapGWService.Option) mockCoapGWService.ServiceHandler {
84+
return customHandler
85+
}
86+
coapShutdown := mockCoapGW.New(t, makeHandler, func(mockCoapGWService.ServiceHandler) {
87+
h := ch
88+
fmt.Printf("%+v\n", h.CallCounter.Data)
89+
// d1 -> signup + signin + publish
90+
// d2 -> should use the stored credentials to skip signup and only do sign in + publish
91+
require.Equal(t, 1, h.CallCounter.Data[mockCoapGW.SignUpKey])
92+
require.Equal(t, 1, h.CallCounter.Data[mockCoapGW.SignInKey])
93+
require.Equal(t, 1, h.CallCounter.Data[mockCoapGW.PublishKey])
94+
require.Equal(t, 0, h.CallCounter.Data[mockCoapGW.UnpublishKey])
95+
require.Equal(t, 0, h.CallCounter.Data[mockCoapGW.RefreshTokenKey])
96+
})
97+
defer coapShutdown()
98+
99+
s1 := bridgeTest.NewBridgeService(t)
100+
t.Cleanup(func() {
101+
_ = s1.Shutdown()
102+
})
103+
deviceID := uuid.New().String()
104+
tickInterval := time.Second
105+
d1 := bridgeTest.NewBridgedDevice(t, s1, deviceID, true, false, device.WithCloudOptions(cloud.WithTickInterval(tickInterval)))
106+
s1Shutdown := bridgeTest.RunBridgeService(s1)
107+
t.Cleanup(func() {
108+
_ = s1Shutdown()
109+
})
110+
111+
c, err := testClient.NewTestSecureClientWithBridgeSupport()
112+
require.NoError(t, err)
113+
defer func() {
114+
errC := c.Close(context.Background())
115+
require.NoError(t, errC)
116+
}()
117+
118+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
119+
defer cancel()
120+
err = c.OnboardDevice(ctx, deviceID, "authorizationProvider", "coaps+tcp://"+mockCoapGW.COAP_GW_HOST, "authorizationCode", test.CloudSID())
121+
require.NoError(t, err)
122+
123+
// wait for sign in
124+
require.Equal(t, 1, ch.WaitForSignIn(time.Second*20))
125+
126+
// wait for publish
127+
require.Equal(t, 1, ch.WaitForPublish(time.Second*20))
128+
129+
customHandler.SetSignIn(func(ocfCloud.CoapSignInRequest) (ocfCloud.CoapSignInResponse, error) {
130+
return ocfCloud.CoapSignInResponse{}, getUnauthorizedError()
131+
})
132+
customHandler.SetRefreshToken(func(ocfCloud.CoapRefreshTokenRequest) (ocfCloud.CoapRefreshTokenResponse, error) {
133+
return ocfCloud.CoapRefreshTokenResponse{}, getUnauthorizedError()
134+
})
135+
136+
d1.GetCloudManager().Reconnect()
137+
for i := 0; i < 5; i++ {
138+
cfg := d1.GetCloudManager().ExportConfig()
139+
if cfg.AccessToken == "" {
140+
return
141+
}
142+
time.Sleep(tickInterval)
143+
}
144+
require.Fail(t, "cloud manager should be reset, but it is not")
145+
}
146+
72147
func TestProvisioningOnDeviceRestart(t *testing.T) {
73148
ch := mockCoapGW.NewCoapHandlerWithCounter(-1)
74149
makeHandler := func(*mockCoapGWService.Service, ...mockCoapGWService.Option) mockCoapGWService.ServiceHandler {

0 commit comments

Comments
 (0)