Skip to content

Commit dcd6595

Browse files
Make more intermediate values follow the intermediate dtype. (Comfy-Org#13051)
1 parent b67ed2a commit dcd6595

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

comfy/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
6464
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
6565

6666
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
67-
samples = samples.to(comfy.model_management.intermediate_device())
67+
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
6868
return samples
6969

7070
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
7171
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
72-
samples = samples.to(comfy.model_management.intermediate_device())
72+
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
7373
return samples

comfy/sd1_clip.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def encode_token_weights(self, token_weight_pairs):
4646
out, pooled = o[:2]
4747

4848
if pooled is not None:
49-
first_pooled = pooled[0:1].to(model_management.intermediate_device())
49+
first_pooled = pooled[0:1].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
5050
else:
5151
first_pooled = pooled
5252

@@ -63,16 +63,16 @@ def encode_token_weights(self, token_weight_pairs):
6363
output.append(z)
6464

6565
if (len(output) == 0):
66-
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
66+
r = (out[-1:].to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
6767
else:
68-
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
68+
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype()), first_pooled)
6969

7070
if len(o) > 2:
7171
extra = {}
7272
for k in o[2]:
7373
v = o[2][k]
7474
if k == "attention_mask":
75-
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
75+
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device(), dtype=model_management.intermediate_dtype())
7676
extra[k] = v
7777

7878
r = r + (extra,)

0 commit comments

Comments
 (0)