Skip to content

Commit cc95a01

Browse files
committed
session: add GetGroupID and GetSessionIDs methods
This commit adds new getters: `GetGroupID` and `GetSessionIDs` to the session store which can be used to query the newly added indexes to get the associated group ID for a session ID or the associated set of session IDs for a group ID.
1 parent e90df06 commit cc95a01

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

session/interface.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ func NewSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type,
128128
return sess, nil
129129
}
130130

131+
// IDToGroupIndex defines an interface for the session ID to group ID index.
132+
type IDToGroupIndex interface {
133+
// GetGroupID will return the group ID for the given session ID.
134+
GetGroupID(sessionID ID) (ID, error)
135+
136+
// GetSessionIDs will return the set of session IDs that are in the
137+
// group with the given ID.
138+
GetSessionIDs(groupID ID) ([]ID, error)
139+
}
140+
131141
// Store is the interface a persistent storage must implement for storing and
132142
// retrieving Terminal Connect sessions.
133143
type Store interface {
@@ -160,4 +170,6 @@ type Store interface {
160170

161171
// GetSessionByID fetches the session with the given ID.
162172
GetSessionByID(id ID) (*Session, error)
173+
174+
IDToGroupIndex
163175
}

session/store.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,91 @@ func (db *DB) GetUnusedIDAndKeyPair() (ID, *btcec.PrivateKey, error) {
341341
return id, privKey, nil
342342
}
343343

344+
// GetGroupID will return the group ID for the given session ID.
345+
//
346+
// NOTE: this is part of the IDToGroupIndex interface.
347+
func (db *DB) GetGroupID(sessionID ID) (ID, error) {
348+
var groupID ID
349+
err := db.View(func(tx *bbolt.Tx) error {
350+
sessionBkt, err := getBucket(tx, sessionBucketKey)
351+
if err != nil {
352+
return err
353+
}
354+
355+
idIndex := sessionBkt.Bucket(idIndexKey)
356+
if idIndex == nil {
357+
return ErrDBInitErr
358+
}
359+
360+
sessionIDBkt := idIndex.Bucket(sessionID[:])
361+
if sessionIDBkt == nil {
362+
return fmt.Errorf("no index entry for session ID: %x",
363+
sessionID)
364+
}
365+
366+
groupIDBytes := sessionIDBkt.Get(groupIDKey)
367+
if len(groupIDBytes) == 0 {
368+
return fmt.Errorf("group ID not found for session "+
369+
"ID %x", sessionID)
370+
}
371+
372+
copy(groupID[:], groupIDBytes)
373+
374+
return nil
375+
})
376+
if err != nil {
377+
return groupID, err
378+
}
379+
380+
return groupID, nil
381+
}
382+
383+
// GetSessionIDs will return the set of session IDs that are in the
384+
// group with the given ID.
385+
//
386+
// NOTE: this is part of the IDToGroupIndex interface.
387+
func (db *DB) GetSessionIDs(groupID ID) ([]ID, error) {
388+
var sessionIDs []ID
389+
err := db.View(func(tx *bbolt.Tx) error {
390+
sessionBkt, err := getBucket(tx, sessionBucketKey)
391+
if err != nil {
392+
return err
393+
}
394+
395+
groupIndexBkt := sessionBkt.Bucket(groupIDIndexKey)
396+
if groupIndexBkt == nil {
397+
return ErrDBInitErr
398+
}
399+
400+
groupIDBkt := groupIndexBkt.Bucket(groupID[:])
401+
if groupIDBkt == nil {
402+
return fmt.Errorf("no sessions for group ID %v",
403+
groupID)
404+
}
405+
406+
sessionIDsBkt := groupIDBkt.Bucket(sessionIDKey)
407+
if sessionIDsBkt == nil {
408+
return fmt.Errorf("no sessions for group ID %v",
409+
groupID)
410+
}
411+
412+
return sessionIDsBkt.ForEach(func(_,
413+
sessionIDBytes []byte) error {
414+
415+
var sessionID ID
416+
copy(sessionID[:], sessionIDBytes)
417+
sessionIDs = append(sessionIDs, sessionID)
418+
419+
return nil
420+
})
421+
})
422+
if err != nil {
423+
return nil, err
424+
}
425+
426+
return sessionIDs, nil
427+
}
428+
344429
// addIdToKeyPair inserts the mapping from session ID to session key into the
345430
// id-index bucket. An error is returned if an entry for this ID already exists.
346431
func addIDToKeyPair(sessionBkt *bbolt.Bucket, id ID, sessionKey []byte) error {

session/store_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,68 @@ func TestLinkingSessions(t *testing.T) {
110110
require.NoError(t, db.CreateSession(s2))
111111
}
112112

113+
// TestIDToGroupIDIndex tests that the session-ID-to-group-ID and
114+
// group-ID-to-session-ID indexes work as expected by asserting the behaviour
115+
// of the GetGroupID and GetSessionIDs methods.
116+
func TestLinkedSessions(t *testing.T) {
117+
// Set up a new DB.
118+
db, err := NewDB(t.TempDir(), "test.db")
119+
require.NoError(t, err)
120+
t.Cleanup(func() {
121+
_ = db.Close()
122+
})
123+
124+
// Create a few sessions. The first one is a new session and the two
125+
// after are all linked to the prior one. All these sessions belong to
126+
// the same group. The group ID is equivalent to the session ID of the
127+
// first session.
128+
s1 := newSession(t, db, "session 1", nil)
129+
s2 := newSession(t, db, "session 2", &s1.GroupID)
130+
s3 := newSession(t, db, "session 3", &s2.GroupID)
131+
132+
// Persist the sessions.
133+
require.NoError(t, db.CreateSession(s1))
134+
require.NoError(t, db.CreateSession(s2))
135+
require.NoError(t, db.CreateSession(s3))
136+
137+
// Assert that the session ID to group ID index works as expected.
138+
for _, s := range []*Session{s1, s2, s3} {
139+
groupID, err := db.GetGroupID(s.ID)
140+
require.NoError(t, err)
141+
require.Equal(t, s1.ID, groupID)
142+
require.Equal(t, s.GroupID, groupID)
143+
}
144+
145+
// Assert that the group ID to session ID index works as expected.
146+
sIDs, err := db.GetSessionIDs(s1.GroupID)
147+
require.NoError(t, err)
148+
require.EqualValues(t, []ID{s1.ID, s2.ID, s3.ID}, sIDs)
149+
150+
// To ensure that different groups don't interfere with each other,
151+
// let's add another set of linked sessions not linked to the first.
152+
s4 := newSession(t, db, "session 4", nil)
153+
s5 := newSession(t, db, "session 5", &s4.GroupID)
154+
155+
require.NotEqual(t, s4.GroupID, s1.GroupID)
156+
157+
// Persist the sessions.
158+
require.NoError(t, db.CreateSession(s4))
159+
require.NoError(t, db.CreateSession(s5))
160+
161+
// Assert that the session ID to group ID index works as expected.
162+
for _, s := range []*Session{s4, s5} {
163+
groupID, err := db.GetGroupID(s.ID)
164+
require.NoError(t, err)
165+
require.Equal(t, s4.ID, groupID)
166+
require.Equal(t, s.GroupID, groupID)
167+
}
168+
169+
// Assert that the group ID to session ID index works as expected.
170+
sIDs, err = db.GetSessionIDs(s5.GroupID)
171+
require.NoError(t, err)
172+
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
173+
}
174+
113175
func newSession(t *testing.T, db Store, label string,
114176
linkedGroupID *ID) *Session {
115177

0 commit comments

Comments
 (0)