-
-
Save zning1994/b05eb27c63147f7c7b4847df0193a3ee to your computer and use it in GitHub Desktop.
openai chatgpt python flask proxy 代理,支持 SSE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, request, Response | |
import requests, logging | |
_FIELD_SEPARATOR = ':' | |
class SSEClient(object): | |
"""Implementation of a SSE client. | |
See http://www.w3.org/TR/2009/WD-eventsource-20091029/ for the | |
specification. | |
""" | |
def __init__(self, event_source, char_enc='utf-8'): | |
"""Initialize the SSE client over an existing, ready to consume | |
event source. | |
The event source is expected to be a binary stream and have a close() | |
method. That would usually be something that implements | |
io.BinaryIOBase, like an httplib or urllib3 HTTPResponse object. | |
""" | |
self._logger = logging.getLogger(self.__class__.__module__) | |
self._logger.debug('Initialized SSE client from event source %s', | |
event_source) | |
self._event_source = event_source | |
self._char_enc = char_enc | |
def _read(self): | |
"""Read the incoming event source stream and yield event chunks. | |
Unfortunately it is possible for some servers to decide to break an | |
event into multiple HTTP chunks in the response. It is thus necessary | |
to correctly stitch together consecutive response chunks and find the | |
SSE delimiter (empty new line) to yield full, correct event chunks.""" | |
data = b'' | |
for chunk in self._event_source: | |
for line in chunk.splitlines(True): | |
data += line | |
if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): | |
yield data | |
data = b'' | |
if data: | |
yield data | |
def events(self): | |
for chunk in self._read(): | |
event = Event() | |
# Split before decoding so splitlines() only uses \r and \n | |
for line in chunk.splitlines(): | |
# Decode the line. | |
line = line.decode(self._char_enc) | |
# Lines starting with a separator are comments and are to be | |
# ignored. | |
if not line.strip() or line.startswith(_FIELD_SEPARATOR): | |
continue | |
data = line.split(_FIELD_SEPARATOR, 1) | |
field = data[0] | |
# Ignore unknown fields. | |
if field not in event.__dict__: | |
self._logger.debug('Saw invalid field %s while parsing ' | |
'Server Side Event', field) | |
continue | |
if len(data) > 1: | |
# From the spec: | |
# "If value starts with a single U+0020 SPACE character, | |
# remove it from value." | |
if data[1].startswith(' '): | |
value = data[1][1:] | |
else: | |
value = data[1] | |
else: | |
# If no value is present after the separator, | |
# assume an empty value. | |
value = '' | |
# The data field may come over multiple lines and their values | |
# are concatenated with each other. | |
if field == 'data': | |
event.__dict__[field] += value + '\n' | |
else: | |
event.__dict__[field] = value | |
# Events with no data are not dispatched. | |
if not event.data: | |
continue | |
# If the data field ends with a newline, remove it. | |
if event.data.endswith('\n'): | |
event.data = event.data[0:-1] | |
# Empty event names default to 'message' | |
event.event = event.event or 'message' | |
# Dispatch the event | |
self._logger.debug('Dispatching %s...', event) | |
yield event | |
def close(self): | |
"""Manually close the event source stream.""" | |
self._event_source.close() | |
class Event(object): | |
"""Representation of an event from the event stream.""" | |
def __init__(self, id=None, event='message', data='', retry=None): | |
self.id = id | |
self.event = event | |
self.data = data | |
self.retry = retry | |
def __str__(self): | |
s = '{0} event'.format(self.event) | |
if self.id: | |
s += ' #{0}'.format(self.id) | |
if self.data: | |
s += ', {0} byte{1}'.format(len(self.data), | |
's' if len(self.data) else '') | |
else: | |
s += ', no data' | |
if self.retry: | |
s += ', retry in {0}ms'.format(self.retry) | |
return s | |
app = Flask(__name__) | |
@app.route('/', defaults={'path': ''}) | |
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE']) | |
def proxy(path): | |
url = request.url.replace(request.host_url, 'https://api.openai.com/') | |
stream = None | |
try: | |
stream = request.get_json().get('stream', None) | |
except: | |
pass | |
resp = requests.request( | |
method=request.method, | |
url=url, | |
stream=stream, | |
headers={key: value for (key, value) in request.headers if key != 'Host'}, | |
data=request.get_data(), | |
allow_redirects=False) | |
if not stream: | |
excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection'] | |
headers = [(name, value) for (name, value) in resp.raw.headers.items() if name.lower() not in excluded_headers] | |
response = app.make_response((resp.content, resp.status_code, headers)) | |
return response | |
def stream_generate(): | |
client = SSEClient(resp) | |
for event in client.events(): | |
yield ('data: ' + event.data + '\n\n') | |
return Response(stream_generate(), mimetype='text/event-stream') | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0',port=9000) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment