Skip to content

Commit e24506a

Browse files
authored
Merge pull request #574 from PROCOLLAB-github/feature/ws-auth-middleware
WebSocket через subprotocol + единый middleware
2 parents ecedd7e + 93ed76a commit e24506a

5 files changed

Lines changed: 38 additions & 21 deletions

File tree

chats/consumers/chat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,15 @@ async def connect(self):
7777
await self.channel_layer.group_add(
7878
EventGroupType.GENERAL_EVENTS, self.channel_name
7979
)
80-
await self.accept()
80+
# Confirm selected subprotocol so browser clients finish handshake.
81+
subprotocol = None
82+
if (
83+
self.scope.get("subprotocols")
84+
and len(self.scope["subprotocols"]) >= 1
85+
):
86+
subprotocol = self.scope["subprotocols"][0]
87+
88+
await self.accept(subprotocol=subprotocol)
8189

8290
async def disconnect(self, close_code):
8391
"""User disconnected from websocket"""

core/auth/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Authentication utilities for ASGI/WebSocket middleware.
3+
"""
Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from urllib.parse import parse_qs
2-
31
import jwt
42
from channels.db import database_sync_to_async
53
from django.conf import settings
@@ -97,33 +95,37 @@ def get_user(scope):
9795

9896
class TokenAuthMiddleware:
9997
"""
100-
Custom middleware that takes a token from the query string and authenticates via Django Rest Framework authtoken.
98+
Custom middleware that takes a token from WebSocket subprotocols and authenticates via JWT.
10199
"""
102100

101+
SUBPROTOCOL_KEYWORD = "Bearer"
102+
103103
def __init__(self, app):
104104
# Store the ASGI application we were passed
105105
self.app = app
106106

107107
async def __call__(self, scope, receive, send):
108-
# Look up user from query string
109-
110-
# TODO: (you should also do things like
111-
# checking if it is a valid user ID, or if scope["user" ] is already
112-
# populated).
113-
114-
query_string = scope["query_string"].decode()
115-
query_dict = parse_qs(query_string)
116-
try:
117-
token = query_dict["token"][0]
118-
if token is None:
119-
raise ValueError("Token is missing from headers")
108+
# Extract token from Sec-WebSocket-Protocol header.
109+
token = self._extract_token_from_subprotocol(scope.get("subprotocols", []))
120110

111+
if token:
121112
scope["token"] = token
122113
scope["user"] = await get_user(scope)
123-
except (ValueError, KeyError, IndexError):
124-
# Token is missing from query string
114+
else:
125115
from django.contrib.auth.models import AnonymousUser
126116

127117
scope["user"] = AnonymousUser()
128118

129119
return await self.app(scope, receive, send)
120+
121+
def _extract_token_from_subprotocol(self, subprotocols: list[str]) -> str | None:
122+
"""
123+
Expect subprotocols in the form ["Bearer", "<JWT>"].
124+
"""
125+
if not subprotocols:
126+
return None
127+
128+
if len(subprotocols) >= 2 and subprotocols[0] == self.SUBPROTOCOL_KEYWORD:
129+
return subprotocols[1]
130+
131+
return None

procollab/asgi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
# Ensure Django app registry is loaded before importing project routes.
99
django_asgi_app = get_asgi_application()
1010

11-
import chats.routing # noqa: E402
12-
from chats.middleware import TokenAuthMiddleware # noqa: E402
11+
from core.auth.middleware import TokenAuthMiddleware # noqa: E402
12+
from procollab.websocket_routing import websocket_urlpatterns # noqa: E402
1313

1414
application = ProtocolTypeRouter(
1515
{
1616
"http": django_asgi_app,
17-
"websocket": TokenAuthMiddleware(URLRouter(chats.routing.websocket_urlpatterns)),
17+
"websocket": TokenAuthMiddleware(URLRouter(websocket_urlpatterns)),
1818
}
1919
)

procollab/websocket_routing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from chats.routing import websocket_urlpatterns as chat_websocket_urlpatterns
2+
3+
websocket_urlpatterns = []
4+
websocket_urlpatterns += chat_websocket_urlpatterns

0 commit comments

Comments
 (0)