Skip to content

Commit edaa34c

Browse files
committed
Refactor adapter
1 parent 674e972 commit edaa34c

File tree

11 files changed

+1147
-17
lines changed

11 files changed

+1147
-17
lines changed

adapter/adapter.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ type Adapter[T any, K any] struct {
1616
*Writer[*T]
1717
Map map[string]int
1818
Fields string
19-
Keys []string
2019
IdMap bool
2120
}
2221

@@ -41,11 +40,10 @@ func NewSqlAdapterWithVersionAndArray[T any, K any](db *sql.DB, tableName string
4140
return nil, errors.New("T must be a struct")
4241
}
4342

44-
_, primaryKeys := q.FindPrimaryKeys(modelType)
4543
var k K
4644
kType := reflect.TypeOf(k)
4745
idMap := false
48-
if len(primaryKeys) > 1 {
46+
if len(adapter.Keys) > 1 {
4947
if kType.Kind() == reflect.Map {
5048
idMap = true
5149
} else if kType.Kind() != reflect.Struct {
@@ -58,7 +56,7 @@ func NewSqlAdapterWithVersionAndArray[T any, K any](db *sql.DB, tableName string
5856
return nil, err
5957
}
6058
fields := q.BuildFieldsBySchema(adapter.Schema)
61-
return &Adapter[T, K]{adapter, fieldsIndex, fields, primaryKeys, idMap}, nil
59+
return &Adapter[T, K]{adapter, fieldsIndex, fields, idMap}, nil
6260
}
6361
func (a *Adapter[T, K]) All(ctx context.Context) ([]T, error) {
6462
var objs []T

adapter/writer.go

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type Writer[T any] struct {
1616
DB *sql.DB
1717
Table string
1818
Schema *q.Schema
19+
Keys []string
1920
JsonColumnMap map[string]string
2021
BuildParam func(int) string
2122
Driver string
@@ -63,10 +64,14 @@ func NewSqlWriterWithVersionAndArray[T any](db *sql.DB, tableName string, versio
6364
if modelType.Kind() == reflect.Ptr {
6465
modelType = modelType.Elem()
6566
}
67+
_, primaryKeys := q.FindPrimaryKeys(modelType)
68+
if len(primaryKeys) == 0 {
69+
return nil, fmt.Errorf("require primary key for table '%s'", tableName)
70+
}
6671
schema := q.CreateSchema(modelType)
6772
jsonColumnMapT := q.MakeJsonColumnMap(modelType)
6873
jsonColumnMap := q.GetWritableColumns(schema.Fields, jsonColumnMapT)
69-
adapter := &Writer[T]{DB: db, Table: tableName, Schema: schema, JsonColumnMap: jsonColumnMap, BuildParam: buildParam, Driver: drivr, BoolSupport: boolSupport, ToArray: toArray, TxKey: "tx", versionIndex: -1}
74+
adapter := &Writer[T]{DB: db, Table: tableName, Schema: schema, Keys: primaryKeys, JsonColumnMap: jsonColumnMap, BuildParam: buildParam, Driver: drivr, BoolSupport: boolSupport, ToArray: toArray, TxKey: "tx", versionIndex: -1}
7075
if len(versionField) > 0 {
7176
index := q.FindFieldIndex(modelType, versionField)
7277
if index >= 0 {
@@ -87,7 +92,7 @@ func (a *Writer[T]) Create(ctx context.Context, model T) (int64, error) {
8792
query, args := q.BuildToInsertWithVersion(a.Table, model, a.versionIndex, a.BuildParam, a.BoolSupport, a.ToArray, a.Schema)
8893
res, err := tx.ExecContext(ctx, query, args...)
8994
if err != nil {
90-
return -1, err
95+
return q.HandleDuplicate(a.DB, err)
9196
}
9297
rowsAffected, err := res.RowsAffected()
9398
if err != nil {
@@ -113,11 +118,29 @@ func (a *Writer[T]) Update(ctx context.Context, model T) (int64, error) {
113118
if err != nil {
114119
return rowsAffected, err
115120
}
116-
if rowsAffected > 0 && a.versionIndex >= 0 {
117-
vo := reflect.ValueOf(model)
118-
if vo.Kind() == reflect.Ptr {
119-
vo = reflect.Indirect(vo)
121+
vo := reflect.ValueOf(model)
122+
if vo.Kind() == reflect.Ptr {
123+
vo = reflect.Indirect(vo)
124+
}
125+
if rowsAffected < 1 {
126+
var values []interface{}
127+
query1 := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
128+
le := len(a.Keys)
129+
var where []string
130+
for i := 0; i < le; i++ {
131+
where = append(where, fmt.Sprintf("%s = %s", a.Schema.Keys[i].Column), a.BuildParam(i+1))
132+
}
133+
query2 := query1 + " where " + strings.Join(where, " and ")
134+
rows, er2 := tx.QueryContext(ctx, query2, values...)
135+
if er2 != nil {
136+
return -1, err
120137
}
138+
defer rows.Close()
139+
for rows.Next() {
140+
return -1, nil
141+
}
142+
return 0, nil
143+
} else if a.versionIndex >= 0 {
121144
currentVersion := vo.Field(a.versionIndex).Interface()
122145
increaseVersion(vo, a.versionIndex, currentVersion)
123146
}
@@ -159,7 +182,30 @@ func (a *Writer[T]) Patch(ctx context.Context, model map[string]interface{}) (in
159182
if err != nil {
160183
return rowsAffected, err
161184
}
162-
if rowsAffected > 0 && a.versionIndex >= 0 {
185+
if rowsAffected < 1 {
186+
var query2 string
187+
var values []interface{}
188+
query1 := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
189+
if len(a.Keys) == 1 {
190+
query2, values = q.BuildFindByIdWithDB(a.DB, query1, model[a.Keys[0]], a.JsonColumnMap, a.Keys, a.BuildParam)
191+
} else {
192+
im := make(map[string]interface{})
193+
le := len(a.Keys)
194+
for i := 0; i < le; i++ {
195+
im[a.Keys[i]] = model[a.Keys[i]]
196+
}
197+
query2, values = q.BuildFindByIdWithDB(a.DB, query1, im, a.JsonColumnMap, a.Keys, a.BuildParam)
198+
}
199+
rows, er2 := tx.QueryContext(ctx, query2, values...)
200+
if er2 != nil {
201+
return -1, err
202+
}
203+
defer rows.Close()
204+
for rows.Next() {
205+
return -1, nil
206+
}
207+
return 0, nil
208+
} else if a.versionIndex >= 0 {
163209
currentVersion, vok := model[a.versionJson]
164210
if !vok {
165211
return -1, fmt.Errorf("%s must be in model for patch", a.versionJson)

dao/dao.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package dao
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"encoding/json"
8+
"errors"
9+
"fmt"
10+
"reflect"
11+
12+
q "github.com/core-go/sql"
13+
)
14+
15+
type Dao[T any, K any] struct {
16+
*Writer[*T]
17+
Map map[string]int
18+
Fields string
19+
IdMap bool
20+
}
21+
22+
func NewDao[T any, K any](db *sql.DB, tableName string, opts ...func(int) string) (*Dao[T, K], error) {
23+
return NewSqlDaoWithVersionAndArray[T, K](db, tableName, "", nil, opts...)
24+
}
25+
func NewDaoWithVersion[T any, K any](db *sql.DB, tableName string, versionField string, opts ...func(int) string) (*Dao[T, K], error) {
26+
return NewSqlDaoWithVersionAndArray[T, K](db, tableName, versionField, nil, opts...)
27+
}
28+
func NewSqlDaoWithVersionAndArray[T any, K any](db *sql.DB, tableName string, versionField string, toArray func(interface{}) interface {
29+
driver.Valuer
30+
sql.Scanner
31+
}, opts ...func(int) string) (*Dao[T, K], error) {
32+
adapter, err := NewSqlWriterWithVersionAndArray[*T](db, tableName, versionField, toArray, opts...)
33+
if err != nil {
34+
return nil, err
35+
}
36+
37+
var t T
38+
modelType := reflect.TypeOf(t)
39+
if modelType.Kind() != reflect.Struct {
40+
return nil, errors.New("T must be a struct")
41+
}
42+
43+
var k K
44+
kType := reflect.TypeOf(k)
45+
idMap := false
46+
if len(adapter.Keys) > 1 {
47+
if kType.Kind() == reflect.Map {
48+
idMap = true
49+
} else if kType.Kind() != reflect.Struct {
50+
return nil, errors.New("for composite keys, K must be a struct or a map")
51+
}
52+
}
53+
54+
fieldsIndex, err := q.GetColumnIndexes(modelType)
55+
if err != nil {
56+
return nil, err
57+
}
58+
fields := q.BuildFieldsBySchema(adapter.Schema)
59+
return &Dao[T, K]{adapter, fieldsIndex, fields, idMap}, nil
60+
}
61+
func (a *Dao[T, K]) All(ctx context.Context) ([]T, error) {
62+
var objs []T
63+
query := fmt.Sprintf("select %s from %s", a.Fields, a.Table)
64+
tx := q.GetExec(ctx, a.DB, a.TxKey)
65+
err := q.Query(ctx, tx, a.Map, &objs, query)
66+
return objs, err
67+
}
68+
func toMap(obj interface{}) (map[string]interface{}, error) {
69+
b, err := json.Marshal(obj)
70+
if err != nil {
71+
return nil, err
72+
}
73+
im := make(map[string]interface{})
74+
er2 := json.Unmarshal(b, &im)
75+
return im, er2
76+
}
77+
func (a *Dao[T, K]) getId(k K) (interface{}, error) {
78+
if len(a.Keys) >= 2 && !a.IdMap {
79+
ri, err := toMap(k)
80+
return ri, err
81+
} else {
82+
return k, nil
83+
}
84+
}
85+
func (a *Dao[T, K]) Load(ctx context.Context, id K) (*T, error) {
86+
ip, er0 := a.getId(id)
87+
if er0 != nil {
88+
return nil, er0
89+
}
90+
var objs []T
91+
query := fmt.Sprintf("select %s from %s ", a.Fields, a.Table)
92+
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Keys, a.BuildParam)
93+
tx := q.GetExec(ctx, a.DB, a.TxKey)
94+
err := q.Query(ctx, tx, a.Map, &objs, query1, args...)
95+
if err != nil {
96+
return nil, err
97+
}
98+
if len(objs) > 0 {
99+
return &objs[0], nil
100+
}
101+
return nil, nil
102+
}
103+
func (a *Dao[T, K]) Exist(ctx context.Context, id K) (bool, error) {
104+
ip, er0 := a.getId(id)
105+
if er0 != nil {
106+
return false, er0
107+
}
108+
query := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
109+
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Keys, a.BuildParam)
110+
tx := q.GetExec(ctx, a.DB, a.TxKey)
111+
rows, err := tx.QueryContext(ctx, query1, args...)
112+
if err != nil {
113+
return false, err
114+
}
115+
defer rows.Close()
116+
for rows.Next() {
117+
return true, nil
118+
}
119+
return false, nil
120+
}
121+
func (a *Dao[T, K]) Delete(ctx context.Context, id K) (int64, error) {
122+
ip, er0 := a.getId(id)
123+
if er0 != nil {
124+
return -1, er0
125+
}
126+
query := fmt.Sprintf("delete from %s ", a.Table)
127+
query1, args := q.BuildFindByIdWithDB(a.DB, query, ip, a.JsonColumnMap, a.Keys, a.BuildParam)
128+
tx := q.GetExec(ctx, a.DB, a.TxKey)
129+
res, err := tx.ExecContext(ctx, query1, args...)
130+
if err != nil {
131+
return -1, err
132+
}
133+
return res.RowsAffected()
134+
}

dao/search.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package dao
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"reflect"
8+
9+
q "github.com/core-go/sql"
10+
)
11+
12+
type SearchDao[T any, K any, F any] struct {
13+
*Dao[T, K]
14+
BuildQuery func(F) (string, []interface{})
15+
Mp func(*T)
16+
Map map[string]int
17+
ToArray func(interface{}) interface {
18+
driver.Valuer
19+
sql.Scanner
20+
}
21+
}
22+
23+
func NewSearchDao[T any, K any, F any](db *sql.DB, table string, buildQuery func(F) (string, []interface{}), options ...func(*T)) (*SearchDao[T, K, F], error) {
24+
return NewSearchDaoWithArray[T, K, F](db, table, buildQuery, nil, "", nil, options...)
25+
}
26+
func NewSearchDaoWithVersion[T any, K any, F any](db *sql.DB, table string, buildQuery func(F) (string, []interface{}), versionField string, options ...func(*T)) (*SearchDao[T, K, F], error) {
27+
return NewSearchDaoWithArray[T, K, F](db, table, buildQuery, nil, versionField, nil, options...)
28+
}
29+
func NewSearchDaoWithArray[T any, K any, F any](db *sql.DB, table string, buildQuery func(F) (string, []interface{}), toArray func(interface{}) interface {
30+
driver.Valuer
31+
sql.Scanner
32+
}, versionField string, buildParam func(int) string, opts ...func(*T)) (*SearchDao[T, K, F], error) {
33+
daObj, err := NewSqlDaoWithVersionAndArray[T, K](db, table, versionField, toArray, buildParam)
34+
if err != nil {
35+
return nil, err
36+
}
37+
var mp func(*T)
38+
if len(opts) >= 1 {
39+
mp = opts[0]
40+
}
41+
var t T
42+
modelType := reflect.TypeOf(t)
43+
if modelType.Kind() == reflect.Ptr {
44+
modelType = modelType.Elem()
45+
}
46+
fieldsIndex, err := q.GetColumnIndexes(modelType)
47+
if err != nil {
48+
return nil, err
49+
}
50+
builder := &SearchDao[T, K, F]{Dao: daObj, Map: fieldsIndex, BuildQuery: buildQuery, Mp: mp, ToArray: toArray}
51+
return builder, nil
52+
}
53+
54+
func (b *SearchDao[T, K, F]) Search(ctx context.Context, filter F, limit int64, offset int64) ([]T, int64, error) {
55+
var objs []T
56+
query, args := b.BuildQuery(filter)
57+
total, er2 := q.BuildFromQuery(ctx, b.DB, b.Map, &objs, query, args, limit, offset, b.ToArray)
58+
if b.Mp != nil {
59+
l := len(objs)
60+
for i := 0; i < l; i++ {
61+
b.Mp(&objs[i])
62+
}
63+
}
64+
return objs, total, er2
65+
}

dao/search_builder.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package dao
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"errors"
8+
"reflect"
9+
10+
q "github.com/core-go/sql"
11+
)
12+
13+
type SearchBuilder[T any, F any] struct {
14+
Database *sql.DB
15+
BuildQuery func(F) (string, []interface{})
16+
fieldsIndex map[string]int
17+
Map func(*T)
18+
ToArray func(interface{}) interface {
19+
driver.Valuer
20+
sql.Scanner
21+
}
22+
}
23+
24+
func NewSearchBuilder[T any, F any](db *sql.DB, buildQuery func(F) (string, []interface{}), opts ...func(*T)) (*SearchBuilder[T, F], error) {
25+
return NewSearchBuilderWithArray[T, F](db, buildQuery, nil, opts...)
26+
}
27+
func NewSearchBuilderWithArray[T any, F any](db *sql.DB, buildQuery func(F) (string, []interface{}), toArray func(interface{}) interface {
28+
driver.Valuer
29+
sql.Scanner
30+
}, opts ...func(*T)) (*SearchBuilder[T, F], error) {
31+
var t T
32+
modelType := reflect.TypeOf(t)
33+
if modelType.Kind() != reflect.Struct {
34+
return nil, errors.New("T must be a struct")
35+
}
36+
var mp func(*T)
37+
if len(opts) >= 1 {
38+
mp = opts[0]
39+
}
40+
fieldsIndex, err := q.GetColumnIndexes(modelType)
41+
if err != nil {
42+
return nil, err
43+
}
44+
builder := &SearchBuilder[T, F]{Database: db, fieldsIndex: fieldsIndex, BuildQuery: buildQuery, Map: mp, ToArray: toArray}
45+
return builder, nil
46+
}
47+
48+
func (b *SearchBuilder[T, F]) Search(ctx context.Context, m F, limit int64, offset int64) ([]T, int64, error) {
49+
sql, params := b.BuildQuery(m)
50+
var objs []T
51+
total, er2 := q.BuildFromQuery(ctx, b.Database, b.fieldsIndex, &objs, sql, params, limit, offset, b.ToArray)
52+
if b.Map != nil {
53+
l := len(objs)
54+
for i := 0; i < l; i++ {
55+
b.Map(&objs[i])
56+
}
57+
}
58+
return objs, total, er2
59+
}

0 commit comments

Comments
 (0)