Skip to content

Commit a65ecfd

Browse files
owenrumneyCopilot
andauthored
test: add tests for areas that currently dont have coverage (#44)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 96819f8 commit a65ecfd

File tree

10 files changed

+679
-22
lines changed

10 files changed

+679
-22
lines changed

Makefile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ test:
1414
go test -v ./... -ldflags "-X github.com/aquasecurity/trivy-mcp/pkg/version.TrivyVersion=$${trivy_version}" -coverprofile=coverage.out -covermode=atomic
1515
@echo "Tests completed."
1616

17+
.PHONY: coverage
18+
coverage:
19+
@echo "Generating coverage report..."
20+
@trivy_version=$$(cat go.mod | grep 'github.com/aquasecurity/trivy v' | awk '{ print $$2}') ;\
21+
echo Current trivy version: $$trivy_version ;\
22+
go test -v ./... -ldflags "-X github.com/aquasecurity/trivy-mcp/pkg/version.TrivyVersion=$${trivy_version}" -coverprofile=coverage.out -covermode=atomic
23+
@go tool cover -html=coverage.out -o coverage.html
24+
@echo "Coverage report generated: coverage.html"
25+
1726
.PHONY: build
1827
build: clean $(OUTPUTS)
1928
%/trivy-mcp:

internal/creds/aqua_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"github.com/stretchr/testify/require"
88
)
99

10+
// skipCI skips the test if running in a CI environment
11+
// This is useful to avoid running tests that require user interaction or keyring access etc
1012
func skipCI(t *testing.T) {
1113
if os.Getenv("GITHUB_ACTIONS") == "true" {
1214
t.Skip("Skipping test in CI environment")

internal/creds/verify_test.go

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
package creds
2+
3+
import (
4+
"io"
5+
"strings"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestDecodeJWT(t *testing.T) {
14+
tests := []struct {
15+
name string
16+
jwt string
17+
wantErr bool
18+
}{
19+
{
20+
name: "valid JWT",
21+
jwt: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE2MTYyMzkwMjJ9.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
22+
wantErr: false,
23+
},
24+
{
25+
name: "invalid JWT format",
26+
jwt: "invalid-jwt",
27+
wantErr: true,
28+
},
29+
{
30+
name: "invalid JWT payload",
31+
jwt: "header.invalid-payload.signature",
32+
wantErr: true,
33+
},
34+
}
35+
36+
for _, tt := range tests {
37+
t.Run(tt.name, func(t *testing.T) {
38+
claims, err := decodeJWT(tt.jwt)
39+
if tt.wantErr {
40+
require.Error(t, err)
41+
return
42+
}
43+
require.NoError(t, err)
44+
assert.NotNil(t, claims)
45+
})
46+
}
47+
}
48+
49+
func TestComputeHmac256(t *testing.T) {
50+
tests := []struct {
51+
name string
52+
message string
53+
secret string
54+
want string
55+
wantErr bool
56+
}{
57+
{
58+
name: "valid computation",
59+
message: "test message",
60+
secret: "test secret",
61+
want: "b5664a92da7fef821fa7ff75c00f711ba615dcb610de82edc440bc1337e251ef",
62+
wantErr: false,
63+
},
64+
{
65+
name: "empty message",
66+
message: "",
67+
secret: "test secret",
68+
want: "18914c0590232ac230ffa391cacdf29978282fd411ba0173587c59e607cb4af7",
69+
wantErr: false,
70+
},
71+
}
72+
73+
for _, tt := range tests {
74+
t.Run(tt.name, func(t *testing.T) {
75+
result, err := computeHmac256(tt.message, tt.secret)
76+
if tt.wantErr {
77+
require.Error(t, err)
78+
return
79+
}
80+
require.NoError(t, err)
81+
assert.Equal(t, tt.want, result)
82+
})
83+
}
84+
}
85+
86+
func TestGenerateToken(t *testing.T) {
87+
skipCI(t)
88+
89+
tests := []struct {
90+
name string
91+
creds *AquaCreds
92+
wantToken bool
93+
wantErr bool
94+
}{
95+
{
96+
name: "valid with existing valid token",
97+
creds: &AquaCreds{
98+
AquaKey: "test-key",
99+
AquaSecret: "test-secret",
100+
Region: "test-region",
101+
Token: "test-token",
102+
ExpiresAt: time.Now().Add(time.Hour).Unix(), // future expiry
103+
},
104+
wantToken: true,
105+
wantErr: false,
106+
},
107+
{
108+
name: "valid with expired token",
109+
creds: &AquaCreds{
110+
AquaKey: "test-key",
111+
AquaSecret: "test-secret",
112+
Region: "test-region",
113+
Token: "test-token",
114+
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // past expiry
115+
},
116+
wantToken: false,
117+
wantErr: true, // Will fail on Verify() since test-key and test-secret aren't valid
118+
},
119+
{
120+
name: "missing credentials",
121+
creds: &AquaCreds{
122+
Region: "test-region",
123+
},
124+
wantToken: false,
125+
wantErr: true,
126+
},
127+
{
128+
name: "missing region",
129+
creds: &AquaCreds{
130+
AquaKey: "test-key",
131+
AquaSecret: "test-secret",
132+
},
133+
wantToken: false,
134+
wantErr: true,
135+
},
136+
}
137+
138+
for _, tt := range tests {
139+
t.Run(tt.name, func(t *testing.T) {
140+
token, err := tt.creds.GenerateToken()
141+
if tt.wantErr {
142+
require.Error(t, err)
143+
return
144+
}
145+
require.NoError(t, err)
146+
if tt.wantToken {
147+
assert.NotEmpty(t, token)
148+
}
149+
})
150+
}
151+
}
152+
153+
func TestVerify(t *testing.T) {
154+
skipCI(t)
155+
156+
tests := []struct {
157+
name string
158+
creds *AquaCreds
159+
wantErr bool
160+
}{
161+
{
162+
name: "missing credentials",
163+
creds: &AquaCreds{
164+
Region: "test-region",
165+
},
166+
wantErr: true,
167+
},
168+
{
169+
name: "invalid credentials",
170+
creds: &AquaCreds{
171+
AquaKey: "invalid-key",
172+
AquaSecret: "invalid-secret",
173+
},
174+
wantErr: true,
175+
},
176+
}
177+
178+
for _, tt := range tests {
179+
t.Run(tt.name, func(t *testing.T) {
180+
err := tt.creds.Verify()
181+
if tt.wantErr {
182+
require.Error(t, err)
183+
return
184+
}
185+
require.NoError(t, err)
186+
// Valid credentials would be tested here, but we don't have real credentials for testing
187+
})
188+
}
189+
}
190+
191+
func TestGetUrls(t *testing.T) {
192+
tests := []struct {
193+
name string
194+
region string
195+
expectedSCSURL string
196+
expectedCSPMURL string
197+
}{
198+
{
199+
name: "default region",
200+
region: "",
201+
expectedSCSURL: "https://api.supply-chain.cloud.aquasec.com",
202+
expectedCSPMURL: "https://api.cloudsploit.com",
203+
},
204+
{
205+
name: "dev region",
206+
region: "dev",
207+
expectedSCSURL: "https://api.dev.supply-chain.cloud.aquasec.com",
208+
expectedCSPMURL: "https://stage.api.cloudsploit.com",
209+
},
210+
{
211+
name: "eu region",
212+
region: "eu",
213+
expectedSCSURL: "https://api.eu-1.supply-chain.cloud.aquasec.com",
214+
expectedCSPMURL: "https://eu-1.api.cloudsploit.com",
215+
},
216+
{
217+
name: "singapore region",
218+
region: "singapore",
219+
expectedSCSURL: "https://api.ap-1.supply-chain.cloud.aquasec.com",
220+
expectedCSPMURL: "https://ap-1.api.cloudsploit.com",
221+
},
222+
{
223+
name: "sydney region",
224+
region: "sydney",
225+
expectedSCSURL: "https://api.ap-2.supply-chain.cloud.aquasec.com",
226+
expectedCSPMURL: "https://ap-2.api.cloudsploit.com",
227+
},
228+
}
229+
230+
for _, tt := range tests {
231+
t.Run(tt.name, func(t *testing.T) {
232+
creds := &AquaCreds{Region: tt.region}
233+
scsURL, cspmURL := creds.GetUrls()
234+
assert.Equal(t, tt.expectedSCSURL, scsURL)
235+
assert.Equal(t, tt.expectedCSPMURL, cspmURL)
236+
})
237+
}
238+
}
239+
240+
func TestGetRawMessageData(t *testing.T) {
241+
tests := []struct {
242+
name string
243+
input string
244+
expected string
245+
}{
246+
{
247+
name: "simple string",
248+
input: "test data",
249+
expected: "test data",
250+
},
251+
{
252+
name: "empty string",
253+
input: "",
254+
expected: "",
255+
},
256+
{
257+
name: "json string",
258+
input: `{"status":200,"message":"success","data":"token123"}`,
259+
expected: `{"status":200,"message":"success","data":"token123"}`,
260+
},
261+
}
262+
263+
for _, tt := range tests {
264+
t.Run(tt.name, func(t *testing.T) {
265+
result := getRawMessageData(io.NopCloser(strings.NewReader(tt.input)))
266+
assert.Equal(t, tt.expected, result)
267+
})
268+
}
269+
}

pkg/commands/auth_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package commands
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestNewAuthCommand(t *testing.T) {
12+
cmd := NewAuthCommand()
13+
assert.Equal(t, "auth", cmd.Use)
14+
assert.Equal(t, "Auth tools for the Aqua Platform", cmd.Short)
15+
16+
subCmds := cmd.Commands()
17+
require.Len(t, subCmds, 4)
18+
19+
var foundLogin, foundLogout, foundStatus, foundToken bool
20+
for _, subCmd := range subCmds {
21+
switch subCmd.Use {
22+
case "login":
23+
foundLogin = true
24+
case "logout":
25+
foundLogout = true
26+
case "status":
27+
foundStatus = true
28+
case "token":
29+
foundToken = true
30+
}
31+
}
32+
assert.True(t, foundLogin)
33+
assert.True(t, foundLogout)
34+
assert.True(t, foundStatus)
35+
assert.True(t, foundToken)
36+
37+
for _, subCmd := range subCmds {
38+
if subCmd.Use == "token" {
39+
assert.True(t, subCmd.Hidden)
40+
}
41+
}
42+
}
43+
44+
func TestGetInput_existingValue(t *testing.T) {
45+
val, err := getInput("foo", "prompt: ", false)
46+
assert.NoError(t, err)
47+
assert.Equal(t, "foo", val)
48+
}
49+
50+
func TestGetInput_emptyInput(t *testing.T) {
51+
oldStdin := os.Stdin
52+
defer func() { os.Stdin = oldStdin }()
53+
r, w, _ := os.Pipe()
54+
os.Stdin = r
55+
go func() {
56+
_, err := w.Write([]byte("\n"))
57+
assert.NoError(t, err)
58+
require.NoError(t, w.Close())
59+
}()
60+
_, err := getInput("", "prompt: ", false)
61+
assert.Error(t, err)
62+
}
63+
64+
func TestGetRegionFromList_existing(t *testing.T) {
65+
val, err := getRegionFromList("US")
66+
assert.NoError(t, err)
67+
assert.Equal(t, "US", val)
68+
}
69+
70+
func TestGetRegionFromList_prompt(t *testing.T) {
71+
oldStdin := os.Stdin
72+
defer func() { os.Stdin = oldStdin }()
73+
r, w, _ := os.Pipe()
74+
os.Stdin = r
75+
go func() {
76+
_, err := w.Write([]byte("1\n"))
77+
assert.NoError(t, err)
78+
require.NoError(t, w.Close())
79+
}()
80+
val, err := getRegionFromList("")
81+
assert.NoError(t, err)
82+
assert.Equal(t, "US", val)
83+
}

0 commit comments

Comments
 (0)