Skip to content

Instantly share code, notes, and snippets.

@otmb
Created November 10, 2024 15:12
Show Gist options
  • Save otmb/83d5d8a465d357fdefcf25dc0caacce2 to your computer and use it in GitHub Desktop.
Save otmb/83d5d8a465d357fdefcf25dc0caacce2 to your computer and use it in GitHub Desktop.
File upload progress with FastAPI
# https://www.starlette.io/requests/
# https://developer.mozilla.org/ja/docs/Web/API/Streams_API/Using_readable_streams
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, StreamingResponse, Response
from logging import getLogger, StreamHandler
import asyncio
import tempfile
import json
import time
import os
import typing
from starlette.datastructures import FormData
from starlette.formparsers import MultiPartParser, MultiPartException, _user_safe_decode
if typing.TYPE_CHECKING:
import multipart
from multipart.multipart import MultipartCallbacks, parse_options_header
else:
try:
try:
import python_multipart as multipart
from python_multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
import multipart
from multipart.multipart import parse_options_header
except ModuleNotFoundError: # pragma: no cover
multipart = None
parse_options_header = None
app = FastAPI()
logger = getLogger(__name__)
logger.addHandler(StreamHandler())
logger.setLevel("INFO")
@app.get("/", response_class=HTMLResponse)
async def upload_form():
return HTMLResponse("""
<input type="file" id="file" />
<input type="button" value="submit" onclick="upload_post()"/>
<div><progress id="progressBar" max="100" value="0" style="width:300px;"></progress></div>
<div id="progressText"></div>
<div><a id="download"></a></div>
<script type="text/javascript">
async function upload_post(){
document.getElementById("progressBar").value = "0";
document.getElementById("progressText").innerText = "";
const formData = new FormData();
const file = document.querySelector("#file");
formData.append("file", file.files[0]);
const filetype = file.files[0].type;
const filename = file.files[0].name;
const task_id = Date.now();
progress_request(task_id);
const response = await fetch("upload", {
method: 'POST',
body: formData,
headers: { 'TaskId': task_id, 'filename': filename }
});
if (response.statusText == "OK"){
const reader = response.body.getReader();
const { done, value } = await reader.read();
if (done){ return; }
writeText(value);
}
}
async function progress_request(task_id){
const response = await fetch("progress", {
method: 'POST',
headers: { 'TaskId': task_id }
});
const reader = response.body.getReader();
while (true) {
const { done, value } = await reader.read();
if (done){ return; }
writeText(value);
}
}
async function writeText(value){
try {
const text = new TextDecoder().decode(value);
const json = JSON.parse(text);
document.getElementById("progressBar").value = json.progress;
document.getElementById("progressText").innerText = `${json.length}/${json.max_length} (${json.progress}%)`;
if ('filepath' in json){
const params = {
"filepath": json.filepath
};
const query = new URLSearchParams(params);
const download = document.getElementById("download");
download.href = "download?" + query;
download.download = json.filename;
download.innerText = json.filename;
}
} catch (error) {
console.log(error);
}
}
</script>
""")
job = {}
lock = asyncio.Lock()
@app.post("/upload")
async def upload_file(request: Request):
logger.info(request.headers)
content_length = int(request.headers["content-length"])
response = {"progress": 0, "length": 0 , "max_length": content_length}
task_id = request.headers["TaskId"]
filename = request.headers["filename"]
logger.info(f"task_id: {task_id}")
try:
async with lock:
job[task_id] = response
await asyncio.sleep(0.1)
parser = CustomMultiPartParser(request.headers, request.stream())
form = await parser.parse()
file = form['file']
data = await file.read()
with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as t:
t.write(data)
filepath = os.path.join(os.path.dirname(t.name), filename)
os.rename(t.name, filepath)
logger.info(filepath)
async with lock:
response = job[task_id]
response["filename"] = filename
response["filepath"] = filepath
finally:
async with lock:
if task_id in job:
del job[task_id]
return Response(json.dumps(response), media_type="text/plain")
@app.post("/progress")
async def get_progress(request: Request):
async def stream_progress(request):
try:
start_time = time.time()
task_id = request.headers["TaskId"]
is_entry_job = False
progress = 0
job_data = {}
logger.info(f"task_id: {task_id}")
while(progress < 100):
async with lock:
if task_id in job:
is_entry_job = True
job_data = job[task_id]
progress = job_data["progress"]
elif is_entry_job:
logger.info("progress task finish")
break
elif start_time < time.time() - 60:
logger.info("progress timeout")
break
if is_entry_job:
yield json.dumps(job_data) + "\n"
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("caught cancelled error")
raise GeneratorExit
return StreamingResponse(content=stream_progress(request), media_type="text/plain")
@app.get("/download")
async def download(request: Request, filepath: str):
if os.path.isfile(filepath):
with open(filepath, "rb") as f:
data = f.read()
else:
raise HTTPException(status_code=404, detail="Item not found")
return Response(data, media_type="application/octet-stream")
# MultiPartParser: https://github.com/encode/starlette/blob/master/starlette/formparsers.py
class CustomMultiPartParser(MultiPartParser):
async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if isinstance(charset, bytes):
charset = charset.decode("latin-1")
self._charset = charset
try:
boundary = params[b"boundary"]
except KeyError:
raise MultiPartException("Missing boundary in multipart.")
# Callbacks dictionary.
callbacks: MultipartCallbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
"on_header_field": self.on_header_field,
"on_header_value": self.on_header_value,
"on_header_end": self.on_header_end,
"on_headers_finished": self.on_headers_finished,
"on_end": self.on_end,
}
content_length = int(self.headers["content-length"])
task_id = self.headers["TaskId"]
chunks = 0
# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
try:
# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
# Write file data, it needs to use await with the UploadFile methods
# that call the corresponding file methods *in a threadpool*,
# otherwise, if they were called directly in the callback methods above
# (regular, non-async functions), that would block the event loop in
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()
chunks += len(chunk)
progress = int(chunks / content_length * 100)
async with lock:
if task_id in job:
job[task_id]["progress"] = progress
job[task_id]["length"] = chunks
await asyncio.sleep(0.001)
except MultiPartException as exc:
# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
raise exc
parser.finalize()
return FormData(self.items)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment