@@ -382,26 +382,22 @@ def cascade(self, table_expr, part_integrity="enforce"):
382382 result ._propagate_restrictions (node , mode = "cascade" , part_integrity = part_integrity )
383383 return result
384384
385- @staticmethod
386- def _restrict_freetable (ft , restrictions , mode = "cascade" ):
385+ def _restricted_table (self , node ):
387386 """
388- Apply cascade/restrict restrictions to a FreeTable.
389-
390- Uses ``restrict()`` to properly convert each restriction (AndList,
391- QueryExpression, etc.) into SQL via ``make_condition``, rather than
392- assigning raw objects to ``_restriction`` which would produce
393- invalid SQL in ``where_clause``.
394-
395- For cascade mode (delete), restrictions from different parent edges
396- are OR-ed: a row is deleted if ANY of its FK references point to a
397- deleted row.
387+ Return a FreeTable for ``node`` with this diagram's restrictions applied.
398388
399- For restrict mode (export), restrictions are AND-ed: a row is
400- included only if ALL ancestor conditions are satisfied.
389+ Cascade restrictions are OR-combined (a row is affected if ANY
390+ FK reference points to a deleted row). Restrict conditions are
391+ AND-combined (a row is included only when ALL ancestor conditions
392+ are satisfied).
401393 """
394+ from .table import FreeTable
395+
396+ ft = FreeTable (self ._connection , node )
397+ restrictions = (self ._cascade_restrictions or self ._restrict_conditions ).get (node , [])
402398 if not restrictions :
403399 return ft
404- if mode == "cascade" :
400+ if self . _cascade_restrictions :
405401 # OR semantics — passing a list to restrict() creates an OrList
406402 return ft .restrict (restrictions )
407403 else :
@@ -473,10 +469,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
473469 continue
474470
475471 # Build parent FreeTable with current restriction
476- parent_ft = FreeTable (self ._connection , node )
477- restr = restrictions [node ]
478- if restr :
479- parent_ft = self ._restrict_freetable (parent_ft , restr , mode = mode )
472+ parent_ft = self ._restricted_table (node )
480473
481474 parent_attrs = self ._restriction_attrs .get (node , set ())
482475
@@ -531,10 +524,7 @@ def _propagate_restrictions(self, start_node, mode, part_integrity="enforce"):
531524 and master_name not in visited_masters
532525 ):
533526 visited_masters .add (master_name )
534- child_ft = FreeTable (self ._connection , target )
535- child_restr = restrictions .get (target , [])
536- if child_restr :
537- child_ft = self ._restrict_freetable (child_ft , child_restr , mode = mode )
527+ child_ft = self ._restricted_table (target )
538528 master_ft = FreeTable (self ._connection , master_name )
539529 from .condition import make_condition
540530
@@ -625,8 +615,6 @@ def delete(self, transaction=True, prompt=None, dry_run=False):
625615 Number of rows deleted from the root table, or (if ``dry_run``)
626616 a mapping of full table name to affected row count.
627617 """
628- from .table import FreeTable
629-
630618 if dry_run :
631619 return self .preview ()
632620
@@ -644,8 +632,7 @@ def delete(self, transaction=True, prompt=None, dry_run=False):
644632 # Preview
645633 if prompt :
646634 for t in tables :
647- ft = FreeTable (conn , t )
648- ft = self ._restrict_freetable (ft , self ._cascade_restrictions [t ])
635+ ft = self ._restricted_table (t )
649636 logger .info ("{table} ({count} tuples)" .format (table = t , count = len (ft )))
650637
651638 # Start transaction
@@ -667,8 +654,7 @@ def delete(self, transaction=True, prompt=None, dry_run=False):
667654 deleted_tables = set ()
668655 try :
669656 for table_name in reversed (tables ):
670- ft = FreeTable (conn , table_name )
671- ft = self ._restrict_freetable (ft , self ._cascade_restrictions [table_name ])
657+ ft = self ._restricted_table (table_name )
672658 count = ft .delete_quick (get_count = True )
673659 if count > 0 :
674660 deleted_tables .add (table_name )
@@ -789,20 +775,15 @@ def preview(self):
789775 dict[str, int]
790776 Mapping of full table name to affected row count.
791777 """
792- from .table import FreeTable
793-
794778 restrictions = self ._cascade_restrictions or self ._restrict_conditions
795- mode = "cascade" if self ._cascade_restrictions else "restrict"
796779 if not restrictions :
797780 raise DataJointError ("No restrictions applied. " "Call cascade() or restrict() first." )
798781
799782 result = {}
800783 for node in topo_sort (self ):
801784 if node .isdigit () or node not in restrictions :
802785 continue
803- ft = FreeTable (self ._connection , node )
804- ft = self ._restrict_freetable (ft , restrictions [node ], mode = mode )
805- result [node ] = len (ft )
786+ result [node ] = len (self ._restricted_table (node ))
806787
807788 for t , count in result .items ():
808789 logger .info ("{table} ({count} tuples)" .format (table = t , count = count ))
@@ -825,16 +806,13 @@ def prune(self):
825806
826807 result = Diagram (self )
827808 restrictions = result ._cascade_restrictions or result ._restrict_conditions
828- mode = "cascade" if result ._cascade_restrictions else "restrict"
829809
830810 if restrictions :
831811 # Restricted: check row counts under restriction
832812 for node in list (restrictions ):
833813 if node .isdigit ():
834814 continue
835- ft = FreeTable (self ._connection , node )
836- ft = self ._restrict_freetable (ft , restrictions [node ], mode = mode )
837- if len (ft ) == 0 :
815+ if len (result ._restricted_table (node )) == 0 :
838816 restrictions .pop (node )
839817 result ._restriction_attrs .pop (node , None )
840818 result .nodes_to_show .discard (node )
0 commit comments