Skip to content

Conversation

@RayenTian
Copy link
Contributor

@RayenTian RayenTian commented Jan 20, 2026

Summary

  • Merge LoRA adapter weights into base linear weights when exporting dtensor state and skip standalone LoRA adapter tensors.
  • Add LoRA configuration defaults to grpo_math_1B.yaml and introduce a Qwen3-8B LoRA recipe.
  • Expand LoRA coverage in functional and unit tests (vLLM generation + GRPO LoRA suites).

Changes

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
    • Merge LoRA weights into base weights during state export.
    • Skip lora_A/lora_B tensors and release temporary tensors to reduce memory.
  • examples/configs/grpo_math_1B.yaml
    • Add LoRA config section with defaults/documentation.
  • examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
    • New LoRA recipe for Qwen3-8B.
  • tests/functional/*
    • Add GRPO LoRA functional tests (sync/async/non-colocated) and include in nightly.
  • tests/unit/models/generation/test_vllm_generation.py
    • Add LoRA config coverage and parameters in vLLM tests.

Testing

  • Not run (manual PR only).

Notes

  • LoRA weight merge uses W + scale * (B @ A) with dtype/device alignment.

Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 1ccb5be (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian added the CI:L1 Run doctests, unit tests, and functional tests label Jan 20, 2026
@RayenTian RayenTian removed the CI:L1 Run doctests, unit tests, and functional tests label Jan 20, 2026
Signed-off-by: ruit <ruit@nvidia.com>
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: acad57c (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@RayenTian RayenTian added the CI:L1 Run doctests, unit tests, and functional tests label Jan 21, 2026
Signed-off-by: ruit <ruit@nvidia.com>
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 68263ea (PR #1797 from ruit/lora_merge_weight)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

print(f"Error applying torch.ops.aten.alias.default patch: {e}")


def patched_lora_linear_forward(self, x):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this patch cannot guarantee that the computational logic is exactly identical to the original.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants