@@ -15,6 +15,7 @@ import (
15
15
"github.com/vmihailenco/msgpack/v5"
16
16
17
17
"github.com/coreos/go-oidc/v3/oidc"
18
+ cache "github.com/go-pkgz/expirable-cache/v3"
18
19
action "github.com/negasus/haproxy-spoe-go/action"
19
20
message "github.com/negasus/haproxy-spoe-go/message"
20
21
@@ -74,7 +75,7 @@ type OIDCAuthenticator struct {
74
75
75
76
signatureComputer * HmacSha256Computer
76
77
encryptor * AESEncryptor
77
- pkceVerifier string
78
+ pkceVerifierCache cache. Cache [ string , string ]
78
79
79
80
options OIDCAuthenticatorOptions
80
81
}
@@ -120,7 +121,7 @@ func NewOIDCAuthenticator(options OIDCAuthenticatorOptions) *OIDCAuthenticator {
120
121
options : options ,
121
122
signatureComputer : NewHmacSha256Computer (options .SignatureSecret ),
122
123
encryptor : NewAESEncryptor (options .EncryptionSecret ),
123
- pkceVerifier : oauth2 . GenerateVerifier (),
124
+ pkceVerifierCache : cache . NewCache [ string , string ] (),
124
125
}
125
126
126
127
go func () {
@@ -396,8 +397,15 @@ func (oa *OIDCAuthenticator) buildAuthorizationURL(domain string, oauthArgs OAut
396
397
}
397
398
398
399
var authorizationURL string
400
+ pkceVerifier := oauth2 .GenerateVerifier ()
401
+ stateStr := base64 .StdEncoding .EncodeToString (stateBytes )
402
+ cacheTTL := time .Second * 3600
403
+ if oa .options .CookieTTL != 0 {
404
+ cacheTTL = oa .options .CookieTTL
405
+ }
406
+ oa .pkceVerifierCache .Set (stateStr , pkceVerifier , cacheTTL )
399
407
err = oa .withOAuth2Config (domain , func (config oauth2.Config ) error {
400
- authorizationURL = config .AuthCodeURL (base64 . StdEncoding . EncodeToString ( stateBytes ) , oauth2 .S256ChallengeOption (oa . pkceVerifier ))
408
+ authorizationURL = config .AuthCodeURL (stateStr , oauth2 .S256ChallengeOption (pkceVerifier ))
401
409
return nil
402
410
})
403
411
if err != nil {
@@ -435,9 +443,15 @@ func (oa *OIDCAuthenticator) handleOAuth2Callback(tmpl *template.Template, error
435
443
436
444
domain := extractDomainFromHost (r .Host )
437
445
446
+ pkceVerifier , ok := oa .pkceVerifierCache .Get (stateB64Payload )
447
+ if ! ok {
448
+ logrus .Error ("cannot retrieve pkce verifier" )
449
+ http .Error (w , "Bad request" , http .StatusBadRequest )
450
+ return
451
+ }
438
452
var oauth2Token * oauth2.Token
439
453
err := oa .withOAuth2Config (domain , func (config oauth2.Config ) error {
440
- token , err := config .Exchange (r .Context (), r .URL .Query ().Get ("code" ), oauth2 .VerifierOption (oa . pkceVerifier ))
454
+ token , err := config .Exchange (r .Context (), r .URL .Query ().Get ("code" ), oauth2 .VerifierOption (pkceVerifier ))
441
455
oauth2Token = token
442
456
return err
443
457
})
0 commit comments