@@ -20,6 +20,9 @@ package gplus
20
20
import (
21
21
"github.com/acmestack/gorm-plus/constants"
22
22
"gorm.io/gorm"
23
+ "gorm.io/gorm/schema"
24
+ "gorm.io/gorm/utils"
25
+ "reflect"
23
26
)
24
27
25
28
var gormDb * gorm.DB
@@ -64,15 +67,15 @@ func InsertBatchSize[T any](entities []*T, batchSize int) *gorm.DB {
64
67
return resultDb
65
68
}
66
69
67
- func DeleteById [T any ](id any , primaryKeyColumn ... string ) * gorm.DB {
70
+ func DeleteById [T any ](id any ) * gorm.DB {
68
71
var entity T
69
- resultDb := gormDb .Where (getPKColumn ( primaryKeyColumn ), id ).Delete (& entity )
72
+ resultDb := gormDb .Where (getPKColumn [ T ]( ), id ).Delete (& entity )
70
73
return resultDb
71
74
}
72
75
73
- func DeleteByIds [T any ](ids any , primaryKeyColumn ... string ) * gorm.DB {
76
+ func DeleteByIds [T any ](ids any ) * gorm.DB {
74
77
q := NewQuery [T ]()
75
- q .In (getPKColumn ( primaryKeyColumn ), ids )
78
+ q .In (getPKColumn [ T ]( ), ids )
76
79
resultDb := Delete [T ](q )
77
80
return resultDb
78
81
}
@@ -83,8 +86,8 @@ func Delete[T any](q *Query[T]) *gorm.DB {
83
86
return resultDb
84
87
}
85
88
86
- func UpdateById [T any ](entity * T , id any , primaryKeyColumn ... string ) * gorm.DB {
87
- resultDb := gormDb .Model (& entity ).Where (getPKColumn ( primaryKeyColumn ), id ).Updates (entity )
89
+ func UpdateById [T any ](entity * T , id any ) * gorm.DB {
90
+ resultDb := gormDb .Model (& entity ).Where (getPKColumn [ T ]( ), id ).Updates (entity )
88
91
return resultDb
89
92
}
90
93
@@ -102,9 +105,9 @@ func SelectById[T any](id any) (*T, *gorm.DB) {
102
105
return & entity , resultDb
103
106
}
104
107
105
- func SelectByIds [T any ](ids any , primaryKeyColumn ... string ) ([]* T , * gorm.DB ) {
108
+ func SelectByIds [T any ](ids any ) ([]* T , * gorm.DB ) {
106
109
q := NewQuery [T ]()
107
- q .In (getPKColumn ( primaryKeyColumn ), ids )
110
+ q .In (getPKColumn [ T ]( ), ids )
108
111
return SelectList [T ](q )
109
112
}
110
113
@@ -222,9 +225,22 @@ func buildCondition[T any](q *Query[T]) *gorm.DB {
222
225
}
223
226
224
227
// getPKColumn 获取主键key
225
- func getPKColumn (primaryKeyColumn []string ) string {
226
- if len (primaryKeyColumn ) > 0 {
227
- return primaryKeyColumn [0 ]
228
+ func getPKColumn [T any ]() string {
229
+ var entity T
230
+ entityType := reflect .TypeOf (entity )
231
+ numField := entityType .NumField ()
232
+ var columnName string
233
+ for i := 0 ; i < numField ; i ++ {
234
+ field := entityType .Field (i )
235
+ tagSetting := schema .ParseTagSetting (field .Tag .Get ("gorm" ), ";" )
236
+ isPrimaryKey := utils .CheckTruth (tagSetting ["PRIMARYKEY" ], tagSetting ["PRIMARY_KEY" ])
237
+ if isPrimaryKey {
238
+ columnName = tagSetting ["COLUMN" ]
239
+ break
240
+ }
241
+ }
242
+ if columnName == "" {
243
+ return constants .DefaultPrimaryName
228
244
}
229
- return constants . PK
245
+ return columnName
230
246
}
0 commit comments