Skip to content

Commit 535c16c

Browse files
authored
Widen OOM_EXCEPTION to AcceleratorError form (Comfy-Org#12835)
Pytorch only filters for OOMs in its own allocators however there are paths that can OOM on allocators made outside the pytorch allocators. These manifest as an AllocatorError as pytorch does not have universal error translation to its OOM type on exception. Handle it. A log I have for this also shows a double report of the error async, so call the async discarder to cleanup and make these OOMs look like OOMs.
1 parent a912809 commit 535c16c

File tree

7 files changed

+27
-8
lines changed

7 files changed

+27
-8
lines changed

comfy/ldm/modules/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
372372
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
373373
del s2
374374
break
375-
except model_management.OOM_EXCEPTION as e:
375+
except Exception as e:
376+
model_management.raise_non_oom(e)
376377
if first_op_done == False:
377378
model_management.soft_empty_cache(True)
378379
if cleared_cache == False:

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ def slice_attention(q, k, v):
258258
r1[:, :, i:end] = torch.bmm(v, s2)
259259
del s2
260260
break
261-
except model_management.OOM_EXCEPTION as e:
261+
except Exception as e:
262+
model_management.raise_non_oom(e)
262263
model_management.soft_empty_cache(True)
263264
steps *= 2
264265
if steps > 128:
@@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
314315
try:
315316
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
316317
out = out.transpose(2, 3).reshape(orig_shape)
317-
except model_management.OOM_EXCEPTION:
318+
except Exception as e:
319+
model_management.raise_non_oom(e)
318320
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
319321
oom_fallback = True
320322
if oom_fallback:

comfy/ldm/modules/sub_quadratic_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
169169
try:
170170
attn_probs = attn_scores.softmax(dim=-1)
171171
del attn_scores
172-
except model_management.OOM_EXCEPTION:
172+
except Exception as e:
173+
model_management.raise_non_oom(e)
173174
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
174175
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
175176
torch.exp(attn_scores, out=attn_scores)

comfy/model_management.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,18 @@ def mac_version():
270270
except:
271271
OOM_EXCEPTION = Exception
272272

273+
def is_oom(e):
274+
if isinstance(e, OOM_EXCEPTION):
275+
return True
276+
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2:
277+
discard_cuda_async_error()
278+
return True
279+
return False
280+
281+
def raise_non_oom(e):
282+
if not is_oom(e):
283+
raise e
284+
273285
XFORMERS_VERSION = ""
274286
XFORMERS_ENABLED_VAE = True
275287
if args.disable_xformers:

comfy/sd.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,8 @@ def decode(self, samples_in, vae_options={}):
954954
if pixel_samples is None:
955955
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
956956
pixel_samples[x:x+batch_number] = out
957-
except model_management.OOM_EXCEPTION:
957+
except Exception as e:
958+
model_management.raise_non_oom(e)
958959
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
959960
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
960961
#exception and the exception itself refs them all until we get out of this except block.
@@ -1029,7 +1030,8 @@ def encode(self, pixel_samples):
10291030
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
10301031
samples[x:x + batch_number] = out
10311032

1032-
except model_management.OOM_EXCEPTION:
1033+
except Exception as e:
1034+
model_management.raise_non_oom(e)
10331035
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
10341036
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
10351037
#exception and the exception itself refs them all until we get out of this except block.

comfy_extras/nodes_upscale_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def execute(cls, upscale_model, image) -> io.NodeOutput:
8686
pbar = comfy.utils.ProgressBar(steps)
8787
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
8888
oom = False
89-
except model_management.OOM_EXCEPTION as e:
89+
except Exception as e:
90+
model_management.raise_non_oom(e)
9091
tile //= 2
9192
if tile < 128:
9293
raise e

execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ async def await_completion():
612612
logging.error(traceback.format_exc())
613613
tips = ""
614614

615-
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
615+
if comfy.model_management.is_oom(ex):
616616
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
617617
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
618618
logging.error("Got an OOM, unloading all loaded models.")

0 commit comments

Comments
 (0)