Skip to content

Commit 0d7bbfc

Browse files
committed
Added unit tests for saml
1 parent 025d848 commit 0d7bbfc

File tree

1 file changed

+396
-0
lines changed

1 file changed

+396
-0
lines changed

internal/server/saml_test.go

Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
// Copyright (c) ClaceIO, LLC
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package server
5+
6+
import (
7+
"context"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
12+
"github.com/gorilla/sessions"
13+
"github.com/openrundev/openrun/internal/metadata"
14+
"github.com/openrundev/openrun/internal/testutil"
15+
"github.com/openrundev/openrun/internal/types"
16+
saml2 "github.com/russellhaering/gosaml2"
17+
)
18+
19+
func TestGenSAMLCookieName(t *testing.T) {
20+
t.Parallel()
21+
22+
tests := []struct {
23+
name string
24+
provider string
25+
want string
26+
}{
27+
{
28+
name: "simple provider name",
29+
provider: "okta",
30+
want: "okta_openrun_saml_session",
31+
},
32+
{
33+
name: "provider with prefix",
34+
provider: "saml_google",
35+
want: "saml_google_openrun_saml_session",
36+
},
37+
{
38+
name: "empty provider",
39+
provider: "",
40+
want: "_openrun_saml_session",
41+
},
42+
}
43+
44+
for _, tt := range tests {
45+
t.Run(tt.name, func(t *testing.T) {
46+
t.Parallel()
47+
got := genSAMLCookieName(tt.provider)
48+
testutil.AssertEqualsString(t, "cookie name", tt.want, got)
49+
})
50+
}
51+
}
52+
53+
func TestBuildSAMLUrl(t *testing.T) {
54+
t.Parallel()
55+
56+
tests := []struct {
57+
name string
58+
baseUrl string
59+
providerName string
60+
endpoint string
61+
want string
62+
}{
63+
{
64+
name: "basic url without trailing slash",
65+
baseUrl: "https://example.com",
66+
providerName: "okta",
67+
endpoint: "acs",
68+
want: "https://example.com/_openrun/sso/okta/acs",
69+
},
70+
{
71+
name: "basic url with trailing slash",
72+
baseUrl: "https://example.com/",
73+
providerName: "okta",
74+
endpoint: "acs",
75+
want: "https://example.com/_openrun/sso/okta/acs",
76+
},
77+
{
78+
name: "url with path",
79+
baseUrl: "https://example.com/app",
80+
providerName: "google",
81+
endpoint: "metadata",
82+
want: "https://example.com/app/_openrun/sso/google/metadata",
83+
},
84+
{
85+
name: "url with path and trailing slash",
86+
baseUrl: "https://example.com/app/",
87+
providerName: "azure",
88+
endpoint: "slo",
89+
want: "https://example.com/app/_openrun/sso/azure/slo",
90+
},
91+
{
92+
name: "localhost url",
93+
baseUrl: "http://localhost:8080",
94+
providerName: "test",
95+
endpoint: "redirect",
96+
want: "http://localhost:8080/_openrun/sso/test/redirect",
97+
},
98+
}
99+
100+
for _, tt := range tests {
101+
t.Run(tt.name, func(t *testing.T) {
102+
t.Parallel()
103+
got := buildSAMLUrl(tt.baseUrl, tt.providerName, tt.endpoint)
104+
testutil.AssertEqualsString(t, "saml url", tt.want, got)
105+
})
106+
}
107+
}
108+
109+
func TestFirstNonEmpty(t *testing.T) {
110+
t.Parallel()
111+
112+
tests := []struct {
113+
name string
114+
slices [][]string
115+
want []string
116+
}{
117+
{
118+
name: "first slice non-empty",
119+
slices: [][]string{{"a", "b"}, {"c", "d"}, {"e", "f"}},
120+
want: []string{"a", "b"},
121+
},
122+
{
123+
name: "first slice empty, second non-empty",
124+
slices: [][]string{{}, {"c", "d"}, {"e", "f"}},
125+
want: []string{"c", "d"},
126+
},
127+
{
128+
name: "all slices empty",
129+
slices: [][]string{{}, {}, {}},
130+
want: []string{},
131+
},
132+
{
133+
name: "no slices",
134+
slices: [][]string{},
135+
want: []string{},
136+
},
137+
{
138+
name: "single non-empty slice",
139+
slices: [][]string{{"only"}},
140+
want: []string{"only"},
141+
},
142+
{
143+
name: "single empty slice",
144+
slices: [][]string{{}},
145+
want: []string{},
146+
},
147+
{
148+
name: "nil first, non-empty second",
149+
slices: [][]string{nil, {"value"}},
150+
want: []string{"value"},
151+
},
152+
{
153+
name: "all nil slices",
154+
slices: [][]string{nil, nil, nil},
155+
want: []string{},
156+
},
157+
}
158+
159+
for _, tt := range tests {
160+
t.Run(tt.name, func(t *testing.T) {
161+
t.Parallel()
162+
got := firstNonEmpty(tt.slices...)
163+
164+
if len(got) != len(tt.want) {
165+
t.Errorf("length mismatch: want %d, got %d", len(tt.want), len(got))
166+
return
167+
}
168+
169+
for i := range got {
170+
if got[i] != tt.want[i] {
171+
t.Errorf("element %d: want %s, got %s", i, tt.want[i], got[i])
172+
}
173+
}
174+
})
175+
}
176+
}
177+
178+
func TestNewSAMLManager(t *testing.T) {
179+
t.Parallel()
180+
181+
logger := createTestLogger()
182+
config := &types.ServerConfig{
183+
GlobalConfig: types.GlobalConfig{
184+
AdminUser: "admin",
185+
},
186+
}
187+
cookieStore := sessions.NewCookieStore([]byte("test-key"))
188+
db := &metadata.Metadata{}
189+
190+
manager := NewSAMLManager(logger, config, cookieStore, db)
191+
192+
if manager == nil {
193+
t.Fatal("NewSAMLManager returned nil")
194+
}
195+
196+
if manager.Logger == nil {
197+
t.Error("Logger is nil")
198+
}
199+
200+
if manager.config != config {
201+
t.Error("config not set correctly")
202+
}
203+
204+
if manager.cookieStore != cookieStore {
205+
t.Error("cookieStore not set correctly")
206+
}
207+
208+
if manager.db != db {
209+
t.Error("db not set correctly")
210+
}
211+
}
212+
213+
func TestSAMLManager_ValidateSAMLProvider(t *testing.T) {
214+
t.Parallel()
215+
216+
tests := []struct {
217+
name string
218+
setupProviders map[string]bool
219+
authType string
220+
want bool
221+
}{
222+
{
223+
name: "valid provider with rbac prefix",
224+
setupProviders: map[string]bool{
225+
"saml_okta": true,
226+
},
227+
authType: "rbac:saml_okta",
228+
want: true,
229+
},
230+
{
231+
name: "valid provider without rbac prefix",
232+
setupProviders: map[string]bool{
233+
"saml_google": true,
234+
},
235+
authType: "saml_google",
236+
want: true,
237+
},
238+
{
239+
name: "non-existent provider",
240+
setupProviders: map[string]bool{
241+
"saml_okta": true,
242+
},
243+
authType: "rbac:saml_azure",
244+
want: false,
245+
},
246+
{
247+
name: "empty providers map",
248+
setupProviders: map[string]bool{},
249+
authType: "rbac:saml_okta",
250+
want: false,
251+
},
252+
{
253+
name: "provider without saml prefix",
254+
setupProviders: map[string]bool{
255+
"okta": true,
256+
},
257+
authType: "rbac:okta",
258+
want: true,
259+
},
260+
}
261+
262+
for _, tt := range tests {
263+
t.Run(tt.name, func(t *testing.T) {
264+
t.Parallel()
265+
266+
logger := createTestLogger()
267+
config := &types.ServerConfig{}
268+
cookieStore := sessions.NewCookieStore([]byte("test-key"))
269+
db := &metadata.Metadata{}
270+
271+
manager := NewSAMLManager(logger, config, cookieStore, db)
272+
manager.providers = make(map[string]*saml2.SAMLServiceProvider)
273+
274+
// Setup mock providers
275+
for name := range tt.setupProviders {
276+
manager.providers[name] = &saml2.SAMLServiceProvider{}
277+
}
278+
279+
got := manager.ValidateSAMLProvider(tt.authType)
280+
testutil.AssertEqualsBool(t, "validation result", tt.want, got)
281+
})
282+
}
283+
}
284+
285+
func TestSAMLManager_Metadata(t *testing.T) {
286+
t.Parallel()
287+
288+
t.Run("provider not found", func(t *testing.T) {
289+
t.Parallel()
290+
291+
logger := createTestLogger()
292+
config := &types.ServerConfig{}
293+
cookieStore := sessions.NewCookieStore([]byte("test-key"))
294+
db := &metadata.Metadata{}
295+
296+
manager := NewSAMLManager(logger, config, cookieStore, db)
297+
manager.providers = make(map[string]*saml2.SAMLServiceProvider)
298+
299+
w := httptest.NewRecorder()
300+
301+
// Call metadata logic directly
302+
sp := manager.providers["nonexistent"]
303+
if sp == nil {
304+
http.Error(w, "provider not found", http.StatusNotFound)
305+
}
306+
307+
resp := w.Result()
308+
defer resp.Body.Close() //nolint:errcheck
309+
310+
testutil.AssertEqualsInt(t, "status code", http.StatusNotFound, resp.StatusCode)
311+
})
312+
}
313+
314+
func TestSAMLManager_Setup(t *testing.T) {
315+
t.Parallel()
316+
317+
tests := []struct {
318+
name string
319+
samlConfigs map[string]types.SAMLConfig
320+
expectError bool
321+
}{
322+
{
323+
name: "empty config",
324+
samlConfigs: map[string]types.SAMLConfig{},
325+
expectError: false,
326+
},
327+
{
328+
name: "missing callback url",
329+
samlConfigs: map[string]types.SAMLConfig{
330+
"okta": {
331+
MetadataURL: "https://example.com/metadata",
332+
UsePost: false,
333+
},
334+
},
335+
expectError: true,
336+
},
337+
}
338+
339+
for _, tt := range tests {
340+
t.Run(tt.name, func(t *testing.T) {
341+
t.Parallel()
342+
343+
logger := createTestLogger()
344+
config := &types.ServerConfig{
345+
SAML: tt.samlConfigs,
346+
}
347+
cookieStore := sessions.NewCookieStore([]byte("test-key"))
348+
db := &metadata.Metadata{}
349+
350+
manager := NewSAMLManager(logger, config, cookieStore, db)
351+
err := manager.Setup(context.Background())
352+
353+
if tt.expectError {
354+
if err == nil {
355+
t.Error("expected error but got none")
356+
}
357+
} else {
358+
if err != nil {
359+
t.Errorf("unexpected error: %v", err)
360+
}
361+
}
362+
})
363+
}
364+
}
365+
366+
func TestSAMLManager_SetupInitializationState(t *testing.T) {
367+
t.Parallel()
368+
369+
logger := createTestLogger()
370+
config := &types.ServerConfig{
371+
SAML: map[string]types.SAMLConfig{},
372+
}
373+
cookieStore := sessions.NewCookieStore([]byte("test-key"))
374+
db := &metadata.Metadata{}
375+
376+
manager := NewSAMLManager(logger, config, cookieStore, db)
377+
378+
// Before setup
379+
if manager.providerConfigs != nil {
380+
t.Error("providerConfigs should be nil before setup")
381+
}
382+
if manager.providers != nil {
383+
t.Error("providers should be nil before setup")
384+
}
385+
386+
// After setup
387+
err := manager.Setup(context.Background())
388+
testutil.AssertNoError(t, err)
389+
390+
if manager.providerConfigs == nil {
391+
t.Error("providerConfigs should be initialized after setup")
392+
}
393+
if manager.providers == nil {
394+
t.Error("providers should be initialized after setup")
395+
}
396+
}

0 commit comments

Comments
 (0)