Skip to content

Commit 7810f49

Browse files
authored
comfy aimdo 0.2.11 + Improved RAM Pressure release strategies - Windows speedups (Comfy-Org#12925)
* Implement seek and read for pins Source pins from an mmap is pad because its its a CPU->CPU copy that attempts to fully buffer the same data twice. Instead, use seek and read which avoids the mmap buffering while usually being a faster read in the first place (avoiding mmap faulting etc). * pinned_memory: Use Aimdo pinner The aimdo pinner bypasses pytorches CPU allocator which can leak windows commit charge. * ops: bypass init() of weight for embedding layer This similarly consumes large commit charge especially for TEs. It can cause a permanement leaked commit charge which can destabilize on systems close to the commit ceiling and generally confuses the RAM stats. * model_patcher: implement pinned memory counter Implement a pinned memory counter for better accounting of what volume of memory pins have. * implement touch accounting Implement accounting of touching mmapped tensors. * mm+mp: add residency mmap getter * utils: use the aimdo mmap to load sft files * model_management: Implement tigher RAM pressure semantics Implement a pressure release on entire MMAPs as windows does perform faster when mmaps are unloaded and model loads free ramp into fully unallocated RAM. Make the concept of freeing for pins a completely separate concept. Now that pins are loadable directly from original file and don' touch the mmap, tighten the freeing budget to just the current loaded model - what you have left over. This still over-frees pins, but its a lot better than before. So after the pins are freed with that algorithm, bounce entire MMAPs to free RAM based on what the model needs, deducting off any known resident-in-mmap tensors to the free quota to keep it as tight as possible. * comfy-aimdo 0.2.11 Comfy aimdo 0.2.11 * mm: Implement file_slice path for QT * ruff * ops: put meta-tensors in place to allow custom nodes to check geo
1 parent e1f10ca commit 7810f49

File tree

7 files changed

+258
-50
lines changed

7 files changed

+258
-50
lines changed

comfy/memory_management.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,68 @@
11
import math
2+
import ctypes
3+
import threading
4+
import dataclasses
25
import torch
36
from typing import NamedTuple
47

58
from comfy.quant_ops import QuantizedTensor
69

10+
11+
class TensorFileSlice(NamedTuple):
12+
file_ref: object
13+
thread_id: int
14+
offset: int
15+
size: int
16+
17+
18+
def read_tensor_file_slice_into(tensor, destination):
19+
20+
if isinstance(tensor, QuantizedTensor):
21+
if not isinstance(destination, QuantizedTensor):
22+
return False
23+
if tensor._layout_cls != destination._layout_cls:
24+
return False
25+
26+
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
27+
return False
28+
29+
dst_orig_dtype = destination._params.orig_dtype
30+
destination._params.copy_from(tensor._params, non_blocking=False)
31+
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
32+
return True
33+
34+
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
35+
if info is None:
36+
return False
37+
38+
file_obj = info.file_ref
39+
if (destination.device.type != "cpu"
40+
or file_obj is None
41+
or threading.get_ident() != info.thread_id
42+
or destination.numel() * destination.element_size() < info.size):
43+
return False
44+
45+
if info.size == 0:
46+
return True
47+
48+
buf_type = ctypes.c_ubyte * info.size
49+
view = memoryview(buf_type.from_address(destination.data_ptr()))
50+
51+
try:
52+
file_obj.seek(info.offset)
53+
done = 0
54+
while done < info.size:
55+
try:
56+
n = file_obj.readinto(view[done:])
57+
except OSError:
58+
return False
59+
if n <= 0:
60+
return False
61+
done += n
62+
return True
63+
finally:
64+
view.release()
65+
766
class TensorGeometry(NamedTuple):
867
shape: any
968
dtype: torch.dtype

comfy/model_management.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,28 @@ def module_size(module):
505505
module_mem += t.nbytes
506506
return module_mem
507507

508+
def module_mmap_residency(module, free=False):
509+
mmap_touched_mem = 0
510+
module_mem = 0
511+
bounced_mmaps = set()
512+
sd = module.state_dict()
513+
for k in sd:
514+
t = sd[k]
515+
module_mem += t.nbytes
516+
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
517+
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
518+
continue
519+
mmap_touched_mem += t.nbytes
520+
if not free:
521+
continue
522+
storage._comfy_tensor_mmap_touched = False
523+
mmap_obj = storage._comfy_tensor_mmap_refs[0]
524+
if mmap_obj in bounced_mmaps:
525+
continue
526+
mmap_obj.bounce()
527+
bounced_mmaps.add(mmap_obj)
528+
return mmap_touched_mem, module_mem
529+
508530
class LoadedModel:
509531
def __init__(self, model):
510532
self._set_model(model)
@@ -532,6 +554,9 @@ def model(self):
532554
def model_memory(self):
533555
return self.model.model_size()
534556

557+
def model_mmap_residency(self, free=False):
558+
return self.model.model_mmap_residency(free=free)
559+
535560
def model_loaded_memory(self):
536561
return self.model.loaded_size()
537562

@@ -633,7 +658,7 @@ def extra_reserved_memory():
633658
def minimum_inference_memory():
634659
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
635660

636-
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
661+
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
637662
cleanup_models_gc()
638663
unloaded_model = []
639664
can_unload = []
@@ -646,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
646671
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
647672
shift_model.currently_used = False
648673

649-
for x in sorted(can_unload):
674+
can_unload_sorted = sorted(can_unload)
675+
for x in can_unload_sorted:
650676
i = x[-1]
651677
memory_to_free = 1e32
652-
ram_to_free = 1e32
678+
pins_to_free = 1e32
653679
if not DISABLE_SMART_MEMORY:
654680
memory_to_free = memory_required - get_free_memory(device)
655-
ram_to_free = ram_required - get_free_ram()
681+
pins_to_free = pins_required - get_free_ram()
656682
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
657683
#don't actually unload dynamic models for the sake of other dynamic models
658684
#as that works on-demand.
@@ -661,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
661687
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
662688
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
663689
unloaded_model.append(i)
664-
if ram_to_free > 0:
690+
if pins_to_free > 0:
691+
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
692+
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
693+
694+
for x in can_unload_sorted:
695+
i = x[-1]
696+
ram_to_free = ram_required - psutil.virtual_memory().available
697+
if ram_to_free <= 0 and i not in unloaded_model:
698+
continue
699+
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
700+
if resident_memory > 0:
665701
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
666-
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
667702

668703
for i in sorted(unloaded_model, reverse=True):
669704
unloaded_models.append(current_loaded_models.pop(i))
@@ -729,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
729764

730765

731766
total_memory_required = {}
767+
total_pins_required = {}
732768
total_ram_required = {}
733769
for loaded_model in models_to_load:
734-
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
735-
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
736-
#want to do.
737-
#FIXME: This should subtract off the to_load current pin consumption.
738-
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
770+
device = loaded_model.device
771+
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
772+
resident_memory, model_memory = loaded_model.model_mmap_residency()
773+
pinned_memory = loaded_model.model.pinned_memory_size()
774+
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
775+
#make this JIT to keep as much pinned as possible.
776+
pins_required = model_memory - pinned_memory
777+
ram_required = model_memory - resident_memory
778+
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
779+
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
739780

740781
for device in total_memory_required:
741782
if device != torch.device("cpu"):
742-
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
783+
free_memory(total_memory_required[device] * 1.1 + extra_mem,
784+
device,
785+
for_dynamic=free_for_dynamic,
786+
pins_required=total_pins_required[device],
787+
ram_required=total_ram_required[device])
743788

744789
for device in total_memory_required:
745790
if device != torch.device("cpu"):
@@ -1225,6 +1270,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
12251270
dest_view = dest_views.pop(0)
12261271
if tensor is None:
12271272
continue
1273+
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
1274+
continue
1275+
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
1276+
if hasattr(storage, "_comfy_tensor_mmap_touched"):
1277+
storage._comfy_tensor_mmap_touched = True
12281278
dest_view.copy_(tensor, non_blocking=non_blocking)
12291279

12301280

comfy/model_patcher.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ def model_size(self):
297297
self.size = comfy.model_management.module_size(self.model)
298298
return self.size
299299

300+
def model_mmap_residency(self, free=False):
301+
return comfy.model_management.module_mmap_residency(self.model, free=free)
302+
300303
def get_ram_usage(self):
301304
return self.model_size()
302305

@@ -1063,6 +1066,10 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
10631066

10641067
return self.model.model_loaded_weight_memory - current_used
10651068

1069+
def pinned_memory_size(self):
1070+
# Pinned memory pressure tracking is only implemented for DynamicVram loading
1071+
return 0
1072+
10661073
def partially_unload_ram(self, ram_to_unload):
10671074
pass
10681075

@@ -1653,6 +1660,16 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
16531660

16541661
return freed
16551662

1663+
def pinned_memory_size(self):
1664+
total = 0
1665+
loading = self._load_list(for_dynamic=True)
1666+
for x in loading:
1667+
_, _, _, _, m, _ = x
1668+
pin = comfy.pinned_memory.get_pin(m)
1669+
if pin is not None:
1670+
total += pin.numel() * pin.element_size()
1671+
return total
1672+
16561673
def partially_unload_ram(self, ram_to_unload):
16571674
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
16581675
for x in loading:

comfy/ops.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,33 @@ class CastWeightBiasOp:
306306
bias_function = []
307307

308308
class disable_weight_init:
309+
@staticmethod
310+
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
311+
missing_keys, unexpected_keys, weight_shape,
312+
bias_shape=None):
313+
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
314+
prefix_len = len(prefix)
315+
for k, v in state_dict.items():
316+
key = k[prefix_len:]
317+
if key == "weight":
318+
if not assign_to_params_buffers:
319+
v = v.clone()
320+
module.weight = torch.nn.Parameter(v, requires_grad=False)
321+
elif bias_shape is not None and key == "bias" and v is not None:
322+
if not assign_to_params_buffers:
323+
v = v.clone()
324+
module.bias = torch.nn.Parameter(v, requires_grad=False)
325+
else:
326+
unexpected_keys.append(k)
327+
328+
if module.weight is None:
329+
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
330+
missing_keys.append(prefix + "weight")
331+
332+
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
333+
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
334+
missing_keys.append(prefix + "bias")
335+
309336
class Linear(torch.nn.Linear, CastWeightBiasOp):
310337

311338
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
@@ -333,29 +360,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
333360
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
334361
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
335362
missing_keys, unexpected_keys, error_msgs)
336-
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
337-
prefix_len = len(prefix)
338-
for k,v in state_dict.items():
339-
if k[prefix_len:] == "weight":
340-
if not assign_to_params_buffers:
341-
v = v.clone()
342-
self.weight = torch.nn.Parameter(v, requires_grad=False)
343-
elif k[prefix_len:] == "bias" and v is not None:
344-
if not assign_to_params_buffers:
345-
v = v.clone()
346-
self.bias = torch.nn.Parameter(v, requires_grad=False)
347-
else:
348-
unexpected_keys.append(k)
349-
350-
#Reconcile default construction of the weight if its missing.
351-
if self.weight is None:
352-
v = torch.zeros(self.in_features, self.out_features)
353-
self.weight = torch.nn.Parameter(v, requires_grad=False)
354-
missing_keys.append(prefix+"weight")
355-
if self.bias is None and self.comfy_need_lazy_init_bias:
356-
v = torch.zeros(self.out_features,)
357-
self.bias = torch.nn.Parameter(v, requires_grad=False)
358-
missing_keys.append(prefix+"bias")
363+
disable_weight_init._lazy_load_from_state_dict(
364+
self,
365+
state_dict,
366+
prefix,
367+
local_metadata,
368+
missing_keys,
369+
unexpected_keys,
370+
weight_shape=(self.in_features, self.out_features),
371+
bias_shape=(self.out_features,),
372+
)
359373

360374

361375
def reset_parameters(self):
@@ -547,6 +561,48 @@ def forward(self, *args, **kwargs):
547561
return super().forward(*args, **kwargs)
548562

549563
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
564+
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
565+
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
566+
_freeze=False, device=None, dtype=None):
567+
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
568+
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
569+
norm_type, scale_grad_by_freq, sparse, _weight,
570+
_freeze, device, dtype)
571+
return
572+
573+
torch.nn.Module.__init__(self)
574+
self.num_embeddings = num_embeddings
575+
self.embedding_dim = embedding_dim
576+
self.padding_idx = padding_idx
577+
self.max_norm = max_norm
578+
self.norm_type = norm_type
579+
self.scale_grad_by_freq = scale_grad_by_freq
580+
self.sparse = sparse
581+
# Keep shape/dtype visible for module introspection without reserving storage.
582+
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
583+
self.weight = torch.nn.Parameter(
584+
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
585+
requires_grad=False,
586+
)
587+
self.bias = None
588+
self.weight_comfy_model_dtype = dtype
589+
590+
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
591+
strict, missing_keys, unexpected_keys, error_msgs):
592+
593+
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
594+
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
595+
missing_keys, unexpected_keys, error_msgs)
596+
disable_weight_init._lazy_load_from_state_dict(
597+
self,
598+
state_dict,
599+
prefix,
600+
local_metadata,
601+
missing_keys,
602+
unexpected_keys,
603+
weight_shape=(self.num_embeddings, self.embedding_dim),
604+
)
605+
550606
def reset_parameters(self):
551607
self.bias = None
552608
return None

0 commit comments

Comments
 (0)