Skip to content

TortoiseContext context manager breaks when used by parallel tasks #2186

@i-am-grub

Description

@i-am-grub

Describe the bug
Parallel tasks accessing the same TortoiseContext within context managers overwrite the shared _token attribute, This breaks the ability to reset the _current_context when the task with the overwritten token attempts to exit the context manager.

To Reproduce

import asyncio

from tortoise import Tortoise, fields, run_async
from tortoise.context import TortoiseContext
from tortoise.models import Model


class Event(Model):
    id = fields.IntField(primary_key=True)
    name = fields.TextField(
        description="Name of the event that corresponds to an action"
    )
    datetime = fields.DatetimeField(
        null=True, description="Datetime of when the event was generated"
    )

    class Meta:
        table = "event"
        table_description = "This table contains a list of all the example events"

    def __str__(self):
        return self.name


async def database_work(context: TortoiseContext):

    async with context:
        event = await Event.create(name="Test")
        await Event.filter(id=event.id).update(name="Updated name")

        await Event.filter(name="Updated name").first()

        await Event(name="Test 2").save()
        await Event.all().values_list("id", flat=True)
        await Event.all().values("id", "name")


async def run():
    context = await Tortoise.init(
        db_url="sqlite://:memory:", modules={"models": ["__main__"]}
    )
    await Tortoise.generate_schemas()

    # Create two tasks running in parallel utilizing 
    async with asyncio.TaskGroup() as tg:
        tg.create_task(database_work(context))
        tg.create_task(database_work(context))


if __name__ == "__main__":
    run_async(run())

Expected behavior
Being able to cleanly set and reset the TortoiseContext from parallel tasks accessing the context at the same time.

Additional context
Starlette now requires the use of the lifespan state. When starting the database during the lifespan, the database needs to manually be propagated into each request (which results in a parallel tasks encountering this issue).

import contextlib
from typing import TypedDict

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.routing import Route
from tortoise import Tortoise, fields, run_async
from tortoise.context import TortoiseContext, _current_context
from tortoise.models import Model


class Event(Model):
    id = fields.IntField(primary_key=True)
    name = fields.TextField(
        description="Name of the event that corresponds to an action"
    )
    datetime = fields.DatetimeField(
        null=True, description="Datetime of when the event was generated"
    )

    class Meta:
        table = "event"
        table_description = "This table contains a list of all the example events"

    def __str__(self):
        return self.name


class ContextState(TypedDict):
    database_ctx: TortoiseContext


class ContextMiddleware:
    """
    Middleware for propagating the database context into http and websocket requests
    """
    def __init__(self, app) -> None:
        self.app = app

    async def __call__(self, scope, receive, send) -> None:
        if scope["type"] not in ("http", "websocket"):
            await self.app(scope, receive, send)
            return

        state: ContextState = scope["state"]

        #
        # Propagate the database context into the route (currently breaks)
        #

        # Breaking version (automatically adjusting the database context through the context manager)
        with state["database_ctx"]:
            await self.app(scope, receive, send)

        # Working alternative (manually adjusting the database context state)
        token = _current_context.set(state["database_ctx"])
        try:
            await self.app(scope, receive, send)
        finally:
            _current_context.reset(token)


async def create_event(request: Request):
    await Event.create(name="Test")


@contextlib.asynccontextmanager
async def lifespan(_app: Starlette):

    async with TortoiseContext() as ctx:
        await ctx.init(
            db_url="sqlite://:memory:",
            modules={"models": ["myapp.models"]}
        )
        await ctx.generate_schemas()

        state = ContextState(database_ctx=context)

        yield state


def main():

    app = Starlette(
        routes=[Route("/test", create_event)],
        middleware=[Middleware(ContextMiddleware)],
        lifespan=lifespan,
    )

    # run the generated asgi app...


if __name__ == "__main__":
    run_async(main())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions