-
Notifications
You must be signed in to change notification settings - Fork 402
[Feature] support async rl #1360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
efb3109 to
1601d51
Compare
5e3f135 to
aaa4860
Compare
31b3535 to
953a613
Compare
f6fa0fd to
4bd4c4f
Compare
| tail_batch_trigger_size: Annotated[ | ||
| Optional[int], | ||
| Parameter( | ||
| help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set to 0 to disable. 这句描述不对
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有所谓的 enable说法吧,需要配合 tail_batch_candidate_steps 才生效
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个支持tail_batch_candidate_steps>0, tail_batch_trigger_size=0,这种情况下,过期的数据会放到过期队列,但是不会触发tail_batch,相当于这部分数据不去训练
xtuner/v1/data_proto/rl_data.py
Outdated
| response_ids: Optional[List[int]] = None | ||
| logprobs: Optional[List[float]] = None | ||
| num_return_tokens: Optional[int] = None | ||
| versioned_response: List[str] = Field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
思考下未来多轮情况下,这个地方是否有改动?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多轮的输出应该是list[RolloutResonseItem], RolloutResonseItem里面的结构不会变吧
| tail_batch_trigger_size: Annotated[ | ||
| Optional[int], | ||
| Parameter( | ||
| help="Number of candidate samples needed in the queue to trigger a tail batch operation. Set to 0 to disable." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个没有所谓的 enable说法吧,需要配合 tail_batch_candidate_steps 才生效
| if not self.enable_partial_rollout: | ||
| # 清除上次的response_ids等env数据 | ||
| if "routed_experts" in sample.env.rollout.extra_info: | ||
| del sample.env.rollout.extra_info["routed_experts"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否有考虑在中断情况下,下一次发送请求时候发给同一个 server,从而复用 cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在也不能复用cache,每一次权重更新都会清掉cache,这个feature可能需要跟cache kv off-policy一起加
2397345 to
003ca72
Compare
003ca72 to
ab6bccc
Compare
50afb64 to
718f817
Compare
|
异步的功能通过 test_lmdeploy_dataflow_save_resume_with_partial_rollout / test_lmdeploy_dataflow_save_resume_with_partial_rollout_r3进行测试,精度测试与RL的e2e测试一起加吧 |
718f817 to
283b1e1
Compare
fc8681a to
e2d18c0
Compare
This PR introduces asynchronous RL support to Xtuner, enabling partial rollouts and version-based sample management for more efficient training data generation.
1. Key Concepts:
2. Async logic:
staleness_threshold=0.0enable_partial_rollout=0tail_batch_candidate_steps=0staleness_threshold=0.2enable_partial_rollout=0tail_batch_candidate_steps=02. Responses not retained when paused rollout
3. Prioritize sampling data from the abort queue
staleness_threshold=0.2enable_partial_rollout=0tail_batch_candidate_steps=1tail_batch_trigger_size=02. Responses not retained when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches
tail_batch_candidate_steps+1staleness_threshold=0.2enable_partial_rollout=1tail_batch_candidate_steps=0tail_batch_trigger_size=02. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
staleness_threshold=0.2enable_partial_rollout=1tail_batch_candidate_steps=1tail_batch_trigger_size=02. Responses retained & concatenated when paused
3. Prioritize sampling data from the abort queue
4. Put it into the candidate pool when sample abort num reaches
tail_batch_candidate_steps+1. thetail_batch_candidate_stepsmeans off policy step3. BenchMark
4. Relative PR
sample_from_expired_storagein dataflow. Whensample_from_expired_storageis set to True, the dataflow will not oversend data and will return data only after all tasks of the current batch are completed.