Skip to content

Instantly share code, notes, and snippets.

@YoukouTenhouin
Last active December 21, 2015 04:38
Show Gist options
  • Save YoukouTenhouin/6250840 to your computer and use it in GitHub Desktop.
Save YoukouTenhouin/6250840 to your computer and use it in GitHub Desktop.
simple static file http server
import socket,select
import argparse
import os
import mimetypes
import time
class IOLoop:
def __init__(self,server):
self.clients = {}
self.server = server
self._init_loop()
self.timeout_queue = []
def _init_loop(self):
if hasattr(select,"epoll"):
self.poll_obj = select.epoll()
self.poll_obj.register(self.server.ss.fileno())
else:
self.readlist = [self.server.ss.fileno()]
self.writelist = []
raise NotImplemented()
def _add_epoll(self,client):
self.clients[client.socket.fileno()] = client
client.socket.setblocking(False)
self.timeout_queue.append(client)
try:
self.poll_obj.register(client.socket.fileno())
except IOError:
pass
except:
raise
def _rm_epoll(self,client):
try:
del self.clients[client.socket.fileno()]
if client in self.timeout_queue:
self.timeout_queue.remove(client)
self.poll_obj.unregister(client.socket.fileno())
except KeyError:
pass
except:
raise
if hasattr(select,"epoll"):
add = _add_epoll
remove = _rm_epoll
else:
raise NotImplemented()
def _loop_epoll(self):
while True:
try:
timeout = self.timeout_queue[0].deadline - int(time.time())
except IndexError:
timeout = -1
result = self.poll_obj.poll(timeout=timeout)
if result == []:
now = int(time.time())
while timeout_queue[0].deadline <= now:
try:
timeout,*self.timeout_queue = timeout_queue
except ValueError:
timeout = self.timeout_queue[0]
self.timeout_queue = []
except:
raise
timeout.stop()
for i in result:
if i[0] == self.server.ss.fileno():
self.server.callback()
elif i[1] & select.EPOLLIN:
self.clients[i[0]].read_callback()
elif i[1] & select.EPOLLOUT:
self.clients[i[0]].write_callback()
if hasattr(select,"epoll"):
loop = _loop_epoll
else:
raise NotImplemented()
class httpBadRequestError(Exception):
pass
class httpNotFoundError(Exception):
pass
class httpForbiddenError(Exception):
pass
class Client:
def __init__(self,socket,ioloop):
self.socket = socket
self.ioloop = ioloop
self.ioloop.add(self)
self.buffer = str()
self.write_queue = []
self.file_queue = []
self.currf = None
self.deadline = int(time.time()) + TIMEOUT
def read_from_sock(self):
try:
recv = self.socket.recv(8192).decode('utf8')
self.buffer += recv
if '\r\n\r\n' in self.buffer:
[request,self.buffer] = self.buffer.split('\r\n\r\n',2)
self.handle_request(request)
except UnicodeDecodeError:
raise httpBadRequestError()
except:
raise
def cut(self):
self.ioloop.remove(self)
self.flush()
self.stop()
def stop(self):
self.ioloop.remove(self)
self.socket.close()
def flush(self):
for i in self.write_queue:
self.socket.send(i)
def read_callback(self):
try:
self.read_from_sock()
except httpBadRequestError:
self.br_handler()
self.cut()
except httpNotFoundError:
self.nf_handler()
except httpForbiddenError:
self.fb_handler()
def handle_request(self,request):
try:
start_line = request.split('\r\n')[0]
[method,path,httpver] = start_line.split()
method = method.upper()
print("%s %s => "%(method,path),end='')
if method != 'GET':
# FIXME:Should return 405 Error
raise httpNotFoundError()
if path == '/':
path = 'index.html'
else:
path = path[1:]
path = path.replace('../','')
path,*_ = path.split('?')
headers = {
'Content-Type':mimetypes.guess_type(path)
}
try:
self.create_response('200 OK',headers,path=path)
except FileNotFoundError:
raise httpNotFoundError()
except IsADirectoryError:
raise httpForbiddenError()
except:
raise
except ValueError:
raise httpBadRequestError()
except:
raise
def create_response(self,status,headers,content=None,path=None):
head = 'HTTP/1.1 ' + status + '\r\n'
if content:
headers['Content-Length'] = len(content)
elif path:
if os.path.isdir(path):
raise IsADirectoryError()
headers['Content-Length'] = str(os.path.getsize(path))
headers['Server'] = 'RSHS'
if not 'Content-Type' in headers:
headers['Content-Type'] = 'text/html'
for i in headers:
head += ('%s: %s\r\n'%(i,headers[i]))
response = head + '\r\n'
response = response.encode('utf8')
self.write(response)
print(status)
if content:
if isinstance(content,str):
content = content.encode('utf8')
self.write(content)
elif path:
self.writefile(open(path,'rb'))
def write(self,response):
self.write_queue.append(response)
def writefile(self,file):
self.file_queue.append(file)
def write_callback(self):
try:
response,*self.write_queue = self.write_queue
self.socket.send(response)
except ValueError:
try:
if not self.currf:
self.currf,*self.file_queue = self.file_queue
buf = self.currf.read(8192)
if len(buf) == 0:
self.currf.close()
self.currf = None
else:
self.socket.send(buf)
except ValueError:
pass
except:
raise
except:
raise
def br_handler(self):
self.create_response('400 Bad Request',dict(),'Your client sent a bad request')
def nf_handler(self):
self.create_response('404 Not Found',dict(),'File not found')
def fb_handler(self):
self.create_response('403 Forbidden',dict(),'Forbidden')
class Server:
def __init__(self,client_t = Client):
self.client_t = client_t
def start(self,host='localhost',port=8080):
try:
self.ss = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
self.ss.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
self.ss.bind((host,port))
self.ss.listen(0)
print('Listening on %s:%d'%(host,port))
self.ioloop = IOLoop(self)
self.ioloop.loop()
except:
self.stop()
raise
def callback(self):
cs = self.ss.accept()
c = Client(cs[0],self.ioloop)
def stop(self):
print('shutdown')
self.ss.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Static file http server')
parser.add_argument('--port',dest='port',type=int,help='port to listen on',default=8080)
parser.add_argument('--host',dest='host',default='localhost',help='host to bind')
parser.add_argument('--root',dest='root',default='./',help='set root path of this server')
parser.add_argument('--timeout',dest='timeout',default=300,type=int,help='set timeout')
args = parser.parse_args()
os.chdir(args.root)
TIMEOUT = args.timeout
server = Server()
server.start(args.host,args.port)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment