Skip to content

Commit 0ea6f99

Browse files
committed
增加获得总记录数支持
1 parent a82c5d6 commit 0ea6f99

File tree

3 files changed

+113
-15
lines changed

3 files changed

+113
-15
lines changed

builder.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ func (b *builder) Page(page, pageSize int) *builder {
2626
return b
2727
}
2828

29+
//分页
30+
//page 页码
31+
//pageSize 分页大小
32+
func (b *builder) PageWithTotal(page, pageSize int) *builder {
33+
b.ctx = StartPageWithTotal(page, pageSize, b.ctx)
34+
return b
35+
}
36+
2937
//手动指定字段和排序
3038
//field 字段
3139
//order 排序 [ASC | DESC]

pagehelper.go

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package pagehelper
1111
import (
1212
"context"
1313
"fmt"
14+
"github.com/xfali/gobatis"
1415
"github.com/xfali/gobatis/common"
1516
"github.com/xfali/gobatis/executor"
1617
"github.com/xfali/gobatis/factory"
@@ -24,11 +25,18 @@ import (
2425
const (
2526
pageHelperValue = "_page_helper_value"
2627
orderHelperValue = "_order_helper_value"
28+
totalHelperValue = "_total_helper_value"
2729

2830
ASC = "ASC"
2931
DESC = "DESC"
3032
)
3133

34+
var (
35+
OrderByModifier = modifyOrderSql
36+
PageModifier = modifyPageSql
37+
CountModifier = modifyCountSql
38+
)
39+
3240
type OrderParam struct {
3341
Field string
3442
Order string
@@ -37,12 +45,27 @@ type OrderParam struct {
3745
type PageParam struct {
3846
Page int
3947
PageSize int
48+
49+
total bool
4050
}
4151

4252
func New(f factory.Factory) *Factory {
4353
return &Factory{f}
4454
}
4555

56+
func GetTotal(ctx context.Context) int64 {
57+
if ctx == nil {
58+
return 0
59+
}
60+
p := ctx.Value(totalHelperValue)
61+
if p != nil {
62+
if t, ok := p.(int64); ok {
63+
return t
64+
}
65+
}
66+
return 0
67+
}
68+
4669
type Factory struct {
4770
fac factory.Factory
4871
}
@@ -57,7 +80,15 @@ type Executor struct {
5780
//pageSize 分页大小
5881
//ctx 初始context
5982
func StartPage(page, pageSize int, ctx context.Context) context.Context {
60-
return context.WithValue(ctx, pageHelperValue, &PageParam{Page: page, PageSize: pageSize})
83+
return context.WithValue(ctx, pageHelperValue, &PageParam{Page: page, PageSize: pageSize, total: false})
84+
}
85+
86+
//分页(包含total信息)
87+
//page 页码
88+
//pageSize 分页大小
89+
//ctx 初始context
90+
func StartPageWithTotal(page, pageSize int, ctx context.Context) context.Context {
91+
return context.WithValue(ctx, pageHelperValue, &PageParam{Page: page, PageSize: pageSize, total: true})
6192
}
6293

6394
//排序
@@ -89,17 +120,22 @@ func (exec *Executor) Rollback(require bool) error {
89120
}
90121

91122
func (exec *Executor) Query(ctx context.Context, result reflection.Object, sql string, params ...interface{}) error {
123+
originSql := sql
92124
o := ctx.Value(orderHelperValue)
93125
if o != nil {
94126
if param, ok := o.(*OrderParam); ok {
95-
sql = modifySqlOrder(sql, param)
127+
sql = OrderByModifier(sql, param)
96128
}
97129
}
98130

99131
p := ctx.Value(pageHelperValue)
100132
if p != nil {
101133
if param, ok := p.(*PageParam); ok {
102-
sql = modifySql(sql, param)
134+
sql = PageModifier(sql, param)
135+
if param.total == true {
136+
total := exec.getTotal(ctx, originSql, params...)
137+
ctx = context.WithValue(ctx, totalHelperValue, total)
138+
}
103139
}
104140
}
105141
exec.log(logging.DEBUG, "PageHelper Query: [%s], params: %s\n", sql, fmt.Sprint(params))
@@ -130,7 +166,7 @@ func (f *Factory) CreateExecutor(transaction transaction.Transaction) executor.E
130166
}
131167
}
132168

133-
func modifySqlOrder(sql string, p *OrderParam) string {
169+
func modifyOrderSql(sql string, p *OrderParam) string {
134170
if p.Field == "" {
135171
return sql
136172
}
@@ -140,9 +176,28 @@ func modifySqlOrder(sql string, p *OrderParam) string {
140176
return b.String()
141177
}
142178

143-
func modifySql(sql string, p *PageParam) string {
179+
func modifyPageSql(sql string, p *PageParam) string {
144180
b := strings.Builder{}
145181
b.WriteString(strings.TrimSpace(sql))
146182
b.WriteString(fmt.Sprintf(" LIMIT %d, %d ", p.Page*p.PageSize, p.PageSize))
147183
return b.String()
148184
}
185+
186+
func (exec *Executor) getTotal(ctx context.Context, sql string, params ...interface{}) int64 {
187+
totalSql := CountModifier(sql)
188+
var total int64
189+
obj, err := gobatis.ParseObject(&total)
190+
if err == nil {
191+
exec.exec.Query(ctx, obj, totalSql, params...)
192+
return total
193+
}
194+
return 0
195+
}
196+
197+
func modifyCountSql(sql string) string {
198+
b := strings.Builder{}
199+
b.WriteString("SELECT COUNT(0) FROM (")
200+
b.WriteString(sql)
201+
b.WriteString(")")
202+
return b.String()
203+
}

pagehelper_test.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,40 +124,75 @@ func TestPageHelper2(t *testing.T) {
124124
}
125125

126126
func TestModifyPage(t *testing.T) {
127-
sql := modifySql("select * from x", &PageParam{1, 2})
127+
sql := PageModifier("select * from x", &PageParam{1, 2, false})
128128
t.Log(sql)
129+
if strings.TrimSpace(sql) != `select * from x LIMIT 2, 2` {
130+
t.Fail()
131+
}
129132
}
130133

131134
func order(sql string, params ...interface{}) (string, []interface{}) {
132-
return modifySqlOrder(sql, &OrderParam{"test", ASC}), params
135+
return OrderByModifier(sql, &OrderParam{"test", ASC}), params
133136
}
134137

135138
func TestModifyOrder(t *testing.T) {
136139
sql, p := order("select ? from x", "field1")
137140
t.Log(sql)
138-
if len(p) != 2 {
139-
t.Fatal()
140-
}
141141
for _, v := range p {
142142
t.Log(v)
143143
}
144+
145+
if strings.TrimSpace(sql) != "select ? from x ORDER BY `test` ASC" {
146+
t.Fail()
147+
}
148+
}
149+
150+
func TestModifyCount(t *testing.T) {
151+
sql := CountModifier("select ? from x")
152+
t.Log(sql)
153+
154+
if strings.TrimSpace(sql) != "SELECT COUNT(0) FROM (select ? from x)" {
155+
t.Fail()
156+
}
157+
}
158+
159+
160+
func TestChangeModifyCount(t *testing.T) {
161+
CountModifier = func(sql string) string {
162+
return "test " + sql
163+
}
164+
sql := CountModifier("select ? from x")
165+
t.Log(sql)
166+
167+
if strings.TrimSpace(sql) != "test select ? from x" {
168+
t.Fail()
169+
}
170+
}
171+
172+
func TestGetTotal(t *testing.T) {
173+
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
174+
ctx = OrderBy("test", ASC, ctx)
175+
ctx = StartPage(1, 2, ctx)
176+
177+
total := GetTotal(ctx)
178+
t.Log(total)
179+
if total != 0 {
180+
t.Fail()
181+
}
144182
}
145183

146184
func TestModifyOrderAndPage(t *testing.T) {
147185
sql, p := order("select ? from x", "field1")
148186
t.Log(sql)
149-
if len(p) != 2 {
150-
t.Fatal()
151-
}
152187

153-
sql = modifySql(sql, &PageParam{1, 2})
188+
sql = PageModifier(sql, &PageParam{1, 2, false})
154189

155190
t.Log(sql)
156191
for _, v := range p {
157192
t.Log(v)
158193
}
159194

160-
if strings.TrimSpace(sql) != "select ? from x ORDER BY ? ASC LIMIT 2, 2" {
195+
if strings.TrimSpace(sql) != "select ? from x ORDER BY `test` ASC LIMIT 2, 2" {
161196
t.Fail()
162197
}
163198
}

0 commit comments

Comments
 (0)