Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
.idea/
*.iml

# Built binaries
caswaf

tmp/
tmpFiles/
*.tmp
Expand Down
Binary file added ip/17monipdb.dat
Binary file not shown.
34 changes: 34 additions & 0 deletions ip/ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ip

import "fmt"

func InitIpDb() {
err := Init("ip/17monipdb.dat")
if err != nil {
panic(err)
}
}

func IsAbroadIp(ip string) bool {
info, err := Find(ip)
if err != nil {
fmt.Printf("error: ip = %s, error = %s\n", ip, err.Error())
return false
}

return info.Country != "中国"
}
200 changes: 200 additions & 0 deletions ip/ip17mon.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Copyright 2022 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ip

import (
"bytes"
"encoding/binary"
"errors"
"io/ioutil"
"net"
)

const Null = "N/A"

var (
ErrInvalidIp = errors.New("invalid ip format")
std *Locator
)

// Init default locator with dataFile
func Init(dataFile string) (err error) {
if std != nil {
return
}
std, err = NewLocator(dataFile)
return
}

// Init default locator with data
func InitWithData(data []byte) {
if std != nil {
return
}
std = NewLocatorWithData(data)
return
}

// Find locationInfo by ip string
// It will return err when ipstr is not a valid format
func Find(ipstr string) (*LocationInfo, error) {
return std.Find(ipstr)
}

// Find locationInfo by uint32
func FindByUint(ip uint32) *LocationInfo {
return std.FindByUint(ip)
}

//-----------------------------------------------------------------------------

// New locator with dataFile
func NewLocator(dataFile string) (loc *Locator, err error) {
data, err := ioutil.ReadFile(dataFile)
if err != nil {
return
}
loc = NewLocatorWithData(data)
return
}

// New locator with data
func NewLocatorWithData(data []byte) (loc *Locator) {
loc = new(Locator)
loc.init(data)
return
}

type Locator struct {
textData []byte
indexData1 []uint32
indexData2 []int
indexData3 []int
index []int
}

type LocationInfo struct {
Country string
Region string
City string
Isp string
}

// Find locationInfo by ip string
// It will return err when ipstr is not a valid format
func (loc *Locator) Find(ipstr string) (info *LocationInfo, err error) {
ip := net.ParseIP(ipstr).To4()
if ip == nil || ip.To4() == nil {
err = ErrInvalidIp
return
}
info = loc.FindByUint(binary.BigEndian.Uint32([]byte(ip)))
return
}

// Find locationInfo by uint32
func (loc *Locator) FindByUint(ip uint32) (info *LocationInfo) {
end := len(loc.indexData1) - 1
if ip>>24 != 0xff {
end = loc.index[(ip>>24)+1]
}
idx := loc.findIndexOffset(ip, loc.index[ip>>24], end)
off := loc.indexData2[idx]
return newLocationInfo(loc.textData[off : off+loc.indexData3[idx]])
}

// binary search
func (loc *Locator) findIndexOffset(ip uint32, start, end int) int {
for start < end {
mid := (start + end) / 2
if ip > loc.indexData1[mid] {
start = mid + 1
} else {
end = mid
}
}

if loc.indexData1[end] >= ip {
return end
}

return start
}

func (loc *Locator) init(data []byte) {
textoff := int(binary.BigEndian.Uint32(data[:4]))

loc.textData = data[textoff-1024:]

loc.index = make([]int, 256)
for i := 0; i < 256; i++ {
off := 4 + i*4
loc.index[i] = int(binary.LittleEndian.Uint32(data[off : off+4]))
}

nidx := (textoff - 4 - 1024 - 1024) / 8

loc.indexData1 = make([]uint32, nidx)
loc.indexData2 = make([]int, nidx)
loc.indexData3 = make([]int, nidx)

for i := 0; i < nidx; i++ {
off := 4 + 1024 + i*8
loc.indexData1[i] = binary.BigEndian.Uint32(data[off : off+4])
loc.indexData2[i] = int(uint32(data[off+4]) | uint32(data[off+5])<<8 | uint32(data[off+6])<<16)
loc.indexData3[i] = int(data[off+7])
}
return
}

func newLocationInfo(str []byte) *LocationInfo {

var info *LocationInfo

fields := bytes.Split(str, []byte("\t"))
switch len(fields) {
case 4:
// free version
info = &LocationInfo{
Country: string(fields[0]),
Region: string(fields[1]),
City: string(fields[2]),
}
case 5:
// pay version
info = &LocationInfo{
Country: string(fields[0]),
Region: string(fields[1]),
City: string(fields[2]),
Isp: string(fields[4]),
}
default:
panic("unexpected ip info:" + string(str))
}

if len(info.Country) == 0 {
info.Country = Null
}
if len(info.Region) == 0 {
info.Region = Null
}
if len(info.City) == 0 {
info.City = Null
}
if len(info.Isp) == 0 {
info.Isp = Null
}
return info
}
2 changes: 2 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/beego/beego/plugins/cors"
_ "github.com/beego/beego/session/redis"
"github.com/casbin/caswaf/casdoor"
"github.com/casbin/caswaf/ip"
"github.com/casbin/caswaf/object"
"github.com/casbin/caswaf/proxy"
"github.com/casbin/caswaf/routers"
Expand All @@ -32,6 +33,7 @@ func main() {
object.CreateTables()
casdoor.InitCasdoorConfig()
proxy.InitHttpClient()
ip.InitIpDb()
object.InitSiteMap()
object.InitRuleMap()
run.InitAppMap()
Expand Down
24 changes: 17 additions & 7 deletions rule/rule_ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"net/http"
"strings"

"github.com/casbin/caswaf/ip"
"github.com/casbin/caswaf/object"
"github.com/casbin/caswaf/util"
)
Expand All @@ -34,10 +35,19 @@ func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request)
}
for _, expression := range expressions {
reason := fmt.Sprintf("expression matched: \"%s %s %s\"", clientIp, expression.Operator, expression.Value)

// Handle "is abroad" operator
if expression.Operator == "is abroad" {
if ip.IsAbroadIp(clientIp) {
return &RuleResult{Reason: reason}, nil
}
continue
}

ips := strings.Split(expression.Value, ",")
for _, ip := range ips {
if strings.Contains(ip, "/") {
_, ipNet, err := net.ParseCIDR(ip)
for _, ipStr := range ips {
if strings.Contains(ipStr, "/") {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return nil, err
}
Expand All @@ -54,21 +64,21 @@ func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request)
default:
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
}
} else if strings.ContainsAny(ip, ".:") {
} else if strings.ContainsAny(ipStr, ".:") {
switch expression.Operator {
case "is in":
if ip == clientIp {
if ipStr == clientIp {
return &RuleResult{Reason: reason}, nil
}
case "is not in":
if ip != clientIp {
if ipStr != clientIp {
return &RuleResult{Reason: reason}, nil
}
default:
return nil, fmt.Errorf("unknown operator: %s", expression.Operator)
}
} else {
return nil, fmt.Errorf("unknown IP or CIDR format: %s", ip)
return nil, fmt.Errorf("unknown IP or CIDR format: %s", ipStr)
}
}
}
Expand Down
Loading