Skip to content

Commit e117948

Browse files
uof with di for handlers
1 parent f765535 commit e117948

File tree

22 files changed

+489
-155
lines changed

22 files changed

+489
-155
lines changed

app/application/repository/interfaces.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@
66
from app.domain.entities.sample_record import SampleRecord
77

88

9+
class IUnitOfWork(ABC):
10+
"""Interface for unit of work operations."""
11+
12+
@abstractmethod
13+
async def __aenter__(self):
14+
return self
15+
16+
@abstractmethod
17+
async def __aexit__(self, exc_type, exc_val, exc_tb):
18+
pass
19+
20+
21+
class ISampleRecordUnitOfWork(IUnitOfWork, ABC):
22+
"""Interface for record unit of work operations."""
23+
24+
@abstractmethod
25+
def get_sample_record_repository(self) -> "ISampleRecordRepository":
26+
pass
27+
28+
929
class ISampleRecordRepository(ABC):
1030
"""Interface for record repository operations."""
1131

@@ -60,3 +80,8 @@ async def get_by_id(self, record_id: int) -> SampleRecord:
6080
async def get_all(self) -> List[SampleRecord]:
6181
"""Get all records from the database"""
6282
pass
83+
84+
@abstractmethod
85+
async def commit(self) -> None:
86+
"""Commit changes to the database"""
87+
pass

app/application/use_cases/record_use_cases.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ async def create_record(
2323
domain_object = SampleRecord(
2424
record_data=request_object.record_data, name=request_object.name
2525
)
26-
created_record = await self._repo.create(domain_object)
27-
return SampleRecordResponseSchema.from_orm(created_record)
26+
created_record = SampleRecordResponseSchema.from_orm(
27+
await self._repo.create(domain_object)
28+
)
29+
30+
return created_record
2831

2932
async def update_record(
3033
self, update_request: SampleRecordUpdateRequestSchema
@@ -33,10 +36,12 @@ async def update_record(
3336
domain_object = SampleRecord(
3437
record_data=update_request.record_data,
3538
id=update_request.id,
36-
name=update_request.name
39+
name=update_request.name,
40+
)
41+
updated_record = SampleRecordResponseSchema.from_orm(
42+
await self._repo.update(domain_object)
3743
)
38-
updated_record = await self._repo.update(domain_object)
39-
return SampleRecordResponseSchema.from_orm(updated_record)
44+
return updated_record
4045

4146
async def delete_record(self, record_id: int) -> None:
4247
"""Delete a record."""

app/infrastructure/containers.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from importlib import import_module
23

34
from dependency_injector import containers, providers
45
from dependency_injector.providers import Callable, Factory, Singleton
@@ -14,7 +15,10 @@
1415
PubsubExceptionHandler,
1516
)
1617
from app.infrastructure.repositories.caching.redis_repo import RedisRepo
17-
from app.infrastructure.repositories.sample_record import SampleRecordRepository
18+
from app.infrastructure.repositories.unit_of_work import (
19+
ReadOnlySampleRecordUnitOfWork,
20+
WriteSampleRecordUnitOfWork,
21+
)
1822
from app.logger import logger
1923

2024
from app.presentation.bot.error_handlers.internal_error_handler import (
@@ -27,32 +31,18 @@
2731

2832

2933
class BotSampleRecordCommandContainer(containers.DeclarativeContainer):
30-
record_use_cases_factory = Callable(
31-
lambda session: SampleRecordUseCases(
32-
record_repo=SampleRecordRepository(session=session)
33-
)
34-
)
34+
session_factory = providers.Dependency()
3535

36+
ro_unit_of_work: Factory[ReadOnlySampleRecordUnitOfWork] = Factory(
37+
ReadOnlySampleRecordUnitOfWork, session_factory
38+
)
39+
rw_unit_of_work: Factory[WriteSampleRecordUnitOfWork] = Factory(
40+
WriteSampleRecordUnitOfWork, session_factory
41+
)
3642

37-
# class StorageContainer(containers.DeclarativeContainer):
38-
# wiring_config = containers.WiringConfiguration(
39-
# modules=["app.presentation.bot.commands.sample_records"]
40-
# )
41-
#
42-
# # Provider that returns a factory to create sessions
43-
# session_factory = Factory(build_db_session_factory)
44-
#
45-
# # Provider that creates a session (e.g., AsyncSession instance)
46-
# session = Resource(session_factory)
47-
#
48-
# # Provider that creates the SampleRecordUseCases, injecting the session
49-
# record_use_cases = Factory(
50-
# SampleRecordUseCases,
51-
# record_repo=Factory(
52-
# SampleRecordRepository,
53-
# session=session
54-
# )
55-
# )
43+
record_use_cases_factory = Callable(
44+
lambda repository: SampleRecordUseCases(record_repo=repository)
45+
)
5646

5747

5848
class CallbackTaskManager:
@@ -117,11 +107,16 @@ class ApplicationStartupContainer(containers.DeclarativeContainer):
117107
{} if not settings.RAISE_BOT_EXCEPTIONS else {Exception: internal_error_handler}
118108
)
119109

120-
from app.presentation.bot.commands import common, sample_record
110+
# Ленивая загрузка коллекторов
111+
@staticmethod
112+
def get_collectors():
113+
common = import_module("app.presentation.bot.commands.common")
114+
sample_record = import_module("app.presentation.bot.commands.sample_record")
115+
return [common.collector, sample_record.collector]
121116

122117
bot = providers.Singleton(
123118
Bot,
124-
collectors=[common.collector, sample_record.collector],
119+
collectors=Callable(get_collectors),
125120
bot_accounts=settings.BOT_CREDENTIALS,
126121
exception_handlers=exception_handlers, # type: ignore
127122
default_callback_timeout=settings.BOTX_CALLBACK_TIMEOUT_IN_SECONDS,
@@ -146,7 +141,6 @@ class ApplicationStartupContainer(containers.DeclarativeContainer):
146141
)
147142

148143

149-
150144
class WorkerStartupContainer(containers.DeclarativeContainer):
151145
redis_client = Singleton(lambda: aioredis.from_url(settings.REDIS_DSN))
152146

app/infrastructure/db/sqlalchemy.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def make_url_sync(url: str) -> str:
4141

4242
Base = declarative_base(metadata=MetaData(naming_convention=convention))
4343

44-
4544
@lru_cache(maxsize=1)
4645
def get_engine() -> AsyncEngine:
4746
"""Lazily initialize and cache a single SQLAlchemy async engine."""
@@ -54,33 +53,31 @@ def get_engine() -> AsyncEngine:
5453
)
5554

5655

57-
session_factory = async_sessionmaker(bind=get_engine(), expire_on_commit=False)
58-
59-
60-
async def close_db_connections() -> None:
61-
await get_engine().dispose()
62-
63-
64-
def provide_session(func: Callable) -> Callable:
65-
"""
66-
Provides a database session to an async function if one is not already passed.
67-
68-
:param func: The asynchronous function to wrap. It must accept a `session`
69-
keyword argument.
70-
:return: The wrapped function with automatic session provisioning."""
71-
72-
@wraps(func)
73-
async def wrapper(*args: Any, **kwargs: Any) -> Any:
74-
if kwargs.get("session"):
75-
return await func(*args, **kwargs)
76-
77-
async with session_factory() as session:
78-
try:
79-
return await func(*args, **kwargs, session=session)
80-
except Exception:
81-
await session.rollback()
82-
raise
83-
finally:
84-
await session.close()
85-
86-
return wrapper
56+
def get_session_factory() -> async_sessionmaker:
57+
engine = get_engine()
58+
return async_sessionmaker(bind=engine, expire_on_commit=False)
59+
60+
#
61+
# def provide_session(func: Callable) -> Callable:
62+
# """
63+
# Provides a database session to an async function if one is not already passed.
64+
#
65+
# :param func: The asynchronous function to wrap. It must accept a `session`
66+
# keyword argument.
67+
# :return: The wrapped function with automatic session provisioning."""
68+
#
69+
# @wraps(func)
70+
# async def wrapper(*args: Any, **kwargs: Any) -> Any:
71+
# if kwargs.get("session"):
72+
# return await func(*args, **kwargs)
73+
#
74+
# async with session_factory() as session:
75+
# try:
76+
# return await func(*args, **kwargs, session=session)
77+
# except Exception:
78+
# await session.rollback()
79+
# raise
80+
# finally:
81+
# await session.close()
82+
#
83+
# return wrapper

app/infrastructure/repositories/sample_record.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
RecordUpdateError,
1414
RecordAlreadyExistsError,
1515
ForeignKeyError,
16-
ValidationError,
16+
ValidationError, BaseRepositoryError,
1717
)
1818
from app.application.repository.interfaces import ISampleRecordRepository
1919
from app.decorators.mapper.exception_mapper import (
@@ -58,6 +58,15 @@ def __init__(self, session: AsyncSession):
5858
"""
5959
self._session = session
6060

61+
@ExceptionMapper(
62+
{
63+
Exception: EnrichedExceptionFactory(BaseRepositoryError),
64+
},
65+
is_bound_method=True,
66+
)
67+
async def commit(self) -> None:
68+
await self._session.commit()
69+
6170
@ExceptionMapper(
6271
{
6372
IntegrityError: IntegrityErrorFactory(RecordCreateError),
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import asyncio
2+
3+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
4+
5+
from app.application.repository.interfaces import (
6+
ISampleRecordUnitOfWork,
7+
ISampleRecordRepository,
8+
)
9+
from app.infrastructure.repositories.sample_record import SampleRecordRepository
10+
11+
12+
class ReadOnlySampleRecordUnitOfWork(ISampleRecordUnitOfWork):
13+
def get_sample_record_repository(self) -> ISampleRecordRepository:
14+
if not self._session:
15+
raise RuntimeError("Session is not initialized")
16+
17+
return SampleRecordRepository(self._session)
18+
19+
def __init__(self, session_factory: async_sessionmaker):
20+
super().__init__()
21+
self.session_factory = session_factory
22+
self._session: AsyncSession | None = None
23+
24+
async def __aenter__(self):
25+
self._session = self.session_factory()
26+
return self
27+
28+
async def __aexit__(self, exc_type, exc_val, exc_tb):
29+
try:
30+
# Recommended for implicit resources cleanup
31+
await self._session.rollback()
32+
finally:
33+
await self._session.close()
34+
35+
36+
class WriteSampleRecordUnitOfWork(ReadOnlySampleRecordUnitOfWork):
37+
"""Unit of Work for write operations with full transaction management."""
38+
39+
async def __aenter__(self):
40+
self._session = self.session_factory()
41+
await asyncio.wait_for(self._session.begin(), timeout=5)
42+
return self
43+
44+
async def __aexit__(self, exc_type, exc_val, exc_tb):
45+
try:
46+
if exc_type:
47+
await self._session.rollback()
48+
else:
49+
await self._session.commit()
50+
finally:
51+
await self._session.close()

app/main.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,41 +12,48 @@
1212
BotSampleRecordCommandContainer,
1313
CallbackTaskManager,
1414
)
15-
from app.infrastructure.db.sqlalchemy import close_db_connections
15+
from app.infrastructure.db.sqlalchemy import (
16+
get_session_factory, get_engine,
17+
)
1618
from app.presentation.api.routers import router
1719
from app.presentation.bot.resources import strings
1820

1921

2022
async def startup(
21-
bot: Bot = Provide[ApplicationStartupContainer.bot],
23+
bot: Bot,
2224
) -> None:
2325
await bot.startup()
2426

2527

2628
async def shutdown(
27-
callback_task_manager: CallbackTaskManager = Provide[
28-
ApplicationStartupContainer.callback_task_manager
29-
],
30-
bot: Bot = Provide[ApplicationStartupContainer.bot],
31-
redis_client: Redis = Provide[ApplicationStartupContainer.redis_client],
29+
container: ApplicationStartupContainer = Provide[ApplicationStartupContainer],
3230
) -> None:
33-
await bot.shutdown()
31+
await container.bot().shutdown()
3432

35-
await callback_task_manager.shutdown()
33+
await container.callback_task_manager().shutdown()
3634

37-
await redis_client.aclose()
38-
await close_db_connections()
35+
await container.redis_client().aclose()
36+
await container.shutdown_resources()
37+
await get_engine().dispose()
3938

4039

4140
def get_application() -> FastAPI:
4241
"""Create configured server application instance."""
4342

4443
# Initialize the main application container
4544
main_container = ApplicationStartupContainer()
46-
main_container.wire(modules=["app.main", "app.presentation.api.botx"])
45+
main_container.wire(
46+
modules=[
47+
"app.main",
48+
"app.presentation.api.botx",
49+
"app.presentation.bot.commands.sample_record",
50+
]
51+
)
4752

4853
# Initialize the SampleRecord commands container
49-
sample_record_commands_container = BotSampleRecordCommandContainer()
54+
sample_record_commands_container = BotSampleRecordCommandContainer(
55+
session_factory=get_session_factory()
56+
)
5057
sample_record_commands_container.wire(
5158
modules=["app.presentation.bot.commands.sample_record"]
5259
)
@@ -64,9 +71,10 @@ def get_application() -> FastAPI:
6471
"shutdown",
6572
partial(
6673
shutdown,
67-
callback_task_manager=main_container.callback_task_manager(),
68-
bot=main_container.bot(),
69-
redis_client=main_container.redis_client(),
74+
# callback_task_manager=main_container.callback_task_manager(),
75+
# bot=main_container.bot(),
76+
# redis_client=main_container.redis_client(),
77+
container=main_container,
7078
),
7179
)
7280

0 commit comments

Comments
 (0)