@@ -433,13 +433,14 @@ func (s *SAMLManager) login(w http.ResponseWriter, r *http.Request, providerName
433433 session .Values [AUTH_KEY ] = false
434434 session .Values [PROVIDER_NAME_KEY ] = providerName
435435 session .Values [NONCE_KEY ] = nonce
436+ session .Values [REDIRECT_URL ] = redirectUrl
436437 if err := session .Save (r , w ); err != nil {
437438 http .Error (w , "error saving session: " + err .Error (), http .StatusInternalServerError )
438439 return
439440 }
440441
441442 // The relay state is the session id and the redirect url, encoded in base64
442- relayState := generateRelayString ( sessionId , redirectUrl )
443+ relayState := base64 . URLEncoding . EncodeToString ([] byte ( sessionId ) )
443444 if sp .IdentityProviderSSOBinding == saml2 .BindingHttpPost {
444445 body , err := sp .BuildAuthBodyPost (relayState )
445446 if err != nil {
@@ -462,22 +463,6 @@ func (s *SAMLManager) login(w http.ResponseWriter, r *http.Request, providerName
462463 http .Redirect (w , r , url , http .StatusFound )
463464}
464465
465- func generateRelayString (sessionId , redirectUrl string ) string {
466- return base64 .URLEncoding .EncodeToString ([]byte (fmt .Sprintf ("%s;%s" , sessionId , redirectUrl )))
467- }
468-
469- func parseRelayString (relay string ) (sessionId , redirecturl string , err error ) {
470- relayStr , err := base64 .URLEncoding .DecodeString (relay )
471- if err != nil {
472- return "" , "" , fmt .Errorf ("error decoding relay state: %w" , err )
473- }
474- sessionId , redirectUrl , ok := strings .Cut (string (relayStr ), ";" )
475- if ! ok {
476- return "" , "" , fmt .Errorf ("error parsing relay state" )
477- }
478- return sessionId , redirectUrl , nil
479- }
480-
481466func (s * SAMLManager ) acs (w http.ResponseWriter , r * http.Request ) {
482467 providerName := chi .URLParam (r , "provider" )
483468 sp := s .providers [providerName ]
@@ -486,6 +471,8 @@ func (s *SAMLManager) acs(w http.ResponseWriter, r *http.Request) {
486471 return
487472 }
488473
474+ const maxACSBody = 10 << 20 // 10 MiB
475+ r .Body = http .MaxBytesReader (w , r .Body , maxACSBody )
489476 if err := r .ParseForm (); err != nil {
490477 http .Error (w , "parse form: " + err .Error (), http .StatusBadRequest )
491478 return
@@ -525,26 +512,27 @@ func (s *SAMLManager) acs(w http.ResponseWriter, r *http.Request) {
525512 )
526513 s .Trace ().Str ("user_id" , ai .NameID ).Str ("provider_name" , providerName ).Msgf ("authenticated saml user with groups %+v" , groups )
527514
528- sessionId , redirectUrl , err := parseRelayString (r .PostFormValue ("RelayState" ))
515+ sessionIdBytes , err := base64 . URLEncoding . DecodeString (r .PostFormValue ("RelayState" ))
529516 if err != nil {
530517 http .Error (w , err .Error (), http .StatusInternalServerError )
531518 return
532519 }
533-
520+ sessionId := string ( sessionIdBytes )
534521 stateMap , err := s .db .FetchKV (r .Context (), sessionId )
535522 if err != nil {
536523 http .Error (w , "error fetching KV state: " + err .Error (), http .StatusInternalServerError )
537524 return
538525 }
539526
540- if stateMap [PROVIDER_NAME_KEY ] != providerName || stateMap [ REDIRECT_URL ] != redirectUrl {
527+ if stateMap [PROVIDER_NAME_KEY ] != providerName {
541528 http .Error (w , "error matching session state" , http .StatusInternalServerError )
542529 return
543530 }
544531 if stateMap [AUTH_KEY ] != false {
545532 http .Error (w , "error matching session state, expected auth to be false" , http .StatusInternalServerError )
546533 return
547534 }
535+ redirectUrl := stateMap [REDIRECT_URL ].(string )
548536
549537 // Update the state map, set to authenticated and add the user id and groups
550538 stateMap [AUTH_KEY ] = true
@@ -582,17 +570,12 @@ func (s *SAMLManager) redirect(w http.ResponseWriter, r *http.Request) {
582570 }
583571
584572 sessionNonce := session .Values [NONCE_KEY ].(string )
585- sessionId , redirectUrl , err := parseRelayString (relayStr )
573+ sessionIdBytes , err := base64 . URLEncoding . DecodeString (relayStr )
586574 if err != nil {
587575 http .Error (w , err .Error (), http .StatusInternalServerError )
588576 return
589577 }
590-
591- if auth , ok := session .Values [AUTH_KEY ].(bool ); ! ok || auth {
592- // already authenticated, redirect to original url
593- http .Redirect (w , r , redirectUrl , http .StatusFound )
594- return
595- }
578+ sessionId := string (sessionIdBytes )
596579
597580 // Get the state map, delete the entry from database, validate state, set the session values
598581 // in the cookie and then redirect to original url
@@ -607,6 +590,17 @@ func (s *SAMLManager) redirect(w http.ResponseWriter, r *http.Request) {
607590 http .Error (w , "error deleting state: " + err .Error (), http .StatusInternalServerError )
608591 return
609592 }
593+ redirectUrl , ok := session .Values [REDIRECT_URL ].(string )
594+ if ! ok {
595+ http .Error (w , "error matching session, redirect url not found" , http .StatusInternalServerError )
596+ return
597+ }
598+
599+ if auth , ok := session .Values [AUTH_KEY ].(bool ); ! ok || auth {
600+ // already authenticated, redirect to original url
601+ http .Redirect (w , r , redirectUrl , http .StatusFound )
602+ return
603+ }
610604
611605 if stateMap [PROVIDER_NAME_KEY ] != providerName || stateMap [REDIRECT_URL ] != redirectUrl {
612606 http .Error (w , "error matching session state" , http .StatusInternalServerError )
0 commit comments