Skip to content

Commit a849a7f

Browse files
Add on-prem env vars to Helm chart (S3_ENDPOINT_URL, REDIS_HOST, etc.)
1 parent 0f6674a commit a849a7f

11 files changed

Lines changed: 612 additions & 46 deletions

File tree

charts/model-engine/templates/_helpers.tpl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,10 @@ env:
256256
- name: ABS_CONTAINER_NAME
257257
value: {{ .Values.azure.abs_container_name }}
258258
{{- end }}
259+
{{- if .Values.s3EndpointUrl }}
260+
- name: S3_ENDPOINT_URL
261+
value: {{ .Values.s3EndpointUrl | quote }}
262+
{{- end }}
259263
{{- end }}
260264

261265
{{- define "modelEngine.syncForwarderTemplateEnv" -}}
@@ -342,9 +346,27 @@ env:
342346
value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml"
343347
{{- end }}
344348
- name: CELERY_ELASTICACHE_ENABLED
345-
value: "true"
349+
value: {{ .Values.celeryElasticacheEnabled | default true | quote }}
346350
- name: LAUNCH_SERVICE_TEMPLATE_FOLDER
347351
value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates"
352+
{{- if .Values.s3EndpointUrl }}
353+
- name: S3_ENDPOINT_URL
354+
value: {{ .Values.s3EndpointUrl | quote }}
355+
{{- end }}
356+
{{- if .Values.redisHost }}
357+
- name: REDIS_HOST
358+
value: {{ .Values.redisHost | quote }}
359+
- name: REDIS_PORT
360+
value: {{ .Values.redisPort | default "6379" | quote }}
361+
{{- end }}
362+
{{- if .Values.celeryBrokerUrl }}
363+
- name: CELERY_BROKER_URL
364+
value: {{ .Values.celeryBrokerUrl | quote }}
365+
{{- end }}
366+
{{- if .Values.celeryResultBackend }}
367+
- name: CELERY_RESULT_BACKEND
368+
value: {{ .Values.celeryResultBackend | quote }}
369+
{{- end }}
348370
{{- if .Values.redis.auth}}
349371
- name: REDIS_AUTH_TOKEN
350372
value: {{ .Values.redis.auth }}

model-engine/model_engine_server/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def cache_redis_url(self) -> str:
9696
# On-prem Redis support - check explicit URL first, then fallback to env vars
9797
if self.cache_redis_onprem_url:
9898
return self.cache_redis_onprem_url
99-
99+
100100
if cloud_provider == "onprem":
101101
if self.cache_redis_aws_url:
102102
logger.info("On-prem deployment using cache_redis_aws_url")

model-engine/model_engine_server/core/aws/roles.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ def session(role: Optional[str], session_type: SessionT = Session) -> SessionT:
119119
120120
:param:`session_type` defines the type of session to return. Most users will use
121121
the default boto3 type. Some users required a special type (e.g aioboto3 session).
122-
122+
123123
For on-prem deployments without AWS profiles, pass role=None or role=""
124124
to use default credentials from environment variables (AWS_ACCESS_KEY_ID, etc).
125125
"""
126126
# Do not assume roles in CIRCLECI
127127
if os.getenv("CIRCLECI"):
128128
logger.warning(f"In circleci, not assuming role (ignoring: {role})")
129129
role = None
130-
130+
131131
# Use profile-based auth only if role is specified
132132
# For on-prem with MinIO, role will be None or empty - use env var credentials
133133
if role:

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 82 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import yaml
1818
from model_engine_server.common.config import hmi_config
19-
from model_engine_server.core.config import infra_config
2019
from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests
2120
from model_engine_server.common.dtos.llms import (
2221
ChatCompletionV2Request,
@@ -62,6 +61,7 @@
6261
from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus
6362
from model_engine_server.common.resource_limits import validate_resource_requests
6463
from model_engine_server.core.auth.authentication_repository import User
64+
from model_engine_server.core.config import infra_config
6565
from model_engine_server.core.configmap import read_config_map
6666
from model_engine_server.core.loggers import (
6767
LoggerTagKey,
@@ -373,7 +373,7 @@ def check_docker_image_exists_for_image_tag(
373373
# Skip ECR validation for on-prem deployments - images are in local registry
374374
if infra_config().cloud_provider == "onprem":
375375
return
376-
376+
377377
if not self.docker_repository.image_exists(
378378
image_tag=framework_image_tag,
379379
repository_name=repository_name,
@@ -638,9 +638,11 @@ def load_model_weights_sub_commands_s3(
638638
file_selection_str = '--include "*.model" --include "*.model.v*" --include "*.json" --include "*.safetensors" --include "*.txt" --exclude "optimizer*"'
639639
if trust_remote_code:
640640
file_selection_str += ' --include "*.py"'
641-
641+
642642
# Support for MinIO/on-prem S3-compatible storage via S3_ENDPOINT_URL env var
643-
endpoint_flag = '$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
643+
endpoint_flag = (
644+
'$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
645+
)
644646
subcommands.append(
645647
f"{s5cmd} {endpoint_flag} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
646648
)
@@ -695,7 +697,9 @@ def load_model_files_sub_commands_trt_llm(
695697
"""
696698
if checkpoint_path.startswith("s3://"):
697699
# Support for MinIO/on-prem S3-compatible storage via S3_ENDPOINT_URL env var
698-
endpoint_flag = '$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
700+
endpoint_flag = (
701+
'$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
702+
)
699703
subcommands = [
700704
f"./s5cmd {endpoint_flag} --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./"
701705
]
@@ -1028,8 +1032,9 @@ async def create_vllm_bundle(
10281032
protocol="http",
10291033
readiness_initial_delay_seconds=10,
10301034
healthcheck_route="/health",
1031-
predict_route="/predict",
1032-
streaming_predict_route="/stream",
1035+
# vLLM 0.5+ uses OpenAI-compatible endpoints
1036+
predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions"
1037+
streaming_predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions" (streaming via same endpoint)
10331038
routes=[
10341039
OPENAI_CHAT_COMPLETION_PATH,
10351040
OPENAI_COMPLETION_PATH,
@@ -1110,8 +1115,9 @@ async def create_vllm_multinode_bundle(
11101115
protocol="http",
11111116
readiness_initial_delay_seconds=10,
11121117
healthcheck_route="/health",
1113-
predict_route="/predict",
1114-
streaming_predict_route="/stream",
1118+
# vLLM 0.5+ uses OpenAI-compatible endpoints
1119+
predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions"
1120+
streaming_predict_route=OPENAI_COMPLETION_PATH, # "/v1/completions" (streaming via same endpoint)
11151121
routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
11161122
env=common_vllm_envs,
11171123
worker_command=worker_command,
@@ -1912,18 +1918,42 @@ def model_output_to_completion_output(
19121918

19131919
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
19141920
tokens = None
1915-
if with_token_probs:
1916-
tokens = [
1917-
TokenOutput(
1918-
token=model_output["tokens"][index],
1919-
log_prob=list(t.values())[0],
1920-
)
1921-
for index, t in enumerate(model_output["log_probs"])
1922-
]
1921+
# Handle OpenAI-compatible format (vLLM 0.5+) vs legacy format
1922+
if "choices" in model_output and model_output["choices"]:
1923+
# OpenAI-compatible format: {"choices": [{"text": "...", ...}], "usage": {...}}
1924+
choice = model_output["choices"][0]
1925+
text = choice.get("text", "")
1926+
usage = model_output.get("usage", {})
1927+
num_prompt_tokens = usage.get("prompt_tokens", 0)
1928+
num_completion_tokens = usage.get("completion_tokens", 0)
1929+
# OpenAI format logprobs are in choice.logprobs
1930+
if with_token_probs and choice.get("logprobs"):
1931+
logprobs = choice["logprobs"]
1932+
if logprobs.get("tokens") and logprobs.get("token_logprobs"):
1933+
tokens = [
1934+
TokenOutput(
1935+
token=logprobs["tokens"][i],
1936+
log_prob=logprobs["token_logprobs"][i] or 0.0,
1937+
)
1938+
for i in range(len(logprobs["tokens"]))
1939+
]
1940+
else:
1941+
# Legacy format: {"text": "...", "count_prompt_tokens": ..., ...}
1942+
text = model_output["text"]
1943+
num_prompt_tokens = model_output["count_prompt_tokens"]
1944+
num_completion_tokens = model_output["count_output_tokens"]
1945+
if with_token_probs and model_output.get("log_probs"):
1946+
tokens = [
1947+
TokenOutput(
1948+
token=model_output["tokens"][index],
1949+
log_prob=list(t.values())[0],
1950+
)
1951+
for index, t in enumerate(model_output["log_probs"])
1952+
]
19231953
return CompletionOutput(
1924-
text=model_output["text"],
1925-
num_prompt_tokens=model_output["count_prompt_tokens"],
1926-
num_completion_tokens=model_output["count_output_tokens"],
1954+
text=text,
1955+
num_prompt_tokens=num_prompt_tokens,
1956+
num_completion_tokens=num_completion_tokens,
19271957
tokens=tokens,
19281958
)
19291959
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
@@ -2663,20 +2693,43 @@ async def _response_chunk_generator(
26632693
# VLLM
26642694
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
26652695
token = None
2666-
if request.return_token_log_probs:
2667-
token = TokenOutput(
2668-
token=result["result"]["text"],
2669-
log_prob=list(result["result"]["log_probs"].values())[0],
2670-
)
2671-
finished = result["result"]["finished"]
2672-
num_prompt_tokens = result["result"]["count_prompt_tokens"]
2696+
vllm_output: dict = result["result"]
2697+
# Handle OpenAI-compatible streaming format (vLLM 0.5+) vs legacy format
2698+
if "choices" in vllm_output and vllm_output["choices"]:
2699+
# OpenAI streaming format: {"choices": [{"text": "...", "finish_reason": ...}], ...}
2700+
choice = vllm_output["choices"][0]
2701+
text = choice.get("text", "")
2702+
finished = choice.get("finish_reason") is not None
2703+
usage = vllm_output.get("usage", {})
2704+
num_prompt_tokens = usage.get("prompt_tokens", 0)
2705+
num_completion_tokens = usage.get("completion_tokens", 0)
2706+
if request.return_token_log_probs and choice.get("logprobs"):
2707+
logprobs = choice["logprobs"]
2708+
if logprobs.get("tokens") and logprobs.get("token_logprobs"):
2709+
# Get the last token from the logprobs
2710+
idx = len(logprobs["tokens"]) - 1
2711+
token = TokenOutput(
2712+
token=logprobs["tokens"][idx],
2713+
log_prob=logprobs["token_logprobs"][idx] or 0.0,
2714+
)
2715+
else:
2716+
# Legacy format: {"text": "...", "finished": ..., ...}
2717+
text = vllm_output["text"]
2718+
finished = vllm_output["finished"]
2719+
num_prompt_tokens = vllm_output["count_prompt_tokens"]
2720+
num_completion_tokens = vllm_output["count_output_tokens"]
2721+
if request.return_token_log_probs and vllm_output.get("log_probs"):
2722+
token = TokenOutput(
2723+
token=vllm_output["text"],
2724+
log_prob=list(vllm_output["log_probs"].values())[0],
2725+
)
26732726
yield CompletionStreamV1Response(
26742727
request_id=request_id,
26752728
output=CompletionStreamOutput(
2676-
text=result["result"]["text"],
2729+
text=text,
26772730
finished=finished,
26782731
num_prompt_tokens=num_prompt_tokens if finished else None,
2679-
num_completion_tokens=result["result"]["count_output_tokens"],
2732+
num_completion_tokens=num_completion_tokens,
26802733
token=token,
26812734
),
26822735
)

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_s3_client():
6262
session = boto3.Session(profile_name=profile_name)
6363
else:
6464
session = boto3.Session()
65-
65+
6666
# Support for MinIO/on-prem S3-compatible storage
6767
endpoint_url = os.getenv("S3_ENDPOINT_URL")
6868
return session.client("s3", region_name=AWS_REGION, endpoint_url=endpoint_url)
@@ -72,7 +72,7 @@ def download_model(checkpoint_path, final_weights_folder):
7272
# Support for MinIO/on-prem S3-compatible storage
7373
s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "")
7474
endpoint_flag = f"--endpoint-url {s3_endpoint_url}" if s3_endpoint_url else ""
75-
75+
7676
s5cmd = f"./s5cmd {endpoint_flag} --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
7777
env = os.environ.copy()
7878
env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default")

model-engine/model_engine_server/inference/vllm/vllm_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ async def download_model(checkpoint_path: str, target_dir: str, trust_remote_cod
7878

7979
print(f"Downloading model from {checkpoint_path} to {target_dir}", flush=True)
8080
additional_include = "--include '*.py'" if trust_remote_code else ""
81-
81+
8282
# Support for MinIO/on-prem S3-compatible storage
8383
s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "")
8484
endpoint_flag = f"--endpoint-url {s3_endpoint_url}" if s3_endpoint_url else ""
85-
85+
8686
s5cmd = f"./s5cmd {endpoint_flag} --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --include '*.txt' {additional_include} --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}"
8787
print(s5cmd, flush=True)
8888
env = os.environ.copy()

model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def get_endpoint_resource_arguments_from_request(
579579
abs_account_name = os.getenv("ABS_ACCOUNT_NAME")
580580
if abs_account_name is not None:
581581
main_env.append({"name": "ABS_ACCOUNT_NAME", "value": abs_account_name})
582-
582+
583583
# Support for MinIO/on-prem S3-compatible storage
584584
s3_endpoint_url = os.getenv("S3_ENDPOINT_URL")
585585
if s3_endpoint_url:

model-engine/model_engine_server/infra/gateways/s3_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,11 @@ def _get_onprem_client_kwargs() -> Dict[str, Any]:
1414
global _s3_config_logged
1515
client_kwargs: Dict[str, Any] = {}
1616

17-
s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv(
18-
"S3_ENDPOINT_URL"
19-
)
17+
s3_endpoint = getattr(infra_config(), "s3_endpoint_url", None) or os.getenv("S3_ENDPOINT_URL")
2018
if s3_endpoint:
2119
client_kwargs["endpoint_url"] = s3_endpoint
2220

23-
addressing_style = cast(
24-
AddressingStyle, getattr(infra_config(), "s3_addressing_style", "path")
25-
)
21+
addressing_style = cast(AddressingStyle, getattr(infra_config(), "s3_addressing_style", "path"))
2622
client_kwargs["config"] = Config(s3={"addressing_style": addressing_style})
2723

2824
if not _s3_config_logged and s3_endpoint:

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,12 @@ def test_load_model_weights_sub_commands(
583583
framework, framework_image_tag, checkpoint_path, final_weights_folder
584584
)
585585

586+
# Support for MinIO/on-prem S3-compatible storage via S3_ENDPOINT_URL env var
587+
endpoint_flag = (
588+
'$(if [ -n "$S3_ENDPOINT_URL" ]; then echo "--endpoint-url $S3_ENDPOINT_URL"; fi)'
589+
)
586590
expected_result = [
587-
'./s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.model.v*" --include "*.json" --include "*.safetensors" --include "*.txt" --exclude "optimizer*" s3://fake-checkpoint/* test_folder',
591+
f'./s5cmd {endpoint_flag} --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.model.v*" --include "*.json" --include "*.safetensors" --include "*.txt" --exclude "optimizer*" s3://fake-checkpoint/* test_folder',
588592
]
589593
assert expected_result == subcommands
590594

@@ -594,7 +598,7 @@ def test_load_model_weights_sub_commands(
594598
)
595599

596600
expected_result = [
597-
'./s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.model.v*" --include "*.json" --include "*.safetensors" --include "*.txt" --exclude "optimizer*" --include "*.py" s3://fake-checkpoint/* test_folder',
601+
f'./s5cmd {endpoint_flag} --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.model.v*" --include "*.json" --include "*.safetensors" --include "*.txt" --exclude "optimizer*" --include "*.py" s3://fake-checkpoint/* test_folder',
598602
]
599603
assert expected_result == subcommands
600604

0 commit comments

Comments
 (0)