Skip to content

Commit a5b0a27

Browse files
authored
More unit tests for the auth package (#28)
* Added more unit tests for auth package * Added some unit tests for low-level JWT processing * More test cases * Fixing a typo
1 parent 9bd40b8 commit a5b0a27

File tree

5 files changed

+248
-23
lines changed

5 files changed

+248
-23
lines changed

auth/auth_test.go

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
package auth
1616

1717
import (
18+
"encoding/json"
1819
"errors"
20+
"fmt"
1921
"io/ioutil"
2022
"log"
2123
"os"
@@ -81,6 +83,38 @@ func TestMain(m *testing.M) {
8183
os.Exit(m.Run())
8284
}
8385

86+
func TestNewClientInvalidCredentials(t *testing.T) {
87+
creds := &google.DefaultCredentials{
88+
JSON: []byte("foo"),
89+
}
90+
conf := &internal.AuthConfig{
91+
Ctx: context.Background(),
92+
Creds: creds,
93+
}
94+
if c, err := NewClient(conf); c != nil || err == nil {
95+
t.Errorf("NewCient() = (%v,%v); want = (nil, error)", c, err)
96+
}
97+
}
98+
99+
func TestNewClientInvalidPrivateKey(t *testing.T) {
100+
sa := map[string]interface{}{
101+
"private_key": "foo",
102+
"client_email": "bar@test.com",
103+
}
104+
b, err := json.Marshal(sa)
105+
if err != nil {
106+
t.Fatal(err)
107+
}
108+
creds := &google.DefaultCredentials{JSON: b}
109+
conf := &internal.AuthConfig{
110+
Ctx: context.Background(),
111+
Creds: creds,
112+
}
113+
if c, err := NewClient(conf); c != nil || err == nil {
114+
t.Errorf("NewCient() = (%v,%v); want = (nil, error)", c, err)
115+
}
116+
}
117+
84118
func TestCustomToken(t *testing.T) {
85119
token, err := client.CustomToken("user1")
86120
if err != nil {
@@ -118,13 +152,14 @@ func TestCustomTokenError(t *testing.T) {
118152
}{
119153
{"EmptyName", "", nil},
120154
{"LongUid", strings.Repeat("a", 129), nil},
121-
{"ReservedClaims", "uid", map[string]interface{}{"sub": "1234"}},
155+
{"ReservedClaim", "uid", map[string]interface{}{"sub": "1234"}},
156+
{"ReservedClaims", "uid", map[string]interface{}{"sub": "1234", "aud": "foo"}},
122157
}
123158

124159
for _, tc := range cases {
125160
token, err := client.CustomTokenWithClaims(tc.uid, tc.claims)
126161
if token != "" || err == nil {
127-
t.Errorf("CustomTokenWithClaims(%q) = (%q, %v); want: (\"\", error)", tc.name, token, err)
162+
t.Errorf("CustomTokenWithClaims(%q) = (%q, %v); want = (\"\", error)", tc.name, token, err)
128163
}
129164
}
130165
}
@@ -137,12 +172,12 @@ func TestCustomTokenInvalidCredential(t *testing.T) {
137172

138173
token, err := s.CustomToken("user1")
139174
if token != "" || err == nil {
140-
t.Errorf("CustomTokenWithClaims() = (%q, %v); want: (\"\", error)", token, err)
175+
t.Errorf("CustomTokenWithClaims() = (%q, %v); want = (\"\", error)", token, err)
141176
}
142177

143178
token, err = s.CustomTokenWithClaims("user1", map[string]interface{}{"foo": "bar"})
144179
if token != "" || err == nil {
145-
t.Errorf("CustomTokenWithClaims() = (%q, %v); want: (\"\", error)", token, err)
180+
t.Errorf("CustomTokenWithClaims() = (%q, %v); want = (\"\", error)", token, err)
146181
}
147182
}
148183

@@ -152,15 +187,23 @@ func TestVerifyIDToken(t *testing.T) {
152187
t.Fatal(err)
153188
}
154189
if ft.Claims["admin"] != true {
155-
t.Errorf("Claims['admin'] = %v; want: true", ft.Claims["admin"])
190+
t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"])
156191
}
157192
if ft.UID != ft.Subject {
158193
t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject)
159194
}
160195
}
161196

197+
func TestVerifyIDTokenInvalidSignature(t *testing.T) {
198+
parts := strings.Split(testIDToken, ".")
199+
token := fmt.Sprintf("%s:%s:invalidsignature", parts[0], parts[1])
200+
if ft, err := client.VerifyIDToken(token); ft != nil || err == nil {
201+
t.Errorf("VerifyiDToken('invalid-signature') = (%v, %v); want = (nil, error)", ft, err)
202+
}
203+
}
204+
162205
func TestVerifyIDTokenError(t *testing.T) {
163-
var now int64 = 1000
206+
now := time.Now().Unix()
164207
cases := []struct {
165208
name string
166209
token string
@@ -172,22 +215,18 @@ func TestVerifyIDTokenError(t *testing.T) {
172215
{"EmptySubject", getIDToken(mockIDTokenPayload{"sub": ""})},
173216
{"IntSubject", getIDToken(mockIDTokenPayload{"sub": 10})},
174217
{"LongSubject", getIDToken(mockIDTokenPayload{"sub": strings.Repeat("a", 129)})},
175-
{"FutureToken", getIDToken(mockIDTokenPayload{"iat": time.Unix(now+1, 0)})},
218+
{"FutureToken", getIDToken(mockIDTokenPayload{"iat": now + 1000})},
176219
{"ExpiredToken", getIDToken(mockIDTokenPayload{
177-
"iat": time.Unix(now-10, 0),
178-
"exp": time.Unix(now-1, 0),
220+
"iat": now - 1000,
221+
"exp": now - 100,
179222
})},
180223
{"EmptyToken", ""},
181224
{"BadFormatToken", "foobar"},
182225
}
183226

184-
clk = &mockClock{now: time.Unix(now, 0)}
185-
defer func() {
186-
clk = &systemClock{}
187-
}()
188227
for _, tc := range cases {
189228
if _, err := client.VerifyIDToken(tc.token); err == nil {
190-
t.Errorf("VerifyyIDToken(%q) = nil; want error", tc.name)
229+
t.Errorf("VerifyIDToken(%q) = nil; want error", tc.name)
191230
}
192231
}
193232
}

auth/crypto.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,10 @@ func (k *httpKeySource) refreshKeys() error {
148148

149149
func findMaxAge(resp *http.Response) (*time.Duration, error) {
150150
cc := resp.Header.Get("cache-control")
151-
for _, value := range strings.Split(cc, ", ") {
151+
for _, value := range strings.Split(cc, ",") {
152152
value = strings.TrimSpace(value)
153-
if strings.HasPrefix(value, "max-age") {
153+
if strings.HasPrefix(value, "max-age=") {
154154
sep := strings.Index(value, "=")
155-
if sep == -1 {
156-
return nil, errors.New("Malformed cache-control header")
157-
}
158155
seconds, err := strconv.ParseInt(value[sep+1:], 10, 64)
159156
if err != nil {
160157
return nil, err

auth/crypto_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package auth
1616

1717
import (
18+
"errors"
1819
"fmt"
1920
"io"
2021
"io/ioutil"
@@ -122,6 +123,122 @@ func TestHTTPKeySourceWithClient(t *testing.T) {
122123
}
123124
}
124125

126+
func TestHTTPKeySourceEmptyResponse(t *testing.T) {
127+
hc, _ := newHTTPClient([]byte(""))
128+
ks, err := newHTTPKeySource(context.Background(), "http://mock.url", option.WithHTTPClient(hc))
129+
if err != nil {
130+
t.Fatal(err)
131+
}
132+
133+
if keys, err := ks.Keys(); keys != nil || err == nil {
134+
t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err)
135+
}
136+
}
137+
138+
func TestHTTPKeySourceIncorrectResponse(t *testing.T) {
139+
hc, _ := newHTTPClient([]byte("{\"foo\": 1}"))
140+
ks, err := newHTTPKeySource(context.Background(), "http://mock.url", option.WithHTTPClient(hc))
141+
if err != nil {
142+
t.Fatal(err)
143+
}
144+
145+
if keys, err := ks.Keys(); keys != nil || err == nil {
146+
t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err)
147+
}
148+
}
149+
150+
func TestHTTPKeySourceTransportError(t *testing.T) {
151+
hc := &http.Client{
152+
Transport: &mockHTTPResponse{
153+
Err: errors.New("transport error"),
154+
},
155+
}
156+
ks, err := newHTTPKeySource(context.Background(), "http://mock.url", option.WithHTTPClient(hc))
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
161+
if keys, err := ks.Keys(); keys != nil || err == nil {
162+
t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err)
163+
}
164+
}
165+
166+
func TestFindMaxAge(t *testing.T) {
167+
cases := []struct {
168+
cc string
169+
want int64
170+
}{
171+
{"max-age=100", 100},
172+
{"public, max-age=100", 100},
173+
{"public,max-age=100", 100},
174+
}
175+
for _, tc := range cases {
176+
resp := &http.Response{
177+
Header: http.Header{"Cache-Control": {tc.cc}},
178+
}
179+
age, err := findMaxAge(resp)
180+
if err != nil {
181+
t.Errorf("findMaxAge(%q) = %v", tc.cc, err)
182+
} else if *age != (time.Duration(tc.want) * time.Second) {
183+
t.Errorf("findMaxAge(%q) = %v; want %v", tc.cc, *age, tc.want)
184+
}
185+
}
186+
}
187+
188+
func TestFindMaxAgeError(t *testing.T) {
189+
cases := []string{
190+
"",
191+
"max-age 100",
192+
"max-age: 100",
193+
"max-age2=100",
194+
"max-age=foo",
195+
}
196+
for _, tc := range cases {
197+
resp := &http.Response{
198+
Header: http.Header{"Cache-Control": []string{tc}},
199+
}
200+
if age, err := findMaxAge(resp); age != nil || err == nil {
201+
t.Errorf("findMaxAge(%q) = (%v, %v); want = (nil, err)", tc, age, err)
202+
}
203+
}
204+
}
205+
206+
func TestParsePublicKeys(t *testing.T) {
207+
b, err := ioutil.ReadFile("../testdata/public_certs.json")
208+
if err != nil {
209+
t.Fatal(err)
210+
}
211+
keys, err := parsePublicKeys(b)
212+
if err != nil {
213+
t.Fatal(err)
214+
}
215+
if len(keys) != 3 {
216+
t.Errorf("parsePublicKeys() = %d; want: %d", len(keys), 3)
217+
}
218+
}
219+
220+
func TestParsePublicKeysError(t *testing.T) {
221+
cases := []string{
222+
"",
223+
"not-json",
224+
}
225+
for _, tc := range cases {
226+
if keys, err := parsePublicKeys([]byte(tc)); keys != nil || err == nil {
227+
t.Errorf("parsePublicKeys(%q) = (%v, %v); want: (nil, err)", tc, keys, err)
228+
}
229+
}
230+
}
231+
232+
func TestDefaultServiceAcctSigner(t *testing.T) {
233+
signer := &serviceAcctSigner{}
234+
if email, err := signer.Email(); email != "" || err == nil {
235+
t.Errorf("Email() = (%v, %v); want = ('', error)", email, err)
236+
}
237+
if sig, err := signer.Sign([]byte("")); sig != nil || err == nil {
238+
t.Errorf("Sign() = (%v, %v); want = ('', error)", sig, err)
239+
}
240+
}
241+
125242
func verifyHTTPKeySource(ks *httpKeySource, rc *mockReadCloser) error {
126243
mc := &mockClock{now: time.Unix(0, 0)}
127244
ks.Clock = mc

auth/jwt.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ func decode(s string, i interface{}) error {
8080
if err != nil {
8181
return err
8282
}
83-
if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(i); err != nil {
84-
return err
85-
}
86-
return nil
83+
return json.NewDecoder(bytes.NewBuffer(decoded)).Decode(i)
8784
}
8885

8986
func encodeToken(s signer, h jwtHeader, p jwtPayload) (string, error) {

auth/jwt_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package auth
2+
3+
import (
4+
"encoding/base64"
5+
"errors"
6+
"strings"
7+
"testing"
8+
)
9+
10+
func TestEncodeToken(t *testing.T) {
11+
h := defaultHeader()
12+
p := mockIDTokenPayload{"key": "value"}
13+
s, err := encodeToken(&mockSigner{}, h, p)
14+
if err != nil {
15+
t.Fatal(err)
16+
}
17+
parts := strings.Split(s, ".")
18+
if len(parts) != 3 {
19+
t.Errorf("encodeToken() = %d; want: %d", len(parts), 3)
20+
}
21+
22+
var header jwtHeader
23+
if err := decode(parts[0], &header); err != nil {
24+
t.Fatal(err)
25+
} else if h != header {
26+
t.Errorf("decode(header) = %v; want = %v", header, h)
27+
}
28+
29+
payload := make(mockIDTokenPayload)
30+
if err := decode(parts[1], &payload); err != nil {
31+
t.Fatal(err)
32+
} else if len(payload) != 1 || payload["key"] != "value" {
33+
t.Errorf("decode(payload) = %v; want = %v", payload, p)
34+
}
35+
36+
if sig, err := base64.RawURLEncoding.DecodeString(parts[2]); err != nil {
37+
t.Fatal(err)
38+
} else if string(sig) != "signature" {
39+
t.Errorf("decode(signature) = %q; want = %q", string(sig), "signature")
40+
}
41+
}
42+
43+
func TestEncodeSignError(t *testing.T) {
44+
h := defaultHeader()
45+
p := mockIDTokenPayload{"key": "value"}
46+
signer := &mockSigner{
47+
err: errors.New("sign error"),
48+
}
49+
if s, err := encodeToken(signer, h, p); s != "" || err == nil {
50+
t.Errorf("encodeToken() = (%v, %v); want = ('', error)", s, err)
51+
}
52+
}
53+
54+
func TestEncodeInvalidPayload(t *testing.T) {
55+
h := defaultHeader()
56+
p := mockIDTokenPayload{"key": func() {}}
57+
if s, err := encodeToken(&mockSigner{}, h, p); s != "" || err == nil {
58+
t.Errorf("encodeToken() = (%v, %v); want = ('', error)", s, err)
59+
}
60+
}
61+
62+
type mockSigner struct {
63+
err error
64+
}
65+
66+
func (s *mockSigner) Email() (string, error) {
67+
return "", nil
68+
}
69+
70+
func (s *mockSigner) Sign(b []byte) ([]byte, error) {
71+
if s.err != nil {
72+
return nil, s.err
73+
}
74+
return []byte("signature"), nil
75+
}

0 commit comments

Comments
 (0)