@@ -16,7 +16,7 @@ type Args struct {
16
16
// The default flavor used by `Args#Compile`
17
17
Flavor Flavor
18
18
19
- args []interface {}
19
+ argValues []interface {}
20
20
namedArgs map [string ]int
21
21
sqlNamedArgs map [string ]int
22
22
onlyNamed bool
@@ -47,7 +47,7 @@ func (args *Args) Add(arg interface{}) string {
47
47
}
48
48
49
49
func (args * Args ) add (arg interface {}) int {
50
- idx := len (args .args )
50
+ idx := len (args .argValues )
51
51
52
52
switch a := arg .(type ) {
53
53
case sql.NamedArg :
@@ -56,7 +56,7 @@ func (args *Args) add(arg interface{}) int {
56
56
}
57
57
58
58
if p , ok := args .sqlNamedArgs [a .Name ]; ok {
59
- arg = args .args [p ]
59
+ arg = args .argValues [p ]
60
60
break
61
61
}
62
62
@@ -67,7 +67,7 @@ func (args *Args) add(arg interface{}) int {
67
67
}
68
68
69
69
if p , ok := args .namedArgs [a .name ]; ok {
70
- arg = args .args [p ]
70
+ arg = args .argValues [p ]
71
71
break
72
72
}
73
73
@@ -77,7 +77,7 @@ func (args *Args) add(arg interface{}) int {
77
77
return idx
78
78
}
79
79
80
- args .args = append (args .args , arg )
80
+ args .argValues = append (args .argValues , arg )
81
81
return idx
82
82
}
83
83
@@ -97,55 +97,58 @@ func (args *Args) Compile(format string, initialValue ...interface{}) (query str
97
97
//
98
98
// See doc for `Compile` to learn details.
99
99
func (args * Args ) CompileWithFlavor (format string , flavor Flavor , initialValue ... interface {}) (query string , values []interface {}) {
100
- buf := newStringBuilder ()
101
100
idx := strings .IndexRune (format , '$' )
102
101
offset := 0
103
- values = initialValue
102
+ ctx := & argsCompileContext {
103
+ stringBuilder : newStringBuilder (),
104
+ Flavor : flavor ,
105
+ Values : initialValue ,
106
+ }
104
107
105
- if flavor == invalidFlavor {
106
- flavor = DefaultFlavor
108
+ if ctx . Flavor == invalidFlavor {
109
+ ctx . Flavor = DefaultFlavor
107
110
}
108
111
109
112
for idx >= 0 && len (format ) > 0 {
110
113
if idx > 0 {
111
- buf .WriteString (format [:idx ])
114
+ ctx .WriteString (format [:idx ])
112
115
}
113
116
114
117
format = format [idx + 1 :]
115
118
116
119
// Treat the $ at the end of format is a normal $ rune.
117
120
if len (format ) == 0 {
118
- buf .WriteRune ('$' )
121
+ ctx .WriteRune ('$' )
119
122
break
120
123
}
121
124
122
125
if r := format [0 ]; r == '$' {
123
- buf .WriteRune ('$' )
126
+ ctx .WriteRune ('$' )
124
127
format = format [1 :]
125
128
} else if r == '{' {
126
- format , values = args .compileNamed (buf , flavor , format , values )
129
+ format = args .compileNamed (ctx , format )
127
130
} else if ! args .onlyNamed && '0' <= r && r <= '9' {
128
- format , values , offset = args .compileDigits (buf , flavor , format , values , offset )
131
+ format , offset = args .compileDigits (ctx , format , offset )
129
132
} else if ! args .onlyNamed && r == '?' {
130
- format , values , offset = args .compileSuccessive (buf , flavor , format [1 :], values , offset )
133
+ format , offset = args .compileSuccessive (ctx , format [1 :], offset )
131
134
} else {
132
135
// For unknown $ expression format, treat it as a normal $ rune.
133
- buf .WriteRune ('$' )
136
+ ctx .WriteRune ('$' )
134
137
}
135
138
136
139
idx = strings .IndexRune (format , '$' )
137
140
}
138
141
139
142
if len (format ) > 0 {
140
- buf .WriteString (format )
143
+ ctx .WriteString (format )
141
144
}
142
145
143
- query = buf .String ()
144
- values = args .mergeSQLNamedArgs (values )
146
+ query = ctx .String ()
147
+ values = args .mergeSQLNamedArgs (ctx )
145
148
return
146
149
}
147
150
148
- func (args * Args ) compileNamed (buf * stringBuilder , flavor Flavor , format string , values [] interface {}) ( string , [] interface {}) {
151
+ func (args * Args ) compileNamed (ctx * argsCompileContext , format string ) string {
149
152
i := 1
150
153
151
154
for ; i < len (format ) && format [i ] != '}' ; i ++ {
@@ -154,20 +157,20 @@ func (args *Args) compileNamed(buf *stringBuilder, flavor Flavor, format string,
154
157
155
158
// Invalid $ format. Ignore it.
156
159
if i == len (format ) {
157
- return format , values
160
+ return format
158
161
}
159
162
160
163
name := format [1 :i ]
161
164
format = format [i + 1 :]
162
165
163
166
if p , ok := args .namedArgs [name ]; ok {
164
- format , values , _ = args .compileSuccessive (buf , flavor , format , values , p )
167
+ format , _ = args .compileSuccessive (ctx , format , p )
165
168
}
166
169
167
- return format , values
170
+ return format
168
171
}
169
172
170
- func (args * Args ) compileDigits (buf * stringBuilder , flavor Flavor , format string , values [] interface {}, offset int ) (string , [] interface {} , int ) {
173
+ func (args * Args ) compileDigits (ctx * argsCompileContext , format string , offset int ) (string , int ) {
171
174
i := 1
172
175
173
176
for ; i < len (format ) && '0' <= format [i ] && format [i ] <= '9' ; i ++ {
@@ -178,91 +181,37 @@ func (args *Args) compileDigits(buf *stringBuilder, flavor Flavor, format string
178
181
format = format [i :]
179
182
180
183
if pointer , err := strconv .Atoi (digits ); err == nil {
181
- return args .compileSuccessive (buf , flavor , format , values , pointer )
184
+ return args .compileSuccessive (ctx , format , pointer )
182
185
}
183
186
184
- return format , values , offset
187
+ return format , offset
185
188
}
186
189
187
- func (args * Args ) compileSuccessive (buf * stringBuilder , flavor Flavor , format string , values [] interface {}, offset int ) (string , [] interface {} , int ) {
188
- if offset >= len (args .args ) {
189
- return format , values , offset
190
+ func (args * Args ) compileSuccessive (ctx * argsCompileContext , format string , offset int ) (string , int ) {
191
+ if offset >= len (args .argValues ) {
192
+ return format , offset
190
193
}
191
194
192
- arg := args .args [offset ]
193
- values = args .compileArg (buf , flavor , values , arg )
194
-
195
- return format , values , offset + 1
196
- }
197
-
198
- func (args * Args ) compileArg (buf * stringBuilder , flavor Flavor , values []interface {}, arg interface {}) []interface {} {
199
- switch a := arg .(type ) {
200
- case Builder :
201
- var s string
202
- s , values = a .BuildWithFlavor (flavor , values ... )
203
- buf .WriteString (s )
204
- case sql.NamedArg :
205
- buf .WriteRune ('@' )
206
- buf .WriteString (a .Name )
207
- case rawArgs :
208
- buf .WriteString (a .expr )
209
- case listArgs :
210
- if a .isTuple {
211
- buf .WriteRune ('(' )
212
- }
213
-
214
- if len (a .args ) > 0 {
215
- values = args .compileArg (buf , flavor , values , a .args [0 ])
216
- }
217
-
218
- for i := 1 ; i < len (a .args ); i ++ {
219
- buf .WriteString (", " )
220
- values = args .compileArg (buf , flavor , values , a .args [i ])
221
- }
222
-
223
- if a .isTuple {
224
- buf .WriteRune (')' )
225
- }
226
-
227
- default :
228
- switch flavor {
229
- case MySQL , SQLite , CQL , ClickHouse , Presto , Informix :
230
- buf .WriteRune ('?' )
231
- case PostgreSQL :
232
- fmt .Fprintf (buf , "$%d" , len (values )+ 1 )
233
- case SQLServer :
234
- fmt .Fprintf (buf , "@p%d" , len (values )+ 1 )
235
- case Oracle :
236
- fmt .Fprintf (buf , ":%d" , len (values )+ 1 )
237
- default :
238
- panic (fmt .Errorf ("Args.CompileWithFlavor: invalid flavor %v (%v)" , flavor , int (flavor )))
239
- }
240
-
241
- namedValues := parseNamedArgs (values )
195
+ arg := args .argValues [offset ]
196
+ ctx .WriteValue (arg )
242
197
243
- if n := len (namedValues ); n == 0 {
244
- values = append (values , arg )
245
- } else {
246
- index := len (values ) - n
247
- values = append (values [:index + 1 ], namedValues ... )
248
- values [index ] = arg
249
- }
250
- }
251
-
252
- return values
198
+ return format , offset + 1
253
199
}
254
200
255
- func (args * Args ) mergeSQLNamedArgs (values [] interface {} ) []interface {} {
256
- if len (args .sqlNamedArgs ) == 0 {
257
- return values
201
+ func (args * Args ) mergeSQLNamedArgs (ctx * argsCompileContext ) []interface {} {
202
+ if len (args .sqlNamedArgs ) == 0 && len ( ctx . NamedArgs ) == 0 {
203
+ return ctx . Values
258
204
}
259
205
260
- namedValues := parseNamedArgs ( values )
261
- existingNames := make (map [string ]struct {}, len (namedValues ))
206
+ values := ctx . Values
207
+ existingNames := make (map [string ]struct {}, len (ctx . NamedArgs ))
262
208
263
- for _ , v := range namedValues {
264
- if a , ok := v .(sql.NamedArg ); ok {
265
- existingNames [a .Name ] = struct {}{}
209
+ // Add all named args to values.
210
+ // Remove duplicated named args in this step.
211
+ for _ , arg := range ctx .NamedArgs {
212
+ if _ , ok := existingNames [arg .Name ]; ! ok {
213
+ existingNames [arg .Name ] = struct {}{}
214
+ values = append (values , arg )
266
215
}
267
216
}
268
217
@@ -280,19 +229,21 @@ func (args *Args) mergeSQLNamedArgs(values []interface{}) []interface{} {
280
229
sort .Ints (ints )
281
230
282
231
for _ , i := range ints {
283
- values = append (values , args .args [i ])
232
+ values = append (values , args .argValues [i ])
284
233
}
285
234
286
235
return values
287
236
}
288
237
289
- func parseNamedArgs (initialValue []interface {}) (namedValues []interface {}) {
238
+ func parseNamedArgs (initialValue []interface {}) (values []interface {}, namedValues []sql. NamedArg ) {
290
239
if len (initialValue ) == 0 {
291
- return nil
240
+ values = initialValue
241
+ return
292
242
}
293
243
294
244
// sql.NamedArgs must be placed at the end of the initial value.
295
- i := len (initialValue )
245
+ size := len (initialValue )
246
+ i := size
296
247
297
248
for ; i > 0 ; i -- {
298
249
switch initialValue [i - 1 ].(type ) {
@@ -303,6 +254,97 @@ func parseNamedArgs(initialValue []interface{}) (namedValues []interface{}) {
303
254
break
304
255
}
305
256
306
- namedValues = initialValue [i :]
257
+ if i == size {
258
+ values = initialValue
259
+ return
260
+ }
261
+
262
+ values = initialValue [:i ]
263
+ namedValues = make ([]sql.NamedArg , 0 , size - i )
264
+
265
+ for ; i < size ; i ++ {
266
+ namedValues = append (namedValues , initialValue [i ].(sql.NamedArg ))
267
+ }
268
+
307
269
return
308
270
}
271
+
272
+ type argsCompileContext struct {
273
+ * stringBuilder
274
+
275
+ Flavor Flavor
276
+ Values []interface {}
277
+ NamedArgs []sql.NamedArg
278
+ }
279
+
280
+ func (ctx * argsCompileContext ) WriteValue (arg interface {}) {
281
+ switch a := arg .(type ) {
282
+ case Builder :
283
+ s , values := a .BuildWithFlavor (ctx .Flavor , ctx .Values ... )
284
+ ctx .WriteString (s )
285
+
286
+ // Add all values to ctx.
287
+ // Named args must be located at the end of values.
288
+ values , namedArgs := parseNamedArgs (values )
289
+ ctx .Values = values
290
+ ctx .NamedArgs = append (ctx .NamedArgs , namedArgs ... )
291
+
292
+ case sql.NamedArg :
293
+ ctx .WriteRune ('@' )
294
+ ctx .WriteString (a .Name )
295
+ ctx .NamedArgs = append (ctx .NamedArgs , a )
296
+
297
+ case rawArgs :
298
+ ctx .WriteString (a .expr )
299
+
300
+ case listArgs :
301
+ if a .isTuple {
302
+ ctx .WriteRune ('(' )
303
+ }
304
+
305
+ if len (a .args ) > 0 {
306
+ ctx .WriteValue (a .args [0 ])
307
+ }
308
+
309
+ for i := 1 ; i < len (a .args ); i ++ {
310
+ ctx .WriteString (", " )
311
+ ctx .WriteValue (a .args [i ])
312
+ }
313
+
314
+ if a .isTuple {
315
+ ctx .WriteRune (')' )
316
+ }
317
+
318
+ case condBuilder :
319
+ a .Builder (ctx )
320
+
321
+ default :
322
+ switch ctx .Flavor {
323
+ case MySQL , SQLite , CQL , ClickHouse , Presto , Informix :
324
+ ctx .WriteRune ('?' )
325
+ case PostgreSQL :
326
+ fmt .Fprintf (ctx , "$%d" , len (ctx .Values )+ 1 )
327
+ case SQLServer :
328
+ fmt .Fprintf (ctx , "@p%d" , len (ctx .Values )+ 1 )
329
+ case Oracle :
330
+ fmt .Fprintf (ctx , ":%d" , len (ctx .Values )+ 1 )
331
+ default :
332
+ panic (fmt .Errorf ("Args.CompileWithFlavor: invalid flavor %v (%v)" , ctx .Flavor , int (ctx .Flavor )))
333
+ }
334
+
335
+ ctx .Values = append (ctx .Values , arg )
336
+ }
337
+ }
338
+
339
+ func (ctx * argsCompileContext ) WriteValues (values []interface {}, sep string ) {
340
+ if len (values ) == 0 {
341
+ return
342
+ }
343
+
344
+ ctx .WriteValue (values [0 ])
345
+
346
+ for _ , v := range values [1 :] {
347
+ ctx .WriteString (sep )
348
+ ctx .WriteValue (v )
349
+ }
350
+ }
0 commit comments