Skip to content

Commit 61e8328

Browse files
committed
fix test
1 parent 19d9d35 commit 61e8328

File tree

2 files changed

+301
-0
lines changed

2 files changed

+301
-0
lines changed

pkg/mysql/crud.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package mysql
2+
3+
import (
4+
"context"
5+
6+
"github.com/zhufuyi/sponge/pkg/mysql/query"
7+
8+
"gorm.io/gorm"
9+
)
10+
11+
// TableName get table name
12+
func TableName(table interface{}) string {
13+
return GetTableName(table)
14+
}
15+
16+
// Create a new record
17+
// the param of 'table' must be pointer, eg: &StructName
18+
func Create(ctx context.Context, db *gorm.DB, table interface{}) error {
19+
return db.WithContext(ctx).Create(table).Error
20+
}
21+
22+
// Delete record
23+
// the param of 'table' must be pointer, eg: &StructName
24+
func Delete(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) error {
25+
return db.WithContext(ctx).Where(queryCondition, args...).Delete(table).Error
26+
}
27+
28+
// DeleteByID delete record by id
29+
// the param of 'table' must be pointer, eg: &StructName
30+
func DeleteByID(ctx context.Context, db *gorm.DB, table interface{}, id interface{}) error {
31+
return db.WithContext(ctx).Where("id = ?", id).Delete(table).Error
32+
}
33+
34+
// Update record
35+
// the param of 'table' must be pointer, eg: &StructName
36+
func Update(ctx context.Context, db *gorm.DB, table interface{}, column string, value interface{}, queryCondition interface{}, args ...interface{}) error {
37+
return db.WithContext(ctx).Model(table).Where(queryCondition, args...).Update(column, value).Error
38+
}
39+
40+
// Updates record
41+
// the param of 'table' must be pointer, eg: &StructName
42+
func Updates(ctx context.Context, db *gorm.DB, table interface{}, update KV, queryCondition interface{}, args ...interface{}) error {
43+
return db.WithContext(ctx).Model(table).Where(queryCondition, args...).Updates(update).Error
44+
}
45+
46+
// Get one record
47+
// the param of 'table' must be pointer, eg: &StructName
48+
func Get(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) error {
49+
return db.WithContext(ctx).Where(queryCondition, args...).First(table).Error
50+
}
51+
52+
// GetByID get record by id
53+
func GetByID(ctx context.Context, db *gorm.DB, table interface{}, id interface{}) error {
54+
return db.WithContext(ctx).Where("id = ?", id).First(table).Error
55+
}
56+
57+
// List multiple records, starting from page 0
58+
// the param of 'tables' must be a slice, eg: []StructName
59+
func List(ctx context.Context, db *gorm.DB, tables interface{}, page *query.Page, queryCondition interface{}, args ...interface{}) error {
60+
return db.WithContext(ctx).Order(page.Sort()).Limit(page.Size()).Offset(page.Offset()).Where(queryCondition, args...).Find(tables).Error
61+
}
62+
63+
// Count number of records
64+
// the param of 'table' must be pointer, eg: &StructName
65+
func Count(ctx context.Context, db *gorm.DB, table interface{}, queryCondition interface{}, args ...interface{}) (int64, error) {
66+
var count int64
67+
err := db.WithContext(ctx).Model(table).Where(queryCondition, args...).Count(&count).Error
68+
return count, err
69+
}

pkg/mysql/crud_test.go

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package mysql
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
"time"
7+
8+
"github.com/zhufuyi/sponge/pkg/gotest"
9+
"github.com/zhufuyi/sponge/pkg/mysql/query"
10+
11+
"github.com/DATA-DOG/go-sqlmock"
12+
"github.com/stretchr/testify/assert"
13+
"gorm.io/gorm"
14+
)
15+
16+
var table = &userExample{}
17+
18+
type userExample struct {
19+
Model `gorm:"embedded"`
20+
21+
Name string `gorm:"type:varchar(40);unique_index;not null" json:"name"`
22+
Age int `gorm:"not null" json:"age"`
23+
Gender string `gorm:"type:varchar(10);not null" json:"gender"`
24+
}
25+
26+
func newUserExampleDao() *gotest.Dao {
27+
testData := &userExample{Name: "ZhangSan", Age: 20, Gender: "male"}
28+
testData.ID = 1
29+
testData.CreatedAt = time.Now()
30+
testData.UpdatedAt = testData.CreatedAt
31+
32+
// init mock dao
33+
d := gotest.NewDao(nil, testData)
34+
35+
return d
36+
}
37+
38+
func TestTableName(t *testing.T) {
39+
t.Logf("table name = %s", TableName(&userExample{}))
40+
}
41+
42+
func TestCreate(t *testing.T) {
43+
d := newUserExampleDao()
44+
defer d.Close()
45+
testData := d.TestData.(*userExample)
46+
47+
d.SQLMock.ExpectBegin()
48+
d.SQLMock.ExpectExec("INSERT INTO .*").
49+
WithArgs(d.GetAnyArgs(testData)...).
50+
WillReturnResult(sqlmock.NewResult(1, 1))
51+
d.SQLMock.ExpectCommit()
52+
53+
err := Create(d.Ctx, d.DB, testData)
54+
assert.NoError(t, err)
55+
}
56+
57+
func TestDelete(t *testing.T) {
58+
d := newUserExampleDao()
59+
defer d.Close()
60+
testData := d.TestData.(*userExample)
61+
62+
d.SQLMock.ExpectBegin()
63+
d.SQLMock.ExpectExec("UPDATE .*").
64+
WithArgs(d.AnyTime, testData.Name).
65+
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
66+
d.SQLMock.ExpectCommit()
67+
68+
err := Delete(d.Ctx, d.DB, table, "name = ?", testData.Name)
69+
assert.NoError(t, err)
70+
}
71+
72+
func TestDeleteByID(t *testing.T) {
73+
d := newUserExampleDao()
74+
defer d.Close()
75+
testData := d.TestData.(*userExample)
76+
77+
d.SQLMock.ExpectBegin()
78+
d.SQLMock.ExpectExec("UPDATE .*").
79+
WithArgs(d.AnyTime, testData.ID).
80+
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
81+
d.SQLMock.ExpectCommit()
82+
83+
err := Delete(d.Ctx, d.DB, table, "id = ?", testData.ID)
84+
assert.NoError(t, err)
85+
}
86+
87+
func TestUpdate(t *testing.T) {
88+
d := newUserExampleDao()
89+
defer d.Close()
90+
testData := d.TestData.(*userExample)
91+
92+
d.SQLMock.ExpectBegin()
93+
d.SQLMock.ExpectExec("UPDATE .*").
94+
WithArgs(sqlmock.AnyArg(), d.AnyTime, testData.Name).
95+
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
96+
d.SQLMock.ExpectCommit()
97+
98+
err := Update(d.Ctx, d.DB, table, "age", gorm.Expr("age + ?", 1), "name = ?", testData.Name)
99+
assert.NoError(t, err)
100+
}
101+
102+
func TestUpdates(t *testing.T) {
103+
d := newUserExampleDao()
104+
defer d.Close()
105+
testData := d.TestData.(*userExample)
106+
107+
d.SQLMock.ExpectBegin()
108+
d.SQLMock.ExpectExec("UPDATE .*").
109+
WithArgs(sqlmock.AnyArg(), d.AnyTime, testData.Gender).
110+
WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1))
111+
d.SQLMock.ExpectCommit()
112+
113+
update := KV{"age": gorm.Expr("age + ?", 1)}
114+
err := Updates(d.Ctx, d.DB, table, update, "gender = ?", testData.Gender)
115+
assert.NoError(t, err)
116+
}
117+
118+
func TestGetByID(t *testing.T) {
119+
d := newUserExampleDao()
120+
defer d.Close()
121+
testData := d.TestData.(*userExample)
122+
123+
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
124+
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)
125+
126+
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(testData.ID).WillReturnRows(rows)
127+
128+
err := GetByID(d.Ctx, d.DB, table, testData.ID)
129+
assert.NoError(t, err)
130+
131+
t.Logf("%+v", table)
132+
}
133+
134+
func TestGet(t *testing.T) {
135+
d := newUserExampleDao()
136+
defer d.Close()
137+
testData := d.TestData.(*userExample)
138+
139+
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
140+
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)
141+
142+
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnRows(rows) // adjusted for number of fields
143+
144+
err := Get(d.Ctx, d.DB, table, "name = ?", testData.Name)
145+
assert.NoError(t, err)
146+
147+
t.Logf("%+v", table)
148+
}
149+
150+
func TestList(t *testing.T) {
151+
d := newUserExampleDao()
152+
defer d.Close()
153+
testData := d.TestData.(*userExample)
154+
155+
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
156+
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)
157+
158+
d.SQLMock.ExpectQuery("SELECT .*").WillReturnRows(rows)
159+
160+
page := query.NewPage(0, 10, "")
161+
tables := []userExample{}
162+
err := List(d.Ctx, d.DB, &tables, page, "")
163+
assert.NoError(t, err)
164+
165+
for _, user := range tables {
166+
t.Logf("%+v", user)
167+
}
168+
}
169+
170+
func TestCount(t *testing.T) {
171+
d := newUserExampleDao()
172+
defer d.Close()
173+
testData := d.TestData.(*userExample)
174+
175+
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
176+
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)
177+
178+
d.SQLMock.ExpectQuery("SELECT .*").
179+
WithArgs(sqlmock.AnyArg()).
180+
WillReturnRows(rows)
181+
182+
count, err := Count(d.Ctx, d.DB, table, "id > ?", 0)
183+
assert.NotNil(t, err)
184+
185+
t.Logf("count=%d", count)
186+
}
187+
188+
func TestTx(t *testing.T) {
189+
err := createUser()
190+
if err != nil {
191+
t.Fatal(err)
192+
}
193+
}
194+
195+
func createUser() error {
196+
d := newUserExampleDao()
197+
defer d.Close()
198+
testData := d.TestData.(*userExample)
199+
rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "name", "age", "gender"}).
200+
AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt, testData.Name, testData.Age, testData.Gender)
201+
d.SQLMock.ExpectBegin()
202+
d.SQLMock.ExpectQuery("SELECT .*").WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg()).WillReturnRows(rows) // adjusted for number of fields
203+
d.SQLMock.ExpectCommit()
204+
205+
// note that you should use tx as the database handle when you are in a transaction
206+
tx := d.DB.Begin()
207+
defer func() {
208+
if err := recover(); err != nil { // rollback after a panic during transaction execution
209+
tx.Rollback()
210+
fmt.Printf("transaction failed, err = %v\n", err)
211+
}
212+
}()
213+
214+
var err error
215+
if err = tx.Error; err != nil {
216+
return err
217+
}
218+
219+
if err = tx.WithContext(d.Ctx).Where("id = ?", testData.ID).First(table).Error; err != nil {
220+
tx.Rollback()
221+
return err
222+
}
223+
224+
panic("mock panic")
225+
226+
if err = tx.WithContext(d.Ctx).Create(&userExample{Name: "lisi", Age: table.Age + 2, Gender: "male"}).Error; err != nil {
227+
tx.Rollback()
228+
return err
229+
}
230+
231+
return tx.Commit().Error
232+
}

0 commit comments

Comments
 (0)