Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
from einops import rearrange

from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM

from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2
from .module_io import HunyuanVideo15InferModuleOutput

try:
from sgl_kernel.elementwise import timestep_embedding as timestep_embedding_cuda

TIMESTEP_EMBEDDING_CUDA_AVAILABLE = PLATFORM == "cuda"
except ImportError:
TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False


def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Expand Down Expand Up @@ -201,6 +208,14 @@ def timestep_embedding(self, t, dim, max_period=10000):

.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
if TIMESTEP_EMBEDDING_CUDA_AVAILABLE:
return timestep_embedding_cuda(
t,
dim,
flip_sin_to_cos=True,
max_period=max_period,
)

half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
Expand Down
20 changes: 19 additions & 1 deletion lightx2v/models/schedulers/qwen_image/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
from torch.nn import functional as F

from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM

try:
from sgl_kernel.elementwise import timestep_embedding as timestep_embedding_cuda

TIMESTEP_EMBEDDING_CUDA_AVAILABLE = PLATFORM == "cuda"
except ImportError:
TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False


def calculate_shift(
Expand Down Expand Up @@ -160,6 +167,17 @@ def get_timestep_embedding(
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""

if TIMESTEP_EMBEDDING_CUDA_AVAILABLE:
return timestep_embedding_cuda(
timesteps,
embedding_dim,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
scale=scale,
max_period=max_period,
)

assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

half_dim = embedding_dim // 2
Expand Down
20 changes: 19 additions & 1 deletion lightx2v/models/schedulers/z_image/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
from torch.nn import functional as F

from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM

try:
from sgl_kernel.elementwise import timestep_embedding as timestep_embedding_cuda

TIMESTEP_EMBEDDING_CUDA_AVAILABLE = PLATFORM == "cuda"
except ImportError:
TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False


def calculate_shift(
Expand Down Expand Up @@ -172,6 +179,17 @@ def get_timestep_embedding(
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""

if TIMESTEP_EMBEDDING_CUDA_AVAILABLE:
return timestep_embedding_cuda(
timesteps,
embedding_dim,
flip_sin_to_cos=flip_sin_to_cos,
downscale_freq_shift=downscale_freq_shift,
scale=scale,
max_period=max_period,
)

assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

half_dim = embedding_dim // 2
Expand Down
1 change: 1 addition & 0 deletions lightx2v_platform/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def init_ai_device(platform="cuda"):
if platform_device is None:
available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
raise RuntimeError(f"Unsupported PLATFORM: {platform}. Available PLATFORM: {available_platforms}")
global_var.PLATFORM = platform
global_var.AI_DEVICE = platform_device.get_device()
platform_device.init_device_env()
logger.info(f"Initialized AI_DEVICE: {global_var.AI_DEVICE}")
Expand Down
1 change: 1 addition & 0 deletions lightx2v_platform/base/global_var.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
AI_DEVICE = None
PLATFORM = None