Skip to content

Commit 35e9fce

Browse files
authored
Enable Pytorch Attention for gfx950 (Comfy-Org#12641)
1 parent c7f7d52 commit 35e9fce

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

comfy/model_management.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def amd_min_version(device=None, min_rdna_version=0):
350350

351351
try:
352352
if is_amd():
353-
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
353+
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0]
354354
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
355355
if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1':
356356
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
@@ -378,7 +378,7 @@ def aotriton_supported(gpu_arch):
378378
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
379379
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
380380
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
381-
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
381+
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
382382
ENABLE_PYTORCH_ATTENTION = True
383383
if rocm_version >= (7, 0):
384384
if any((a in arch) for a in ["gfx1200", "gfx1201"]):

0 commit comments

Comments
 (0)