Skip to content

Commit 7751e83

Browse files
authored
feat: refactor ActionResult to RuleResult (#88)
1 parent dd654ff commit 7751e83

25 files changed

+235
-111
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ logs/
2424
lastupdate.tmp
2525
commentsRouter*.go
2626
acme_account.key
27+
caswaf

routers/base.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2023 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
package routers
216

317
import (

rule/rule.go

Lines changed: 45 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@ import (
2222
)
2323

2424
type Rule interface {
25-
checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error)
25+
checkRule(expressions []*object.Expression, req *http.Request) (*RuleResult, error)
2626
}
2727

28-
type ActionResult struct {
29-
Type string
28+
type RuleResult struct {
29+
Action string
3030
StatusCode int
31+
Reason string
3132
}
3233

33-
func CheckRules(ruleIds []string, r *http.Request) (*ActionResult, string, error) {
34+
func CheckRules(ruleIds []string, r *http.Request) (*RuleResult, error) {
3435
rules, err := object.GetRulesByRuleIds(ruleIds)
3536
if err != nil {
36-
return nil, "", err
37+
return nil, err
3738
}
3839
for i, rule := range rules {
3940
var ruleObj Rule
@@ -51,63 +52,57 @@ func CheckRules(ruleIds []string, r *http.Request) (*ActionResult, string, error
5152
case "Compound":
5253
ruleObj = &CompoundRule{}
5354
default:
54-
return nil, "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId())
55+
return nil, fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId())
5556
}
5657

57-
isHit, action, reason, err := ruleObj.checkRule(rule.Expressions, r)
58+
result, err := ruleObj.checkRule(rule.Expressions, r)
5859
if err != nil {
59-
return nil, "", err
60+
return nil, err
6061
}
6162

62-
// Use rule's action if no action specified by the rule check
63-
if action == "" {
64-
action = rule.Action
65-
}
66-
67-
// Determine status code
68-
statusCode := rule.StatusCode
69-
if statusCode == 0 {
70-
// Set default status codes if not specified
71-
switch action {
72-
case "Block":
73-
statusCode = 403
74-
case "Drop":
75-
statusCode = 400
76-
case "Allow":
77-
statusCode = 200
78-
case "CAPTCHA":
79-
statusCode = 302
80-
default:
81-
return nil, "", fmt.Errorf("unknown rule action: %s for rule: %s", action, rule.GetId())
63+
if result != nil {
64+
// Use rule's action if no action specified by the rule check
65+
if result.Action == "" {
66+
result.Action = rule.Action
8267
}
83-
}
84-
85-
actionResult := &ActionResult{
86-
Type: action,
87-
StatusCode: statusCode,
88-
}
89-
90-
if isHit {
91-
if action == "Block" || action == "Drop" {
92-
if rule.Reason != "" {
93-
reason = rule.Reason
68+
69+
// Determine status code
70+
if result.StatusCode == 0 {
71+
if rule.StatusCode != 0 {
72+
result.StatusCode = rule.StatusCode
9473
} else {
95-
reason = fmt.Sprintf("hit rule %s: %s", ruleIds[i], reason)
74+
// Set default status codes if not specified
75+
switch result.Action {
76+
case "Block":
77+
result.StatusCode = 403
78+
case "Drop":
79+
result.StatusCode = 400
80+
case "Allow":
81+
result.StatusCode = 200
82+
case "CAPTCHA":
83+
result.StatusCode = 302
84+
default:
85+
return nil, fmt.Errorf("unknown rule action: %s for rule: %s", result.Action, rule.GetId())
86+
}
87+
}
88+
}
89+
90+
// Update reason if rule has custom reason
91+
if result.Action == "Block" || result.Action == "Drop" {
92+
if rule.Reason != "" {
93+
result.Reason = rule.Reason
94+
} else if result.Reason != "" {
95+
result.Reason = fmt.Sprintf("hit rule %s: %s", ruleIds[i], result.Reason)
9696
}
97-
return actionResult, reason, nil
98-
} else if action == "Allow" {
99-
return actionResult, reason, nil
100-
} else if action == "CAPTCHA" {
101-
return actionResult, reason, nil
102-
} else {
103-
return nil, "", fmt.Errorf("unknown rule action: %s for rule: %s", action, rule.GetId())
10497
}
98+
99+
return result, nil
105100
}
106101
}
107102

108103
// Default action if no rule matched
109-
return &ActionResult{
110-
Type: "Allow",
104+
return &RuleResult{
105+
Action: "Allow",
111106
StatusCode: 200,
112-
}, "", nil
107+
}, nil
113108
}

rule/rule_compound.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ import (
2424

2525
type CompoundRule struct{}
2626

27-
func (r *CompoundRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
27+
func (r *CompoundRule) checkRule(expressions []*object.Expression, req *http.Request) (*RuleResult, error) {
2828
operators := util.NewStack()
2929
res := true
3030
for _, expression := range expressions {
3131
isHit := true
32-
action, _, err := CheckRules([]string{expression.Value}, req)
32+
result, err := CheckRules([]string{expression.Value}, req)
3333
if err != nil {
34-
return false, "", "", err
34+
return nil, err
3535
}
36-
if action.Type == "" {
36+
if result == nil || result.Action == "" {
3737
isHit = false
3838
}
3939
switch expression.Operator {
@@ -43,7 +43,7 @@ func (r *CompoundRule) checkRule(expressions []*object.Expression, req *http.Req
4343
operators.Push(res)
4444
res = isHit
4545
default:
46-
return false, "", "", fmt.Errorf("unknown operator: %s", expression.Operator)
46+
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
4747
}
4848
if operators.Size() > 0 {
4949
last, ok := operators.Pop()
@@ -53,5 +53,8 @@ func (r *CompoundRule) checkRule(expressions []*object.Expression, req *http.Req
5353
}
5454
}
5555
}
56-
return res, "", "", nil
56+
if res {
57+
return &RuleResult{}, nil
58+
}
59+
return nil, nil
5760
}

rule/rule_ip.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ import (
2626

2727
type IpRule struct{}
2828

29-
func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
29+
func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request) (*RuleResult, error) {
3030
clientIp := util.GetClientIp(req)
3131
netIp, err := parseIp(clientIp)
3232
if err != nil {
33-
return false, "", "", err
33+
return nil, err
3434
}
3535
for _, expression := range expressions {
3636
reason := fmt.Sprintf("expression matched: \"%s %s %s\"", clientIp, expression.Operator, expression.Value)
@@ -39,40 +39,40 @@ func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request)
3939
if strings.Contains(ip, "/") {
4040
_, ipNet, err := net.ParseCIDR(ip)
4141
if err != nil {
42-
return false, "", "", err
42+
return nil, err
4343
}
4444

4545
switch expression.Operator {
4646
case "is in":
4747
if ipNet.Contains(netIp) {
48-
return true, "", reason, nil
48+
return &RuleResult{Reason: reason}, nil
4949
}
5050
case "is not in":
5151
if !ipNet.Contains(netIp) {
52-
return true, "", reason, nil
52+
return &RuleResult{Reason: reason}, nil
5353
}
5454
default:
55-
return false, "", "", fmt.Errorf("unknown operator: %s", expression.Operator)
55+
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
5656
}
5757
} else if strings.ContainsAny(ip, ".:") {
5858
switch expression.Operator {
5959
case "is in":
6060
if ip == clientIp {
61-
return true, "", reason, nil
61+
return &RuleResult{Reason: reason}, nil
6262
}
6363
case "is not in":
6464
if ip != clientIp {
65-
return true, "", reason, nil
65+
return &RuleResult{Reason: reason}, nil
6666
}
6767
default:
68-
return false, "", "", fmt.Errorf("unknown operator: %s", expression.Operator)
68+
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
6969
}
7070
} else {
71-
return false, "", "", fmt.Errorf("unknown IP or CIDR format: %s", ip)
71+
return nil, fmt.Errorf("unknown IP or CIDR format: %s", ip)
7272
}
7373
}
7474
}
75-
return false, "", "", nil
75+
return nil, nil
7676
}
7777

7878
func parseIp(ipStr string) (net.IP, error) {

rule/rule_ip_rate.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func (i *IpRateLimiter) GetLimiter(ip string) *rate.Limiter {
8080
return limiter
8181
}
8282

83-
func (r *IpRateRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
83+
func (r *IpRateRule) checkRule(expressions []*object.Expression, req *http.Request) (*RuleResult, error) {
8484
expression := expressions[0] // IpRate rule should have only one expression
8585
clientIp := util.GetClientIp(req)
8686

@@ -89,7 +89,10 @@ func (r *IpRateRule) checkRule(expressions []*object.Expression, req *http.Reque
8989
if ok {
9090
blockTime := util.ParseInt(expression.Value)
9191
if time.Now().Sub(createAt) < time.Duration(blockTime)*time.Second {
92-
return true, "Block", "Rate limit exceeded", nil
92+
return &RuleResult{
93+
Action: "Block",
94+
Reason: "Rate limit exceeded",
95+
}, nil
9396
} else {
9497
delete(blackList, clientIp)
9598
}
@@ -112,17 +115,20 @@ func (r *IpRateRule) checkRule(expressions []*object.Expression, req *http.Reque
112115
limiter.SetBurst(ipRateLimiter.b)
113116
err := limiter.Wait(req.Context())
114117
if err != nil {
115-
return false, "", "", err
118+
return nil, err
116119
}
117120
} else {
118121
// If the rate limit is exceeded, add the client IP to the blacklist
119122
allow := limiter.Allow()
120123
if !allow {
121124
blackList[r.ruleName] = map[string]time.Time{}
122125
blackList[r.ruleName][clientIp] = time.Now()
123-
return true, "Block", "Rate limit exceeded", nil
126+
return &RuleResult{
127+
Action: "Block",
128+
Reason: "Rate limit exceeded",
129+
}, nil
124130
}
125131
}
126132

127-
return false, "", "", nil
133+
return nil, nil
128134
}

rule/rule_ip_rate_test.go

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2023 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
package rule
216

317
import (
@@ -113,19 +127,26 @@ func TestIpRateRule_checkRule(t *testing.T) {
113127
ruleName: tt.fields.ruleName,
114128
}
115129
for i, arg := range tt.args.args {
116-
got, got1, got2, err := r.checkRule(arg.expressions, arg.req)
130+
result, err := r.checkRule(arg.expressions, arg.req)
117131
if (err != nil) != tt.wantErr[i] {
118132
t.Errorf("checkRule() error = %v, wantErr %v", err, tt.wantErr)
119133
return
120134
}
135+
got := result != nil
136+
got1 := ""
137+
got2 := ""
138+
if result != nil {
139+
got1 = result.Action
140+
got2 = result.Reason
141+
}
121142
if got != tt.want[i] {
122-
t.Errorf("checkRule() got = %v, want %v", got, tt.want)
143+
t.Errorf("checkRule() got = %v, want %v", got, tt.want[i])
123144
}
124145
if got1 != tt.want1[i] {
125-
t.Errorf("checkRule() got1 = %v, want %v", got1, tt.want1)
146+
t.Errorf("checkRule() got1 = %v, want %v", got1, tt.want1[i])
126147
}
127148
if got2 != tt.want2[i] {
128-
t.Errorf("checkRule() got2 = %v, want %v", got2, tt.want2)
149+
t.Errorf("checkRule() got2 = %v, want %v", got2, tt.want2[i])
129150
}
130151
}
131152
})

rule/rule_ua.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,39 +25,39 @@ import (
2525

2626
type UaRule struct{}
2727

28-
func (r *UaRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
28+
func (r *UaRule) checkRule(expressions []*object.Expression, req *http.Request) (*RuleResult, error) {
2929
userAgent := req.UserAgent()
3030
for _, expression := range expressions {
3131
ua := expression.Value
3232
reason := fmt.Sprintf("expression matched: \"%s %s %s\"", userAgent, expression.Operator, expression.Value)
3333
switch expression.Operator {
3434
case "contains":
3535
if strings.Contains(userAgent, ua) {
36-
return true, "", reason, nil
36+
return &RuleResult{Reason: reason}, nil
3737
}
3838
case "does not contain":
3939
if !strings.Contains(userAgent, ua) {
40-
return true, "", reason, nil
40+
return &RuleResult{Reason: reason}, nil
4141
}
4242
case "equals":
4343
if userAgent == ua {
44-
return true, "", reason, nil
44+
return &RuleResult{Reason: reason}, nil
4545
}
4646
case "does not equal":
4747
if strings.Compare(userAgent, ua) != 0 {
48-
return true, "", reason, nil
48+
return &RuleResult{Reason: reason}, nil
4949
}
5050
case "match":
5151
// regex match
5252
isHit, err := regexp.MatchString(ua, userAgent)
5353
if err != nil {
54-
return false, "", "", err
54+
return nil, err
5555
}
5656
if isHit {
57-
return true, "", reason, nil
57+
return &RuleResult{Reason: reason}, nil
5858
}
5959
}
6060
}
6161

62-
return false, "", "", nil
62+
return nil, nil
6363
}

0 commit comments

Comments
 (0)