-
Notifications
You must be signed in to change notification settings - Fork 153
refactor for models #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
refactor for models #849
Conversation
Summary of ChangesHello @gushiqiao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request undertakes a substantial refactoring of the model architecture by introducing a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request introduces a BaseTransformerModel to consolidate common functionality across different transformer models, such as weight loading, CPU/CUDA offloading, lazy loading, and distributed weight loading. This refactoring significantly reduces code duplication in hunyuan_video, longcat_image, ltx2, qwen_image, wan, and z_image models. A generic LoraAdapter class has also been introduced to handle LoRA weight application, replacing model-specific wrappers. The changes improve maintainability and modularity by centralizing core logic. I've identified a few areas for improvement regarding redundant garbage collection calls, a potential issue with device handling for safe_open, and some inconsistent error handling.
| if dit_quantized: | ||
| self._check_dit_quantized() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if self.device.type != "cpu" and dist.is_initialized(): | ||
| device = dist.get_rank() | ||
| else: | ||
| device = str(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When dist.is_initialized() is true, dist.get_rank() returns an integer. Passing an integer directly as the device argument to safe_open might not be correctly interpreted as a CUDA device string (e.g., 'cuda:0'). It should likely be formatted as f"cuda:{dist.get_rank()}" to ensure safe_open correctly identifies the target device.
| if self.device.type != "cpu" and dist.is_initialized(): | |
| device = dist.get_rank() | |
| else: | |
| device = str(self.device) | |
| if self.device.type != "cpu" and dist.is_initialized(): | |
| device = f"cuda:{dist.get_rank()}" | |
| else: | |
| device = str(self.device) |
| def _seq_parallel_pre_process(self, pre_infer_out): | ||
| world_size = dist.get_world_size(self.seq_p_group) | ||
| cur_rank = dist.get_rank(self.seq_p_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This _seq_parallel_pre_process method raises a NotImplementedError, but the __init__ method checks self.config["seq_parallel"] and if true, _infer_cond_uncond will call this method. This means if sequence parallel is enabled, the model will fail at runtime. This indicates a logical inconsistency: either sequence parallel should not be enabled for this model, or this method needs a proper implementation.
| def _seq_parallel_pre_process(self, pre_infer_out): | |
| world_size = dist.get_world_size(self.seq_p_group) | |
| cur_rank = dist.get_rank(self.seq_p_group) | |
| @torch.no_grad() | |
| def _seq_parallel_pre_process(self, pre_infer_out): | |
| # TODO: Implement sequence parallel pre-processing for QwenImageTransformerModel | |
| raise NotImplementedError("Sequence parallel pre-process is not implemented for QwenImageTransformerModel") |
| def _seq_parallel_post_process(self, noise_pred): | ||
| world_size = dist.get_world_size(self.seq_p_group) | ||
| gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to _seq_parallel_pre_process, this method raises a NotImplementedError but will be called if self.config["seq_parallel"] is true. This creates a runtime error if sequence parallel is enabled. Ensure that if seq_parallel is supported, these methods are properly implemented, or prevent seq_parallel from being enabled for this model.
| def _seq_parallel_post_process(self, noise_pred): | |
| world_size = dist.get_world_size(self.seq_p_group) | |
| gathered_noise_pred = [torch.empty_like(noise_pred) for _ in range(world_size)] | |
| @torch.no_grad() | |
| def _seq_parallel_post_process(self, noise_pred): | |
| # TODO: Implement sequence parallel post-processing for QwenImageTransformerModel | |
| raise NotImplementedError("Sequence parallel post-process is not implemented for QwenImageTransformerModel") |
| assert self.config.get("dit_quant_scheme", "Default") in [ | ||
| "fp8-pertensor", | ||
| "fp8-triton", | ||
| "int8-triton", | ||
| "fp8-vllm", | ||
| "int8-vllm", | ||
| "int8-vllm-hygon-dcu", | ||
| "fp8-q8f", | ||
| "int8-q8f", | ||
| "fp8-b128-deepgemm", | ||
| "fp8-sgl", | ||
| "int8-sgl", | ||
| "int8-torchao", | ||
| "fp8-torchao", | ||
| "nvfp4", | ||
| "mxfp4", | ||
| "mxfp6-mxfp8", | ||
| "mxfp8", | ||
| "int8-tmo", | ||
| "gguf-Q8_0", | ||
| "gguf-Q6_K", | ||
| "gguf-Q5_K_S", | ||
| "gguf-Q5_K_M", | ||
| "gguf-Q5_0", | ||
| "gguf-Q5_1", | ||
| "gguf-Q4_K_S", | ||
| "gguf-Q4_K_M", | ||
| "gguf-Q4_0", | ||
| "gguf-Q4_1", | ||
| "gguf-Q3_K_S", | ||
| "gguf-Q3_K_M", | ||
| "int8-npu", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| with safe_open(file_path, framework="pt", device=device) as f: | ||
| return { | ||
| key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sensitive_layer is initialized as an empty dictionary {} in __init__ and _init_weights. If it remains empty, the condition all(s not in key for s in sensitive_layer) will always evaluate to True, effectively bypassing any sensitive layer specific dtype handling. If sensitive_layer is meant to be populated by subclasses, this should be explicitly documented, or a default non-empty value should be provided if it's always expected to have patterns.
| for key in sorted(synced_meta_dict.keys()): | ||
| if is_weight_loader: | ||
| distributed_weight_dict[key].copy_(weight_dict[key], non_blocking=True) | ||
|
|
||
| if target_device == "cpu": | ||
| if is_weight_loader: | ||
| gpu_tensor = distributed_weight_dict[key].cuda() | ||
| dist.broadcast(gpu_tensor, src=global_src_rank) | ||
| distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) | ||
| del gpu_tensor | ||
| torch.cuda.empty_cache() | ||
| else: | ||
| gpu_tensor = torch.empty_like(distributed_weight_dict[key], device="cuda") | ||
| dist.broadcast(gpu_tensor, src=global_src_rank) | ||
| distributed_weight_dict[key].copy_(gpu_tensor.cpu(), non_blocking=True) | ||
| del gpu_tensor | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for CPU offload during distributed weight loading involves moving tensors to CUDA, broadcasting, and then moving them back to CPU. This process adds unnecessary overhead. If dist.broadcast supports direct broadcasting of CPU tensors (which it typically does), it would be more efficient to perform the broadcast directly on the CPU tensors without intermediate CUDA transfers. If not, this workaround is necessary but should be noted for potential optimization.
No description provided.