Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def _get_s3_endpoint_flag() -> str:
if SERVICE_IDENTIFIER:
SERVICE_NAME += f"-{SERVICE_IDENTIFIER}"

# Host address for inference servers (vLLM, SGLang, TGI).
# Defaults to "::" (IPv6 all-interfaces) for clusters with IPv6 pod networking.
# Set to "0.0.0.0" for clusters using IPv4 pod networking.
INFERENCE_SERVER_HOST = os.getenv("INFERENCE_SERVER_HOST", "::")


def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int:
"""
Expand Down Expand Up @@ -562,7 +567,7 @@ async def create_text_generation_inference_bundle(
)

subcommands.append(
f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}"
f"text-generation-launcher --hostname {INFERENCE_SERVER_HOST} --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}"
)

if quantize:
Expand Down Expand Up @@ -851,7 +856,7 @@ async def create_sglang_bundle( # pragma: no cover
if chat_template_override:
sglang_args.chat_template = chat_template_override

sglang_cmd = f"python3 -m sglang.launch_server --model-path {huggingface_repo} --served-model-name {model_name} --port 5005 --host '::'"
sglang_cmd = f"python3 -m sglang.launch_server --model-path {huggingface_repo} --served-model-name {model_name} --port 5005 --host '{INFERENCE_SERVER_HOST}'"
for field in SGLangEndpointAdditionalArgs.model_fields.keys():
config_value = getattr(sglang_args, field, None)
if config_value is not None:
Expand Down Expand Up @@ -1000,7 +1005,7 @@ def _create_vllm_bundle_command(

# Use wrapper if startup metrics enabled, otherwise use vllm_server directly
server_module = "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server"
vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "::"'
vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "{INFERENCE_SERVER_HOST}"'
for field in VLLMEndpointAdditionalArgs.model_fields.keys():
config_value = getattr(vllm_args, field, None)
if config_value is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def main(
"--tp",
str(tp),
"--host",
"::",
os.environ.get("INFERENCE_SERVER_HOST", "::"),
"--port",
str(worker_port),
"--dist-init-addr",
Expand Down