Skip to content

Commit 9afaf2d

Browse files
committed
fix deadlock
1 parent 26ba67e commit 9afaf2d

File tree

7 files changed

+281
-138
lines changed

7 files changed

+281
-138
lines changed

src/replit_river/client_transport.py

Lines changed: 80 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ERROR_SESSION,
1818
RiverException,
1919
)
20+
from replit_river.lock import AcquiredLock, assert_correct_lock
2021
from replit_river.messages import (
2122
PROTOCOL_VERSION,
2223
FailedSendingMessageException,
@@ -70,51 +71,66 @@ def __init__(
7071

7172
async def close(self) -> None:
7273
self._rate_limiter.close()
73-
await self._close_all_sessions()
74+
async with self._session_lock() as session_lock:
75+
await self._close_all_sessions(session_lock)
7476

77+
# get a session, ensuring its in connected state
7578
async def get_or_create_session(self) -> ClientSession:
76-
async with self._create_session_lock:
77-
existing_session = await self._get_existing_session()
79+
async with self._session_lock() as session_lock:
80+
existing_session = await self._get_existing_session(session_lock)
7881
if not existing_session:
79-
return await self._create_new_session()
82+
# create a new session
83+
return await self._create_new_session(session_lock)
8084
is_session_open = await existing_session.is_session_open()
8185
if not is_session_open:
82-
return await self._create_new_session()
83-
is_ws_open = await existing_session.is_websocket_open()
84-
if is_ws_open:
85-
return existing_session
86-
new_ws, _, hs_response = await self._establish_new_connection(
87-
existing_session
88-
)
89-
if hs_response.status.sessionId == existing_session.session_id:
90-
logger.info(
91-
"Replacing ws connection in session id %s",
92-
existing_session.session_id,
93-
)
94-
await existing_session.replace_with_new_websocket(new_ws)
95-
return existing_session
96-
else:
97-
logger.info("Closing stale session %s", existing_session.session_id)
86+
# replace the session with a new one
9887
await existing_session.close()
99-
return await self._create_new_session()
88+
return await self._create_new_session(session_lock)
10089

101-
async def _get_existing_session(self) -> Optional[ClientSession]:
102-
async with self._session_lock:
103-
if not self._sessions:
104-
return None
105-
if len(self._sessions) > 1:
106-
raise RiverException(
107-
"session_error",
108-
"More than one session found in client, should only be one",
109-
)
110-
session = list(self._sessions.values())[0]
111-
if isinstance(session, ClientSession):
112-
return session
113-
else:
114-
raise RiverException(
115-
"session_error", f"Client session type wrong, got {type(session)}"
90+
# we have a websocket, check if its open
91+
is_ws_open = await existing_session.is_websocket_open()
92+
if not is_ws_open:
93+
# dial a new connection
94+
new_ws, _, hs_response = await self._establish_new_connection(
95+
existing_session
11696
)
11797

98+
# session id is the same, it was a transparent reconnect
99+
if hs_response.status.sessionId == existing_session.session_id:
100+
logger.info(
101+
"Replacing ws connection in session id %s",
102+
existing_session.session_id,
103+
)
104+
await existing_session.replace_with_new_websocket(new_ws)
105+
return existing_session
106+
else:
107+
# session id is different, we need to close the old session
108+
logger.info("Closing stale session %s", existing_session.session_id)
109+
await existing_session.close()
110+
return await self._create_new_session(session_lock)
111+
112+
# happy path, we have a session and its open
113+
return existing_session
114+
115+
async def _get_existing_session(
116+
self, session_lock: AcquiredLock
117+
) -> Optional[ClientSession]:
118+
assert_correct_lock(session_lock, self._session_lock)
119+
if not self._sessions:
120+
return None
121+
if len(self._sessions) > 1:
122+
raise RiverException(
123+
"session_error",
124+
"More than one session found in client, should only be one",
125+
)
126+
session = list(self._sessions.values())[0]
127+
if isinstance(session, ClientSession):
128+
return session
129+
else:
130+
raise RiverException(
131+
"session_error", f"Client session type wrong, got {type(session)}"
132+
)
133+
118134
async def _establish_new_connection(
119135
self,
120136
old_session: Optional[ClientSession] = None,
@@ -186,35 +202,39 @@ async def _establish_new_connection(
186202

187203
async def _create_new_session(
188204
self,
205+
session_lock: AcquiredLock,
189206
) -> ClientSession:
190-
logger.info("Creating new session")
191-
new_ws, hs_request, hs_response = await self._establish_new_connection()
192-
if not hs_response.status.ok:
193-
message = hs_response.status.reason
194-
raise RiverException(
195-
ERROR_SESSION,
196-
f"Server did not return OK status on handshake response: {message}",
207+
async with self._create_session_lock:
208+
logger.info("Creating new session")
209+
new_ws, hs_request, hs_response = await self._establish_new_connection()
210+
if not hs_response.status.ok:
211+
message = hs_response.status.reason
212+
raise RiverException(
213+
ERROR_SESSION,
214+
f"Server did not return OK status on handshake response: {message}",
215+
)
216+
new_session = ClientSession(
217+
transport_id=self._transport_id,
218+
to_id=self._server_id,
219+
session_id=hs_request.sessionId,
220+
websocket=new_ws,
221+
transport_options=self._transport_options,
222+
is_server=False,
223+
close_session_callback=self.lock_and_delete_session,
224+
retry_connection_callback=self._retry_connection,
225+
handlers={},
197226
)
198-
new_session = ClientSession(
199-
transport_id=self._transport_id,
200-
to_id=self._server_id,
201-
session_id=hs_request.sessionId,
202-
websocket=new_ws,
203-
transport_options=self._transport_options,
204-
is_server=False,
205-
close_session_callback=self._delete_session,
206-
retry_connection_callback=self._retry_connection,
207-
handlers={},
208-
)
209227

210-
self._set_session(new_session)
211-
await new_session.start_serve_responses()
212-
return new_session
228+
self._set_session(session_lock, new_session)
229+
await new_session.start_serve_responses()
230+
return new_session
213231

214232
async def _retry_connection(self) -> ClientSession:
215-
if not self._transport_options.transparent_reconnect:
216-
await self._close_all_sessions()
217-
return await self.get_or_create_session()
233+
if self._transport_options.transparent_reconnect:
234+
return await self.get_or_create_session()
235+
236+
async with self._session_lock() as session_lock:
237+
return await self._create_new_session(session_lock)
218238

219239
async def _send_handshake_request(
220240
self,
@@ -344,7 +364,6 @@ async def _establish_handshake(
344364
# If the session status is mismatched, we should close the old session
345365
# and let the retry logic to create a new session.
346366
await old_session.close()
347-
await self._delete_session(old_session)
348367

349368
raise RiverException(
350369
ERROR_HANDSHAKE,

src/replit_river/lock.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import asyncio
2+
from contextlib import asynccontextmanager
3+
from typing import AsyncGenerator, ParamSpec, TypeVar
4+
5+
P = ParamSpec("P")
6+
R = TypeVar("R")
7+
8+
9+
class AcquiredLock:
10+
def __init__(self, lock: asyncio.Lock):
11+
self._lock = lock
12+
13+
14+
class TransferableLock:
15+
"""A lock that can be transferred between coroutines."""
16+
17+
def __init__(self) -> None:
18+
self._lock = asyncio.Lock()
19+
20+
class BoundLock(AcquiredLock):
21+
pass
22+
23+
self.BoundLock = BoundLock
24+
25+
@asynccontextmanager
26+
async def __call__(self) -> AsyncGenerator[AcquiredLock]:
27+
await self._lock.acquire()
28+
yield self.BoundLock(self._lock)
29+
self._lock.release()
30+
31+
32+
def assert_correct_lock(
33+
acquired_lock: AcquiredLock, expected_lock_class: TransferableLock
34+
) -> None:
35+
"""Assert that the acquired lock is the correct class."""
36+
if not isinstance(acquired_lock, expected_lock_class.BoundLock):
37+
raise ValueError(
38+
f"Expected {expected_lock_class.BoundLock.__name__}, "
39+
f"got {type(acquired_lock).__name__}"
40+
)

src/replit_river/seq_manager.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,4 @@ async def check_seq_and_update(self, msg: TransportMessage) -> None:
8787
f"{self.ack}"
8888
)
8989
self.receiver_ack = msg.ack
90-
await self._set_ack(msg.seq + 1)
91-
92-
async def _set_ack(self, new_ack: int) -> int:
93-
async with self._ack_lock:
94-
self.ack = new_ack
95-
return self.ack
90+
self.ack = msg.seq + 1

src/replit_river/server_transport.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ async def handshake_to_get_session(
8888
raise WebsocketClosedException("No handshake message received")
8989

9090
async def close(self) -> None:
91-
await self._close_all_sessions()
91+
async with self._session_lock() as session_lock:
92+
await self._close_all_sessions(session_lock)
9293

9394
async def _get_or_create_session(
9495
self,
@@ -97,7 +98,7 @@ async def _get_or_create_session(
9798
session_id: str,
9899
websocket: WebSocketCommonProtocol,
99100
) -> Session:
100-
async with self._session_lock:
101+
async with self._session_lock() as session_lock:
101102
session_to_close: Optional[Session] = None
102103
new_session: Optional[Session] = None
103104
if to_id not in self._sessions:
@@ -112,7 +113,7 @@ async def _get_or_create_session(
112113
self._transport_options,
113114
self._is_server,
114115
self._handlers,
115-
close_session_callback=self._delete_session,
116+
close_session_callback=self.lock_and_delete_session,
116117
)
117118
else:
118119
old_session = self._sessions[to_id]
@@ -133,7 +134,7 @@ async def _get_or_create_session(
133134
self._transport_options,
134135
self._is_server,
135136
self._handlers,
136-
close_session_callback=self._delete_session,
137+
close_session_callback=self.lock_and_delete_session,
137138
)
138139
else:
139140
# If the instance id is the same, we reuse the session and assign
@@ -152,7 +153,7 @@ async def _get_or_create_session(
152153
if session_to_close:
153154
logger.info("Closing stale session %s", session_to_close.session_id)
154155
await session_to_close.close()
155-
self._set_session(new_session)
156+
self._set_session(session_lock, new_session)
156157
return new_session
157158

158159
async def _send_handshake_response(
@@ -228,68 +229,64 @@ async def _establish_handshake(
228229
)
229230
raise InvalidMessageException("handshake request to wrong server")
230231

231-
async with self._session_lock:
232-
old_session = self._sessions.get(request_message.from_, None)
233-
client_next_expected_seq = (
234-
handshake_request.expectedSessionState.nextExpectedSeq
235-
)
236-
client_next_sent_seq = (
237-
handshake_request.expectedSessionState.nextSentSeq or 0
238-
)
239-
if old_session and old_session.session_id == handshake_request.sessionId:
240-
# check invariants
241-
# ordering must be correct
242-
our_next_seq = await old_session.get_next_sent_seq()
243-
our_ack = await old_session.get_next_expected_seq()
244-
245-
if client_next_sent_seq > our_ack:
246-
message = (
247-
"client is in the future: "
248-
f"server wanted {our_ack} but client has {client_next_sent_seq}"
249-
)
250-
await self._send_handshake_response(
251-
request_message,
252-
HandShakeStatus(ok=False, reason=message),
253-
websocket,
254-
)
255-
raise SessionStateMismatchException(message)
232+
old_session = self._sessions.get(request_message.from_, None)
233+
client_next_expected_seq = (
234+
handshake_request.expectedSessionState.nextExpectedSeq
235+
)
236+
client_next_sent_seq = handshake_request.expectedSessionState.nextSentSeq or 0
237+
if old_session and old_session.session_id == handshake_request.sessionId:
238+
# check invariants
239+
# ordering must be correct
240+
our_next_seq = await old_session.get_next_sent_seq()
241+
our_ack = await old_session.get_next_expected_seq()
256242

257-
if our_next_seq > client_next_expected_seq:
258-
message = (
259-
"server is in the future: "
260-
f"client wanted {client_next_expected_seq} "
261-
f"but server has {our_next_seq}"
262-
)
263-
await self._send_handshake_response(
264-
request_message,
265-
HandShakeStatus(ok=False, reason=message),
266-
websocket,
267-
)
268-
raise SessionStateMismatchException(message)
269-
elif old_session:
270-
# we have an old session but the session id is different
271-
# just delete the old session
272-
await old_session.close()
273-
await self._delete_session(old_session)
274-
old_session = None
243+
if client_next_sent_seq > our_ack:
244+
message = (
245+
"client is in the future: "
246+
f"server wanted {our_ack} but client has {client_next_sent_seq}"
247+
)
248+
await self._send_handshake_response(
249+
request_message,
250+
HandShakeStatus(ok=False, reason=message),
251+
websocket,
252+
)
253+
raise SessionStateMismatchException(message)
275254

276-
if not old_session and (
277-
client_next_sent_seq > 0 or client_next_expected_seq > 0
278-
):
279-
message = "client is trying to resume a session but we don't have it"
255+
if our_next_seq > client_next_expected_seq:
256+
message = (
257+
"server is in the future: "
258+
f"client wanted {client_next_expected_seq} "
259+
f"but server has {our_next_seq}"
260+
)
280261
await self._send_handshake_response(
281262
request_message,
282263
HandShakeStatus(ok=False, reason=message),
283264
websocket,
284265
)
285266
raise SessionStateMismatchException(message)
267+
elif old_session:
268+
# we have an old session but the session id is different
269+
# just delete the old session
270+
await old_session.close()
271+
old_session = None
286272

287-
# from this point on, we're committed to connecting
288-
session_id = handshake_request.sessionId
289-
handshake_response = await self._send_handshake_response(
273+
if not old_session and (
274+
client_next_sent_seq > 0 or client_next_expected_seq > 0
275+
):
276+
message = "client is trying to resume a session but we don't have it"
277+
await self._send_handshake_response(
290278
request_message,
291-
HandShakeStatus(ok=True, sessionId=session_id),
279+
HandShakeStatus(ok=False, reason=message),
292280
websocket,
293281
)
282+
raise SessionStateMismatchException(message)
283+
284+
# from this point on, we're committed to connecting
285+
session_id = handshake_request.sessionId
286+
handshake_response = await self._send_handshake_response(
287+
request_message,
288+
HandShakeStatus(ok=True, sessionId=session_id),
289+
websocket,
290+
)
294291

295-
return handshake_request, handshake_response
292+
return handshake_request, handshake_response

0 commit comments

Comments
 (0)