Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test(session, sqlalchemy):
session.install(f"sqlalchemy~={sqlalchemy}.0")
session.install("-e", ".")
pytest_args = session.posargs or ["--pyargs", "sqlalchemy_mptt"]
session.run("pytest", *pytest_args, env={"SQLALCHEMY_SILENCE_UBER_WARNING": "1"})
session.run("pytest", *pytest_args, env={"SQLALCHEMY_WARN_20": "1"})


@nox.session(default=False)
Expand Down
216 changes: 84 additions & 132 deletions sqlalchemy_mptt/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# vim:fenc=utf-8
#
# Copyright © 2014 uralbash <root@uralbash.ru>
# Copyright (c) 2025 Fayaz Yusuf Khan <fayaz.yusuf.khan@gmail.com>
#
# Distributed under terms of the MIT license.

Expand All @@ -13,11 +14,13 @@
import weakref

# SQLAlchemy
from sqlalchemy import and_, case, event, select, inspection
from sqlalchemy import and_, event, inspection
from sqlalchemy.orm import object_session
from sqlalchemy.sql import func
from sqlalchemy.orm.base import NO_VALUE

from sqlalchemy_mptt.sqlalchemy_compat import compat_layer


def _insert_subtree(
table,
Expand All @@ -41,9 +44,9 @@ def _insert_subtree(
delta_rgt = delta_lft + node_size - 1

connection.execute(
table.update(
table_pk.in_(subtree)
).values(
table.update()
.where(table_pk.in_(subtree))
.values(
lft=table.c.lft - node_pos_left + delta_lft,
rgt=table.c.rgt - node_pos_right + delta_rgt,
level=table.c.level - node_level + parent_level + 1,
Expand All @@ -53,21 +56,14 @@ def _insert_subtree(

# step 2: update key of right side
connection.execute(
table.update(
and_(
table.c.rgt > delta_lft - 1,
table_pk.notin_(subtree),
table.c.tree_id == parent_tree_id
)
).values(
table.update()
.where(table.c.rgt > delta_lft - 1)
.where(table_pk.notin_(subtree))
.where(table.c.tree_id == parent_tree_id)
.values(
rgt=table.c.rgt + node_size,
lft=case(
[
(
table.c.lft > left_sibling['lft'],
table.c.lft + node_size
)
],
lft=compat_layer.case(
(table.c.lft > left_sibling['lft'], table.c.lft + node_size),
else_=table.c.lft
)
)
Expand All @@ -93,10 +89,8 @@ def mptt_before_insert(mapper, connection, instance):
instance.right = 2
instance.level = instance.get_default_level()
tree_id = connection.scalar(
select(
[
func.max(table.c.tree_id) + 1
]
compat_layer.select(
func.max(table.c.tree_id) + 1
)
) or 1
instance.tree_id = tree_id
Expand All @@ -105,40 +99,28 @@ def mptt_before_insert(mapper, connection, instance):
parent_pos_right,
parent_tree_id,
parent_level) = connection.execute(
select(
[
table.c.lft,
table.c.rgt,
table.c.tree_id,
table.c.level
]
compat_layer.select(
table.c.lft,
table.c.rgt,
table.c.tree_id,
table.c.level
).where(
table_pk == instance.parent_id
)
).fetchone()

# Update key of right side
connection.execute(
table.update(
and_(table.c.rgt >= parent_pos_right,
table.c.tree_id == parent_tree_id)
).values(
lft=case(
[
(
table.c.lft > parent_pos_right,
table.c.lft + 2
)
],
table.update()
.where(table.c.rgt >= parent_pos_right)
.where(table.c.tree_id == parent_tree_id)
.values(
lft=compat_layer.case(
(table.c.lft > parent_pos_right, table.c.lft + 2),
else_=table.c.lft
),
rgt=case(
[
(
table.c.rgt >= parent_pos_right,
table.c.rgt + 2
)
],
rgt=compat_layer.case(
(table.c.rgt >= parent_pos_right, table.c.rgt + 2),
else_=table.c.rgt
)
)
Expand All @@ -157,11 +139,9 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
db_pk = instance.get_pk_column()
table_pk = getattr(table.c, db_pk.name)
lft, rgt = connection.execute(
select(
[
table.c.lft,
table.c.rgt
]
compat_layer.select(
table.c.lft,
table.c.rgt
).where(
table_pk == pk
)
Expand All @@ -171,7 +151,7 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
if delete:
mapper.base_mapper.confirm_deleted_rows = False
connection.execute(
table.delete(
table.delete().where(
table_pk == pk
)
)
Expand All @@ -190,28 +170,16 @@ def mptt_before_delete(mapper, connection, instance, delete=True):
END
"""
connection.execute(
table.update(
and_(
table.c.rgt > rgt,
table.c.tree_id == tree_id
)
).values(
lft=case(
[
(
table.c.lft > lft,
table.c.lft - delta
)
],
table.update()
.where(table.c.rgt > rgt)
.where(table.c.tree_id == tree_id)
.values(
lft=compat_layer.case(
(table.c.lft > lft, table.c.lft - delta),
else_=table.c.lft
),
rgt=case(
[
(
table.c.rgt >= rgt,
table.c.rgt - delta
)
],
rgt=compat_layer.case(
(table.c.rgt >= rgt, table.c.rgt - delta),
else_=table.c.rgt
)
)
Expand Down Expand Up @@ -242,26 +210,22 @@ def mptt_before_update(mapper, connection, instance):
right_sibling_level,
right_sibling_tree_id
) = connection.execute(
select(
[
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.level,
table.c.tree_id
]
compat_layer.select(
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.level,
table.c.tree_id
).where(
table_pk == instance.mptt_move_before
)
).fetchone()
current_lvl_nodes = connection.execute(
select(
[
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.tree_id
]
compat_layer.select(
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.tree_id
).where(
and_(
table.c.level == right_sibling_level,
Expand Down Expand Up @@ -295,13 +259,11 @@ def mptt_before_update(mapper, connection, instance):
left_sibling_parent,
left_sibling_tree_id
) = connection.execute(
select(
[
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.tree_id
]
compat_layer.select(
table.c.lft,
table.c.rgt,
table.c.parent_id,
table.c.tree_id
).where(
table_pk == instance.mptt_move_after
)
Expand All @@ -320,7 +282,7 @@ def mptt_before_update(mapper, connection, instance):
ORDER BY left_key
"""
subtree = connection.execute(
select([table_pk])
compat_layer.select(table_pk)
.where(
and_(
table.c.lft >= instance.left,
Expand All @@ -344,14 +306,12 @@ def mptt_before_update(mapper, connection, instance):
node_parent_id,
node_level
) = connection.execute(
select(
[
table.c.lft,
table.c.rgt,
table.c.tree_id,
table.c.parent_id,
table.c.level
]
compat_layer.select(
table.c.lft,
table.c.rgt,
table.c.tree_id,
table.c.parent_id,
table.c.level
).where(
table_pk == node_id
)
Expand All @@ -374,14 +334,12 @@ def mptt_before_update(mapper, connection, instance):
parent_tree_id,
parent_level
) = connection.execute(
select(
[
table_pk,
table.c.rgt,
table.c.lft,
table.c.tree_id,
table.c.level
]
compat_layer.select(
table_pk,
table.c.rgt,
table.c.lft,
table.c.tree_id,
table.c.level
).where(
table_pk == instance.parent_id
)
Expand All @@ -404,14 +362,12 @@ def mptt_before_update(mapper, connection, instance):
parent_tree_id,
parent_level
) = connection.execute(
select(
[
table_pk,
table.c.rgt,
table.c.lft,
table.c.tree_id,
table.c.level
]
compat_layer.select(
table_pk,
table.c.rgt,
table.c.lft,
table.c.tree_id,
table.c.level
).where(
table_pk == instance.parent_id
)
Expand Down Expand Up @@ -449,28 +405,24 @@ def mptt_before_update(mapper, connection, instance):
if left_sibling_tree_id or left_sibling_tree_id == 0:
tree_id = left_sibling_tree_id + 1
connection.execute(
table.update(
table.c.tree_id > left_sibling_tree_id
).values(
table.update()
.where(table.c.tree_id > left_sibling_tree_id)
.values(
tree_id=table.c.tree_id + 1
)
)
# if just insert
else:
tree_id = connection.scalar(
select(
[
func.max(table.c.tree_id) + 1
]
compat_layer.select(
func.max(table.c.tree_id) + 1
)
)

connection.execute(
table.update(
table_pk.in_(
subtree
)
).values(
table.update()
.where(table_pk.in_(subtree))
.values(
lft=table.c.lft - node_pos_left + 1,
rgt=table.c.rgt - node_pos_left + 1,
level=table.c.level - node_level + default_level,
Expand Down
Loading