|
| 1 | +package sql |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "database/sql" |
| 6 | + "database/sql/driver" |
| 7 | + "reflect" |
| 8 | +) |
| 9 | + |
| 10 | +type BatchWriter struct { |
| 11 | + db *sql.DB |
| 12 | + tableName string |
| 13 | + Map func(ctx context.Context, model interface{}) (interface{}, error) |
| 14 | + Schema *Schema |
| 15 | + ToArray func(interface{}) interface { |
| 16 | + driver.Valuer |
| 17 | + sql.Scanner |
| 18 | + } |
| 19 | +} |
| 20 | +func NewBatchWriter(db *sql.DB, tableName string, modelType reflect.Type, options ...func(context.Context, interface{}) (interface{}, error)) *BatchWriter { |
| 21 | + var mp func(context.Context, interface{}) (interface{}, error) |
| 22 | + if len(options) > 0 && options[0] != nil { |
| 23 | + mp = options[0] |
| 24 | + } |
| 25 | + return NewBatchWriterWithArray(db, tableName, modelType, nil, mp) |
| 26 | +} |
| 27 | +func NewBatchWriterWithArray(db *sql.DB, tableName string, modelType reflect.Type, toArray func(interface{}) interface { |
| 28 | + driver.Valuer |
| 29 | + sql.Scanner |
| 30 | +}, options ...func(context.Context, interface{}) (interface{}, error)) *BatchWriter { |
| 31 | + var mp func(context.Context, interface{}) (interface{}, error) |
| 32 | + if len(options) > 0 && options[0] != nil { |
| 33 | + mp = options[0] |
| 34 | + } |
| 35 | + schema := CreateSchema(modelType) |
| 36 | + return &BatchWriter{db: db, tableName: tableName, Schema: schema, Map: mp, ToArray: toArray} |
| 37 | +} |
| 38 | + |
| 39 | +func (w *BatchWriter) Write(ctx context.Context, models interface{}) ([]int, []int, error) { |
| 40 | + successIndices := make([]int, 0) |
| 41 | + failIndices := make([]int, 0) |
| 42 | + var m interface{} |
| 43 | + var er0 error |
| 44 | + if w.Map != nil { |
| 45 | + m, er0 = MapModels(ctx, models, w.Map) |
| 46 | + if er0 != nil { |
| 47 | + s0 := reflect.ValueOf(m) |
| 48 | + _, er0b := InterfaceSlice(m) |
| 49 | + failIndices = ToArrayIndex(s0, failIndices) |
| 50 | + return successIndices, failIndices, er0b |
| 51 | + } |
| 52 | + } else { |
| 53 | + m = models |
| 54 | + } |
| 55 | + s := reflect.ValueOf(m) |
| 56 | + _, er2 := SaveBatchWithArray(ctx, w.db, w.tableName, m, w.ToArray, w.Schema) |
| 57 | + |
| 58 | + if er2 == nil { |
| 59 | + // Return full success |
| 60 | + successIndices = ToArrayIndex(s, successIndices) |
| 61 | + return successIndices, failIndices, er2 |
| 62 | + } else { |
| 63 | + // Return full fail |
| 64 | + failIndices = ToArrayIndex(s, failIndices) |
| 65 | + } |
| 66 | + return successIndices, failIndices, er2 |
| 67 | +} |
0 commit comments