Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self, fd_config: FDConfig, block_table, total_block_num):
num_experts = fd_config.model_config.moe_num_experts + fd_config.model_config.moe_num_shared_experts
self.routing_dtype = self.get_routing_dtype(num_experts=num_experts)
self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num)
self.pending_update_positions = None

# Initialize routing store wrapper
if self.tp_rank == 0:
Expand Down Expand Up @@ -287,6 +288,8 @@ def put_finished_batch(
seq_lens_decoder,
):
finished_batch_ids_list = finished_batch_ids.cpu().tolist()
logger.info(f"[R3] Finished batch id list: {finished_batch_ids_list}")
logger.info(f"[R3] batch id to request map: {self.routing_batch_to_request}")
for batch_id, finished in enumerate(finished_batch_ids_list):
if finished:
assert batch_id in self.routing_batch_to_request.keys()
Expand Down Expand Up @@ -397,7 +400,7 @@ def __init__(self, fd_config: False) -> None:
# Initialize task queue
moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index
max_num_seqs = fd_config.scheduler_config.max_num_seqs
self.queue_max_size = moe_layer_num * max_num_seqs * 10
self.queue_max_size = moe_layer_num * max_num_seqs * 1000

self.manager = multiprocessing.Manager()
self._task_queue = self.manager.Queue(maxsize=self.queue_max_size)
Expand Down
21 changes: 21 additions & 0 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
calculate_logits_entropy,
speculate_calculate_logits_entropy,
)
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
Expand Down Expand Up @@ -326,6 +329,7 @@ def post_process_normal(
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
):
"""Post-processing steps after completing a single token generation."""
if think_end_id > 0:
Expand Down Expand Up @@ -394,6 +398,21 @@ def post_process_normal(
if enable_entropy:
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)

# Routing replay
if routing_replay_manager is not None:
# Update host cache
slot_mapping = routing_replay_manager.compute_slot_mapping(
positions=routing_replay_manager.pending_update_positions
)
routing_replay_manager.update_host_cache(
positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping
)

# Put routing of finished requests to store
finished_batch_ids = paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id)[:, 0]
context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder
routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens)

# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
Expand Down Expand Up @@ -564,6 +583,7 @@ def post_process(
think_end_id: int = -1,
line_break_id: int = -1,
enable_entropy: bool = False,
routing_replay_manager: RoutingReplayManager = None,
) -> None:
"""Post-processing steps after completing a single token generation."""

Expand Down Expand Up @@ -603,6 +623,7 @@ def post_process(
think_end_id,
line_break_id,
enable_entropy,
routing_replay_manager,
)


Expand Down
30 changes: 16 additions & 14 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,7 @@ class at the server level, which is too granular for ModelRunner.
self._prepare_inputs()
self.sampler.pre_process(p_done_idxs)
if self.fd_config.routing_replay_config.enable_routing_replay:
self.positions = self.routing_replay_manager.get_token_positions(
self.routing_replay_manager.pending_update_positions = self.routing_replay_manager.get_token_positions(
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.seq_lens_this_time_buffer,
)
Expand Down Expand Up @@ -2450,6 +2450,7 @@ class at the server level, which is too granular for ModelRunner.
skip_save_output=False,
async_output_queue=self.async_output_queue,
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
routing_replay_manager=self.routing_replay_manager,
)

return None
Expand Down Expand Up @@ -2579,6 +2580,7 @@ class at the server level, which is too granular for ModelRunner.
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
enable_entropy=self.enable_entropy and self.parallel_config.tensor_parallel_rank == 0,
routing_replay_manager=self.routing_replay_manager,
)
if self.guided_backend is not None and sampler_output is not None:
self.sampler.post_process(sampler_output.sampled_token_ids)
Expand Down Expand Up @@ -2626,19 +2628,19 @@ class at the server level, which is too granular for ModelRunner.
self.speculative_config.num_speculative_tokens,
)

# Routing replay
if self.fd_config.routing_replay_config.enable_routing_replay:
# Update host cache
slot_mapping = self.routing_replay_manager.compute_slot_mapping(positions=self.positions)
self.routing_replay_manager.update_host_cache(positions=self.positions, slot_mapping=slot_mapping)

# Put routing of finished requests to store
finished_batch_ids = paddle.isin(sampler_output.sampled_token_ids, self.share_inputs["eos_token_id"])[:, 0]
self.routing_replay_manager.put_finished_batch(
finished_batch_ids=finished_batch_ids,
seq_lens_decoder=self.seq_lens_routing_buffer,
)
paddle.assign(self.share_inputs["seq_lens_decoder"], self.seq_lens_routing_buffer)
# # Routing replay
# if self.fd_config.routing_replay_config.enable_routing_replay:
# # Update host cache
# slot_mapping = self.routing_replay_manager.compute_slot_mapping(positions=self.positions)
# self.routing_replay_manager.update_host_cache(positions=self.positions, slot_mapping=slot_mapping)

# # Put routing of finished requests to store
# finished_batch_ids = paddle.isin(sampler_output.sampled_token_ids, self.share_inputs["eos_token_id"])[:, 0]
# self.routing_replay_manager.put_finished_batch(
# finished_batch_ids=finished_batch_ids,
# seq_lens_decoder=self.seq_lens_routing_buffer,
# )
# paddle.assign(self.share_inputs["seq_lens_decoder"], self.seq_lens_routing_buffer)

return None

Expand Down
Loading