Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
Expand Down
6 changes: 4 additions & 2 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
Expand Down Expand Up @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
Expand Down
3 changes: 2 additions & 1 deletion comfy/ldm/modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)
Expand Down
12 changes: 12 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ def mac_version():
except:
OOM_EXCEPTION = Exception

def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
discard_cuda_async_error()
return True
return False

def raise_non_oom(e):
if not is_oom(e):
raise e

XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
Expand Down
6 changes: 4 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,8 @@ def decode(self, samples_in, vae_options={}):
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
Expand Down Expand Up @@ -1029,7 +1030,8 @@ def encode(self, pixel_samples):
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out

except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
Expand Down
3 changes: 2 additions & 1 deletion comfy_extras/nodes_upscale_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def execute(cls, upscale_model, image) -> io.NodeOutput:
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
tile //= 2
if tile < 128:
raise e
Expand Down
2 changes: 1 addition & 1 deletion execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ async def await_completion():
logging.error(traceback.format_exc())
tips = ""

if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
Expand Down
Loading