Created
May 17, 2025 16:42
-
-
Save whiler/2bfa310a9bc2904e696322f73c6e2556 to your computer and use it in GitHub Desktop.
HighFrequencyBuffer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from buf import HighFrequencyBuffer | |
def test_throughput(): | |
buffer = HighFrequencyBuffer() | |
test_size = 100_000 | |
# 写入性能 | |
start = time.perf_counter() | |
for i in range(test_size): | |
buffer.put(i) | |
write_time = time.perf_counter() - start | |
# 读取性能 | |
start = time.perf_counter() | |
result = buffer.get_all() | |
read_time = time.perf_counter() - start | |
print(f"单线程写入 {test_size} 次耗时: {write_time:.4f}s") | |
print(f"单次读取 {test_size} 元素耗时: {read_time:.6f}s") | |
assert len(result) == test_size | |
if __name__ == "__main__": | |
test_throughput() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import deque | |
from threading import Lock | |
class HighFrequencyBuffer: | |
"""高频数据缓冲区(双缓冲线程安全设计) | |
适用场景:生产者-消费者模型中需要高频读写操作的场景 | |
设计目标: | |
- 写入无阻塞,读取批量获取 | |
- 通过双缓冲技术减少锁竞争 | |
- 低延迟、高吞吐量的线程安全操作 | |
""" | |
def __init__(self): | |
# 双缓冲区设计: | |
# - _write_buf:当前写入缓冲区,生产者持续向其追加数据 | |
# - _read_buf:预备读取缓冲区,消费者获取时与其交换 | |
# 使用 deque 实现高效的头部/尾部操作(O(1)时间复杂度) | |
self._write_buf = deque() # 活跃写入缓冲区 | |
self._read_buf = deque() # 预备读取缓冲区(复用内存) | |
# 线程安全锁: | |
# - 细粒度锁:仅保护缓冲区交换操作 | |
# - 使用 with 语句确保异常安全 | |
self._lock = Lock() # 互斥锁,保护缓冲区交换过程 | |
def put(self, item): | |
"""线程安全写入单个元素 | |
特性: | |
- 无阻塞设计:仅需要获取锁执行 append 操作 | |
- 时间复杂度:O(1) 恒定时间操作 | |
""" | |
with self._lock: # 获取锁(上下文管理器自动释放) | |
# 将元素追加到当前写入缓冲区尾部 | |
self._write_buf.append(item) | |
def get_all(self): | |
"""线程安全批量获取全部元素(非破坏性读取) | |
执行流程: | |
1. 加锁执行缓冲区交换 | |
2. 清空新的写入缓冲区(复用原读取缓冲区内存) | |
3. 在锁外转换数据,减少锁持有时间 | |
返回: | |
- list 类型结果(保证数据不可变性) | |
""" | |
with self._lock: | |
# 快速路径:缓冲区为空时立即返回 | |
if not self._write_buf: | |
return [] | |
# 双缓冲交换操作(关键性能优化点): | |
# - 将当前写入缓冲区变为读取缓冲区 | |
# - 复用原来的读取缓冲区作为新的写入缓冲区 | |
self._read_buf, self._write_buf = self._write_buf, self._read_buf | |
# 清空新的写入缓冲区(复用原_read_buf的内存分配) | |
# 注意:此处 clear() 操作是 O(1) 时间复杂度 | |
self._write_buf.clear() | |
# 在锁外执行数据转换(性能关键路径优化): | |
# - 将 deque 转换为 list 使返回结果不可变 | |
# - 避免在锁内执行耗时的类型转换操作 | |
return list(self._read_buf) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from threading import Thread | |
from buf import HighFrequencyBuffer | |
import time | |
def test_concurrent_performance(): | |
buffer = HighFrequencyBuffer() | |
producer_count = 4 | |
test_size = 500_000 | |
total_items = producer_count * test_size | |
# 生产者线程 | |
def producer(): | |
for i in range(test_size): | |
buffer.put(i) | |
# 消费者线程 | |
def consumer(): | |
collected = 0 | |
while collected < total_items: | |
items = buffer.get_all() | |
collected += len(items) | |
# 启动线程 | |
producers = [Thread(target=producer) for _ in range(producer_count)] | |
consumer_thread = Thread(target=consumer) | |
start = time.perf_counter() | |
for t in producers: | |
t.start() | |
consumer_thread.start() | |
for t in producers: | |
t.join() | |
consumer_thread.join() | |
duration = time.perf_counter() - start | |
print(f"并发吞吐量: {total_items/duration:.2f} ops/s") | |
if __name__ == "__main__": | |
test_concurrent_performance() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from threading import Thread | |
from buf import HighFrequencyBuffer | |
import time | |
class TestHighFrequencyBuffer(unittest.TestCase): | |
def setUp(self): | |
self.buffer = HighFrequencyBuffer() | |
# 基础功能测试 | |
def test_basic_operations(self): | |
# 测试空缓冲区 | |
self.assertEqual(self.buffer.get_all(), []) | |
# 测试单元素 | |
self.buffer.put(1) | |
self.assertEqual(self.buffer.get_all(), [1]) | |
self.assertEqual(self.buffer.get_all(), []) # 确认缓冲区已清空 | |
# 测试多元素 | |
for i in range(5): | |
self.buffer.put(i) | |
self.assertEqual(self.buffer.get_all(), [0, 1, 2, 3, 4]) | |
# 缓冲区交换测试 | |
def test_buffer_swapping(self): | |
# 第一次填充 | |
self.buffer.put("A") | |
result1 = self.buffer.get_all() | |
self.assertEqual(result1, ["A"]) | |
# 第二次填充(测试缓冲区复用) | |
self.buffer.put("B") | |
result2 = self.buffer.get_all() | |
self.assertEqual(result2, ["B"]) | |
# 线程安全测试 | |
def test_concurrent_access(self): | |
def producer(): | |
for i in range(1000): | |
self.buffer.put(i) | |
def consumer(): | |
time.sleep(0.1) # 确保生产者先启动 | |
total = 0 | |
while total < 1000: | |
items = self.buffer.get_all() | |
total += len(items) | |
# 启动生产者和消费者线程 | |
prod_thread = Thread(target=producer) | |
cons_thread = Thread(target=consumer) | |
prod_thread.start() | |
cons_thread.start() | |
prod_thread.join() | |
cons_thread.join() | |
# 验证数据完整性 | |
self.assertEqual(len(self.buffer.get_all()), 0) | |
self.assertTrue(True) # 如果未发生死锁则通过 | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment