Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,5 +249,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_CONFIG_ROOT": lambda: os.path.expanduser(
os.getenv("FD_CONFIG_ROOT", os.path.join(os.path.expanduser("~"), ".config", "fastdeploy"))
),
# Whether to force the inference engine to synchronize token_ids sampled by TP groups.
"FD_SYNC_TOKEN_IDS_ACROSS_TP": lambda: bool(int(os.getenv("FD_SYNC_TOKEN_IDS_ACROSS_TP", "1"))),
}
```
3 changes: 3 additions & 0 deletions docs/zh/usage/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_CONFIG_ROOT": lambda: os.path.expanduser(
os.getenv("FD_CONFIG_ROOT", os.path.join(os.path.expanduser("~"), ".config", "fastdeploy"))
),

# 是否强制推理引擎同步TP组采样到的 token_ids, 默认同步
"FD_SYNC_TOKEN_IDS_ACROSS_TP": lambda: bool(int(os.getenv("FD_SYNC_TOKEN_IDS_ACROSS_TP", "1"))),
}
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def _validate_split_kv_size(value: int) -> int:
# When v1 is enabled, the legacy /clear_load_weight and /update_model_weight
# will adopt this new communication pattern.
"FD_ENABLE_V1_UPDATE_WEIGHTS": lambda: bool(int(os.getenv("FD_ENABLE_V1_UPDATE_WEIGHTS", "0"))),
# Whether to sync token IDs across TP ranks
"FD_SYNC_TOKEN_IDS_ACROSS_TP": lambda: bool(int(os.getenv("FD_SYNC_TOKEN_IDS_ACROSS_TP", "1"))),
}


Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,7 @@ def _dummy_sampler_run(
self.share_inputs["stop_flags"],
)
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
if self.parallel_config.tensor_parallel_size > 1 and envs.FD_SYNC_TOKEN_IDS_ACROSS_TP:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
Expand Down Expand Up @@ -2280,7 +2280,7 @@ def _postprocess(
[sampler_output.sampled_token_ids.shape[0]], device="cpu", dtype="int64"
),
)
if self.parallel_config.tensor_parallel_size > 1:
if self.parallel_config.tensor_parallel_size > 1 and envs.FD_SYNC_TOKEN_IDS_ACROSS_TP:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
Expand Down
Loading