Skip to content

Commit b2c6916

Browse files
Add RebuildBatched (#23)
Original `Rebuild` was making thousands of individual update queries. I added `RebuildBatched` which can becalled as: ```go nestedset.RebuildBatched(db.WithContext(ctx), root, true, 1000) ```
1 parent b4d4ae0 commit b2c6916

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

nested_set.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,65 @@ func Rebuild(db *gorm.DB, source interface{}, doUpdate bool) (affectedCount int,
323323
return
324324
}
325325

326+
327+
// RebuildBatched rebuild nodes as any nestedset which in the scope
328+
// ```nestedset.RebuildBatched(db, &node, true, 1000)``` will rebuild [&node] as nestedset
329+
func RebuildBatched(db *gorm.DB, source interface{}, doUpdate bool, batchSize int) (affectedCount int, err error) {
330+
tx, target, err := parseNode(db, source)
331+
if err != nil {
332+
return
333+
}
334+
err = tx.Transaction(func(tx *gorm.DB) (err error) {
335+
allItems := []*nestedItem{}
336+
err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).
337+
Where(formatSQL("", target)).
338+
Order(formatSQL(":parent_id ASC NULLS FIRST, :lft ASC", target)).
339+
Find(&allItems).
340+
Error
341+
342+
if err != nil {
343+
return
344+
}
345+
initTree(allItems).rebuild()
346+
347+
var itemsToUpdate []*nestedItem
348+
for _, item := range allItems {
349+
if item.IsChanged {
350+
affectedCount += 1
351+
if doUpdate {
352+
itemsToUpdate = append(itemsToUpdate, item)
353+
}
354+
}
355+
}
356+
if doUpdate && len(itemsToUpdate) > 0 {
357+
err = batchUpdate(tx, []string{"lft", "rgt", "depth", "children_count"}, target.DbNames, itemsToUpdate, batchSize)
358+
if err != nil {
359+
return
360+
}
361+
}
362+
return nil
363+
})
364+
return
365+
}
366+
367+
// batchUpdate performs a batched upsert (update on conflict) for the given columns and items.
368+
func batchUpdate(db *gorm.DB, columns []string, dbNames map[string]string, items []*nestedItem, batchSize int) error {
369+
if len(items) == 0 {
370+
return nil
371+
}
372+
373+
assignmentMap := map[string]interface{}{}
374+
for _, column := range columns {
375+
column = dbNames[column]
376+
assignmentMap[column] = gorm.Expr("EXCLUDED." + column)
377+
}
378+
379+
return db.Clauses(clause.OnConflict{
380+
Columns: []clause.Column{{Name: dbNames["id"]}},
381+
DoUpdates: clause.Assignments(assignmentMap),
382+
}).CreateInBatches(items, batchSize).Error
383+
}
384+
326385
func moveIsValid(node, to nestedItem) error {
327386
validLft, validRgt := node.Lft, node.Rgt
328387
if (to.Lft >= validLft && to.Lft <= validRgt) || (to.Rgt >= validLft && to.Rgt <= validRgt) {

nested_set_test.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,164 @@ func TestRebuild(t *testing.T) {
344344
assertNodeEqual(t, lilysDresses, 4, 5, 1, 0, lilysClothing.ID)
345345
}
346346

347+
func TestRebuildBatched(t *testing.T) {
348+
const batchSize = 5
349+
initData()
350+
affectedCount, err := RebuildBatched(db, clothing, true, batchSize)
351+
assert.NoError(t, err)
352+
assert.Equal(t, 0, affectedCount)
353+
reloadCategories()
354+
355+
assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
356+
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
357+
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
358+
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
359+
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
360+
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
361+
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
362+
assertNodeEqual(t, eveningGowns, 12, 13, 3, 0, dresses.ID)
363+
assertNodeEqual(t, sunDresses, 14, 15, 3, 0, dresses.ID)
364+
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
365+
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)
366+
367+
sunDresses.Rgt = 123
368+
sunDresses.Lft = 12
369+
sunDresses.Depth = 1
370+
sunDresses.ChildrenCount = 100
371+
err = db.Updates(&sunDresses).Error
372+
assert.NoError(t, err)
373+
reloadCategories()
374+
assertNodeEqual(t, sunDresses, 12, 123, 1, 100, dresses.ID)
375+
376+
affectedCount, err = RebuildBatched(db, clothing, true, batchSize)
377+
assert.NoError(t, err)
378+
assert.Equal(t, 2, affectedCount)
379+
reloadCategories()
380+
381+
assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
382+
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
383+
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
384+
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
385+
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
386+
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
387+
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
388+
assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID)
389+
assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID)
390+
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
391+
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)
392+
393+
affectedCount, err = RebuildBatched(db, clothing, true, batchSize)
394+
assert.NoError(t, err)
395+
assert.Equal(t, 0, affectedCount)
396+
reloadCategories()
397+
398+
assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
399+
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
400+
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
401+
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
402+
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
403+
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
404+
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
405+
assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID)
406+
assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID)
407+
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
408+
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)
409+
410+
hat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
411+
"Title": "Hat",
412+
"ParentID": sql.NullInt64{Valid: false},
413+
}).(*Category)
414+
415+
affectedCount, err = RebuildBatched(db, clothing, false, batchSize)
416+
assert.NoError(t, err)
417+
assert.Equal(t, 1, affectedCount)
418+
419+
affectedCount, err = RebuildBatched(db, clothing, true, batchSize)
420+
assert.NoError(t, err)
421+
assert.Equal(t, 1, affectedCount)
422+
reloadCategories()
423+
hat, _ = findNode(db, hat.ID)
424+
425+
assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
426+
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
427+
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
428+
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
429+
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
430+
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
431+
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
432+
assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID)
433+
assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID)
434+
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
435+
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)
436+
assertNodeEqual(t, hat, 23, 24, 0, 0, 0)
437+
438+
jacksClothing := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
439+
"Title": "Jack's Clothing",
440+
"ParentID": sql.NullInt64{Valid: false},
441+
"UserType": "User",
442+
"UserID": 8686,
443+
}).(*Category)
444+
jacksSuits := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
445+
"Title": "Jack's Suits",
446+
"ParentID": sql.NullInt64{Valid: true, Int64: jacksClothing.ID},
447+
"UserType": "User",
448+
"UserID": 8686,
449+
}).(*Category)
450+
jacksHat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
451+
"Title": "Jack's Hat",
452+
"UserType": "User",
453+
"UserID": 8686,
454+
"ParentID": sql.NullInt64{Valid: false},
455+
}).(*Category)
456+
jacksSlacks := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
457+
"Title": "Jack's Slacks",
458+
"ParentID": sql.NullInt64{Valid: true, Int64: jacksClothing.ID},
459+
"UserType": "User",
460+
"UserID": 8686,
461+
}).(*Category)
462+
463+
lilysHat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
464+
"Title": "Lily's Hat",
465+
"UserType": "User",
466+
"UserID": 6666,
467+
"ParentID": sql.NullInt64{Valid: false},
468+
}).(*Category)
469+
lilysClothing := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
470+
"Title": "Lily's Clothing",
471+
"ParentID": sql.NullInt64{Valid: false},
472+
"UserType": "User",
473+
"UserID": 6666,
474+
}).(*Category)
475+
lilysDresses := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
476+
"Title": "Lily's Dresses",
477+
"ParentID": sql.NullInt64{Valid: true, Int64: lilysClothing.ID},
478+
"UserType": "User",
479+
"UserID": 6666,
480+
}).(*Category)
481+
482+
affectedCount, err = RebuildBatched(db, jacksSuits, true, batchSize)
483+
assert.NoError(t, err)
484+
assert.Equal(t, 4, affectedCount)
485+
affectedCount, err = RebuildBatched(db, lilysHat, true, batchSize)
486+
assert.NoError(t, err)
487+
assert.Equal(t, 3, affectedCount)
488+
jacksClothing, _ = findNode(db, jacksClothing.ID)
489+
jacksSuits, _ = findNode(db, jacksSuits.ID)
490+
jacksSlacks, _ = findNode(db, jacksSlacks.ID)
491+
jacksHat, _ = findNode(db, jacksHat.ID)
492+
lilysHat, _ = findNode(db, lilysHat.ID)
493+
lilysClothing, _ = findNode(db, lilysClothing.ID)
494+
lilysDresses, _ = findNode(db, lilysDresses.ID)
495+
496+
assertNodeEqual(t, jacksClothing, 1, 6, 0, 2, 0)
497+
assertNodeEqual(t, jacksSuits, 2, 3, 1, 0, jacksClothing.ID)
498+
assertNodeEqual(t, jacksSlacks, 4, 5, 1, 0, jacksClothing.ID)
499+
assertNodeEqual(t, jacksHat, 7, 8, 0, 0, 0)
500+
assertNodeEqual(t, lilysHat, 1, 2, 0, 0, 0)
501+
assertNodeEqual(t, lilysClothing, 3, 6, 0, 1, 0)
502+
assertNodeEqual(t, lilysDresses, 4, 5, 1, 0, lilysClothing.ID)
503+
}
504+
347505
func TestMoveToLeft(t *testing.T) {
348506
// case 1
349507
initData()

0 commit comments

Comments
 (0)