Skip to content

Conversation

@NikhilNayak-debug
Copy link
Contributor

Summary

Adds runnable continual learning example for OSF method as requested in PR #2685 .

Changes

  • New example in examples/orthogonal_subspace_learning/
  • Demonstrates OSF preventing catastrophic forgetting on 3 sequential tasks (ScienceQA, NumGLUE, FOMC)
  • Includes full fine-tuning baseline for comparison
  • Progressive capacity allocation: 70% trainable (Task 1) → 50% (Task 2) → 30% (Task 3)
  • Tracks accuracy and backward transfer metrics

Results (2 epochs per task)

  • OSF: 53.42% average accuracy, +30.25% backward transfer
  • Full FT: 46.26% average accuracy, -6.00% forgetting
  • OSF prevents catastrophic forgetting and enables positive backward transfer

Files

  • osf_continual_learning.py - Main example script with OSF and baseline training
  • utils.py - Dataset loading and formatting utilities for 3 tasks
  • README.md - Comprehensive documentation with usage examples

Implementation Details

  • Uses meta-llama/Llama-3.1-8B-Instruct by default
  • Learning rate: 5e-6, batch size: 32
  • Progressive effective_rank allocation (0.3 → 0.5 → 0.7)

Addresses feedback from #2685

@NikhilNayak-debug
Copy link
Contributor Author

@githubnemo I have added the continual learning example as requested. Could you please review this PR?

The example demonstrates OSF on 3 sequential tasks (ScienceQA, NumGLUE, FOMC) with progressive rank allocation and compares against full fine-tuning baseline.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks, this looks very nice!

Just a quick review for now, tried to run the code and got this exception:

$ python osf_continual_learning.py --model_name meta-llama/Llama-3.2-1B-Instruct --run_baseline

================================================================================
TRAINING WITH OSF (Orthogonal Subspace Fine-tuning)
================================================================================

Loading datasets...
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 9871.85 examples/s]
Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 23863.81 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 17309.66 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 10534.10 examples/s]

================================================================================
TASK 1: ScienceQA
Effective Rank: 0.3 (preserving 30%)
================================================================================
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.

Training on ScienceQA...
{'loss': 6.7117, 'grad_norm': 220.0, 'learning_rate': 4.296875e-06, 'epoch': 0.31}                                                                                                                                   
{'loss': 1.8609, 'grad_norm': 136.0, 'learning_rate': 3.5156250000000003e-06, 'epoch': 0.62}                                                                                                                         
{'loss': 1.2447, 'grad_norm': 131.0, 'learning_rate': 2.7343750000000004e-06, 'epoch': 0.94}                                                                                                                         
{'loss': 1.1187, 'grad_norm': 118.5, 'learning_rate': 1.953125e-06, 'epoch': 1.25}                                                                                                                                   
{'loss': 1.0351, 'grad_norm': 117.0, 'learning_rate': 1.1718750000000001e-06, 'epoch': 1.56}                                                                                                                         
{'loss': 1.0219, 'grad_norm': 117.0, 'learning_rate': 3.90625e-07, 'epoch': 1.88}                                                                                                                                    
{'train_runtime': 22.6765, 'train_samples_per_second': 88.197, 'train_steps_per_second': 2.822, 'train_loss': 2.0924081951379776, 'epoch': 2.0}                                                                      
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:22<00:00,  2.82it/s]

Evaluating on all tasks after training on ScienceQA:
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:01<00:00, 20.69it/s]
Traceback (most recent call last):
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 694, in <module>
    main()
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 627, in main
    osf_history = train_with_osf(
                  ^^^^^^^^^^^^^^^
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 364, in train_with_osf
    loss, accuracy = evaluate_model(
                     ^^^^^^^^^^^^^^^
  File "examples/orthogonal_subspace_learning/osf_continual_learning.py", line 199, in evaluate_model
    loss = results["eval_loss"]
           ~~~~~~~^^^^^^^^^^^^^
KeyError: 'eval_loss'

Maybe that's on my side, I'm investigating.


- [OSF Documentation](../../docs/source/package_reference/osf.md)
- [PEFT Documentation](https://huggingface.co/docs/peft)
- [Original Paper](https://arxiv.org/abs/2504.07097)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- [Original Paper](https://arxiv.org/abs/2504.07097)
- [Original Paper](https://huggingface.co/papers/2504.07097)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the paper link!

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the culprit, commented the solution above.

I ran the example with Llama-3.2 1B and got these results:

Final full fine-tuning model saved to ./osf_continual_learning_outputs/full_final

================================================================================
RESULTS COMPARISON: OSF vs Full Fine-tuning
================================================================================

--------------------------------------------------------------------------------
DETAILED RESULTS (Accuracy %)
--------------------------------------------------------------------------------
Task            After Task      OSF Acc %       Full FT Acc %   Difference     
--------------------------------------------------------------------------------
ScienceQA       ScienceQA       41.00           45.00                     -4.00
ScienceQA       NumGLUE         44.50           76.50                    -32.00
ScienceQA       FOMC            51.50           81.50                    -30.00
NumGLUE         NumGLUE         22.00           48.50                    -26.50
NumGLUE         FOMC            17.00           50.50                    -33.50
FOMC            FOMC            28.77           28.77                     +0.00

================================================================================
SUMMARY METRICS
================================================================================

1. Average Accuracy Across All 3 Tasks (After Final Task):
   OSF:     32.42%
   Full FT: 53.59%
   Difference: -21.17% (Full FT better)

2. Average Forgetting (Task 1 & 2):
   Forgetting = Final Accuracy - Initial Accuracy (negative is worse)

   ScienceQA:
     OSF:     +10.50% (initial: 41.00% → final: 51.50%)
     Full FT: +36.50% (initial: 45.00% → final: 81.50%)
     Difference: -26.00% (Full FT better)

   NumGLUE:
     OSF:     -5.00% (initial: 22.00% → final: 17.00%)
     Full FT: +2.00% (initial: 48.50% → final: 50.50%)
     Difference: -7.00% (Full FT better)

   Average Forgetting:
     OSF:     +2.75%
     Full FT: +19.25%
     Difference: -16.50% (Full FT better)

I'm not sure if that's expected, the effective rank is probably smaller since there's probably a difference in hidden dimensions.

When I run the same experiment with --learning_rate=5e-5 I get the following:

================================================================================
RESULTS COMPARISON: OSF vs Full Fine-tuning
================================================================================

--------------------------------------------------------------------------------
DETAILED RESULTS (Accuracy %)
--------------------------------------------------------------------------------
Task            After Task      OSF Acc %       Full FT Acc %   Difference     
--------------------------------------------------------------------------------
ScienceQA       ScienceQA       100.00          39.50                    +60.50  (OSF better)
ScienceQA       NumGLUE         99.50           100.00                    -0.50
ScienceQA       FOMC            100.00          39.50                    +60.50  (OSF better)
NumGLUE         NumGLUE         54.50           17.00                    +37.50  (OSF better)
NumGLUE         FOMC            53.50           55.00                     -1.50
FOMC            FOMC            28.77           28.77                     +0.00

================================================================================
SUMMARY METRICS
================================================================================

1. Average Accuracy Across All 3 Tasks (After Final Task):
   OSF:     60.76%
   Full FT: 41.09%
   Difference: +19.67% (OSF better)

2. Average Forgetting (Task 1 & 2):
   Forgetting = Final Accuracy - Initial Accuracy (negative is worse)

   ScienceQA:
     OSF:     +0.00% (initial: 100.00% → final: 100.00%)
     Full FT: +0.00% (initial: 39.50% → final: 39.50%)
     Difference: +0.00% (Full FT better)

   NumGLUE:
     OSF:     -1.00% (initial: 54.50% → final: 53.50%)
     Full FT: +38.00% (initial: 17.00% → final: 55.00%)
     Difference: -39.00% (Full FT better)

   Average Forgetting:
     OSF:     -0.50%
     Full FT: +19.00%
     Difference: -19.50% (Full FT better)

Is this expected? This looks a bit off.

trainer = Trainer(
model=model,
data_collator=data_collator,
eval_dataset=eval_dataset,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eval_dataset=eval_dataset,
eval_dataset=eval_dataset,
args=TrainingArguments(
label_names=["labels"],
),

We need this to get an eval loss for PEFT models. See also: #1120 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added above lines!

tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
model_name, torch_dtype=torch.bfloat16, device_map="auto",

this isn't a requirement or is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it!

@NikhilNayak-debug
Copy link
Contributor Author

I ran the example with Llama-3.2 1B and got these results. I'm not sure if that's expected, the effective rank is probably smaller since there's probably a difference in hidden dimensions. Is this expected? This looks a bit off.

@githubnemo thanks for testing with Llama-3.2-1B! The results are expected, performance depends on learning rate and effective rank. The hyperparameters in the script are tuned for Llama-3.1-8B, so smaller models need adjustment (as you found with learning_rate=5e-5). I ran the same experiment using the same hyperparameters as the 8B model on the 1B model and got higher average accuracy with OSF than full fine-tuning.

The key point is that OSF generally shows better retention of earlier tasks (lower catastrophic forgetting) and higher average accuracy compared to full fine-tuning when properly tuned. This is good to go.

@NikhilNayak-debug NikhilNayak-debug force-pushed the osf-continual-learning-example branch from e3df758 to c98baed Compare December 9, 2025 21:38
@NikhilNayak-debug
Copy link
Contributor Author

@githubnemo the PR is ready for review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants