Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/fastsqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from typing import Annotated, Generic, TypeVar, TypedDict
from typing import Annotated, Generic, TypedDict, TypeVar

from fastapi import Depends as BaseDepends
from fastapi import FastAPI, Query
Expand Down Expand Up @@ -90,7 +90,7 @@ class State(TypedDict):

def new_lifespan(
url: str | None = None, **kw
) -> Callable[[FastAPI], _AsyncGeneratorContextManager[State, None]]:
) -> Callable[[FastAPI | None], _AsyncGeneratorContextManager[State, None]]:
"""Create a new lifespan async context manager.

It expects the exact same parameters as
Expand All @@ -117,7 +117,7 @@ def new_lifespan(
has_config = url is not None

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[State, None]:
async def lifespan(app: FastAPI | None) -> AsyncGenerator[State, None]:
if has_config:
prefix = ""
sqla_config = {**kw, **{"url": url}}
Expand Down
22 changes: 14 additions & 8 deletions tests/unit/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from fastapi import FastAPI
from pytest import raises
from pytest import raises, fixture

app = FastAPI()
_app = FastAPI()


async def test_it_returns_state(environ):
@fixture(params=[_app, None])
def app(request):
# lifespan tests pass whether lifespan receives app or None
return request.param


async def test_it_returns_state(environ, app):
from fastsqla import lifespan

async with lifespan(app) as state:
assert "fastsqla_engine" in state


async def test_it_binds_an_sqla_engine_to_sessionmaker(environ):
async def test_it_binds_an_sqla_engine_to_sessionmaker(environ, app):
from fastsqla import SessionFactory, lifespan

assert SessionFactory.kw["bind"] is None
Expand All @@ -24,7 +30,7 @@ async def test_it_binds_an_sqla_engine_to_sessionmaker(environ):
assert SessionFactory.kw["bind"] is None


async def test_it_fails_on_a_missing_sqlalchemy_url(monkeypatch):
async def test_it_fails_on_a_missing_sqlalchemy_url(monkeypatch, app):
from fastsqla import lifespan

monkeypatch.delenv("SQLALCHEMY_URL", raising=False)
Expand All @@ -35,7 +41,7 @@ async def test_it_fails_on_a_missing_sqlalchemy_url(monkeypatch):
assert raise_info.value.args[0] == "Missing sqlalchemy_url in environ."


async def test_it_fails_on_not_async_engine(monkeypatch):
async def test_it_fails_on_not_async_engine(monkeypatch, app):
from fastsqla import lifespan

monkeypatch.setenv("SQLALCHEMY_URL", "sqlite:///:memory:")
Expand All @@ -46,7 +52,7 @@ async def test_it_fails_on_not_async_engine(monkeypatch):
assert "'pysqlite' is not async." in raise_info.value.args[0]


async def test_new_lifespan_with_connect_args(sqlalchemy_url):
async def test_new_lifespan_with_connect_args(sqlalchemy_url, app):
from fastsqla import new_lifespan

lifespan = new_lifespan(sqlalchemy_url, connect_args={"autocommit": False})
Expand All @@ -55,7 +61,7 @@ async def test_new_lifespan_with_connect_args(sqlalchemy_url):
pass


async def test_new_lifespan_fails_with_invalid_connect_args(sqlalchemy_url):
async def test_new_lifespan_fails_with_invalid_connect_args(sqlalchemy_url, app):
from fastsqla import new_lifespan

lifespan = new_lifespan(sqlalchemy_url, connect_args={"this is wrong": False})
Expand Down