Skip to content

Commit 563112f

Browse files
authored
feat: add cache method to store columnName (#43)
* feat: add cache method to store columnName * ci: add license
1 parent a219afc commit 563112f

File tree

4 files changed

+280
-135
lines changed

4 files changed

+280
-135
lines changed

gplus/cache.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the AcmeStack under one or more contributor license
3+
* agreements. See the NOTICE file distributed with this work for
4+
* additional information regarding copyright ownership.
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package gplus
19+
20+
import (
21+
"gorm.io/gorm/schema"
22+
"reflect"
23+
"sync"
24+
)
25+
26+
// 缓存项目中所有实体字段名,储存格式:key为字段指针值,value为字段名
27+
// 通过缓存实体的字段名,方便gplus通过字段指针获取到对应的字段名
28+
var columnNameCache sync.Map
29+
30+
// 缓存实体对象,主要给NewQuery方法返回使用
31+
var modelInstanceCache sync.Map
32+
33+
// Cache 缓存实体对象所有的字段名
34+
func Cache(model any, namingStrategy ...schema.Namer) {
35+
valueOf := reflect.ValueOf(model).Elem()
36+
typeOf := reflect.TypeOf(model).Elem()
37+
38+
for i := 0; i < valueOf.NumField(); i++ {
39+
field := typeOf.Field(i)
40+
// 如果当前实体嵌入了其他实体,同样需要缓存它的字段名
41+
if field.Anonymous {
42+
// 如果存在多重嵌套,通过递归方式获取他们的字段名
43+
subFieldMap := getSubFieldColumnNameMap(valueOf, field)
44+
for key, value := range subFieldMap {
45+
columnNameCache.Store(key, value)
46+
}
47+
} else {
48+
// 获取对象字段指针值
49+
pointer := valueOf.Field(i).Addr().Pointer()
50+
name := parseColumnName(field, namingStrategy...)
51+
columnNameCache.Store(pointer, name)
52+
}
53+
}
54+
55+
// 缓存对象
56+
modelTypeStr := reflect.TypeOf(model).Elem().String()
57+
modelInstanceCache.Store(modelTypeStr, model)
58+
}
59+
60+
// 递归获取嵌套字段名
61+
func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) map[uintptr]string {
62+
result := make(map[uintptr]string)
63+
64+
modelType := field.Type
65+
if modelType.Kind() == reflect.Ptr {
66+
modelType = modelType.Elem()
67+
}
68+
for j := 0; j < modelType.NumField(); j++ {
69+
subField := modelType.Field(j)
70+
if subField.Anonymous {
71+
nestedFields := getSubFieldColumnNameMap(valueOf, subField)
72+
for key, value := range nestedFields {
73+
result[key] = value
74+
}
75+
} else {
76+
pointer := valueOf.FieldByName(modelType.Field(j).Name).Addr().Pointer()
77+
name := parseColumnName(modelType.Field(j))
78+
result[pointer] = name
79+
}
80+
}
81+
82+
return result
83+
}
84+
85+
// 获取字段名称
86+
func parseColumnName(field reflect.StructField, namingStrategy ...schema.Namer) string {
87+
tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";")
88+
name, ok := tagSetting["COLUMN"]
89+
if ok {
90+
return name
91+
}
92+
93+
if len(namingStrategy) > 0 {
94+
return namingStrategy[0].ColumnName("", field.Name)
95+
}
96+
97+
strategy := schema.NamingStrategy{}
98+
return strategy.ColumnName("", field.Name)
99+
}

0 commit comments

Comments
 (0)