Skip to content

Commit 1c4a785

Browse files
authored
feat: add "is abroad" operator to IP rules (#92)
1 parent 532406a commit 1c4a785

File tree

6 files changed

+256
-7
lines changed

6 files changed

+256
-7
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
.idea/
1818
*.iml
1919

20+
# Built binaries
21+
caswaf
22+
2023
tmp/
2124
tmpFiles/
2225
*.tmp

ip/17monipdb.dat

1.48 MB
Binary file not shown.

ip/ip.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2024 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ip
16+
17+
import "fmt"
18+
19+
func InitIpDb() {
20+
err := Init("ip/17monipdb.dat")
21+
if err != nil {
22+
panic(err)
23+
}
24+
}
25+
26+
func IsAbroadIp(ip string) bool {
27+
info, err := Find(ip)
28+
if err != nil {
29+
fmt.Printf("error: ip = %s, error = %s\n", ip, err.Error())
30+
return false
31+
}
32+
33+
return info.Country != "中国"
34+
}

ip/ip17mon.go

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// Copyright 2022 The casbin Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ip
16+
17+
import (
18+
"bytes"
19+
"encoding/binary"
20+
"errors"
21+
"io/ioutil"
22+
"net"
23+
)
24+
25+
const Null = "N/A"
26+
27+
var (
28+
ErrInvalidIp = errors.New("invalid ip format")
29+
std *Locator
30+
)
31+
32+
// Init default locator with dataFile
33+
func Init(dataFile string) (err error) {
34+
if std != nil {
35+
return
36+
}
37+
std, err = NewLocator(dataFile)
38+
return
39+
}
40+
41+
// Init default locator with data
42+
func InitWithData(data []byte) {
43+
if std != nil {
44+
return
45+
}
46+
std = NewLocatorWithData(data)
47+
return
48+
}
49+
50+
// Find locationInfo by ip string
51+
// It will return err when ipstr is not a valid format
52+
func Find(ipstr string) (*LocationInfo, error) {
53+
return std.Find(ipstr)
54+
}
55+
56+
// Find locationInfo by uint32
57+
func FindByUint(ip uint32) *LocationInfo {
58+
return std.FindByUint(ip)
59+
}
60+
61+
//-----------------------------------------------------------------------------
62+
63+
// New locator with dataFile
64+
func NewLocator(dataFile string) (loc *Locator, err error) {
65+
data, err := ioutil.ReadFile(dataFile)
66+
if err != nil {
67+
return
68+
}
69+
loc = NewLocatorWithData(data)
70+
return
71+
}
72+
73+
// New locator with data
74+
func NewLocatorWithData(data []byte) (loc *Locator) {
75+
loc = new(Locator)
76+
loc.init(data)
77+
return
78+
}
79+
80+
type Locator struct {
81+
textData []byte
82+
indexData1 []uint32
83+
indexData2 []int
84+
indexData3 []int
85+
index []int
86+
}
87+
88+
type LocationInfo struct {
89+
Country string
90+
Region string
91+
City string
92+
Isp string
93+
}
94+
95+
// Find locationInfo by ip string
96+
// It will return err when ipstr is not a valid format
97+
func (loc *Locator) Find(ipstr string) (info *LocationInfo, err error) {
98+
ip := net.ParseIP(ipstr).To4()
99+
if ip == nil || ip.To4() == nil {
100+
err = ErrInvalidIp
101+
return
102+
}
103+
info = loc.FindByUint(binary.BigEndian.Uint32([]byte(ip)))
104+
return
105+
}
106+
107+
// Find locationInfo by uint32
108+
func (loc *Locator) FindByUint(ip uint32) (info *LocationInfo) {
109+
end := len(loc.indexData1) - 1
110+
if ip>>24 != 0xff {
111+
end = loc.index[(ip>>24)+1]
112+
}
113+
idx := loc.findIndexOffset(ip, loc.index[ip>>24], end)
114+
off := loc.indexData2[idx]
115+
return newLocationInfo(loc.textData[off : off+loc.indexData3[idx]])
116+
}
117+
118+
// binary search
119+
func (loc *Locator) findIndexOffset(ip uint32, start, end int) int {
120+
for start < end {
121+
mid := (start + end) / 2
122+
if ip > loc.indexData1[mid] {
123+
start = mid + 1
124+
} else {
125+
end = mid
126+
}
127+
}
128+
129+
if loc.indexData1[end] >= ip {
130+
return end
131+
}
132+
133+
return start
134+
}
135+
136+
func (loc *Locator) init(data []byte) {
137+
textoff := int(binary.BigEndian.Uint32(data[:4]))
138+
139+
loc.textData = data[textoff-1024:]
140+
141+
loc.index = make([]int, 256)
142+
for i := 0; i < 256; i++ {
143+
off := 4 + i*4
144+
loc.index[i] = int(binary.LittleEndian.Uint32(data[off : off+4]))
145+
}
146+
147+
nidx := (textoff - 4 - 1024 - 1024) / 8
148+
149+
loc.indexData1 = make([]uint32, nidx)
150+
loc.indexData2 = make([]int, nidx)
151+
loc.indexData3 = make([]int, nidx)
152+
153+
for i := 0; i < nidx; i++ {
154+
off := 4 + 1024 + i*8
155+
loc.indexData1[i] = binary.BigEndian.Uint32(data[off : off+4])
156+
loc.indexData2[i] = int(uint32(data[off+4]) | uint32(data[off+5])<<8 | uint32(data[off+6])<<16)
157+
loc.indexData3[i] = int(data[off+7])
158+
}
159+
return
160+
}
161+
162+
func newLocationInfo(str []byte) *LocationInfo {
163+
164+
var info *LocationInfo
165+
166+
fields := bytes.Split(str, []byte("\t"))
167+
switch len(fields) {
168+
case 4:
169+
// free version
170+
info = &LocationInfo{
171+
Country: string(fields[0]),
172+
Region: string(fields[1]),
173+
City: string(fields[2]),
174+
}
175+
case 5:
176+
// pay version
177+
info = &LocationInfo{
178+
Country: string(fields[0]),
179+
Region: string(fields[1]),
180+
City: string(fields[2]),
181+
Isp: string(fields[4]),
182+
}
183+
default:
184+
panic("unexpected ip info:" + string(str))
185+
}
186+
187+
if len(info.Country) == 0 {
188+
info.Country = Null
189+
}
190+
if len(info.Region) == 0 {
191+
info.Region = Null
192+
}
193+
if len(info.City) == 0 {
194+
info.City = Null
195+
}
196+
if len(info.Isp) == 0 {
197+
info.Isp = Null
198+
}
199+
return info
200+
}

main.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/beego/beego/plugins/cors"
2020
_ "github.com/beego/beego/session/redis"
2121
"github.com/casbin/caswaf/casdoor"
22+
"github.com/casbin/caswaf/ip"
2223
"github.com/casbin/caswaf/object"
2324
"github.com/casbin/caswaf/proxy"
2425
"github.com/casbin/caswaf/routers"
@@ -32,6 +33,7 @@ func main() {
3233
object.CreateTables()
3334
casdoor.InitCasdoorConfig()
3435
proxy.InitHttpClient()
36+
ip.InitIpDb()
3537
object.InitSiteMap()
3638
object.InitRuleMap()
3739
run.InitAppMap()

rule/rule_ip.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"net/http"
2121
"strings"
2222

23+
"github.com/casbin/caswaf/ip"
2324
"github.com/casbin/caswaf/object"
2425
"github.com/casbin/caswaf/util"
2526
)
@@ -34,10 +35,19 @@ func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request)
3435
}
3536
for _, expression := range expressions {
3637
reason := fmt.Sprintf("expression matched: \"%s %s %s\"", clientIp, expression.Operator, expression.Value)
38+
39+
// Handle "is abroad" operator
40+
if expression.Operator == "is abroad" {
41+
if ip.IsAbroadIp(clientIp) {
42+
return &RuleResult{Reason: reason}, nil
43+
}
44+
continue
45+
}
46+
3747
ips := strings.Split(expression.Value, ",")
38-
for _, ip := range ips {
39-
if strings.Contains(ip, "/") {
40-
_, ipNet, err := net.ParseCIDR(ip)
48+
for _, ipStr := range ips {
49+
if strings.Contains(ipStr, "/") {
50+
_, ipNet, err := net.ParseCIDR(ipStr)
4151
if err != nil {
4252
return nil, err
4353
}
@@ -54,21 +64,21 @@ func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request)
5464
default:
5565
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
5666
}
57-
} else if strings.ContainsAny(ip, ".:") {
67+
} else if strings.ContainsAny(ipStr, ".:") {
5868
switch expression.Operator {
5969
case "is in":
60-
if ip == clientIp {
70+
if ipStr == clientIp {
6171
return &RuleResult{Reason: reason}, nil
6272
}
6373
case "is not in":
64-
if ip != clientIp {
74+
if ipStr != clientIp {
6575
return &RuleResult{Reason: reason}, nil
6676
}
6777
default:
6878
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
6979
}
7080
} else {
71-
return nil, fmt.Errorf("unknown IP or CIDR format: %s", ip)
81+
return nil, fmt.Errorf("unknown IP or CIDR format: %s", ipStr)
7282
}
7383
}
7484
}

0 commit comments

Comments
 (0)