Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions xtuner/v1/ray/dataflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ def __init__(
self.env = env
self.config = dataflow_cfg
replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir
self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined]
self.replay_buffer.setup_storage_config.remote( # type: ignore[attr-defined]
self.replay_buffer = ReplayBuffer(replay_buffer_cfg)
self.replay_buffer.setup_storage_config(
enable_partial_rollout=self.config.enable_partial_rollout,
tail_batch_candidate_steps=self.config.tail_batch_candidate_steps,
tail_batch_trigger_size=self.config.tail_batch_trigger_size,
tail_batch_trigger_size=self.config.tail_batch_trigger_size, # type: ignore
)
self.staleness_threshold = self.config.staleness_threshold
self.env_controller = environment
Expand Down Expand Up @@ -201,9 +201,7 @@ def _reset_internal_states(
else:
self.staleness_threshold = self.config.staleness_threshold

self.sample_from_expired_storage, self.finished_samples_count = ray.get(
self.replay_buffer.get_prerun_state.remote()
)
self.sample_from_expired_storage, self.finished_samples_count = self.replay_buffer.get_prerun_state()
ray.get(self.env_controller.restart.remote()) # type: ignore[attr-defined]
self.sample_params = sample_params if sample_params else self.config.sample_params
self.extra_params = extra_params if extra_params else self.config.extra_params
Expand All @@ -217,7 +215,7 @@ def _reset_internal_states(
@ray_method
def get_train_dataset_length(self):
"""Gets the length of the training dataset from the replay buffer."""
return ray.get(self.replay_buffer.get_train_dataset_length.remote())
return self.replay_buffer.get_train_dataset_length()

@ray_method
async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowItem]] = None):
Expand All @@ -242,9 +240,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
# step 1: sample
# TODO(@duanyanhui): More fine-grained control over group data generation:
# Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency.
group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
self.env, self.config.prompt_repeat_k
)
group_data_items = self.replay_buffer.sample(self.env, self.config.prompt_repeat_k)
assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer."
action_id = group_data_items[0].uid.action_id
# step 2: env generate
Expand All @@ -258,14 +254,15 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
group_state = determine_group_state(group_data_items)
self.logger.debug(f"Determined replay state for {action_id}: {group_state}")
if group_state == RolloutState.COMPLETED:
group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
if not self.sample_from_expired_storage:
group_data_items = self.replay_buffer.post_processor(group_data_items) # type: ignore[attr-defined]
if len(group_data_items) > 0:
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
self.replay_buffer.add(group_data_items) # type: ignore[attr-defined]
else:
self.filtered_samples_count += 1
self.logger.debug(f"Worker task completed successfully for {action_id}.")
elif group_state == RolloutState.ABORTED:
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
self.replay_buffer.add(group_data_items) # type: ignore[attr-defined]
self.logger.debug(f"Adding aborted sample {action_id} to aborted storage")
elif group_state == RolloutState.SKIPPED:
self.skipped_sample_count += 1
Expand Down Expand Up @@ -343,7 +340,7 @@ async def concurrent_task_runner(self):
task_time = done_task.result()
task_completion_times.append(task_time)

self.finished_samples_count = await self.replay_buffer.get_completed_samples_count.remote()
self.finished_samples_count = self.replay_buffer.get_completed_samples_count()
pbar.update(self.finished_samples_count - last_pbar_n)
last_pbar_n = self.finished_samples_count

Expand Down Expand Up @@ -473,7 +470,7 @@ async def run(
self.logging_replaybuffer_state("DataFlow run completed. ")

get_start_time = time.perf_counter()
return_samples = await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined]
return_samples = self.replay_buffer.get_samples(self.target_batch_size) # type: ignore[attr-defined]
self.logger.info(
f"Getting {self.target_batch_size} samples from replay buffer took {time.perf_counter() - get_start_time:.2f}s"
)
Expand All @@ -496,7 +493,7 @@ def logging_replaybuffer_state(self, logging_msg: Optional[str] = None):
self.logger.info(logging_msg)

def get_replaybuffer_status(self):
return ray.get(self.replay_buffer.status.remote())
return self.replay_buffer.status()

async def _send_abort_request(self, client, url, timeout):
worker_url = f"{url}/abort_request"
Expand Down Expand Up @@ -543,15 +540,15 @@ def save(self, save_path: Path | str):
Args:
save_path (str): The path to the checkpoint file to save to.
"""
ray.get(self.replay_buffer.save.remote(save_path))
self.replay_buffer.save(save_path)

def resume(self, resume_path: Path | str):
"""Resumes the replay buffer from the specified path.

Args:
resume_path (str): The path to the checkpoint file to resume from.
"""
ray.get(self.replay_buffer.resume.remote(resume_path))
self.replay_buffer.resume(resume_path)


DataFlow = ray.remote(RawDataFlow)
Expand Down
1 change: 0 additions & 1 deletion xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,6 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta):
)


@ray.remote
class ReplayBuffer:
"""A Ray actor that manages experience replay for reinforcement
learning."""
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/ray/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ async def rollout_task(
# 当拼接后的response_ids长度已经达到了max_tokens时,则不需要发送数据,直接返回
if extra_info.get("partial_rollout_input_ids", None) is not None:
if sample_params["max_tokens"] == 0:
self.logger.info(
self.logger.debug(
f"Request {uid} reached max context length {self.config.context_length}, no need to rollout more."
)
return RLRolloutResponseItem(
Expand All @@ -390,7 +390,7 @@ async def rollout_task(
state=RolloutState.COMPLETED,
)
if extra_info["partial_rollout_input_ids"][-1] in self.eos_token:
self.logger.info(
self.logger.debug(
f"Request {uid} already ends with eos token {extra_info['partial_rollout_input_ids'][-1]}, no need to rollout more"
)
return RLRolloutResponseItem(
Expand Down
Loading