@@ -25,7 +25,6 @@ import (
25
25
const (
26
26
pageHelperValue = "_page_helper_value"
27
27
orderHelperValue = "_order_helper_value"
28
- totalHelperValue = "_total_helper_value"
29
28
30
29
ASC = "ASC"
31
30
DESC = "DESC"
@@ -46,7 +45,8 @@ type PageParam struct {
46
45
Page int
47
46
PageSize int
48
47
49
- total bool
48
+ countColumn string
49
+ total int64
50
50
}
51
51
52
52
func New (f factory.Factory ) * Factory {
@@ -57,10 +57,10 @@ func GetTotal(ctx context.Context) int64 {
57
57
if ctx == nil {
58
58
return 0
59
59
}
60
- p := ctx .Value (totalHelperValue )
60
+ p := ctx .Value (pageHelperValue )
61
61
if p != nil {
62
- if t , ok := p .(int64 ); ok {
63
- return t
62
+ if param , ok := p .(* PageParam ); ok {
63
+ return param . total
64
64
}
65
65
}
66
66
return 0
@@ -80,15 +80,15 @@ type Executor struct {
80
80
//pageSize 分页大小
81
81
//ctx 初始context
82
82
func StartPage (page , pageSize int , ctx context.Context ) context.Context {
83
- return context .WithValue (ctx , pageHelperValue , & PageParam {Page : page , PageSize : pageSize , total : false })
83
+ return context .WithValue (ctx , pageHelperValue , & PageParam {Page : page , PageSize : pageSize , total : 0 })
84
84
}
85
85
86
86
//分页(包含total信息)
87
87
//page 页码
88
88
//pageSize 分页大小
89
89
//ctx 初始context
90
- func StartPageWithTotal (page , pageSize int , ctx context.Context ) context.Context {
91
- return context .WithValue (ctx , pageHelperValue , & PageParam {Page : page , PageSize : pageSize , total : true })
90
+ func StartPageWithTotal (page , pageSize int , countColumn string , ctx context.Context ) context.Context {
91
+ return context .WithValue (ctx , pageHelperValue , & PageParam {Page : page , PageSize : pageSize , total : - 1 , countColumn : countColumn })
92
92
}
93
93
94
94
//排序
@@ -132,9 +132,8 @@ func (exec *Executor) Query(ctx context.Context, result reflection.Object, sql s
132
132
if p != nil {
133
133
if param , ok := p .(* PageParam ); ok {
134
134
sql = PageModifier (sql , param )
135
- if param .total == true {
136
- total := exec .getTotal (ctx , originSql , params ... )
137
- ctx = context .WithValue (ctx , totalHelperValue , total )
135
+ if param .total == - 1 {
136
+ param .total = exec .getTotal (ctx , originSql , param .countColumn , params ... )
138
137
}
139
138
}
140
139
}
@@ -183,8 +182,8 @@ func modifyPageSql(sql string, p *PageParam) string {
183
182
return b .String ()
184
183
}
185
184
186
- func (exec * Executor ) getTotal (ctx context.Context , sql string , params ... interface {}) int64 {
187
- totalSql := CountModifier (sql )
185
+ func (exec * Executor ) getTotal (ctx context.Context , sql , countColumn string , params ... interface {}) int64 {
186
+ totalSql := CountModifier (sql , countColumn )
188
187
var total int64
189
188
obj , err := gobatis .ParseObject (& total )
190
189
if err == nil {
@@ -194,10 +193,15 @@ func (exec *Executor) getTotal(ctx context.Context, sql string, params ...interf
194
193
return 0
195
194
}
196
195
197
- func modifyCountSql (sql string ) string {
196
+ func modifyCountSql (sql , countColumn string ) string {
197
+ if countColumn == "" {
198
+ countColumn = "0"
199
+ }
198
200
b := strings.Builder {}
199
- b .WriteString ("SELECT COUNT(0) FROM (" )
200
- b .WriteString (sql )
201
- b .WriteString (")" )
201
+ b .WriteString ("SELECT COUNT(`" )
202
+ b .WriteString (countColumn )
203
+ b .WriteString ("`) FROM (" )
204
+ b .WriteString (strings .TrimSpace (sql ))
205
+ b .WriteString (") AS __hp_tempCountTl" )
202
206
return b .String ()
203
207
}
0 commit comments