Skip to content

Commit 9f96444

Browse files
authored
refactor: Refactor query condition,support nested query (#54)
1 parent 4839bc3 commit 9f96444

File tree

9 files changed

+360
-194
lines changed

9 files changed

+360
-194
lines changed

constants/keyword.go

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@
1818
package constants
1919

2020
const (
21-
And = "AND"
22-
Or = "OR"
23-
In = "IN"
24-
Not = "NOT"
25-
Like = "LIKE"
26-
Eq = "="
27-
Ne = "<>"
28-
Gt = ">"
29-
Ge = ">="
30-
Lt = "<"
31-
Le = "<="
32-
Between = "BETWEEN"
33-
Desc = "DESC"
34-
Asc = "ASC"
35-
As = "AS"
36-
SUM = "SUM"
37-
AVG = "AVG"
38-
MAX = "MAX"
39-
MIN = "MIN"
40-
COUNT = "COUNT"
21+
And = "AND"
22+
Or = "OR"
23+
In = "IN"
24+
Not = "NOT"
25+
Like = "LIKE"
26+
Eq = "="
27+
Ne = "<>"
28+
Gt = ">"
29+
Ge = ">="
30+
Lt = "<"
31+
Le = "<="
32+
IsNull = "IS NULL"
33+
IsNotNull = "IS NOT NULL"
34+
Between = "BETWEEN"
35+
Desc = "DESC"
36+
Asc = "ASC"
37+
As = "AS"
38+
SUM = "SUM"
39+
AVG = "AVG"
40+
MAX = "MAX"
41+
MIN = "MIN"
42+
COUNT = "COUNT"
4143
)

gplus/dao.go

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
1817
package gplus
1918

2019
import (
2120
"database/sql"
22-
"reflect"
23-
2421
"github.com/acmestack/gorm-plus/constants"
2522
"gorm.io/gorm"
2623
"gorm.io/gorm/schema"
2724
"gorm.io/gorm/utils"
25+
"reflect"
26+
"strings"
2827
)
2928

3029
var globalDb *gorm.DB
@@ -45,8 +44,7 @@ type Page[T any] struct {
4544
type Dao[T any] struct{}
4645

4746
func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) {
48-
q := &QueryCond[T]{}
49-
return q, nil
47+
return NewQuery[T]()
5048
}
5149

5250
func NewPage[T any](current, size int) *Page[T] {
@@ -268,8 +266,11 @@ func paginate[T any](p *Page[T]) func(db *gorm.DB) *gorm.DB {
268266

269267
func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB {
270268
db := getDb(opts...)
269+
// 这里清空参数,避免用户重复使用一个query条件
270+
q.queryArgs = make([]any, 0)
271271
resultDb := db.Model(new(T))
272272
if q != nil {
273+
273274
if len(q.distinctColumns) > 0 {
274275
resultDb.Distinct(q.distinctColumns)
275276
}
@@ -278,19 +279,11 @@ func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB {
278279
resultDb.Select(q.selectColumns)
279280
}
280281

281-
if q.queryBuilder.Len() > 0 {
282-
283-
if q.andNestBuilder.Len() > 0 {
284-
q.queryArgs = append(q.queryArgs, q.andNestArgs...)
285-
q.queryBuilder.WriteString(q.andNestBuilder.String())
286-
}
287-
288-
if q.orNestBuilder.Len() > 0 {
289-
q.queryArgs = append(q.queryArgs, q.orNestArgs...)
290-
q.queryBuilder.WriteString(q.orNestBuilder.String())
291-
}
292-
293-
resultDb.Where(q.queryBuilder.String(), q.queryArgs...)
282+
expressions := q.queryExpressions
283+
if len(expressions) > 0 {
284+
var sqlBuilder strings.Builder
285+
q.queryArgs = buildSqlAndArgs[T](expressions, &sqlBuilder, q.queryArgs)
286+
resultDb.Where(sqlBuilder.String(), q.queryArgs...)
294287
}
295288

296289
if q.orderBuilder.Len() > 0 {
@@ -316,29 +309,31 @@ func buildCondition[T any](q *QueryCond[T], opts ...OptionFunc) *gorm.DB {
316309
return resultDb
317310
}
318311

319-
func getPkColumnName[T any]() string {
320-
var entity T
321-
entityType := reflect.TypeOf(entity)
322-
numField := entityType.NumField()
323-
var columnName string
324-
for i := 0; i < numField; i++ {
325-
field := entityType.Field(i)
326-
tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";")
327-
isPrimaryKey := utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"])
328-
if isPrimaryKey {
329-
name, ok := tagSetting["COLUMN"]
330-
if !ok {
331-
namingStrategy := schema.NamingStrategy{}
332-
name = namingStrategy.ColumnName("", field.Name)
312+
func buildSqlAndArgs[T any](expressions []any, sqlBuilder *strings.Builder, queryArgs []any) []any {
313+
for _, v := range expressions {
314+
// 判断是否是columnValue类型
315+
switch segment := v.(type) {
316+
case *columnPointer:
317+
sqlBuilder.WriteString(segment.getSqlSegment() + " ")
318+
case *sqlKeyword:
319+
sqlBuilder.WriteString(segment.getSqlSegment() + " ")
320+
case *columnValue:
321+
if segment.value == constants.And {
322+
sqlBuilder.WriteString(segment.value.(string) + " ")
323+
continue
333324
}
334-
columnName = name
335-
break
325+
if segment.value != "" {
326+
sqlBuilder.WriteString("? ")
327+
queryArgs = append(queryArgs, segment.value)
328+
}
329+
case *QueryCond[T]:
330+
sqlBuilder.WriteString(constants.LeftBracket + " ")
331+
// 递归处理条件
332+
queryArgs = buildSqlAndArgs[T](segment.queryExpressions, sqlBuilder, queryArgs)
333+
sqlBuilder.WriteString(constants.RightBracket + " ")
336334
}
337335
}
338-
if columnName == "" {
339-
return constants.DefaultPrimaryName
340-
}
341-
return columnName
336+
return queryArgs
342337
}
343338

344339
func getDb(opts ...OptionFunc) *gorm.DB {
@@ -359,6 +354,14 @@ func getDb(opts ...OptionFunc) *gorm.DB {
359354
return db
360355
}
361356

357+
func getOption(opts []OptionFunc) Option {
358+
var config Option
359+
for _, op := range opts {
360+
op(&config)
361+
}
362+
return config
363+
}
364+
362365
func setSelectIfNeed(option Option, db *gorm.DB) {
363366
if len(option.Selects) > 0 {
364367
var columnNames []string
@@ -381,10 +384,27 @@ func setOmitIfNeed(option Option, db *gorm.DB) {
381384
}
382385
}
383386

384-
func getOption(opts []OptionFunc) Option {
385-
var config Option
386-
for _, op := range opts {
387-
op(&config)
387+
func getPkColumnName[T any]() string {
388+
var entity T
389+
entityType := reflect.TypeOf(entity)
390+
numField := entityType.NumField()
391+
var columnName string
392+
for i := 0; i < numField; i++ {
393+
field := entityType.Field(i)
394+
tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";")
395+
isPrimaryKey := utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"])
396+
if isPrimaryKey {
397+
name, ok := tagSetting["COLUMN"]
398+
if !ok {
399+
namingStrategy := schema.NamingStrategy{}
400+
name = namingStrategy.ColumnName("", field.Name)
401+
}
402+
columnName = name
403+
break
404+
}
388405
}
389-
return config
406+
if columnName == "" {
407+
return constants.DefaultPrimaryName
408+
}
409+
return columnName
390410
}

gplus/function.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ func (f *Function) NotIn(values ...any) (string, []any) {
6767
}
6868

6969
func (f *Function) Between(start int64, end int64) (string, int64, int64) {
70-
return f.funStr + " " + constants.Between + " ? and ?", start, end
70+
return f.funStr + " " + constants.Between + " ? " + constants.And + " ?", start, end
7171
}
7272

7373
func (f *Function) NotBetween(start int64, end int64) (string, int64, int64) {
74-
return f.funStr + " " + constants.Not + " " + constants.Between + " ? and ?", start, end
74+
return f.funStr + " " + constants.Not + " " + constants.Between + " ? " + constants.And + " ?", start, end
7575
}
7676

7777
func Sum(columnName any) *Function {

0 commit comments

Comments
 (0)