Skip to content

Instantly share code, notes, and snippets.

@PetterS
Created March 8, 2018 13:09
Show Gist options
  • Save PetterS/f684095a09fd1d8164a4d8b28ce3932d to your computer and use it in GitHub Desktop.
Save PetterS/f684095a09fd1d8164a4d8b28ce3932d to your computer and use it in GitHub Desktop.
import asyncio
import functools
import inspect
import os
import unittest
from unittest import mock
class AsyncTestCaseMeta(type(unittest.TestCase)):
def __new__(mcls, name, bases, ns):
for attrname, attr in ns.items():
if (attrname.startswith('test_') and
inspect.iscoroutinefunction(attr)):
ns[attrname] = mcls._sync_wrap(attr)
return super().__new__(mcls, name, bases, ns)
@staticmethod
def _sync_wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return asyncio.get_event_loop().run_until_complete(func(*args, **kwargs))
return wrapper
class AsyncTestCase(unittest.TestCase, metaclass=AsyncTestCaseMeta):
pass
class MyTest(AsyncTestCase):
def test_sync(self):
pass
async def test_async(self):
await asyncio.sleep(0.0)
@mock.patch("os.path", {})
async def test_async_with_mock(self):
await asyncio.sleep(0.0)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment