Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions nested_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,65 @@ func Rebuild(db *gorm.DB, source interface{}, doUpdate bool) (affectedCount int,
return
}


// RebuildBatched rebuild nodes as any nestedset which in the scope
// ```nestedset.RebuildBatched(db, &node, true, 1000)``` will rebuild [&node] as nestedset
Copy link
Member

@huacnlee huacnlee Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, we need to add a doc here to describe the limitation.

func RebuildBatched(db *gorm.DB, source interface{}, doUpdate bool, batchSize int) (affectedCount int, err error) {
tx, target, err := parseNode(db, source)
if err != nil {
return
}
err = tx.Transaction(func(tx *gorm.DB) (err error) {
allItems := []*nestedItem{}
err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where(formatSQL("", target)).
Order(formatSQL(":parent_id ASC NULLS FIRST, :lft ASC", target)).
Find(&allItems).
Error

if err != nil {
return
}
initTree(allItems).rebuild()

var itemsToUpdate []*nestedItem
for _, item := range allItems {
if item.IsChanged {
affectedCount += 1
if doUpdate {
itemsToUpdate = append(itemsToUpdate, item)
}
}
}
if doUpdate && len(itemsToUpdate) > 0 {
err = batchUpdate(tx, []string{"lft", "rgt", "depth", "children_count"}, target.DbNames, itemsToUpdate, batchSize)
if err != nil {
return
}
}
return nil
})
return
}

// batchUpdate performs a batched upsert (update on conflict) for the given columns and items.
func batchUpdate(db *gorm.DB, columns []string, dbNames map[string]string, items []*nestedItem, batchSize int) error {
if len(items) == 0 {
return nil
}

assignmentMap := map[string]interface{}{}
for _, column := range columns {
column = dbNames[column]
assignmentMap[column] = gorm.Expr("EXCLUDED." + column)
}

return db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: dbNames["id"]}},
DoUpdates: clause.Assignments(assignmentMap),
}).CreateInBatches(items, batchSize).Error
}

func moveIsValid(node, to nestedItem) error {
validLft, validRgt := node.Lft, node.Rgt
if (to.Lft >= validLft && to.Lft <= validRgt) || (to.Rgt >= validLft && to.Rgt <= validRgt) {
Expand Down
158 changes: 158 additions & 0 deletions nested_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,164 @@ func TestRebuild(t *testing.T) {
assertNodeEqual(t, lilysDresses, 4, 5, 1, 0, lilysClothing.ID)
}

func TestRebuildBatched(t *testing.T) {
const batchSize = 5
initData()
affectedCount, err := RebuildBatched(db, clothing, true, batchSize)
assert.NoError(t, err)
assert.Equal(t, 0, affectedCount)
reloadCategories()

assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
assertNodeEqual(t, eveningGowns, 12, 13, 3, 0, dresses.ID)
assertNodeEqual(t, sunDresses, 14, 15, 3, 0, dresses.ID)
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)

sunDresses.Rgt = 123
sunDresses.Lft = 12
sunDresses.Depth = 1
sunDresses.ChildrenCount = 100
err = db.Updates(&sunDresses).Error
assert.NoError(t, err)
reloadCategories()
assertNodeEqual(t, sunDresses, 12, 123, 1, 100, dresses.ID)

affectedCount, err = RebuildBatched(db, clothing, true, batchSize)
assert.NoError(t, err)
assert.Equal(t, 2, affectedCount)
reloadCategories()

assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID)
assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID)
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)

affectedCount, err = RebuildBatched(db, clothing, true, batchSize)
assert.NoError(t, err)
assert.Equal(t, 0, affectedCount)
reloadCategories()

assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID)
assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID)
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)

hat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Hat",
"ParentID": sql.NullInt64{Valid: false},
}).(*Category)

affectedCount, err = RebuildBatched(db, clothing, false, batchSize)
assert.NoError(t, err)
assert.Equal(t, 1, affectedCount)

affectedCount, err = RebuildBatched(db, clothing, true, batchSize)
assert.NoError(t, err)
assert.Equal(t, 1, affectedCount)
reloadCategories()
hat, _ = findNode(db, hat.ID)

assertNodeEqual(t, clothing, 1, 22, 0, 2, 0)
assertNodeEqual(t, mens, 2, 9, 1, 1, clothing.ID)
assertNodeEqual(t, suits, 3, 8, 2, 2, mens.ID)
assertNodeEqual(t, slacks, 4, 5, 3, 0, suits.ID)
assertNodeEqual(t, jackets, 6, 7, 3, 0, suits.ID)
assertNodeEqual(t, womens, 10, 21, 1, 3, clothing.ID)
assertNodeEqual(t, dresses, 11, 16, 2, 2, womens.ID)
assertNodeEqual(t, eveningGowns, 14, 15, 3, 0, dresses.ID)
assertNodeEqual(t, sunDresses, 12, 13, 3, 0, dresses.ID)
assertNodeEqual(t, skirts, 17, 18, 2, 0, womens.ID)
assertNodeEqual(t, blouses, 19, 20, 2, 0, womens.ID)
assertNodeEqual(t, hat, 23, 24, 0, 0, 0)

jacksClothing := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Jack's Clothing",
"ParentID": sql.NullInt64{Valid: false},
"UserType": "User",
"UserID": 8686,
}).(*Category)
jacksSuits := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Jack's Suits",
"ParentID": sql.NullInt64{Valid: true, Int64: jacksClothing.ID},
"UserType": "User",
"UserID": 8686,
}).(*Category)
jacksHat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Jack's Hat",
"UserType": "User",
"UserID": 8686,
"ParentID": sql.NullInt64{Valid: false},
}).(*Category)
jacksSlacks := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Jack's Slacks",
"ParentID": sql.NullInt64{Valid: true, Int64: jacksClothing.ID},
"UserType": "User",
"UserID": 8686,
}).(*Category)

lilysHat := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Lily's Hat",
"UserType": "User",
"UserID": 6666,
"ParentID": sql.NullInt64{Valid: false},
}).(*Category)
lilysClothing := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Lily's Clothing",
"ParentID": sql.NullInt64{Valid: false},
"UserType": "User",
"UserID": 6666,
}).(*Category)
lilysDresses := *CategoryFactory.MustCreateWithOption(map[string]interface{}{
"Title": "Lily's Dresses",
"ParentID": sql.NullInt64{Valid: true, Int64: lilysClothing.ID},
"UserType": "User",
"UserID": 6666,
}).(*Category)

affectedCount, err = RebuildBatched(db, jacksSuits, true, batchSize)
assert.NoError(t, err)
assert.Equal(t, 4, affectedCount)
affectedCount, err = RebuildBatched(db, lilysHat, true, batchSize)
assert.NoError(t, err)
assert.Equal(t, 3, affectedCount)
jacksClothing, _ = findNode(db, jacksClothing.ID)
jacksSuits, _ = findNode(db, jacksSuits.ID)
jacksSlacks, _ = findNode(db, jacksSlacks.ID)
jacksHat, _ = findNode(db, jacksHat.ID)
lilysHat, _ = findNode(db, lilysHat.ID)
lilysClothing, _ = findNode(db, lilysClothing.ID)
lilysDresses, _ = findNode(db, lilysDresses.ID)

assertNodeEqual(t, jacksClothing, 1, 6, 0, 2, 0)
assertNodeEqual(t, jacksSuits, 2, 3, 1, 0, jacksClothing.ID)
assertNodeEqual(t, jacksSlacks, 4, 5, 1, 0, jacksClothing.ID)
assertNodeEqual(t, jacksHat, 7, 8, 0, 0, 0)
assertNodeEqual(t, lilysHat, 1, 2, 0, 0, 0)
assertNodeEqual(t, lilysClothing, 3, 6, 0, 1, 0)
assertNodeEqual(t, lilysDresses, 4, 5, 1, 0, lilysClothing.ID)
}

func TestMoveToLeft(t *testing.T) {
// case 1
initData()
Expand Down