Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .env_sample
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
secret=please_please_update_me_please
algorithm=HS256
# expiry time in milliseconds (3600000 = 1 hour)
jwt_expiry_time=3600000
# expiry time in seconds (3600 = 1 hour)
jwt_expiry_time=3600

# the log level
# values [TRACE, DEBUG, INFO, SUCCESS, WARNING, ERROR, CRITICAL]
Expand Down
4 changes: 2 additions & 2 deletions app/auth/auth_bearer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from .auth_handler import decodeJWT
from .auth_handler import decode_jwt

# https://testdriven.io/blog/fastapi-jwt-auth/

Expand Down Expand Up @@ -36,7 +36,7 @@ def verify_jwt(self, jwtoken: str) -> bool:
isTokenValid: bool = False

try:
payload = decodeJWT(jwtoken)
payload = decode_jwt(jwtoken)
except: # noqa: E722
payload = None

Expand Down
29 changes: 17 additions & 12 deletions app/auth/auth_handler.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
import asyncio
import os
import time
from typing import Dict

import jwt
from decouple import config
from loguru import logger

JWT_SECRET = config("secret")
JWT_ALGORITHM = config("algorithm")
JWT_EXPIRY_TIME = config("jwt_expiry_time", default=300, cast=int)


def sign_jwt() -> Dict[str, str]:
def sign_jwt() -> str:
payload = {
"user_id": "admin",
"expires": int(round(time.time() * 1000) + JWT_EXPIRY_TIME),
"expires": int(
time.time() + config("jwt_expiry_time", default=300, cast=int),
),
}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
token = jwt.encode(
payload,
config("secret"),
algorithm=str(config("algorithm")),
)
return token


def decodeJWT(token: str) -> dict:
def decode_jwt(token: str):
try:
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return decoded_token if decoded_token["expires"] >= time.time() * 1000 else None
decoded_token = jwt.decode(
token, config("secret"), algorithms=[str(config("algorithm"))]
)
return decoded_token if decoded_token["expires"] >= time.time() else None
except Exception as e:
logger.warning(f"Unable to decode jwt_token {e}")
return {}
Expand Down Expand Up @@ -57,9 +60,11 @@ def remove_local_cookie():

def register_cookie_updater():
# We need to update the cookie file once the cookie is expired
expiry_time = config("jwt_expiry_time", default=300, cast=int)

async def _cookie_updater():
while True:
await asyncio.sleep(JWT_EXPIRY_TIME - 10)
await asyncio.sleep(expiry_time - 10)
handle_local_cookie()

loop = asyncio.get_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion app/system/impl/native_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def get_connection_info(self) -> ConnectionInfo:
# return an empty connection info object for now
return ConnectionInfo()

async def login(self, i: LoginInput) -> Dict[str, str]:
async def login(self, i: LoginInput) -> str:
matches = secrets.compare_digest(i.password, config("login_password", cast=str))
if matches:
return sign_jwt()
Expand Down
2 changes: 1 addition & 1 deletion app/system/impl/raspiblitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def get_connection_info(self) -> ConnectionInfo:
cl_rest_onion=data_cl_rest_onion,
)

async def login(self, i: LoginInput) -> Dict[str, str]:
async def login(self, i: LoginInput) -> str:
matches = await self._match_password(i)
if matches:
return sign_jwt()
Expand Down
2 changes: 1 addition & 1 deletion app/system/impl/system_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def get_connection_info(self) -> ConnectionInfo:
raise NotImplementedError()

@abstractmethod
async def login(self, i: LoginInput) -> Dict[str, str]:
async def login(self, i: LoginInput) -> str:
raise NotImplementedError()

@abstractmethod
Expand Down
6 changes: 4 additions & 2 deletions app/system/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ async def login_path(i: LoginInput, response: Response):
response_description="Returns a fresh JWT token.",
dependencies=[Depends(JWTBearer())],
)
def refresh_token():
return sign_jwt()
def refresh_token(response: Response):
token = sign_jwt()
response.set_cookie("access_token", token)
return token


@router.post(
Expand Down
2 changes: 1 addition & 1 deletion app/system/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async def register_hardware_info_gatherer():
loop.create_task(_handle_gather_hardware_info())


async def login(i: LoginInput) -> Dict[str, str]:
async def login(i: LoginInput) -> str:
try:
return await system.login(i)
except HTTPException:
Expand Down
49 changes: 47 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ ruff = "^0.1.15"
ruff-lsp = "^0.0.52"
debugpy = "^1.8.1"
click = "^8.1.7"
httpx = "^0.27.0"

[build-system]
requires = ["poetry-core"]
Expand Down
Empty file added tests/auth/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions tests/auth/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import jwt

from app.auth.auth_handler import decode_jwt, sign_jwt

UNIX_TIME = 1629200000

DEFAULT_VALUES = {
"secret": "test_secret",
"algorithm": "HS256",
"jwt_expiry_time": 36008,
}


def mock_config(key, default=None, cast=None, vals=DEFAULT_VALUES):
if key not in vals:
raise ValueError(f"Unknown key {key}")

return vals.get(key, default)


def test_sign_jwt_valid_token(monkeypatch):
monkeypatch.setattr("app.auth.auth_handler.config", mock_config)
monkeypatch.setattr("app.auth.auth_handler.time.time", lambda: UNIX_TIME)

token = sign_jwt()

try:
t = jwt.decode(
token,
mock_config("secret"),
algorithms=[str(mock_config("algorithm"))],
)

assert "user_id" in t
assert "expires" in t
assert t["user_id"] == "admin"
assert t["expires"] == UNIX_TIME + mock_config("jwt_expiry_time")
except jwt.ExpiredSignatureError:
raise AssertionError("Token expired unexpectedly")
except jwt.InvalidTokenError as e:
print(e)
raise AssertionError(f"Invalid token: {e}")


def test_sign_jwt_expired_token(monkeypatch):
monkeypatch.setattr("app.auth.auth_handler.config", mock_config)
monkeypatch.setattr("app.auth.auth_handler.time.time", lambda: UNIX_TIME)

token = sign_jwt()

expired_time = UNIX_TIME + mock_config("jwt_expiry_time") + 3600
monkeypatch.setattr("app.auth.auth_handler.time.time", lambda: expired_time)

# Should return None when expired
res = decode_jwt(token)
assert res is None