Skip to content

Commit 09bcbdd

Browse files
authored
ModelPatcherDynamic: Force load all non-comfy weights (Comfy-Org#12739)
* model_management: Remove non-comfy dynamic _v caster * Force pre-load non-comfy weights to GPU in ModelPatcherDynamic Non-comfy weights may expect to be pre-cast to the target device without in-model casting. Previously they were allocated in the vbar with _v which required the _v fault path in cast_to. Instead, back up the original CPU weight and move it directly to GPU at load time.
1 parent dff0a4a commit 09bcbdd

File tree

2 files changed

+21
-69
lines changed

2 files changed

+21
-69
lines changed

comfy/model_management.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@
3232
import comfy.utils
3333
import comfy.quant_ops
3434

35-
import comfy_aimdo.torch
36-
import comfy_aimdo.model_vbar
37-
3835
class VRAMState(Enum):
3936
DISABLED = 0 #No vram present: no need to move models to vram
4037
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -1206,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
12061203

12071204

12081205
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
1209-
if hasattr(weight, "_v"):
1210-
#Unexpected usage patterns. There is no reason these don't work but they
1211-
#have no testing and no callers do this.
1212-
assert r is None
1213-
assert stream is None
1214-
1215-
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
1216-
1217-
if dtype is None:
1218-
dtype = weight._model_dtype
1219-
1220-
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
1221-
if signature is not None:
1222-
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
1223-
v_tensor = weight._v_tensor
1224-
else:
1225-
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
1226-
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
1227-
weight._v_tensor = v_tensor
1228-
weight._v_signature = signature
1229-
#Send it over
1230-
v_tensor.copy_(weight, non_blocking=non_blocking)
1231-
return v_tensor.to(dtype=dtype)
1232-
1233-
r = torch.empty_like(weight, dtype=dtype, device=device)
1234-
1235-
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
1236-
#Offloaded casting could skip this, however it would make the quantizations
1237-
#inconsistent between loaded and offloaded weights. So force the double casting
1238-
#that would happen in regular flow to make offload deterministic.
1239-
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
1240-
cast_buffer.copy_(weight, non_blocking=non_blocking)
1241-
weight = cast_buffer
1242-
r.copy_(weight, non_blocking=non_blocking)
1243-
1244-
return r
1245-
12461206
if device is None or weight.device == device:
12471207
if not copy:
12481208
if dtype is None or weight.dtype == dtype:

comfy/model_patcher.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,10 +1435,6 @@ def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weig
14351435

14361436
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
14371437
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
1438-
#this is now way more dynamic and we dont support the same base model for both Dynamic
1439-
#and non-dynamic patchers.
1440-
if hasattr(self.model, "model_loaded_weight_memory"):
1441-
del self.model.model_loaded_weight_memory
14421438
if not hasattr(self.model, "dynamic_vbars"):
14431439
self.model.dynamic_vbars = {}
14441440
self.non_dynamic_delegate_model = None
@@ -1461,9 +1457,7 @@ def _vbar_get(self, create=False):
14611457

14621458
def loaded_size(self):
14631459
vbar = self._vbar_get()
1464-
if vbar is None:
1465-
return 0
1466-
return vbar.loaded_size()
1460+
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
14671461

14681462
def get_free_memory(self, device):
14691463
#NOTE: on high condition / batch counts, estimate should have already vacated
@@ -1504,6 +1498,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
15041498

15051499
num_patches = 0
15061500
allocated_size = 0
1501+
self.model.model_loaded_weight_memory = 0
15071502

15081503
with self.use_ejected():
15091504
self.unpatch_hooks()
@@ -1512,10 +1507,6 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
15121507
if vbar is not None:
15131508
vbar.prioritize()
15141509

1515-
#We force reserve VRAM for the non comfy-weight so we dont have to deal
1516-
#with pin and unpin syncrhonization which can be expensive for small weights
1517-
#with a high layer rate (e.g. autoregressive LLMs).
1518-
#prioritize the non-comfy weights (note the order reverse).
15191510
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
15201511
loading.sort(reverse=True)
15211512

@@ -1558,6 +1549,9 @@ def force_load_param(self, param_key, device_to):
15581549
if key in self.backup:
15591550
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
15601551
self.patch_weight_to_device(key, device_to=device_to)
1552+
weight, _, _ = get_key_weight(self.model, key)
1553+
if weight is not None:
1554+
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
15611555

15621556
if hasattr(m, "comfy_cast_weights"):
15631557
m.comfy_cast_weights = True
@@ -1583,21 +1577,15 @@ def force_load_param(self, param_key, device_to):
15831577
for param in params:
15841578
key = key_param_name_to_key(n, param)
15851579
weight, _, _ = get_key_weight(self.model, key)
1586-
weight.seed_key = key
1587-
set_dirty(weight, dirty)
1588-
geometry = weight
1589-
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
1590-
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
1591-
weight_size = geometry.numel() * geometry.element_size()
1592-
if vbar is not None and not hasattr(weight, "_v"):
1593-
weight._v = vbar.alloc(weight_size)
1594-
weight._model_dtype = model_dtype
1595-
allocated_size += weight_size
1596-
vbar.set_watermark_limit(allocated_size)
1580+
if key not in self.backup:
1581+
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
1582+
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
1583+
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
15971584

15981585
move_weight_functions(m, device_to)
15991586

1600-
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
1587+
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
1588+
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
16011589

16021590
self.model.device = device_to
16031591
self.model.current_weight_patches_uuid = self.patches_uuid
@@ -1613,7 +1601,16 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
16131601
assert self.load_device != torch.device("cpu")
16141602

16151603
vbar = self._vbar_get()
1616-
return 0 if vbar is None else vbar.free_memory(memory_to_free)
1604+
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
1605+
1606+
if freed < memory_to_free:
1607+
for key in list(self.backup.keys()):
1608+
bk = self.backup.pop(key)
1609+
comfy.utils.set_attr_param(self.model, key, bk.weight)
1610+
freed += self.model.model_loaded_weight_memory
1611+
self.model.model_loaded_weight_memory = 0
1612+
1613+
return freed
16171614

16181615
def partially_unload_ram(self, ram_to_unload):
16191616
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
@@ -1640,11 +1637,6 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
16401637
for m in self.model.modules():
16411638
move_weight_functions(m, device_to)
16421639

1643-
keys = list(self.backup.keys())
1644-
for k in keys:
1645-
bk = self.backup[k]
1646-
comfy.utils.set_attr_param(self.model, k, bk.weight)
1647-
16481640
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
16491641
assert not force_patch_weights #See above
16501642
with self.use_ejected(skip_and_inject_on_exit_only=True):

0 commit comments

Comments
 (0)