Skip to content

Commit 259d3bb

Browse files
committed
fix: fix Sql injection problem
1 parent 11546ce commit 259d3bb

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

example/base/select_test.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,20 @@ import (
2424
"gorm.io/gorm"
2525
"log"
2626
"testing"
27+
"time"
2728
)
2829

30+
type Test2 struct {
31+
TestId string `gorm:"primaryKey"`
32+
Code string
33+
Price string
34+
CreatedAt time.Time
35+
UpdatedAt time.Time
36+
DeletedAt time.Time
37+
}
38+
2939
func TestSelectById(t *testing.T) {
30-
user, resultDb := gplus.SelectById[User](1)
40+
user, resultDb := gplus.SelectById[User]("or 1=1")
3141
if resultDb.Error != nil {
3242
if errors.Is(resultDb.Error, gorm.ErrRecordNotFound) {
3343
log.Fatalln("SelectById Data not found:", resultDb.Error)
@@ -39,6 +49,19 @@ func TestSelectById(t *testing.T) {
3949
log.Println(string(marshal))
4050
}
4151

52+
func TestSelectByStrId(t *testing.T) {
53+
test, resultDb := gplus.SelectById[Test2]("a = 1 or 1=1")
54+
if resultDb.Error != nil {
55+
if errors.Is(resultDb.Error, gorm.ErrRecordNotFound) {
56+
log.Fatalln("SelectById Data not found:", resultDb.Error)
57+
}
58+
log.Fatalln("SelectById error:", resultDb.Error)
59+
}
60+
log.Println("RowsAffected:", resultDb.RowsAffected)
61+
marshal, _ := json.Marshal(test)
62+
log.Println(string(marshal))
63+
}
64+
4265
func TestSelectByIds(t *testing.T) {
4366
var ids []int
4467
ids = append(ids, 1)

gplus/base_dao.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,11 @@ func Update[T any](q *Query[T]) *gorm.DB {
9797
}
9898

9999
func SelectById[T any](id any) (*T, *gorm.DB) {
100+
q := NewQuery[T]()
101+
q.Eq(getPKColumn[T](), id)
100102
var entity T
101-
resultDb := gormDb.Take(&entity, id)
102-
if resultDb.RowsAffected == 0 {
103-
return nil, resultDb
104-
}
105-
return &entity, resultDb
103+
resultDb := buildCondition(q)
104+
return &entity, resultDb.Limit(1).Find(&entity)
106105
}
107106

108107
func SelectByIds[T any](ids any) ([]*T, *gorm.DB) {
@@ -114,11 +113,7 @@ func SelectByIds[T any](ids any) ([]*T, *gorm.DB) {
114113
func SelectOne[T any](q *Query[T]) (*T, *gorm.DB) {
115114
var entity T
116115
resultDb := buildCondition(q)
117-
resultDb.Take(&entity)
118-
if resultDb.RowsAffected == 0 {
119-
return nil, resultDb
120-
}
121-
return &entity, resultDb
116+
return &entity, resultDb.Limit(1).Find(&entity)
122117
}
123118

124119
func SelectList[T any](q *Query[T]) ([]*T, *gorm.DB) {

0 commit comments

Comments
 (0)