Skip to content

Conversation

@gushiqiao
Copy link
Contributor

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @gushiqiao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request undertakes a substantial refactoring of the model architecture by introducing a BaseTransformerModel class. This new base class centralizes common functionalities such as weight loading, quantization handling, CPU/CUDA offloading, and LoRA integration, which were previously duplicated across various transformer model implementations. The refactoring simplifies the code in individual model classes, making them more maintainable and reducing redundancy. Additionally, a generic LoraAdapter has been created to replace model-specific LoRA wrappers, further streamlining the LoRA application process.

Highlights

  • New Base Model Class: A BaseTransformerModel class has been introduced to centralize common functionalities like weight loading, quantization, CPU/CUDA offloading, and LoRA handling across various transformer models.
  • Model Refactoring: Existing transformer models (e.g., HunyuanVideo15Model, LTX2Model, QwenImageTransformerModel, WanModel, ZImageTransformerModel) have been refactored to inherit from the new BaseTransformerModel, significantly reducing code duplication.
  • Centralized LoRA Management: A generic LoraAdapter class has been created and moved to a common location, replacing model-specific LoRA wrappers and streamlining the application of LoRA weights.
  • Simplified Initialization: Subclasses now have simpler __init__ methods, as much of the model setup and weight management logic is delegated to the BaseTransformerModel.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a BaseTransformerModel to consolidate common functionality across different transformer models, such as weight loading, CPU/CUDA offloading, lazy loading, and distributed weight loading. This refactoring significantly reduces code duplication in hunyuan_video, longcat_image, ltx2, qwen_image, wan, and z_image models. A generic LoraAdapter class has also been introduced to handle LoRA weight application, replacing model-specific wrappers. The changes improve maintainability and modularity by centralizing core logic. I've identified a few areas for improvement regarding redundant garbage collection calls, a potential issue with device handling for safe_open, and some inconsistent error handling.

Comment on lines +75 to +76
if dit_quantized:
self._check_dit_quantized()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dit_quantized variable is not defined in this scope. It should be self.dit_quantized to refer to the instance attribute initialized earlier in the __init__ method.

Suggested change
if dit_quantized:
self._check_dit_quantized()
if self.dit_quantized:
self._check_dit_quantized()

Comment on lines +229 to +232
if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank()
else:
device = str(self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When dist.is_initialized() is true, dist.get_rank() returns an integer. Passing an integer directly as the device argument to safe_open might not be correctly interpreted as a CUDA device string (e.g., 'cuda:0'). It should likely be formatted as f"cuda:{dist.get_rank()}" to ensure safe_open correctly identifies the target device.

Suggested change
if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank()
else:
device = str(self.device)
if self.device.type != "cpu" and dist.is_initialized():
device = f"cuda:{dist.get_rank()}"
else:
device = str(self.device)

Comment on lines +72 to +74
def _seq_parallel_pre_process(self, pre_infer_out):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This _seq_parallel_pre_process method raises a NotImplementedError, but the __init__ method checks self.config["seq_parallel"] and if true, _infer_cond_uncond will call this method. This means if sequence parallel is enabled, the model will fail at runtime. This indicates a logical inconsistency: either sequence parallel should not be enabled for this model, or this method needs a proper implementation.

Suggested change
def _seq_parallel_pre_process(self, pre_infer_out):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
# TODO: Implement sequence parallel pre-processing for QwenImageTransformerModel
raise NotImplementedError("Sequence parallel pre-process is not implemented for QwenImageTransformerModel")

Comment on lines +83 to +85
def _seq_parallel_post_process(self, noise_pred):
world_size = dist.get_world_size(self.seq_p_group)
gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to _seq_parallel_pre_process, this method raises a NotImplementedError but will be called if self.config["seq_parallel"] is true. This creates a runtime error if sequence parallel is enabled. Ensure that if seq_parallel is supported, these methods are properly implemented, or prevent seq_parallel from being enabled for this model.

Suggested change
def _seq_parallel_post_process(self, noise_pred):
world_size = dist.get_world_size(self.seq_p_group)
gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)]
@torch.no_grad()
def _seq_parallel_post_process(self, noise_pred):
# TODO: Implement sequence parallel post-processing for QwenImageTransformerModel
raise NotImplementedError("Sequence parallel post-process is not implemented for QwenImageTransformerModel")

Comment on lines +84 to +116
assert self.config.get("dit_quant_scheme", "Default") in [
"fp8-pertensor",
"fp8-triton",
"int8-triton",
"fp8-vllm",
"int8-vllm",
"int8-vllm-hygon-dcu",
"fp8-q8f",
"int8-q8f",
"fp8-b128-deepgemm",
"fp8-sgl",
"int8-sgl",
"int8-torchao",
"fp8-torchao",
"nvfp4",
"mxfp4",
"mxfp6-mxfp8",
"mxfp8",
"int8-tmo",
"gguf-Q8_0",
"gguf-Q6_K",
"gguf-Q5_K_S",
"gguf-Q5_K_M",
"gguf-Q5_0",
"gguf-Q5_1",
"gguf-Q4_K_S",
"gguf-Q4_K_M",
"gguf-Q4_0",
"gguf-Q4_1",
"gguf-Q3_K_S",
"gguf-Q3_K_M",
"int8-npu",
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion is very long and impacts readability. Consider defining a set of valid quantization schemes as a class attribute and checking membership against that set. This would make the code cleaner and easier to maintain if new schemes are added or existing ones are modified.


with safe_open(file_path, framework="pt", device=device) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sensitive_layer is initialized as an empty dictionary {} in __init__ and _init_weights. If it remains empty, the condition all(s not in key for s in sensitive_layer) will always evaluate to True, effectively bypassing any sensitive layer specific dtype handling. If sensitive_layer is meant to be populated by subclasses, this should be explicitly documented, or a default non-empty value should be provided if it's always expected to have patterns.

Comment on lines +440 to +456
for key in sorted(synced_meta_dict.keys()):
if is_weight_loader:
distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True)

if target_device == "cpu":
if is_weight_loader:
gpu_tensor = distributed_weight_dict[key].cuda()
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
else:
gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda")
dist.broadcast(gpu_tensor, src=global_src_rank)
distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True)
del gpu_tensor
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for CPU offload during distributed weight loading involves moving tensors to CUDA, broadcasting, and then moving them back to CPU. This process adds unnecessary overhead. If dist.broadcast supports direct broadcasting of CPU tensors (which it typically does), it would be more efficient to perform the broadcast directly on the CPU tensors without intermediate CUDA transfers. If not, this workaround is necessary but should be noted for potential optimization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants