Skip to content

Commit 3649cbd

Browse files
authored
refactor: modify updateById to solve zero-value update problem (#51)
1 parent 0220223 commit 3649cbd

File tree

4 files changed

+90
-96
lines changed

4 files changed

+90
-96
lines changed

gplus/cache.go

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,39 @@ var modelInstanceCache sync.Map
3333
// Cache 缓存实体对象所有的字段名
3434
func Cache(models ...any) {
3535
for _, model := range models {
36-
valueOf := reflect.ValueOf(model).Elem()
37-
typeOf := reflect.TypeOf(model).Elem()
38-
39-
for i := 0; i < valueOf.NumField(); i++ {
40-
field := typeOf.Field(i)
41-
// 如果当前实体嵌入了其他实体,同样需要缓存它的字段名
42-
if field.Anonymous {
43-
// 如果存在多重嵌套,通过递归方式获取他们的字段名
44-
subFieldMap := getSubFieldColumnNameMap(valueOf, field)
45-
for key, value := range subFieldMap {
46-
columnNameCache.Store(key, value)
47-
}
48-
} else {
49-
// 获取对象字段指针值
50-
pointer := valueOf.Field(i).Addr().Pointer()
51-
name := parseColumnName(field)
52-
columnNameCache.Store(pointer, name)
53-
}
36+
columnNameMap := getColumnNameMap(model)
37+
for pointer, columnName := range columnNameMap {
38+
columnNameCache.Store(pointer, columnName)
5439
}
55-
5640
// 缓存对象
5741
modelTypeStr := reflect.TypeOf(model).Elem().String()
5842
modelInstanceCache.Store(modelTypeStr, model)
5943
}
6044
}
6145

46+
func getColumnNameMap(model any) map[uintptr]string {
47+
var columnNameMap = make(map[uintptr]string)
48+
valueOf := reflect.ValueOf(model).Elem()
49+
typeOf := reflect.TypeOf(model).Elem()
50+
for i := 0; i < valueOf.NumField(); i++ {
51+
field := typeOf.Field(i)
52+
// 如果当前实体嵌入了其他实体,同样需要缓存它的字段名
53+
if field.Anonymous {
54+
// 如果存在多重嵌套,通过递归方式获取他们的字段名
55+
subFieldMap := getSubFieldColumnNameMap(valueOf, field)
56+
for pointer, columnName := range subFieldMap {
57+
columnNameMap[pointer] = columnName
58+
}
59+
} else {
60+
// 获取对象字段指针值
61+
pointer := valueOf.Field(i).Addr().Pointer()
62+
columnName := parseColumnName(field)
63+
columnNameMap[pointer] = columnName
64+
}
65+
}
66+
return columnNameMap
67+
}
68+
6269
// GetModel 获取
6370
func GetModel[T any]() *T {
6471
modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String()
@@ -73,7 +80,6 @@ func GetModel[T any]() *T {
7380
// 递归获取嵌套字段名
7481
func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) map[uintptr]string {
7582
result := make(map[uintptr]string)
76-
7783
modelType := field.Type
7884
if modelType.Kind() == reflect.Ptr {
7985
modelType = modelType.Elem()
@@ -95,7 +101,7 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField)
95101
return result
96102
}
97103

98-
// 获取字段名称
104+
// 解析字段名称
99105
func parseColumnName(field reflect.StructField) string {
100106
tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";")
101107
name, ok := tagSetting["COLUMN"]
@@ -104,3 +110,22 @@ func parseColumnName(field reflect.StructField) string {
104110
}
105111
return globalDb.Config.NamingStrategy.ColumnName("", field.Name)
106112
}
113+
114+
func getColumnName(v any) string {
115+
var columnName string
116+
valueOf := reflect.ValueOf(v)
117+
switch valueOf.Kind() {
118+
case reflect.String:
119+
return v.(string)
120+
case reflect.Pointer:
121+
if name, ok := columnNameCache.Load(valueOf.Pointer()); ok {
122+
return name.(string)
123+
}
124+
// 如果是Function类型,解析字段名称
125+
if reflect.TypeOf(v).Elem() == reflect.TypeOf((*Function)(nil)).Elem() {
126+
f := v.(*Function)
127+
return f.funStr
128+
}
129+
}
130+
return columnName
131+
}

gplus/dao.go

Lines changed: 39 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,26 @@ func Delete[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB {
109109
// UpdateById 根据 ID 更新
110110
func UpdateById[T any](entity *T, opts ...OptionFunc) *gorm.DB {
111111
db := getDb(opts...)
112+
113+
// 如果用户没有设置选择更新的字段,默认更新所有的字段,包括零值更新
114+
updateAllIfNeed(entity, opts, db)
115+
112116
resultDb := db.Model(entity).Updates(entity)
113117
return resultDb
114118
}
115119

120+
func updateAllIfNeed(entity any, opts []OptionFunc, db *gorm.DB) {
121+
option := getOption(opts)
122+
if len(option.Selects) == 0 {
123+
columnNameMap := getColumnNameMap(entity)
124+
var columnNames []string
125+
for _, columnName := range columnNameMap {
126+
columnNames = append(columnNames, columnName)
127+
}
128+
db.Select(columnNames)
129+
}
130+
}
131+
116132
// Update 根据 Map 更新
117133
func Update[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB {
118134
db := getDb(opts...)
@@ -143,12 +159,6 @@ func SelectOne[T any](q *QueryCond[T], opts ...OptionFunc) (*T, *gorm.DB) {
143159
return &entity, resultDb.First(&entity)
144160
}
145161

146-
// Exists 根据条件判断记录是否存在
147-
func Exists[T any](q *QueryCond[T], opts ...OptionFunc) (bool, error) {
148-
_, dbRes := SelectOne[T](q, opts...)
149-
return dbRes.RowsAffected > 0, dbRes.Error
150-
}
151-
152162
// SelectList 根据条件查询多条记录
153163
func SelectList[T any](q *QueryCond[T], opts ...OptionFunc) ([]*T, *gorm.DB) {
154164
resultDb := buildCondition(q, opts...)
@@ -157,32 +167,6 @@ func SelectList[T any](q *QueryCond[T], opts ...OptionFunc) ([]*T, *gorm.DB) {
157167
return results, resultDb
158168
}
159169

160-
// SelectListModel 根据条件查询多条记录
161-
// 第一个泛型代表数据库表实体
162-
// 第二个泛型代表返回记录实体
163-
func SelectListModel[T any, R any](q *QueryCond[T], opts ...OptionFunc) ([]*R, *gorm.DB) {
164-
resultDb := buildCondition(q, opts...)
165-
var results []*R
166-
resultDb.Scan(&results)
167-
return results, resultDb
168-
}
169-
170-
// SelectListByMap 根据 Map 查询多条记录
171-
func SelectListByMap[T any](q *QueryCond[T], opts ...OptionFunc) ([]*T, *gorm.DB) {
172-
resultDb := buildCondition(q, opts...)
173-
var results []*T
174-
resultDb.Find(&results)
175-
return results, resultDb
176-
}
177-
178-
// SelectListMaps 根据条件查询,返回Map记录
179-
func SelectListMaps[T any](q *QueryCond[T], opts ...OptionFunc) ([]map[string]any, *gorm.DB) {
180-
resultDb := buildCondition(q, opts...)
181-
var results []map[string]any
182-
resultDb.Find(&results)
183-
return results, resultDb
184-
}
185-
186170
// SelectPage 根据条件分页查询记录
187171
func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Page[T], *gorm.DB) {
188172
option := getOption(opts)
@@ -203,28 +187,24 @@ func SelectPage[T any](page *Page[T], q *QueryCond[T], opts ...OptionFunc) (*Pag
203187
return page, resultDb
204188
}
205189

206-
// SelectPageModel 根据条件分页查询记录
207-
// 第一个泛型代表数据库表实体
208-
// 第二个泛型代表返回记录实体
209-
func SelectPageModel[T any, R any](page *Page[R], q *QueryCond[T], opts ...OptionFunc) (*Page[R], *gorm.DB) {
210-
option := getOption(opts)
211-
// 如果需要分页忽略总数,不查询总数
212-
if !option.IgnoreTotal {
213-
total, countDb := SelectCount[T](q, opts...)
214-
if countDb.Error != nil {
215-
return page, countDb
216-
}
217-
page.Total = total
218-
}
190+
// SelectCount 根据条件查询记录数量
191+
func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) {
192+
var count int64
219193
resultDb := buildCondition(q, opts...)
220-
var results []*R
221-
resultDb.Scopes(paginate(page)).Scan(&results)
222-
page.Records = results
223-
return page, resultDb
194+
resultDb.Count(&count)
195+
return count, resultDb
224196
}
225197

226-
// SelectPageMaps 根据条件分页查询,返回分页Map记录
227-
func SelectPageMaps[T any](page *Page[map[string]any], q *QueryCond[T], opts ...OptionFunc) (*Page[map[string]any], *gorm.DB) {
198+
// Exists 根据条件判断记录是否存在
199+
func Exists[T any](q *QueryCond[T], opts ...OptionFunc) (bool, error) {
200+
_, dbRes := SelectOne[T](q, opts...)
201+
return dbRes.RowsAffected > 0, dbRes.Error
202+
}
203+
204+
// SelectPageGeneric 根据传入的泛型封装分页记录
205+
// 第一个泛型代表数据库表实体
206+
// 第二个泛型代表返回记录实体
207+
func SelectPageGeneric[T any, R any](page *Page[R], q *QueryCond[T], opts ...OptionFunc) (*Page[R], *gorm.DB) {
228208
option := getOption(opts)
229209
// 如果需要分页忽略总数,不查询总数
230210
if !option.IgnoreTotal {
@@ -235,20 +215,21 @@ func SelectPageMaps[T any](page *Page[map[string]any], q *QueryCond[T], opts ...
235215
page.Total = total
236216
}
237217
resultDb := buildCondition(q, opts...)
238-
var results []map[string]any
239-
resultDb.Scopes(paginate(page)).Find(&results)
218+
var results []R
219+
resultDb.Scopes(paginate(page)).Scan(&results)
240220
for _, m := range results {
241221
page.Records = append(page.Records, &m)
242222
}
243223
return page, resultDb
244224
}
245225

246-
// SelectCount 根据条件查询记录数量
247-
func SelectCount[T any](q *QueryCond[T], opts ...OptionFunc) (int64, *gorm.DB) {
248-
var count int64
226+
// SelectGeneric 根据传入的泛型封装记录
227+
// 第一个泛型代表数据库表实体
228+
// 第二个泛型代表返回记录实体
229+
func SelectGeneric[T any, R any](q *QueryCond[T], opts ...OptionFunc) (R, *gorm.DB) {
230+
var entity R
249231
resultDb := buildCondition(q, opts...)
250-
resultDb.Count(&count)
251-
return count, resultDb
232+
return entity, resultDb.Scan(&entity)
252233
}
253234

254235
func Begin(opts ...*sql.TxOptions) *gorm.DB {

gplus/function.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ func Count(columnName any) *Function {
9494
return &Function{funStr: addBracket(constants.COUNT, getColumnName(columnName))}
9595
}
9696

97+
func As(columnName any, asName any) string {
98+
return getColumnName(columnName) + " " + constants.As + " " + getColumnName(asName)
99+
}
100+
97101
func addBracket(function string, columnNameStr string) string {
98102
return function + constants.LeftBracket + columnNameStr + constants.RightBracket
99103
}

gplus/query.go

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,7 @@ func (q *QueryCond[T]) NotBetween(column any, start, end any) *QueryCond[T] {
201201
// Distinct 去除重复字段值
202202
func (q *QueryCond[T]) Distinct(columns ...any) *QueryCond[T] {
203203
for _, v := range columns {
204-
if columnName, ok := columnNameCache.Load(reflect.ValueOf(v).Pointer()); ok {
205-
q.distinctColumns = append(q.distinctColumns, columnName.(string))
206-
}
204+
q.distinctColumns = append(q.distinctColumns, getColumnName(v))
207205
}
208206
return q
209207
}
@@ -344,17 +342,3 @@ func (q *QueryCond[T]) buildOrder(orderType string, columns ...string) {
344342
q.orderBuilder.WriteString(orderType)
345343
}
346344
}
347-
348-
func getColumnName(v any) string {
349-
var columnName string
350-
valueOf := reflect.ValueOf(v)
351-
switch valueOf.Kind() {
352-
case reflect.String:
353-
columnName = v.(string)
354-
case reflect.Pointer:
355-
if name, ok := columnNameCache.Load(valueOf.Pointer()); ok {
356-
columnName = name.(string)
357-
}
358-
}
359-
return columnName
360-
}

0 commit comments

Comments
 (0)