Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.fleets import InstanceGroupPlacement
from dstack._internal.core.models.instances import (
HealthStatus,
InstanceAvailability,
InstanceOfferWithAvailability,
InstanceRuntime,
Expand Down Expand Up @@ -75,6 +76,7 @@
InstanceHealthResponse,
)
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services import events
from dstack._internal.server.services.fleets import (
fleet_model_to_fleet,
get_create_instance_offers,
Expand Down Expand Up @@ -759,8 +761,8 @@ async def _check_instance(session: AsyncSession, instance: InstanceModel) -> Non
)
session.add(health_check_model)

instance.health = health_status
instance.unreachable = not instance_check.reachable
_set_health(session, instance, health_status)
_set_unreachable(session, instance, unreachable=not instance_check.reachable)

if instance_check.reachable:
instance.termination_deadline = None
Expand Down Expand Up @@ -1093,6 +1095,31 @@ async def _terminate(session: AsyncSession, instance: InstanceModel) -> None:
switch_instance_status(session, instance, InstanceStatus.TERMINATED)


def _set_health(session: AsyncSession, instance: InstanceModel, health: HealthStatus) -> None:
if instance.health != health:
events.emit(
session,
f"Instance health changed {instance.health.upper()} -> {health.upper()}",
actor=events.SystemActor(),
targets=[events.Target.from_model(instance)],
)
instance.health = health


def _set_unreachable(session: AsyncSession, instance: InstanceModel, unreachable: bool) -> None:
if (
instance.status.is_available() # avoid misleading event during provisioning
and instance.unreachable != unreachable
):
events.emit(
session,
"Instance became unreachable" if unreachable else "Instance became reachable",
actor=events.SystemActor(),
targets=[events.Target.from_model(instance)],
)
instance.unreachable = unreachable


def _next_termination_retry_at(instance: InstanceModel) -> datetime.datetime:
assert instance.last_termination_retry_at is not None
return instance.last_termination_retry_at + TERMINATION_RETRY_TIMEOUT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@
UserModel,
)
from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus
from dstack._internal.server.services import events, services
from dstack._internal.server.services import files as files_services
from dstack._internal.server.services import logs as logs_services
from dstack._internal.server.services import services
from dstack._internal.server.services.instances import get_instance_ssh_private_keys
from dstack._internal.server.services.jobs import (
find_job,
Expand Down Expand Up @@ -355,7 +355,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
)

if success:
job_model.disconnected_at = None
_reset_disconnected_at(session, job_model)
else:
if job_model.termination_reason:
logger.warning(
Expand All @@ -368,8 +368,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
# job will be terminated and instance will be emptied by process_terminating_jobs
else:
# No job_model.termination_reason set means ssh connection failed
if job_model.disconnected_at is None:
job_model.disconnected_at = common_utils.get_current_datetime()
_set_disconnected_at_now(session, job_model)
if _should_terminate_job_due_to_disconnect(job_model):
# TODO: Replace with JobTerminationReason.INSTANCE_UNREACHABLE for on-demand.
job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY
Expand Down Expand Up @@ -933,6 +932,28 @@ def _should_terminate_due_to_low_gpu_util(min_util: int, gpus_util: Iterable[Ite
return False


def _set_disconnected_at_now(session: AsyncSession, job_model: JobModel) -> None:
if job_model.disconnected_at is None:
job_model.disconnected_at = common_utils.get_current_datetime()
events.emit(
session,
"Job became unreachable",
actor=events.SystemActor(),
targets=[events.Target.from_model(job_model)],
)


def _reset_disconnected_at(session: AsyncSession, job_model: JobModel) -> None:
if job_model.disconnected_at is not None:
job_model.disconnected_at = None
events.emit(
session,
"Job became reachable",
actor=events.SystemActor(),
targets=[events.Target.from_model(job_model)],
)


def _get_cluster_info(
jobs: List[Job],
replica_num: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
get_job_provisioning_data,
get_placement_group_provisioning_data,
get_remote_connection_info,
list_events,
)
from dstack._internal.utils.common import get_current_datetime

Expand Down Expand Up @@ -324,10 +325,13 @@ async def test_check_shim_process_ureachable_state(
healthcheck.assert_called()

await session.refresh(instance)
events = await list_events(session)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert not instance.unreachable
assert len(events) == 1
assert events[0].message == "Instance became reachable"

@pytest.mark.asyncio
@pytest.mark.parametrize("health_status", [HealthStatus.HEALTHY, HealthStatus.FAILURE])
Expand All @@ -351,12 +355,15 @@ async def test_check_shim_switch_to_unreachable_state(
await process_instances()

await session.refresh(instance)
events = await list_events(session)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert instance.unreachable
# Should keep the previous status
assert instance.health == health_status
assert len(events) == 1
assert events[0].message == "Instance became unreachable"

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down Expand Up @@ -384,11 +391,14 @@ async def test_check_shim_check_instance_health(self, test_db, session: AsyncSes
await process_instances()

await session.refresh(instance)
events = await list_events(session)

assert instance is not None
assert instance.status == InstanceStatus.IDLE
assert not instance.unreachable
assert instance.health == HealthStatus.WARNING
assert len(events) == 1
assert events[0].message == "Instance health changed HEALTHY -> WARNING"

res = await session.execute(select(InstanceHealthCheckModel))
health_check = res.scalars().one()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
get_job_runtime_data,
get_run_spec,
get_volume_configuration,
list_events,
)
from dstack._internal.utils.common import get_current_datetime

Expand Down Expand Up @@ -515,9 +516,12 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession):
await process_running_jobs()
assert SSHTunnelMock.call_count == 3
await session.refresh(job)
events = await list_events(session)
assert job is not None
assert job.disconnected_at is not None
assert job.status == JobStatus.PULLING
assert len(events) == 1
assert events[0].message == "Job became unreachable"
with (
patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock,
patch("dstack._internal.server.services.runner.ssh.time.sleep"),
Expand Down