Skip to content
Closed
429 changes: 429 additions & 0 deletions design/redesign_disagg.md

Large diffs are not rendered by default.

843 changes: 843 additions & 0 deletions design/redesign_disagg.py

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -euo pipefail

if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then
export CUDA_VISIBLE_DEVICES="$(nvidia-smi --query-gpu=index --format=csv,noheader | paste -sd, -)"
fi

IFS=',' read -r -a XTUNER_TEST_VISIBLE_GPUS <<< "${CUDA_VISIBLE_DEVICES}"
XTUNER_TEST_GPU_NUM="${#XTUNER_TEST_VISIBLE_GPUS[@]}"
if [[ "${XTUNER_TEST_GPU_NUM}" -ne 4 && "${XTUNER_TEST_GPU_NUM}" -ne 8 ]]; then
echo "run_test.sh expects 4 or 8 visible GPUs, got CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" >&2
exit 1
fi

XTUNER_TEST_GPU0="${XTUNER_TEST_VISIBLE_GPUS[0]//[[:space:]]/}"
for XTUNER_TEST_GPU in "${XTUNER_TEST_VISIBLE_GPUS[@]}"; do
XTUNER_TEST_GPU="${XTUNER_TEST_GPU//[[:space:]]/}"
if [[ ! "${XTUNER_TEST_GPU}" =~ ^[0-9]+$ ]]; then
echo "run_test.sh expects numeric CUDA_VISIBLE_DEVICES entries, got CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" >&2
exit 1
fi
done

source ./zdev/env.sh
source $(conda info --base)/etc/profile.d/conda.sh
conda activate fla

export RAY_ADDRESS=local
export RAY_TMPDIR="/tmp/xrt_${XTUNER_TEST_GPU0}g${XTUNER_TEST_GPU_NUM}_$$"
export XTUNER_DIST_PORT_BASE="$((35000 + XTUNER_TEST_GPU0 * 1024))"
export XTUNER_TEST_NUM_WORKERS="${XTUNER_TEST_GPU_NUM}"

echo "run_test.sh: CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}"
echo "run_test.sh: XTUNER_DIST_PORT_BASE=${XTUNER_DIST_PORT_BASE}"

pytest --durations=20 \
tests/rl/test_producer.py \
tests/rl/test_rl_colocate_trainer.py \
tests/rl/test_rl_disaggregated_trainer.py
211 changes: 124 additions & 87 deletions tests/rl/test_multi_task_agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,30 @@ def __init__(
self.cleanup_model_steps: list[int] = []
self.cleanup_progresses: list[object | None] = []
self.cleanup_call_count = 0

async def produce_batch(
self,
agent_loop,
sampler,
replay_buffer,
batch_size: int,
task_name: str,
train_step: int = 0,
update_event=None,
*,
model_step: int,
progress,
) -> ProduceBatchStatus:
self.called_batch_sizes.append(batch_size)
self.called_train_steps.append(train_step)
self.called_model_steps.append(model_step)
self.called_update_events.append(update_event)
self.called_update_event_states.append(None if update_event is None else update_event.is_set())
self.called_progresses.append(progress)
self.pending_task_count_value = 0

async def produce_batch(self, ctx) -> ProduceBatchStatus:
self.called_batch_sizes.append(ctx.task_batch_size)
self.called_train_steps.append(ctx.train_step)
self.called_model_steps.append(ctx.model_step)
self.called_update_events.append(ctx.update_event)
self.called_update_event_states.append(None if ctx.update_event is None else ctx.update_event.is_set())
self.called_progresses.append(ctx.progress)
return self.status

async def pause_produce(self, agent_loop, replay_buffer, task_name: str, *, model_step: int, progress) -> float:
async def pause_produce(self, ctx) -> float:
self.cleanup_call_count += 1
self.cleanup_model_steps.append(model_step)
self.cleanup_progresses.append(progress)
self.cleanup_model_steps.append(ctx.model_step)
self.cleanup_progresses.append(ctx.progress)
return self.cleanup_pause_time_s

def is_model_expired(self, train_step: int, model_step: int) -> bool:
# fake strategy 的过期状态由用例显式返回 status 控制。
return False

def pending_task_count(self) -> int:
return self.pending_task_count_value


class _FakeStatusProduceStrategy:
def __init__(self, status: ProduceBatchStatus, pause_time_s: float):
Expand All @@ -95,33 +91,29 @@ def __init__(self, status: ProduceBatchStatus, pause_time_s: float):
self.called_progresses: list[object] = []
self.cleanup_model_steps: list[int] = []
self.cleanup_progresses: list[object | None] = []

async def produce_batch(
self,
agent_loop,
sampler,
replay_buffer,
batch_size: int,
task_name: str,
train_step: int = 0,
update_event=None,
*,
model_step: int,
progress,
) -> ProduceBatchStatus:
self.called_train_steps.append(train_step)
self.called_model_steps.append(model_step)
self.called_update_events.append(update_event)
self.called_update_event_states.append(None if update_event is None else update_event.is_set())
self.called_progresses.append(progress)
self.pending_task_count_value = 0

async def produce_batch(self, ctx) -> ProduceBatchStatus:
self.called_train_steps.append(ctx.train_step)
self.called_model_steps.append(ctx.model_step)
self.called_update_events.append(ctx.update_event)
self.called_update_event_states.append(None if ctx.update_event is None else ctx.update_event.is_set())
self.called_progresses.append(ctx.progress)
return self.status

async def pause_produce(self, agent_loop, replay_buffer, task_name: str, *, model_step: int, progress) -> float:
async def pause_produce(self, ctx) -> float:
self.cleanup_call_count += 1
self.cleanup_model_steps.append(model_step)
self.cleanup_progresses.append(progress)
self.cleanup_model_steps.append(ctx.model_step)
self.cleanup_progresses.append(ctx.progress)
return self.pause_time_s

def is_model_expired(self, train_step: int, model_step: int) -> bool:
# fake strategy 的过期状态由用例显式返回 status 控制。
return False

def pending_task_count(self) -> int:
return self.pending_task_count_value


class _FakeRolloutState:
def __init__(self, uid: str, group_generate_time_s: float):
Expand All @@ -141,25 +133,13 @@ def __init__(self, statuses: list[ProduceBatchStatus], cleanup_pause_time_s: flo
super().__init__(status=ProduceBatchStatus.NORMAL, cleanup_pause_time_s=cleanup_pause_time_s)
self._statuses = list(statuses)

async def produce_batch(
self,
agent_loop,
sampler,
replay_buffer,
batch_size: int,
task_name: str,
train_step: int = 0,
update_event=None,
*,
model_step: int,
progress,
) -> ProduceBatchStatus:
self.called_batch_sizes.append(batch_size)
self.called_train_steps.append(train_step)
self.called_model_steps.append(model_step)
self.called_update_events.append(update_event)
self.called_update_event_states.append(None if update_event is None else update_event.is_set())
self.called_progresses.append(progress)
async def produce_batch(self, ctx) -> ProduceBatchStatus:
self.called_batch_sizes.append(ctx.task_batch_size)
self.called_train_steps.append(ctx.train_step)
self.called_model_steps.append(ctx.model_step)
self.called_update_events.append(ctx.update_event)
self.called_update_event_states.append(None if ctx.update_event is None else ctx.update_event.is_set())
self.called_progresses.append(ctx.progress)
return self._statuses.pop(0) if self._statuses else ProduceBatchStatus.NORMAL


Expand All @@ -183,20 +163,49 @@ async def count(self, task_name: str, group_status: Status):

async def refresh_staleness(
self,
task_name: str,
*,
task_stale_thresholds: dict[str, int],
current_train_step: int,
stale_threshold: int,
statuses: list[Status] | None = None,
):
self.refresh_staleness_calls.append((task_name, current_train_step, stale_threshold, tuple(statuses or ())))
for group in self._rollout_states_by_task.get(task_name, []):
for state in group:
response_model_steps = getattr(state, "response_model_steps", None) or []
if response_model_steps and hasattr(state, "seq_staleness"):
state.seq_staleness = calculate_seq_staleness(
min(response_model_steps), current_train_step
)
return 0
expired_counts = {}
for task_name, stale_threshold in task_stale_thresholds.items():
self.refresh_staleness_calls.append(
(task_name, current_train_step, stale_threshold, tuple(statuses or ()))
)
for group in self._rollout_states_by_task.get(task_name, []):
for state in group:
response_model_steps = getattr(state, "response_model_steps", None) or []
if response_model_steps and hasattr(state, "seq_staleness"):
state.seq_staleness = calculate_seq_staleness(
min(response_model_steps), current_train_step
)
expired_counts[task_name] = 0
return expired_counts

async def is_ready(self, task_batch_sizes: dict[str, int], *, group_status: Status = Status.COMPLETED):
for task_name, batch_size in task_batch_sizes.items():
if await self.count(task_name, group_status) < batch_size:
return False
return True

async def take_batch(self, task_batch_sizes: dict[str, int], *, group_status: Status = Status.COMPLETED):
batch_by_task = {}
consumed_counts = {}
for task_name, batch_size in task_batch_sizes.items():
batch = await self.get(batch_size, task_name, group_status)
batch_by_task[task_name] = batch
consumed_counts[task_name] = len(batch)
return batch_by_task, consumed_counts

async def count_statuses(self, task_names: list[str], statuses: list[Status]):
return {
task_name: {
status: self._leftover_counts.get((task_name, status), 0)
for status in statuses
}
for task_name in task_names
}

async def save(self, checkpoint_path: Path | str):
self.saved_paths.append(Path(checkpoint_path))
Expand Down Expand Up @@ -301,14 +310,14 @@ async def test_produce_batch_allocates_by_weight_and_returns_task_sorted_results
],
replay_buffer=replay_buffer,
)
multi_task_manager._status = AgentLoopManagerStatus.UPDATE_ABORT
multi_task_manager._status = AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT
multi_task_manager._update_event.set()

result = await multi_task_manager.produce_batch(batch_size=7, train_step=3, model_step=2)

self.assertEqual(result.task_batch_sizes, {"task_a": 5, "task_b": 2, "task_c": 0})
# sync produce_batch 在本轮入口恢复 NORMAL,收尾 pause 后保留 UPDATE_ABORT 到下一轮入口再清理。
self.assertEqual(multi_task_manager._status, AgentLoopManagerStatus.UPDATE_ABORT)
# sync produce_batch 在本轮入口恢复 NORMAL,收尾 pause 后保留 UPDATE_WEIGHT_AND_ABORT 到下一轮入口再清理。
self.assertEqual(multi_task_manager._status, AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT)
self.assertTrue(multi_task_manager._update_event.is_set())
self.assertEqual(multi_task_manager._model_step, 2)
self.assertEqual(strategy_a.called_batch_sizes, [5])
Expand Down Expand Up @@ -373,7 +382,7 @@ def test_save_and_resume_roundtrip_restores_paused_manager_state(self):
restored_step = manager.resume(checkpoint_path)

self.assertEqual(restored_step, 7)
self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_ABORT)
self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT)
self.assertTrue(manager._update_event.is_set())
self.assertFalse(manager._finish_event.is_set())
self.assertEqual(manager._pause_time_s, 0.0)
Expand All @@ -385,7 +394,7 @@ def test_save_and_resume_roundtrip_restores_paused_manager_state(self):

def test_save_rejects_pending_async_tasks(self):
strategy = _FakeProduceStrategy()
strategy._pending_tasks = {object()}
strategy.pending_task_count_value = 1
manager = AgentLoopManager(
task_runners=[
_TaskRunner(
Expand Down Expand Up @@ -488,7 +497,7 @@ async def test_status_returning_strategy_uses_cleanup_and_reconstructs_group_tim
self.assertEqual(strategy.called_model_steps, [6])
self.assertEqual(len(strategy.called_update_events), 1)
self.assertFalse(strategy.called_update_event_states[0])
self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_ABORT)
self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT)
self.assertTrue(manager._update_event.is_set())
self.assertEqual(result.group_gen_count, 2)
self.assertAlmostEqual(result.group_gen_mean_s, 0.75)
Expand Down Expand Up @@ -538,7 +547,7 @@ async def test_pause_produce_from_async_produce_loop_sets_status_and_pause_time(
self.assertEqual(strategy.cleanup_model_steps, [0])
self.assertIs(strategy.cleanup_progresses[0], manager._produce_progress)
self.assertTrue(manager._update_event.is_set())
self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_ABORT)
self.assertEqual(manager._status, AgentLoopManagerStatus.UPDATE_WEIGHT_AND_ABORT)
self.assertEqual(manager._pause_time_s, 2.5)

async def test_pause_produce_validates_progress_selection_before_state_change(self):
Expand Down Expand Up @@ -678,7 +687,7 @@ async def test_produce_batch_to_buffer_aggregates_status_with_update_abort_prior
_TaskRunner(
task_name="task_c",
agent_loop=_fake_agent_loop(),
produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.UPDATE_ABORT),
produce_strategy=_FakeProduceStrategy(status=ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT),
sampler=_FakeSampler(),
weight=1.0,
order=2,
Expand All @@ -689,9 +698,17 @@ async def test_produce_batch_to_buffer_aggregates_status_with_update_abort_prior

manager._model_step = 5
manager._produce_progress.producer_future_step = 5
status = await manager._produce_batch_to_buffer(batch_size=3, progress=manager._produce_progress)
task_batch_sizes = manager._produce_progress.ensure_target_upto(
batch_size=3,
future_step=manager._produce_progress.producer_future_step,
allocate_batch_sizes=manager._get_task_batch_sizes_for_step,
)
status = await manager._produce_batch_to_buffer(
task_batch_sizes=task_batch_sizes,
progress=manager._produce_progress,
)

self.assertEqual(status, ProduceBatchStatus.UPDATE_ABORT)
self.assertEqual(status, ProduceBatchStatus.UPDATE_WEIGHT_AND_ABORT)

async def test_produce_loop_waits_for_continue_produce_and_stops_on_finish(self):
strategy = _SequencedProduceStrategy(
Expand Down Expand Up @@ -724,6 +741,26 @@ async def test_produce_loop_waits_for_continue_produce_and_stops_on_finish(self)
self.assertEqual(strategy.called_train_steps[:3], [3, 4, 4])
self.assertEqual(strategy.called_model_steps[2], 9)

manager._status = AgentLoopManagerStatus.FINISH
manager._finish_event.set()
manager.shutdown()
await asyncio.wait_for(loop_task, timeout=1.0)

async def test_shutdown_sets_finish_signals(self):
manager = AgentLoopManager(
task_runners=[
_TaskRunner(
task_name="task_a",
agent_loop=_fake_agent_loop(),
produce_strategy=_FakeProduceStrategy(),
sampler=_FakeSampler(),
weight=1.0,
order=0,
),
],
replay_buffer=_FakeReplayBuffer({}, {}),
)

manager.shutdown()

self.assertEqual(manager._status, AgentLoopManagerStatus.FINISH)
self.assertTrue(manager._update_event.is_set())
self.assertTrue(manager._finish_event.is_set())
Loading