Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/replit_river/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import random
from contextvars import Context
from typing import Protocol

from replit_river.error_schema import RiverException
from replit_river.transport_options import ConnectionRetryOptions
Expand All @@ -15,6 +16,13 @@ def __init__(self, code: str, message: str, client_id: str) -> None:
self.client_id = client_id


class RateLimiter(Protocol):
def start_restoring_budget(self, user: str) -> None: ...
def get_backoff_ms(self, user: str) -> float: ...
def has_budget(self, user: str) -> bool: ...
def consume_budget(self, user: str) -> None: ...


class LeakyBucketRateLimit:
"""Asynchronous leaky bucket rate limiter.

Expand Down
6 changes: 4 additions & 2 deletions src/replit_river/v2/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ async def get_or_create_session(self) -> Session:
call ensure_connected on whatever session is active.
"""
existing_session = self._session
if not existing_session or existing_session.is_closed():
if not existing_session or existing_session.is_terminal():
logger.info("Creating new session")
if existing_session:
await existing_session.close()
new_session = Session(
client_id=self._client_id,
server_id=self._server_id,
Expand All @@ -80,7 +82,7 @@ async def _retry_connection(self) -> Session:
logger.debug("Triggering get_or_create_session")
return await self.get_or_create_session()

async def _delete_session(self, session: Session) -> None:
def _delete_session(self, session: Session) -> None:
if self._session is session:
self._session = None
else:
Expand Down
Loading
Loading