@@ -1435,10 +1435,6 @@ def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weig
14351435
14361436 def __init__ (self , model , load_device , offload_device , size = 0 , weight_inplace_update = False ):
14371437 super ().__init__ (model , load_device , offload_device , size , weight_inplace_update )
1438- #this is now way more dynamic and we dont support the same base model for both Dynamic
1439- #and non-dynamic patchers.
1440- if hasattr (self .model , "model_loaded_weight_memory" ):
1441- del self .model .model_loaded_weight_memory
14421438 if not hasattr (self .model , "dynamic_vbars" ):
14431439 self .model .dynamic_vbars = {}
14441440 self .non_dynamic_delegate_model = None
@@ -1461,9 +1457,7 @@ def _vbar_get(self, create=False):
14611457
14621458 def loaded_size (self ):
14631459 vbar = self ._vbar_get ()
1464- if vbar is None :
1465- return 0
1466- return vbar .loaded_size ()
1460+ return (vbar .loaded_size () if vbar is not None else 0 ) + self .model .model_loaded_weight_memory
14671461
14681462 def get_free_memory (self , device ):
14691463 #NOTE: on high condition / batch counts, estimate should have already vacated
@@ -1504,6 +1498,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
15041498
15051499 num_patches = 0
15061500 allocated_size = 0
1501+ self .model .model_loaded_weight_memory = 0
15071502
15081503 with self .use_ejected ():
15091504 self .unpatch_hooks ()
@@ -1512,10 +1507,6 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
15121507 if vbar is not None :
15131508 vbar .prioritize ()
15141509
1515- #We force reserve VRAM for the non comfy-weight so we dont have to deal
1516- #with pin and unpin syncrhonization which can be expensive for small weights
1517- #with a high layer rate (e.g. autoregressive LLMs).
1518- #prioritize the non-comfy weights (note the order reverse).
15191510 loading = self ._load_list (prio_comfy_cast_weights = True , default_device = device_to )
15201511 loading .sort (reverse = True )
15211512
@@ -1558,6 +1549,9 @@ def force_load_param(self, param_key, device_to):
15581549 if key in self .backup :
15591550 comfy .utils .set_attr_param (self .model , key , self .backup [key ].weight )
15601551 self .patch_weight_to_device (key , device_to = device_to )
1552+ weight , _ , _ = get_key_weight (self .model , key )
1553+ if weight is not None :
1554+ self .model .model_loaded_weight_memory += weight .numel () * weight .element_size ()
15611555
15621556 if hasattr (m , "comfy_cast_weights" ):
15631557 m .comfy_cast_weights = True
@@ -1583,21 +1577,15 @@ def force_load_param(self, param_key, device_to):
15831577 for param in params :
15841578 key = key_param_name_to_key (n , param )
15851579 weight , _ , _ = get_key_weight (self .model , key )
1586- weight .seed_key = key
1587- set_dirty (weight , dirty )
1588- geometry = weight
1589- model_dtype = getattr (m , param + "_comfy_model_dtype" , None ) or weight .dtype
1590- geometry = comfy .memory_management .TensorGeometry (shape = weight .shape , dtype = model_dtype )
1591- weight_size = geometry .numel () * geometry .element_size ()
1592- if vbar is not None and not hasattr (weight , "_v" ):
1593- weight ._v = vbar .alloc (weight_size )
1594- weight ._model_dtype = model_dtype
1595- allocated_size += weight_size
1596- vbar .set_watermark_limit (allocated_size )
1580+ if key not in self .backup :
1581+ self .backup [key ] = collections .namedtuple ('Dimension' , ['weight' , 'inplace_update' ])(weight , False )
1582+ comfy .utils .set_attr_param (self .model , key , weight .to (device = device_to ))
1583+ self .model .model_loaded_weight_memory += weight .numel () * weight .element_size ()
15971584
15981585 move_weight_functions (m , device_to )
15991586
1600- logging .info (f"Model { self .model .__class__ .__name__ } prepared for dynamic VRAM loading. { allocated_size // (1024 ** 2 )} MB Staged. { num_patches } patches attached." )
1587+ force_load_stat = f" Force pre-loaded { len (self .backup )} weights: { self .model .model_loaded_weight_memory // 1024 } KB." if len (self .backup ) > 0 else ""
1588+ logging .info (f"Model { self .model .__class__ .__name__ } prepared for dynamic VRAM loading. { allocated_size // (1024 ** 2 )} MB Staged. { num_patches } patches attached.{ force_load_stat } " )
16011589
16021590 self .model .device = device_to
16031591 self .model .current_weight_patches_uuid = self .patches_uuid
@@ -1613,7 +1601,16 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
16131601 assert self .load_device != torch .device ("cpu" )
16141602
16151603 vbar = self ._vbar_get ()
1616- return 0 if vbar is None else vbar .free_memory (memory_to_free )
1604+ freed = 0 if vbar is None else vbar .free_memory (memory_to_free )
1605+
1606+ if freed < memory_to_free :
1607+ for key in list (self .backup .keys ()):
1608+ bk = self .backup .pop (key )
1609+ comfy .utils .set_attr_param (self .model , key , bk .weight )
1610+ freed += self .model .model_loaded_weight_memory
1611+ self .model .model_loaded_weight_memory = 0
1612+
1613+ return freed
16171614
16181615 def partially_unload_ram (self , ram_to_unload ):
16191616 loading = self ._load_list (prio_comfy_cast_weights = True , default_device = self .offload_device )
@@ -1640,11 +1637,6 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
16401637 for m in self .model .modules ():
16411638 move_weight_functions (m , device_to )
16421639
1643- keys = list (self .backup .keys ())
1644- for k in keys :
1645- bk = self .backup [k ]
1646- comfy .utils .set_attr_param (self .model , k , bk .weight )
1647-
16481640 def partially_load (self , device_to , extra_memory = 0 , force_patch_weights = False ):
16491641 assert not force_patch_weights #See above
16501642 with self .use_ejected (skip_and_inject_on_exit_only = True ):
0 commit comments