Skip to content

Commit 2dd8295

Browse files
authored
Merge pull request #32 from coreruleset/feat/multivalue-actions
feat: add setvar action
2 parents f7152ab + 508f7bd commit 2dd8295

File tree

11 files changed

+383
-78
lines changed

11 files changed

+383
-78
lines changed

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
22
github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw=
3-
github.com/coreruleset/seclang_parser v0.1.1 h1:u0tUmM4cWxkMouDzELIDwBVw0mcBCjzo89DRmPiE+8U=
4-
github.com/coreruleset/seclang_parser v0.1.1/go.mod h1:joNgWwutayILH6THEiq2ypgYxmu816pdvVctt0rLRv4=
53
github.com/coreruleset/seclang_parser v0.2.0 h1:Rj1ZpLZxF2owZg8zgoCx59UUeHBlIxBh85A1WEMvGQU=
64
github.com/coreruleset/seclang_parser v0.2.0/go.mod h1:joNgWwutayILH6THEiq2ypgYxmu816pdvVctt0rLRv4=
75
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

listener/actions.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ func (l *ExtendedSeclangParserListener) EnterDisruptive_action_with_params(ctx *
4444
}
4545

4646
func (l *ExtendedSeclangParserListener) EnterNon_disruptive_action_with_params(ctx *parser.Non_disruptive_action_with_paramsContext) {
47-
l.setParam = func(value string) {
48-
action := types.StringToNonDisruptiveAction(ctx.GetText())
49-
err := l.currentDirective.GetActions().AddNonDisruptiveActionWithParam(action, value)
50-
if err != nil {
51-
panic(fmt.Sprintf("failed to add non-disruptive action with param: %v", err))
47+
if ctx.GetText() != "setvar" {
48+
l.setParam = func(value string) {
49+
action := types.StringToNonDisruptiveAction(ctx.GetText())
50+
err := l.currentDirective.GetActions().AddNonDisruptiveActionWithParam(action, value)
51+
if err != nil {
52+
panic(fmt.Sprintf("failed to add non-disruptive action with param: %v", err))
53+
}
5254
}
5355
}
5456
}
@@ -79,3 +81,23 @@ func (l *ExtendedSeclangParserListener) EnterAction_value_types(ctx *parser.Acti
7981
l.setParam = doNothingFuncString
8082
}
8183
}
84+
85+
func (l *ExtendedSeclangParserListener) EnterCol_name(ctx *parser.Col_nameContext) {
86+
l.varName = ctx.GetText()
87+
}
88+
89+
func (l *ExtendedSeclangParserListener) EnterSetvar_stmt(ctx *parser.Setvar_stmtContext) {
90+
l.varValue = ctx.GetText()
91+
}
92+
93+
func (l *ExtendedSeclangParserListener) EnterAssignment(ctx *parser.AssignmentContext) {
94+
l.parameter = ctx.GetText()
95+
}
96+
97+
func (l *ExtendedSeclangParserListener) EnterVar_assignment(ctx *parser.Var_assignmentContext) {
98+
l.currentDirective.GetActions().AddSetvarAction(l.varName, l.varValue, l.parameter, ctx.GetText())
99+
l.varName = ""
100+
l.varValue = ""
101+
l.parameter = ""
102+
l.setParam = doNothingFuncString
103+
}

listener_test.go

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ func mustNewActionWithParam[T types.ActionType](action T, param string) types.Ac
2727
return newAction
2828
}
2929

30+
func mustNewSetvarAction(collection types.CollectionName, operation string, vars []types.VarAssignment) types.Action {
31+
newAction, err := types.NewSetvarAction(collection, operation, vars)
32+
if err != nil {
33+
panic(err)
34+
}
35+
return newAction
36+
}
37+
3038
type testCase struct {
3139
name string
3240
payload string
@@ -161,27 +169,29 @@ SecAction \
161169
DisruptiveAction: mustNewActionOnly(types.Pass),
162170
NonDisruptiveActions: []types.Action{
163171
mustNewActionOnly(types.NoLog),
164-
mustNewActionWithParam(types.SetVar, "tx.blocking_inbound_anomaly_score=0"),
165-
mustNewActionWithParam(types.SetVar, "tx.detection_inbound_anomaly_score=0"),
166-
mustNewActionWithParam(types.SetVar, "tx.inbound_anomaly_score_pl1=0"),
167-
mustNewActionWithParam(types.SetVar, "tx.inbound_anomaly_score_pl2=0"),
168-
mustNewActionWithParam(types.SetVar, "tx.inbound_anomaly_score_pl3=0"),
169-
mustNewActionWithParam(types.SetVar, "tx.inbound_anomaly_score_pl4=0"),
170-
mustNewActionWithParam(types.SetVar, "tx.sql_injection_score=0"),
171-
mustNewActionWithParam(types.SetVar, "tx.xss_score=0"),
172-
mustNewActionWithParam(types.SetVar, "tx.rfi_score=0"),
173-
mustNewActionWithParam(types.SetVar, "tx.lfi_score=0"),
174-
mustNewActionWithParam(types.SetVar, "tx.rce_score=0"),
175-
mustNewActionWithParam(types.SetVar, "tx.php_injection_score=0"),
176-
mustNewActionWithParam(types.SetVar, "tx.http_violation_score=0"),
177-
mustNewActionWithParam(types.SetVar, "tx.session_fixation_score=0"),
178-
mustNewActionWithParam(types.SetVar, "tx.blocking_outbound_anomaly_score=0"),
179-
mustNewActionWithParam(types.SetVar, "tx.detection_outbound_anomaly_score=0"),
180-
mustNewActionWithParam(types.SetVar, "tx.outbound_anomaly_score_pl1=0"),
181-
mustNewActionWithParam(types.SetVar, "tx.outbound_anomaly_score_pl2=0"),
182-
mustNewActionWithParam(types.SetVar, "tx.outbound_anomaly_score_pl3=0"),
183-
mustNewActionWithParam(types.SetVar, "tx.outbound_anomaly_score_pl4=0"),
184-
mustNewActionWithParam(types.SetVar, "tx.anomaly_score=0"),
172+
mustNewSetvarAction(types.TX, "=", []types.VarAssignment{
173+
{Variable: "blocking_inbound_anomaly_score", Value: "0"},
174+
{Variable: "detection_inbound_anomaly_score", Value: "0"},
175+
{Variable: "inbound_anomaly_score_pl1", Value: "0"},
176+
{Variable: "inbound_anomaly_score_pl2", Value: "0"},
177+
{Variable: "inbound_anomaly_score_pl3", Value: "0"},
178+
{Variable: "inbound_anomaly_score_pl4", Value: "0"},
179+
{Variable: "sql_injection_score", Value: "0"},
180+
{Variable: "xss_score", Value: "0"},
181+
{Variable: "rfi_score", Value: "0"},
182+
{Variable: "lfi_score", Value: "0"},
183+
{Variable: "rce_score", Value: "0"},
184+
{Variable: "php_injection_score", Value: "0"},
185+
{Variable: "http_violation_score", Value: "0"},
186+
{Variable: "session_fixation_score", Value: "0"},
187+
{Variable: "blocking_outbound_anomaly_score", Value: "0"},
188+
{Variable: "detection_outbound_anomaly_score", Value: "0"},
189+
{Variable: "outbound_anomaly_score_pl1", Value: "0"},
190+
{Variable: "outbound_anomaly_score_pl2", Value: "0"},
191+
{Variable: "outbound_anomaly_score_pl3", Value: "0"},
192+
{Variable: "outbound_anomaly_score_pl4", Value: "0"},
193+
{Variable: "anomaly_score", Value: "0"},
194+
}),
185195
},
186196
},
187197
},
@@ -286,7 +296,7 @@ SecRule REQUEST_LINE "@rx (?i)^(?:get /[^#\?]*(?:\?[^\s\v#]*)?(?:#[^\s\v]*)?|(?:
286296
DisruptiveAction: mustNewActionOnly(types.Block),
287297
NonDisruptiveActions: []types.Action{
288298
mustNewActionWithParam(types.LogData, "%{request_line}"),
289-
mustNewActionWithParam(types.SetVar, "tx.inbound_anomaly_score_pl1=+%{tx.warning_anomaly_score}"),
299+
mustNewSetvarAction(types.TX, "=+", []types.VarAssignment{{Variable: "inbound_anomaly_score_pl1", Value: "%{tx.warning_anomaly_score}"}}),
290300
},
291301
},
292302
},
@@ -667,16 +677,18 @@ SecComponentSignature "OWASP_CRS/4.0.1-dev"`,
667677

668678
func TestLoadSecLang(t *testing.T) {
669679
for _, test := range listenerTestCases {
670-
got := types.ConfigurationList{}
671-
input := antlr.NewInputStream(test.payload)
672-
lexer := parser.NewSecLangLexer(input)
673-
stream := antlr.NewCommonTokenStream(lexer, 0)
674-
p := parser.NewSecLangParser(stream)
675-
start := p.Configuration()
676-
var seclangListener listener.ExtendedSeclangParserListener
677-
antlr.ParseTreeWalkerDefault.Walk(&seclangListener, start)
678-
got = seclangListener.ConfigurationList
680+
t.Run(test.name, func(t *testing.T) {
681+
got := types.ConfigurationList{}
682+
input := antlr.NewInputStream(test.payload)
683+
lexer := parser.NewSecLangLexer(input)
684+
stream := antlr.NewCommonTokenStream(lexer, 0)
685+
p := parser.NewSecLangParser(stream)
686+
start := p.Configuration()
687+
var seclangListener listener.ExtendedSeclangParserListener
688+
antlr.ParseTreeWalkerDefault.Walk(&seclangListener, start)
689+
got = seclangListener.ConfigurationList
679690

680-
require.Equalf(t, got, test.expected, test.name)
691+
require.Equalf(t, test.expected, got, test.name)
692+
})
681693
}
682694
}

types/actions.go

Lines changed: 139 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package types
22

33
import (
44
"fmt"
5+
"strings"
56
)
67

78
type SeclangActions struct {
@@ -110,6 +111,92 @@ func (a ActionWithParam) GetParam() string {
110111
return ""
111112
}
112113

114+
type VarAssignment struct {
115+
Variable string `yaml:"variable"`
116+
Value string `yaml:"value"`
117+
}
118+
119+
type SetvarAction struct {
120+
Collection CollectionName `yaml:"collection,omitempty"`
121+
Operation string `yaml:"operation,omitempty"`
122+
Assignments []VarAssignment `yaml:"assignments,omitempty"`
123+
}
124+
125+
// GetKey returns the action name (it is always "setvar")
126+
func (a SetvarAction) GetKey() string {
127+
return SetVar.String()
128+
}
129+
130+
// ToString allows to implement the Action interface
131+
func (a SetvarAction) ToString() string {
132+
if len(a.Assignments) == 0 {
133+
return ""
134+
}
135+
136+
var result []string
137+
// Reconstruct the setvar actions
138+
for _, asg := range a.Assignments {
139+
result = append(result, SetVar.String()+":"+a.Collection.String()+"."+asg.Variable+a.Operation+asg.Value)
140+
}
141+
return strings.Join(result, ", ")
142+
}
143+
144+
func (a *SetvarAction) AppendAssignment(variable, value string) error {
145+
a.Assignments = append(a.Assignments, VarAssignment{Variable: variable, Value: value})
146+
return nil
147+
}
148+
149+
func (a SetvarAction) GetAllParams() []string {
150+
if len(a.Assignments) == 0 {
151+
return []string{}
152+
}
153+
154+
var result []string
155+
// Get all the variables
156+
for _, asg := range a.Assignments {
157+
res := SetVar.String() + ":" + a.Collection.String() + "." + asg.Variable + a.Operation + asg.Value
158+
result = append(result, res)
159+
}
160+
return result
161+
}
162+
163+
func (s VarAssignment) MarshalYAML() (interface{}, error) {
164+
if s.Variable == "" || s.Value == "" {
165+
return nil, fmt.Errorf("invalid variable assignment: missing variable name or value")
166+
}
167+
return map[string]string{s.Variable: s.Value}, nil
168+
}
169+
170+
func (s SetvarAction) MarshalYAML() (interface{}, error) {
171+
if s.Collection == UNKNOWN_COLLECTION || s.Operation == "" || len(s.Assignments) == 0 {
172+
return nil, fmt.Errorf("invalid setvar action: missing collection name, operation, or assignments")
173+
}
174+
if s.Collection == TX && s.Operation == "=" {
175+
// Default case
176+
res := map[string][]VarAssignment{}
177+
res["setvar"] = s.Assignments
178+
return res, nil
179+
} else {
180+
// Non-default case, collection is different to `TX` or operation is different to `=`.
181+
// Fields are re-mapped to a mirrored struct in order to preserve the order in the YAML
182+
res := map[string]struct {
183+
Collection CollectionName
184+
Operation string
185+
Assignments []VarAssignment
186+
}{}
187+
res["setvar"] = struct {
188+
Collection CollectionName
189+
Operation string
190+
Assignments []VarAssignment
191+
}{
192+
Collection: s.Collection,
193+
Operation: s.Operation,
194+
Assignments: s.Assignments,
195+
}
196+
return res, nil
197+
}
198+
}
199+
113200
// ActionType is a constraint for all action types
114201
type ActionType interface {
115202
DisruptiveAction | FlowAction | DataAction | NonDisruptiveAction
@@ -130,6 +217,14 @@ func NewActionWithParam[T ActionType](action T, param string) (ActionWithParam,
130217
return ActionWithParam{actionStr: param}, nil
131218
}
132219

220+
// NewSetvarAction creates a new SetvarAction with the given collection name, operation, and variable assignments
221+
func NewSetvarAction(collection CollectionName, operation string, vars []VarAssignment) (SetvarAction, error) {
222+
if collection != GLOBAL && collection != IP && collection != RESOURCE && collection != SESSION && collection != TX && collection != USER {
223+
return SetvarAction{}, fmt.Errorf("invalid setvar action: invalid collection name '%s'", collection)
224+
}
225+
return SetvarAction{Collection: collection, Operation: operation, Assignments: vars}, nil
226+
}
227+
133228
type DisruptiveAction int
134229

135230
const (
@@ -431,6 +526,42 @@ func (s *SeclangActions) AddNonDisruptiveActionWithParam(action NonDisruptiveAct
431526
return nil
432527
}
433528

529+
// AddSetvarAction adds a setvar action to the NonDisruptiveActions list
530+
func (s *SeclangActions) AddSetvarAction(collection, variable, operation, value string) error {
531+
colName := stringToCollectionName(strings.ToUpper(collection))
532+
// Check if there is already a setvar action in the last position
533+
if len(s.NonDisruptiveActions) > 0 {
534+
lastAction := s.NonDisruptiveActions[len(s.NonDisruptiveActions)-1]
535+
if lastAction.GetKey() != SetVar.String() || lastAction.(SetvarAction).Collection != colName || lastAction.(SetvarAction).Operation != operation {
536+
// If the last action is not setvar, we need to create a new one
537+
newAction, err := NewSetvarAction(colName, operation, []VarAssignment{{Variable: variable, Value: value}})
538+
if err != nil {
539+
return err
540+
}
541+
s.NonDisruptiveActions = append(s.NonDisruptiveActions, newAction)
542+
} else {
543+
// If the last action is setvar, we need to append the param to it
544+
sv, ok := lastAction.(SetvarAction)
545+
if !ok {
546+
return fmt.Errorf("invalid action type: expected SetvarAction, got %T", lastAction)
547+
}
548+
err := sv.AppendAssignment(variable, value)
549+
if err != nil {
550+
return err
551+
}
552+
s.NonDisruptiveActions[len(s.NonDisruptiveActions)-1] = sv
553+
}
554+
} else {
555+
// If there are no actions yet, we need to create a new setvar action
556+
newAction, err := NewSetvarAction(colName, operation, []VarAssignment{{Variable: variable, Value: value}})
557+
if err != nil {
558+
return err
559+
}
560+
s.NonDisruptiveActions = append(s.NonDisruptiveActions, newAction)
561+
}
562+
return nil
563+
}
564+
434565
func (s *SeclangActions) AddNonDisruptiveActionOnly(action NonDisruptiveAction) error {
435566
newAction, err := NewActionOnly(action)
436567
if err != nil {
@@ -518,35 +649,24 @@ func (s *SeclangActions) GetActionByKey(key string) Action {
518649
return ActionWithParam{}
519650
}
520651

521-
func (s *SeclangActions) GetActionsByKey(key string) []ActionWithParam {
522-
actions := []ActionWithParam{}
523-
// if s.DisruptiveAction != nil {
524-
// if s.DisruptiveAction.ToString() == key {
525-
// actions = append(actions, s.DisruptiveAction)
526-
// }
527-
// }
652+
func (s *SeclangActions) GetActionsByKey(key string) []Action {
653+
actions := []Action{}
654+
if s.DisruptiveAction != nil && s.DisruptiveAction.GetKey() == key {
655+
actions = append(actions, s.DisruptiveAction)
656+
}
528657
for _, action := range s.NonDisruptiveActions {
529658
if action.GetKey() == key {
530-
aP, ok := action.(ActionWithParam)
531-
if ok {
532-
actions = append(actions, aP)
533-
}
659+
actions = append(actions, action)
534660
}
535661
}
536662
for _, action := range s.FlowActions {
537663
if action.GetKey() == key {
538-
aP, ok := action.(ActionWithParam)
539-
if ok {
540-
actions = append(actions, aP)
541-
}
664+
actions = append(actions, action)
542665
}
543666
}
544667
for _, action := range s.DataActions {
545668
if action.GetKey() == key {
546-
aP, ok := action.(ActionWithParam)
547-
if ok {
548-
actions = append(actions, aP)
549-
}
669+
actions = append(actions, action)
550670
}
551671
}
552672
return actions

0 commit comments

Comments
 (0)