Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion model-engine/model_engine_server/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,23 @@ class HostedModelInferenceServiceConfig:

@classmethod
def from_json(cls, json):
return cls(**{k: v for k, v in json.items() if k in inspect.signature(cls).parameters})
# NOTE: Our Helm chart historically rendered booleans as strings (e.g. "false")
# via a blanket `| quote`. Dataclasses don't enforce runtime types, so a non-empty
# string would evaluate truthy and accidentally enable features like Istio.
sig = inspect.signature(cls)
kwargs = {}
for k, v in json.items():
if k not in sig.parameters:
continue
ann = sig.parameters[k].annotation
if ann is bool and isinstance(v, str):
vv = v.strip().lower()
if vv in {"true", "1", "yes", "y"}:
v = True
elif vv in {"false", "0", "no", "n", ""}:
v = False
kwargs[k] = v
return cls(**kwargs)

@classmethod
def from_yaml(cls, yaml_path):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from typing import Dict, Optional

from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User
Expand All @@ -13,10 +14,25 @@ def __init__(self, user_team_override: Optional[Dict[str, str]] = None):
def is_allowed_team(team: str) -> bool:
return True

@staticmethod
def _stable_id(username: str) -> str:
"""
Model-engine DB schemas often store created_by/owner as VARCHAR(24) (mongo-style ids).
When using fake auth and bearer tokens are long, we must map them into 24 chars to
avoid DB insert failures while keeping determinism.
"""
if len(username) <= 24:
return username
return hashlib.sha1(username.encode("utf-8")).hexdigest()[:24]

def get_auth_from_username(self, username: str) -> Optional[User]:
team_id = self.user_team_override.get(username, username)
return User(user_id=username, team_id=team_id, is_privileged_user=True)
user_id = self._stable_id(username)
team_id_raw = self.user_team_override.get(username, username)
team_id = self._stable_id(team_id_raw)
return User(user_id=user_id, team_id=team_id, is_privileged_user=True)

async def get_auth_from_username_async(self, username: str) -> Optional[User]:
team_id = self.user_team_override.get(username, username)
return User(user_id=username, team_id=team_id, is_privileged_user=True)
user_id = self._stable_id(username)
team_id_raw = self.user_team_override.get(username, username)
team_id = self._stable_id(team_id_raw)
return User(user_id=user_id, team_id=team_id, is_privileged_user=True)
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def _get_s3_endpoint_flag() -> str:
if SERVICE_IDENTIFIER:
SERVICE_NAME += f"-{SERVICE_IDENTIFIER}"

# Host address for inference servers (vLLM, SGLang, TGI).
# Defaults to "::" (IPv6 all-interfaces) for clusters with IPv6 pod networking.
# Set to "0.0.0.0" for clusters using IPv4 pod networking (e.g., AWS EKS).
INFERENCE_SERVER_HOST = os.getenv("INFERENCE_SERVER_HOST", "::")


def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int:
"""
Expand Down Expand Up @@ -562,7 +567,7 @@ async def create_text_generation_inference_bundle(
)

subcommands.append(
f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}"
f"text-generation-launcher --hostname {INFERENCE_SERVER_HOST} --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}"
)

if quantize:
Expand Down Expand Up @@ -851,7 +856,7 @@ async def create_sglang_bundle( # pragma: no cover
if chat_template_override:
sglang_args.chat_template = chat_template_override

sglang_cmd = f"python3 -m sglang.launch_server --model-path {huggingface_repo} --served-model-name {model_name} --port 5005 --host '::'"
sglang_cmd = f"python3 -m sglang.launch_server --model-path {huggingface_repo} --served-model-name {model_name} --port 5005 --host '{INFERENCE_SERVER_HOST}'"
for field in SGLangEndpointAdditionalArgs.model_fields.keys():
config_value = getattr(sglang_args, field, None)
if config_value is not None:
Expand Down Expand Up @@ -1000,7 +1005,10 @@ def _create_vllm_bundle_command(

# Use wrapper if startup metrics enabled, otherwise use vllm_server directly
server_module = "vllm_startup_wrapper" if enable_startup_metrics else "vllm_server"
vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005 --host "::"'
# NOTE: vLLM's OpenAI server expects model path to be provided via `--model`.
# Do not add any extra positional args (it can confuse argparse and lead to
# incorrect served model names / startup failures).
vllm_cmd = f'python -m {server_module} --model {final_weights_folder} --served-model-name {model_name} --port 5005 --host "{INFERENCE_SERVER_HOST}"'
for field in VLLMEndpointAdditionalArgs.model_fields.keys():
config_value = getattr(vllm_args, field, None)
if config_value is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def main(
"--tp",
str(tp),
"--host",
"::",
os.environ.get("INFERENCE_SERVER_HOST", "::"),
"--port",
str(worker_port),
"--dist-init-addr",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
USER_CONTAINER_PORT = 5005
ARTIFACT_LIKE_CONTAINER_PORT = FORWARDER_PORT

# LLM endpoints (vLLM) expose OpenAI-compatible routes on the user container port.
# Route Services directly to the user port for those routes to avoid dependencies
# on the http-forwarder container behavior.
OPENAI_CHAT_COMPLETION_PATH = "/v1/chat/completions"
OPENAI_COMPLETION_PATH = "/v1/completions"


class _BaseResourceArguments(TypedDict):
"""Keyword-arguments for substituting into all resource templates."""
Expand Down Expand Up @@ -1325,6 +1331,18 @@ def get_endpoint_resource_arguments_from_request(
node_port_dict = DictStrInt(f"nodePort: {node_port}")
else:
node_port_dict = DictStrInt("")

all_routes: List[str] = []
if isinstance(flavor, RunnableImageLike):
all_routes = list(flavor.routes or []) + list(flavor.extra_routes or [])

# If this endpoint exposes OpenAI-compatible routes, send traffic directly to the
# vLLM OpenAI server (5005). Otherwise keep legacy behavior (forwarder on 5000).
service_target_port = (
USER_CONTAINER_PORT
if (OPENAI_CHAT_COMPLETION_PATH in all_routes or OPENAI_COMPLETION_PATH in all_routes)
else FORWARDER_PORT
)
return ServiceArguments(
# Base resource arguments
RESOURCE_NAME=k8s_resource_group_name,
Expand All @@ -1339,7 +1357,7 @@ def get_endpoint_resource_arguments_from_request(
# Service arguments
NODE_PORT_DICT=node_port_dict,
SERVICE_TYPE=service_type,
SERVICE_TARGET_PORT=FORWARDER_PORT,
SERVICE_TARGET_PORT=service_target_port,
)
elif endpoint_resource_name == "lws-service":
# Use ClusterIP by default for sync endpoint.
Expand All @@ -1350,6 +1368,16 @@ def get_endpoint_resource_arguments_from_request(
node_port_dict = DictStrInt(f"nodePort: {node_port}")
else:
node_port_dict = DictStrInt("")

all_routes: List[str] = []
if isinstance(flavor, RunnableImageLike):
all_routes = list(flavor.routes or []) + list(flavor.extra_routes or [])

service_target_port = (
USER_CONTAINER_PORT
if (OPENAI_CHAT_COMPLETION_PATH in all_routes or OPENAI_COMPLETION_PATH in all_routes)
else FORWARDER_PORT
)
return LwsServiceArguments(
# Base resource arguments
RESOURCE_NAME=k8s_resource_group_name,
Expand All @@ -1364,7 +1392,7 @@ def get_endpoint_resource_arguments_from_request(
# Service arguments
NODE_PORT_DICT=node_port_dict,
SERVICE_TYPE=service_type,
SERVICE_TARGET_PORT=FORWARDER_PORT,
SERVICE_TARGET_PORT=service_target_port,
# LWS Service args
SERVICE_NAME_OVERRIDE=service_name_override,
)
Expand Down