Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions backend/app/core/startup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
from typing import Any

from fastapi import HTTPException
from sqlalchemy.orm import Session

from app.modules.admin.seed.service import seed_all
from app.modules.auth.schemas import UserCreate
from app.modules.auth.service import create_user_if_not_exists
from app.settings import Settings

logger = logging.getLogger(__name__)


def setup_test_users(db: Session, users: list[dict[str, Any]], default_password: str):
"""Create initial test users if they don't exist. Easily extendable."""
for user_info in users:
# Use default password if not provided in user_info
data = user_info.copy()
if "password" not in data:
data["password"] = default_password

logger.info(f"Ensuring test user exists: {data['email']}")
create_user_if_not_exists(db, UserCreate(**data))


def auto_seed_data(db: Session):
"""Seed the database with initial data if it's empty."""
try:
seed_all(db, n_tags=7, n_fields=12, n_events=30)
logger.info("Auto-seeding completed successfully.")
except HTTPException as e:
if e.status_code == 405:
logger.info("Database already contains data. Skipping auto-seeding.")
else:
logger.exception(f"Auto-seeding failed with unexpected error: {e.detail}")
except Exception:
logger.exception("Auto-seeding failed")


def run_startup_tasks(db: Session, settings: Settings):
"""Run all necessary startup tasks for development environment."""
if settings.is_dev:
setup_test_users(db, settings.dev_users, settings.dev_users_password)
auto_seed_data(db)
elif settings.is_demo:
create_user_if_not_exists(
db,
UserCreate(
email=settings.demo_user_email, password=settings.demo_user_password
),
)
10 changes: 3 additions & 7 deletions backend/app/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from app.api.v1.routes import admin, auth, events, fields, generic, tags
from app.core.handlers import http_exception_handler, validation_exception_handler
from app.modules.auth.schemas import UserCreate
from app.modules.auth.service import create_user_if_not_exists
from app.core.startup import run_startup_tasks
from app.settings import Settings

logger = logging.getLogger(__name__)
Expand All @@ -24,11 +23,8 @@ def create_app(
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(f"Starting application in {settings.env} mode")
if settings.is_demo:
with SessionLocal() as db:
create_user_if_not_exists(
db, UserCreate(email="demo@evsy.dev", password="bestructured")
)
with SessionLocal() as db:
run_startup_tasks(db, settings)
yield
logger.info("Shutting down application")
engine.dispose()
Expand Down
12 changes: 9 additions & 3 deletions backend/app/modules/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
settings = get_settings()
to_encode = data.copy()
expire = datetime.now(UTC) + (
expires_delta or timedelta(minutes=settings.access_token_expire_minutes)
)
if expires_delta:
expire = datetime.now(UTC) + expires_delta
elif settings.is_dev:
# 100 years for dev mode
expire = datetime.now(UTC) + timedelta(days=365 * 100)
else:
expire = datetime.now(UTC) + timedelta(
minutes=settings.access_token_expire_minutes
)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.secret_key, algorithm=settings.jwt_algorithm)

Expand Down
9 changes: 9 additions & 0 deletions backend/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def __init__(self, _env_file: Optional[str] = None, **kwargs: Any):
default=None, alias="GOOGLE_CLIENT_SECRET"
)

dev_users: list[dict[str, Any]] = Field(
default=[{"email": "user@example.com"}],
alias="DEV_USERS",
)
dev_users_password: str = Field(default="12345678", alias="DEV_USERS_PASSWORD")

demo_user_email: str = Field(default="demo@evsy.dev", alias="DEMO_USER_EMAIL")
demo_user_password: str = Field(default="bestructured", alias="DEMO_USER_PASSWORD")

model_config = SettingsConfigDict(
env_file_encoding="utf-8",
case_sensitive=False,
Expand Down
98 changes: 98 additions & 0 deletions backend/tests/test_startup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch

from fastapi import HTTPException
from jose import jwt

from app.core.startup import auto_seed_data, run_startup_tasks, setup_test_users
from app.modules.auth.models import User
from app.modules.auth.token import create_access_token


def test_create_access_token_dev_long_expiry():
"""Test that tokens in dev mode have a very long expiry."""
mock_settings = MagicMock()
mock_settings.is_dev = True
mock_settings.secret_key = "test_secret"
mock_settings.jwt_algorithm = "HS256"

with patch("app.modules.auth.token.get_settings", return_value=mock_settings):
token = create_access_token({"sub": "user@example.com"})
payload = jwt.decode(token, "test_secret", algorithms=["HS256"])

exp = payload["exp"]
expected_min_exp = (datetime.now(UTC) + timedelta(days=365 * 99)).timestamp()
assert exp > expected_min_exp


def test_create_access_token_prod_normal_expiry():
"""Test that tokens in prod mode have normal expiry."""
mock_settings = MagicMock()
mock_settings.is_dev = False
mock_settings.access_token_expire_minutes = 60
mock_settings.secret_key = "test_secret"
mock_settings.jwt_algorithm = "HS256"

with patch("app.modules.auth.token.get_settings", return_value=mock_settings):
token = create_access_token({"sub": "user@example.com"})
payload = jwt.decode(token, "test_secret", algorithms=["HS256"])

exp = payload["exp"]
# Should be roughly 60 minutes from now
expected_exp = (datetime.now(UTC) + timedelta(minutes=60)).timestamp()
assert abs(exp - expected_exp) < 10 # Allow 10s difference


def test_setup_test_users(db, test_settings):
"""Test that test users are created if they don't exist."""
# Ensure user doesn't exist in the current transaction
primary_dev_user = test_settings.dev_users[0]
user = db.query(User).filter(User.email == primary_dev_user["email"]).first()
if user:
db.delete(user)
db.flush()

setup_test_users(db, test_settings.dev_users, test_settings.dev_users_password)

user = db.query(User).filter(User.email == primary_dev_user["email"]).first()
assert user is not None
assert user.email == primary_dev_user["email"]


@patch("app.core.startup.seed_all")
def test_auto_seed_data_empty_db(mock_seed_all, db):
"""Test that seeding is called when DB is empty."""
mock_seed_all.return_value = None
auto_seed_data(db)
mock_seed_all.assert_called_once()


@patch("app.core.startup.seed_all")
def test_auto_seed_data_already_seeded(mock_seed_all, db):
"""Test that seeding is skipped if DB already has data (simulated by HTTPException 405)."""
mock_seed_all.side_effect = HTTPException(
status_code=405, detail="Action is only allowed on empty database"
)

# This should not raise an exception, just log and return
auto_seed_data(db)
mock_seed_all.assert_called_once()


def test_run_startup_tasks_dev_calls_subtasks(db):
"""Test that all dev startup tasks are triggered in dev mode."""
mock_settings = MagicMock()
mock_settings.is_dev = True
mock_settings.is_demo = False
mock_settings.dev_users = [{"email": "user@example.com"}]
mock_settings.dev_users_password = "password"

with (
patch("app.core.startup.setup_test_users") as mock_setup_users,
patch("app.core.startup.auto_seed_data") as mock_auto_seed,
):
run_startup_tasks(db, mock_settings)
mock_setup_users.assert_called_once_with(
db, mock_settings.dev_users, mock_settings.dev_users_password
)
mock_auto_seed.assert_called_once_with(db)
Loading