@@ -64,24 +64,21 @@ func (a *Authenticator) GetOrCreateUser(ctx context.Context) (*model.User, error
6464
6565 session := sessions .Get (ctx , a .sessionName )
6666
67- if session .Values [userIDKey ] == nil {
67+ user , err := a .getCurrentUser (ctx )
68+ if errors .Is (err , errNotFound ) {
6869 // Create new user since UserID for cookie has not been created yet
6970 user , err := a .createNewUser ()
7071 if err != nil {
7172 return nil , errors .Wrap (err , "failed to create new user" )
7273 }
7374
7475 session .Values [userIDKey ] = user .ID .String ()
75-
7676 err = sessions .Save (ctx , session )
7777 if err != nil {
7878 fmt .Println ("Failed to save session!" )
7979 return nil , errors .Wrap (err , "failed to save userID to session" )
8080 }
81- }
82-
83- user , err := a .getCurrentUser (session .Values [userIDKey ].(string ))
84- if err != nil {
81+ } else if err != nil {
8582 sentry .CaptureException (errors .New (fmt .Sprintf (
8683 "Failed to load user id %s from session\n " , session .Values [userIDKey ].(string ))))
8784 return nil , errors .New ("failed to load user id from session" )
@@ -90,6 +87,23 @@ func (a *Authenticator) GetOrCreateUser(ctx context.Context) (*model.User, error
9087 return user , nil
9188}
9289
90+ func (a * Authenticator ) GetUser (ctx context.Context ) (* model.User , error ) {
91+ a .lock .Lock ()
92+ defer a .lock .Unlock ()
93+
94+ session := sessions .Get (ctx , a .sessionName )
95+ user , err := a .getCurrentUser (ctx )
96+ if err != nil {
97+ if ! errors .Is (err , errNotFound ) {
98+ sentry .CaptureException (errors .New (fmt .Sprintf (
99+ "Failed to load user id %s from session\n " , session .Values [userIDKey ].(string ))))
100+ }
101+ return nil , fmt .Errorf ("failed to load user id from session: %w" , err )
102+ }
103+
104+ return user , nil
105+ }
106+
93107// CheckProjectAccess returns an error if the current user is not authorized to mutate
94108// the provided project.
95109//
@@ -105,7 +119,7 @@ func (a *Authenticator) CheckProjectAccess(ctx context.Context, proj *model.Proj
105119 return errors .New ("no userIdKey found in session" )
106120 }
107121
108- user , err := a .getCurrentUser (session . Values [ userIDKey ].( string ) )
122+ user , err := a .getCurrentUser (ctx )
109123 if err != nil {
110124 return errors .New ("access denied" )
111125 }
@@ -138,11 +152,23 @@ func (a *Authenticator) CheckProjectAccess(ctx context.Context, proj *model.Proj
138152 return errors .New ("access denied" )
139153}
140154
141- func (a * Authenticator ) getCurrentUser (userIDStr string ) (* model.User , error ) {
155+ var errNotFound = errors .New ("user not found" )
156+
157+ // getCurrentUser from the request context using the session to get the user ID
158+ // and retrieving that user from the storage.
159+ // expected errors:
160+ // - errNotFound user was not created
161+ func (a * Authenticator ) getCurrentUser (ctx context.Context ) (* model.User , error ) {
162+ session := sessions .Get (ctx , a .sessionName )
163+ if session .Values [userIDKey ] == nil {
164+ return nil , errNotFound
165+ }
166+ rawUserID := session .Values [userIDKey ].(string )
167+
142168 var user model.User
143169 var userID uuid.UUID
144170
145- err := userID .UnmarshalText ([]byte (userIDStr ))
171+ err := userID .UnmarshalText ([]byte (rawUserID ))
146172 if err != nil {
147173 return nil , errors .Wrap (err , "failed to unmarshal userIDStr" )
148174 }
0 commit comments