Skip to content

Commit 3ad36d6

Browse files
Allow model patches to have a cleanup function. (Comfy-Org#12878)
The function gets called after sampling is finished.
1 parent 8086468 commit 3ad36d6

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

comfy/model_patcher.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,27 @@ def model_patches_models(self):
599599

600600
return models
601601

602+
def model_patches_call_function(self, function_name="cleanup", arguments={}):
603+
to = self.model_options["transformer_options"]
604+
if "patches" in to:
605+
patches = to["patches"]
606+
for name in patches:
607+
patch_list = patches[name]
608+
for i in range(len(patch_list)):
609+
if hasattr(patch_list[i], function_name):
610+
getattr(patch_list[i], function_name)(**arguments)
611+
if "patches_replace" in to:
612+
patches = to["patches_replace"]
613+
for name in patches:
614+
patch_list = patches[name]
615+
for k in patch_list:
616+
if hasattr(patch_list[k], function_name):
617+
getattr(patch_list[k], function_name)(**arguments)
618+
if "model_function_wrapper" in self.model_options:
619+
wrap_func = self.model_options["model_function_wrapper"]
620+
if hasattr(wrap_func, function_name):
621+
getattr(wrap_func, function_name)(**arguments)
622+
602623
def model_dtype(self):
603624
if hasattr(self.model, "get_dtype"):
604625
return self.model.get_dtype()
@@ -1062,6 +1083,7 @@ def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float3
10621083
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
10631084

10641085
def cleanup(self):
1086+
self.model_patches_call_function(function_name="cleanup")
10651087
self.clean_hooks()
10661088
if hasattr(self.model, "current_patcher"):
10671089
self.model.current_patcher = None

0 commit comments

Comments
 (0)