@@ -599,6 +599,27 @@ def model_patches_models(self):
599599
600600 return models
601601
602+ def model_patches_call_function (self , function_name = "cleanup" , arguments = {}):
603+ to = self .model_options ["transformer_options" ]
604+ if "patches" in to :
605+ patches = to ["patches" ]
606+ for name in patches :
607+ patch_list = patches [name ]
608+ for i in range (len (patch_list )):
609+ if hasattr (patch_list [i ], function_name ):
610+ getattr (patch_list [i ], function_name )(** arguments )
611+ if "patches_replace" in to :
612+ patches = to ["patches_replace" ]
613+ for name in patches :
614+ patch_list = patches [name ]
615+ for k in patch_list :
616+ if hasattr (patch_list [k ], function_name ):
617+ getattr (patch_list [k ], function_name )(** arguments )
618+ if "model_function_wrapper" in self .model_options :
619+ wrap_func = self .model_options ["model_function_wrapper" ]
620+ if hasattr (wrap_func , function_name ):
621+ getattr (wrap_func , function_name )(** arguments )
622+
602623 def model_dtype (self ):
603624 if hasattr (self .model , "get_dtype" ):
604625 return self .model .get_dtype ()
@@ -1062,6 +1083,7 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3
10621083 return comfy .lora .calculate_weight (patches , weight , key , intermediate_dtype = intermediate_dtype )
10631084
10641085 def cleanup (self ):
1086+ self .model_patches_call_function (function_name = "cleanup" )
10651087 self .clean_hooks ()
10661088 if hasattr (self .model , "current_patcher" ):
10671089 self .model .current_patcher = None
0 commit comments