@@ -97,14 +97,34 @@ def collect_instance_variables(self, func_node: ast.FunctionDef) -> None:
9797 if self .class_name is None or self .only_function_name != "forward" :
9898 return
9999
100- for node in ast .walk (func_node ):
100+ class_name = self .class_name
101+ instance_vars = self .instance_variable_names
102+
103+ # Manually traverse only assignment nodes instead of walking entire tree
104+ nodes_to_check = list (func_node .body )
105+ while nodes_to_check :
106+ node = nodes_to_check .pop ()
107+
101108 # Look for assignments like: model = ClassName(...)
102109 if isinstance (node , ast .Assign ):
103- if isinstance (node .value , ast .Call ) and isinstance (node .value .func , ast .Name ):
104- if node .value .func .id == self .class_name :
110+ value = node .value
111+ if isinstance (value , ast .Call ):
112+ func = value .func
113+ if isinstance (func , ast .Name ) and func .id == class_name :
105114 for target in node .targets :
106115 if isinstance (target , ast .Name ):
107- self .instance_variable_names .add (target .id )
116+ instance_vars .add (target .id )
117+
118+ # Add nested statements to check
119+ if hasattr (node , 'body' ):
120+ nodes_to_check .extend (node .body )
121+ if hasattr (node , 'orelse' ):
122+ nodes_to_check .extend (node .orelse )
123+ if hasattr (node , 'finalbody' ):
124+ nodes_to_check .extend (node .finalbody )
125+ if hasattr (node , 'handlers' ):
126+ for handler in node .handlers :
127+ nodes_to_check .extend (handler .body )
108128
109129 def find_and_update_line_node (
110130 self , test_node : ast .stmt , node_name : str , index : str , test_class_name : str | None = None
0 commit comments