Skip to content

Commit 4808607

Browse files
committed
Guess file extension if there isnt any
1 parent dad95d0 commit 4808607

3 files changed

Lines changed: 22 additions & 3 deletions

File tree

main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from util.misc import *
66
from util.filenameGenerator import filenameGenerator
7-
7+
import util.extensionSolver
88

99
app = FastAPI()
1010
app.mount(configuration.IMAGE_DIR, StaticFiles(directory=configuration.UPLOAD_PATH), name=configuration.IMAGE_DIR)
@@ -16,11 +16,19 @@ async def index():
1616

1717
@app.post("/upload")
1818
async def upload(file: UploadFile = File(...), authorized = Depends(checkCredentials)):
19+
if len(file.filename.split('.')) == 1: # Guess the extension if there is none
20+
_file = await file.read()
21+
guessExt = util.extensionSolver.guessFileExtension(_file)
22+
guessMime = util.extensionSolver.guessMime(_file)
23+
file.content_type = guessMime
24+
file.filename = file.filename + guessExt
25+
await file.seek(0) # Seek back, otherwise upload will fail
1926
if file.content_type not in configuration.ALLOWED_CONTENT:
20-
return {"message": configuration.ERROR_UNALLOWED_CONTENT}
27+
return {"error": configuration.ERROR_UNALLOWED_CONTENT}
2128
if authorized:
2229
if configuration.RANDOMIZED_FILENAMES:
2330
_extension = file.filename.split('.')[1].lower() # Save the extension of the file for upload
31+
print(_extension)
2432
file.filename = filenameGenerator.generateName(5) +f".{_extension}"
2533
await uploadFile(file)
2634
return {"url": configuration.BASE_URL+configuration.IMAGE_DIR+file.filename}

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
fastapi
22
python-multipart
3-
uvicorn
3+
uvicorn
4+
python-magic

util/extensionSolver.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import magic
2+
import mimetypes
3+
4+
def guessMime(file) -> str:
5+
return magic.from_buffer(file, mime=True)
6+
7+
def guessFileExtension(file) -> str:
8+
mime = guessMime(file)
9+
return mimetypes.guess_extension(mime)
10+

0 commit comments

Comments
 (0)