|
1 | | -from urllib.parse import parse_qs |
2 | | - |
3 | 1 | import jwt |
4 | 2 | from channels.db import database_sync_to_async |
5 | 3 | from django.conf import settings |
@@ -97,33 +95,37 @@ def get_user(scope): |
97 | 95 |
|
98 | 96 | class TokenAuthMiddleware: |
99 | 97 | """ |
100 | | - Custom middleware that takes a token from the query string and authenticates via Django Rest Framework authtoken. |
| 98 | + Custom middleware that takes a token from WebSocket subprotocols and authenticates via JWT. |
101 | 99 | """ |
102 | 100 |
|
| 101 | + SUBPROTOCOL_KEYWORD = "Bearer" |
| 102 | + |
103 | 103 | def __init__(self, app): |
104 | 104 | # Store the ASGI application we were passed |
105 | 105 | self.app = app |
106 | 106 |
|
107 | 107 | async def __call__(self, scope, receive, send): |
108 | | - # Look up user from query string |
109 | | - |
110 | | - # TODO: (you should also do things like |
111 | | - # checking if it is a valid user ID, or if scope["user" ] is already |
112 | | - # populated). |
113 | | - |
114 | | - query_string = scope["query_string"].decode() |
115 | | - query_dict = parse_qs(query_string) |
116 | | - try: |
117 | | - token = query_dict["token"][0] |
118 | | - if token is None: |
119 | | - raise ValueError("Token is missing from headers") |
| 108 | + # Extract token from Sec-WebSocket-Protocol header. |
| 109 | + token = self._extract_token_from_subprotocol(scope.get("subprotocols", [])) |
120 | 110 |
|
| 111 | + if token: |
121 | 112 | scope["token"] = token |
122 | 113 | scope["user"] = await get_user(scope) |
123 | | - except (ValueError, KeyError, IndexError): |
124 | | - # Token is missing from query string |
| 114 | + else: |
125 | 115 | from django.contrib.auth.models import AnonymousUser |
126 | 116 |
|
127 | 117 | scope["user"] = AnonymousUser() |
128 | 118 |
|
129 | 119 | return await self.app(scope, receive, send) |
| 120 | + |
| 121 | + def _extract_token_from_subprotocol(self, subprotocols: list[str]) -> str | None: |
| 122 | + """ |
| 123 | + Expect subprotocols in the form ["Bearer", "<JWT>"]. |
| 124 | + """ |
| 125 | + if not subprotocols: |
| 126 | + return None |
| 127 | + |
| 128 | + if len(subprotocols) >= 2 and subprotocols[0] == self.SUBPROTOCOL_KEYWORD: |
| 129 | + return subprotocols[1] |
| 130 | + |
| 131 | + return None |
0 commit comments