Skip to content

Commit 6dc0c04

Browse files
committed
template增加arg自定义函数
1 parent 3127417 commit 6dc0c04

File tree

6 files changed

+257
-63
lines changed

6 files changed

+257
-63
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,17 @@ DELETE FROM test_table
284284
{{where (ne .Id 0) "AND" "id" .Id "" | where (ne .Username "") "AND" "username" .Username | where (ne .Password "") "AND" "password" .Password}}
285285
{{end}}
286286
```
287-
其中where和set是gobatis的自定义函数,用于智能的生成动态sql
287+
其中where、set、arg是gobatis的自定义函数,用于智能生成动态sql
288+
289+
arg用于将对象动态转换为占位符,并保存为SQL参数,如:
290+
```cassandraql
291+
SELECT * FROM TABLENAME WHERE name = {{arg .Name}}
292+
```
293+
以mysql为例,将解析为:
294+
```cassandraql
295+
SELECT * FROM TABLENAME WHERE name = ?
296+
```
297+
同时Name的值将自动保存为SQL参数,自动传入,起到类似xml中#{MODEL.Name}的效果。
288298

289299
### 9、gobatis-cmd生成文件使用示例
290300

parsing/template/dynamic.go

Lines changed: 210 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,22 @@ package template
77

88
import (
99
"fmt"
10+
"strconv"
1011
"strings"
1112
"text/template"
1213
)
1314

15+
const (
16+
argPlaceHolder = "_xfali_Arg_Holder"
17+
argPlaceHolderLen = 17
18+
)
19+
20+
type Dynamic interface {
21+
getFuncMap() template.FuncMap
22+
getParam() []interface{}
23+
format(string) (string, []interface{})
24+
}
25+
1426
func dummyUpdateSet(b bool, column string, value interface{}, origin string) string {
1527
return origin
1628
}
@@ -19,11 +31,57 @@ func dummyWhere(b bool, cond, column string, value interface{}, origin string) s
1931
return origin
2032
}
2133

34+
func dummyParam(p interface{}) string {
35+
return fmt.Sprint(p)
36+
}
37+
2238
func commonAdd(a, b int) int {
2339
return a + b
2440
}
2541

26-
func mysqlUpdateSet(b bool, column string, value interface{}, origin string) string {
42+
type DummyDynamic struct{}
43+
44+
var dummyFuncMap = template.FuncMap{
45+
"set": dummyUpdateSet,
46+
"where": dummyWhere,
47+
"add": commonAdd,
48+
"arg": dummyParam,
49+
}
50+
51+
var gDummyDynamic = &DummyDynamic{}
52+
53+
func (d *DummyDynamic) getFuncMap() template.FuncMap {
54+
return dummyFuncMap
55+
}
56+
57+
func (d *DummyDynamic) getParam() []interface{} {
58+
return nil
59+
}
60+
61+
func (d *DummyDynamic) format(s string) (string, []interface{}) {
62+
return s, nil
63+
}
64+
65+
type MysqlDynamic struct {
66+
index int
67+
keys []string
68+
paramMap map[string]interface{}
69+
}
70+
71+
func (d *MysqlDynamic) getFuncMap() template.FuncMap {
72+
return template.FuncMap{
73+
"set": d.mysqlUpdateSet,
74+
"where": d.mysqlWhere,
75+
"add": commonAdd,
76+
"arg": d.Param,
77+
}
78+
}
79+
80+
func (d *MysqlDynamic) getParam() []interface{} {
81+
return nil
82+
}
83+
84+
func (d *MysqlDynamic) mysqlUpdateSet(b bool, column string, value interface{}, origin string) string {
2785
if !b {
2886
return origin
2987
}
@@ -42,74 +100,121 @@ func mysqlUpdateSet(b bool, column string, value interface{}, origin string) str
42100
buf.WriteString(column)
43101
buf.WriteString("` = ")
44102
if s, ok := value.(string); ok {
45-
buf.WriteString(`'`)
46-
buf.WriteString(s)
47-
buf.WriteString(`'`)
103+
if _, ok := d.paramMap[s]; ok {
104+
buf.WriteString(s)
105+
} else {
106+
buf.WriteString(`'`)
107+
buf.WriteString(s)
108+
buf.WriteString(`'`)
109+
}
48110
} else {
49111
buf.WriteString(fmt.Sprint(value))
50112
}
51113
return buf.String()
52114
}
53115

54-
func postgresUpdateSet(b bool, column string, value interface{}, origin string) string {
116+
func (d *MysqlDynamic) mysqlWhere(b bool, cond, column string, value interface{}, origin string) string {
55117
if !b {
56118
return origin
57119
}
58120

59121
buf := strings.Builder{}
60122
if origin == "" {
61-
buf.WriteString(" SET ")
123+
buf.WriteString(" WHERE ")
124+
cond = ""
62125
} else {
63-
origin = strings.TrimSpace(origin)
64-
buf.WriteString(origin)
65-
if origin[:len(origin)-1] != "," {
66-
buf.WriteString(",")
67-
}
126+
buf.WriteString(strings.TrimSpace(origin))
127+
buf.WriteString(" ")
128+
buf.WriteString(cond)
129+
buf.WriteString(" ")
68130
}
69-
buf.WriteString(`"`)
131+
132+
buf.WriteString("`")
70133
buf.WriteString(column)
71-
buf.WriteString(`"`)
72-
buf.WriteString(" = ")
134+
buf.WriteString("` = ")
73135
if s, ok := value.(string); ok {
74-
buf.WriteString(`'`)
75-
buf.WriteString(s)
76-
buf.WriteString(`'`)
136+
if _, ok := d.paramMap[s]; ok {
137+
buf.WriteString(s)
138+
} else {
139+
buf.WriteString(`'`)
140+
buf.WriteString(s)
141+
buf.WriteString(`'`)
142+
}
77143
} else {
78144
buf.WriteString(fmt.Sprint(value))
79145
}
80146
return buf.String()
81147
}
82148

83-
func mysqlWhere(b bool, cond, column string, value interface{}, origin string) string {
149+
func (d *MysqlDynamic) Param(p interface{}) string {
150+
d.index++
151+
key := argPlaceHolder + strconv.Itoa(d.index)
152+
d.paramMap[key] = p
153+
d.keys = append(d.keys, key)
154+
return key
155+
}
156+
157+
func (d *MysqlDynamic) format(s string) (string, []interface{}) {
158+
i := 0
159+
var params []interface{}
160+
for _, k := range d.keys {
161+
s, i = replace(s, k, "?", -1)
162+
if i > 0 {
163+
params = append(params, d.paramMap[k])
164+
}
165+
}
166+
return s, params
167+
}
168+
169+
type PostgresDynamic struct {
170+
index int
171+
keys [] string
172+
paramMap map[string]interface{}
173+
}
174+
175+
func (d *PostgresDynamic) getFuncMap() template.FuncMap {
176+
return template.FuncMap{
177+
"set": d.postgresUpdateSet,
178+
"where": d.postgresWhere,
179+
"add": commonAdd,
180+
"arg": d.Param,
181+
}
182+
}
183+
184+
func (d *PostgresDynamic) postgresUpdateSet(b bool, column string, value interface{}, origin string) string {
84185
if !b {
85186
return origin
86187
}
87188

88189
buf := strings.Builder{}
89190
if origin == "" {
90-
buf.WriteString(" WHERE ")
91-
cond = ""
191+
buf.WriteString(" SET ")
92192
} else {
93-
buf.WriteString(strings.TrimSpace(origin))
94-
buf.WriteString(" ")
95-
buf.WriteString(cond)
96-
buf.WriteString(" ")
193+
origin = strings.TrimSpace(origin)
194+
buf.WriteString(origin)
195+
if origin[:len(origin)-1] != "," {
196+
buf.WriteString(",")
197+
}
97198
}
98-
99-
buf.WriteString("`")
199+
buf.WriteString(`"`)
100200
buf.WriteString(column)
101-
buf.WriteString("` = ")
201+
buf.WriteString(`"`)
202+
buf.WriteString(" = ")
102203
if s, ok := value.(string); ok {
103-
buf.WriteString(`'`)
104-
buf.WriteString(s)
105-
buf.WriteString(`'`)
204+
if _, ok := d.paramMap[s]; ok {
205+
buf.WriteString(s)
206+
} else {
207+
buf.WriteString(`'`)
208+
buf.WriteString(s)
209+
buf.WriteString(`'`)
210+
}
106211
} else {
107212
buf.WriteString(fmt.Sprint(value))
108213
}
109214
return buf.String()
110215
}
111216

112-
func postgresWhere(b bool, cond, column string, value interface{}, origin string) string {
217+
func (d *PostgresDynamic) postgresWhere(b bool, cond, column string, value interface{}, origin string) string {
113218
if !b {
114219
return origin
115220
}
@@ -130,41 +235,94 @@ func postgresWhere(b bool, cond, column string, value interface{}, origin string
130235
buf.WriteString(`"`)
131236
buf.WriteString(" = ")
132237
if s, ok := value.(string); ok {
133-
buf.WriteString(`'`)
134-
buf.WriteString(s)
135-
buf.WriteString(`'`)
238+
if _, ok := d.paramMap[s]; ok {
239+
buf.WriteString(s)
240+
} else {
241+
buf.WriteString(`'`)
242+
buf.WriteString(s)
243+
buf.WriteString(`'`)
244+
}
136245
} else {
137246
buf.WriteString(fmt.Sprint(value))
138247
}
139248
return buf.String()
140249
}
141250

142-
var mysqlFuncMap = template.FuncMap{
143-
"set": mysqlUpdateSet,
144-
"where": mysqlWhere,
145-
"add": commonAdd,
251+
func (d *PostgresDynamic) getParam() []interface{} {
252+
return nil
146253
}
147254

148-
var postgresFuncMap = template.FuncMap{
149-
"set": postgresUpdateSet,
150-
"where": postgresWhere,
151-
"add": commonAdd,
255+
func (d *PostgresDynamic) Param(p interface{}) string {
256+
d.index++
257+
key := argPlaceHolder + strconv.Itoa(d.index)
258+
d.paramMap[key] = p
259+
d.keys = append(d.keys, key)
260+
return key
152261
}
153262

154-
var dummyFuncMap = template.FuncMap{
155-
"set": dummyUpdateSet,
156-
"where": dummyWhere,
157-
"add": commonAdd,
263+
func (d *PostgresDynamic) format(s string) (string, []interface{}) {
264+
i, index := 0, 1
265+
var params []interface{}
266+
for _, k := range d.keys {
267+
s, i = replace(s, k, "$"+strconv.Itoa(index), -1)
268+
if i > 0 {
269+
params = append(params, d.paramMap[k])
270+
index++
271+
}
272+
}
273+
return s, params
158274
}
159275

160-
var funcMap = map[string]template.FuncMap{
161-
"mysql": mysqlFuncMap,
162-
"postgres": postgresFuncMap,
276+
var dynamicMap = map[string]Dynamic{
277+
"mysql": &MysqlDynamic{paramMap: map[string]interface{}{}},
278+
"postgres": &PostgresDynamic{paramMap: map[string]interface{}{}},
163279
}
164280

165-
func selectFuncMap(driverName string) template.FuncMap {
166-
if v, ok := funcMap[driverName]; ok {
281+
func selectDynamic(driverName string) Dynamic {
282+
if v, ok := dynamicMap[driverName]; ok {
167283
return v
168284
}
169-
return dummyFuncMap
285+
return gDummyDynamic
286+
}
287+
288+
func replace(s, old, new string, n int) (string, int) {
289+
if old == new || n == 0 {
290+
return s, 0 // avoid allocation
291+
}
292+
293+
if old == "" {
294+
return s, 0
295+
}
296+
297+
if n < 0 {
298+
if m := strings.Count(s, old); m == 0 {
299+
return s, 0 // avoid allocation
300+
} else if n < 0 || m < n {
301+
n = m
302+
}
303+
}
304+
makeSize := len(s) + n*(len(new)-len(old))
305+
// Apply replacements to buffer.
306+
t := make([]byte, makeSize)
307+
w, count := 0, 0
308+
start := 0
309+
for {
310+
if n == 0 {
311+
break
312+
}
313+
j := start
314+
index := strings.Index(s[start:], old)
315+
if index == -1 {
316+
return string(t[0:w]), count
317+
} else {
318+
j += index
319+
count++
320+
}
321+
w += copy(t[w:], s[start:j])
322+
w += copy(t[w:], new)
323+
start = j + len(old)
324+
n--
325+
}
326+
w += copy(t[w:], s[start:])
327+
return string(t[0:w]), count
170328
}

parsing/template/parse.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ func (p *Parser) ParseMetadata(driverName string, params ...interface{}) (*sqlpa
2525
if len(params) > 0 {
2626
param = params[0]
2727
}
28-
fm := selectFuncMap(driverName)
29-
tpl := p.tpl.Funcs(fm)
28+
dynamic := selectDynamic(driverName)
29+
tpl := p.tpl.Funcs(dynamic.getFuncMap())
3030
err := tpl.Execute(&b, param)
3131
if err != nil {
3232
return nil, err
@@ -37,8 +37,7 @@ func (p *Parser) ParseMetadata(driverName string, params ...interface{}) (*sqlpa
3737
action := sql[:6]
3838
action = strings.ToLower(action)
3939
ret.Action = action
40-
ret.PrepareSql = sql
41-
ret.Params = nil
40+
ret.PrepareSql, ret.Params = dynamic.format(sql)
4241

4342
return ret, nil
4443
}

0 commit comments

Comments
 (0)