Skip to content

Commit 91bf61b

Browse files
refactor: replace _restrict_freetable with _restricted_table on Diagram
Move OR/AND convergence logic into a single Diagram method that returns a FreeTable with the diagram's restrictions already applied. Callers no longer need to know about modes or pass restriction lists explicitly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3a2fc59 commit 91bf61b

File tree

1 file changed

+17
-39
lines changed

1 file changed

+17
-39
lines changed

src/datajoint/diagram.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -382,26 +382,22 @@ def cascade(self, table_expr, part_integrity="enforce"):
382382
result._propagate_restrictions(node, mode="cascade", part_integrity=part_integrity)
383383
return result
384384

385-
@staticmethod
386-
def _restrict_freetable(ft, restrictions, mode="cascade"):
385+
def _restricted_table(self, node):
387386
"""
388-
Apply cascade/restrict restrictions to a FreeTable.
389-
390-
Uses ``restrict()`` to properly convert each restriction (AndList,
391-
QueryExpression, etc.) into SQL via ``make_condition``, rather than
392-
assigning raw objects to ``_restriction`` which would produce
393-
invalid SQL in ``where_clause``.
394-
395-
For cascade mode (delete), restrictions from different parent edges
396-
are OR-ed: a row is deleted if ANY of its FK references point to a
397-
deleted row.
387+
Return a FreeTable for ``node`` with this diagram's restrictions applied.
398388
399-
For restrict mode (export), restrictions are AND-ed: a row is
400-
included only if ALL ancestor conditions are satisfied.
389+
Cascade restrictions are OR-combined (a row is affected if ANY
390+
FK reference points to a deleted row). Restrict conditions are
391+
AND-combined (a row is included only when ALL ancestor conditions
392+
are satisfied).
401393
"""
394+
from .table import FreeTable
395+
396+
ft = FreeTable(self._connection, node)
397+
restrictions = (self._cascade_restrictions or self._restrict_conditions).get(node, [])
402398
if not restrictions:
403399
return ft
404-
if mode == "cascade":
400+
if self._cascade_restrictions:
405401
# OR semantics — passing a list to restrict() creates an OrList
406402
return ft.restrict(restrictions)
407403
else:
@@ -473,10 +469,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
473469
continue
474470

475471
# Build parent FreeTable with current restriction
476-
parent_ft = FreeTable(self._connection, node)
477-
restr = restrictions[node]
478-
if restr:
479-
parent_ft = self._restrict_freetable(parent_ft, restr, mode=mode)
472+
parent_ft = self._restricted_table(node)
480473

481474
parent_attrs = self._restriction_attrs.get(node, set())
482475

@@ -531,10 +524,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
531524
and master_name not in visited_masters
532525
):
533526
visited_masters.add(master_name)
534-
child_ft = FreeTable(self._connection, target)
535-
child_restr = restrictions.get(target, [])
536-
if child_restr:
537-
child_ft = self._restrict_freetable(child_ft, child_restr, mode=mode)
527+
child_ft = self._restricted_table(target)
538528
master_ft = FreeTable(self._connection, master_name)
539529
from .condition import make_condition
540530

@@ -625,8 +615,6 @@ def delete(self, transaction=True, prompt=None, dry_run=False):
625615
Number of rows deleted from the root table, or (if ``dry_run``)
626616
a mapping of full table name to affected row count.
627617
"""
628-
from .table import FreeTable
629-
630618
if dry_run:
631619
return self.preview()
632620

@@ -644,8 +632,7 @@ def delete(self, transaction=True, prompt=None, dry_run=False):
644632
# Preview
645633
if prompt:
646634
for t in tables:
647-
ft = FreeTable(conn, t)
648-
ft = self._restrict_freetable(ft, self._cascade_restrictions[t])
635+
ft = self._restricted_table(t)
649636
logger.info("{table} ({count} tuples)".format(table=t, count=len(ft)))
650637

651638
# Start transaction
@@ -667,8 +654,7 @@ def delete(self, transaction=True, prompt=None, dry_run=False):
667654
deleted_tables = set()
668655
try:
669656
for table_name in reversed(tables):
670-
ft = FreeTable(conn, table_name)
671-
ft = self._restrict_freetable(ft, self._cascade_restrictions[table_name])
657+
ft = self._restricted_table(table_name)
672658
count = ft.delete_quick(get_count=True)
673659
if count > 0:
674660
deleted_tables.add(table_name)
@@ -789,20 +775,15 @@ def preview(self):
789775
dict[str, int]
790776
Mapping of full table name to affected row count.
791777
"""
792-
from .table import FreeTable
793-
794778
restrictions = self._cascade_restrictions or self._restrict_conditions
795-
mode = "cascade" if self._cascade_restrictions else "restrict"
796779
if not restrictions:
797780
raise DataJointError("No restrictions applied. " "Call cascade() or restrict() first.")
798781

799782
result = {}
800783
for node in topo_sort(self):
801784
if node.isdigit() or node not in restrictions:
802785
continue
803-
ft = FreeTable(self._connection, node)
804-
ft = self._restrict_freetable(ft, restrictions[node], mode=mode)
805-
result[node] = len(ft)
786+
result[node] = len(self._restricted_table(node))
806787

807788
for t, count in result.items():
808789
logger.info("{table} ({count} tuples)".format(table=t, count=count))
@@ -825,16 +806,13 @@ def prune(self):
825806

826807
result = Diagram(self)
827808
restrictions = result._cascade_restrictions or result._restrict_conditions
828-
mode = "cascade" if result._cascade_restrictions else "restrict"
829809

830810
if restrictions:
831811
# Restricted: check row counts under restriction
832812
for node in list(restrictions):
833813
if node.isdigit():
834814
continue
835-
ft = FreeTable(self._connection, node)
836-
ft = self._restrict_freetable(ft, restrictions[node], mode=mode)
837-
if len(ft) == 0:
815+
if len(result._restricted_table(node)) == 0:
838816
restrictions.pop(node)
839817
result._restriction_attrs.pop(node, None)
840818
result.nodes_to_show.discard(node)

0 commit comments

Comments
 (0)