Skip to content

Commit 31cf01b

Browse files
authored
Merge pull request #37 from smallstep/feat/guessSigAlg
Make it easier to sign a JWT with a key.
2 parents 6be1399 + a20952a commit 31cf01b

File tree

4 files changed

+147
-2
lines changed

4 files changed

+147
-2
lines changed

jose/parse.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package jose
22

33
import (
44
"bytes"
5+
"crypto"
56
"crypto/ecdsa"
67
"crypto/ed25519"
78
"crypto/elliptic"
@@ -351,6 +352,24 @@ func guessJWKAlgorithm(ctx *context, jwk *JSONWebKey) {
351352
}
352353
}
353354

355+
// guessSignatureAlgorithm returns the signature algorithm for a given private key.
356+
func guessSignatureAlgorithm(key crypto.PrivateKey) SignatureAlgorithm {
357+
switch k := key.(type) {
358+
case []byte:
359+
return DefaultOctSigAlgorithm
360+
case *ecdsa.PrivateKey:
361+
return SignatureAlgorithm(getECAlgorithm(k.Curve))
362+
case *rsa.PrivateKey:
363+
return DefaultRSASigAlgorithm
364+
case ed25519.PrivateKey:
365+
return EdDSA
366+
case x25519.PrivateKey, X25519Signer:
367+
return XEdDSA
368+
default:
369+
return ""
370+
}
371+
}
372+
354373
// guessKnownJWKAlgorithm sets the algorithm for keys that only have one
355374
// possible algorithm.
356375
func guessKnownJWKAlgorithm(ctx *context, jwk *JSONWebKey) {

jose/parse_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,44 @@ func Test_guessKeyType(t *testing.T) {
685685
})
686686
}
687687
}
688+
689+
func Test_guessSignatureAlgorithm(t *testing.T) {
690+
must := func(args ...interface{}) crypto.PrivateKey {
691+
last := len(args) - 1
692+
if err := args[last]; err != nil {
693+
t.Fatal(err)
694+
}
695+
return args[last-1]
696+
}
697+
698+
_, x25519Key, err := x25519.GenerateKey(rand.Reader)
699+
if err != nil {
700+
t.Fatal(err)
701+
}
702+
703+
type args struct {
704+
key crypto.PrivateKey
705+
}
706+
tests := []struct {
707+
name string
708+
args args
709+
want SignatureAlgorithm
710+
}{
711+
{"byte", args{[]byte("the-key")}, HS256},
712+
{"ES256", args{must(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))}, ES256},
713+
{"ES384", args{must(ecdsa.GenerateKey(elliptic.P384(), rand.Reader))}, ES384},
714+
{"ES512", args{must(ecdsa.GenerateKey(elliptic.P521(), rand.Reader))}, ES512},
715+
{"RS256", args{must(rsa.GenerateKey(rand.Reader, 2048))}, RS256},
716+
{"EdDSA", args{must(ed25519.GenerateKey(rand.Reader))}, EdDSA},
717+
{"XEdDSA", args{x25519Key}, XEdDSA},
718+
{"XEdDSA with X25519Signer", args{X25519Signer(x25519Key)}, XEdDSA},
719+
{"empty", args{must(ecdsa.GenerateKey(elliptic.P224(), rand.Reader))}, ""},
720+
}
721+
for _, tt := range tests {
722+
t.Run(tt.name, func(t *testing.T) {
723+
if got := guessSignatureAlgorithm(tt.args.key); !reflect.DeepEqual(got, tt.want) {
724+
t.Errorf("guessSignatureAlgorithm() = %v, want %v", got, tt.want)
725+
}
726+
})
727+
}
728+
}

jose/types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ var ErrInvalidSubject = jwt.ErrInvalidSubject
118118
// ErrInvalidID indicates invalid jti claim.
119119
var ErrInvalidID = jwt.ErrInvalidID
120120

121+
// ErrIssuedInTheFuture indicates that the iat field is in the future.
122+
var ErrIssuedInTheFuture = jwt.ErrIssuedInTheFuture
123+
121124
// Key management algorithms
122125
//nolint:revive // use standard names in upper-case
123126
const (
@@ -236,6 +239,9 @@ func NewSigner(sig SigningKey, opts *SignerOptions) (Signer, error) {
236239
if k, ok := sig.Key.(x25519.PrivateKey); ok {
237240
sig.Key = X25519Signer(k)
238241
}
242+
if sig.Algorithm == "" {
243+
sig.Algorithm = guessSignatureAlgorithm(sig.Key)
244+
}
239245
return jose.NewSigner(sig, opts)
240246
}
241247

jose/types_test.go

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
// Code generated (comment to force golint to ignore this file). DO NOT EDIT.
2-
31
package jose
42

53
import (
4+
"crypto"
5+
"crypto/ecdsa"
6+
"crypto/ed25519"
7+
"crypto/elliptic"
8+
"crypto/rand"
9+
"crypto/rsa"
610
"reflect"
711
"testing"
812
"time"
913

1014
"github.com/pkg/errors"
15+
"go.step.sm/crypto/x25519"
1116
)
1217

1318
func TestNumericDate(t *testing.T) {
@@ -100,3 +105,77 @@ func TestTrimPrefix(t *testing.T) {
100105
})
101106
}
102107
}
108+
109+
func TestSignVerify(t *testing.T) {
110+
must := func(args ...interface{}) crypto.Signer {
111+
last := len(args) - 1
112+
if err := args[last]; err != nil {
113+
t.Fatal(err)
114+
}
115+
return args[last-1].(crypto.Signer)
116+
}
117+
118+
p224 := must(ecdsa.GenerateKey(elliptic.P224(), rand.Reader))
119+
p256 := must(ecdsa.GenerateKey(elliptic.P256(), rand.Reader))
120+
p384 := must(ecdsa.GenerateKey(elliptic.P384(), rand.Reader))
121+
p521 := must(ecdsa.GenerateKey(elliptic.P521(), rand.Reader))
122+
rsa2048 := must(rsa.GenerateKey(rand.Reader, 2048))
123+
edKey := must(ed25519.GenerateKey(rand.Reader))
124+
xKey := must(x25519.GenerateKey(rand.Reader))
125+
126+
type args struct {
127+
sig SigningKey
128+
opts *SignerOptions
129+
}
130+
tests := []struct {
131+
name string
132+
args args
133+
wantErr bool
134+
}{
135+
{"byte", args{SigningKey{Key: []byte("the-key")}, nil}, false},
136+
{"P256", args{SigningKey{Key: p256}, nil}, false},
137+
{"P384", args{SigningKey{Key: p384}, nil}, false},
138+
{"P521", args{SigningKey{Key: p521}, nil}, false},
139+
{"rsa2048", args{SigningKey{Key: rsa2048}, nil}, false},
140+
{"ed", args{SigningKey{Key: edKey}, nil}, false},
141+
{"x25519", args{SigningKey{Key: xKey}, nil}, false},
142+
{"fail P224", args{SigningKey{Key: p224}, nil}, true},
143+
}
144+
for _, tt := range tests {
145+
t.Run(tt.name, func(t *testing.T) {
146+
got, err := NewSigner(tt.args.sig, tt.args.opts)
147+
if (err != nil) != tt.wantErr {
148+
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
149+
return
150+
}
151+
if !tt.wantErr {
152+
payload := []byte(`{"sub": "sub"}`)
153+
jws, err := got.Sign(payload)
154+
if err != nil {
155+
t.Errorf("Signer.Sign() error = %v", err)
156+
return
157+
}
158+
jwt, err := ParseSigned(jws.FullSerialize())
159+
if err != nil {
160+
t.Errorf("ParseSigned() error = %v", err)
161+
return
162+
}
163+
164+
var claims Claims
165+
if signer, ok := tt.args.sig.Key.(crypto.Signer); ok {
166+
err = Verify(jwt, signer.Public(), &claims)
167+
} else {
168+
err = Verify(jwt, tt.args.sig.Key, &claims)
169+
}
170+
if err != nil {
171+
t.Errorf("JSONWebSignature.Verify() error = %v", err)
172+
return
173+
}
174+
want := Claims{Subject: "sub"}
175+
if !reflect.DeepEqual(claims, want) {
176+
t.Errorf("JSONWebSignature.Verify() claims = %v, want %v", claims, want)
177+
}
178+
}
179+
})
180+
}
181+
}

0 commit comments

Comments
 (0)