Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from uuid import UUID

import gpuhunt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.backends.base.compute import (
Expand Down Expand Up @@ -90,6 +91,7 @@
BackendModel,
ComputeGroupModel,
DecryptedString,
EventModel,
FileArchiveModel,
FleetModel,
GatewayComputeModel,
Expand Down Expand Up @@ -1111,6 +1113,11 @@ async def create_secret(
return secret_model


async def list_events(session: AsyncSession) -> list[EventModel]:
res = await session.execute(select(EventModel).order_by(EventModel.recorded_at, EventModel.id))
return list(res.scalars().all())


def get_private_key_string() -> str:
return """
-----BEGIN RSA PRIVATE KEY-----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@

import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.server import settings
from dstack._internal.server.background.tasks.process_events import delete_events
from dstack._internal.server.models import EventModel
from dstack._internal.server.services import events
from dstack._internal.server.testing.common import create_user
from dstack._internal.server.testing.common import create_user, list_events


@pytest.mark.asyncio
Expand All @@ -27,8 +25,7 @@ async def test_deletes_old_events(test_db, session: AsyncSession) -> None:
)
await session.commit()

res = await session.execute(select(EventModel))
all_events = res.scalars().all()
all_events = await list_events(session)
assert len(all_events) == 10

with (
Expand All @@ -37,8 +34,7 @@ async def test_deletes_old_events(test_db, session: AsyncSession) -> None:
):
await delete_events()

res = await session.execute(select(EventModel).order_by(EventModel.recorded_at))
remaining_events = res.scalars().all()
remaining_events = await list_events(session)
assert len(remaining_events) == 5
assert [e.message for e in remaining_events] == [
"Event 5",
Expand Down
10 changes: 4 additions & 6 deletions src/tests/_internal/server/services/test_instances.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid

import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

import dstack._internal.server.services.instances as instances_services
Expand All @@ -15,13 +14,14 @@
Resources,
)
from dstack._internal.core.models.profiles import Profile
from dstack._internal.server.models import EventModel, InstanceModel
from dstack._internal.server.models import InstanceModel
from dstack._internal.server.testing.common import (
create_instance,
create_project,
create_user,
get_volume,
get_volume_configuration,
list_events,
)
from dstack._internal.utils.common import get_current_datetime

Expand All @@ -41,8 +41,7 @@ async def test_includes_termination_reason_in_event_messages_only_once(
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATING)
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)

res = await session.execute(select(EventModel))
events = res.scalars().all()
events = await list_events(session)
assert len(events) == 2
assert {e.message for e in events} == {
"Instance status changed PENDING -> TERMINATING. Termination reason: ERROR (Some err)",
Expand All @@ -63,8 +62,7 @@ async def test_includes_termination_reason_in_event_message_when_switching_direc
instance.termination_reason_message = "Some err"
instances_services.switch_instance_status(session, instance, InstanceStatus.TERMINATED)

res = await session.execute(select(EventModel))
events = res.scalars().all()
events = await list_events(session)
assert len(events) == 1
assert events[0].message == (
"Instance status changed PENDING -> TERMINATED. Termination reason: ERROR (Some err)"
Expand Down