Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions modules/sqlite_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def insert_url(sqlite_file: str, url: str, alias: str, expiration_date: typing.U
cursor = db.cursor()
timestamp = datetime.now()
if expiration_date is not None:
if isinstance(expiration_date, str) and expiration_date.endswith('Z'):
expiration_date = expiration_date[:-1] + '+00:00'
expiration_date = datetime.fromisoformat(expiration_date)
try:
sql = "INSERT INTO urls(url, alias, created_at, expires_at) VALUES (?, ?, ?, ?)"
Expand Down Expand Up @@ -83,6 +85,7 @@ def get_urls(sqlite_file, page=0, search=None, sort_by="created_at", order="DESC
"alias": row[2],
"created_at": row[3],
"used": row[4],
"expires_at": row[5]
}
url_array.append(url_data)
except KeyError:
Expand Down Expand Up @@ -126,9 +129,14 @@ def maybe_delete_expired_url(sqlite_file, sqlite_row) -> bool: #returns True if
utc_tz = ZoneInfo('UTC')

expiration_datetime = None
# sqlite_row[5] represents the expiration datetime e.g., "2024-11-04 18:05:24.006593"
# sqlite_row[5] represents the expiration datetime in UTC timezone e.g., "2024-11-04 05:09:00+00:00"
# The following date and datetime formats are now supported:
# 2024-11-04
# 2024-11-04T18:05:24
# 2024-11-04T18:05:24.123456
# 2024-11-04T18:05:24+02:00
if sqlite_row[5] is not None:
expiration_datetime = datetime.strptime(sqlite_row[5], "%Y-%m-%d %H:%M:%S.%f")
expiration_datetime = datetime.fromisoformat(sqlite_row[5])
expiration_datetime = expiration_datetime.replace(tzinfo=utc_tz)

now = datetime.now(tz=utc_tz)
Expand Down
8 changes: 6 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def create_url(request: Request):
alias = generate_alias(urljson["url"])
if not alias.isalnum():
raise ValueError("alias must only contain alphanumeric characters")
expiration_date = urljson.get("expiration_date")
expiration_date = urljson.get("expires_at")

with MetricsHandler.query_time.labels("create").time():
response = sqlite_helpers.insert_url(
Expand Down Expand Up @@ -99,7 +99,7 @@ async def get_urls(
sort_by: str = "created_at",
order: str = "DESC",
):
valid_sort_attributes = {"id", "url", "alias", "created_at", "used"}
valid_sort_attributes = {"id", "url", "alias", "created_at", "expires_at", "used"}
if order not in {"DESC", "ASC"}:
raise HTTPException(status_code=400, detail="Invalid order")
if sort_by not in valid_sort_attributes:
Expand Down Expand Up @@ -128,6 +128,10 @@ async def get_url(alias: str):
logging.debug(f"/find called with alias: {alias}")
url_output = cache.find(alias) # try to find url in cache
if url_output is not None:
valid = sqlite_helpers.get_url(DATABASE_FILE, alias)
if valid is None:
cache.delete(alias)
raise HTTPException(status_code=HttpResponse.NOT_FOUND.code)
alias_queue.put(alias)
return RedirectResponse(url_output)

Expand Down