diff --git a/api/http/common.go b/api/http/common.go index 4004450e02..3705eb0336 100644 --- a/api/http/common.go +++ b/api/http/common.go @@ -184,6 +184,21 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { return } + switch retErr := err.(type) { + case errors.RequestError: + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(retErr); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + case errors.AuthNError, errors.AuthZError: + w.WriteHeader(http.StatusUnauthorized) + if err := json.NewEncoder(w).Encode(retErr); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + var wrapper error if errors.Contains(err, apiutil.ErrValidation) { wrapper, err = errors.Unwrap(err) diff --git a/api/http/util/errors.go b/api/http/util/errors.go index 8af1f00994..cabf32e752 100644 --- a/api/http/util/errors.go +++ b/api/http/util/errors.go @@ -52,10 +52,10 @@ var ( ErrInvalidIDFormat = errors.New("invalid id format provided") // ErrNameSize indicates that name size exceeds the max. - ErrNameSize = errors.New("invalid name size") + ErrNameSize = errors.NewRequestError("invalid name size") // ErrEmailSize indicates that email size exceeds the max. - ErrEmailSize = errors.New("invalid email size") + ErrEmailSize = errors.NewRequestError("invalid email size") // ErrInvalidRole indicates that an invalid role. ErrInvalidRole = errors.New("invalid client role") @@ -124,7 +124,7 @@ var ( ErrInvalidContact = errors.New("invalid Subscription contact") // ErrMissingEmail indicates missing email. - ErrMissingEmail = errors.New("missing email") + ErrMissingEmail = errors.NewRequestError("missing email") // ErrInvalidEmail indicates missing email. ErrInvalidEmail = errors.New("invalid email") @@ -169,16 +169,16 @@ var ( ErrMissingIdentity = errors.New("missing entity identity") // ErrMissingSecret indicates missing secret. - ErrMissingSecret = errors.New("missing secret") + ErrMissingSecret = errors.NewRequestError("missing secret") // ErrPasswordFormat indicates weak password. - ErrPasswordFormat = errors.New("password does not meet the requirements") + ErrPasswordFormat = errors.NewRequestError("password does not meet the requirements") // ErrMissingName indicates missing identity name. - ErrMissingName = errors.New("missing identity name") + ErrMissingName = errors.NewRequestError("missing identity name") // ErrMissingRoute indicates missing route. - ErrMissingRoute = errors.New("missing route") + ErrMissingRoute = errors.NewRequestError("missing route") // ErrInvalidLevel indicates an invalid group level. ErrInvalidLevel = errors.New("invalid group level (should be between 0 and 5)") @@ -211,37 +211,37 @@ var ( ErrMissingTo = errors.New("missing to time value") // ErrEmptyMessage indicates empty message. - ErrEmptyMessage = errors.New("empty message") + ErrEmptyMessage = errors.NewRequestError("empty message") // ErrMissingEntityType indicates missing entity type. - ErrMissingEntityType = errors.New("missing entity type") + ErrMissingEntityType = errors.NewRequestError("missing entity type") // ErrInvalidEntityType indicates invalid entity type. - ErrInvalidEntityType = errors.New("invalid entity type") + ErrInvalidEntityType = errors.NewRequestError("invalid entity type") // ErrInvalidTimeFormat indicates invalid time format i.e not unix time. - ErrInvalidTimeFormat = errors.New("invalid time format use unix time") + ErrInvalidTimeFormat = errors.NewRequestError("invalid time format use unix time") // ErrEmptySearchQuery indicates search query should not be empty. - ErrEmptySearchQuery = errors.New("search query must not be empty") + ErrEmptySearchQuery = errors.NewRequestError("search query must not be empty") // ErrLenSearchQuery indicates search query length. - ErrLenSearchQuery = errors.New("search query must be at least 3 characters") + ErrLenSearchQuery = errors.NewRequestError("search query must be at least 3 characters") // ErrMissingDomainID indicates missing domainID. - ErrMissingDomainID = errors.New("missing domainID") + ErrMissingDomainID = errors.NewRequestError("missing domainID") // ErrMissingUsername indicates missing user name. - ErrMissingUsername = errors.New("missing username") + ErrMissingUsername = errors.NewRequestError("missing username") // ErrInvalidUsername indicates invalid user name. - ErrInvalidUsername = errors.New("invalid username") + ErrInvalidUsername = errors.NewRequestError("invalid username") // ErrMissingFirstName indicates missing first name. - ErrMissingFirstName = errors.New("missing first name") + ErrMissingFirstName = errors.NewRequestError("missing first name") // ErrMissingLastName indicates missing last name. - ErrMissingLastName = errors.New("missing last name") + ErrMissingLastName = errors.NewRequestError("missing last name") // ErrInvalidProfilePictureURL indicates that the profile picture url is invalid. ErrInvalidProfilePictureURL = errors.New("invalid profile picture url") @@ -255,23 +255,23 @@ var ( ErrUnsupportedTokenType = errors.New("unsupported content token type") // ErrMissingUserID indicates missing user ID. - ErrMissingUserID = errors.New("missing user id") + ErrMissingUserID = errors.NewRequestError("missing user id") // ErrMissingPATID indicates missing pat ID. ErrMissingPATID = errors.New("missing pat id") // ErrInvalidNameFormat indicates invalid name format. - ErrInvalidNameFormat = errors.New("invalid name format") + ErrInvalidNameFormat = errors.NewRequestError("invalid name format") // ErrInvalidRouteFormat indicates invalid route format. - ErrInvalidRouteFormat = errors.New("invalid route format") + ErrInvalidRouteFormat = errors.NewRequestError("invalid route format") // ErrMissingUsernameEmail indicates missing user name / email. - ErrMissingUsernameEmail = errors.New("missing username / email") + ErrMissingUsernameEmail = errors.NewRequestError("missing username / email") // ErrInvalidVerification indicates invalid email verification. - ErrInvalidVerification = errors.New("invalid verification") + ErrInvalidVerification = errors.NewRequestError("invalid verification") // ErrEmailNotVerified indicates invalid email not verified. - ErrEmailNotVerified = errors.New("email not verified") + ErrEmailNotVerified = errors.NewRequestError("email not verified") ) diff --git a/docker/addons/certs/openbao-entrypoint.sh b/docker/addons/certs/openbao-entrypoint.sh index 36f629b30f..69c2627377 100755 --- a/docker/addons/certs/openbao-entrypoint.sh +++ b/docker/addons/certs/openbao-entrypoint.sh @@ -270,7 +270,10 @@ if [ ! -f /opt/openbao/data/configured ]; then key_usage=\"DigitalSignature,KeyEncipherment,KeyAgreement\" \ ext_key_usage=\"ServerAuth,ClientAuth,OCSPSigning\" \ use_csr_common_name=true \ - use_csr_sans=false \ + use_csr_sans=true \ + copy_extensions=true \ + allowed_extensions=\"*\" \ + basic_constraints_valid_for_non_ca=true \ max_ttl=720h \ ttl=720h" @@ -284,6 +287,9 @@ path "pki_int/issue/${AM_CERTS_OPENBAO_PKI_ROLE}" { path "pki_int/sign/${AM_CERTS_OPENBAO_PKI_ROLE}" { capabilities = ["create", "update"] } +path "pki_int/sign-verbatim/${AM_CERTS_OPENBAO_PKI_ROLE}" { + capabilities = ["create", "update"] +} path "pki_int/certs" { capabilities = ["list"] } diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index b7642cbd4d..4f40ca1733 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -90,6 +90,11 @@ func Wrap(wrapper, err error) error { if wrapper == nil || err == nil { return wrapper } + + if ne, ok := err.(NewError); ok { + return ne.Wrap(wrapper) + } + if w, ok := wrapper.(Error); ok { return &customError{ msg: w.Msg(), diff --git a/pkg/errors/errortypes.go b/pkg/errors/errortypes.go new file mode 100644 index 0000000000..2d7481266f --- /dev/null +++ b/pkg/errors/errortypes.go @@ -0,0 +1,138 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package errors + +import ( + "encoding/json" + "errors" +) + +type NewError interface { + // Error implements the error interface. + Error() string + + // Msg returns error message. + Msg() string + + // Err returns wrapped error. + Unwrap() error + + Wrap(e error) error + + // MarshalJSON returns a marshaled error. + MarshalJSON() ([]byte, error) +} + +// NewError specifies an that request could be processed and error which should be addressed by user. +type newError struct { + Err error // Contains other internal details and errors as wrapped error + Message string // Message for end users returned by API layer or other end layer +} + +func (e newError) Error() string { + if e.Err == nil { + return e.Message + } + return e.Message + " : " + e.Err.Error() +} + +func (e newError) Unwrap() error { + return e.Err +} + +func (e newError) Msg() string { + return e.Message +} + +func (e newError) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + Err string `json:"error"` + }{ + Err: e.Message, + }) +} + +var _ NewError = (*RequestError)(nil) + +type RequestError struct { + newError +} + +func (e RequestError) Wrap(err error) error { + e.Err = errors.Join(err, e.Err) + return e +} + +func NewRequestError(message string) error { + return RequestError{ + newError: newError{ + Message: message, + }, + } +} + +func NewRequestErrorWithErr(message string, err error) error { + return RequestError{ + newError: newError{ + Message: message, + Err: err, + }, + } +} + +var _ NewError = (*AuthNError)(nil) + +type AuthNError struct { + newError +} + +func (e AuthNError) Wrap(err error) error { + e.Err = errors.Join(err, e.Err) + return e +} + +func NewAuthNError(message string) error { + return AuthNError{ + newError: newError{ + Message: message, + }, + } +} + +func NewAuthNErrorWithErr(message string, err error) error { + return AuthNError{ + newError: newError{ + Message: message, + Err: err, + }, + } +} + +var _ NewError = (*AuthZError)(nil) + +type AuthZError struct { + newError +} + +func (e AuthZError) Wrap(err error) error { + e.Err = errors.Join(err, e.Err) + return e +} + +func NewAuthZError(message string) error { + return AuthZError{ + newError: newError{ + Message: message, + }, + } +} + +func NewAuthZErrorWithErr(message string, err error) error { + return AuthZError{ + newError: newError{ + Message: message, + Err: err, + }, + } +} diff --git a/pkg/errors/handler.go b/pkg/errors/handler.go new file mode 100644 index 0000000000..ab85890006 --- /dev/null +++ b/pkg/errors/handler.go @@ -0,0 +1,14 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package errors + +type Mapper interface { + GetError(key string) (error, bool) +} + +type Handler interface { + HandleError(wrapper, err error) error +} + +type HandlerOption func(*Handler) diff --git a/pkg/errors/types.go b/pkg/errors/types.go index d747381b44..dadcffc3a9 100644 --- a/pkg/errors/types.go +++ b/pkg/errors/types.go @@ -25,10 +25,13 @@ var ( ErrRollbackTx = errors.New("failed to rollback transaction") // ErrAuthentication indicates failure occurred while authenticating the entity. - ErrAuthentication = errors.New("failed to perform authentication over the entity") + ErrAuthentication = NewAuthNError("failed to perform authentication over the entity") // ErrAuthorization indicates failure occurred while authorizing the entity. - ErrAuthorization = errors.New("failed to perform authorization over the entity") + ErrAuthorization = NewAuthZError("failed to perform authorization over the entity") + + // ErrDomainAuthorization indicates failure occurred while authorizing the domain. + ErrDomainAuthorization = NewAuthZError("failed to perform authorization over the domain") // ErrMissingDomainMember indicates member is not part of a domain. ErrMissingDomainMember = errors.New("member id is not member of domain") diff --git a/pkg/postgres/errorhandler.go b/pkg/postgres/errorhandler.go new file mode 100644 index 0000000000..60dcb8e091 --- /dev/null +++ b/pkg/postgres/errorhandler.go @@ -0,0 +1,52 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/jackc/pgx/v5/pgconn" +) + +var _ errors.Handler = (*errHandler)(nil) + +type errHandler struct { + duplicateErrors errors.Mapper +} + +func WithDuplicateErrors(mapper errors.Mapper) errors.HandlerOption { + return func(eh *errors.Handler) { + if h, ok := (*eh).(*errHandler); ok { + h.duplicateErrors = mapper + } + } +} + +func NewErrorHandler(opts ...errors.HandlerOption) errors.Handler { + var eh errors.Handler = &errHandler{} + for _, opt := range opts { + opt(&eh) + } + return eh +} + +// Handle handles the error. +func (eh errHandler) HandleError(wrapper, err error) error { + pqErr, ok := err.(*pgconn.PgError) + if ok { + switch pqErr.Code { + case errDuplicate: + if knownErr, ok := eh.duplicateErrors.GetError(pqErr.ConstraintName); ok { + return errors.Wrap(wrapper, knownErr) + } + return errors.Wrap(repoerr.ErrConflict, err) + case errInvalid, errInvalidChar, errTruncation, errUntranslatable: + return errors.Wrap(repoerr.ErrMalformedEntity, err) + case errFK: + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + + return errors.Wrap(wrapper, err) +} diff --git a/users/api/endpoints.go b/users/api/endpoints.go index 60b5910db4..5c97cd7118 100644 --- a/users/api/endpoints.go +++ b/users/api/endpoints.go @@ -6,10 +6,8 @@ package api import ( "context" - apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/users" "github.com/go-kit/kit/endpoint" ) @@ -18,7 +16,7 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin return func(ctx context.Context, request any) (any, error) { req := request.(createUserReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session := authn.Session{} @@ -26,7 +24,7 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin if !selfRegister { session, ok = ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } } @@ -48,7 +46,7 @@ func sendVerificationEndpoint(svc users.Service) endpoint.Endpoint { session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } if err := svc.SendVerification(ctx, session); err != nil { @@ -63,7 +61,7 @@ func verifyEmailEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(verifyEmailReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } if _, err := svc.VerifyEmail(ctx, req.token); err != nil { @@ -78,12 +76,12 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(viewUserReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } user, err := svc.View(ctx, session, req.id) if err != nil { @@ -98,7 +96,7 @@ func viewProfileEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } client, err := svc.ViewProfile(ctx, session) if err != nil { @@ -113,12 +111,12 @@ func listUsersEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(listUsersReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } pm := users.Page{ @@ -162,7 +160,7 @@ func searchUsersEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(searchUsersReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } pm := users.Page{ @@ -200,12 +198,12 @@ func updateEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateUserReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } usr := users.UserReq{ @@ -227,12 +225,12 @@ func updateTagsEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateUserTagsReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } usr := users.UserReq{ @@ -252,12 +250,12 @@ func updateEmailEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateEmailReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } user, err := svc.UpdateEmail(ctx, session, req.id, req.Email) @@ -283,7 +281,7 @@ func passwordResetRequestEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(passResetReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } if err := svc.SendPasswordReset(ctx, req.Email); err != nil { @@ -301,12 +299,12 @@ func passwordResetEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(resetTokenReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } if err := svc.ResetSecret(ctx, session, req.Password); err != nil { return nil, err @@ -320,12 +318,12 @@ func updateSecretEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateUserSecretReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } user, err := svc.UpdateSecret(ctx, session, req.OldSecret, req.NewSecret) if err != nil { @@ -340,12 +338,12 @@ func updateUsernameEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateUsernameReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthorization + return nil, errors.ErrAuthentication } user, err := svc.UpdateUsername(ctx, session, req.id, req.Username) @@ -361,7 +359,7 @@ func updateProfilePictureEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateProfilePictureReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } usr := users.UserReq{ @@ -370,7 +368,7 @@ func updateProfilePictureEndpoint(svc users.Service) endpoint.Endpoint { session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthorization + return nil, errors.ErrAuthentication } user, err := svc.UpdateProfilePicture(ctx, session, req.id, usr) @@ -386,7 +384,7 @@ func updateRoleEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(updateUserRoleReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } user := users.User{ @@ -396,7 +394,7 @@ func updateRoleEndpoint(svc users.Service) endpoint.Endpoint { session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } user, err := svc.UpdateRole(ctx, session, user) @@ -412,7 +410,7 @@ func issueTokenEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(loginUserReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } token, err := svc.IssueToken(ctx, req.Username, req.Password) @@ -432,12 +430,12 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(tokenReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } token, err := svc.RefreshToken(ctx, session, req.RefreshToken) @@ -457,12 +455,12 @@ func enableEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(changeUserStatusReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } user, err := svc.Enable(ctx, session, req.id) @@ -478,12 +476,12 @@ func disableEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(changeUserStatusReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } user, err := svc.Disable(ctx, session, req.id) @@ -499,12 +497,12 @@ func deleteEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(changeUserStatusReq) if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return nil, err } session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { - return nil, svcerr.ErrAuthentication + return nil, errors.ErrAuthentication } if err := svc.Delete(ctx, session, req.id); err != nil { diff --git a/users/postgres/errors.go b/users/postgres/errors.go new file mode 100644 index 0000000000..aa68373b40 --- /dev/null +++ b/users/postgres/errors.go @@ -0,0 +1,27 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "github.com/absmach/supermq/pkg/errors" +) + +var _ errors.Mapper = (*duplicateErrors)(nil) + +type duplicateErrors struct{} + +func (d duplicateErrors) GetError(key string) (error, bool) { + switch key { + case "clients_email_key": + return errors.NewRequestError("email id already exists"), true + case "clients_username_key": + return errors.NewRequestError("username not available"), true + default: + return nil, false + } +} + +func NewDuplicateErrors() errors.Mapper { + return duplicateErrors{} +} diff --git a/users/postgres/init.go b/users/postgres/init.go index 29c1a19c84..ebf84a4a28 100644 --- a/users/postgres/init.go +++ b/users/postgres/init.go @@ -118,6 +118,15 @@ func Migration() *migrate.MemoryMigrationSource { `DROP TABLE users_verifications;`, }, }, + { + Id: "clients_08", + Up: []string{ + `ALTER TABLE users RENAME CONSTRAINT clients_identity_key TO clients_email_key;`, + }, + Down: []string{ + `ALTER TABLE users RENAME CONSTRAINT clients_email_key TO clients_identity_key;`, + }, + }, }, } } diff --git a/users/postgres/users.go b/users/postgres/users.go index 9f30eefcee..9ee1c8808d 100644 --- a/users/postgres/users.go +++ b/users/postgres/users.go @@ -22,11 +22,16 @@ import ( type userRepo struct { Repository users.UserRepository + eh errors.Handler } func NewRepository(db postgres.Database) users.Repository { + errHandlerOptions := []errors.HandlerOption{ + postgres.WithDuplicateErrors(NewDuplicateErrors()), + } return &userRepo{ Repository: users.UserRepository{DB: db}, + eh: postgres.NewErrorHandler(errHandlerOptions...), } } @@ -42,7 +47,7 @@ func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu) if err != nil { - return users.User{}, postgres.HandleError(repoerr.ErrCreateEntity, err) + return users.User{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err) } defer row.Close() @@ -232,7 +237,7 @@ func (repo *userRepo) update(ctx context.Context, user users.User, query string) row, err := repo.Repository.DB.NamedQueryContext(ctx, query, dbu) if err != nil { - return users.User{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + return users.User{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err) } defer row.Close()