Created
November 10, 2024 15:12
-
-
Save otmb/83d5d8a465d357fdefcf25dc0caacce2 to your computer and use it in GitHub Desktop.
File upload progress with FastAPI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://www.starlette.io/requests/ | |
# https://developer.mozilla.org/ja/docs/Web/API/Streams_API/Using_readable_streams | |
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.responses import HTMLResponse, StreamingResponse, Response | |
from logging import getLogger, StreamHandler | |
import asyncio | |
import tempfile | |
import json | |
import time | |
import os | |
import typing | |
from starlette.datastructures import FormData | |
from starlette.formparsers import MultiPartParser, MultiPartException, _user_safe_decode | |
if typing.TYPE_CHECKING: | |
import multipart | |
from multipart.multipart import MultipartCallbacks, parse_options_header | |
else: | |
try: | |
try: | |
import python_multipart as multipart | |
from python_multipart.multipart import parse_options_header | |
except ModuleNotFoundError: # pragma: no cover | |
import multipart | |
from multipart.multipart import parse_options_header | |
except ModuleNotFoundError: # pragma: no cover | |
multipart = None | |
parse_options_header = None | |
app = FastAPI() | |
logger = getLogger(__name__) | |
logger.addHandler(StreamHandler()) | |
logger.setLevel("INFO") | |
@app.get("/", response_class=HTMLResponse) | |
async def upload_form(): | |
return HTMLResponse(""" | |
<input type="file" id="file" /> | |
<input type="button" value="submit" onclick="upload_post()"/> | |
<div><progress id="progressBar" max="100" value="0" style="width:300px;"></progress></div> | |
<div id="progressText"></div> | |
<div><a id="download"></a></div> | |
<script type="text/javascript"> | |
async function upload_post(){ | |
document.getElementById("progressBar").value = "0"; | |
document.getElementById("progressText").innerText = ""; | |
const formData = new FormData(); | |
const file = document.querySelector("#file"); | |
formData.append("file", file.files[0]); | |
const filetype = file.files[0].type; | |
const filename = file.files[0].name; | |
const task_id = Date.now(); | |
progress_request(task_id); | |
const response = await fetch("upload", { | |
method: 'POST', | |
body: formData, | |
headers: { 'TaskId': task_id, 'filename': filename } | |
}); | |
if (response.statusText == "OK"){ | |
const reader = response.body.getReader(); | |
const { done, value } = await reader.read(); | |
if (done){ return; } | |
writeText(value); | |
} | |
} | |
async function progress_request(task_id){ | |
const response = await fetch("progress", { | |
method: 'POST', | |
headers: { 'TaskId': task_id } | |
}); | |
const reader = response.body.getReader(); | |
while (true) { | |
const { done, value } = await reader.read(); | |
if (done){ return; } | |
writeText(value); | |
} | |
} | |
async function writeText(value){ | |
try { | |
const text = new TextDecoder().decode(value); | |
const json = JSON.parse(text); | |
document.getElementById("progressBar").value = json.progress; | |
document.getElementById("progressText").innerText = `${json.length}/${json.max_length} (${json.progress}%)`; | |
if ('filepath' in json){ | |
const params = { | |
"filepath": json.filepath | |
}; | |
const query = new URLSearchParams(params); | |
const download = document.getElementById("download"); | |
download.href = "download?" + query; | |
download.download = json.filename; | |
download.innerText = json.filename; | |
} | |
} catch (error) { | |
console.log(error); | |
} | |
} | |
</script> | |
""") | |
job = {} | |
lock = asyncio.Lock() | |
@app.post("/upload") | |
async def upload_file(request: Request): | |
logger.info(request.headers) | |
content_length = int(request.headers["content-length"]) | |
response = {"progress": 0, "length": 0 , "max_length": content_length} | |
task_id = request.headers["TaskId"] | |
filename = request.headers["filename"] | |
logger.info(f"task_id: {task_id}") | |
try: | |
async with lock: | |
job[task_id] = response | |
await asyncio.sleep(0.1) | |
parser = CustomMultiPartParser(request.headers, request.stream()) | |
form = await parser.parse() | |
file = form['file'] | |
data = await file.read() | |
with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as t: | |
t.write(data) | |
filepath = os.path.join(os.path.dirname(t.name), filename) | |
os.rename(t.name, filepath) | |
logger.info(filepath) | |
async with lock: | |
response = job[task_id] | |
response["filename"] = filename | |
response["filepath"] = filepath | |
finally: | |
async with lock: | |
if task_id in job: | |
del job[task_id] | |
return Response(json.dumps(response), media_type="text/plain") | |
@app.post("/progress") | |
async def get_progress(request: Request): | |
async def stream_progress(request): | |
try: | |
start_time = time.time() | |
task_id = request.headers["TaskId"] | |
is_entry_job = False | |
progress = 0 | |
job_data = {} | |
logger.info(f"task_id: {task_id}") | |
while(progress < 100): | |
async with lock: | |
if task_id in job: | |
is_entry_job = True | |
job_data = job[task_id] | |
progress = job_data["progress"] | |
elif is_entry_job: | |
logger.info("progress task finish") | |
break | |
elif start_time < time.time() - 60: | |
logger.info("progress timeout") | |
break | |
if is_entry_job: | |
yield json.dumps(job_data) + "\n" | |
await asyncio.sleep(1) | |
except asyncio.CancelledError: | |
logger.info("caught cancelled error") | |
raise GeneratorExit | |
return StreamingResponse(content=stream_progress(request), media_type="text/plain") | |
@app.get("/download") | |
async def download(request: Request, filepath: str): | |
if os.path.isfile(filepath): | |
with open(filepath, "rb") as f: | |
data = f.read() | |
else: | |
raise HTTPException(status_code=404, detail="Item not found") | |
return Response(data, media_type="application/octet-stream") | |
# MultiPartParser: https://github.com/encode/starlette/blob/master/starlette/formparsers.py | |
class CustomMultiPartParser(MultiPartParser): | |
async def parse(self) -> FormData: | |
# Parse the Content-Type header to get the multipart boundary. | |
_, params = parse_options_header(self.headers["Content-Type"]) | |
charset = params.get(b"charset", "utf-8") | |
if isinstance(charset, bytes): | |
charset = charset.decode("latin-1") | |
self._charset = charset | |
try: | |
boundary = params[b"boundary"] | |
except KeyError: | |
raise MultiPartException("Missing boundary in multipart.") | |
# Callbacks dictionary. | |
callbacks: MultipartCallbacks = { | |
"on_part_begin": self.on_part_begin, | |
"on_part_data": self.on_part_data, | |
"on_part_end": self.on_part_end, | |
"on_header_field": self.on_header_field, | |
"on_header_value": self.on_header_value, | |
"on_header_end": self.on_header_end, | |
"on_headers_finished": self.on_headers_finished, | |
"on_end": self.on_end, | |
} | |
content_length = int(self.headers["content-length"]) | |
task_id = self.headers["TaskId"] | |
chunks = 0 | |
# Create the parser. | |
parser = multipart.MultipartParser(boundary, callbacks) | |
try: | |
# Feed the parser with data from the request. | |
async for chunk in self.stream: | |
parser.write(chunk) | |
# Write file data, it needs to use await with the UploadFile methods | |
# that call the corresponding file methods *in a threadpool*, | |
# otherwise, if they were called directly in the callback methods above | |
# (regular, non-async functions), that would block the event loop in | |
# the main thread. | |
for part, data in self._file_parts_to_write: | |
assert part.file # for type checkers | |
await part.file.write(data) | |
for part in self._file_parts_to_finish: | |
assert part.file # for type checkers | |
await part.file.seek(0) | |
self._file_parts_to_write.clear() | |
self._file_parts_to_finish.clear() | |
chunks += len(chunk) | |
progress = int(chunks / content_length * 100) | |
async with lock: | |
if task_id in job: | |
job[task_id]["progress"] = progress | |
job[task_id]["length"] = chunks | |
await asyncio.sleep(0.001) | |
except MultiPartException as exc: | |
# Close all the files if there was an error. | |
for file in self._files_to_close_on_error: | |
file.close() | |
raise exc | |
parser.finalize() | |
return FormData(self.items) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment