|
#!/usr/bin/env python |
|
# -*- coding: utf-8 -*- |
|
""" Gevent pool/group alike: make asyncio easier to use |
|
|
|
Usage:: |
|
|
|
>>> @asyncio.coroutine |
|
>>> def f(url): |
|
... r = yield from aiohttp.request('get', url) |
|
... content = yield from r.read() |
|
... print('{}: {}'.format(url, content[:80])) |
|
|
|
>>> g = Group() |
|
>>> g.async(f('http://www.baidu.com')) |
|
>>> g.async(f('http://www.sina.com.cn')) |
|
>>> g.join() |
|
|
|
>>> # limit the concurrent coroutines to 3 |
|
>>> p = Pool(3) |
|
>>> for _ in range(10): |
|
... p.async(f('http://www.baidu.com')) |
|
>>> p.join() |
|
""" |
|
import asyncio |
|
|
|
|
|
class Group(object): |
|
|
|
def __init__(self, loop=None): |
|
self.loop = loop or asyncio.get_event_loop() |
|
self._prepare() |
|
|
|
def _prepare(self): |
|
self.counter = 0 |
|
self.task_waiter = asyncio.futures.Future(loop=self.loop) |
|
|
|
def spawn(self, coro_or_future): |
|
self.counter += 1 |
|
task = asyncio.async(coro_or_future) |
|
task.add_done_callback(self._on_completion) |
|
return task |
|
|
|
async = spawn |
|
|
|
def _on_completion(self, f): |
|
self.counter -= 1 |
|
f.remove_done_callback(self._on_completion) |
|
if self.counter <= 0: |
|
if not self.task_waiter.done(): |
|
self.task_waiter.set_result(None) |
|
|
|
def join(self): |
|
self.loop.run_until_complete(self.task_waiter) |
|
self._prepare() |
|
|
|
|
|
class Pool(Group): |
|
|
|
def __init__(self, pool_size, loop=None): |
|
self.sem = asyncio.Semaphore(pool_size, loop=loop) |
|
super(Pool, self).__init__(loop) |
|
|
|
def spawn(self, coro): |
|
assert asyncio.iscoroutine(coro), 'pool only accepts coroutine' |
|
|
|
@asyncio.coroutine |
|
def _limit_coro(): |
|
with (yield from self.sem): |
|
yield from coro |
|
|
|
self.counter += 1 |
|
task = asyncio.async(_limit_coro()) |
|
task.add_done_callback(self._on_completion) |
|
return task |
|
|
|
async = spawn |
|
|
|
|
|
def test_group(): |
|
import time |
|
|
|
@asyncio.coroutine |
|
def f(i): |
|
t0 = time.time() |
|
yield from asyncio.sleep(0.5) |
|
t = time.time() - t0 |
|
print('finish {}, seconds={:4.2f}'.format(i, t)) |
|
|
|
print('testing group') |
|
t0 = time.time() |
|
g = Group() |
|
for i in range(9): |
|
g.spawn(f(i)) |
|
g.join() |
|
print('total time: {:4.2f}'.format(time.time() - t0)) |
|
assert 0.5 * 0.9 < time.time() - t0 < 0.5 * 1.1 |
|
|
|
|
|
def test_pool(): |
|
import time |
|
|
|
@asyncio.coroutine |
|
def f(i): |
|
t0 = time.time() |
|
yield from asyncio.sleep(0.5) |
|
t = time.time() - t0 |
|
print('finish {}, seconds={:4.2f}'.format(i, t)) |
|
|
|
print('testing pool') |
|
t0 = time.time() |
|
p = Pool(3) |
|
for i in range(9): |
|
p.spawn(f(i)) |
|
p.join() |
|
print('total time: {:4.2f}'.format(time.time() - t0)) |
|
assert 0.5 * 3 * 0.9 < time.time() - t0 < 0.5 * 3 * 1.1 |
|
|
|
if __name__ == '__main__': |
|
test_group() |
|
test_pool() |