Skip to content

Instantly share code, notes, and snippets.

@jerrylususu
Created March 23, 2025 11:39
Show Gist options
  • Save jerrylususu/78cd41ae51fff75f684b3a531aeb83fc to your computer and use it in GitHub Desktop.
Save jerrylususu/78cd41ae51fff75f684b3a531aeb83fc to your computer and use it in GitHub Desktop.
flask sse proxy

original: https://github.com/wujianguo/openai-proxy

  • flask_proxy_for_continue: 修复了硅基流动 reasoning_content 在 continue 插件不显示的问题;对于文本总结自动换用一个小模型
  • flask_proxy:original 的备份,改了目标地址为硅基流动

和 pyhttpdbg 一起用,可以观察app到底给大模型发送了什么

pyhttpdbg --script proxy.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from flask import Flask, request, Response
import requests
import logging
_FIELD_SEPARATOR = ':'
class SSEClient(object):
"""Implementation of a SSE client.
See http://www.w3.org/TR/2009/WD-eventsource-20091029/ for the
specification.
"""
def __init__(self, event_source, char_enc='utf-8'):
"""Initialize the SSE client over an existing, ready to consume
event source.
The event source is expected to be a binary stream and have a close()
method. That would usually be something that implements
io.BinaryIOBase, like an httplib or urllib3 HTTPResponse object.
"""
self._logger = logging.getLogger(self.__class__.__module__)
self._logger.debug('Initialized SSE client from event source %s',
event_source)
self._event_source = event_source
self._char_enc = char_enc
def _read(self):
"""Read the incoming event source stream and yield event chunks.
Unfortunately it is possible for some servers to decide to break an
event into multiple HTTP chunks in the response. It is thus necessary
to correctly stitch together consecutive response chunks and find the
SSE delimiter (empty new line) to yield full, correct event chunks."""
data = b''
for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
yield data
data = b''
if data:
yield data
def events(self):
for chunk in self._read():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)
# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue
data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]
# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug('Saw invalid field %s while parsing '
'Server Side Event', field)
continue
if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(' '):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ''
# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == 'data':
event.__dict__[field] += value + '\n'
else:
event.__dict__[field] = value
# Events with no data are not dispatched.
if not event.data:
continue
# If the data field ends with a newline, remove it.
if event.data.endswith('\n'):
event.data = event.data[0:-1]
# Empty event names default to 'message'
event.event = event.event or 'message'
# Dispatch the event
self._logger.debug('Dispatching %s...', event)
yield event
def close(self):
"""Manually close the event source stream."""
self._event_source.close()
class Event(object):
"""Representation of an event from the event stream."""
def __init__(self, id=None, event='message', data='', retry=None):
self.id = id
self.event = event
self.data = data
self.retry = retry
def __str__(self):
s = '{0} event'.format(self.event)
if self.id:
s += ' #{0}'.format(self.id)
if self.data:
s += ', {0} byte{1}'.format(len(self.data),
's' if len(self.data) else '')
else:
s += ', no data'
if self.retry:
s += ', retry in {0}ms'.format(self.retry)
return s
app = Flask(__name__)
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
def proxy(path):
url = request.url.replace(request.host_url, 'https://api.siliconflow.cn/')
stream = None
try:
stream = request.get_json().get('stream', None)
except:
pass
resp = requests.request(
method=request.method,
url=url,
stream=stream,
headers={key: value for (key, value)
in request.headers if key != 'Host'},
data=request.get_data(),
allow_redirects=False)
if not stream:
excluded_headers = ['content-encoding',
'content-length', 'transfer-encoding', 'connection']
headers = [(name, value) for (name, value) in resp.raw.headers.items(
) if name.lower() not in excluded_headers]
response = app.make_response((resp.content, resp.status_code, headers))
return response
def stream_generate():
client = SSEClient(resp)
for event in client.events():
yield ('data: ' + event.data + '\n\n')
return Response(stream_generate(), mimetype='text/event-stream')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=9000)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from flask import Flask, request, Response
import requests
import logging
import json
from enum import Enum, auto
_FIELD_SEPARATOR = ':'
class SSEClient(object):
"""Implementation of a SSE client.
See http://www.w3.org/TR/2009/WD-eventsource-20091029/ for the
specification.
"""
def __init__(self, event_source, char_enc='utf-8'):
"""Initialize the SSE client over an existing, ready to consume
event source.
The event source is expected to be a binary stream and have a close()
method. That would usually be something that implements
io.BinaryIOBase, like an httplib or urllib3 HTTPResponse object.
"""
self._logger = logging.getLogger(self.__class__.__module__)
self._logger.debug('Initialized SSE client from event source %s',
event_source)
self._event_source = event_source
self._char_enc = char_enc
def _read(self):
"""Read the incoming event source stream and yield event chunks.
Unfortunately it is possible for some servers to decide to break an
event into multiple HTTP chunks in the response. It is thus necessary
to correctly stitch together consecutive response chunks and find the
SSE delimiter (empty new line) to yield full, correct event chunks."""
data = b''
for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
yield data
data = b''
if data:
yield data
def events(self):
for chunk in self._read():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)
# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue
data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]
# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug('Saw invalid field %s while parsing '
'Server Side Event', field)
continue
if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(' '):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ''
# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == 'data':
event.__dict__[field] += value + '\n'
else:
event.__dict__[field] = value
# Events with no data are not dispatched.
if not event.data:
continue
# If the data field ends with a newline, remove it.
if event.data.endswith('\n'):
event.data = event.data[0:-1]
# Empty event names default to 'message'
event.event = event.event or 'message'
# Dispatch the event
self._logger.debug('Dispatching %s...', event)
yield event
def close(self):
"""Manually close the event source stream."""
self._event_source.close()
class Event(object):
"""Representation of an event from the event stream."""
def __init__(self, id=None, event='message', data='', retry=None):
self.id = id
self.event = event
self.data = data
self.retry = retry
def __str__(self):
s = '{0} event'.format(self.event)
if self.id:
s += ' #{0}'.format(self.id)
if self.data:
s += ', {0} byte{1}'.format(len(self.data),
's' if len(self.data) else '')
else:
s += ', no data'
if self.retry:
s += ', retry in {0}ms'.format(self.retry)
return s
class ThinkTagState(Enum):
"""States for the think tag conversion state machine"""
WAITING_FOR_FIRST_MESSAGE = auto() # Initial state, waiting for first assistant message
IN_THINKING_PHASE = auto() # Inside the thinking phase (after <think>)
IN_RESPONSE_PHASE = auto() # Inside the response phase (after </think>)
class ThinkTagStateMachine:
"""
State machine for handling think tag conversion in model responses.
Converts reasoning_content and content into proper <think>...</think> format.
"""
def __init__(self):
self.state = ThinkTagState.WAITING_FOR_FIRST_MESSAGE
print("initial state", self.state)
def process_delta(self, delta):
"""
Process a delta object based on current state and return modified delta.
Args:
delta (dict): The delta object from the model response
Returns:
dict: Modified delta with appropriate think tags
"""
content = delta.get('content')
reasoning_content = delta.get('reasoning_content')
role = delta.get('role')
# Make a copy of delta to avoid modifying the original
modified_delta = delta.copy()
# State transitions and actions
if self.state == ThinkTagState.WAITING_FOR_FIRST_MESSAGE:
if role == 'assistant':
# First message from assistant, insert <think> tag
modified_delta['content'] = "<think>"
self.state = ThinkTagState.IN_THINKING_PHASE
print("state trans: WAITING_FOR_FIRST_MESSAGE -> IN_THINKING_PHASE")
elif self.state == ThinkTagState.IN_THINKING_PHASE:
if content is None and reasoning_content is not None:
# Still in thinking phase, use reasoning_content as content
modified_delta['content'] = reasoning_content
elif content is not None and reasoning_content is None:
# Transition to response phase, add </think> tag
modified_delta['content'] = "</think>" + (content or "")
self.state = ThinkTagState.IN_RESPONSE_PHASE
print("state trans: IN_THINKING_PHASE -> IN_RESPONSE_PHASE")
elif self.state == ThinkTagState.IN_RESPONSE_PHASE:
# Already in response phase, no modifications needed
pass
return modified_delta
def reset(self):
"""Reset the state machine to its initial state"""
self.state = ThinkTagState.WAITING_FOR_FIRST_MESSAGE
app = Flask(__name__)
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>', methods=['GET', 'POST', 'PUT', 'DELETE'])
def proxy(path):
url = request.url.replace(request.host_url, 'https://api.siliconflow.cn/')
stream = None
# Get the request data
request_data = request.get_data()
# Check if this is a title generation request
try:
json_data = request.get_json()
if json_data:
# Check for title generation conditions:
# 1. Content starts with "Given the following... please reply with"
# 2. max_tokens is 12
messages = json_data.get('messages', [])
max_tokens = json_data.get('max_tokens')
is_title_generation = False
if messages and max_tokens == 12:
for message in messages:
if message.get('role') == 'user' and message.get('content', '').startswith("Given the following... please reply with"):
is_title_generation = True
break
# If this is a title generation request, replace the model with "small_cheap_model"
if is_title_generation and 'model' in json_data:
json_data['model'] = "Qwen/Qwen2.5-7B-Instruct"
# Update the request data with the modified model
request_data = json.dumps(json_data).encode('utf-8')
stream = json_data.get('stream', None)
except:
pass
resp = requests.request(
method=request.method,
url=url,
stream=stream,
headers={key: value for (key, value)
in request.headers if key != 'Host'},
data=request_data,
allow_redirects=False)
if not stream:
excluded_headers = ['content-encoding',
'content-length', 'transfer-encoding', 'connection']
headers = [(name, value) for (name, value) in resp.raw.headers.items(
) if name.lower() not in excluded_headers]
response = app.make_response((resp.content, resp.status_code, headers))
return response
def stream_generate():
client = SSEClient(resp)
# Initialize the think tag state machine
think_tag_machine = ThinkTagStateMachine()
for event in client.events():
data = event.data
try:
json_data = json.loads(data)
model = json_data.get('model')
thinking_model_list = ["Qwen/QwQ-32B", "deepseek-ai/DeepSeek-R1", "Pro/deepseek-ai/DeepSeek-R1"]
need_think_tag_convert = False
for thinking_model in thinking_model_list:
if model == thinking_model:
need_think_tag_convert = True
break
# Check if the model needs think tag convert
if need_think_tag_convert:
choices = json_data.get('choices', [])
if choices and len(choices) > 0:
delta = choices[0].get('delta', {})
# Process delta through the state machine
modified_delta = think_tag_machine.process_delta(delta)
# Update the JSON data with modified delta
choices[0]['delta'] = modified_delta
json_data['choices'] = choices
data = json.dumps(json_data, ensure_ascii=False)
except Exception as e:
logging.error(f"Error processing stream data: {e}")
yield ('data: ' + data + '\n\n')
return Response(stream_generate(), mimetype='text/event-stream')
if __name__ == '__main__':
app.run(host='0.0.0.0', port=9000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment