-
Notifications
You must be signed in to change notification settings - Fork 7k
Add LoRA support for Cosmos Predict 2.5 and fix pipeline to match official Cosmos repo #13664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
terarachang
wants to merge
15
commits into
huggingface:main
Choose a base branch
from
terarachang:cosmos_predict_2.5_lora_clean
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
48fbe84
support lora for cosmos 2.5
5216816
Fix inconsistencies with cosmos official repo in VAE encoding, text e…
3d5551b
Support f_min and f_max in linear_scheduler warmup
f0b51ed
Add requirements and dataset preprocessing scripts to run examples
9038e10
Add LoRA training scripts
cee46cc
Add LoRA eval scripts
9e24319
add assets for blogpost
c3bb1aa
Fix(scheduler): device mismatch from upstream b114620 - move rk and b…
d32b085
Always upcast to fp32
80aa547
Directly inhrit from LoraBaseMixin
c8513c2
remove flash-attn2
49f5b35
Use _keep_in_fp32_modules instead of autocast
6acdc1c
remove the get_latent_shape_cthw method and fix style
d4b27f6
simplifiy the eval script to make it more user-friendly
3cc1210
overwrite scheduling_unipc_multistep.py with main's version
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # LoRA fine-tuning for Cosmos Predict 2.5 | ||
|
|
||
| This example shows how to fine-tune [Cosmos Predict 2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B) using LoRA on a custom video dataset. | ||
|
|
||
| ## Requirements | ||
|
|
||
| Install the library from source and the example-specific dependencies: | ||
|
|
||
| ```bash | ||
| git clone https://github.com/huggingface/diffusers | ||
| cd diffusers | ||
| pip install -e ".[dev]" | ||
| cd examples/cosmos | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| ## Data preparation | ||
|
|
||
| The training script expects a dataset directory with the following layout: | ||
|
|
||
| ``` | ||
| <dataset_dir>/ | ||
| ├── videos/ # .mp4 files | ||
| └── metas/ # one .txt prompt file per video (same stem) | ||
| ├── 0.txt | ||
| ├── 1.txt | ||
| └── ... | ||
| ``` | ||
|
|
||
| ### GR1 dataset (quick start) | ||
|
|
||
| The `download_and_preprocess_datasets.sh` script downloads the GR1-100 training set and the EVAL-175 test set, then runs the preprocessing script to create the per-video prompt files. | ||
|
|
||
| ```bash | ||
| bash download_and_preprocess_datasets.sh | ||
| ``` | ||
|
|
||
| This produces: | ||
| - `gr1_dataset/train/` — training videos + prompts | ||
| - `gr1_dataset/test/` — evaluation images + prompts | ||
|
|
||
| ## Training | ||
|
|
||
| Launch LoRA training with `accelerate`: | ||
|
|
||
| ```bash | ||
| export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B" | ||
| export DATA_DIR="gr1_dataset/train" | ||
| export OUT_DIR="lora-output" | ||
|
|
||
| accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \ | ||
| --pretrained_model_name_or_path=$MODEL_NAME \ | ||
| --revision diffusers/base/post-trained \ | ||
| --train_data_dir=$DATA_DIR \ | ||
| --output_dir=$OUT_DIR \ | ||
| --train_batch_size=1 \ | ||
| --num_train_epochs=500 \ | ||
| --checkpointing_epochs=100 \ | ||
| --seed=0 \ | ||
| --height 432 --width 768 \ | ||
| --allow_tf32 \ | ||
| --gradient_checkpointing \ | ||
| --lora_rank 32 --lora_alpha 32 \ | ||
| --report_to=wandb | ||
| ``` | ||
|
|
||
| Or use the provided shell script: | ||
|
|
||
| ```bash | ||
| bash train_lora.sh | ||
| ``` | ||
|
|
||
| ## Evaluation | ||
|
|
||
| Run inference with the trained LoRA adapter: | ||
|
|
||
| ```bash | ||
| export DATA_DIR="gr1_dataset/test" | ||
| export LORA_DIR="lora-output" | ||
| export OUT_DIR="eval-output" | ||
|
|
||
| python eval_cosmos_predict25_lora.py \ | ||
| --data_dir $DATA_DIR \ | ||
| --output_dir $OUT_DIR \ | ||
| --lora_dir $LORA_DIR \ | ||
| --revision diffusers/base/post-trained \ | ||
| --height 432 --width 768 \ | ||
| --num_output_frames 93 \ | ||
| --num_steps 36 \ | ||
| --seed 0 | ||
| ``` | ||
|
|
||
| Or use the provided shell script: | ||
|
|
||
| ```bash | ||
| bash eval_lora.sh | ||
| ``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| import os | ||
|
|
||
| from tqdm import tqdm | ||
|
|
||
|
|
||
| """example command | ||
| python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1 | ||
| """ | ||
|
|
||
|
|
||
| def parse_args() -> argparse.ArgumentParser: | ||
| parser = argparse.ArgumentParser(description="Create text prompts for GR1 dataset") | ||
| parser.add_argument( | ||
| "--dataset_path", type=str, default="datasets/benchmark_train/gr1", help="Root path to the dataset" | ||
| ) | ||
| parser.add_argument( | ||
| "--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt" | ||
| ) | ||
| parser.add_argument( | ||
| "--meta_csv", type=str, default=None, help="Metadata csv file (defaults to <dataset_path>/metadata.csv)" | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(args) -> None: | ||
| meta_csv = args.meta_csv or os.path.join(args.dataset_path, "metadata.csv") | ||
| meta_lines = open(meta_csv).readlines()[1:] | ||
| meta_txt_dir = os.path.join(args.dataset_path, "metas") | ||
| os.makedirs(meta_txt_dir, exist_ok=True) | ||
|
|
||
| for meta_line in tqdm(meta_lines): | ||
| video_filename, prompt = meta_line.split(",", 1) | ||
| prompt = prompt.strip("\n") | ||
| if prompt.startswith('"') and prompt.endswith('"'): | ||
| # Remove the quotes | ||
| prompt = prompt[1:-1] | ||
| prompt = args.prompt_prefix + prompt | ||
| meta_txt_filename = os.path.join(meta_txt_dir, os.path.basename(video_filename).replace(".mp4", ".txt")) | ||
| with open(meta_txt_filename, "w") as fp: | ||
| fp.write(prompt) | ||
|
|
||
| print(f"encoding prompt: {prompt}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_args() | ||
| main(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| dataset_dir='gr1_dataset' | ||
| train_dir=$dataset_dir/train | ||
| test_dir=$dataset_dir/test | ||
|
|
||
| # Download and Preprocess Training Dataset | ||
| hf download nvidia/GR1-100 --repo-type dataset --local-dir datasets/benchmark_train/hf_gr1/ && \ | ||
| mkdir -p datasets/benchmark_train/gr1/videos && \ | ||
| mv datasets/benchmark_train/hf_gr1/gr1/*mp4 datasets/benchmark_train/gr1/videos && \ | ||
| mv datasets/benchmark_train/hf_gr1/metadata.csv datasets/benchmark_train/gr1/ | ||
|
|
||
| python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1 | ||
|
|
||
| # Download Eval Dataset | ||
| hf download nvidia/EVAL-175 --repo-type dataset --local-dir dream_gen_benchmark | ||
|
|
||
|
|
||
| # Rename dataset directory | ||
| mkdir $dataset_dir | ||
| mv datasets/benchmark_train/gr1 $train_dir | ||
| mv dream_gen_benchmark/gr1_object $test_dir | ||
| echo Download training data to $train_dir | ||
| echo Download test data to $test_dir | ||
|
|
||
| # Clean up staging directories | ||
| rm -rf datasets/ dream_gen_benchmark/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import argparse | ||
| import os | ||
|
|
||
| import torch | ||
| from torch.utils.data import DataLoader, Dataset | ||
| from tqdm import tqdm | ||
|
|
||
| from diffusers import Cosmos2_5_PredictBasePipeline | ||
| from diffusers.utils import export_to_video, load_image | ||
|
|
||
|
|
||
| IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} | ||
|
|
||
|
|
||
| class ImageDataset(Dataset): | ||
| """Dataset that loads images and their corresponding text prompts. | ||
|
|
||
| Expects a directory with: | ||
| <filename>.jpg / .jpeg / .png — the conditioning image | ||
| <filename>.txt — the prompt text | ||
| """ | ||
|
|
||
| def __init__(self, data_dir: str): | ||
| self.data_dir = data_dir | ||
| self.samples = [] | ||
|
|
||
| for filename in sorted(os.listdir(data_dir)): | ||
| stem, ext = os.path.splitext(filename) | ||
| if ext.lower() not in IMAGE_EXTENSIONS: | ||
| continue | ||
| img_path = os.path.join(data_dir, filename) | ||
| txt_path = os.path.join(data_dir, stem + ".txt") | ||
| if not os.path.exists(txt_path): | ||
| print(f"WARNING: no prompt file found for {img_path}, skipping.") | ||
| continue | ||
| self.samples.append((img_path, txt_path, stem)) | ||
|
|
||
| if len(self.samples) == 0: | ||
| raise ValueError(f"No valid image/prompt pairs found in {data_dir}") | ||
|
|
||
| def __len__(self): | ||
| return len(self.samples) | ||
|
|
||
| def __getitem__(self, idx): | ||
| img_path, txt_path, stem = self.samples[idx] | ||
| image = load_image(img_path) | ||
| with open(txt_path) as f: | ||
| prompt = f.read().strip() | ||
| return { | ||
| "image": image, | ||
| "prompt": prompt, | ||
| "stem": stem, | ||
| } | ||
|
|
||
|
|
||
| def collate_fn(batch): | ||
| """Keep images as a list (PIL images can't be stacked into a tensor).""" | ||
| return { | ||
| "images": [item["image"] for item in batch], | ||
| "prompts": [item["prompt"] for item in batch], | ||
| "stems": [item["stem"] for item in batch], | ||
| } | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="Eval Cosmos Predict 2.5 with optional LoRA weights.") | ||
|
|
||
| parser.add_argument("--data_dir", type=str, required=True, help="Directory with image/prompt pairs.") | ||
| parser.add_argument("--output_dir", type=str, required=True, help="Directory to save generated outputs.") | ||
| parser.add_argument( | ||
| "--model_id", type=str, default="nvidia/Cosmos-Predict2.5-2B", help="HuggingFace model repository." | ||
| ) | ||
| parser.add_argument( | ||
| "--revision", | ||
| type=str, | ||
| default="diffusers/base/post-trained", | ||
| choices=["diffusers/base/post-trained", "diffusers/base/pre-trained"], | ||
| ) | ||
| parser.add_argument("--lora_dir", type=str, default=None, help="Path to LoRA weights directory.") | ||
| parser.add_argument("--num_output_frames", type=int, default=93, help="1 for image output, 93 for video output.") | ||
| parser.add_argument("--num_steps", type=int, default=36, help="Number of inference steps.") | ||
| parser.add_argument("--height", type=int, default=704, help="Output height in pixels (must be divisible by 16).") | ||
| parser.add_argument("--width", type=int, default=1280, help="Output width in pixels (must be divisible by 16).") | ||
| parser.add_argument("--seed", type=int, default=0, help="Random seed.") | ||
| parser.add_argument("--device", type=str, default="cuda", help="Device to use.") | ||
| parser.add_argument("--batch_size", type=int, default=1, help="Number of samples per batch.") | ||
| parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker processes.") | ||
| parser.add_argument( | ||
| "--negative_prompt", | ||
| type=str, | ||
| default=None, | ||
| help="Negative prompt. Defaults to the pipeline's built-in negative prompt.", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
| os.makedirs(args.output_dir, exist_ok=True) | ||
|
|
||
| dataset = ImageDataset(args.data_dir) | ||
| dataloader = DataLoader( | ||
| dataset, | ||
| batch_size=args.batch_size, | ||
| shuffle=False, | ||
| num_workers=args.num_workers, | ||
| collate_fn=collate_fn, | ||
| ) | ||
|
|
||
| print(f"Found {len(dataset)} examples.") | ||
|
|
||
| class MockSafetyChecker: | ||
| def to(self, *args, **kwargs): | ||
| return self | ||
|
|
||
| def check_text_safety(self, *args, **kwargs): | ||
| return True | ||
|
|
||
| def check_video_safety(self, video): | ||
| return video | ||
|
|
||
| pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( | ||
| args.model_id, | ||
| revision=args.revision, | ||
| device_map=args.device, | ||
| torch_dtype=torch.bfloat16, | ||
| safety_checker=MockSafetyChecker(), | ||
| ) | ||
|
|
||
| if args.lora_dir is not None: | ||
| pipe.load_lora_weights(args.lora_dir) | ||
| pipe.fuse_lora(lora_scale=1.0) | ||
| print(f"Loaded LoRA weights from {args.lora_dir}") | ||
|
|
||
| progress = tqdm(total=len(dataset), desc="Generating") | ||
| for batch in dataloader: | ||
| images = batch["images"] | ||
| prompts = batch["prompts"] | ||
| stems = batch["stems"] | ||
|
|
||
| for image, prompt, stem in zip(images, prompts, stems): | ||
| frames = pipe( | ||
| image=image, | ||
| prompt=prompt, | ||
| negative_prompt=args.negative_prompt, | ||
| num_frames=args.num_output_frames, | ||
| num_inference_steps=args.num_steps, | ||
| height=args.height, | ||
| width=args.width, | ||
| ).frames[0] # NOTE: batch_size == 1 | ||
|
|
||
| out_path = os.path.join(args.output_dir, f"{stem}.mp4") | ||
| export_to_video(frames, out_path, fps=16) | ||
|
|
||
| tqdm.write(f" Saved to: {out_path}") | ||
| progress.update(1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| export DATA_DIR="gr1_dataset/test" | ||
| export LORA_DIR=YOUR_ADAPTER_DIR | ||
| export OUT_DIR=YOUR_EVAL_OUTPUT_DIR | ||
| revision="post-trained" | ||
|
|
||
| export TOKENIZERS_PARALLELISM=false | ||
| python eval_cosmos_predict25_lora.py \ | ||
| --data_dir $DATA_DIR \ | ||
| --output_dir $OUT_DIR \ | ||
| --lora_dir $LORA_DIR \ | ||
| --revision diffusers/base/$revision \ | ||
| --height 432 --width 768 \ | ||
| --num_output_frames 93 \ | ||
| --num_steps 36 \ | ||
| --seed 0 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Figures should be hosted somewhere else.