Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions integration-tests/tests/healthcheck_custom.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Skip for coglet_alpha which does not support this
[coglet_alpha] skip
# Skip for coglet_rust which does not support this
[coglet_rust] skip
# Skip for cog-dataclass which does not support this
[cog_dataclass] skip

# Test custom healthcheck functionality
# This tests the user-defined healthcheck() method in predictors

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Test 1: Healthy healthcheck returns READY status
curl GET /health-check
stdout '"status":"READY"'
! stdout 'user_healthcheck_error'

# Test 2: Make a prediction to ensure predictor works
curl POST /predictions '{"input":{"text":"world"}}'
stdout '"output":"hello world"'

# Test 3: Health check still works after prediction
curl GET /health-check
stdout '"status":"READY"'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
def predict(self, text: str) -> str:
return f"hello {text}"

def healthcheck(self) -> bool:
"""Custom healthcheck that always returns healthy."""
return True
43 changes: 43 additions & 0 deletions integration-tests/tests/healthcheck_exception.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Skip for coglet_alpha which does not support this
[coglet_alpha] skip
# Skip for coglet_rust which does not support this
[coglet_rust] skip
# Skip for cog-dataclass which does not support this
[cog_dataclass] skip

# Test healthcheck that raises an exception
# This tests error handling when healthcheck throws

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Exception in healthcheck should return UNHEALTHY with error message
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'Critical system error'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
def setup(self) -> None:
self._healthcheck_calls = 0

def predict(self, text: str) -> str:
return f"hello {text}"

def healthcheck(self) -> bool:
"""Healthcheck that raises an exception after startup."""
self._healthcheck_calls += 1
if self._healthcheck_calls == 1:
return True
raise RuntimeError("Critical system error")
45 changes: 45 additions & 0 deletions integration-tests/tests/healthcheck_timeout.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Skip for coglet_alpha which does not support this
[coglet_alpha] skip
# Skip for coglet_rust which does not support this
[coglet_rust] skip
# Skip for cog-dataclass which does not support this
[cog_dataclass] skip

# Test healthcheck timeout behavior
# This tests when healthcheck takes too long (>5 seconds)

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Timeout in healthcheck should return UNHEALTHY with timeout message
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'timed out after 5.0 seconds'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import asyncio
from cog import BasePredictor

class Predictor(BasePredictor):
def setup(self) -> None:
self._healthcheck_calls = 0

def predict(self, text: str) -> str:
return f"hello {text}"

async def healthcheck(self) -> bool:
"""Healthcheck that times out after startup."""
self._healthcheck_calls += 1
if self._healthcheck_calls == 1:
return True
await asyncio.sleep(10) # Sleep longer than the 5 second timeout
return True
43 changes: 43 additions & 0 deletions integration-tests/tests/healthcheck_unhealthy.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Skip for coglet_alpha which does not support this
[coglet_alpha] skip
# Skip for coglet_rust which does not support this
[coglet_rust] skip
# Skip for cog-dataclass which does not support this
[cog_dataclass] skip

# Test unhealthy healthcheck behavior
# This tests when healthcheck returns False

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Unhealthy healthcheck should return UNHEALTHY status
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'user-defined healthcheck returned False'

-- cog.yaml --
build:
python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
def setup(self) -> None:
self._healthcheck_calls = 0

def predict(self, text: str) -> str:
return f"hello {text}"

def healthcheck(self) -> bool:
"""Unhealthy healthcheck after startup."""
self._healthcheck_calls += 1
if self._healthcheck_calls == 1:
return True
return False
7 changes: 7 additions & 0 deletions python/cog/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ def train(self, **kwargs: Any) -> Any:
Run a single train on the model
"""
raise NotImplementedError("train has not been implemented by parent class.")

def healthcheck(self) -> Any:
"""
An optional method to perform custom health checks on the model.
Return True if healthy, False or raise an exception if unhealthy.
"""
return True # If unimplemented don't kill the container
7 changes: 7 additions & 0 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,13 @@ def get_train(predictor: Any) -> Callable[..., Any]:
return predictor


def get_healthcheck(predictor: Any) -> Optional[Callable[..., Any]]:
"""Get the healthcheck method if it exists."""
if hasattr(predictor, "healthcheck"):
return predictor.healthcheck
return None


def get_training_input_type(predictor: BasePredictor) -> Type[BaseInput]:
"""
Creates a Pydantic Input model from the arguments of a Predictor's train() method.
Expand Down
7 changes: 7 additions & 0 deletions python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class Shutdown:
pass


@define
class Healthcheck:
pass


# From predictor child process
#
@define
Expand Down Expand Up @@ -50,6 +55,7 @@ class Done:
canceled: bool = False
error: bool = False
error_detail: str = ""
event_type: str = "prediction"


@define
Expand All @@ -63,6 +69,7 @@ class Envelope:
Cancel,
PredictionInput,
Shutdown,
Healthcheck,
Log,
PredictionMetric,
PredictionOutput,
Expand Down
24 changes: 23 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Health(Enum):
BUSY = auto()
SETUP_FAILED = auto()
DEFUNCT = auto()
UNHEALTHY = auto()


class MyState:
Expand Down Expand Up @@ -358,10 +359,31 @@ async def root() -> Any:
async def healthcheck() -> Any:
if app.state.health == Health.READY:
health = Health.BUSY if runner.is_busy() else Health.READY

# Run custom healthcheck. If it doesn't exist, this will
# always return healthy (healthcheck_result.error = False)
healthcheck_result = await runner.healthcheck()
custom_health_ok = not healthcheck_result.error
custom_health_error = healthcheck_result.error_detail

if not custom_health_ok:
health = Health.UNHEALTHY
else:
health = app.state.health
custom_health_ok = True
custom_health_error = None

setup = app.state.setup_result.to_dict() if app.state.setup_result else {}
return jsonable_encoder({"status": health.name, "setup": setup})

response = {
"status": health.name,
"setup": setup,
}

if not custom_health_ok:
response["user_healthcheck_error"] = custom_health_error

return jsonable_encoder(response)

@limited
@app.post(
Expand Down
9 changes: 9 additions & 0 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ def cancel(self, prediction_id: str) -> None:
raise UnknownPredictionError("unknown prediction id")
self._worker.cancel(tag=prediction_id)

async def healthcheck(self) -> Done:
"""Run the user's healthcheck method."""
# Don't run healthcheck if the setup task is not completed
if self._setup_task is None or not self._setup_task.done():
return Done(event_type="healthcheck")

result = await asyncio.wrap_future(self._worker.healthcheck())
return result

def _raise_if_busy(self) -> None:
if self._setup_task is None:
# Setup hasn't been called yet.
Expand Down
Loading