Skip to content

Commit 975bcfd

Browse files
committed
fix #169: support IS [NOT] DISTINCT FROM operator
1 parent fa64168 commit 975bcfd

File tree

3 files changed

+593
-332
lines changed

3 files changed

+593
-332
lines changed

args.go

Lines changed: 144 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type Args struct {
1616
// The default flavor used by `Args#Compile`
1717
Flavor Flavor
1818

19-
args []interface{}
19+
argValues []interface{}
2020
namedArgs map[string]int
2121
sqlNamedArgs map[string]int
2222
onlyNamed bool
@@ -47,7 +47,7 @@ func (args *Args) Add(arg interface{}) string {
4747
}
4848

4949
func (args *Args) add(arg interface{}) int {
50-
idx := len(args.args)
50+
idx := len(args.argValues)
5151

5252
switch a := arg.(type) {
5353
case sql.NamedArg:
@@ -56,7 +56,7 @@ func (args *Args) add(arg interface{}) int {
5656
}
5757

5858
if p, ok := args.sqlNamedArgs[a.Name]; ok {
59-
arg = args.args[p]
59+
arg = args.argValues[p]
6060
break
6161
}
6262

@@ -67,7 +67,7 @@ func (args *Args) add(arg interface{}) int {
6767
}
6868

6969
if p, ok := args.namedArgs[a.name]; ok {
70-
arg = args.args[p]
70+
arg = args.argValues[p]
7171
break
7272
}
7373

@@ -77,7 +77,7 @@ func (args *Args) add(arg interface{}) int {
7777
return idx
7878
}
7979

80-
args.args = append(args.args, arg)
80+
args.argValues = append(args.argValues, arg)
8181
return idx
8282
}
8383

@@ -97,55 +97,58 @@ func (args *Args) Compile(format string, initialValue ...interface{}) (query str
9797
//
9898
// See doc for `Compile` to learn details.
9999
func (args *Args) CompileWithFlavor(format string, flavor Flavor, initialValue ...interface{}) (query string, values []interface{}) {
100-
buf := newStringBuilder()
101100
idx := strings.IndexRune(format, '$')
102101
offset := 0
103-
values = initialValue
102+
ctx := &argsCompileContext{
103+
stringBuilder: newStringBuilder(),
104+
Flavor: flavor,
105+
Values: initialValue,
106+
}
104107

105-
if flavor == invalidFlavor {
106-
flavor = DefaultFlavor
108+
if ctx.Flavor == invalidFlavor {
109+
ctx.Flavor = DefaultFlavor
107110
}
108111

109112
for idx >= 0 && len(format) > 0 {
110113
if idx > 0 {
111-
buf.WriteString(format[:idx])
114+
ctx.WriteString(format[:idx])
112115
}
113116

114117
format = format[idx+1:]
115118

116119
// Treat the $ at the end of format is a normal $ rune.
117120
if len(format) == 0 {
118-
buf.WriteRune('$')
121+
ctx.WriteRune('$')
119122
break
120123
}
121124

122125
if r := format[0]; r == '$' {
123-
buf.WriteRune('$')
126+
ctx.WriteRune('$')
124127
format = format[1:]
125128
} else if r == '{' {
126-
format, values = args.compileNamed(buf, flavor, format, values)
129+
format = args.compileNamed(ctx, format)
127130
} else if !args.onlyNamed && '0' <= r && r <= '9' {
128-
format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
131+
format, offset = args.compileDigits(ctx, format, offset)
129132
} else if !args.onlyNamed && r == '?' {
130-
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
133+
format, offset = args.compileSuccessive(ctx, format[1:], offset)
131134
} else {
132135
// For unknown $ expression format, treat it as a normal $ rune.
133-
buf.WriteRune('$')
136+
ctx.WriteRune('$')
134137
}
135138

136139
idx = strings.IndexRune(format, '$')
137140
}
138141

139142
if len(format) > 0 {
140-
buf.WriteString(format)
143+
ctx.WriteString(format)
141144
}
142145

143-
query = buf.String()
144-
values = args.mergeSQLNamedArgs(values)
146+
query = ctx.String()
147+
values = args.mergeSQLNamedArgs(ctx)
145148
return
146149
}
147150

148-
func (args *Args) compileNamed(buf *stringBuilder, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
151+
func (args *Args) compileNamed(ctx *argsCompileContext, format string) string {
149152
i := 1
150153

151154
for ; i < len(format) && format[i] != '}'; i++ {
@@ -154,20 +157,20 @@ func (args *Args) compileNamed(buf *stringBuilder, flavor Flavor, format string,
154157

155158
// Invalid $ format. Ignore it.
156159
if i == len(format) {
157-
return format, values
160+
return format
158161
}
159162

160163
name := format[1:i]
161164
format = format[i+1:]
162165

163166
if p, ok := args.namedArgs[name]; ok {
164-
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
167+
format, _ = args.compileSuccessive(ctx, format, p)
165168
}
166169

167-
return format, values
170+
return format
168171
}
169172

170-
func (args *Args) compileDigits(buf *stringBuilder, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
173+
func (args *Args) compileDigits(ctx *argsCompileContext, format string, offset int) (string, int) {
171174
i := 1
172175

173176
for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
@@ -178,91 +181,37 @@ func (args *Args) compileDigits(buf *stringBuilder, flavor Flavor, format string
178181
format = format[i:]
179182

180183
if pointer, err := strconv.Atoi(digits); err == nil {
181-
return args.compileSuccessive(buf, flavor, format, values, pointer)
184+
return args.compileSuccessive(ctx, format, pointer)
182185
}
183186

184-
return format, values, offset
187+
return format, offset
185188
}
186189

187-
func (args *Args) compileSuccessive(buf *stringBuilder, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
188-
if offset >= len(args.args) {
189-
return format, values, offset
190+
func (args *Args) compileSuccessive(ctx *argsCompileContext, format string, offset int) (string, int) {
191+
if offset >= len(args.argValues) {
192+
return format, offset
190193
}
191194

192-
arg := args.args[offset]
193-
values = args.compileArg(buf, flavor, values, arg)
194-
195-
return format, values, offset + 1
196-
}
197-
198-
func (args *Args) compileArg(buf *stringBuilder, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
199-
switch a := arg.(type) {
200-
case Builder:
201-
var s string
202-
s, values = a.BuildWithFlavor(flavor, values...)
203-
buf.WriteString(s)
204-
case sql.NamedArg:
205-
buf.WriteRune('@')
206-
buf.WriteString(a.Name)
207-
case rawArgs:
208-
buf.WriteString(a.expr)
209-
case listArgs:
210-
if a.isTuple {
211-
buf.WriteRune('(')
212-
}
213-
214-
if len(a.args) > 0 {
215-
values = args.compileArg(buf, flavor, values, a.args[0])
216-
}
217-
218-
for i := 1; i < len(a.args); i++ {
219-
buf.WriteString(", ")
220-
values = args.compileArg(buf, flavor, values, a.args[i])
221-
}
222-
223-
if a.isTuple {
224-
buf.WriteRune(')')
225-
}
226-
227-
default:
228-
switch flavor {
229-
case MySQL, SQLite, CQL, ClickHouse, Presto, Informix:
230-
buf.WriteRune('?')
231-
case PostgreSQL:
232-
fmt.Fprintf(buf, "$%d", len(values)+1)
233-
case SQLServer:
234-
fmt.Fprintf(buf, "@p%d", len(values)+1)
235-
case Oracle:
236-
fmt.Fprintf(buf, ":%d", len(values)+1)
237-
default:
238-
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
239-
}
240-
241-
namedValues := parseNamedArgs(values)
195+
arg := args.argValues[offset]
196+
ctx.WriteValue(arg)
242197

243-
if n := len(namedValues); n == 0 {
244-
values = append(values, arg)
245-
} else {
246-
index := len(values) - n
247-
values = append(values[:index+1], namedValues...)
248-
values[index] = arg
249-
}
250-
}
251-
252-
return values
198+
return format, offset + 1
253199
}
254200

255-
func (args *Args) mergeSQLNamedArgs(values []interface{}) []interface{} {
256-
if len(args.sqlNamedArgs) == 0 {
257-
return values
201+
func (args *Args) mergeSQLNamedArgs(ctx *argsCompileContext) []interface{} {
202+
if len(args.sqlNamedArgs) == 0 && len(ctx.NamedArgs) == 0 {
203+
return ctx.Values
258204
}
259205

260-
namedValues := parseNamedArgs(values)
261-
existingNames := make(map[string]struct{}, len(namedValues))
206+
values := ctx.Values
207+
existingNames := make(map[string]struct{}, len(ctx.NamedArgs))
262208

263-
for _, v := range namedValues {
264-
if a, ok := v.(sql.NamedArg); ok {
265-
existingNames[a.Name] = struct{}{}
209+
// Add all named args to values.
210+
// Remove duplicated named args in this step.
211+
for _, arg := range ctx.NamedArgs {
212+
if _, ok := existingNames[arg.Name]; !ok {
213+
existingNames[arg.Name] = struct{}{}
214+
values = append(values, arg)
266215
}
267216
}
268217

@@ -280,19 +229,21 @@ func (args *Args) mergeSQLNamedArgs(values []interface{}) []interface{} {
280229
sort.Ints(ints)
281230

282231
for _, i := range ints {
283-
values = append(values, args.args[i])
232+
values = append(values, args.argValues[i])
284233
}
285234

286235
return values
287236
}
288237

289-
func parseNamedArgs(initialValue []interface{}) (namedValues []interface{}) {
238+
func parseNamedArgs(initialValue []interface{}) (values []interface{}, namedValues []sql.NamedArg) {
290239
if len(initialValue) == 0 {
291-
return nil
240+
values = initialValue
241+
return
292242
}
293243

294244
// sql.NamedArgs must be placed at the end of the initial value.
295-
i := len(initialValue)
245+
size := len(initialValue)
246+
i := size
296247

297248
for ; i > 0; i-- {
298249
switch initialValue[i-1].(type) {
@@ -303,6 +254,97 @@ func parseNamedArgs(initialValue []interface{}) (namedValues []interface{}) {
303254
break
304255
}
305256

306-
namedValues = initialValue[i:]
257+
if i == size {
258+
values = initialValue
259+
return
260+
}
261+
262+
values = initialValue[:i]
263+
namedValues = make([]sql.NamedArg, 0, size-i)
264+
265+
for ; i < size; i++ {
266+
namedValues = append(namedValues, initialValue[i].(sql.NamedArg))
267+
}
268+
307269
return
308270
}
271+
272+
type argsCompileContext struct {
273+
*stringBuilder
274+
275+
Flavor Flavor
276+
Values []interface{}
277+
NamedArgs []sql.NamedArg
278+
}
279+
280+
func (ctx *argsCompileContext) WriteValue(arg interface{}) {
281+
switch a := arg.(type) {
282+
case Builder:
283+
s, values := a.BuildWithFlavor(ctx.Flavor, ctx.Values...)
284+
ctx.WriteString(s)
285+
286+
// Add all values to ctx.
287+
// Named args must be located at the end of values.
288+
values, namedArgs := parseNamedArgs(values)
289+
ctx.Values = values
290+
ctx.NamedArgs = append(ctx.NamedArgs, namedArgs...)
291+
292+
case sql.NamedArg:
293+
ctx.WriteRune('@')
294+
ctx.WriteString(a.Name)
295+
ctx.NamedArgs = append(ctx.NamedArgs, a)
296+
297+
case rawArgs:
298+
ctx.WriteString(a.expr)
299+
300+
case listArgs:
301+
if a.isTuple {
302+
ctx.WriteRune('(')
303+
}
304+
305+
if len(a.args) > 0 {
306+
ctx.WriteValue(a.args[0])
307+
}
308+
309+
for i := 1; i < len(a.args); i++ {
310+
ctx.WriteString(", ")
311+
ctx.WriteValue(a.args[i])
312+
}
313+
314+
if a.isTuple {
315+
ctx.WriteRune(')')
316+
}
317+
318+
case condBuilder:
319+
a.Builder(ctx)
320+
321+
default:
322+
switch ctx.Flavor {
323+
case MySQL, SQLite, CQL, ClickHouse, Presto, Informix:
324+
ctx.WriteRune('?')
325+
case PostgreSQL:
326+
fmt.Fprintf(ctx, "$%d", len(ctx.Values)+1)
327+
case SQLServer:
328+
fmt.Fprintf(ctx, "@p%d", len(ctx.Values)+1)
329+
case Oracle:
330+
fmt.Fprintf(ctx, ":%d", len(ctx.Values)+1)
331+
default:
332+
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", ctx.Flavor, int(ctx.Flavor)))
333+
}
334+
335+
ctx.Values = append(ctx.Values, arg)
336+
}
337+
}
338+
339+
func (ctx *argsCompileContext) WriteValues(values []interface{}, sep string) {
340+
if len(values) == 0 {
341+
return
342+
}
343+
344+
ctx.WriteValue(values[0])
345+
346+
for _, v := range values[1:] {
347+
ctx.WriteString(sep)
348+
ctx.WriteValue(v)
349+
}
350+
}

0 commit comments

Comments
 (0)