Skip to content

Commit ac197b0

Browse files
authored
Merge pull request #462 from smallstep/mariano/error-is
Allow to compare kms errors with errors.Is
2 parents 8e55bd9 + ca39242 commit ac197b0

File tree

4 files changed

+147
-13
lines changed

4 files changed

+147
-13
lines changed

kms/apiv1/options.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,13 @@ func (e NotImplementedError) Error() string {
7272
return "not implemented"
7373
}
7474

75+
func (e NotImplementedError) Is(target error) bool {
76+
_, ok := target.(NotImplementedError)
77+
return ok
78+
}
79+
7580
// AlreadyExistsError is the type of error returned if a key already exists. This
76-
// is currently only implmented for pkcs11 and tpmkms.
81+
// is currently only implemented for pkcs11, tpmkms, and mackms.
7782
type AlreadyExistsError struct {
7883
Message string
7984
}
@@ -82,7 +87,30 @@ func (e AlreadyExistsError) Error() string {
8287
if e.Message != "" {
8388
return e.Message
8489
}
85-
return "key already exists"
90+
return "already exists"
91+
}
92+
93+
func (e AlreadyExistsError) Is(target error) bool {
94+
_, ok := target.(AlreadyExistsError)
95+
return ok
96+
}
97+
98+
// NotFoundError is the type of error returned if a key or certificate does not
99+
// exist. This is currently only implemented for mackms.
100+
type NotFoundError struct {
101+
Message string
102+
}
103+
104+
func (e NotFoundError) Error() string {
105+
if e.Message != "" {
106+
return e.Message
107+
}
108+
return "not found"
109+
}
110+
111+
func (e NotFoundError) Is(target error) bool {
112+
_, ok := target.(NotFoundError)
113+
return ok
86114
}
87115

88116
// Type represents the KMS type used.

kms/apiv1/options_test.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ package apiv1
33
import (
44
"context"
55
"crypto"
6+
"errors"
7+
"fmt"
68
"os"
79
"testing"
10+
11+
"github.com/stretchr/testify/assert"
812
)
913

1014
type fakeKM struct{}
@@ -124,7 +128,7 @@ func TestErrAlreadyExists_Error(t *testing.T) {
124128
fields fields
125129
want string
126130
}{
127-
{"default", fields{}, "key already exists"},
131+
{"default", fields{}, "already exists"},
128132
{"custom", fields{"custom message: key already exists"}, "custom message: key already exists"},
129133
}
130134
for _, tt := range tests {
@@ -139,6 +143,30 @@ func TestErrAlreadyExists_Error(t *testing.T) {
139143
}
140144
}
141145

146+
func TestNotFoundError_Error(t *testing.T) {
147+
type fields struct {
148+
msg string
149+
}
150+
tests := []struct {
151+
name string
152+
fields fields
153+
want string
154+
}{
155+
{"default", fields{}, "not found"},
156+
{"custom", fields{"custom message: not found"}, "custom message: not found"},
157+
}
158+
for _, tt := range tests {
159+
t.Run(tt.name, func(t *testing.T) {
160+
e := NotFoundError{
161+
Message: tt.fields.msg,
162+
}
163+
if got := e.Error(); got != tt.want {
164+
t.Errorf("ErrAlreadyExists.Error() = %v, want %v", got, tt.want)
165+
}
166+
})
167+
}
168+
}
169+
142170
func TestTypeOf(t *testing.T) {
143171
type args struct {
144172
rawuri string
@@ -176,3 +204,29 @@ func TestTypeOf(t *testing.T) {
176204
})
177205
}
178206
}
207+
208+
func TestError_Is(t *testing.T) {
209+
tests := []struct {
210+
name string
211+
err error
212+
target error
213+
want bool
214+
}{
215+
{"ok not implemented", NotImplementedError{}, NotImplementedError{}, true},
216+
{"ok not implemented with message", NotImplementedError{Message: "something"}, NotImplementedError{}, true},
217+
{"ok already exists", AlreadyExistsError{}, AlreadyExistsError{}, true},
218+
{"ok already exists with message", AlreadyExistsError{Message: "something"}, AlreadyExistsError{}, true},
219+
{"ok not found", NotFoundError{}, NotFoundError{}, true},
220+
{"ok not found with message", NotFoundError{Message: "something"}, NotFoundError{}, true},
221+
{"fail not implemented", errors.New("not implemented"), NotImplementedError{}, false},
222+
{"fail already exists", errors.New("already exists"), AlreadyExistsError{}, false},
223+
{"fail not found", errors.New("not found"), NotFoundError{}, false},
224+
}
225+
for _, tt := range tests {
226+
t.Run(tt.name, func(t *testing.T) {
227+
assert.Equal(t, tt.want, errors.Is(tt.err, tt.target))
228+
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", tt.err), tt.target))
229+
assert.Equal(t, tt.want, errors.Is(fmt.Errorf("wrap 1: %w", fmt.Errorf("wrap 2: %w", tt.err)), tt.target))
230+
})
231+
}
232+
}

kms/mackms/mackms.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func (k *MacKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
141141

142142
key, err := getPrivateKey(u)
143143
if err != nil {
144-
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", err)
144+
return nil, fmt.Errorf("mackms GetPublicKey failed: %w", apiv1Error(err))
145145
}
146146
defer key.Release()
147147

@@ -263,7 +263,7 @@ func (k *MacKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
263263

264264
secKeyRef, err := security.SecKeyCreateRandomKey(attrs)
265265
if err != nil {
266-
return nil, fmt.Errorf("mackms CreateKey failed: %w", err)
266+
return nil, fmt.Errorf("mackms CreateKey failed: %w", apiv1Error(err))
267267
}
268268
defer secKeyRef.Release()
269269

@@ -307,7 +307,7 @@ func (k *MacKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er
307307

308308
key, err := getPrivateKey(u)
309309
if err != nil {
310-
return nil, fmt.Errorf("mackms CreateSigner failed: %w", err)
310+
return nil, fmt.Errorf("mackms CreateSigner failed: %w", apiv1Error(err))
311311
}
312312
defer key.Release()
313313

@@ -343,7 +343,7 @@ func (k *MacKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi
343343

344344
cert, err := loadCertificate(u.label, u.serialNumber, nil)
345345
if err != nil {
346-
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", err)
346+
return nil, fmt.Errorf("mackms LoadCertificate failed: %w", apiv1Error(err))
347347
}
348348

349349
return cert, nil
@@ -375,7 +375,7 @@ func (k *MacKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
375375

376376
// Store the certificate and update the label if required
377377
if err := storeCertificate(u.label, req.Certificate); err != nil {
378-
return fmt.Errorf("mackms StoreCertificate failed: %w", err)
378+
return fmt.Errorf("mackms StoreCertificate failed: %w", apiv1Error(err))
379379
}
380380

381381
return nil
@@ -402,7 +402,7 @@ func (k *MacKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([
402402

403403
cert, err := loadCertificate(u.label, u.serialNumber, nil)
404404
if err != nil {
405-
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", err)
405+
return nil, fmt.Errorf("mackms LoadCertificateChain failed1: %w", apiv1Error(err))
406406
}
407407

408408
chain := []*x509.Certificate{cert}
@@ -453,7 +453,7 @@ func (k *MacKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
453453

454454
// Store the certificate and update the label if required
455455
if err := storeCertificate(u.label, req.CertificateChain[0]); err != nil {
456-
return fmt.Errorf("mackms StoreCertificateChain failed: %w", err)
456+
return fmt.Errorf("mackms StoreCertificateChain failed: %w", apiv1Error(err))
457457
}
458458

459459
// Store the rest of the chain but do not fail if already exists
@@ -503,7 +503,7 @@ func (*MacKMS) DeleteKey(req *apiv1.DeleteKeyRequest) error {
503503
}
504504
// Extract logic to deleteItem to avoid defer on loops
505505
if err := deleteItem(dict, u.hash); err != nil {
506-
return fmt.Errorf("mackms DeleteKey failed: %w", err)
506+
return fmt.Errorf("mackms DeleteKey failed: %w", apiv1Error(err))
507507
}
508508
}
509509

@@ -548,7 +548,7 @@ func (*MacKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
548548
}
549549

550550
if err := deleteItem(query, nil); err != nil {
551-
return fmt.Errorf("mackms DeleteCertificate failed: %w", err)
551+
return fmt.Errorf("mackms DeleteCertificate failed: %w", apiv1Error(err))
552552
}
553553

554554
return nil
@@ -1003,3 +1003,18 @@ func ecdhToECDSAPublicKey(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) {
10031003
return nil, errors.New("failed to convert *ecdh.PublicKey to *ecdsa.PublicKey")
10041004
}
10051005
}
1006+
1007+
func apiv1Error(err error) error {
1008+
switch {
1009+
case errors.Is(err, security.ErrNotFound):
1010+
return apiv1.NotFoundError{
1011+
Message: err.Error(),
1012+
}
1013+
case errors.Is(err, security.ErrAlreadyExists):
1014+
return apiv1.AlreadyExistsError{
1015+
Message: err.Error(),
1016+
}
1017+
default:
1018+
return err
1019+
}
1020+
}

kms/mackms/mackms_test.go

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929
"crypto/x509"
3030
"crypto/x509/pkix"
3131
"encoding/hex"
32+
"fmt"
33+
"io"
3234
"math/big"
3335
"net/url"
3436
"testing"
@@ -1143,7 +1145,7 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
11431145
_, err := kms.LoadCertificate(&apiv1.LoadCertificateRequest{
11441146
Name: "mackms:serial=" + hex.EncodeToString(cert.SerialNumber.Bytes()),
11451147
})
1146-
assert.ErrorIs(t, err, security.ErrNotFound)
1148+
assert.ErrorIs(t, err, apiv1.NotFoundError{})
11471149
}
11481150

11491151
kms := &MacKMS{}
@@ -1196,3 +1198,38 @@ func TestMacKMS_DeleteCertificate(t *testing.T) {
11961198
})
11971199
}
11981200
}
1201+
1202+
func Test_apiv1Error(t *testing.T) {
1203+
type args struct {
1204+
err error
1205+
}
1206+
tests := []struct {
1207+
name string
1208+
args args
1209+
assertion assert.ErrorAssertionFunc
1210+
}{
1211+
{"ok not found", args{security.ErrNotFound}, func(t assert.TestingT, err error, msg ...interface{}) bool {
1212+
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
1213+
}},
1214+
{"ok not found wrapped", args{fmt.Errorf("something happened: %w", security.ErrNotFound)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
1215+
return assert.ErrorIs(t, err, apiv1.NotFoundError{}, msg...)
1216+
}},
1217+
{"ok already exists", args{security.ErrAlreadyExists}, func(t assert.TestingT, err error, msg ...interface{}) bool {
1218+
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
1219+
}},
1220+
{"ok already exists wrapped", args{fmt.Errorf("something happened: %w", security.ErrAlreadyExists)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
1221+
return assert.ErrorIs(t, err, apiv1.AlreadyExistsError{}, msg...)
1222+
}},
1223+
{"ok other", args{io.ErrUnexpectedEOF}, func(t assert.TestingT, err error, msg ...interface{}) bool {
1224+
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
1225+
}},
1226+
{"ok other wrapped", args{fmt.Errorf("something happened: %w", io.ErrUnexpectedEOF)}, func(t assert.TestingT, err error, msg ...interface{}) bool {
1227+
return assert.ErrorIs(t, err, io.ErrUnexpectedEOF, msg...)
1228+
}},
1229+
}
1230+
for _, tt := range tests {
1231+
t.Run(tt.name, func(t *testing.T) {
1232+
tt.assertion(t, apiv1Error(tt.args.err))
1233+
})
1234+
}
1235+
}

0 commit comments

Comments
 (0)