feat: lr scheduler#151
Open
Chamberlain0w0 wants to merge 11 commits intomasterfrom
Open
Conversation
…r accessors, passthrough SetLearningRate/GetLearningRate, and add initial_learning_rate and it's accessors
…StepLR, LinearLR, LambdaLR and SequentialLR
…base class, add factory method Create<T>() with two-phase init and update all tests to use Create<T>() factory method. - Change Step() to virtual with default implementation - Add pure virtual ComputeLR() for subclasses to implement. - Adapt test helpers (IdentityScheduler, LinearDecayScheduler) to implement ComputeLR() instead of Step(). - All existing tests pass without behavioral changes. BREAKING CHANGE: Subclasses must implement ComputeLR() instead of Step().
…closed and chained form, adjust LinearLR、SequentialLR - enhance LRScheduler with chained and closed form learning rate methods - adapt methods(Step, InitialStep, GetClosedFormLR, GetChainedFormLR) to match PyTorch‘s design - add tests for consistency - refactor LinearLR: add end_factor, and rename this class - add SequentialLR InitialStep and UndoChildInitialSteps BREAKING CHANGE: Subclasses must implement GetClosedFormLR instead of ComputeLR(). Should use LinearLR instead of LinearwarmupLR.
- Add LRSchedulerConfig struct with parameters for all basic schedulers(constant, linear, step) - Add CreateLRScheduler() factory function - Support automatic warmup wrapping via SequentialLR when warmup_steps > 0 - Adapt test files
…ogs, and integrate scheduler into training loop
…s, add validation tests for learning rate schedulers - it now only be used for learning rate recovery when using loadstate
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
2025 年训练营项目选题“学习率调度器实现”中, @littleotherut 完成了学习率调度器模块的基本实现(PR #113 )。这个 PR 是在他实现的基础上进一步修改接口、规范相关,使之符合我们项目实际应用的需求。
设计文档可以参考:https://gxtctab8no8.feishu.cn/wiki/Bd6Pw1BeeiQ7QfktiT8cbSiTnFb?from=from_copylink。
核心修改:
main.cc里面加了相应的 gflags。获得上述参数后,构造相应的
TrainingLRSchedulerConfig结构体,然后同优化器一并传给CreateLRScheduler()构造得到对应的学习率调度器。a. 新增
LRScheduler基类以及各种调度策略对应的派生类,与 torch 实现以及使用方式均对齐。训练循环中在optimizer.step()之后再调用scheduler.step()即完成学习率更新。b.
LRScheduler需要与Optimizer交互来获取、更新、同步学习率,所以给Optimizer基类也加了对应的 setter 和 getter。c.
LRScheduler::State()的部分仍是一个较 naive 的实现,后续等 ckpt 机制完成以后再进一步修改。