Skip to content

Commit 12da462

Browse files
authored
Add unit tests for CLI and Auth packages (prequel-dev#55)
* add cli unit/integration tests Signed-off-by: amanycodes <amanycodes@gmail.com> * updated coverage Signed-off-by: amanycodes <amanycodes@gmail.com> * added auth unit/integration tests Signed-off-by: amanycodes <amanycodes@gmail.com> * update coverage Signed-off-by: amanycodes <amanycodes@gmail.com> --------- Signed-off-by: amanycodes <amanycodes@gmail.com>
1 parent 5d5f9e5 commit 12da462

File tree

4 files changed

+411
-3
lines changed

4 files changed

+411
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# preq
2-
![Coverage](https://img.shields.io/badge/Coverage-29.7%25-red)
2+
![Coverage](https://img.shields.io/badge/Coverage-39.8%25-red)
33
[![Unit Tests](https://github.com/prequel-dev/cre/actions/workflows/build.yml/badge.svg)](https://github.com/prequel-dev/cre/actions/workflows/build.yml)
44
[![Unit Tests](https://github.com/prequel-dev/preq/actions/workflows/build.yml/badge.svg)](https://github.com/prequel-dev/preq/actions/workflows/build.yml)
55
[![Unit Tests](https://github.com/prequel-dev/prequel-compiler/actions/workflows/build.yml/badge.svg)](https://github.com/prequel-dev/prequel-compiler/actions/workflows/build.yml)

internal/pkg/auth/auth_test.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"crypto/rsa"
7+
"crypto/x509"
8+
"encoding/json"
9+
"encoding/pem"
10+
"net/http"
11+
"net/http/httptest"
12+
"os"
13+
"path/filepath"
14+
"strings"
15+
"testing"
16+
"time"
17+
18+
"github.com/golang-jwt/jwt"
19+
)
20+
21+
var (
22+
testPrivateKey *rsa.PrivateKey
23+
testPublicKey *rsa.PublicKey
24+
testPublicKeyPEM []byte
25+
)
26+
27+
func TestMain(m *testing.M) {
28+
var err error
29+
testPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
30+
if err != nil {
31+
panic("failed to generate test key: " + err.Error())
32+
}
33+
testPublicKey = &testPrivateKey.PublicKey
34+
35+
pubKeyBytes, err := x509.MarshalPKIXPublicKey(testPublicKey)
36+
if err != nil {
37+
panic("failed to marshal test public key: " + err.Error())
38+
}
39+
testPublicKeyPEM = pem.EncodeToMemory(&pem.Block{
40+
Type: "PUBLIC KEY",
41+
Bytes: pubKeyBytes,
42+
})
43+
44+
os.Exit(m.Run())
45+
}
46+
47+
func generateTestToken(claims *UserClaims, t *testing.T) string {
48+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
49+
tokenString, err := token.SignedString(testPrivateKey)
50+
if err != nil {
51+
t.Fatalf("Failed to sign test token: %v", err)
52+
}
53+
return tokenString
54+
}
55+
56+
func TestEmailClaim(t *testing.T) {
57+
t.Run("valid token", func(t *testing.T) {
58+
payload := `{"email": "test@example.com"}`
59+
encodedPayload := jwt.EncodeSegment([]byte(payload))
60+
jwtString := "header." + encodedPayload + ".signature"
61+
62+
email, err := EmailClaim(jwtString)
63+
if err != nil {
64+
t.Fatalf("Expected no error, but got: %v", err)
65+
}
66+
if email != "test@example.com" {
67+
t.Errorf("Expected email 'test@example.com', but got '%s'", email)
68+
}
69+
})
70+
71+
t.Run("invalid JWT segments", func(t *testing.T) {
72+
_, err := EmailClaim("just.one.part")
73+
if err == nil {
74+
t.Fatal("Expected an error for an invalid number of segments, but got nil")
75+
}
76+
})
77+
78+
t.Run("malformed base64 payload", func(t *testing.T) {
79+
_, err := EmailClaim("header.%%%%.signature")
80+
if err == nil {
81+
t.Fatal("Expected an error for a malformed payload, but got nil")
82+
}
83+
})
84+
85+
t.Run("payload missing email claim", func(t *testing.T) {
86+
payload := `{"name": "test user"}`
87+
encodedPayload := jwt.EncodeSegment([]byte(payload))
88+
jwtString := "header." + encodedPayload + ".signature"
89+
90+
_, err := EmailClaim(jwtString)
91+
if err != ErrInvalidTokenClaims {
92+
t.Fatalf("Expected error '%v', but got '%v'", ErrInvalidTokenClaims, err)
93+
}
94+
})
95+
}
96+
97+
func TestCheckLocalToken(t *testing.T) {
98+
originalKey := publicJwtKeyPEM
99+
publicJwtKeyPEM = testPublicKeyPEM
100+
t.Cleanup(func() {
101+
publicJwtKeyPEM = originalKey
102+
})
103+
104+
tempDir := t.TempDir()
105+
tokenPath := filepath.Join(tempDir, "test.token")
106+
107+
t.Run("valid token file", func(t *testing.T) {
108+
claims := &UserClaims{StandardClaims: jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Hour).Unix()}}
109+
tokenString := generateTestToken(claims, t)
110+
os.WriteFile(tokenPath, []byte(tokenString), 0644)
111+
112+
readToken, err := checkLocalToken(tokenPath)
113+
if err != nil {
114+
t.Fatalf("Expected no error for a valid token, but got: %v", err)
115+
}
116+
if readToken != tokenString {
117+
t.Error("Returned token does not match original token")
118+
}
119+
})
120+
121+
t.Run("expired token", func(t *testing.T) {
122+
claims := &UserClaims{StandardClaims: jwt.StandardClaims{ExpiresAt: time.Now().Add(-time.Hour).Unix()}}
123+
tokenString := generateTestToken(claims, t)
124+
os.WriteFile(tokenPath, []byte(tokenString), 0644)
125+
126+
_, err := checkLocalToken(tokenPath)
127+
if err == nil || !strings.Contains(err.Error(), "token is expired") {
128+
t.Fatalf("Expected an expiry error, but got: %v", err)
129+
}
130+
})
131+
132+
t.Run("malformed token file", func(t *testing.T) {
133+
os.WriteFile(tokenPath, []byte("this is not a jwt"), 0644)
134+
_, err := checkLocalToken(tokenPath)
135+
if err == nil {
136+
t.Fatal("Expected an error for a malformed token, but got nil")
137+
}
138+
})
139+
140+
t.Run("token signed with wrong key", func(t *testing.T) {
141+
otherPrivateKey, _ := rsa.GenerateKey(rand.Reader, 2048)
142+
token := jwt.NewWithClaims(jwt.SigningMethodRS256, &UserClaims{})
143+
tokenString, _ := token.SignedString(otherPrivateKey)
144+
os.WriteFile(tokenPath, []byte(tokenString), 0644)
145+
146+
_, err := checkLocalToken(tokenPath)
147+
if err == nil || !strings.Contains(err.Error(), "crypto/rsa: verification error") {
148+
t.Fatalf("Expected a signature verification error, but got: %v", err)
149+
}
150+
})
151+
152+
t.Run("token file does not exist", func(t *testing.T) {
153+
_, err := checkLocalToken("/path/that/does/not/exist.token")
154+
if err == nil {
155+
t.Fatal("Expected an error for a missing file, but got nil")
156+
}
157+
})
158+
}
159+
160+
func TestLogin_LocalTokenExists(t *testing.T) {
161+
originalKey := publicJwtKeyPEM
162+
publicJwtKeyPEM = testPublicKeyPEM
163+
t.Cleanup(func() {
164+
publicJwtKeyPEM = originalKey
165+
})
166+
167+
tempDir := t.TempDir()
168+
tokenPath := filepath.Join(tempDir, "login.token")
169+
170+
expectedToken := generateTestToken(&UserClaims{
171+
StandardClaims: jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Hour).Unix()},
172+
}, t)
173+
os.WriteFile(tokenPath, []byte(expectedToken), 0644)
174+
175+
token, err := Login(context.Background(), "http://dummy-addr", tokenPath)
176+
177+
if err != nil {
178+
t.Fatalf("Login failed when a valid local token exists: %v", err)
179+
}
180+
if token != expectedToken {
181+
t.Errorf("Login returned an incorrect token. Got %s, want %s", token, expectedToken)
182+
}
183+
}
184+
185+
func TestAuthenticationFlow_EndToEnd(t *testing.T) {
186+
originalKey := publicJwtKeyPEM
187+
publicJwtKeyPEM = testPublicKeyPEM
188+
t.Cleanup(func() {
189+
publicJwtKeyPEM = originalKey
190+
})
191+
192+
expectedClaims := &UserClaims{
193+
StandardClaims: jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Hour).Unix()},
194+
Email: "final-user@example.com",
195+
}
196+
finalTokenString := generateTestToken(expectedClaims, t)
197+
198+
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
199+
switch r.URL.Path {
200+
case "/v1/auth/rules":
201+
w.Header().Set("Content-Type", "application/json")
202+
json.NewEncoder(w).Encode(DeviceAuth{
203+
DeviceCode: "test-device-code",
204+
ExpiresIn: 60,
205+
Interval: 0,
206+
})
207+
case "/v1/auth/token_poll_rules":
208+
w.Header().Set("Content-Type", "application/json")
209+
json.NewEncoder(w).Encode(TokenPollResponse{
210+
AccessToken: "dummy-access-token",
211+
IdToken: "dummy-id-token",
212+
OrgUuid: "dummy-org-uuid",
213+
})
214+
case "/v1/auth/exchange_rules":
215+
w.Header().Set("Content-Type", "application/json")
216+
json.NewEncoder(w).Encode(Token{
217+
Token: finalTokenString,
218+
Type: TokenTypePrequel,
219+
})
220+
default:
221+
http.NotFound(w, r)
222+
t.Errorf("Received unexpected request to path: %s", r.URL.Path)
223+
}
224+
}))
225+
t.Cleanup(mockServer.Close)
226+
227+
deviceAuth, err := startAuth(context.Background(), mockServer.URL+"/v1/auth/rules")
228+
if err != nil {
229+
t.Fatalf("startAuth failed: %v", err)
230+
}
231+
232+
tokenPollResponse, err := pollToken(context.Background(), mockServer.URL, deviceAuth)
233+
if err != nil {
234+
t.Fatalf("pollToken failed: %v", err)
235+
}
236+
237+
finalToken, err := exchangeRulesToken(context.Background(), mockServer.URL, tokenPollResponse)
238+
if err != nil {
239+
t.Fatalf("exchangeRulesToken failed: %v", err)
240+
}
241+
242+
if finalToken.Token != finalTokenString {
243+
t.Errorf("Final token does not match expected. Got %s, want %s", finalToken.Token, finalTokenString)
244+
}
245+
}

internal/pkg/cli/cli.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ var (
4141
ruleUpdateFile = filepath.Join(defaultConfigDir, ".ruleupdate")
4242
)
4343

44+
var (
45+
getRulesFunc = func(ctx context.Context, conf *config.Config, configDir, cmdLineRules, token, ruleUpdateFile, baseAddr string, tlsPort, udpPort int) ([]utils.RulePathT, error) {
46+
return rules.GetRules(ctx, conf, configDir, cmdLineRules, token, ruleUpdateFile, baseAddr, tlsPort, udpPort)
47+
}
48+
loginUserFunc = func(ctx context.Context, baseAddr, ruleToken string) (string, error) {
49+
return auth.Login(ctx, baseAddr, ruleToken)
50+
}
51+
)
52+
4453
const (
4554
tlsPort = 443
4655
udpPort = 8081
@@ -105,7 +114,7 @@ func InitAndExecute(ctx context.Context) error {
105114
}
106115

107116
// Log in for community rule updates
108-
if token, err = auth.Login(ctx, baseAddr, ruleToken); err != nil {
117+
if token, err = loginUserFunc(ctx, baseAddr, ruleToken); err != nil {
109118
log.Error().Err(err).Msg("Failed to login")
110119

111120
// A notice will be printed if the email is not verified
@@ -128,7 +137,7 @@ func InitAndExecute(ctx context.Context) error {
128137
c.Skip = timez.DefaultSkip
129138
}
130139

131-
rulesPaths, err = rules.GetRules(ctx, c, defaultConfigDir, Options.Rules, token, ruleUpdateFile, baseAddr, tlsPort, udpPort)
140+
rulesPaths, err = getRulesFunc(ctx, c, defaultConfigDir, Options.Rules, token, ruleUpdateFile, baseAddr, tlsPort, udpPort)
132141
if err != nil {
133142
log.Error().Err(err).Msg("Failed to get rules")
134143
ux.RulesError(err)

0 commit comments

Comments
 (0)