Skip to content

Commit fbd3fd2

Browse files
authored
Merge pull request #625 from hslatman/hs/acme-revocation
ACME Certificate Revocation
2 parents 53ebd85 + 00539d0 commit fbd3fd2

26 files changed

+2334
-44
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased - 0.18.1] - DATE
88
### Added
9+
- Support for ACME revocation.
910
### Changed
1011
### Deprecated
1112
### Removed

acme/api/handler.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,17 @@ func (h *Handler) Route(r api.Router) {
100100
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
101101
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
102102

103+
validatingMiddleware := func(next nextHTTP) nextHTTP {
104+
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))
105+
}
103106
extractPayloadByJWK := func(next nextHTTP) nextHTTP {
104-
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
107+
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
105108
}
106109
extractPayloadByKid := func(next nextHTTP) nextHTTP {
107-
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
110+
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
111+
}
112+
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
113+
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
108114
}
109115

110116
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
@@ -117,6 +123,7 @@ func (h *Handler) Route(r api.Router) {
117123
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
118124
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
119125
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
126+
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
120127
}
121128

122129
// GetNonce just sets the right header since a Nonce is added to each response

acme/api/middleware.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
262262
// Store the JWK in the context.
263263
ctx = context.WithValue(ctx, jwkContextKey, jwk)
264264

265-
// Get Account or continue to generate a new one.
265+
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key
266266
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
267267
switch {
268268
case errors.Is(err, acme.ErrNotFound):
269-
// For NewAccount requests ...
269+
// For NewAccount and Revoke requests ...
270270
break
271271
case err != nil:
272272
api.WriteError(w, err)
@@ -352,6 +352,42 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
352352
}
353353
}
354354

355+
// extractOrLookupJWK forwards handling to either extractJWK or
356+
// lookupJWK based on the presence of a JWK or a KID, respectively.
357+
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
358+
return func(w http.ResponseWriter, r *http.Request) {
359+
ctx := r.Context()
360+
jws, err := jwsFromContext(ctx)
361+
if err != nil {
362+
api.WriteError(w, err)
363+
return
364+
}
365+
366+
// at this point the JWS has already been verified (if correctly configured in middleware),
367+
// and it can be used to check if a JWK exists. This flow is used when the ACME client
368+
// signed the payload with a certificate private key.
369+
if canExtractJWKFrom(jws) {
370+
h.extractJWK(next)(w, r)
371+
return
372+
}
373+
374+
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
375+
// signed the payload with an account private key.
376+
h.lookupJWK(next)(w, r)
377+
}
378+
}
379+
380+
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
381+
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
382+
if jws == nil {
383+
return false
384+
}
385+
if len(jws.Signatures) == 0 {
386+
return false
387+
}
388+
return jws.Signatures[0].Protected.JSONWebKey != nil
389+
}
390+
355391
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
356392
// Make sure to parse and validate the JWS before running this middleware.
357393
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {

acme/api/middleware_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,3 +1472,187 @@ func TestHandler_validateJWS(t *testing.T) {
14721472
})
14731473
}
14741474
}
1475+
1476+
func Test_canExtractJWKFrom(t *testing.T) {
1477+
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
1478+
assert.FatalError(t, err)
1479+
type args struct {
1480+
jws *jose.JSONWebSignature
1481+
}
1482+
tests := []struct {
1483+
name string
1484+
args args
1485+
want bool
1486+
}{
1487+
{
1488+
name: "no-jws",
1489+
args: args{
1490+
jws: nil,
1491+
},
1492+
want: false,
1493+
},
1494+
{
1495+
name: "no-signatures",
1496+
args: args{
1497+
jws: &jose.JSONWebSignature{
1498+
Signatures: []jose.Signature{},
1499+
},
1500+
},
1501+
want: false,
1502+
},
1503+
{
1504+
name: "no-jwk",
1505+
args: args{
1506+
jws: &jose.JSONWebSignature{
1507+
Signatures: []jose.Signature{
1508+
{
1509+
Protected: jose.Header{},
1510+
},
1511+
},
1512+
},
1513+
},
1514+
want: false,
1515+
},
1516+
{
1517+
name: "ok",
1518+
args: args{
1519+
jws: &jose.JSONWebSignature{
1520+
Signatures: []jose.Signature{
1521+
{
1522+
Protected: jose.Header{
1523+
JSONWebKey: jwk,
1524+
},
1525+
},
1526+
},
1527+
},
1528+
},
1529+
want: true,
1530+
},
1531+
}
1532+
for _, tt := range tests {
1533+
t.Run(tt.name, func(t *testing.T) {
1534+
if got := canExtractJWKFrom(tt.args.jws); got != tt.want {
1535+
t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want)
1536+
}
1537+
})
1538+
}
1539+
}
1540+
1541+
func TestHandler_extractOrLookupJWK(t *testing.T) {
1542+
u := "https://ca.smallstep.com/acme/account"
1543+
type test struct {
1544+
db acme.DB
1545+
linker Linker
1546+
statusCode int
1547+
ctx context.Context
1548+
err *acme.Error
1549+
next func(w http.ResponseWriter, r *http.Request)
1550+
}
1551+
var tests = map[string]func(t *testing.T) test{
1552+
"ok/extract": func(t *testing.T) test {
1553+
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
1554+
assert.FatalError(t, err)
1555+
kid, err := jwk.Thumbprint(crypto.SHA256)
1556+
assert.FatalError(t, err)
1557+
pub := jwk.Public()
1558+
pub.KeyID = base64.RawURLEncoding.EncodeToString(kid)
1559+
so := new(jose.SignerOptions)
1560+
so.WithHeader("jwk", pub) // JWK for certificate private key flow
1561+
signer, err := jose.NewSigner(jose.SigningKey{
1562+
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
1563+
Key: jwk.Key,
1564+
}, so)
1565+
assert.FatalError(t, err)
1566+
signed, err := signer.Sign([]byte("foo"))
1567+
assert.FatalError(t, err)
1568+
raw, err := signed.CompactSerialize()
1569+
assert.FatalError(t, err)
1570+
parsedJWS, err := jose.ParseJWS(raw)
1571+
assert.FatalError(t, err)
1572+
return test{
1573+
linker: NewLinker("dns", "acme"),
1574+
db: &acme.MockDB{
1575+
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
1576+
assert.Equals(t, kid, pub.KeyID)
1577+
return nil, acme.ErrNotFound
1578+
},
1579+
},
1580+
ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS),
1581+
statusCode: 200,
1582+
next: func(w http.ResponseWriter, r *http.Request) {
1583+
w.Write(testBody)
1584+
},
1585+
}
1586+
},
1587+
"ok/lookup": func(t *testing.T) test {
1588+
prov := newProv()
1589+
provName := url.PathEscape(prov.GetName())
1590+
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
1591+
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
1592+
assert.FatalError(t, err)
1593+
accID := "accID"
1594+
prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName)
1595+
so := new(jose.SignerOptions)
1596+
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) // KID for account private key flow
1597+
signer, err := jose.NewSigner(jose.SigningKey{
1598+
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
1599+
Key: jwk.Key,
1600+
}, so)
1601+
assert.FatalError(t, err)
1602+
jws, err := signer.Sign([]byte("baz"))
1603+
assert.FatalError(t, err)
1604+
raw, err := jws.CompactSerialize()
1605+
assert.FatalError(t, err)
1606+
parsedJWS, err := jose.ParseJWS(raw)
1607+
assert.FatalError(t, err)
1608+
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
1609+
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
1610+
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
1611+
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
1612+
return test{
1613+
linker: NewLinker("test.ca.smallstep.com", "acme"),
1614+
db: &acme.MockDB{
1615+
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
1616+
assert.Equals(t, accID, acc.ID)
1617+
return acc, nil
1618+
},
1619+
},
1620+
ctx: ctx,
1621+
statusCode: 200,
1622+
next: func(w http.ResponseWriter, r *http.Request) {
1623+
w.Write(testBody)
1624+
},
1625+
}
1626+
},
1627+
}
1628+
for name, prep := range tests {
1629+
tc := prep(t)
1630+
t.Run(name, func(t *testing.T) {
1631+
h := &Handler{db: tc.db, linker: tc.linker}
1632+
req := httptest.NewRequest("GET", u, nil)
1633+
req = req.WithContext(tc.ctx)
1634+
w := httptest.NewRecorder()
1635+
h.extractOrLookupJWK(tc.next)(w, req)
1636+
res := w.Result()
1637+
1638+
assert.Equals(t, res.StatusCode, tc.statusCode)
1639+
1640+
body, err := io.ReadAll(res.Body)
1641+
res.Body.Close()
1642+
assert.FatalError(t, err)
1643+
1644+
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
1645+
var ae acme.Error
1646+
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
1647+
1648+
assert.Equals(t, ae.Type, tc.err.Type)
1649+
assert.Equals(t, ae.Detail, tc.err.Detail)
1650+
assert.Equals(t, ae.Identifier, tc.err.Identifier)
1651+
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
1652+
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
1653+
} else {
1654+
assert.Equals(t, bytes.TrimSpace(body), testBody)
1655+
}
1656+
})
1657+
}
1658+
}

0 commit comments

Comments
 (0)