Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class LLMModelEndpointCommonArgs(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = True # LLM endpoints are public by default.
task_expires_seconds: Optional[int] = Field(
default=None,
gt=0,
description="For async endpoints, how long a task can wait in queue before expiring (in seconds).",
)
chat_template_override: Optional[str] = Field(
default=None,
description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint",
Expand Down Expand Up @@ -166,6 +171,10 @@ class GetLLMModelEndpointV1Response(BaseModel):
default=None,
description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint",
)
task_expires_seconds: Optional[int] = Field(
default=None,
description="For async endpoints, how long a task can wait in queue before expiring (in seconds).",
)
spec: Optional[GetModelEndpointV1Response] = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class CreateModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = Field(default=False)
task_expires_seconds: Optional[int] = Field(default=None, gt=0)


class CreateModelEndpointV1Response(BaseModel):
Expand Down Expand Up @@ -100,6 +101,7 @@ class UpdateModelEndpointV1Request(BaseModel):
default_callback_url: Optional[HttpUrlStr] = None
default_callback_auth: Optional[CallbackAuth] = None
public_inference: Optional[bool] = None
task_expires_seconds: Optional[int] = Field(default=None, gt=0)


class UpdateModelEndpointV1Response(BaseModel):
Expand Down Expand Up @@ -128,6 +130,7 @@ class GetModelEndpointV1Response(BaseModel):
resource_state: Optional[ModelEndpointResourceState] = Field(default=None)
num_queued_items: Optional[int] = Field(default=None)
public_inference: Optional[bool] = Field(default=None)
task_expires_seconds: Optional[int] = Field(default=None)


class ListModelEndpointsV1Response(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add task_expires_seconds column

Revision ID: 62da4f8b3403
Revises: 221aa19d3f32
Create Date: 2026-02-10 19:20:00.000000

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = '62da4f8b3403'
down_revision = '221aa19d3f32'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
'endpoints',
sa.Column('task_expires_seconds', sa.Integer, nullable=True),
schema='hosted_model_inference',
)


def downgrade() -> None:
op.drop_column(
'endpoints',
'task_expires_seconds',
schema='hosted_model_inference',
)
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ class Endpoint(Base):
current_bundle = relationship("Bundle")
owner = Column(String(SHORT_STRING))
public_inference = Column(Boolean, default=False)
# Task expiration time in seconds for async endpoints (how long a task can wait in queue)
task_expires_seconds = Column(Integer, nullable=True)

def __init__(
self,
Expand All @@ -482,6 +484,7 @@ def __init__(
endpoint_status: Optional[str] = "READY", # EndpointStatus.ready.value
owner: Optional[str] = None,
public_inference: Optional[bool] = False,
task_expires_seconds: Optional[int] = None,
):
self.id = f"end_{get_xid()}"
self.name = name
Expand All @@ -494,6 +497,7 @@ def __init__(
self.endpoint_status = endpoint_status
self.owner = owner
self.public_inference = public_inference
self.task_expires_seconds = task_expires_seconds

@classmethod
async def create(cls, session: AsyncSession, endpoint: "Endpoint") -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ModelEndpointRecord(OwnedEntity):
current_model_bundle: ModelBundle
owner: str
public_inference: Optional[bool] = None
task_expires_seconds: Optional[int] = None


class ModelEndpointInfraState(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def create_task(
self,
topic: str,
predict_request: EndpointPredictV1Request,
task_timeout_seconds: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is being used in async_inference_use_cases.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just renamed the parameter to task_expires_seconds here. Can you explain more what you mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue was that it wasnt properly passed down to when celery creates the task:

res = celery_dest.send_task(
name=task_name,
args=args,
kwargs=kwargs,
queue=queue_name,
)

Also it wasnt possible to configure this parameter. So basically what I did was

  • rename it (you could argue if thats necessary)
  • pass it down properly to downstream code
  • make it configurable from outside

task_expires_seconds: int,
*,
task_name: str = DEFAULT_CELERY_TASK_NAME,
) -> CreateAsyncTaskV1Response:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async def create_model_endpoint(
default_callback_url: Optional[str],
default_callback_auth: Optional[CallbackAuth],
public_inference: Optional[bool] = False,
task_expires_seconds: Optional[int] = None,
) -> ModelEndpointRecord:
"""
Creates a model endpoint.
Expand Down Expand Up @@ -125,6 +126,7 @@ async def create_model_endpoint(
default_callback_url: The default callback URL to use for the model endpoint.
default_callback_auth: The default callback auth to use for the model endpoint.
public_inference: Whether to allow public inference.
task_expires_seconds: For async endpoints, how long a task can wait in queue before expiring.
Returns:
A Model Endpoint Record domain entity object of the created endpoint.
Raises:
Expand Down Expand Up @@ -222,6 +224,7 @@ async def update_model_endpoint(
default_callback_url: Optional[str] = None,
default_callback_auth: Optional[CallbackAuth] = None,
public_inference: Optional[bool] = None,
task_expires_seconds: Optional[int] = None,
) -> ModelEndpointRecord:
"""
Updates a model endpoint.
Expand Down Expand Up @@ -250,6 +253,7 @@ async def update_model_endpoint(
default_callback_url: The default callback URL to use for the model endpoint.
default_callback_auth: The default callback auth to use for the model endpoint.
public_inference: Whether to allow public inference.
task_expires_seconds: For async endpoints, how long a task can wait in queue before expiring.
Returns:
A Model Endpoint Record domain entity object of the updated endpoint.
Raises:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService

DEFAULT_TASK_TIMEOUT_SECONDS = 86400
DEFAULT_TASK_EXPIRES_SECONDS = 86400


class CreateAsyncInferenceTaskV1UseCase:
Expand Down Expand Up @@ -66,10 +66,11 @@ async def execute(
task_name = model_endpoint.record.current_model_bundle.celery_task_name()

inference_gateway = self.model_endpoint_service.get_async_model_endpoint_inference_gateway()
task_expires = model_endpoint.record.task_expires_seconds or DEFAULT_TASK_EXPIRES_SECONDS
return inference_gateway.create_task(
topic=model_endpoint.record.destination,
predict_request=request,
task_timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS,
task_expires_seconds=task_expires,
task_name=task_name,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def _model_endpoint_entity_to_get_llm_model_endpoint_response(
quantize=llm_metadata.get("quantize"),
checkpoint_path=llm_metadata.get("checkpoint_path"),
chat_template_override=llm_metadata.get("chat_template_override"),
task_expires_seconds=model_endpoint.record.task_expires_seconds,
spec=model_endpoint_entity_to_get_model_endpoint_response(model_endpoint),
)
return response
Expand Down Expand Up @@ -1463,6 +1464,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
task_expires_seconds=request.task_expires_seconds,
)
_handle_post_inference_hooks(
created_by=user.user_id,
Expand Down Expand Up @@ -1730,6 +1732,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
task_expires_seconds=request.task_expires_seconds,
)
_handle_post_inference_hooks(
created_by=endpoint_record.created_by,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def model_endpoint_entity_to_get_model_endpoint_response(
resource_state=(None if infra_state is None else infra_state.resource_state),
num_queued_items=(None if infra_state is None else infra_state.num_queued_items),
public_inference=model_endpoint.record.public_inference,
task_expires_seconds=model_endpoint.record.task_expires_seconds,
)


Expand Down Expand Up @@ -389,6 +390,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
task_expires_seconds=request.task_expires_seconds,
)
_handle_post_inference_hooks(
created_by=user.user_id,
Expand Down Expand Up @@ -517,6 +519,7 @@ async def execute(
default_callback_url=request.default_callback_url,
default_callback_auth=request.default_callback_auth,
public_inference=request.public_inference,
task_expires_seconds=request.task_expires_seconds,
)
_handle_post_inference_hooks(
created_by=endpoint_record.created_by,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def send_task(
args=args,
kwargs=kwargs,
queue=queue_name,
expires=expires,
)

if infra_config().debug_mode: # pragma: no cover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_task(
self,
topic: str,
predict_request: EndpointPredictV1Request,
task_timeout_seconds: int,
task_expires_seconds: int,
*,
task_name: str = DEFAULT_CELERY_TASK_NAME,
) -> CreateAsyncTaskV1Response:
Expand All @@ -39,7 +39,7 @@ def create_task(
task_name=task_name,
queue_name=topic,
args=[predict_args, datetime.now(), predict_request.return_pickled],
expires=task_timeout_seconds,
expires=task_expires_seconds,
)
return CreateAsyncTaskV1Response(task_id=send_task_response.task_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
FORWARDER_PORT = 5000
USER_CONTAINER_PORT = 5005
ARTIFACT_LIKE_CONTAINER_PORT = FORWARDER_PORT
DEFAULT_TASK_EXPIRES_SECONDS = 86400


class _BaseResourceArguments(TypedDict):
Expand Down Expand Up @@ -127,6 +128,7 @@ class _AsyncDeploymentArguments(TypedDict):
BROKER_TYPE: str
SQS_QUEUE_URL: str
SQS_PROFILE: str
TASK_EXPIRES_SECONDS: int


class _SyncArtifactDeploymentArguments(TypedDict):
Expand Down Expand Up @@ -682,6 +684,7 @@ def get_endpoint_resource_arguments_from_request(
BROKER_TYPE=broker_type,
SQS_QUEUE_URL=sqs_queue_url,
SQS_PROFILE=sqs_profile,
TASK_EXPIRES_SECONDS=model_endpoint_record.task_expires_seconds or DEFAULT_TASK_EXPIRES_SECONDS,
)
elif endpoint_resource_name == "deployment-runnable-image-async-gpu":
assert isinstance(flavor, RunnableImageLike)
Expand Down Expand Up @@ -732,6 +735,7 @@ def get_endpoint_resource_arguments_from_request(
BROKER_TYPE=broker_type,
SQS_QUEUE_URL=sqs_queue_url,
SQS_PROFILE=sqs_profile,
TASK_EXPIRES_SECONDS=model_endpoint_record.task_expires_seconds or DEFAULT_TASK_EXPIRES_SECONDS,
# GPU Deployment Arguments
GPU_TYPE=build_endpoint_request.gpu_type.value,
GPUS=build_endpoint_request.gpus,
Expand Down Expand Up @@ -984,6 +988,7 @@ def get_endpoint_resource_arguments_from_request(
BROKER_TYPE=broker_type,
SQS_QUEUE_URL=sqs_queue_url,
SQS_PROFILE=sqs_profile,
TASK_EXPIRES_SECONDS=model_endpoint_record.task_expires_seconds or DEFAULT_TASK_EXPIRES_SECONDS,
# Triton Deployment Arguments
TRITON_MODEL_REPOSITORY=flavor.triton_model_repository,
TRITON_CPUS=str(flavor.triton_num_cpu),
Expand Down Expand Up @@ -1042,6 +1047,7 @@ def get_endpoint_resource_arguments_from_request(
BROKER_TYPE=broker_type,
SQS_QUEUE_URL=sqs_queue_url,
SQS_PROFILE=sqs_profile,
TASK_EXPIRES_SECONDS=model_endpoint_record.task_expires_seconds or DEFAULT_TASK_EXPIRES_SECONDS,
# GPU Deployment Arguments
GPU_TYPE=build_endpoint_request.gpu_type.value,
GPUS=build_endpoint_request.gpus,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ data:
celery.scaleml.autoscaler/perWorker: "${PER_WORKER}"
celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}"
celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}"
celery.scaleml.autoscaler/taskExpiresSeconds: "${TASK_EXPIRES_SECONDS}"
spec:
strategy:
type: RollingUpdate
Expand Down Expand Up @@ -318,6 +319,7 @@ data:
celery.scaleml.autoscaler/perWorker: "${PER_WORKER}"
celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}"
celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}"
celery.scaleml.autoscaler/taskExpiresSeconds: "${TASK_EXPIRES_SECONDS}"
spec:
strategy:
type: RollingUpdate
Expand Down Expand Up @@ -1266,6 +1268,7 @@ data:
celery.scaleml.autoscaler/perWorker: "${PER_WORKER}"
celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}"
celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}"
celery.scaleml.autoscaler/taskExpiresSeconds: "${TASK_EXPIRES_SECONDS}"
spec:
strategy:
type: RollingUpdate
Expand Down Expand Up @@ -1547,6 +1550,7 @@ data:
celery.scaleml.autoscaler/perWorker: "${PER_WORKER}"
celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}"
celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}"
celery.scaleml.autoscaler/taskExpiresSeconds: "${TASK_EXPIRES_SECONDS}"
spec:
strategy:
type: RollingUpdate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def translate_model_endpoint_orm_to_model_endpoint_record(
status=model_endpoint_orm.endpoint_status,
current_model_bundle=current_model_bundle,
public_inference=model_endpoint_orm.public_inference,
task_expires_seconds=model_endpoint_orm.task_expires_seconds,
)


Expand Down Expand Up @@ -120,6 +121,7 @@ async def create_model_endpoint_record(
status: str,
owner: str,
public_inference: Optional[bool] = False,
task_expires_seconds: Optional[int] = None,
) -> ModelEndpointRecord:
model_endpoint_record = OrmModelEndpoint(
name=name,
Expand All @@ -132,6 +134,7 @@ async def create_model_endpoint_record(
endpoint_status=status,
owner=owner,
public_inference=public_inference,
task_expires_seconds=task_expires_seconds,
)
async with self.session() as session:
await OrmModelEndpoint.create(session, model_endpoint_record)
Expand Down Expand Up @@ -305,6 +308,7 @@ async def update_model_endpoint_record(
destination: Optional[str] = None,
status: Optional[str] = None,
public_inference: Optional[bool] = None,
task_expires_seconds: Optional[int] = None,
) -> Optional[ModelEndpointRecord]:
async with self.session() as session:
model_endpoint_orm = await OrmModelEndpoint.select_by_id(
Expand All @@ -324,6 +328,7 @@ async def update_model_endpoint_record(
endpoint_status=status,
last_updated_at=datetime.utcnow(),
public_inference=public_inference,
task_expires_seconds=task_expires_seconds,
)
await OrmModelEndpoint.update_by_name_owner(
session=session,
Expand Down
Loading