Skip to content

Commit f5a2138

Browse files
authored
Parse teams with saml configs containing access group objects as role values (#264)
1 parent fd43378 commit f5a2138

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

client/team.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67

78
"github.com/hashicorp/terraform-plugin-log/tflog"
@@ -14,9 +15,51 @@ type TeamCreateRequest struct {
1415
Plan string `json:"plan"`
1516
}
1617

18+
type SamlRoleAccessGroupID struct {
19+
AccessGroupID string `json:"accessGroupId"`
20+
}
21+
22+
type SamlRole struct {
23+
Role *string
24+
AccessGroupID *SamlRoleAccessGroupID
25+
}
26+
27+
func (f *SamlRole) UnmarshalJSON(data []byte) error {
28+
var role string
29+
if err := json.Unmarshal(data, &role); err == nil {
30+
f.Role = &role
31+
return nil
32+
}
33+
var ag SamlRoleAccessGroupID
34+
if err := json.Unmarshal(data, &ag); err == nil {
35+
f.AccessGroupID = &ag
36+
return nil
37+
}
38+
return fmt.Errorf("received json is neither Role string nor AccessGroupID map")
39+
}
40+
41+
type SamlRoles map[string]string
42+
43+
func (f *SamlRoles) UnmarshalJSON(data []byte) error {
44+
var result map[string]SamlRole
45+
if err := json.Unmarshal(data, &result); err != nil {
46+
return err
47+
}
48+
tmp := make(SamlRoles)
49+
for k, v := range result {
50+
k := k
51+
v := v
52+
if v.Role != nil {
53+
tmp[k] = *(v.Role)
54+
}
55+
}
56+
*f = tmp
57+
return nil
58+
}
59+
1760
type SamlConfig struct {
18-
Enforced bool `json:"enforced,omitempty"`
19-
Roles map[string]string `json:"roles,omitempty"`
61+
Enforced bool `json:"enforced,omitempty"`
62+
Roles SamlRoles `json:"roles,omitempty"`
2063
}
2164

2265
type TaxID struct {

client/team_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
)
10+
11+
func TestGetTeam(t *testing.T) {
12+
type TestCase struct {
13+
Name string
14+
ResponseJSON string
15+
}
16+
17+
for _, tc := range []TestCase{
18+
{
19+
Name: "SAML",
20+
ResponseJSON: `{ "saml": { "roles": { "A": "OWNER", "B": { "accessGroupId": "foo" } } } }`,
21+
},
22+
} {
23+
t.Run(tc.Name, func(t *testing.T) {
24+
h := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
25+
fmt.Fprintln(w, tc.ResponseJSON)
26+
}))
27+
cl := New("INVALID")
28+
cl.baseURL = fmt.Sprintf("http://%s", h.Listener.Addr().String())
29+
_, err := cl.GetTeam(context.Background(), "INVALID")
30+
if err != nil {
31+
t.Error(err)
32+
}
33+
})
34+
}
35+
}

0 commit comments

Comments
 (0)