Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
import comfy.utils
import comfy.quant_ops

import comfy_aimdo.torch
import comfy_aimdo.model_vbar

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
Expand Down Expand Up @@ -1206,43 +1203,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):


def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None

cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])

if dtype is None:
dtype = weight._model_dtype

signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
return v_tensor.to(dtype=dtype)

r = torch.empty_like(weight, dtype=dtype, device=device)

if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)

return r

if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
Expand Down
50 changes: 21 additions & 29 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,10 +1435,6 @@ def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weig

def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
Expand All @@ -1461,9 +1457,7 @@ def _vbar_get(self, create=False):

def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory

def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
Expand Down Expand Up @@ -1504,6 +1498,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False

num_patches = 0
allocated_size = 0
self.model.model_loaded_weight_memory = 0

with self.use_ejected():
self.unpatch_hooks()
Expand All @@ -1512,10 +1507,6 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
if vbar is not None:
vbar.prioritize()

#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)

Expand Down Expand Up @@ -1558,6 +1549,9 @@ def force_load_param(self, param_key, device_to):
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
weight, _, _ = get_key_weight(self.model, key)
if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()

if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
Expand All @@ -1583,21 +1577,15 @@ def force_load_param(self, param_key, device_to):
for param in params:
key = key_param_name_to_key(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()

move_weight_functions(m, device_to)

logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")

self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
Expand All @@ -1613,7 +1601,16 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
assert self.load_device != torch.device("cpu")

vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)

if freed < memory_to_free:
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0

return freed

def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
Expand All @@ -1640,11 +1637,6 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
for m in self.model.modules():
move_weight_functions(m, device_to)

keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)

def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
Expand Down
10 changes: 5 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags

import comfy_aimdo.control

if enables_dynamic_vram():
comfy_aimdo.control.init()

if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'

setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)

import comfy_aimdo.control

if enables_dynamic_vram():
comfy_aimdo.control.init()

if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'

Expand Down
Loading