Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/runpod_flash/core/resources/live_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@


class LiveServerlessMixin:
"""Common mixin for live serverless endpoints that locks the image."""
"""Common mixin for live serverless endpoints.

Treats the Flash runtime image as a *default*: if the caller passes an
``imageName`` (e.g. via ``Endpoint(image=...)`` in client mode), that
value wins. Otherwise the Flash runtime image for this resource type is
used so decorator-mode workloads continue to deploy the Flash wrapper.

The default is applied via the ``@model_validator(mode="before")`` on each
concrete subclass (see ``_apply_default_live_image``); reads and writes of
``imageName`` go through the normal Pydantic field machinery so model
serialization, drift detection, and ``setattr`` all stay consistent.
"""

_image_type: ClassVar[str] = (
"" # override in subclasses: 'gpu', 'cpu', 'lb', 'lb-cpu'
Expand All @@ -27,13 +38,15 @@ def _live_image(self) -> str:
python_version = getattr(self, "python_version", None) or DEFAULT_PYTHON_VERSION
return get_image_name(self._image_type, python_version)

@property
def imageName(self):
return self._live_image

@imageName.setter
def imageName(self, value):
pass
def _apply_default_live_image(data, image_type: str):
"""Set the Flash runtime image as a default if the caller didn't supply one."""
if not isinstance(data, dict):
return data
if not data.get("imageName"):
python_version = data.get("python_version") or DEFAULT_PYTHON_VERSION
data["imageName"] = get_image_name(image_type, python_version)
return data


class LiveServerless(LiveServerlessMixin, ServerlessEndpoint):
Expand All @@ -44,10 +57,8 @@ class LiveServerless(LiveServerlessMixin, ServerlessEndpoint):
@model_validator(mode="before")
@classmethod
def set_live_serverless_template(cls, data: dict):
"""Set default GPU image for Live Serverless."""
python_version = data.get("python_version") or DEFAULT_PYTHON_VERSION
data["imageName"] = get_image_name("gpu", python_version)
return data
"""Default to the GPU Flash runtime image when none is supplied."""
return _apply_default_live_image(data, "gpu")


class CpuLiveServerless(LiveServerlessMixin, CpuServerlessEndpoint):
Expand All @@ -58,10 +69,8 @@ class CpuLiveServerless(LiveServerlessMixin, CpuServerlessEndpoint):
@model_validator(mode="before")
@classmethod
def set_live_serverless_template(cls, data: dict):
"""Set default CPU image for Live Serverless."""
python_version = data.get("python_version") or DEFAULT_PYTHON_VERSION
data["imageName"] = get_image_name("cpu", python_version)
return data
"""Default to the CPU Flash runtime image when none is supplied."""
return _apply_default_live_image(data, "cpu")


class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource):
Expand All @@ -72,10 +81,8 @@ class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource):
@model_validator(mode="before")
@classmethod
def set_live_lb_template(cls, data: dict):
"""Set default image for Live Load-Balanced endpoint."""
python_version = data.get("python_version") or DEFAULT_PYTHON_VERSION
data["imageName"] = get_image_name("lb", python_version)
return data
"""Default to the LB Flash runtime image when none is supplied."""
return _apply_default_live_image(data, "lb")


class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource):
Expand All @@ -86,7 +93,5 @@ class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource):
@model_validator(mode="before")
@classmethod
def set_live_cpu_lb_template(cls, data: dict):
"""Set default CPU image for Live Load-Balanced endpoint."""
python_version = data.get("python_version") or DEFAULT_PYTHON_VERSION
data["imageName"] = get_image_name("lb-cpu", python_version)
return data
"""Default to the CPU LB Flash runtime image when none is supplied."""
return _apply_default_live_image(data, "lb-cpu")
6 changes: 6 additions & 0 deletions src/runpod_flash/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ def __init__(
self.execution_timeout_ms = execution_timeout_ms
self.flashboot = flashboot
self.image = image
if image is not None:
log.info(
"Endpoint %r: using user-supplied image %r (overrides Flash runtime image)",
name,
image,
)
self._explicit_scaler_type = scaler_type
self.scaler_value = scaler_value
self.template = template
Expand Down
24 changes: 12 additions & 12 deletions tests/integration/test_cpu_disk_sizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,28 +244,28 @@ def test_mixed_cpu_generations_integration(self):
assert "cpu5c-1-2: max 15GB" in error_msg


class TestLiveServerlessImageLockingIntegration:
"""Test image locking integration in live serverless variants."""
class TestLiveServerlessImageIntegration:
"""Test image default + override behavior in live serverless variants (AE-3153)."""

def test_live_serverless_image_consistency(self):
"""Test that LiveServerless variants maintain image consistency."""
"""LiveServerless variants default to distinct Flash runtime images."""
gpu_live = LiveServerless(name="gpu-live")
cpu_live = CpuLiveServerless(name="cpu-live")

# Verify different images are used
# Verify different default images are used per resource type.
assert gpu_live.imageName != cpu_live.imageName
assert "flash:" in gpu_live.imageName
assert "flash-cpu:" in cpu_live.imageName

# Verify images remain locked despite attempts to change
original_gpu_image = gpu_live.imageName
original_cpu_image = cpu_live.imageName

gpu_live.imageName = "custom/image:latest"
cpu_live.imageName = "custom/image:latest"
def test_live_serverless_image_override_via_constructor(self):
"""Caller-supplied imageName overrides the Flash runtime default (AE-3153)."""
gpu_live = LiveServerless(name="gpu-live", imageName="custom/image:latest")
cpu_live = CpuLiveServerless(
name="cpu-live", imageName="custom/cpu-image:latest"
)

assert gpu_live.imageName == original_gpu_image
assert cpu_live.imageName == original_cpu_image
assert gpu_live.imageName == "custom/image:latest"
assert cpu_live.imageName == "custom/cpu-image:latest"

def test_live_serverless_template_integration(self):
"""Test live serverless template integration with disk sizing."""
Expand Down
24 changes: 11 additions & 13 deletions tests/integration/test_lb_remote_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,17 @@ async def echo(message: str):
assert "flash-lb" in lb.imageName
assert echo.__remote_config__["method"] == "POST"

def test_live_load_balancer_image_locked(self):
"""Test that LiveLoadBalancer locks the image to Flash LB image."""
lb = LiveLoadBalancer(name="test-api")

# Verify image is locked and cannot be overridden
original_image = lb.imageName
assert "flash-lb" in original_image

# Try to set a different image (should be ignored due to property)
lb.imageName = "custom-image:latest"

# Image should still be locked to Flash
assert lb.imageName == original_image
def test_live_load_balancer_image_default_and_override(self):
"""LiveLoadBalancer defaults to the Flash LB image but honors overrides (AE-3153)."""
# Default path: no caller image -> Flash LB runtime image.
default_lb = LiveLoadBalancer(name="test-api-default")
assert "flash-lb" in default_lb.imageName

# Override path: caller-supplied image is used verbatim.
custom_lb = LiveLoadBalancer(
name="test-api-custom", imageName="custom-image:latest"
)
assert custom_lb.imageName == "custom-image:latest"

def test_load_balancer_vs_queue_based_endpoints(self):
"""Test that LB and QB endpoints have different characteristics."""
Expand Down
82 changes: 47 additions & 35 deletions tests/unit/resources/test_live_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,20 @@ def test_live_serverless_gpu_defaults(self):
assert live_serverless.template.containerDiskInGb == 64
assert "flash:" in live_serverless.imageName # GPU image

def test_live_serverless_image_locked(self):
"""Test LiveServerless imageName is locked to GPU image."""
def test_live_serverless_image_override_via_constructor(self):
"""LiveServerless accepts a caller-supplied imageName (AE-3153)."""
live_serverless = LiveServerless(
name="example_gpu_live_serverless",
imageName="custom/image:latest",
)

original_image = live_serverless.imageName
# User-supplied image wins over the Flash runtime default.
assert live_serverless.imageName == "custom/image:latest"

# Attempt to change imageName - should be ignored
live_serverless.imageName = "custom/image:latest"

assert live_serverless.imageName == original_image
assert "flash:" in live_serverless.imageName # Still GPU image
def test_live_serverless_image_default_unchanged(self):
"""LiveServerless still defaults to the Flash GPU runtime image."""
live_serverless = LiveServerless(name="example_gpu_live_serverless")
assert "flash:" in live_serverless.imageName

def test_live_serverless_with_custom_template(self):
"""Test LiveServerless with custom template."""
Expand Down Expand Up @@ -113,20 +114,23 @@ def test_cpu_live_serverless_multiple_instances(self):
assert live_serverless.template is not None
assert live_serverless.template.containerDiskInGb == 10 # Min of 10 and 30

def test_cpu_live_serverless_image_locked(self):
"""Test CpuLiveServerless imageName is locked to CPU image."""
def test_cpu_live_serverless_image_override_via_constructor(self):
"""CpuLiveServerless accepts a caller-supplied imageName (AE-3153)."""
live_serverless = CpuLiveServerless(
name="example_cpu_live_serverless",
instanceIds=[CpuInstanceType.CPU3G_1_4],
imageName="custom/image:latest",
)

original_image = live_serverless.imageName

# Attempt to change imageName - should be ignored
live_serverless.imageName = "custom/image:latest"
assert live_serverless.imageName == "custom/image:latest"

assert live_serverless.imageName == original_image
assert "flash-cpu:" in live_serverless.imageName # Still CPU image
def test_cpu_live_serverless_image_default_unchanged(self):
"""CpuLiveServerless still defaults to the Flash CPU runtime image."""
live_serverless = CpuLiveServerless(
name="example_cpu_live_serverless",
instanceIds=[CpuInstanceType.CPU3G_1_4],
)
assert "flash-cpu:" in live_serverless.imageName

def test_cpu_live_serverless_validation_failure(self):
"""Test CpuLiveServerless validation fails with excessive disk size."""
Expand Down Expand Up @@ -190,32 +194,40 @@ def test_live_image_property_cpu(self):
assert "flash-cpu:" in live_serverless._live_image

def test_image_name_property_gpu(self):
"""Test LiveServerless imageName property returns locked image."""
"""LiveServerless defaults imageName to the Flash runtime image when none supplied."""
live_serverless = LiveServerless(name="test")
assert live_serverless.imageName == live_serverless._live_image

def test_image_name_property_cpu(self):
"""Test CpuLiveServerless imageName property returns locked image."""
"""CpuLiveServerless defaults imageName to the Flash runtime image when none supplied."""
live_serverless = CpuLiveServerless(name="test")
assert live_serverless.imageName == live_serverless._live_image

def test_image_name_setter_ignored_gpu(self):
"""Test LiveServerless imageName setter is ignored."""
live_serverless = LiveServerless(name="test")
original_image = live_serverless.imageName

live_serverless.imageName = "should-be-ignored"

assert live_serverless.imageName == original_image

def test_image_name_setter_ignored_cpu(self):
"""Test CpuLiveServerless imageName setter is ignored."""
live_serverless = CpuLiveServerless(name="test")
original_image = live_serverless.imageName

live_serverless.imageName = "should-be-ignored"

assert live_serverless.imageName == original_image
def test_image_name_override_gpu(self):
"""LiveServerless honors caller-supplied imageName (AE-3153)."""
live_serverless = LiveServerless(name="test", imageName="byo/image:v1")
assert live_serverless.imageName == "byo/image:v1"

def test_image_name_override_cpu(self):
"""CpuLiveServerless honors caller-supplied imageName (AE-3153)."""
live_serverless = CpuLiveServerless(name="test", imageName="byo/cpu-image:v1")
assert live_serverless.imageName == "byo/cpu-image:v1"

def test_image_name_override_lb(self):
"""LiveLoadBalancer honors caller-supplied imageName (AE-3153)."""
lb = LiveLoadBalancer(name="test", imageName="byo/lb-image:v1")
assert lb.imageName == "byo/lb-image:v1"

def test_image_name_override_cpu_lb(self):
"""CpuLiveLoadBalancer honors caller-supplied imageName (AE-3153)."""
lb = CpuLiveLoadBalancer(name="test", imageName="byo/cpu-lb-image:v1")
assert lb.imageName == "byo/cpu-lb-image:v1"

def test_default_image_validator_passes_through_non_dict(self):
"""`mode='before'` validators must tolerate non-dict input (e.g., model instances)."""
original = LiveServerless(name="test", imageName="byo/image:v1")
revalidated = LiveServerless.model_validate(original)
assert revalidated.imageName == "byo/image:v1"


class TestLiveServerlessPythonVersion:
Expand Down
Loading