Skip to content

Commit 954be60

Browse files
committed
Refactor code
1 parent f4311b9 commit 954be60

File tree

5 files changed

+117
-58
lines changed

5 files changed

+117
-58
lines changed

batch.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func BuildToInsertBatchWithSchema(table string, models interface{}, driver strin
118118
if buildParam == nil {
119119
buildParam = GetBuildByDriver(driver)
120120
}
121-
var cols []FieldDB
121+
var cols []*FieldDB
122122
// var schema map[string]FieldDB
123123
if len(options) > 0 && options[0] != nil {
124124
cols = options[0].Columns

native.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func BuildToSaveWithSchema(table string, model interface{}, driver string, build
3232
if mv.Kind() == reflect.Ptr {
3333
mv = mv.Elem()
3434
}
35-
var cols, keys []FieldDB
35+
var cols, keys []*FieldDB
3636
// var schema map[string]FieldDB
3737
if len(options) > 0 && options[0] != nil {
3838
m := options[0]

one.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func BuildToInsertWithSchema(table string, model interface{}, versionIndex int,
3232
sql.Scanner
3333
}, options ...*Schema) (string, []interface{}) {
3434
modelType := reflect.TypeOf(model)
35-
var cols []FieldDB
35+
var cols []*FieldDB
3636
if len(options) > 0 && options[0] != nil {
3737
cols = options[0].Columns
3838
} else {
@@ -124,7 +124,7 @@ func BuildToUpdateWithVersion(table string, model interface{}, versionIndex int,
124124
driver.Valuer
125125
sql.Scanner
126126
}, options ...*Schema) (string, []interface{}) {
127-
var cols, keys []FieldDB
127+
var cols, keys []*FieldDB
128128
// var schema map[string]FieldDB
129129
modelType := reflect.TypeOf(model)
130130
if len(options) > 0 && options[0] != nil {
@@ -225,22 +225,22 @@ func BuildToUpdateWithVersion(table string, model interface{}, versionIndex int,
225225
query := fmt.Sprintf("update %v set %v where %v", table, strings.Join(values, ","), strings.Join(where, " and "))
226226
return query, args
227227
}
228-
func BuildToPatch(table string, model map[string]interface{}, keyColumns []string, buildParam func(int) string, options ...map[string]FieldDB) (string, []interface{}) {
228+
func BuildToPatch(table string, model map[string]interface{}, keyColumns []string, buildParam func(int) string, options ...map[string]*FieldDB) (string, []interface{}) {
229229
return BuildToPatchWithVersion(table, model, keyColumns, buildParam, nil, "", options...)
230230
}
231231
func BuildToPatchWithArray(table string, model map[string]interface{}, keyColumns []string, buildParam func(int) string, toArray func(interface{}) interface {
232232
driver.Valuer
233233
sql.Scanner
234-
}, options ...map[string]FieldDB) (string, []interface{}) {
234+
}, options ...map[string]*FieldDB) (string, []interface{}) {
235235
return BuildToPatchWithVersion(table, model, keyColumns, buildParam, toArray, "", options...)
236236
}
237237

238238
// BuildToPatchWithVersion3 model with db column name
239239
func BuildToPatchWithVersion(table string, model map[string]interface{}, keyColumns []string, buildParam func(int) string, toArray func(interface{}) interface {
240240
driver.Valuer
241241
sql.Scanner
242-
}, version string, options ...map[string]FieldDB) (string, []interface{}) { //version column name db
243-
var schema map[string]FieldDB
242+
}, version string, options ...map[string]*FieldDB) (string, []interface{}) { //version column name db
243+
var schema map[string]*FieldDB
244244
if len(options) > 0 {
245245
schema = options[0]
246246
}

string_service.go

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -60,62 +60,121 @@ func (s *StringService) Load(ctx context.Context, key string, max int64) ([]stri
6060
}
6161

6262
func (s *StringService) Save(ctx context.Context, values []string) (int64, error) {
63-
mainScope := BatchStatement{}
6463
driver := s.Driver
65-
for _, e := range values {
66-
mainScope.Values = append(mainScope.Values, e)
64+
l := len(values)
65+
if l == 0 {
66+
return 0, nil
6767
}
68-
query := ""
69-
holders := BuildPlaceHolders(len(mainScope.Values), s.BuildParam)
70-
if driver == DriverPostgres {
71-
query = fmt.Sprintf("insert into %s (%s) values %s on conflict do nothing",
72-
s.Table,
73-
s.Field,
74-
holders,
75-
)
68+
if driver == DriverPostgres || driver == DriverMysql {
69+
ps := make([]string, 0)
70+
p := make([]interface{}, 0)
71+
for _, str := range values {
72+
p = append(p, str)
73+
}
74+
if driver == DriverPostgres {
75+
for i := 1; i <= l; i++ {
76+
ps = append(ps, "(" + BuildDollarParam(i) + ")")
77+
}
78+
} else {
79+
for i := 1; i <= l; i++ {
80+
ps = append(ps, "(?)")
81+
}
82+
}
83+
var query string
84+
if driver == DriverPostgres {
85+
query = fmt.Sprintf("insert into %s (%s) values %s on conflict do nothing", s.Table, s.Field, strings.Join(ps, ","))
86+
} else {
87+
query = fmt.Sprintf("insert ignore into %s (%s) values %s", s.Table, s.Field, strings.Join(ps, ","))
88+
}
89+
tx, err := s.DB.Begin()
90+
if err != nil {
91+
return -1, err
92+
}
93+
res, err := tx.ExecContext(ctx, query, p...)
94+
if err != nil {
95+
er := tx.Rollback()
96+
if er != nil {
97+
return -1, er
98+
}
99+
return -1, err
100+
}
101+
err = tx.Commit()
102+
if err != nil {
103+
return -1, err
104+
}
105+
return res.RowsAffected()
76106
} else if driver == DriverSqlite3 {
77-
query = fmt.Sprintf("insert or ignore into %s (%s) values %s",
78-
s.Table,
79-
s.Field,
80-
holders,
81-
)
82-
} else if driver == DriverMysql {
83-
query = fmt.Sprintf("insert ignore %s (%s) values %s ",
84-
s.Table,
85-
s.Field,
86-
holders,
87-
)
88-
} else if driver == DriverOracle || driver == DriverMssql {
89-
onDupe := s.Table + "." + s.Field + " = " + "temp." + s.Field
90-
value := "temp." + s.Field
91-
query = fmt.Sprintf("merge into %s using (values %s) as temp (%s) on %s when not matched then insert (%s) values (%s);",
92-
s.Table,
93-
holders,
94-
s.Field,
95-
onDupe,
96-
s.Field,
97-
value,
98-
)
107+
tx, err := s.DB.Begin()
108+
if err != nil {
109+
return -1, err
110+
}
111+
var c int64
112+
c = 0
113+
for _, e := range values {
114+
query := fmt.Sprintf("insert or ignore into %s (%s) values (?)", s.Table, s.Field)
115+
res, err := tx.ExecContext(ctx, query, e)
116+
if err != nil {
117+
er := tx.Rollback()
118+
if er != nil {
119+
return -1, er
120+
}
121+
return -1, err
122+
}
123+
a, err := res.RowsAffected()
124+
if err != nil {
125+
return -1, err
126+
}
127+
c = c + a
128+
}
129+
err = tx.Commit()
130+
if err != nil {
131+
return -1, err
132+
}
133+
return c, nil
99134
} else {
100-
return 0, fmt.Errorf("unsupported db vendor, current vendor is %s", driver)
101-
}
102-
mainScope.Query = query
103-
x, err := s.DB.ExecContext(ctx, mainScope.Query, mainScope.Values...)
104-
if err != nil {
105-
return 0, err
135+
mainScope := BatchStatement{}
136+
for _, e := range values {
137+
mainScope.Values = append(mainScope.Values, e)
138+
}
139+
query := ""
140+
holders := BuildPlaceHolders(len(mainScope.Values), s.BuildParam)
141+
if driver == DriverOracle || driver == DriverMssql {
142+
onDupe := s.Table + "." + s.Field + " = " + "temp." + s.Field
143+
value := "temp." + s.Field
144+
query = fmt.Sprintf("merge into %s using (values %s) as temp (%s) on %s when not matched then insert (%s) values (%s);",
145+
s.Table,
146+
holders,
147+
s.Field,
148+
onDupe,
149+
s.Field,
150+
value,
151+
)
152+
} else {
153+
return 0, fmt.Errorf("unsupported db vendor, current vendor is %s", driver)
154+
}
155+
mainScope.Query = query
156+
x, err := s.DB.ExecContext(ctx, mainScope.Query, mainScope.Values...)
157+
if err != nil {
158+
return 0, err
159+
}
160+
return x.RowsAffected()
106161
}
107-
return x.RowsAffected()
108162
}
109163

110164
func (s *StringService) Delete(ctx context.Context, values []string) (int64, error) {
111165
var arrValue []string
112166
le := len(values)
167+
buildParam := GetBuild(s.DB)
168+
p := make([]interface{}, 0)
169+
for _, str := range values {
170+
p = append(p, str)
171+
}
113172
for i := 1; i <= le; i++ {
114-
param := BuildParam(i)
173+
param := buildParam(i)
115174
arrValue = append(arrValue, param)
116175
}
117176
query := `delete from ` + s.Table + ` where ` + s.Field + ` in (` + strings.Join(arrValue, ",") + `)`
118-
x, err := s.DB.ExecContext(ctx, query)
177+
x, err := s.DB.ExecContext(ctx, query, p...)
119178
if err != nil {
120179
return 0, err
121180
}

util.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ type FieldDB struct {
3535
type Schema struct {
3636
SKeys []string
3737
SColumns []string
38-
Keys []FieldDB
39-
Columns []FieldDB
40-
Fields map[string]FieldDB
38+
Keys []*FieldDB
39+
Columns []*FieldDB
40+
Fields map[string]*FieldDB
4141
}
4242

4343
func CreateSchema(modelType reflect.Type) *Schema {
@@ -48,9 +48,9 @@ func CreateSchema(modelType reflect.Type) *Schema {
4848
numField := m.NumField()
4949
scolumns := make([]string, 0)
5050
skeys := make([]string, 0)
51-
columns := make([]FieldDB, 0)
52-
keys := make([]FieldDB, 0)
53-
schema := make(map[string]FieldDB, 0)
51+
columns := make([]*FieldDB, 0)
52+
keys := make([]*FieldDB, 0)
53+
schema := make(map[string]*FieldDB, 0)
5454
for idx := 0; idx < numField; idx++ {
5555
field := m.Field(idx)
5656
tag, _ := field.Tag.Lookup("gorm")
@@ -74,7 +74,7 @@ func CreateSchema(modelType reflect.Type) *Schema {
7474
tagJsons := strings.Split(jTag, ",")
7575
json = tagJsons[0]
7676
}
77-
f := FieldDB{
77+
f := &FieldDB{
7878
JSON: json,
7979
Column: col,
8080
Index: idx,
@@ -113,7 +113,7 @@ func CreateSchema(modelType reflect.Type) *Schema {
113113
s := &Schema{SColumns: scolumns, SKeys: skeys, Columns: columns, Keys: keys, Fields: schema}
114114
return s
115115
}
116-
func MakeSchema(modelType reflect.Type) ([]FieldDB, []FieldDB) {
116+
func MakeSchema(modelType reflect.Type) ([]*FieldDB, []*FieldDB) {
117117
m := CreateSchema(modelType)
118118
return m.Columns, m.Keys
119119
}

0 commit comments

Comments
 (0)