Skip to content

Commit a1d8557

Browse files
refactor: add __iter__/__reversed__ to Diagram; simplify delete/drop
Diagram now supports Python iteration protocol, yielding FreeTable objects in topological order. Table.delete() and Table.drop() use reversed(diagram) instead of manual topo_sort loops. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 793d0b3 commit a1d8557

File tree

2 files changed

+48
-32
lines changed

2 files changed

+48
-32
lines changed

src/datajoint/diagram.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -621,14 +621,36 @@ def preview(self):
621621
raise DataJointError("No restrictions applied. " "Call cascade() or restrict() first.")
622622

623623
result = {}
624+
for ft in self:
625+
if ft.full_table_name in restrictions:
626+
count = len(ft)
627+
result[ft.full_table_name] = count
628+
logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=count))
629+
return result
630+
631+
def __iter__(self):
632+
"""
633+
Iterate over non-alias nodes in topological order (parents first).
634+
635+
Yields restricted ``FreeTable`` objects when cascade or restrict
636+
conditions have been applied, unrestricted ``FreeTable`` otherwise.
637+
638+
Alias nodes (used internally for multi-FK edges) are skipped.
639+
"""
624640
for node in topo_sort(self):
625-
if node.isdigit() or node not in restrictions:
626-
continue
627-
result[node] = len(self._restricted_table(node))
641+
if not node.isdigit() and node in self.nodes_to_show:
642+
yield self._restricted_table(node)
628643

629-
for t, count in result.items():
630-
logger.info("{table} ({count} tuples)".format(table=t, count=count))
631-
return result
644+
def __reversed__(self):
645+
"""
646+
Iterate in reverse topological order (leaves first).
647+
648+
Same as ``__iter__`` but reversed — useful for cascading
649+
deletes and drops.
650+
"""
651+
for node in reversed(topo_sort(self)):
652+
if not node.isdigit() and node in self.nodes_to_show:
653+
yield self._restricted_table(node)
632654

633655
def prune(self):
634656
"""

src/datajoint/table.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .condition import make_condition
1616
from .declare import alter, declare
17-
from .dependencies import extract_master, topo_sort
17+
from .dependencies import extract_master
1818
from .errors import (
1919
AccessError,
2020
DataJointError,
@@ -1019,15 +1019,10 @@ def delete(
10191019
conn = self.connection
10201020
prompt = conn._config["safemode"] if prompt is None else prompt
10211021

1022-
# Get non-alias nodes in topological order (graph is already trimmed by cascade())
1023-
all_sorted = topo_sort(diagram)
1024-
tables = [t for t in all_sorted if not t.isdigit()]
1025-
10261022
# Preview
10271023
if prompt:
1028-
for t in tables:
1029-
ft = diagram._restricted_table(t)
1030-
logger.info("{table} ({count} tuples)".format(table=t, count=len(ft)))
1024+
for ft in diagram:
1025+
logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=len(ft)))
10311026

10321027
# Start transaction
10331028
if transaction:
@@ -1047,13 +1042,12 @@ def delete(
10471042
root_count = 0
10481043
deleted_tables = set()
10491044
try:
1050-
for table_name in reversed(tables):
1051-
ft = diagram._restricted_table(table_name)
1045+
for ft in reversed(diagram):
10521046
count = ft.delete_quick(get_count=True)
10531047
if count > 0:
1054-
deleted_tables.add(table_name)
1055-
logger.info("Deleting {count} rows from {table}".format(count=count, table=table_name))
1056-
if table_name == tables[0]:
1048+
deleted_tables.add(ft.full_table_name)
1049+
logger.info("Deleting {count} rows from {table}".format(count=count, table=ft.full_table_name))
1050+
if ft.full_table_name == self.full_table_name:
10571051
root_count = count
10581052
except IntegrityError as error:
10591053
if transaction:
@@ -1175,34 +1169,34 @@ def drop(self, prompt: bool | None = None, part_integrity: str = "enforce", dry_
11751169
conn = self.connection
11761170
prompt = conn._config["safemode"] if prompt is None else prompt
11771171

1178-
tables = [t for t in topo_sort(diagram) if not t.isdigit() and t in diagram.nodes_to_show]
1172+
table_names = [ft.full_table_name for ft in diagram]
11791173

11801174
if part_integrity == "enforce":
1181-
for part in tables:
1182-
master = extract_master(part)
1183-
if master and master not in tables:
1175+
for name in table_names:
1176+
master = extract_master(name)
1177+
if master and master not in table_names:
11841178
raise DataJointError(
11851179
"Attempt to drop part table {part} before its " "master {master}. Drop the master first.".format(
1186-
part=part, master=master
1180+
part=name, master=master
11871181
)
11881182
)
11891183

11901184
if dry_run:
11911185
result = {}
1192-
for t in tables:
1193-
count = len(FreeTable(conn, t))
1194-
result[t] = count
1195-
logger.info("{table} ({count} tuples)".format(table=t, count=count))
1186+
for ft in diagram:
1187+
count = len(ft)
1188+
result[ft.full_table_name] = count
1189+
logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=count))
11961190
return result
11971191

11981192
do_drop = True
11991193
if prompt:
1200-
for t in tables:
1201-
logger.info("{table} ({count} tuples)".format(table=t, count=len(FreeTable(conn, t))))
1194+
for ft in diagram:
1195+
logger.info("{table} ({count} tuples)".format(table=ft.full_table_name, count=len(ft)))
12021196
do_drop = user_choice("Proceed?", default="no") == "yes"
12031197
if do_drop:
1204-
for t in reversed(tables):
1205-
FreeTable(conn, t).drop_quick()
1198+
for ft in reversed(diagram):
1199+
ft.drop_quick()
12061200
logger.info("Tables dropped. Restart kernel.")
12071201

12081202
def describe(self, context=None, printout=False):

0 commit comments

Comments
 (0)