diff --git a/README.md b/README.md index a967da07..115ecd22 100644 --- a/README.md +++ b/README.md @@ -101,34 +101,45 @@ func main() { ```go package example +import "io" + func BeginRegistration(w http.ResponseWriter, r *http.Request) { - user := datastore.GetUser() // Find or create the new user - options, session, err := webAuthn.BeginRegistration(user) - // handle errors if present - // store the sessionData values - JSONResponse(w, options, http.StatusOK) // return the options generated - // options.publicKey contain our registration options + user := datastore.GetUser() // Find or create the new user + options, session, err := webAuthn.BeginRegistration(user) + // handle errors if present + // store the sessionData values + JSONResponse(w, options, http.StatusOK) // return the options generated + // options.publicKey contain our registration options } func FinishRegistration(w http.ResponseWriter, r *http.Request) { - user := datastore.GetUser() // Get the user - - // Get the session data stored from the function above - session := datastore.GetSession() - - credential, err := webAuthn.FinishRegistration(user, session, r) - if err != nil { - // Handle Error and return. + user := datastore.GetUser() // Get the user - return - } - - // If creation was successful, store the credential object - // Pseudocode to add the user credential. - user.AddCredential(credential) - datastore.SaveUser(user) + // Get the session data stored from the function above + session := datastore.GetSession() - JSONResponse(w, "Registration Success", http.StatusOK) // Handle next steps + body, err := io.ReadAll(r.Body) + if err != nil{ + // Handle Error and return. + + return + } + + defer body.Close() + + credential, err := webAuthn.FinishRegistration(user, session, body) + if err != nil { + // Handle Error and return. + + return + } + + // If creation was successful, store the credential object + // Pseudocode to add the user credential. + user.AddCredential(credential) + datastore.SaveUser(user) + + JSONResponse(w, "Registration Success", http.StatusOK) // Handle next steps } ``` @@ -159,8 +170,17 @@ func FinishLogin(w http.ResponseWriter, r *http.Request) { // Get the session data stored from the function above session := datastore.GetSession() - - credential, err := webAuthn.FinishLogin(user, session, r) + + body, err := io.ReadAll(r.Body) + if err != nil{ + // Handle Error and return. + + return + } + + defer body.Close() + + credential, err := webAuthn.FinishLogin(user, session, body) if err != nil { // Handle Error and return. @@ -176,6 +196,7 @@ func FinishLogin(w http.ResponseWriter, r *http.Request) { JSONResponse(w, "Login Success", http.StatusOK) } + ``` ## Modifying Credential Options diff --git a/protocol/assertion.go b/protocol/assertion.go index 0b5eaefe..609c4e6b 100644 --- a/protocol/assertion.go +++ b/protocol/assertion.go @@ -5,8 +5,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io" - "net/http" "github.com/go-webauthn/webauthn/protocol/webauthncose" ) @@ -49,24 +47,21 @@ type ParsedAssertionResponse struct { // ParseCredentialRequestResponse parses the credential request response into a format that is either required by the // specification or makes the assertion verification steps easier to complete. This takes a http.Request that contains // the assertion response data in a raw, mostly base64 encoded format, and parses the data into manageable structures. -func ParseCredentialRequestResponse(response *http.Request) (*ParsedCredentialAssertionData, error) { - if response == nil || response.Body == nil { +func ParseCredentialRequestResponse(credentialResponse []byte) (*ParsedCredentialAssertionData, error) { + if credentialResponse == nil { return nil, ErrBadRequest.WithDetails("No response given") } - defer response.Body.Close() - defer io.Copy(io.Discard, response.Body) - - return ParseCredentialRequestResponseBody(response.Body) + return ParseCredentialRequestResponseBody(credentialResponse) } // ParseCredentialRequestResponseBody parses the credential request response into a format that is either required by // the specification or makes the assertion verification steps easier to complete. This takes an io.Reader that contains // the assertion response data in a raw, mostly base64 encoded format, and parses the data into manageable structures. -func ParseCredentialRequestResponseBody(body io.Reader) (par *ParsedCredentialAssertionData, err error) { +func ParseCredentialRequestResponseBody(credentialResponse []byte) (par *ParsedCredentialAssertionData, err error) { var car CredentialAssertionResponse - if err = decodeBody(body, &car); err != nil { + if err = json.Unmarshal(credentialResponse, &car); err != nil { return nil, ErrBadRequest.WithDetails("Parse error for Assertion").WithInfo(err.Error()) } diff --git a/protocol/assertion_test.go b/protocol/assertion_test.go index 2731850c..2f704441 100644 --- a/protocol/assertion_test.go +++ b/protocol/assertion_test.go @@ -1,9 +1,7 @@ package protocol import ( - "bytes" "encoding/base64" - "io" "testing" "github.com/stretchr/testify/assert" @@ -94,24 +92,13 @@ func TestParseCredentialRequestResponse(t *testing.T) { }, errString: "", }, - { - name: "ShouldHandleTrailingData", - args: args{ - "trailingData", - }, - expected: nil, - errString: "Parse error for Assertion", - errType: "invalid_request", - errDetails: "Parse error for Assertion", - errInfo: "The body contains trailing data", - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - body := io.NopCloser(bytes.NewReader([]byte(testAssertionResponses[tc.args.responseName]))) + body := []byte(testAssertionResponses[tc.args.responseName]) actual, err := ParseCredentialRequestResponseBody(body) diff --git a/protocol/credential.go b/protocol/credential.go index d532e2b1..29f2464a 100644 --- a/protocol/credential.go +++ b/protocol/credential.go @@ -3,8 +3,7 @@ package protocol import ( "crypto/sha256" "encoding/base64" - "io" - "net/http" + "encoding/json" ) // Credential is the basic credential type from the Credential Management specification that is inherited by WebAuthn's @@ -60,23 +59,20 @@ type ParsedCredentialCreationData struct { // ParseCredentialCreationResponse is a non-agnostic function for parsing a registration response from the http library // from stdlib. It handles some standard cleanup operations. -func ParseCredentialCreationResponse(response *http.Request) (*ParsedCredentialCreationData, error) { - if response == nil || response.Body == nil { +func ParseCredentialCreationResponse(creationResponse []byte) (*ParsedCredentialCreationData, error) { + if creationResponse == nil { return nil, ErrBadRequest.WithDetails("No response given") } - defer response.Body.Close() - defer io.Copy(io.Discard, response.Body) - - return ParseCredentialCreationResponseBody(response.Body) + return ParseCredentialCreationResponseBody(creationResponse) } // ParseCredentialCreationResponseBody is an agnostic version of ParseCredentialCreationResponse. Implementers are // therefore responsible for managing cleanup. -func ParseCredentialCreationResponseBody(body io.Reader) (pcc *ParsedCredentialCreationData, err error) { +func ParseCredentialCreationResponseBody(creationResponse []byte) (pcc *ParsedCredentialCreationData, err error) { var ccr CredentialCreationResponse - if err = decodeBody(body, &ccr); err != nil { + if err = json.Unmarshal(creationResponse, &ccr); err != nil { return nil, ErrBadRequest.WithDetails("Parse error for Registration").WithInfo(err.Error()) } @@ -204,7 +200,7 @@ func (pcc *ParsedCredentialCreationData) Verify(storedChallenge string, verifyUs // 9. Return the appid extension value from the Session data. func (ppkc ParsedPublicKeyCredential) GetAppID(authExt AuthenticationExtensions, credentialAttestationType string) (appID string, err error) { var ( - value, clientValue interface{} + value, clientValue any enableAppID, ok bool ) diff --git a/protocol/credential_test.go b/protocol/credential_test.go index 26d0256c..7000b4a6 100644 --- a/protocol/credential_test.go +++ b/protocol/credential_test.go @@ -1,9 +1,7 @@ package protocol import ( - "bytes" "encoding/base64" - "io" "testing" "github.com/stretchr/testify/assert" @@ -94,22 +92,11 @@ func TestParseCredentialCreationResponse(t *testing.T) { }, errString: "", }, - { - name: "ShouldHandleTrailingData", - args: args{ - responseName: "trailingData", - }, - expected: nil, - errString: "Parse error for Registration", - errType: "invalid_request", - errDetails: "Parse error for Registration", - errInfo: "The body contains trailing data", - }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - body := io.NopCloser(bytes.NewReader([]byte(testCredentialRequestResponses[tc.args.responseName]))) + body := []byte(testCredentialRequestResponses[tc.args.responseName]) actual, err := ParseCredentialCreationResponseBody(body) diff --git a/protocol/decoder.go b/protocol/decoder.go deleted file mode 100644 index 92e8a81c..00000000 --- a/protocol/decoder.go +++ /dev/null @@ -1,23 +0,0 @@ -package protocol - -import ( - "encoding/json" - "errors" - "io" -) - -func decodeBody(body io.Reader, v any) (err error) { - decoder := json.NewDecoder(body) - - if err = decoder.Decode(v); err != nil { - return err - } - - _, err = decoder.Token() - - if !errors.Is(err, io.EOF) { - return errors.New("The body contains trailing data") - } - - return nil -} diff --git a/webauthn/login.go b/webauthn/login.go index 391f35b9..d2ad1f83 100644 --- a/webauthn/login.go +++ b/webauthn/login.go @@ -3,6 +3,7 @@ package webauthn import ( "bytes" "fmt" + "io" "net/http" "net/url" "time" @@ -162,8 +163,13 @@ func WithLoginRelyingPartyID(id string) LoginOption { } // FinishLogin takes the response from the client and validate it against the user credentials and stored session data. -func (webauthn *WebAuthn) FinishLogin(user User, session SessionData, response *http.Request) (*Credential, error) { - parsedResponse, err := protocol.ParseCredentialRequestResponse(response) +func (webauthn *WebAuthn) FinishLogin(user User, session SessionData, clientResponse any) (*Credential, error) { + body, err := webauthn.processResponse(clientResponse) + if err != nil { + return nil, err + } + + parsedResponse, err := protocol.ParseCredentialRequestResponse(body) if err != nil { return nil, err } @@ -171,11 +177,39 @@ func (webauthn *WebAuthn) FinishLogin(user User, session SessionData, response * return webauthn.ValidateLogin(user, session, parsedResponse) } +func (webauthn *WebAuthn) processResponse(data any) ([]byte, error) { + var ( + body []byte + err error + ) + + switch cl := data.(type) { + case *http.Request: + body, err = io.ReadAll(cl.Body) + _ = cl.Body.Close() + + case io.Reader: + body, err = io.ReadAll(cl) + + case []byte: + body = cl + + default: + return nil, protocol.ErrBadRequest.WithDetails("Invalid client response type") + } + + if err != nil { + return nil, err + } + + return body, nil +} + // FinishDiscoverableLogin takes the response from the client and validate it against the handler and stored session data. // The handler helps to find out which user must be used to validate the response. This is a function defined in your // business code that will retrieve the user from your persistent data. -func (webauthn *WebAuthn) FinishDiscoverableLogin(handler DiscoverableUserHandler, session SessionData, response *http.Request) (*Credential, error) { - parsedResponse, err := protocol.ParseCredentialRequestResponse(response) +func (webauthn *WebAuthn) FinishDiscoverableLogin(handler DiscoverableUserHandler, session SessionData, clientResponse []byte) (*Credential, error) { + parsedResponse, err := protocol.ParseCredentialRequestResponse(clientResponse) if err != nil { return nil, err } diff --git a/webauthn/registration.go b/webauthn/registration.go index a0d6e3a6..4a9e9bb1 100644 --- a/webauthn/registration.go +++ b/webauthn/registration.go @@ -3,7 +3,6 @@ package webauthn import ( "bytes" "fmt" - "net/http" "net/url" "time" @@ -203,7 +202,12 @@ func WithRegistrationRelyingPartyName(name string) RegistrationOption { // FinishRegistration takes the response from the authenticator and client and verify the credential against the user's // credentials and session data. -func (webauthn *WebAuthn) FinishRegistration(user User, session SessionData, response *http.Request) (*Credential, error) { +func (webauthn *WebAuthn) FinishRegistration(user User, session SessionData, authenticatorResponse any) (*Credential, error) { + response, err := webauthn.processResponse(authenticatorResponse) + if err != nil { + return nil, err + } + parsedResponse, err := protocol.ParseCredentialCreationResponse(response) if err != nil { return nil, err