Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion diffsynth/diffusion/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from einops import repeat, reduce
from typing import Union
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..utils.lora import GeneralLoRALoader
from ..models.model_loader import ModelPool
from ..utils.controlnet import ControlNetInput
Expand Down Expand Up @@ -61,7 +62,7 @@ class BasePipeline(torch.nn.Module):

def __init__(
self,
device="cuda", torch_dtype=torch.float16,
device=get_device_type(), torch_dtype=torch.float16,
height_division_factor=64, width_division_factor=64,
time_division_factor=None, time_division_remainder=None,
):
Expand Down
4 changes: 3 additions & 1 deletion diffsynth/models/dinov3_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch

from ..core.device.npu_compatible_device import get_device_type


class DINOv3ImageEncoder(DINOv3ViTModel):
def __init__(self):
Expand Down Expand Up @@ -70,7 +72,7 @@ def __init__(self):
}
)

def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
bool_masked_pos = None
Expand Down
11 changes: 6 additions & 5 deletions diffsynth/models/longcat_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward


Expand Down Expand Up @@ -373,7 +374,7 @@ def forward(self, x, t, latent_shape):
B, N, C = x.shape
T, _, _ = latent_shape

with amp.autocast('cuda', dtype=torch.float32):
with amp.autocast(get_device_type(), dtype=torch.float32):
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
Expand Down Expand Up @@ -583,7 +584,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.

# compute modulation params in fp32
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
Expand All @@ -602,7 +603,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return
else:
x_s = attn_outputs

with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)

Expand All @@ -615,7 +616,7 @@ def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return
# ffn with modulation
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)

Expand Down Expand Up @@ -797,7 +798,7 @@ def forward(

hidden_states = self.x_embedder(hidden_states) # [B, N, C]

with amp.autocast(device_type='cuda', dtype=torch.float32):
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]

encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
Expand Down
2 changes: 1 addition & 1 deletion diffsynth/models/nexus_gen_ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _sample(
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
):
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
Expand Down
4 changes: 3 additions & 1 deletion diffsynth/models/siglip2_image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch

from diffsynth.core.device.npu_compatible_device import get_device_type


class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
Expand Down Expand Up @@ -47,7 +49,7 @@ def __init__(self):
}
)

def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False
Expand Down
19 changes: 10 additions & 9 deletions diffsynth/models/step1x_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from typing import Optional, Union
from .qwen_image_text_encoder import QwenImageTextEncoder
from ..core.device.npu_compatible_device import get_device_type, get_torch_device


class Step1xEditEmbedder(torch.nn.Module):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
super().__init__()
self.max_length = max_length
self.dtype = dtype
Expand Down Expand Up @@ -77,13 +78,13 @@ def forward(self, caption, ref_images):
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_torch_device().current_device() will cause a crash on CPU-only systems. The get_torch_device function falls back to torch.cuda when the device is 'cpu', and torch.cuda.current_device() will then fail if no CUDA device is available. A simpler and more robust approach is to use get_device_type() directly, as it returns a device string ('cpu', 'cuda', 'npu') that is accepted by PyTorch tensor creation functions.

Suggested change
device=get_torch_device().current_device(),
device=get_device_type(),

)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_torch_device().current_device() will cause a crash on CPU-only systems. The get_torch_device function falls back to torch.cuda when the device is 'cpu', and torch.cuda.current_device() will then fail if no CUDA device is available. A simpler and more robust approach is to use get_device_type() directly, as it returns a device string ('cpu', 'cuda', 'npu') that is accepted by PyTorch tensor creation functions.

Suggested change
device=get_torch_device().current_device(),
device=get_device_type(),

)

def split_string(s):
Expand Down Expand Up @@ -158,7 +159,7 @@ def split_string(s):
else:
token_list.append(token_each)

new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())

new_txt_ids = new_txt_ids.to(old_inputs_ids.device)

Expand All @@ -167,15 +168,15 @@ def split_string(s):
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
.to(get_device_type())
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
outputs = self.model_forward(
self.model,
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
pixel_values=inputs.pixel_values.to(get_device_type()),
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
output_hidden_states=True,
)

Expand All @@ -188,7 +189,7 @@ def split_string(s):
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
device=get_torch_device().current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using get_torch_device().current_device() will cause a crash on CPU-only systems. The get_torch_device function falls back to torch.cuda when the device is 'cpu', and torch.cuda.current_device() will then fail if no CUDA device is available. A simpler and more robust approach is to use get_device_type() directly, as it returns a device string ('cpu', 'cuda', 'npu') that is accepted by PyTorch tensor creation functions.

Suggested change
device=get_torch_device().current_device(),
device=get_device_type(),

)

return embs, masks
6 changes: 3 additions & 3 deletions diffsynth/models/z_image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .general_modules import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward


Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):

@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
Expand Down Expand Up @@ -105,7 +105,7 @@ def forward(self, hidden_states, freqs_cis, attention_mask):

# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
Expand Down
5 changes: 3 additions & 2 deletions diffsynth/pipelines/flux2_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from typing import Union, List, Optional, Tuple

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
Expand All @@ -19,7 +20,7 @@

class Flux2ImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand All @@ -45,7 +46,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
Expand Down
11 changes: 6 additions & 5 deletions diffsynth/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
Expand Down Expand Up @@ -55,7 +56,7 @@ def forward(self, conditionings: list[torch.Tensor], controlnet_inputs: list[Con

class FluxImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand Down Expand Up @@ -117,7 +118,7 @@ def enable_lora_merger(self):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
Expand Down Expand Up @@ -377,7 +378,7 @@ def encode_prompt(
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
Expand Down Expand Up @@ -558,7 +559,7 @@ def encode_prompt(
text_encoder_2,
prompt,
positive=True,
device="cuda",
device=get_device_type(),
t5_sequence_length=512,
):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
Expand Down Expand Up @@ -793,7 +794,7 @@ def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controlle


class InfinitYou(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__()
from facexlib.recognition import init_recognition_model
from insightface.app import FaceAnalysis
Expand Down
5 changes: 3 additions & 2 deletions diffsynth/pipelines/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from math import prod

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
Expand All @@ -22,7 +23,7 @@

class QwenImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
processor_config: ModelConfig = None,
Expand Down
7 changes: 4 additions & 3 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing_extensions import Literal
from transformers import Wav2Vec2Processor

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
Expand All @@ -30,7 +31,7 @@

class WanVideoPipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
Expand Down Expand Up @@ -98,7 +99,7 @@ def enable_usp(self):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
audio_processor_config: ModelConfig = None,
Expand Down Expand Up @@ -964,7 +965,7 @@ def __init__(self):
onload_model_names=("vae",)
)

def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
Expand Down
5 changes: 3 additions & 2 deletions diffsynth/pipelines/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from typing import Union, List, Optional, Tuple, Iterable, Dict

from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..core.data.operators import ImageCropAndResize
Expand All @@ -25,7 +26,7 @@

class ZImagePipeline(BasePipeline):

def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
vram_limit: float = None,
Expand Down
3 changes: 2 additions & 1 deletion diffsynth/utils/controlnet/annotator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing_extensions import Literal, TypeAlias

from diffsynth.core.device.npu_compatible_device import get_device_type

Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
]

class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
if not skip_processor:
if processor_id == "canny":
from controlnet_aux.processor import CannyDetector
Expand Down