Skip to content

Commit 361b9a8

Browse files
authored
fix pinning with model defined dtype (Comfy-Org#12208)
pinned memory was converted back to pinning the CPU side weight without any changes. Fix the pinner to use the CPU weight and not the model defined geometry. This will either save RAM or stop buffer overruns when the types mismatch. Fix the model defined weight caster to use the [ s.weight, s.bias ] interpretation, as xfer_dest might be the flattened pin now. Fix the detection of needing to cast to not be conditional on !pin.
1 parent 667a1b8 commit 361b9a8

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

comfy/ops.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
9696
pin = comfy.pinned_memory.get_pin(s)
9797
if pin is not None:
9898
xfer_source = [ pin ]
99-
else:
100-
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
101-
if data is None:
102-
continue
103-
if data.dtype != geometry.dtype:
104-
cast_dest = xfer_dest
105-
if cast_dest is None:
106-
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
107-
xfer_dest = None
108-
break
99+
100+
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
101+
if data is None:
102+
continue
103+
if data.dtype != geometry.dtype:
104+
cast_dest = xfer_dest
105+
if cast_dest is None:
106+
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
107+
xfer_dest = None
108+
break
109109

110110
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
111111
offload_stream = comfy.model_management.get_offload_stream(device)
@@ -132,7 +132,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
132132
comfy.model_management.sync_stream(device, offload_stream)
133133

134134
if cast_dest is not None:
135-
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest),
135+
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
136136
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
137137
if post_cast is not None:
138138
post_cast.copy_(pre_cast)

comfy/pinned_memory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ def pin_memory(module):
1111
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
1212
return
1313
#FIXME: This is a RAM cache trigger event
14-
params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ])
15-
size = comfy.memory_management.vram_aligned_size(params)
14+
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
1615
pin = torch.empty((size,), dtype=torch.uint8)
1716
if comfy.model_management.pin_memory(pin):
1817
module._pin = pin

0 commit comments

Comments
 (0)