Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a

## Get Started

### Local

#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
Expand All @@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).

## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
### Cloud

#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.

## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).

## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
Expand Down
59 changes: 59 additions & 0 deletions comfy/memory_management.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,68 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple

from comfy.quant_ops import QuantizedTensor


class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int


def read_tensor_file_slice_into(tensor, destination):

if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False

if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False

dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True

info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False

file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
return False

if info.size == 0:
return True

buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))

try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()

class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype
Expand Down
74 changes: 62 additions & 12 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,28 @@ def module_size(module):
module_mem += t.nbytes
return module_mem

def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem

class LoadedModel:
def __init__(self, model):
self._set_model(model)
Expand Down Expand Up @@ -532,6 +554,9 @@ def model(self):
def model_memory(self):
return self.model.model_size()

def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)

def model_loaded_memory(self):
return self.model.loaded_size()

Expand Down Expand Up @@ -633,7 +658,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()

def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
Expand All @@ -646,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False

for x in sorted(can_unload):
can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
ram_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram()
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
Expand All @@ -661,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
if ram_to_free > 0:
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)

for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)

for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
Expand Down Expand Up @@ -729,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu


total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
#want to do.
#FIXME: This should subtract off the to_load current pin consumption.
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required

for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])

for device in total_memory_required:
if device != torch.device("cpu"):
Expand Down Expand Up @@ -1225,6 +1270,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0)
if tensor is None:
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)


Expand Down
17 changes: 17 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ def model_size(self):
self.size = comfy.model_management.module_size(self.model)
return self.size

def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)

def get_ram_usage(self):
return self.model_size()

Expand Down Expand Up @@ -1063,6 +1066,10 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

return self.model.model_loaded_weight_memory - current_used

def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0

def partially_unload_ram(self, ram_to_unload):
pass

Expand Down Expand Up @@ -1653,6 +1660,16 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals

return freed

def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total

def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
Expand Down
102 changes: 79 additions & 23 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,33 @@ class CastWeightBiasOp:
bias_function = []

class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)

if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")

if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")

class Linear(torch.nn.Linear, CastWeightBiasOp):

def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
Expand Down Expand Up @@ -333,29 +360,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
if not assign_to_params_buffers:
v = v.clone()
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)

#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.in_features, self.out_features),
bias_shape=(self.out_features,),
)


def reset_parameters(self):
Expand Down Expand Up @@ -547,6 +561,48 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)

class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return

torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype

def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):

if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)

def reset_parameters(self):
self.bias = None
return None
Expand Down
Loading
Loading