|
17 | 17 | ERROR_SESSION, |
18 | 18 | RiverException, |
19 | 19 | ) |
| 20 | +from replit_river.lock import AcquiredLock, assert_correct_lock |
20 | 21 | from replit_river.messages import ( |
21 | 22 | PROTOCOL_VERSION, |
22 | 23 | FailedSendingMessageException, |
@@ -70,51 +71,66 @@ def __init__( |
70 | 71 |
|
71 | 72 | async def close(self) -> None: |
72 | 73 | self._rate_limiter.close() |
73 | | - await self._close_all_sessions() |
| 74 | + async with self._session_lock() as session_lock: |
| 75 | + await self._close_all_sessions(session_lock) |
74 | 76 |
|
| 77 | + # get a session, ensuring its in connected state |
75 | 78 | async def get_or_create_session(self) -> ClientSession: |
76 | | - async with self._create_session_lock: |
77 | | - existing_session = await self._get_existing_session() |
| 79 | + async with self._session_lock() as session_lock: |
| 80 | + existing_session = await self._get_existing_session(session_lock) |
78 | 81 | if not existing_session: |
79 | | - return await self._create_new_session() |
| 82 | + # create a new session |
| 83 | + return await self._create_new_session(session_lock) |
80 | 84 | is_session_open = await existing_session.is_session_open() |
81 | 85 | if not is_session_open: |
82 | | - return await self._create_new_session() |
83 | | - is_ws_open = await existing_session.is_websocket_open() |
84 | | - if is_ws_open: |
85 | | - return existing_session |
86 | | - new_ws, _, hs_response = await self._establish_new_connection( |
87 | | - existing_session |
88 | | - ) |
89 | | - if hs_response.status.sessionId == existing_session.session_id: |
90 | | - logger.info( |
91 | | - "Replacing ws connection in session id %s", |
92 | | - existing_session.session_id, |
93 | | - ) |
94 | | - await existing_session.replace_with_new_websocket(new_ws) |
95 | | - return existing_session |
96 | | - else: |
97 | | - logger.info("Closing stale session %s", existing_session.session_id) |
| 86 | + # replace the session with a new one |
98 | 87 | await existing_session.close() |
99 | | - return await self._create_new_session() |
| 88 | + return await self._create_new_session(session_lock) |
100 | 89 |
|
101 | | - async def _get_existing_session(self) -> Optional[ClientSession]: |
102 | | - async with self._session_lock: |
103 | | - if not self._sessions: |
104 | | - return None |
105 | | - if len(self._sessions) > 1: |
106 | | - raise RiverException( |
107 | | - "session_error", |
108 | | - "More than one session found in client, should only be one", |
109 | | - ) |
110 | | - session = list(self._sessions.values())[0] |
111 | | - if isinstance(session, ClientSession): |
112 | | - return session |
113 | | - else: |
114 | | - raise RiverException( |
115 | | - "session_error", f"Client session type wrong, got {type(session)}" |
| 90 | + # we have a websocket, check if its open |
| 91 | + is_ws_open = await existing_session.is_websocket_open() |
| 92 | + if not is_ws_open: |
| 93 | + # dial a new connection |
| 94 | + new_ws, _, hs_response = await self._establish_new_connection( |
| 95 | + existing_session |
116 | 96 | ) |
117 | 97 |
|
| 98 | + # session id is the same, it was a transparent reconnect |
| 99 | + if hs_response.status.sessionId == existing_session.session_id: |
| 100 | + logger.info( |
| 101 | + "Replacing ws connection in session id %s", |
| 102 | + existing_session.session_id, |
| 103 | + ) |
| 104 | + await existing_session.replace_with_new_websocket(new_ws) |
| 105 | + return existing_session |
| 106 | + else: |
| 107 | + # session id is different, we need to close the old session |
| 108 | + logger.info("Closing stale session %s", existing_session.session_id) |
| 109 | + await existing_session.close() |
| 110 | + return await self._create_new_session(session_lock) |
| 111 | + |
| 112 | + # happy path, we have a session and its open |
| 113 | + return existing_session |
| 114 | + |
| 115 | + async def _get_existing_session( |
| 116 | + self, session_lock: AcquiredLock |
| 117 | + ) -> Optional[ClientSession]: |
| 118 | + assert_correct_lock(session_lock, self._session_lock) |
| 119 | + if not self._sessions: |
| 120 | + return None |
| 121 | + if len(self._sessions) > 1: |
| 122 | + raise RiverException( |
| 123 | + "session_error", |
| 124 | + "More than one session found in client, should only be one", |
| 125 | + ) |
| 126 | + session = list(self._sessions.values())[0] |
| 127 | + if isinstance(session, ClientSession): |
| 128 | + return session |
| 129 | + else: |
| 130 | + raise RiverException( |
| 131 | + "session_error", f"Client session type wrong, got {type(session)}" |
| 132 | + ) |
| 133 | + |
118 | 134 | async def _establish_new_connection( |
119 | 135 | self, |
120 | 136 | old_session: Optional[ClientSession] = None, |
@@ -186,35 +202,39 @@ async def _establish_new_connection( |
186 | 202 |
|
187 | 203 | async def _create_new_session( |
188 | 204 | self, |
| 205 | + session_lock: AcquiredLock, |
189 | 206 | ) -> ClientSession: |
190 | | - logger.info("Creating new session") |
191 | | - new_ws, hs_request, hs_response = await self._establish_new_connection() |
192 | | - if not hs_response.status.ok: |
193 | | - message = hs_response.status.reason |
194 | | - raise RiverException( |
195 | | - ERROR_SESSION, |
196 | | - f"Server did not return OK status on handshake response: {message}", |
| 207 | + async with self._create_session_lock: |
| 208 | + logger.info("Creating new session") |
| 209 | + new_ws, hs_request, hs_response = await self._establish_new_connection() |
| 210 | + if not hs_response.status.ok: |
| 211 | + message = hs_response.status.reason |
| 212 | + raise RiverException( |
| 213 | + ERROR_SESSION, |
| 214 | + f"Server did not return OK status on handshake response: {message}", |
| 215 | + ) |
| 216 | + new_session = ClientSession( |
| 217 | + transport_id=self._transport_id, |
| 218 | + to_id=self._server_id, |
| 219 | + session_id=hs_request.sessionId, |
| 220 | + websocket=new_ws, |
| 221 | + transport_options=self._transport_options, |
| 222 | + is_server=False, |
| 223 | + close_session_callback=self.lock_and_delete_session, |
| 224 | + retry_connection_callback=self._retry_connection, |
| 225 | + handlers={}, |
197 | 226 | ) |
198 | | - new_session = ClientSession( |
199 | | - transport_id=self._transport_id, |
200 | | - to_id=self._server_id, |
201 | | - session_id=hs_request.sessionId, |
202 | | - websocket=new_ws, |
203 | | - transport_options=self._transport_options, |
204 | | - is_server=False, |
205 | | - close_session_callback=self._delete_session, |
206 | | - retry_connection_callback=self._retry_connection, |
207 | | - handlers={}, |
208 | | - ) |
209 | 227 |
|
210 | | - self._set_session(new_session) |
211 | | - await new_session.start_serve_responses() |
212 | | - return new_session |
| 228 | + self._set_session(session_lock, new_session) |
| 229 | + await new_session.start_serve_responses() |
| 230 | + return new_session |
213 | 231 |
|
214 | 232 | async def _retry_connection(self) -> ClientSession: |
215 | | - if not self._transport_options.transparent_reconnect: |
216 | | - await self._close_all_sessions() |
217 | | - return await self.get_or_create_session() |
| 233 | + if self._transport_options.transparent_reconnect: |
| 234 | + return await self.get_or_create_session() |
| 235 | + |
| 236 | + async with self._session_lock() as session_lock: |
| 237 | + return await self._create_new_session(session_lock) |
218 | 238 |
|
219 | 239 | async def _send_handshake_request( |
220 | 240 | self, |
@@ -344,7 +364,6 @@ async def _establish_handshake( |
344 | 364 | # If the session status is mismatched, we should close the old session |
345 | 365 | # and let the retry logic to create a new session. |
346 | 366 | await old_session.close() |
347 | | - await self._delete_session(old_session) |
348 | 367 |
|
349 | 368 | raise RiverException( |
350 | 369 | ERROR_HANDSHAKE, |
|
0 commit comments