Created
August 11, 2018 10:37
-
-
Save gsw945/b7d5bdfba6dae77bc35a18a908d7f912 to your computer and use it in GitHub Desktop.
flask reverse proxy to other server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
try: | |
from urllib.parse import ( | |
urlencode, | |
parse_qs, | |
urlsplit, | |
urlunsplit | |
) | |
except ImportError: | |
from urllib import urlencode | |
from urlparse import ( | |
parse_qs, | |
urlsplit, | |
urlunsplit | |
) | |
from flask import Flask | |
from flask import ( | |
request, | |
Response, | |
make_response, | |
abort | |
) | |
from flask.logging import default_handler | |
import requests | |
# http://flask.pocoo.org/docs/dev/logging/ | |
class RequestFormatter(logging.Formatter): | |
def format(self, record): | |
record.url = request.url | |
record.remote_addr = request.remote_addr | |
return super().format(record) | |
formatter = RequestFormatter( | |
'[%(asctime)s] %(remote_addr)s requested %(url)s\n' | |
'%(levelname)s in %(module)s: %(message)s' | |
) | |
default_handler.setFormatter(formatter) | |
# https://docs.python.org/3/library/logging.html#logging-levels | |
# NOTSET->DEBUG->INFO->WARNING->ERROR->CRITICAL | |
default_handler.setLevel(logging.NOTSET) | |
# root = logging.getLogger() | |
# root.addHandler(default_handler) | |
def set_query_parameter(url, param_name, param_value): | |
""" | |
Given a URL, set or replace a query parameter and return the modified URL. | |
>>> set_query_parameter('http://example.com?foo=bar&biz=baz', 'foo', 'stuff') | |
'http://example.com?foo=stuff&biz=baz' | |
""" | |
# from: https://stackoverflow.com/questions/4293460/how-to-add-custom-parameters-to-an-url-query-string-with-python#12897375 | |
scheme, netloc, path, query_string, fragment = urlsplit(url) | |
query_params = parse_qs(query_string) | |
query_params[param_name] = [param_value] | |
new_query_string = urlencode(query_params, doseq=True) | |
return urlunsplit((scheme, netloc, path, new_query_string, fragment)) | |
def set_domain(url, new_domain, new_scheme=None): | |
scheme, netloc, path, query_string, fragment = urlsplit(url) | |
if bool(new_scheme): | |
scheme = new_scheme | |
return urlunsplit((scheme, new_domain, path, query_string, fragment)) | |
def get_domain(url): | |
scheme, netloc, path, query_string, fragment = urlsplit(url) | |
return netloc | |
# https://stackoverflow.com/questions/6656363/proxying-to-another-web-service-with-flask/36601467#36601467 | |
def reverse_proxy(old_domain=None, new_domain=None, scheme=None, timeout=10): | |
'''reverse-proxy-with''' | |
req_url = request.url | |
need_set = True | |
if bool(old_domain): | |
if get_domain(req_url) != old_domain: | |
need_set = False | |
if need_set: | |
req_url = set_domain(req_url, new_domain, scheme) | |
req_headers = {key: value for (key, value) in request.headers if key != 'Host'} | |
resp = None | |
try: | |
resp = requests.request( | |
method=request.method, | |
url=req_url, | |
headers=req_headers, | |
data=request.get_data(), | |
cookies=request.cookies, | |
allow_redirects=False, | |
timeout=timeout, | |
verify=False | |
) | |
except requests.exceptions.ConnectTimeout: | |
return abort(502) | |
except requests.exceptions.ConnectionError: | |
return abort(503) | |
except requests.exceptions.ReadTimeout: | |
return abort(504) | |
except Exception: | |
return abort(500) | |
excluded_headers = [ | |
'content-encoding', 'content-length', 'transfer-encoding', 'connection' | |
] | |
headers = [ | |
(name, value) for (name, value) in resp.raw.headers.items() | |
if name.lower() not in excluded_headers | |
] | |
response = Response(resp.content, resp.status_code, headers) | |
return response | |
app = Flask(__name__) | |
@app.before_request | |
def intercept_hook(): | |
'''截获''' | |
username = request.values.get('username', None) | |
if bool(username): | |
# TODO: capture request data | |
print('窃取到了[username]:', username) | |
else: | |
if request.method.lower() == 'put': | |
return abort(409) | |
# return abort(400) | |
msgs = [ | |
'<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">', | |
'<title>400 Bad Request</title>', | |
'<h1>Bad Request</h1>', | |
'<p>The browser (or proxy) sent a request that this server could not understand.</p>', | |
'<hr />', | |
'<h3>{0}</h3>'.format('[username] is required') | |
] | |
resp = make_response('\n'.join(msgs)) | |
return (resp, 400) | |
@app.after_request | |
def tamper_hook(response): | |
'''篡改''' | |
response.headers['X-Tamper'] = 'modified' | |
return response | |
@app.teardown_request | |
def cleanup_request(exception): | |
'''清理''' | |
pass | |
@app.route('/', methods=['GET', 'POST']) | |
@app.route('/index', methods=['GET', 'POST']) | |
def view_index(): | |
# return 'hello index' | |
return reverse_proxy(new_domain='127.0.0.1:5001') | |
def main(): | |
cfg = { | |
'host': '0.0.0.0', | |
'port': 5656, | |
'debug': False | |
} | |
print('visit via [http://{host}:{port}/]'.format(host='127.0.0.1', port=cfg['port'])) | |
if not (app.debug or cfg['debug']): | |
app.logger.addHandler(default_handler) | |
app.run(**cfg) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, abort | |
import time | |
app = Flask(__name__) | |
@app.route('/') | |
@app.route('/index') | |
def viw_idx(): | |
# time.sleep(31) # simulate timeout, lead ReadTime error | |
# return abort(500) # simulate upstream error response | |
return 'this is upstream response' | |
if __name__ == '__main__': | |
app.run(port=5001) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment