Skip to content

Commit 75fc090

Browse files
authored
Merge pull request #31 from coreruleset/feat/var-enums
feat: add enums for variables and collections
2 parents 7220a8b + 82d7698 commit 75fc090

File tree

6 files changed

+717
-153
lines changed

6 files changed

+717
-153
lines changed

types/collection_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package types
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"go.yaml.in/yaml/v4"
8+
)
9+
10+
var (
11+
colTests = []struct {
12+
collection CollectionName
13+
yamlStr string
14+
}{
15+
{ARGS, "ARGS"},
16+
{ARGS_GET, "ARGS_GET"},
17+
{ARGS_GET_NAMES, "ARGS_GET_NAMES"},
18+
{ARGS_NAMES, "ARGS_NAMES"},
19+
{ARGS_POST_NAMES, "ARGS_POST_NAMES"},
20+
{ARGS_POST, "ARGS_POST"},
21+
{ENV, "ENV"},
22+
{FILES, "FILES"},
23+
{GEO, "GEO"},
24+
{GLOBAL, "GLOBAL"},
25+
{IP, "IP"},
26+
{MATCHED_VARS_NAMES, "MATCHED_VARS_NAMES"},
27+
{MATCHED_VARS, "MATCHED_VARS"},
28+
{MULTIPART_PART_HEADERS, "MULTIPART_PART_HEADERS"},
29+
{PERF_RULES, "PERF_RULES"},
30+
{REQUEST_COOKIES_NAMES, "REQUEST_COOKIES_NAMES"},
31+
{REQUEST_COOKIES, "REQUEST_COOKIES"},
32+
{REQUEST_HEADERS_NAMES, "REQUEST_HEADERS_NAMES"},
33+
{REQUEST_HEADERS, "REQUEST_HEADERS"},
34+
{RESPONSE_HEADERS_NAMES, "RESPONSE_HEADERS_NAMES"},
35+
{RESPONSE_HEADERS, "RESPONSE_HEADERS"},
36+
{RULE, "RULE"},
37+
{SESSION, "SESSION"},
38+
{TX, "TX"},
39+
{XML, "XML"},
40+
}
41+
)
42+
43+
func TestCollectionNameToString(t *testing.T) {
44+
for _, tt := range colTests {
45+
t.Run(tt.yamlStr, func(t *testing.T) {
46+
if tt.collection.String() != tt.yamlStr {
47+
t.Errorf("Expected %q, got %q", tt.yamlStr, tt.collection.String())
48+
}
49+
})
50+
}
51+
}
52+
53+
func TestStringToCollectionName(t *testing.T) {
54+
for _, tt := range colTests {
55+
t.Run(tt.yamlStr, func(t *testing.T) {
56+
collection := stringToCollectionName(tt.yamlStr)
57+
if collection != tt.collection {
58+
t.Errorf("Expected %q, got %q", tt.collection, collection)
59+
}
60+
})
61+
}
62+
}
63+
64+
func TestMarshalCollectionName(t *testing.T) {
65+
for _, tt := range colTests {
66+
t.Run(tt.yamlStr, func(t *testing.T) {
67+
data, err := yaml.Marshal(tt.collection)
68+
if err != nil {
69+
t.Fatalf("Failed to marshal: %v", err)
70+
}
71+
if string(data) != tt.yamlStr+"\n" {
72+
t.Errorf("Expected %q, got %q", tt.yamlStr+"\n", data)
73+
}
74+
})
75+
}
76+
}
77+
78+
func TestUnknownCollectionName(t *testing.T) {
79+
t.Run("marshal unknown", func(t *testing.T) {
80+
unknown := UNKNOWN_COLLECTION
81+
_, err := unknown.MarshalYAML()
82+
assert.Error(t, err)
83+
assert.Equal(t, "Unknown collection name", err.Error())
84+
})
85+
}

types/collections.go

Lines changed: 109 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,36 @@ type Collection struct {
99
Count bool `yaml:"count,omitempty"`
1010
}
1111

12-
type CollectionName string
12+
type CollectionName int
1313

1414
const (
1515
// Collections
16-
ARGS CollectionName = "ARGS"
17-
ARGS_GET CollectionName = "ARGS_GET"
18-
ARGS_GET_NAMES CollectionName = "ARGS_GET_NAMES"
19-
ARGS_NAMES CollectionName = "ARGS_NAMES"
20-
ARGS_POST_NAMES CollectionName = "ARGS_POST_NAMES"
21-
ARGS_POST CollectionName = "ARGS_POST"
22-
ENV CollectionName = "ENV"
23-
FILES CollectionName = "FILES"
24-
GEO CollectionName = "GEO"
25-
GLOBAL CollectionName = "GLOBAL"
26-
IP CollectionName = "IP"
27-
MATCHED_VARS_NAMES CollectionName = "MATCHED_VARS_NAMES"
28-
MATCHED_VARS CollectionName = "MATCHED_VARS"
29-
MULTIPART_PART_HEADERS CollectionName = "MULTIPART_PART_HEADERS"
30-
PERF_RULES CollectionName = "PERF_RULES"
31-
REQUEST_COOKIES_NAMES CollectionName = "REQUEST_COOKIES_NAMES"
32-
REQUEST_COOKIES CollectionName = "REQUEST_COOKIES"
33-
REQUEST_HEADERS_NAMES CollectionName = "REQUEST_HEADERS_NAMES"
34-
REQUEST_HEADERS CollectionName = "REQUEST_HEADERS"
35-
RESPONSE_HEADERS_NAMES CollectionName = "RESPONSE_HEADERS_NAMES"
36-
RESPONSE_HEADERS CollectionName = "RESPONSE_HEADERS"
37-
RULE CollectionName = "RULE"
38-
SESSION CollectionName = "SESSION"
39-
TX CollectionName = "TX"
40-
XML CollectionName = "XML"
16+
UNKNOWN_COLLECTION CollectionName = iota
17+
ARGS
18+
ARGS_GET
19+
ARGS_GET_NAMES
20+
ARGS_NAMES
21+
ARGS_POST_NAMES
22+
ARGS_POST
23+
ENV
24+
FILES
25+
GEO
26+
GLOBAL
27+
IP
28+
MATCHED_VARS_NAMES
29+
MATCHED_VARS
30+
MULTIPART_PART_HEADERS
31+
PERF_RULES
32+
REQUEST_COOKIES_NAMES
33+
REQUEST_COOKIES
34+
REQUEST_HEADERS_NAMES
35+
REQUEST_HEADERS
36+
RESPONSE_HEADERS_NAMES
37+
RESPONSE_HEADERS
38+
RULE
39+
SESSION
40+
TX
41+
XML
4142
)
4243

4344
var (
@@ -70,26 +71,102 @@ var (
7071
}
7172
)
7273

74+
func (c CollectionName) String() string {
75+
switch c {
76+
case ARGS:
77+
return "ARGS"
78+
case ARGS_GET:
79+
return "ARGS_GET"
80+
case ARGS_GET_NAMES:
81+
return "ARGS_GET_NAMES"
82+
case ARGS_NAMES:
83+
return "ARGS_NAMES"
84+
case ARGS_POST_NAMES:
85+
return "ARGS_POST_NAMES"
86+
case ARGS_POST:
87+
return "ARGS_POST"
88+
case ENV:
89+
return "ENV"
90+
case FILES:
91+
return "FILES"
92+
case GEO:
93+
return "GEO"
94+
case GLOBAL:
95+
return "GLOBAL"
96+
case IP:
97+
return "IP"
98+
case MATCHED_VARS_NAMES:
99+
return "MATCHED_VARS_NAMES"
100+
case MATCHED_VARS:
101+
return "MATCHED_VARS"
102+
case MULTIPART_PART_HEADERS:
103+
return "MULTIPART_PART_HEADERS"
104+
case PERF_RULES:
105+
return "PERF_RULES"
106+
case REQUEST_COOKIES_NAMES:
107+
return "REQUEST_COOKIES_NAMES"
108+
case REQUEST_COOKIES:
109+
return "REQUEST_COOKIES"
110+
case REQUEST_HEADERS_NAMES:
111+
return "REQUEST_HEADERS_NAMES"
112+
case REQUEST_HEADERS:
113+
return "REQUEST_HEADERS"
114+
case RESPONSE_HEADERS_NAMES:
115+
return "RESPONSE_HEADERS_NAMES"
116+
case RESPONSE_HEADERS:
117+
return "RESPONSE_HEADERS"
118+
case RULE:
119+
return "RULE"
120+
case SESSION:
121+
return "SESSION"
122+
case TX:
123+
return "TX"
124+
case XML:
125+
return "XML"
126+
default:
127+
return "unknown"
128+
}
129+
}
130+
131+
func (c CollectionName) MarshalYAML() (interface{}, error) {
132+
if c == UNKNOWN_COLLECTION {
133+
return nil, fmt.Errorf("Unknown collection name")
134+
}
135+
return c.String(), nil
136+
}
137+
138+
func (c *CollectionName) UnmarshalYAML(unmarshal func(interface{}) error) error {
139+
var name string
140+
if err := unmarshal(&name); err != nil {
141+
return err
142+
}
143+
*c = stringToCollectionName(name)
144+
if *c == UNKNOWN_COLLECTION {
145+
return fmt.Errorf("Collection name %s is not valid", name)
146+
}
147+
return nil
148+
}
149+
73150
func CollectionsToString(collections []Collection, separator string) string {
74151
result := ""
75152
for i, collection := range collections {
76153
if len(collection.Arguments) == 0 && len(collection.Excluded) == 0 {
77154
if collection.Count {
78155
result += "&"
79156
}
80-
result += string(collection.Name)
157+
result += collection.Name.String()
81158
} else {
82159
for j, arg := range collection.Arguments {
83160
if collection.Count {
84161
result += "&"
85162
}
86-
result += string(collection.Name) + ":" + arg
163+
result += collection.Name.String() + ":" + arg
87164
if j != len(collection.Arguments)-1 || len(collection.Excluded) > 0 {
88165
result += separator
89166
}
90167
}
91168
for j, excluded := range collection.Excluded {
92-
result += "!" + string(collection.Name) + ":" + excluded
169+
result += "!" + collection.Name.String() + ":" + excluded
93170
if j != len(collection.Excluded)-1 {
94171
result += separator
95172
}
@@ -102,10 +179,10 @@ func CollectionsToString(collections []Collection, separator string) string {
102179
return result
103180
}
104181

105-
func GetCollection(name string) (CollectionName, error) {
182+
func stringToCollectionName(name string) CollectionName {
106183
col, exists := allCollections[name]
107184
if !exists {
108-
return "", fmt.Errorf("Invalid collection name: %s", name)
185+
return UNKNOWN_COLLECTION
109186
}
110-
return col, nil
187+
return col
111188
}

types/secrule.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package types
22

33
import (
4+
"fmt"
45
"slices"
56
)
67

@@ -34,9 +35,9 @@ func (d SecRule) GetTransformations() Transformations {
3435
}
3536

3637
func (s *SecRule) AddVariable(name string, excluded bool) error {
37-
variable, err := GetVariable(name)
38-
if err != nil {
39-
return err
38+
variable := stringToVariableName(name)
39+
if variable == UNKNOWN_VAR {
40+
return fmt.Errorf("Invalid variable name: %s", name)
4041
}
4142
if excluded {
4243
vars := []Variable{}
@@ -53,9 +54,9 @@ func (s *SecRule) AddVariable(name string, excluded bool) error {
5354
}
5455

5556
func (s *SecRule) AddCollection(name, value string, excluded, asCount bool) error {
56-
col, err := GetCollection(name)
57-
if err != nil {
58-
return err
57+
col := stringToCollectionName(name)
58+
if col == UNKNOWN_COLLECTION {
59+
return fmt.Errorf("Invalid collection name: %s", name)
5960
}
6061
if excluded && !asCount {
6162
results := []Collection{}

types/update_target.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package types
22

33
import (
4+
"fmt"
45
"strconv"
56
)
67

@@ -22,9 +23,9 @@ func NewUpdateTargetDirective() *UpdateTargetDirective {
2223
}
2324

2425
func (d *UpdateTargetDirective) AddVariable(name string, excluded bool) error {
25-
variable, err := GetVariable(name)
26-
if err != nil {
27-
return err
26+
variable := stringToVariableName(name)
27+
if variable == UNKNOWN_VAR {
28+
return fmt.Errorf("Invalid variable name: %s", name)
2829
}
2930
if excluded {
3031
vars := []Variable{}
@@ -41,9 +42,9 @@ func (d *UpdateTargetDirective) AddVariable(name string, excluded bool) error {
4142
}
4243

4344
func (d *UpdateTargetDirective) AddCollection(name, value string, excluded, asCount bool) error {
44-
col, err := GetCollection(name)
45-
if err != nil {
46-
return err
45+
col := stringToCollectionName(name)
46+
if col == UNKNOWN_COLLECTION {
47+
return fmt.Errorf("Invalid collection name: %s", name)
4748
}
4849
if excluded && !asCount {
4950
results := []Collection{}

0 commit comments

Comments
 (0)