Skip to content

Commit a2bb7a1

Browse files
committed
MySQL: 新增支持 slice 的方式传参,以支持 in 子句。
1 parent a8a258b commit a2bb7a1

File tree

3 files changed

+167
-2
lines changed

3 files changed

+167
-2
lines changed

mysql/mysql_db_client.go

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ func bindArgs(sqlText string, args ...any) (string, []any, error) {
121121
return "", nil, fmt.Errorf("%w:\nlack of parameter\nsql = %s", sqlmer.ErrParseParamFailed, namedParsedResult.Sql)
122122
}
123123
}
124+
namedParsedResult.Sql, resultArgs = extendInParams(namedParsedResult.Sql, resultArgs)
124125
return namedParsedResult.Sql, resultArgs, nil
125126
}
126127

127128
// slice 语句中使用的顺序未必是递增的,切可能重复引用同一个索引,所以这里也需要整理顺序。
128129
for _, paramName := range namedParsedResult.Names {
129-
130130
if paramName[0] != 'p' { // 要求数字参数的格式为 @p1....@pN
131131
return "", nil, fmt.Errorf("%w: parsing parameter failed\nsql = %s", sqlmer.ErrParseParamFailed, namedParsedResult.Sql)
132132
}
@@ -149,6 +149,7 @@ func bindArgs(sqlText string, args ...any) (string, []any, error) {
149149
resultArgs = append(resultArgs, args[index])
150150
}
151151

152+
namedParsedResult.Sql, resultArgs = extendInParams(namedParsedResult.Sql, resultArgs)
152153
return namedParsedResult.Sql, resultArgs, nil
153154
}
154155

@@ -245,6 +246,67 @@ func parseMySqlNamedSql(sqlText string) *mysqlNamedParsedResult {
245246
return parsedResult
246247
}
247248

249+
// extendInParams 用于处理 SQL IN 子句的参数展开
250+
// 将切片类型的参数展开为多个问号占位符
251+
func extendInParams(sqlText string, params []any) (string, []any) {
252+
var newParams []any = make([]any, 0, len(params))
253+
var newSqlBuilder strings.Builder
254+
255+
paramIndex := 0
256+
inString := false
257+
for _, r := range sqlText {
258+
if r == '\'' {
259+
inString = !inString
260+
newSqlBuilder.WriteRune(r)
261+
continue
262+
}
263+
264+
if inString {
265+
newSqlBuilder.WriteRune(r)
266+
continue
267+
}
268+
269+
if r != '?' {
270+
newSqlBuilder.WriteRune(r)
271+
continue
272+
}
273+
274+
if paramIndex >= len(params) { // 后面没参数了,无需判断了。
275+
newSqlBuilder.WriteRune(r)
276+
continue
277+
}
278+
279+
param := params[paramIndex]
280+
paramValue := reflect.ValueOf(param)
281+
paramIndex++
282+
283+
// 处理切片类型。
284+
// 排除 []byte,因为虽然 []byte 也是切片类型,但是它是二进制数据,不应该被展开。
285+
if paramValue.Kind() == reflect.Slice && paramValue.Type() != reflect.TypeOf([]byte{}) {
286+
paramLen := paramValue.Len()
287+
if paramLen == 0 {
288+
// 空切片替换为 SQL 不可能条件。
289+
newSqlBuilder.WriteString("NULL")
290+
continue
291+
}
292+
293+
// 生成占位符。
294+
placeholders := strings.Repeat(",?", paramLen)
295+
newSqlBuilder.WriteString(placeholders[1:])
296+
297+
// 展开参数。
298+
for i := 0; i < paramLen; i++ {
299+
newParams = append(newParams, paramValue.Index(i).Interface())
300+
}
301+
} else {
302+
newSqlBuilder.WriteRune('?')
303+
newParams = append(newParams, param)
304+
}
305+
}
306+
307+
return newSqlBuilder.String(), newParams
308+
}
309+
248310
// getScanTypeFn 根据驱动配置返回一个可以正确获取 Scan 类型的函数。
249311
func getScanTypeFn(cfg *mysqlDriver.Config) sqlen.GetScanTypeFunc {
250312
var scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})

mysql/mysql_db_client_test.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func Test_internalDbClient_Scalar(t *testing.T) {
5252
wantErr bool
5353
}{
5454
{
55-
"mysql",
55+
"mysql1",
5656
mysqlClient,
5757
args{
5858
"SELECT Id FROM go_TypeTest WHERE id=@p1",
@@ -61,6 +61,26 @@ func Test_internalDbClient_Scalar(t *testing.T) {
6161
int64(1),
6262
false,
6363
},
64+
{
65+
"mysql2",
66+
mysqlClient,
67+
args{
68+
"SELECT Id FROM go_TypeTest WHERE id in (@p1)",
69+
[]any{[]int{1}},
70+
},
71+
int64(1),
72+
false,
73+
},
74+
{
75+
"mysql3",
76+
mysqlClient,
77+
args{
78+
"SELECT COUNT(1) FROM go_TypeTest WHERE id in (@p1)",
79+
[]any{[]int{1, 2, 3}},
80+
},
81+
int64(3),
82+
false,
83+
},
6484
}
6585
for _, tt := range tests {
6686
t.Run(tt.name, func(t *testing.T) {

mysql/mysql_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,64 @@ func Test_parseMySqlNamedSql(t *testing.T) {
4343
}
4444
}
4545

46+
func TestExtendInParams(t *testing.T) {
47+
tests := []struct {
48+
name string
49+
sql string
50+
params []any
51+
expSQL string
52+
expParams []any
53+
}{
54+
{
55+
"single",
56+
"select 1 from t where id = ?",
57+
[]any{1},
58+
"select 1 from t where id = ?",
59+
[]any{1},
60+
},
61+
{
62+
"slice1",
63+
"select 1 from t where id in (?)",
64+
[]any{[]int{1, 2, 3}},
65+
"select 1 from t where id in (?,?,?)",
66+
[]any{1, 2, 3},
67+
},
68+
{
69+
"singleWithSlice",
70+
"select 1 from t where id!=? AND id in (?)",
71+
[]any{5, []int{1, 2, 3}},
72+
"select 1 from t where id!=? AND id in (?,?,?)",
73+
[]any{5, 1, 2, 3},
74+
},
75+
{
76+
"empty",
77+
"select 1 from t where id in (?)",
78+
[]any{[]int{}},
79+
"select 1 from t where id in (NULL)",
80+
[]any{},
81+
},
82+
{
83+
"regular",
84+
"select 1 from t where name = ? and age = ?",
85+
[]any{"Alice", 30},
86+
"select 1 from t where name = ? and age = ?",
87+
[]any{"Alice", 30},
88+
},
89+
}
90+
91+
for _, tt := range tests {
92+
t.Run(tt.name, func(t *testing.T) {
93+
gotSQL, gotParams := extendInParams(tt.sql, tt.params)
94+
if gotSQL != tt.expSQL {
95+
t.Errorf("Expected SQL: %s, got: %s", tt.expSQL, gotSQL)
96+
}
97+
if !reflect.DeepEqual(gotParams, tt.expParams) {
98+
t.Errorf("Expected Params: %v, got: %v", tt.expParams, gotParams)
99+
}
100+
})
101+
}
102+
}
103+
46104
func Test_bindMySqlArgs(t *testing.T) {
47105
testCases := []struct {
48106
name string
@@ -157,6 +215,31 @@ func Test_bindMySqlArgs(t *testing.T) {
157215
[]any{3, 3},
158216
nil,
159217
},
218+
{
219+
"inwhere1",
220+
"SELECT * FROM go_TypeTest WHERE id IN (@ids)",
221+
[]any{
222+
map[string]any{
223+
"ids": []int{1, 2, 3},
224+
},
225+
},
226+
"SELECT * FROM go_TypeTest WHERE id IN (?,?,?)",
227+
[]any{1, 2, 3},
228+
nil,
229+
},
230+
{
231+
"inwhere2",
232+
"SELECT * FROM go_TypeTest WHERE id!=@noid AND id IN (@ids)",
233+
[]any{
234+
map[string]any{
235+
"noid": 4,
236+
"ids": []int{1, 2, 3},
237+
},
238+
},
239+
"SELECT * FROM go_TypeTest WHERE id!=? AND id IN (?,?,?)",
240+
[]any{4, 1, 2, 3},
241+
nil,
242+
},
160243
}
161244

162245
for _, tt := range testCases {

0 commit comments

Comments
 (0)